Source code for onnxscript.converter

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import ast
import logging
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    List,
    NoReturn,
    Optional,
    Sequence,
    Tuple,
    Union,
)

import onnx

import onnxscript
from onnxscript import irbuilder, onnx_types, sourceinfo, values
from onnxscript import type_annotation as ta
from onnxscript._internal import analysis, ast_utils, autocast, param_manipulation

PY_VERSION_GE_39 = ast_utils.PY_VERSION_GE_39


logger = logging.getLogger("onnxscript")


# Python-to-IR converter:


def not_allowed(construct):
    return f"{construct}not supported."


class TranslationError(Exception):
    def __init__(self, *args: object) -> None:
        super().__init__(*args)


def warn(msg):
    logger.warning(msg)


def fail(msg) -> NoReturn:
    raise TranslationError(msg)


def fail_if(cond, msg):
    if cond:
        raise TranslationError(msg)


def ignore(cond, msg):
    if cond:
        warn(msg)


# map from python operators to ONNX ops
primop_map = {
    ast.Add: "Add",
    ast.And: "And",
    ast.BitAnd: "And",
    ast.BitOr: "Or",
    ast.Div: "Div",
    ast.Eq: "Equal",
    ast.Gt: "Greater",
    ast.GtE: "GreaterOrEqual",
    ast.Lt: "Less",
    ast.LtE: "LessOrEqual",
    ast.MatMult: "MatMul",
    ast.Mod: "Mod",
    ast.Mult: "Mul",
    ast.Not: "Not",
    ast.NotEq: "NotEqual",
    ast.Or: "Or",
    ast.Pow: "Pow",
    ast.Sub: "Sub",
    ast.USub: "Neg",
}


class Variable:
    """Represents an ONNX variable.

    TODO(rama): Consider merging this with IRVar. However, "castable" is specific to this
    converter.
    """

    def __init__(self, name: str, castable: bool = False):
        """Initialize the instance.

        Args:
           name: Name of the ONNX variable
           castable: Whether this variable is castable to a desired target type.
              Used for ONNX variables representing constants created from python values
              like 0 or 1 or 0.5 which are treated as polymorphic values castable to other
              types as needed.
        """
        self.name = name
        self.is_castable = castable

    def __str__(self) -> str:
        return self.name


if TYPE_CHECKING:
    # The type-alias LocalSymValue represents the types of values that local names in a
    # script-function may be bound to during translation, (ONNX IR values).
    # TODO(rama): Rationalize this and values.SymbolValue

    LocalSymValue = Union[values.SymbolValue, irbuilder.IRFunction]

    # The type-alias PyValue is used to represent the types of python values that may be used
    # in an ONNX Script function.
    # TODO(rama): Flesh out the set of valid types here. These include values such as
    # 1 (int), 1.0 (float), [2, 4], [1.0], etc. which will be converted to ONNX, for
    # use as value-parameters or attribute-parameters in an ONNX call (Node).

    PyValue = Any

    # The type-alias SymValue denotes values that an identifier may be bound to during
    # translation. A local name will be bound to a LocalSymValue, while a global name
    # will be bound to a PyValue.

    SymValue = Union[LocalSymValue, PyValue]

    # PreferredName is a type-alias used to represent the preferred name used in the generated
    # ONNX for a value returned by an expression. There is no guarantee that the specified
    # name will be used exactly. The converter will modify the name (with a suffix),
    # if necesssary, to ensure that it is unique (to ensure ONNX's SSA requirement).

    PreferredName = str

    # The type-alias OnnxVar indicates variable names used in the generated ONNX.
    OnnxVarName = str


