Question

How to create a Python type alias for a parametrized type

jaxtyping provides type annotations that use a str as parameter (as opposed to a type), e.g.:

Float[Array, "dim1 dim2"]

Let's say I would like to create a type alias which combines the Float and Array part. This means I would like to be able to write

MyOwnType["dim1 dim2"]

instead. To my understanding I cannot use TypeAlias/TypeVar, as the generic parameter (here "dim1 dim2") is not a type but actually an instance of a type.

Is there a concise Pythonic way to achieve this?

EDIT:

I tried the following and it does not work:

class _Singleton:

    def __getitem__(self, shape: str) -> Float:
        return Float[Array, shape]

MyOwnType = _Singleton()

Using MyOwnType["dim1 dim2"]as function parameter annotation gives the mypy complaint Variable "MyOwnType" is not valid as a type

Solution:

Based on the answer of @chepner this was the final solution:

class MyOwnType(Generic[Shape]):
    def __class_getitem__(cls, shape: str) -> Float:
        return Float[Array, shape]
 4  69  4
1 Jan 1970

Solution

 2

You need to override/define __class_getitem__, not __getitem__ (which applies to instances of _Singleton, not _Singleton itself).

class _Singleton:

    def __class_getitem__(self, shape: str) -> Float:
        return Float[Array, shape]
2024-07-15
chepner