Introducing ONNX Script: authoring ONNX with the ease of Python

An ONNX Refresher

ONNX models are flexible, standardized representations of machine learning that allow them to be executed across a gamut of hardware platforms and runtime environments from large scale cloud-based super computers to resource-constrained edge devices such as your web browser and phone.

Typically, machine learning models are developed using higher level frameworks such as PyTorch and TensorFlow. While these frameworks tend to be productive for iterating on the development of models, the models are not typically deployed to production in this fashion. Instead, they are exported to ONNX by facilities provided by the frameworks, and then optimized for a particular target by tools such as Olive.

Beyond its graph format, canonically represented using Protobuf, ONNX consists of a standard set of primitive operators which are implemented by runtimes and hardware vendors alike. With this broad ecosystem in mind, ONNX aims to keep the number of these operators low, encouraging composability through ONNX functions. This is important to reduce the overhead of supporting ONNX.

Announcing ONNX Script

ONNX Script is a new open-source library for directly authoring ONNX models in Python with a focus on clean, idiomatic Python syntax and composability through ONNX-native functions. Critically, it is also the foundation upon which we are building the new PyTorch ONNX exporter to support TorchDynamo – the future of PyTorch.

Prior to ONNX Script, authoring ONNX models required deep knowledge of the specification and serialization format itself. While eventually a more convenient helper API was introduced that largely abstracted the serialization format, it still required deep familiarity with ONNX constructs.

ONNX Script takes a new approach by integrating deeply with Python on two fronts: first, it provides a strongly typed API for all operators in ONNX (all 186 as of opset 19). This allows existing Python tooling, linters, and IDEs to provide valuable feedback and enforce correctness. Second, ONNX Script supports idiomatic Python language constructs to make authoring ONNX more natural, including support for conditionals and loops, binary and unary operators, subscripting, slicing, and more. For example, the expression a + b in Python would translate to the ONNX operator Add(a, b).

Let’s look at how we might implement GELU using ONNX Script and compare with onnx.helper API; and to be clear, the examples below produce the same ONNX model. For reference, we’ll use this definition of GELU to guide the ONNX implementations:

GELU in ONNX Script

import math
from onnxscript import (
    script, opset18 as op, FLOAT
)

M_SQRT1_2 = math.sqrt(0.5)

@script()
def gelu(X: FLOAT[...]):
    phiX = 0.5 * (op.Erf(M_SQRT1_2 * X) + 1.0)
    return X * phiX

gelu_model = gelu.to_model_proto()

GELU with the ONNX Helper API

# Note: the code is long enough that you'll need
# to scroll the view to read it all ¯\_(ツ)_/¯

import math
import onnx
import onnx.helper

gelu_model = onnx.helper.make_model(
    ir_version=8,
    opset_imports=[onnx.helper.make_operatorsetid("", 18)],
    graph=onnx.helper.make_graph(
        name="Gelu",
        nodes=[
            onnx.helper.make_node(
                "Constant",
                inputs=[],
                outputs=["tmp"],
                name="n0",
                value=onnx.helper.make_tensor(
                    "tmp", onnx.TensorProto.FLOAT, dims=[], vals=[0.5]
                ),
            ),
            onnx.helper.make_node(
                "Constant",
                inputs=[],
                outputs=["M_SQRT1_2"],
                name="n1",
                value=onnx.helper.make_tensor(
                    "M_SQRT1_2",
                    onnx.TensorProto.FLOAT,
                    dims=[],
                    vals=[math.sqrt(0.5)],
                ),
            ),
            onnx.helper.make_node(
                "CastLike",
                inputs=["M_SQRT1_2", "X"],
                outputs=["M_SQRT1_2_cast"],
                name="n2",
            ),
            onnx.helper.make_node(
                "Mul",
                inputs=["M_SQRT1_2_cast", "X"],
                outputs=["tmp_0"],
                name="n3",
            ),
            onnx.helper.make_node(
                "Erf", inputs=["tmp_0"], outputs=["tmp_1"], name="n4"
            ),
            onnx.helper.make_node(
                "Constant",
                inputs=[],
                outputs=["tmp_2"],
                name="n5",
                value=onnx.helper.make_tensor(
                    "tmp_2",
                    onnx.TensorProto.FLOAT,
                    dims=[],
                    vals=[1.0],
                ),
            ),
            onnx.helper.make_node(
                "CastLike",
                inputs=["tmp_2", "tmp_1"],
                outputs=["tmp_2_cast"],
                name="n6",
            ),
            onnx.helper.make_node(
                "Add",
                inputs=["tmp_1", "tmp_2_cast"],
                outputs=["tmp_3"],
                name="n7",
            ),
            onnx.helper.make_node(
                "CastLike",
                inputs=["tmp", "tmp_3"],
                outputs=["tmp_cast"],
                name="n8",
            ),
            onnx.helper.make_node(
                "Mul",
                inputs=["tmp_cast", "tmp_3"],
                outputs=["phiX"],
                name="n9",
            ),
            onnx.helper.make_node(
                "Mul",
                inputs=["X", "phiX"],
                outputs=["return_val"],
                name="n10",
            ),
        ],
        inputs=[
            onnx.helper.make_value_info(
                name="X",
                type_proto=onnx.helper.make_tensor_type_proto(
                    elem_type=onnx.TensorProto.FLOAT, shape=[]
                ),
            )
        ],
        outputs=[
            onnx.helper.make_value_info(
                name="return_val",
                type_proto=onnx.helper.make_tensor_type_proto(
                    elem_type=onnx.TensorProto.FLOAT, shape=[]
                ),
            )
        ],
    ),
)

