Source code for onnxscript.main

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint disable: protected-access
from __future__ import annotations

import ast
import inspect
import sys
from typing import Any, Callable, Optional, Sequence

import onnx.helper

import onnxscript
from onnxscript import converter, irbuilder, values
from onnxscript._internal import ast_utils


def script_check(
    f: ast.FunctionDef,
    opset: values.Opset,
    global_names: dict[str, Any],
    source: str,
    default_opset: Optional[values.Opset] = None,
) -> irbuilder.IRFunction:
    """Check that a function falls into the ONNXScript subset of Python."""
    # See if conversion succeeds.
    # TODO: cleanup Converter interface/API, separating checker from
    # converter
    convert = converter.Converter(
        opset=opset,
        global_names=global_names,
        source=source,
        default_opset=default_opset,
    )
    return convert.translate_function_def(f)


[docs] def script( opset: Optional[values.Opset] = None, default_opset: Optional[values.Opset] = None, **kwargs: Any, ) -> Callable[[Callable], onnxscript.OnnxFunction]: """Main decorator. Declares a function as an onnx function. Args: opset: Opset the function belongs to (see :ref:`l-api-opsets`). default_opset: Opset to use for operators not in the function's opset. kwargs: Additional keyword arguments. Returns: an instance of :class:`onnxscript.values.OnnxFunction` Example: :: @script() def log2(x): one = op.Constant(value=make_tensor('one', TensorProto.FLOAT, [1], [1])) return op.Div(op.Log(x), op.CastLike(op.Log(cst), x)) Or: :: from onnxscript.onnx_opset import opset16 @script(opset16) def log2(x): one = op.Constant(value=make_tensor('one', TensorProto.FLOAT, [1], [1])) return op.Div(op.Log(x), op.CastLike(op.Log(cst), x)) """ opset = opset or values.Opset("this", 1) if not isinstance(opset, values.Opset): raise TypeError( "Script parameter must be an opset. Did you use @script instead of @script()?" ) def transform(f: Callable) -> onnxscript.OnnxFunction: if not inspect.isfunction(f): raise TypeError("The ONNXScript decorator should be applied to functions only.") src, f_ast = ast_utils.get_src_and_ast(f) # The script should be compiled using the globals/locals at the definition site. # This allows the script to reference names defined outside the script, # which is used for a few different purposes. # The following is an approximate solution that works for normal use. module = inspect.getmodule(f) closure = inspect.getclosurevars(f) env = module.__dict__.copy() env.update(closure.nonlocals) result = script_check(f_ast, opset, env, src, default_opset=default_opset) # TODO: add transformations. return onnxscript.OnnxFunction(opset, f, result, src, kwargs) return transform
def graph() -> Callable[[Callable], values.OnnxClosure]: """A parametric decorator used to annotate nested-functions that are used as graph-attributes. Returns: A decorator that returns its input function, but attaches a graph_proto attribute representing the input function. The translation is not done at this time, but previously when the outer-level function was translated to an OnnxFunction. The decorator just looks up and retrieves the GraphProto representation previously generated. Example: :: @script() def cumulative_sum(X: INT64['N']): # Translation of cumulative_sum by @script will also translate Sum # into a GraphProto, which will be stored in the OnnxFunction generated # for cumulative_sum. At run-time (in eager-mode), the @graph decorator # retrieves the pre-computed GraphProto and attaches it to the Sum function. @graph() def Sum(sum_in, next): sum_out = sum_in + next scan_out = op.Identity(sum_out) return sum_out, scan_out zero = op.Constant(value_int=0) # The call to higher-order operator Scan below uses the above function # Sum as a graph-attribute. all_sum, result = op.Scan (zero, X, body=Sum, num_scan_inputs=1) return result """ # This is a bit fragile. We want to get the ONNXFunction object representing # the outer-scope ONNXScript function from the execution stack. The caller of # @graph is the original script function (cumulative_sum in the above example), # and the caller of that function is the wrapper function/method in the # corresponding OnnxFunction object. # Currently, there is no support for eager-mode execution of nested functions, # so we don't need to handle doubly nested functions (e.g., a function defined # inside Sum in the above example). function_frame = sys._getframe(1) # pylint: disable=protected-access wrapper_frame = sys._getframe(3) # pylint: disable=protected-access onnx_function = wrapper_frame.f_locals["self"] nested_functions = onnx_function.function_ir.nested_functions def transform(f: Callable) -> values.OnnxClosure: return values.OnnxClosure(nested_functions[f.__name__], function_frame, f) return transform def is_converted_fun(f: Any) -> bool: """Return True if f is a function converted by onnxscript decorator.""" return isinstance(f, onnxscript.OnnxFunction) def export_onnx_lib(functions: Sequence[values.OnnxFunction], filename: str) -> None: # Since we don't yet have LibProto defined, we use a ModelProto as a temporary # container for the list of functions exported as a library, with an empty graph # and dummy opset_imports. model = onnx.helper.make_model( onnx.GraphProto(), functions=[f.to_function_proto() for f in functions], producer_name="p2o", opset_imports=[onnx.helper.make_opsetid("", 15)], ) onnx.save(model, filename)