[docs] class Converter: """Main class to translate python code into ONNX operators. Args: ir_builder: convert AST node into ONNX structures, if None, class :class:`onnxscript.irbuilder.IRBuilder` is used The class uses logger `onnxscript`. Logging can be enabled with the following code: :: import logging logging.basicConfig(level=logging.DEBUG) Or if you need to enable only the logger used by this module: :: import logging logger = logging.getLogger('onnxscript') logger.setLevel(logging.DEBUG) console = logging.StreamHandler() logger.addHandler(console) """ def __init__( self, ir_builder: Optional[irbuilder.IRBuilder] = None, opset: Optional[values.Opset] = None, global_names: Optional[dict[str, Any]] = None, source: Optional[str] = None, default_opset: Optional[values.Opset] = None, ): self.ir_builder = ir_builder or irbuilder.IRBuilder() self.source = source if global_names is not None: # We make a copy in case function eval modifies it. self.globals = global_names.copy() self.this_module = opset self.default_opset_ = default_opset # States initialized by `_init_function_translation` self._outer: List[irbuilder.IRFunction] = [] self._current_fn: irbuilder.IRFunction = None self._nextvar: int = 0 self._used_vars: set[str] = set() self._locals: List[Dict[str, LocalSymValue]] = [{}] @property def default_opset(self) -> values.Opset: if self.default_opset_ is None: raise RuntimeError( "default_opset must be specified in script for functions " "that do not contain any use of an ONNX opset." ) return self.default_opset_ def _set_default_opset(self, opset: values.Opset, node: ast.AST) -> None: if opset.domain != "": return if self.default_opset_ is not None: if ( opset.domain != self.default_opset_.domain or opset.version != self.default_opset_.version ): self.fail( node, f"Two distincts opset were used ({opset} != {self.default_opset_})." ) else: self.default_opset_ = opset def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]: """Find the (first) ONNX opset used in the function, if any.""" # Search for a Call expression of form "op.OpName(...)" if isinstance(node, ast.Call): if isinstance(node.func, ast.Attribute): opset_expr = node.func.value if isinstance(opset_expr, ast.Name): if opset_expr.id in self.globals: opset = self.globals[opset_expr.id] if isinstance(opset, values.Opset) and opset.domain == "": return opset for child in ast.iter_child_nodes(node): res = self._find_onnx_opset(child) if res is not None: return res return None def _init_function_translation(self) -> None: """Initialize self for translating a new (top-level) function.""" self._outer = [] self._current_fn: Optional[irbuilder.IRFunction] = None self._nextvar = 0 self._used_vars = set() self._locals: List[Dict[str, LocalSymValue]] = [{}] def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo: return sourceinfo.SourceInfo(node, self.source, self._current_fn.name) def _message(self, node: ast.AST, error_msg: str) -> str: """Constructs an error _message containing source information about an ast node.""" return self._source_of(node).msg(error_msg) def warn(self, node: ast.AST, error_msg: str) -> None: warn(self._message(node, error_msg)) def fail(self, node: ast.AST, error_msg: str) -> NoReturn: fail(self._message(node, error_msg)) # Name resolution and namescopes: This component handles the following aspects: # * Name-scopes are different in Python and the generated ONNX: # - Control-flow blocks (a loop body or the then-or-else block of an if-stmt) # form part of the same name-scope in Python, but will be mapped to a nested # name-scope (as a sub-graph) in ONNX. # * Script-time name-value tracking: Name _lookup during script-time returns # statically-known information about the value the name will have at runtime. def _enter_scope(self, name: str, parent_node: ast.AST): """Enter a control-flow block (a loop body or if-then-else branch). The block is translated into a nested-scope in ONNX. """ self._outer.insert(0, self._current_fn) self._current_fn = self.ir_builder.new_function(name) self._locals.insert(0, {}) logger.debug("Converter:_enter_scope:%d:node:%s", len(self._locals), type(parent_node)) def _exit_scope(self) -> irbuilder.IRFunction: """Exit from a control-flow block (a loop body or if-then-else branch).""" logger.debug("Converter:_exit_scope:%d", len(self._locals)) graph = self._current_fn self._current_fn = self._outer.pop(0) self._locals.pop(0) return graph def _current_scope(self) -> Dict[str, LocalSymValue]: return self._locals[0] def _bind(self, name: str, val: LocalSymValue) -> None: logger.debug("Converter:_bind:%s", name) self._locals[0][name] = val def _lookup( self, name: str, info: sourceinfo.SourceInfo, raise_exception: bool = True ) -> SymValue: for scope in self._locals: if name in scope: return scope[name] if name in self.globals: return self.globals[name] if raise_exception: raise ValueError(info.msg(f"Unbound name: {name}.")) return None def generate_unique_name(self, candidate: str = "tmp") -> str: # TODO(justinchuby): Can we reduce the O complexity of this function? r = candidate while r in self._used_vars: r = f"{candidate}_{self._nextvar}" self._nextvar = self._nextvar + 1 self._used_vars.add(r) return r def _make_onnx_attr( self, attrname: str, attrval: Any, attrtype: Optional[int] = None ) -> irbuilder.IRAttributeValue: def tensor_name_generator() -> str: """Return name to be used for tensor, if we need to create one.""" return self.generate_unique_name(f"attr_{attrname}") proto = autocast.pyvalue_to_onnx_attribute( attrname, attrval, tensor_name_generator, attrtype ) return self.ir_builder.make_attr(proto) def _to_onnx_attr_ref( self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo] ) -> irbuilder.IRAttributeValue: pytype = val.typeinfo attrtype = ta.pytype_to_attrtype(pytype) attrname = None if attrtype is onnx.AttributeProto.FLOAT: attrname = "value_float" elif attrtype is onnx.AttributeProto.INT: attrname = "value_int" elif attrtype is onnx.AttributeProto.STRING: attrname = "value_string" elif attrtype is onnx.AttributeProto.INTS: attrname = "value_ints" else: msg = f"Unsupported attribute type {pytype!r}." fail(info.msg(msg) if info else msg) return self.ir_builder.make_attr_ref(attrname, val.value, pytype) def _to_onnx_var( self, val: values.SymbolValue | PyValue, target: Optional[PreferredName] = None, info: Optional[sourceinfo.SourceInfo] = None, ) -> Variable: if isinstance(val, values.AttrRef): # promote attribute to value result = self.generate_unique_name(target or "tmp") attr = self._to_onnx_attr_ref(val, info) self.emit([result], values.Op(self.default_opset, "Constant"), [], [attr]) if ta.base_type_is_bool(val.typeinfo): # ONNX attributes use an int-encoding for bools, but ONNX tensor types # distinguish between int and bool. So we cast the int tensor to a bool tensor, # to promote a (python) bool attribute to a ONNX bool tensor. result_as_bool = self.generate_unique_name(result + "_as_bool") cast_attr = self._make_onnx_attr("to", onnx_types.BOOL.dtype) self.emit( [result_as_bool], values.Op(self.default_opset, "Cast"), [result], [cast_attr], ) return Variable(result_as_bool, True) return Variable(result, True) if isinstance(val, values.Dynamic): return Variable(val.value) # Assume value is a python-value convertible to a tensor # TODO: check if value is convertible to a TensorProto, so that we can # produce a better error _message otherwise return self._emit_const(val, target or "tmp", info) def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> Variable: return self._to_onnx_var(self._lookup(py_var, info), target=py_var, info=info) def emit( self, outputs: Sequence[str], callee: values.Op | str, inputs: Sequence[Optional[str]], attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None, sub_functions: Optional[dict[str, onnx.FunctionProto]] = None, ): if not isinstance(callee, values.Op): callee = values.Op(self.default_opset, callee) if attrs is None: attrs = [] if sub_functions is None: sub_functions = {} self.ir_builder.add_stmt( self._current_fn, outputs, callee, inputs, attrs, sub_functions, ) def _emit_const( self, pyvalue: PyValue, suggested_name: Optional[PreferredName], info: sourceinfo.SourceInfo, ) -> Variable: if suggested_name is None: if isinstance(pyvalue, int): if pyvalue >= 0: suggested_name = f"int64_{pyvalue}" else: suggested_name = f"int64_m{abs(pyvalue)}" elif ( isinstance(pyvalue, list) and len(pyvalue) == 1 and isinstance(pyvalue[0], int) ): if pyvalue[0] >= 0: suggested_name = f"int64_{pyvalue[0]}_1d" else: suggested_name = f"int64_m{abs(pyvalue[0])}_1d" else: suggested_name = "const" ovar = self.generate_unique_name(suggested_name) try: tensor = autocast.pyvalue_to_onnx_tensor(ovar, pyvalue) except ValueError as e: fail(info.msg(str(e))) attr = self._make_onnx_attr("value", tensor) self.emit([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) return Variable(ovar, True) def _emit_copy(self, original_var: str, suggested_name: str) -> str: """Emits a copy statement, using the ONNX Identity operator.""" new_var = self.generate_unique_name(suggested_name) self.emit([new_var], "Identity", [original_var]) return new_var def _is_constant_expr(self, node: ast.AST) -> None: if isinstance(node, ast.UnaryOp): return self._is_constant_expr(node.operand) if isinstance( node, ( ast.Call, ast.BinOp, ast.UnaryOp, ast.Compare, ast.Num, ast.Str, ast.Attribute, ast.List, ast.Load, ast.NameConstant, ast.Constant, ast.Str, ), ): return all(self._is_constant_expr(c) for c in ast.iter_child_nodes(node)) return False def _eval_constant_expr(self, expr: ast.AST) -> PyValue: """Evaluates a sub-expression that is assumed to represent a constant value. The expression can refer only to global names (inherited from the scope where the script is evaluated) and cannot refer to local names defined within the script.) Further, these expressions are assumed to be constants. Thus, any subsequent mutation of any state/variables (used in computing this constant value) will potentially lead to unexpected behavior (such as divergence between eager-mode execution and evaluation of the ONNX function.) """ # TODO: assert (self._is_constant_expr(expr)) # TODO: Refine types locals: dict[Any, Any] = {} expr = ast.Expression(expr, lineno=expr.lineno, col_offset=expr.col_offset) cpl = compile(expr, filename="<ast>", mode="eval") try: return eval(cpl, self.globals, locals) # pylint: disable=eval-used except NameError as e: raise NameError( self._message( expr, f"Missing names, globals contains {list(self.globals)!r}, " f"locals {list(locals)!r}.", ) ) from e def _translate_attr( self, attr_name: str, expr: ast.AST, attr_meta: Optional[onnx.defs.OpSchema.Attribute] = None, ) -> Optional[irbuilder.IRAttributeValue]: """Translate an attribute-value specification of the form `attr_name=<expr>` in a call to an op. expr is an AST. The following cases are supported: * Expr evaluates to a script-time constant (a python-value) that can be mapped into an ONNX attribute value, or * Expr evaluates to None, in which case None is returned, or * Expr must be an attribute-reference, that is a name representing an attribute-parameter of a containing function. """ if isinstance(expr, ast.Name): val = self._lookup(expr.id, self._source_of(expr)) if isinstance(val, values.AttrRef): attr_ref = self.ir_builder.make_attr_ref(attr_name, val.value, val.typeinfo) if attr_meta is not None and (attr_ref.type != attr_meta.type): self.fail( expr, f"Attribute type '{attr_ref.type}' does not match expected type '{attr_meta.type}'", ) return attr_ref if isinstance(val, irbuilder.IRFunction): # Check that outer-scope variables referenced by function have same value # at function-definition site and use-as-attribute site, to avoid errors. for pyvar, previous in val.outer_scope_variables: current = self._lookup(pyvar, self._source_of(expr)) if current.value != previous.value: self.fail( expr, f"Outer scope variable '{pyvar}' referenced by function " f"'{expr.id!r}' modified.", ) # Create GraphProto attribute val = val.to_graph_proto() else: val = self._eval_constant_expr(expr) # In ONNX, there is no way to explicitly specify a None value for an attribute. # Instead, the attribute must be omitted from the attribute list. # Hence, we do not create an attribute-proto if the value is None. # The caller is responsible for omitting such attribute-values from the list of attributes # in a NodeProto. if val is None: if attr_meta and attr_meta.required: self.fail(expr, f"Attribute '{attr_name}' is required.") return None attr_type = attr_meta.type if attr_meta else None attr = self._make_onnx_attr(attr_name, val, attr_type) if attr_meta and (attr.type != attr_meta.type): self.fail( expr, f"Attribute type '{attr.type}' does not match expected type '{attr_meta.type}'", ) return attr def _translate_docstring(self, node: ast.Expr) -> None: if hasattr(node.value, "value"): # python 3.8+ return self.ir_builder.add_docstring(self._current_fn, node.value.value) raise TypeError( f"Unexpected type {type(node)!r} for node. Unsupoorted version of python." ) def _translate_expr( self, node: ast.AST, target: Optional[PreferredName] = None ) -> Variable: """Expression-translation generates "IR statements/nodes" that compute the value of the expression into a target-variable, and returns the variable that is assigned this value. """ if isinstance(node, ast.Call): r = self._translate_call_expr(node) elif isinstance(node, (ast.BinOp, ast.BitAnd, ast.BitOr)): r = self._translate_binary_op_expr(node) elif isinstance(node, ast.UnaryOp): r = self._translate_unary_op_expr(node) elif isinstance(node, ast.Compare): r = self._translate_compare_expr(node) elif isinstance(node, ast.Name): r = self._translate_name_expr(node) elif isinstance(node, ast.Subscript): r = self._translate_subscript_expr(node, target) elif self._is_constant_expr(node): r = self._emit_const(self._eval_constant_expr(node), target, self._source_of(node)) else: raise ValueError( self._message(node, f"Unsupported expression type {type(node)!r}.") ) if isinstance(r, Variable): return r callee, args, attrs = r target = "tmp" if target is None else target assert isinstance(target, str) result = self.generate_unique_name(target) self.emit([result], callee, args, attrs) return Variable(result) def _translate_opt_expr(self, node: ast.expr) -> Optional[Variable]: """Translation of an expression where "None" is permitted (eg., for an optional argument). None is represented as a NameConstant in Python 3.7 and Constant in Python 3.9. """ if isinstance(node, (ast.NameConstant, ast.Constant)) and (node.value is None): return None return self._translate_expr(node) def _translate_subscript_expr( self, node: ast.Subscript, target: Optional[PreferredName] ) -> Variable: """List of supported syntaxes is below. `A` is a tensor or an expression equivalent to a tensor. :: A[:, 1] A[:2, 0] A[:2, :1] A[2:0:-1] A[1:] A[:2] A[1:-1] A[1:2] A[-1] A[0] A[:0:-1] *i* is a tensor holding one integer. :: A[i] A[i+1:i+2] Fully supported for python 3.9+. :: A[i:i+j, k] Not supported: :: A[::-1] """ var = self._translate_expr(node.value) var_name = var.name if target is None: target = f"{var_name}_subscripted" target = self.generate_unique_name(target) indices = ast_utils.normalize_subscript_expr(node) info = self._source_of(node.slice if PY_VERSION_GE_39 else node) # Create cached int constants: # TODO: Do this at a graph-scope level. cached_int_consts = {} def const_1d(value, name: Optional[str] = None): nonlocal cached_int_consts if value not in cached_int_consts: cached_int_consts[value] = self._emit_const([value], name, info) return cached_int_consts[value] def one_1d(): return const_1d(1) # Max/min 64-bit int values are used to represent default values for start/stop in Slice. maxint = (1 << 63) - 1 minint = -(1 << 63) def translate_slice_component( node_arg, default_value: Optional[int] = None ) -> tuple[str, Optional[int]]: """Translate optional start/stop/step component of a Slice expression.""" if node_arg is None: if default_value is None: # TODO: Emit "Where(step > 0, pos_default, neg_default)" raise RuntimeError( "Default start/stop not supported when step direction is unknown." ) return const_1d(default_value), default_value if self._is_constant_expr(node_arg): cst = self._eval_constant_expr(node_arg) if isinstance(cst, int): return const_1d(cst), cst else: raise RuntimeError(f"Slice component type must be int, not {type(cst)}") else: name = self._translate_expr(node_arg).name reshaped = self.generate_unique_name(f"{name}_reshaped") self.emit( [reshaped], values.Op(self.default_opset, "Reshape"), [name, one_1d().name], [], ) return reshaped, None def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: """Translate slice-expression of the form from:to:step.""" step_name, step = translate_slice_component(slice_expr.step, 1) if step is None: # Step direction unknown. # TODO: Handle default-values using runtime check on sign of step. lower_name, _ = translate_slice_component(slice_expr.lower, None) upper_name, _ = translate_slice_component(slice_expr.upper, None) elif step > 0: lower_name, _ = translate_slice_component(slice_expr.lower, 0) upper_name, _ = translate_slice_component(slice_expr.upper, maxint) else: lower_name, _ = translate_slice_component(slice_expr.lower, maxint) upper_name, _ = translate_slice_component(slice_expr.upper, minint) return (lower_name, upper_name, step_name) # An input like X[2] is translated into a Gather op. # An input like X[1:5:2] is translated into a Slice op. # An input like X[2, 3] is translated into a Slice + Squeeze (instead of two Gathers), # as an optimization. # An input like X[I, J] is translated into two Gathers (which is correct whatever the # rank of I and J) # To replace multiple Gathers by the Slice we need to know that the index-values # are scalars. # As the first step, we partition the index elements into four kinds: Slice (eg., 1:5:2), # known-to-be-scalar (eg., 2), other-tensor (eg., I), skip/no-op (that is, just ":") sliced_indices: List[Tuple[int, ast.expr]] = [] scalar_indices: List[Tuple[int, ast.expr]] = [] non_scalar_indices: List[Tuple[int, ast.expr]] = [] for axis, elt in enumerate(indices): if isinstance(elt, ast.Slice): # Add to sliced_indices, unless it is "::", which is a no-op. if not (elt.lower is None and elt.upper is None and elt.step is None): sliced_indices.append((axis, elt)) elif self._is_constant_expr(elt) and isinstance( self._eval_constant_expr(elt), int ): scalar_indices.append((axis, elt)) else: non_scalar_indices.append((axis, elt)) if not (sliced_indices or scalar_indices or non_scalar_indices): # Edge case: no index specified. Eg. A[:, :] self.emit([target], "Identity", [var_name]) return Variable(target) if sliced_indices or len(scalar_indices) > 1: # We emit a Slice operation if we have any indices like 1:5:2 or if the number of # scalar indices (like 2) is more than 1. starts = [] ends = [] axes = [] steps = [] squeezed_axes = [] for axis, expr in scalar_indices: # Treat a scalar index i as slice "i:i+1:1", but squeeze the axis finally. # TODO: handle negative i index = self._eval_constant_expr(expr) squeezed_axes.append(axis) kwargs = dict( lineno=getattr(expr, "lineno", node.lineno), col_offset=getattr(expr, "col_offset", node.col_offset), ) element = ast.Slice( ast.Constant(index, **kwargs), ast.Constant(index + 1, **kwargs), ast.Constant(1, **kwargs), ) sliced_indices.append((axis, element)) scalar_indices = [] for axis, element in sliced_indices: axis_var = const_1d(axis) inputs = translate_slice(element) starts.append(inputs[0]) ends.append(inputs[1]) axes.append(axis_var.name) steps.append(inputs[2]) if len(starts) > 1: axis_0_attr = self._make_onnx_attr("axis", 0) start_name = self.generate_unique_name(f"{var_name}_start") self.emit([start_name], "Concat", starts, [axis_0_attr]) end_name = self.generate_unique_name(f"{var_name}_end") self.emit([end_name], "Concat", ends, [axis_0_attr]) axes_name = self.generate_unique_name(f"{var_name}_axis") self.emit([axes_name], "Concat", axes, [axis_0_attr]) steps_name = self.generate_unique_name(f"{var_name}_step") self.emit([steps_name], "Concat", steps, [axis_0_attr]) else: start_name = starts[0] end_name = ends[0] axes_name = axes[0] steps_name = steps[0] if squeezed_axes: sliced_name = self.generate_unique_name(f"{var_name}_sliced") self.emit( [sliced_name], "Slice", [var_name, start_name, end_name, axes_name, steps_name], ) squeezed_axes = self._emit_const(squeezed_axes, "squeezed_axes", info) if non_scalar_indices: # use temporary to store result of squeeze result = self.generate_unique_name(f"{var_name}_squeezed") else: # store squeezed result in final target result = target self.emit([result], "Squeeze", [sliced_name, squeezed_axes]) else: if non_scalar_indices: # use temporary to store result of Slice result = self.generate_unique_name(f"{var_name}_sliced") else: # store result of Slice in final target result = target slice_inputs = [var_name, start_name, end_name, axes_name, steps_name] self.emit([result], "Slice", slice_inputs) else: result = var_name non_scalar_indices.extend(scalar_indices) if non_scalar_indices: last_axis, _ = non_scalar_indices[-1] else: # TODO(justinchuby): Clarify what last_axis should be when non_scalar_indices is False last_axis = None for axis, index_expr in non_scalar_indices: index_value = self._translate_expr(index_expr) axis_attr = self._make_onnx_attr("axis", axis) # use Gather to perform indexing # Assign gathered value to either temporary or final target if axis != last_axis: # use temporary to store result of Gather gathered = self.generate_unique_name(f"{var_name}_axis_{axis}") else: # store result of Gather in final target gathered = target self.emit([gathered], "Gather", [str(result), index_value], [axis_attr]) result = gathered return Variable(result) def _translate_call_expr(self, node: ast.Call): """Translates a call-expression.""" callee = self._translate_callee_expr(node.func) param_schemas = callee.param_schemas() # If the callee's schema is available, we use it to determine the inputs and attributes. # Otherwise, we map named arguments to attributes and positional arguments to inputs. if param_schemas: kwargs = {x.arg: x.value for x in node.keywords} args, attrs = param_manipulation.separate_input_attributes_from_arguments( param_schemas, node.args, kwargs, fill_defaults=False ) args = [self._translate_opt_expr(x) for x in args] attrs = [ self._translate_attr(x, y, callee.op_schema.attributes[x]) for x, y in attrs.items() ] else: args = [self._translate_opt_expr(x) for x in node.args] attrs = [self._translate_attr(x.arg, x.value) for x in node.keywords] args = autocast.static_cast_inputs(self, callee.op_schema, args) # In ONNX, there is no way to explicitly specify a None value for an attribute. # Instead, the attribute must be omitted from the attribute list. # Hence, we do not create an attribute-proto if the value is None. attrs = [attr for attr in attrs if attr is not None] return callee, args, attrs def _cast_like_binary_expression(self, op, left, right): schema = op.op_schema return autocast.static_cast_inputs(self, schema, (left, right)) def _translate_binary_op_expr(self, node: ast.BinOp): op = type(node.op) if op not in primop_map: raise ValueError(self._message(node, f"Unsupported operator {op!r}.")) attr = [] if isinstance(node.op, ast.Mod) and self._is_constant_expr(node.right): # specific case X % f where f is a float. # attribute fmod=1 is added in that case. cst = self._eval_constant_expr(node.right) if isinstance(cst, float): attr = [self._make_onnx_attr("fmod", 1)] op = values.Op(self.default_opset, primop_map[op]) left, right = self._cast_like_binary_expression( op, self._translate_expr(node.left), self._translate_expr(node.right) ) return op, [left, right], attr def _translate_unary_op_expr(self, node): op = type(node.op) if op not in primop_map: raise ValueError(self._message(node, self).msg(f"Unsupported operator {op!r}.")) if self._is_constant_expr(node.operand): # This function changed the constant node.operand # and returns it. The function calling this one # should intercept this call and replace node # by node.operand. # This mechanism does not handle somthing like `(-(-5))`. if hasattr(node.operand, "value"): # python 3.8+ val = node.operand.value else: raise TypeError( f"Unable to guess constant value from type {type(node.operand)!r} " f"and attributes {dir(node.operand)!r}." ) if op == ast.USub: cst = ast.Constant(-val, lineno=node.lineno, col_offset=node.col_offset) return self._translate_expr(cst) if op == ast.UAdd: return self._translate_expr(node.operand) opname = primop_map[op] operand = self._translate_expr(node.operand) return values.Op(self.default_opset, opname), [operand], [] def _translate_compare_expr(self, node): # TODO: handle multiple comparisons in one expression assert len(node.ops) == 1 assert len(node.comparators) == 1 op = type(node.ops[0]) if op not in primop_map: raise ValueError(self._message(node, f"Unsupported operator {op!r}.")) opname = primop_map[op] left = self._translate_expr(node.left) right = self._translate_expr(node.comparators[0]) # NotEqual is not a standard ONNX op, and needs to be translated into # an Equal op/node followed by a Not op/node. op = values.Op(self.default_opset, opname if opname != "NotEqual" else "Equal") left, right = self._cast_like_binary_expression(op, left, right) if opname == "NotEqual": tmp = self.generate_unique_name() self.emit([tmp], op, [left, right]) not_op = values.Op(self.default_opset, "Not") return not_op, [tmp], [] return op, [left, right], [] def _translate_name_expr(self, node: ast.Name) -> Variable: return self._py_var_to_onnx_var(node.id, self._source_of(node)) # pylint: disable=inconsistent-return-statements def _translate_opset_expr(self, node: ast.Attribute) -> values.Opset: """Return an Opset""" if isinstance(node, ast.Name): val = self._lookup(node.id, self._source_of(node), raise_exception=False) if isinstance(val, values.Opset): return val self.fail(node, f"'{node.id}' is not an instance of type Opset but {type(val)}.") elif isinstance(node, ast.Attribute): self.fail(node, "Nested module unimplemented.") # TODO else: self.fail(node, "Invalid opset expression.") # pylint: enable=inconsistent-return-statements def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable=R1710 """Return an Op""" if isinstance(node, ast.Attribute): module = self._translate_opset_expr(node.value) self._set_default_opset(module, node) opname = node.attr if opname in module: return values.Op(module, node.attr) warn(f"'{opname}' is not a known op in '{module}'") return values.Op(module, node.attr) if isinstance(node, ast.Name): function_name = node.id found = self._lookup(function_name, self._source_of(node), raise_exception=False) if isinstance(found, onnxscript.OnnxFunction): self._current_fn.add_called_function(found) return found if isinstance(found, values.Op): return found if not found: if function_name not in self.default_opset: warn( f"Unknown function name {function_name!r}. " f"The ONNX graph may not work." ) return values.Op(self.default_opset, function_name) self.fail(node, "Invalid callee") def _translate_stmt(self, node: ast.stmt, index_of_stmt=None) -> None: """Statement translation: A single Python statement is mapped into a sequence of IR statements. """ if isinstance(node, ast.Assign): return self._translate_assign_stmt(node) if isinstance(node, ast.AnnAssign): return self._translate_assign_stmt(node) if isinstance(node, ast.Return): if index_of_stmt is not None: return self._translate_return_stmt(node) raise ValueError( self._message( node, "Return statements are not permitted inside control-flow statements." ) ) if isinstance(node, ast.If): return self._translate_if_stmt(node) if isinstance(node, (ast.For, ast.While)): return self._translate_loop_stmt(node) if ast_utils.is_doc_string(node): if index_of_stmt == 0: return self._translate_docstring(node) return None if isinstance(node, ast.FunctionDef): return self._translate_nested_function_def(node) if ast_utils.is_print_call(node): return None raise ValueError(self._message(node, f"Unsupported statement type '{type(node)!r}'.")) def _translate_assign_stmt(self, stmt: Union[ast.Assign, ast.AnnAssign]) -> None: def assign(lhs: ast.AST, rhs: ast.AST) -> None: if isinstance(lhs, ast.Name): # Assignments of the form "x = SomeExpression" info = self._source_of(lhs) lhs = lhs.id t = self._translate_expr(rhs, lhs).name if isinstance(stmt, ast.AnnAssign): typeinfo = self._eval_constant_expr(stmt.annotation) else: typeinfo = None var = values.Dynamic(t, values.DynamicKind.Intermediate, info, typeinfo) self._bind(lhs, var) elif isinstance(lhs, ast.Tuple): # Assignments of the form "x, y, z = op.SomeOp(...)" if not isinstance(rhs, ast.Call): self.fail( rhs, f"RHS must be a Call expression for unpacking, found: '{type(rhs)!r}'", ) callee, inputs, attrs = self._translate_call_expr(rhs) def generate_onnx_name(x: ast.AST): if not isinstance(x, ast.Name): self.fail(x, f"LHS must be a Name for unpacking, found: '{type(x)!r}'") onnx_name = self.generate_unique_name(x.id) self._bind( x.id, values.Dynamic( onnx_name, values.DynamicKind.Intermediate, self._source_of(x) ), ) return onnx_name outputs = [generate_onnx_name(x) for x in lhs.elts] self.emit(outputs, callee, inputs, attrs) else: self.fail(lhs, f"Unsupported construct in LHS of assignment: '{type(lhs)!r}'") if isinstance(stmt, ast.Assign): targets = stmt.targets else: targets = [stmt.target] if len(targets) != 1: # Assignments of the form "x = y = SomeExpression" self.fail(stmt, "Multi-assignment not supported.") lhs = targets[0] rhs = stmt.value if isinstance(rhs, ast.Tuple): # Assignments of the form "... = Expression1, Expression2" if not isinstance(lhs, ast.Tuple): # Assignments of the form "single_var = Expression1, Expression2". # We do not support tuple-typed variables. self.fail(lhs, f"Left term must be a tuple not '{type(lhs)!r}'.") # Parallel assignments of the form "x, y = Expression1, Expression2" if len(lhs.elts) != len(rhs.elts): self.fail( stmt, "Expected same number of elements on lhs and rhs of assignments." ) for p, r in zip(lhs.elts, rhs.elts): assign(p, r) else: assign(lhs, rhs) def _translate_return_stmt(self, stmt: ast.Return) -> None: def check_num_outputs(n): if self.returntype is not None: if n != len(self.returntype): raise SyntaxError( self._message( stmt, f"Mismatch in number of return values and types. Keyword " f"'return' cannot be used in a subgraph (test, loop). " f"returntype is {self.returntype!r}, num_outputs={n!r}.", ) ) def ret(exp, i, suffix): preferred_name = f"return_val{suffix}" return_var = self._translate_expr(exp, preferred_name).name val = self._lookup(return_var, self._source_of(exp), False) if val and val.kind == values.DynamicKind.Input: # In ONNX, a graph-input cannot be an output of the graph. # We need to insert a copy. return_var = self._emit_copy(return_var, preferred_name) for prev_output in self._current_fn.outputs: if prev_output.name == return_var: # ONNX does not allow duplicate output names. return_var = self._emit_copy(return_var, f"{return_var}_copy") break if self.returntype is None: t = None else: t = self.returntype[i] self.ir_builder.add_output(self._current_fn, return_var, t, self._source_of(stmt)) return return_var val = stmt.value assert val is not None, "Return statement without return-value not supported." if isinstance(val, ast.Tuple): check_num_outputs(len(val.elts)) return [ret(exp, i, str(i)) for i, exp in enumerate(val.elts)] check_num_outputs(1) return ret(val, 0, "") def _translate_if_stmt(self, stmt: ast.If) -> None: if hasattr(stmt, "live_out"): live_defs = list( stmt.live_out.intersection(analysis.assigned_vars(stmt, self._message)) ) else: live_defs = list(analysis.assigned_vars(stmt, self._message)) test = self._translate_expr(stmt.test, "cond").name lineno = self._source_of(stmt).lineno thenGraph, sub_fct_then = self._translate_block( stmt.body, f"thenGraph_{lineno}", live_defs, parent_stmt=stmt ) thenAttr = self._make_onnx_attr("then_branch", thenGraph) elseGraph, sub_fct_else = self._translate_block( stmt.orelse, f"elseGraph_{lineno}", live_defs, parent_stmt=stmt ) elseAttr = self._make_onnx_attr("else_branch", elseGraph) def rename(x): r = self.generate_unique_name(x) self._bind( x, values.Dynamic(r, values.DynamicKind.Intermediate, self._source_of(stmt)), ) return r # no break condition renamed = [rename(x) for x in live_defs] if not renamed: self.fail(stmt, "A subgraph for a test do not have any output variable.") sub_functions = {} sub_functions.update(sub_fct_then) sub_functions.update(sub_fct_else) if renamed == [test]: self.fail(stmt, f"Input and output cannot be the same {renamed!r}.") self.emit( renamed, values.Op(self.default_opset, "If"), [test], [thenAttr, elseAttr], sub_functions=sub_functions, ) def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: # loop-variable if isinstance(loop_stmt, ast.For): if not isinstance(loop_stmt.target, ast.Name): self.fail(loop_stmt, "For loop target must be a single variable.") p_loop_var = loop_stmt.target.id # iter iter = loop_stmt.iter assert isinstance(iter, ast.Call), "Loop bound not a call." if not isinstance(iter.func, ast.Name): self.fail(loop_stmt, f"Unsupported loop bound {iter.func!r}.") if iter.func.id != "range": self.fail( loop_stmt, "Unsupported loop bound, only function 'range' is allowed." ) if not iter.args or len(iter.args) != 1: self.fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.") assert not iter.keywords, "Unsupported loop bound." o_loop_bound = self._translate_expr(iter.args[0], "loop_bound").name o_cond_var = self.generate_unique_name("cond_in") i_cond_var = o_cond_var cond_while = None o_loop_condition = "" # No condition for a for loop. elif isinstance(loop_stmt, ast.While): test = loop_stmt.test if not isinstance(test, ast.Name): self.fail( loop_stmt, "Unexpected condition type {type(loop_stmt)!r} for a while loop, " "it should be 'while <condition_name>:'.", ) p_loop_var = "infinite_loop" o_loop_bound = "" i_cond_var = test.id cond_while = test.id o_cond_var = None o_loop_condition = self._translate_name_expr(test) # we need to go through all the instructions to see # which instruction defines the condition test.id else: self.fail(loop_stmt, f"Unexpected loop type {type(loop_stmt)!r}.") # analyze loop body exposed_uses = analysis.exposed_uses(loop_stmt.body, self._message) vars_def_in_loop = analysis.assigned_vars(loop_stmt.body, self._message) loop_state_vars = vars_def_in_loop.intersection(exposed_uses | loop_stmt.live_out) scan_outputs = set() # TODO outputs = list(loop_state_vars | scan_outputs) # loop-condition: # o_loop_condition = self._emit_const(True, "true", self._source_of(loop_stmt)) # build loop_body self._enter_scope("loop_body", loop_stmt) o_loop_var = self.generate_unique_name(p_loop_var) self.ir_builder.add_input( self._current_fn, o_loop_var, onnx_types.INT64, self._source_of(loop_stmt), ) self._bind( p_loop_var, values.Dynamic(o_loop_var, values.DynamicKind.Loop, self._source_of(loop_stmt)), ) self.ir_builder.add_input( self._current_fn, i_cond_var, onnx_types.BOOL, self._source_of(loop_stmt), ) for pv in loop_state_vars: ov = self.generate_unique_name(pv) # TODO: retrieve the annotation for variable pv is any is specified. # typeinfo = self._eval_constant_expr(pv.annotation) typeinfo = None self.ir_builder.add_input( self._current_fn, ov, typeinfo, self._source_of(loop_stmt) ) self._bind( pv, values.Dynamic(ov, values.DynamicKind.Loop, self._source_of(loop_stmt)), ) condition_name = None operator_name = "Identity" for i, s in enumerate(loop_stmt.body): # We first need to intercept a break instruction in test block. # It must be something like `if <condition_name>: break`. # This instruction must be the last of the loop body. if isinstance(s, ast.If) and len(s.body) == 1 and isinstance(s.body[0], ast.Break): if not isinstance(s.test, ast.Name): self.fail( s, f"Instruction break can be introduced with test but it must be " f"if <condition>: break. However condition is of type " f"{type(s.test)!r}.", ) if i != len(loop_stmt.body) - 1: self.fail(s, "Instruction break must be the last one of the loop.") current_scope = self._current_scope() if s.test.id not in current_scope: self.fail( loop_stmt, f"Unable to find condition variable {s.test.id!r} in known " f"variables {list(current_scope)!r}.", ) condition_name = current_scope[s.test.id].value operator_name = "Not" continue self._translate_stmt(s) o_cond_out = self.generate_unique_name("cond_out") if cond_while is not None: # Loop while current_scope = self._current_scope() if cond_while not in current_scope: self.fail( loop_stmt, f"Unable to find condition variable {cond_while!r} in known " f"variables {list(current_scope)!r}.", ) o_cond_var = current_scope[cond_while].value self.emit( [o_cond_out], values.Op(self.default_opset, operator_name), [condition_name or o_cond_var], [], ) self.ir_builder.add_output( self._current_fn, o_cond_out, onnx_types.BOOL, self._source_of(loop_stmt), ) for pv in loop_state_vars: ov = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)).name if ov not in self._current_fn.assigned_names: # When converting the loop-body into a graph, we need to handle # identity assignments of the form "x = y" inside the loop body # specially if y represents a value computed outside the loop body. # In this case, we create a copy of y, treating the statement as # shorthand for "x = op.Identity(y)". ov = self._emit_copy(ov, pv) # TODO: retrieve variable type for the annotation if any. typeinfo = None self.ir_builder.add_output( self._current_fn, ov, typeinfo, self._source_of(loop_stmt) ) body = self._exit_scope() inputs = [o_loop_bound, o_loop_condition] + [ self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)).name for pv in loop_state_vars ] graph, sub_functions = body.to_graph_and_functions() attrs = [self._make_onnx_attr("body", graph)] info = self._source_of(loop_stmt) def rename(x): r = self.generate_unique_name(x) self._bind(x, values.Dynamic(r, values.DynamicKind.Output, info)) return r onnx_outputs = [rename(x) for x in outputs] self.emit( onnx_outputs, "Loop", inputs, attrs, sub_functions=sub_functions, ) def _translate_block( self, stmts: Sequence[ast.stmt], name: str, live_defs: Sequence[str], parent_stmt: ast.stmt, ): """Translation of a statement-block to GraphProto attribute.""" info_stmt = stmts[0] if len(stmts) > 0 else parent_stmt source = self._source_of(info_stmt) self._enter_scope(name, None) for s in stmts: self._translate_stmt(s) for pvar in live_defs: if pvar in self._current_scope(): pv_val = self._current_scope()[pvar] output = self._to_onnx_var(pv_val, pvar).name if output not in self._current_fn.assigned_names: # To return an outer-scope variable, an ONNX Graph has to # use an explicit copy via Identity. output = self._emit_copy(output, pvar) self.ir_builder.add_output( self._current_fn, output, pv_val.typeinfo, source, ) else: pv_val = None for scope in self._locals: # TODO: skip _current_scope if pvar in scope: pv_val = scope[pvar] break if pv_val is None: self.fail( stmts[0], f"Variable {pvar} is not assigned a value along a conditional " f"branch, known variables: {list(self._locals)}.", ) # introduce a copy ovar = self._emit_copy(self._to_onnx_var(pv_val, pvar).name, pvar) # TODO: retrieve the annotation if any. typeinfo = None self.ir_builder.add_output(self._current_fn, ovar, typeinfo, source) graph = self._exit_scope() return graph.to_graph_and_functions() def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None: """Translate a nested function definition.""" self._enter_scope(fn.name, fn) self._translate_function_def_common(fn) function_ir = self._exit_scope() outer_scope_vars = analysis.outer_scope_variables(fn, self._message) function_ir.outer_scope_variables = [ (var, self._lookup(var, self._source_of(fn))) for var in outer_scope_vars ] self._bind(fn.name, function_ir) # TODO: Does not yet handle nested functions within nested functions. self._current_fn.add_nested_function(function_ir) def _translate_function_signature_common( self, fn: ast.FunctionDef ) -> irbuilder.IRFunction: """Translate a function signature (top-level or nested).""" args = fn.args if args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg: warn(f"{fn.name}: Unsupported feature in function signature.") for i, x in enumerate(args.args): arg_with_default_start_index = len(args.args) - len(args.defaults) if args.defaults and i >= arg_with_default_start_index: default_value = self._eval_constant_expr( args.defaults[i - arg_with_default_start_index] ) else: default_value = None if x.annotation: typeinfo = self._eval_constant_expr(x.annotation) if not ta.is_valid_type(typeinfo): self.warn( x.annotation, f"Unsupported type annotation for argument {x.arg}.", ) typeinfo = None else: # The code can only be exported as a function. typeinfo = None if typeinfo and ta.is_attr_type(typeinfo): self.ir_builder.add_attr_parameter( self._current_fn, x.arg, ta.pytype_to_attrtype(typeinfo), default_value, ) self._bind(x.arg, values.AttrRef(x.arg, typeinfo, self._source_of(x))) else: self.ir_builder.add_input( self._current_fn, x.arg, typeinfo, self._source_of(x) ) self._used_vars.add(x.arg) self._bind( x.arg, values.Dynamic(x.arg, values.DynamicKind.Input, self._source_of(x)), ) if fn.returns: type_annotation = self._eval_constant_expr(fn.returns) self.returntype = ta.get_return_types(type_annotation) invalid = False for t in self.returntype: if not ta.is_valid_type(t): self.warn( fn.returns, f"Unsupported type annotation for return value {t}.", ) invalid = True if invalid: self.returntype = None else: self.returntype = None return self._current_fn def _translate_function_def_common(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: """Translate a function definition, including the signature and its body.""" logger.debug("Converter:_translate_function_def_common:%s", fn.name) _ = self._translate_function_signature_common(fn) for i, s in enumerate(fn.body): self._translate_stmt(s, index_of_stmt=i) return self._current_fn def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: if isinstance(stmt, ast.FunctionDef): self._init_function_translation() if self.default_opset_ is None: opset = self._find_onnx_opset(stmt) if opset: self._set_default_opset(opset, stmt) domain = self.this_module.domain self._current_fn = self.ir_builder.new_function(stmt.name, domain, True) analysis.do_liveness_analysis(stmt, self._message) fn_ir = self._translate_function_def_common(stmt) fn_ir.debug_print() self.this_module.add_function_def(fn_ir) return fn_ir raise ValueError(f"Unsupported top-level statement type {type(stmt)!r}.")
[docs] def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: """Translate a (top-level) function signature.""" domain = self.this_module.domain self._current_fn = self.ir_builder.new_function(fn.name, domain, True) return self._translate_function_signature_common(fn)