As you can see, ONNX Script emphasizes the familiar readability and productivity of Python while expressing an ONNX model that can be statically reasoned about by existing Python and ONNX tooling.

This also means ONNX comes alive within the context of your existing tooling and development environments, be it debugging in Visual Studio Code or demonstrating concepts in a Jupyter Notebook — ONNX Script integrates naturally.

Why are we investing in ONNX Script?

Much has changed since ONNX support for PyTorch was originally introduced over five years ago in PyTorch 0.3.0. For PyTorch 2.0, TorchDynamo represents the eventual deprecation of TorchScript, which implies a major overhaul of the ONNX exporter is necessary. We are fully embracing this as an opportunity to revisit the fundamentals upon which the exporter is built, and ONNX Script is its new foundation. We began this effort in November of last year and have worked closely with PyTorch engineers to ensure TorchDynamo is a fully capable starting point for exporting high fidelity ONNX for years to come.

One of the first streams of work we started was the development of what we call Torchlib, a pure ONNX implementation of the operators in PyTorch – namely Core ATen IR and Prims IR; and of course, these operators are implemented in ONNX Script! This approach greatly simplifies the central responsibility of the exporter as it “just” needs to project FX graph nodes produced by TorchDynamo into ONNX graph nodes, without concerning itself with the implementation details of individual operators.

We will cover the new PyTorch ONNX exporter in a separate post with more depth as PyTorch 2.1 approaches, but for now the key takeaway is that ONNX Script is pervasive throughout our renewed approach. For those willing to try bold new things, the new exporter is available as a preview in PyTorch nightly via the torch.onnx.dynamo_export API.

By deeply weaving ONNX Script support into the PyTorch ONNX exporter, we have also made it possible to augment PyTorch model code with specialized ONNX functions as custom operators. We introduced initial support for this in the TorchScript exporter starting with PyTorch 1.13 and continue to refine this capability in the new exporter.

An End-to-End Example

Let’s look at an example slightly more complicated than GELU. In fact, the following example is adapted directly from the new PyTorch ONNX exporter, implementing support for torch.chunk, which attempts to split a tensor into the number of specified chunks.

from typing import Sequence
from onnxscript import opset18 as op, script, FLOAT, INT64

@script()
def aten_chunk(
    tensor: FLOAT[...], chunks: int, dim: int = 0,
) -> Sequence[FLOAT[...]]:
    neg_1 = op.Constant(value_ints=[-1])

    # Get size of specified dim
    dim_size = op.Shape(tensor)[dim]

    # Compute size/chunk to get the number of data in one chunk
    num_per_chunk = dim_size / chunks + op.Cast(dim_size % chunks > 0, to=INT64.dtype)

    # Compute real chunk number
    num_chunk = dim_size / num_per_chunk

    # Get something like [n, n, n, n, ...], total num_chunk
    list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1))

    remainder = dim_size % num_per_chunk
    if remainder > 0:
        # Append the remainder to the [n, n, n, n, ..., r]
        list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0)

    return op.SplitToSequence(tensor, list_split, axis=dim)

We start by importing from onnxscript the ONNX opset we wish to use (version 18 in this case), the @script decorator, and the tensor types of FLOAT and INT64. In ONNX Script, tensor shapes are denoted by subscripting the type, such as FLOAT[2, 10], symbolically such as FLOAT["M", "N"], or FLOAT[...] in case the tensor shape is unknown. Without subscripting (just FLOAT), the type is intended to indicate a scalar (a tensor of rank 0).

Next, we define the aten_chunk function with type annotations and implement the body of the function using both built-in Python syntax and explicit invocations of ONNX operators. The example uses various binary expressions and an if statement, but many other idiomatic Python constructs are also supported.

