# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import abc
from typing import ClassVar, Optional, Tuple, Union
import onnx
import onnx.helper
import onnxscript.ir
_DType = onnxscript.ir.DataType
_DimType = Union[int, str, type(None)]
_ShapeType = Union[Tuple[_DimType, ...], _DimType, type(Ellipsis)]
_tensor_type_shape_cache: dict[_DType, TensorType] = {}
tensor_type_registry: dict[_DType, TensorType] = {}
def _check_dim(dim):
if not isinstance(dim, (int, str, type(None))):
raise TypeError(f"Invalid dimension {dim}")
def _check_shape(shape):
if isinstance(shape, tuple):
for dim in shape:
_check_dim(dim)
elif shape != Ellipsis:
_check_dim(shape)
class TensorType(abc.ABC):
"""ONNX Script representation of a tensor type supporting shape annotations.
A scalar-tensor of rank 0:
::
tensor: FLOAT
A tensor of unknown rank:
::
tensor: FLOAT[...]
A tensor of rank 2 of unknown dimensions, with symbolic names:
::
tensor: FLOAT['M', 'N']
A tensor of rank 2 of known dimensions:
::
tensor: FLOAT[128, 1024]
"""
dtype: ClassVar[_DType]
shape: ClassVar[Optional[_ShapeType]]
def __new__(cls):
raise NotImplementedError("TensorTypes cannot be instantiated")
def __init_subclass__(cls, dtype: _DType, shape: Optional[_ShapeType] = None):
cls.dtype = dtype
cls.shape = shape
if shape is None:
existing_cls = tensor_type_registry.get(dtype)
if existing_cls is not None:
raise ValueError(
f"Invalid usage: subclass {existing_cls!r} "
f"already defined for dtype={dtype}"
)
tensor_type_registry[dtype] = cls
else:
_check_shape(shape)
def __class_getitem__(cls, shape: Optional[_ShapeType]) -> type[TensorType]:
if cls.shape is not None:
raise ValueError("Invalid usage: shape already specified.")
if shape is None:
# Treat FLOAT[NONE] as 1-dimensional tensor with unknown dimension
shape = (None,)
key = (cls.dtype, shape)
shaped_type = _tensor_type_shape_cache.get(key)
if shaped_type is None:
shaped_type = type(cls.__name__, (TensorType,), {}, dtype=cls.dtype, shape=shape)
_tensor_type_shape_cache[key] = shaped_type
return shaped_type
@classmethod
def to_type_proto(cls) -> onnx.TypeProto:
if cls.shape is None:
shape = () # "FLOAT" is treated as a scalar
elif cls.shape is Ellipsis:
shape = None # "FLOAT[...]" is a tensor of unknown rank
elif isinstance(cls.shape, tuple):
shape = cls.shape # example: "FLOAT[10,20]"
else:
shape = [cls.shape] # example: "FLOAT[10]"
return onnx.helper.make_tensor_type_proto(cls.dtype, shape)
@classmethod
def to_string(cls) -> str:
return f"tensor({cls.__name__.lower()})"
[docs]
class FLOAT(TensorType, dtype=onnxscript.ir.DataType.FLOAT):
pass
[docs]
class UINT8(TensorType, dtype=onnxscript.ir.DataType.UINT8):
pass
[docs]
class INT8(TensorType, dtype=onnxscript.ir.DataType.INT8):
pass
[docs]
class UINT16(TensorType, dtype=onnxscript.ir.DataType.UINT16):
pass
[docs]
class INT16(TensorType, dtype=onnxscript.ir.DataType.INT16):
pass
[docs]
class INT32(TensorType, dtype=onnxscript.ir.DataType.INT32):
pass
[docs]
class INT64(TensorType, dtype=onnxscript.ir.DataType.INT64):
pass
[docs]
class STRING(TensorType, dtype=onnxscript.ir.DataType.STRING):
pass
[docs]
class BOOL(TensorType, dtype=onnxscript.ir.DataType.BOOL):
pass
[docs]
class FLOAT16(TensorType, dtype=onnxscript.ir.DataType.FLOAT16):
pass
[docs]
class DOUBLE(TensorType, dtype=onnxscript.ir.DataType.DOUBLE):
pass
[docs]
class UINT32(TensorType, dtype=onnxscript.ir.DataType.UINT32):
pass
[docs]
class UINT64(TensorType, dtype=onnxscript.ir.DataType.UINT64):
pass
[docs]
class COMPLEX64(TensorType, dtype=onnxscript.ir.DataType.COMPLEX64):
pass
[docs]
class COMPLEX128(TensorType, dtype=onnxscript.ir.DataType.COMPLEX128):
pass
[docs]
class BFLOAT16(TensorType, dtype=onnxscript.ir.DataType.BFLOAT16):
pass
class FLOAT8E4M3FN(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E4M3FN):
pass
class FLOAT8E4M3FNUZ(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E4M3FNUZ):
pass
class FLOAT8E5M2(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E5M2):
pass
class FLOAT8E5M2FNUZ(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E5M2FNUZ):
pass
class INT4(TensorType, dtype=onnxscript.ir.DataType.INT4):
pass
class UINT4(TensorType, dtype=onnxscript.ir.DataType.UINT4):
pass
def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str:
"""Converts an onnx type into the string representation of the type in *onnxscript*.
Args:
onnx_type: an instance of onnx TypeProto
Returns:
The string representation of the type in onnxscript
Raises:
...
"""
if onnx_type.HasField("tensor_type"):
elem_type = onnx_type.tensor_type.elem_type
name = onnx.TensorProto.DataType.Name(elem_type)
if onnx_type.tensor_type.HasField("shape"):
shape = []
for d in onnx_type.tensor_type.shape.dim:
if d.HasField("dim_value"):
shape.append(str(d.dim_value))
elif d.HasField("dim_param"):
shape.append(repr(d.dim_param))
else:
shape.append("None")
if not shape:
return name
return f"{name}[{','.join(shape)}]"
return f"{name}[...]"
raise NotImplementedError(f"Unable to translate type {onnx_type!r} into onnxscript type.")
# Currently, only tensor types are supported. Need to expand support for other ONNX types.
ONNXType = TensorType