# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""ONNX IR enums that matches the ONNX spec."""
from __future__ import annotations
import enum
import ml_dtypes
import numpy as np
class AttributeType(enum.IntEnum):
"""Enum for the types of ONNX attributes."""
UNDEFINED = 0
FLOAT = 1
INT = 2
STRING = 3
TENSOR = 4
GRAPH = 5
FLOATS = 6
INTS = 7
STRINGS = 8
TENSORS = 9
GRAPHS = 10
SPARSE_TENSOR = 11
SPARSE_TENSORS = 12
TYPE_PROTO = 13
TYPE_PROTOS = 14
def __repr__(self) -> str:
return self.name
def __str__(self) -> str:
return self.__repr__()
class DataType(enum.IntEnum):
"""Enum for the data types of ONNX tensors, defined in ``onnx.TensorProto``."""
# NOTE: Naming: It is tempting to use shorter and more modern names like f32, i64,
# but we should stick to the names used in the ONNX spec for consistency.
UNDEFINED = 0
FLOAT = 1
UINT8 = 2
INT8 = 3
UINT16 = 4
INT16 = 5
INT32 = 6
INT64 = 7
STRING = 8
BOOL = 9
FLOAT16 = 10
DOUBLE = 11
UINT32 = 12
UINT64 = 13
COMPLEX64 = 14
COMPLEX128 = 15
BFLOAT16 = 16
FLOAT8E4M3FN = 17
FLOAT8E4M3FNUZ = 18
FLOAT8E5M2 = 19
FLOAT8E5M2FNUZ = 20
UINT4 = 21
INT4 = 22
FLOAT4E2M1 = 23
[docs]
@classmethod
def from_numpy(cls, dtype: np.dtype) -> DataType:
"""Returns the ONNX data type for the numpy dtype.
Raises:
TypeError: If the data type is not supported by ONNX.
"""
if dtype not in _NP_TYPE_TO_DATA_TYPE:
raise TypeError(f"Unsupported numpy data type: {dtype}")
return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
@property
def itemsize(self) -> float:
"""Returns the size of the data type in bytes."""
return _ITEMSIZE_MAP[self]
[docs]
def numpy(self) -> np.dtype:
"""Returns the numpy dtype for the ONNX data type.
Raises:
TypeError: If the data type is not supported by numpy.
"""
if self not in _DATA_TYPE_TO_NP_TYPE:
raise TypeError(f"Numpy does not support ONNX data type: {self}")
return _DATA_TYPE_TO_NP_TYPE[self]
def __repr__(self) -> str:
return self.name
def __str__(self) -> str:
return self.__repr__()
_ITEMSIZE_MAP = {
DataType.FLOAT: 4,
DataType.UINT8: 1,
DataType.INT8: 1,
DataType.UINT16: 2,
DataType.INT16: 2,
DataType.INT32: 4,
DataType.INT64: 8,
DataType.STRING: 1,
DataType.BOOL: 1,
DataType.FLOAT16: 2,
DataType.DOUBLE: 8,
DataType.UINT32: 4,
DataType.UINT64: 8,
DataType.COMPLEX64: 8,
DataType.COMPLEX128: 16,
DataType.BFLOAT16: 2,
DataType.FLOAT8E4M3FN: 1,
DataType.FLOAT8E4M3FNUZ: 1,
DataType.FLOAT8E5M2: 1,
DataType.FLOAT8E5M2FNUZ: 1,
DataType.UINT4: 0.5,
DataType.INT4: 0.5,
DataType.FLOAT4E2M1: 0.5,
}
# We use ml_dtypes to support dtypes that are not in numpy.
_NP_TYPE_TO_DATA_TYPE = {
np.dtype("bool"): DataType.BOOL,
np.dtype("complex128"): DataType.COMPLEX128,
np.dtype("complex64"): DataType.COMPLEX64,
np.dtype("float16"): DataType.FLOAT16,
np.dtype("float32"): DataType.FLOAT,
np.dtype("float64"): DataType.DOUBLE,
np.dtype("int16"): DataType.INT16,
np.dtype("int32"): DataType.INT32,
np.dtype("int64"): DataType.INT64,
np.dtype("int8"): DataType.INT8,
np.dtype("object"): DataType.STRING,
np.dtype("uint16"): DataType.UINT16,
np.dtype("uint32"): DataType.UINT32,
np.dtype("uint64"): DataType.UINT64,
np.dtype("uint8"): DataType.UINT8,
np.dtype(ml_dtypes.bfloat16): DataType.BFLOAT16,
np.dtype(ml_dtypes.float8_e4m3fn): DataType.FLOAT8E4M3FN,
np.dtype(ml_dtypes.float8_e4m3fnuz): DataType.FLOAT8E4M3FNUZ,
np.dtype(ml_dtypes.float8_e5m2): DataType.FLOAT8E5M2,
np.dtype(ml_dtypes.float8_e5m2fnuz): DataType.FLOAT8E5M2FNUZ,
np.dtype(ml_dtypes.int4): DataType.INT4,
np.dtype(ml_dtypes.uint4): DataType.UINT4,
}
# TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE
_NP_TYPE_TO_DATA_TYPE.update(
{np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1}
if hasattr(ml_dtypes, "float4_e2m1fn")
else {}
)
# ONNX DataType to Numpy dtype.
_DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}