We also need to define a simple model that calls our ONNX Script function so we can export and verify an end-to-end example:

@script()
def ten_chunks_model(tensor: FLOAT["M"]):
    return aten_chunk(tensor, chunks=10)

This model will simply split the provided tensor into ten tensors, but it also demonstrates that ONNX functions can of course call other ONNX functions, not just built-in ONNX operators.

We’ll now export our ONNX Script model to ONNX and explore it in Netron. Functions decorated with @script allow them to be exported using the to_model_proto function.

import onnx
onnx.save_model(
    ten_chunks_model.to_model_proto(),
    "ten_chunks_model.onnx",
)

If we open ten_chunks_model.onnx in Netron, we can observe the composability of ONNX functions and how the Python code was translated into an ONNX model.

Exploring the ONNX model in Netron

The graphs depict our two ONNX functions; we can observe the original input tensor flowing from ten_chunks_model into aten_chunk along with the attribute chunks=10. A sequence of ≤ 10 tensors is returned. As one might expect, functions in ONNX can be defined once and invoked any number of times within a model. Read more about core ONNX concepts.

Iterating & Debugging

Finally, we should test our model! ONNX Script makes this easy since it provides a mechanism for eagerly evaluating the model through either ONNX Runtime or the new ONNX Reference Evaluator. Of note, ONNX Script has built-in support for NumPy for input and output values to ease the overhead of creating and filling tensors in ONNX.

import numpy as np
tensor = np.array(range(0, 48), dtype=np.float32)
chunked_tensors = ten_chunks_model(tensor)

from pprint import pprint
pprint(tensor)
pprint(chunked_tensors)
array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
       13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
       26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38.,
       39., 40., 41., 42., 43., 44., 45., 46., 47.], dtype=float32)
[array([0., 1., 2., 3., 4.], dtype=float32),
 array([5., 6., 7., 8., 9.], dtype=float32),
 array([10., 11., 12., 13., 14.], dtype=float32),
 array([15., 16., 17., 18., 19.], dtype=float32),
 array([20., 21., 22., 23., 24.], dtype=float32),
 array([25., 26., 27., 28., 29.], dtype=float32),
 array([30., 31., 32., 33., 34.], dtype=float32),
 array([35., 36., 37., 38., 39.], dtype=float32),
 array([40., 41., 42., 43., 44.], dtype=float32),
 array([45., 46., 47.], dtype=float32)]

Because ONNX Script’s eager mode evaluates the model on an op-by-op basis, it is conducive to debugging ONNX using standard Python tooling such as pdb directly or through richer IDE and editor integrations provided by Visual Studio and Visual Studio Code.

Screenshot of Visual Studio Code debugging ONNX Script

A screenshot of Visual Studio Code debugging the model while stopped on a breakpoint to inspect the dim_size variable and call stack.

Along with debuggability, IntelliSense support is front-and-center with ONNX Script. We rely heavily on Python type annotations to enforce correctness of ONNX and to make ONNX more discoverable, including inline documentation on hover tooltips and code completion suggestions. A single click will take you to expanded online documentation in your browser as well.

Screenshot of Visual Studio Code IntelliSense for ONNX Script

A screenshot of Visual Studio Code displaying a hover tooltip for the ONNX Expand operator with inline documentation linking to full online documentation.

What’s next for ONNX Script?

In summary, ONNX Script offers a new Python-first programming model for authoring ONNX models that integrates with the existing rich ecosystem of Python tooling and environments.

Going forward, we envision ONNX Script as a means for defining and extending ONNX itself. New core operators and higher order functions that are intended to become part of the ONNX standard could be authored in ONNX Script as well, reducing the time and effort it takes for the standard to evolve. We have proven this is a viable strategy by developing Torchlib, upon which the upcoming PyTorch Dynamo-based ONNX exporter is built.

Over the coming months, we will also support converting ONNX into ONNX Script to enable seamless editing of existing models, which can play a key role for optimization passes, but also allow for maintaining and evolving ONNX models more naturally. We also intend to propose ONNX Script for inclusion directly within the ONNX GitHub organization soon, under the Linux Foundation umbrella.

Check out ONNX Script today on GitHub or install with pip install git+https://github.com/microsoft/onnxscript. We look forward to your feedback!

Finally, a huge thank you to the wonderful engineering team at Microsoft that has brought us to this point so far: Bowen Bao, Aaron Bockover, Shubham Bhokare, Jacky Chen, Wei-Sheng Chin, Justin Chu, Thiago Crepaldi, Xavier Dupre, Liqun Fu, Xaiowu Hu, Ganesan Ramalingam, Ti-Tai Wang, Jay Zhang.