Getting started with ONNX IR 🌱¶
The ONNX IR ships with the ONNX Script package and is available as onnxscript.ir
.
To create an IR object from ONNX file, load it as ModelProto
and call
ir.from_proto()
or ir.serde.deserialize_model
:
import pathlib
import onnx
from onnxscript import ir
# Load the model as onnx.ModelProto
model_proto = onnx.load(
pathlib.Path(ir.__file__).parent.parent.parent
/ "testdata"
/ "dort_models"
/ "llama_forward.onnx"
)
# Create an IR object from the model
model = ir.serde.deserialize_model(model_proto)
---------------------------------------------------------------------------
FileNotFoundError Traceback (most recent call last)
Cell In[1], line 8
5 from onnxscript import ir
7 # Load the model as onnx.ModelProto
----> 8 model_proto = onnx.load(
9 pathlib.Path(ir.__file__).parent.parent.parent
10 / "testdata"
11 / "dort_models"
12 / "llama_forward.onnx"
13 )
15 # Create an IR object from the model
16 model = ir.serde.deserialize_model(model_proto)
File /opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/onnx/__init__.py:210, in load_model(f, format, load_external_data)
189 def load_model(
190 f: IO[bytes] | str | os.PathLike,
191 format: _SupportedFormat | None = None, # noqa: A002
192 load_external_data: bool = True,
193 ) -> ModelProto:
194 """Loads a serialized ModelProto into memory.
195
196 Args:
(...)
208 Loaded in-memory ModelProto.
209 """
--> 210 model = _get_serializer(format, f).deserialize_proto(_load_bytes(f), ModelProto())
212 if load_external_data:
213 model_filepath = _get_file_path(f)
File /opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/onnx/__init__.py:147, in _load_bytes(f)
145 else:
146 f = typing.cast(Union[str, os.PathLike], f)
--> 147 with open(f, "rb") as readable:
148 content = readable.read()
149 return content
FileNotFoundError: [Errno 2] No such file or directory: '/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/testdata/dort_models/llama_forward.onnx'
Now we can explore the IR object
print(f"The main graph has {len(model.graph)} nodes.")
The main graph has 279 nodes.
All inputs
print(model.graph.inputs)
[Input('primals_8', type=Tensor(FLOAT), shape=[2,1,1024,1024], producer=None, index=None), Input('primals_1', type=Tensor(FLOAT), shape=[16,16], producer=None, index=None), Input('primals_6', type=Tensor(FLOAT), shape=[2,1024,16], producer=None, index=None), Input('primals_4', type=Tensor(FLOAT), shape=[16,16], producer=None, index=None), Input('primals_2', type=Tensor(FLOAT), shape=[16,16], producer=None, index=None), Input('primals_3', type=Tensor(FLOAT), shape=[16,16], producer=None, index=None), Input('primals_5', type=Tensor(FLOAT), shape=[4], producer=None, index=None), Input('primals_7', type=Tensor(INT64), shape=[1,1024], producer=None, index=None)]
All outputs
print(model.graph.outputs)
[Value('view', type=Tensor(FLOAT), shape=[2048,16], producer=True, index=0), Value('t_6', type=Tensor(FLOAT), shape=[16,16], producer=True, index=0), Value('transpose_8', type=Tensor(FLOAT), shape=[4,8,1024], producer=True, index=0), Value('cat', type=Tensor(FLOAT), shape=[1,1024,8], producer=True, index=0), Value('transpose_9', type=Tensor(FLOAT), shape=[4,8,1024], producer=True, index=0), Value('transpose_10', type=Tensor(FLOAT), shape=[4,1024,8], producer=True, index=0), Value('detach_3', type=Tensor(FLOAT), shape=[2,2,1024,1024], producer=True, index=0), Value('transpose_7', type=Tensor(FLOAT), shape=[4,1024,1024], producer=True, index=0), Value('view_19', type=Tensor(FLOAT), shape=[2048,16], producer=True, index=0), Value('view_20', type=Tensor(FLOAT), shape=[2,1024,16], producer=True, index=0)]
Nodes that uses the first input
print(list(model.graph.inputs[0].uses()))
[(Node(name='Slice_83', domain='', op_type='Slice', inputs=(Input('primals_8', type=Tensor(FLOAT), shape=[2,1,1024,1024], producer=None, index=None), Value('_val_11', type=None, shape=None, producer=True, index=0), Value('_val_15', type=None, shape=None, producer=True, index=0), Value('_val_19', type=None, shape=None, producer=True, index=0), Value('_val_23', type=None, shape=None, producer=True, index=0)), attributes=OrderedDict(), overload='', outputs=(Value('slice_8', type=Tensor(FLOAT), shape=[2,1,1024,1024], producer=True, index=0),), version=None, doc_string=''), 0)]
The node that produces the last output (as the i-th output)
print(model.graph.outputs[-1].producer())
print(model.graph.outputs[-1].index())
%"view_20"<FLOAT,[2,1024,16]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"mm_3", %"_val_285")
0
Examine a Function
print(model.functions[("pkg.onnxscript.torch_lib", "aten_view", "")])
<
opset_imports={'': 18},
>
def pkg.onnxscript.torch_lib::aten_view(
inputs=(
%"self"<?,?>,
%"size"<?,?>
),
outputs=(
%"return_val"<?,?>
),
) {
0 | # n0
%"size_0"<?,?> ⬅️ ::Cast(%"size") {to=7}
1 | # n1
%"return_val"<?,?> ⬅️ ::Reshape(%"self", %"size_0")
return %"return_val"<?,?>
}
Print the graph
model.graph.display(
page=False
) # Set page=True to use a pager in the terminal so long outputs are scrollable
graph( name=main_graph, inputs=( %"primals_8"<FLOAT,[2,1,1024,1024]>, %"primals_1"<FLOAT,[16,16]>, %"primals_6"<FLOAT,[2,1024,16]>, %"primals_4"<FLOAT,[16,16]>, %"primals_2"<FLOAT,[16,16]>, %"primals_3"<FLOAT,[16,16]>, %"primals_5"<FLOAT,[4]>, %"primals_7"<INT64,[1,1024]> ), outputs=( %"view"<FLOAT,[2048,16]>, %"t_6"<FLOAT,[16,16]>, %"transpose_8"<FLOAT,[4,8,1024]>, %"cat"<FLOAT,[1,1024,8]>, %"transpose_9"<FLOAT,[4,8,1024]>, %"transpose_10"<FLOAT,[4,1024,8]>, %"detach_3"<FLOAT,[2,2,1024,1024]>, %"transpose_7"<FLOAT,[4,1024,1024]>, %"view_19"<FLOAT,[2048,16]>, %"view_20"<FLOAT,[2,1024,16]> ), ) { 0 | # Constant_67 %"_val_8"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 1 | # Cast_68 %"_val_9"<?,?> ⬅️ ::Cast(%"_val_8") {to=7} 2 | # Constant_69 %"_val_10"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 3 | # Reshape_70 %"_val_11"<?,?> ⬅️ ::Reshape(%"_val_9", %"_val_10") {allowzero=0} 4 | # Constant_71 %"_val_12"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 5 | # Cast_72 %"_val_13"<?,?> ⬅️ ::Cast(%"_val_12") {to=7} 6 | # Constant_73 %"_val_14"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 7 | # Reshape_74 %"_val_15"<?,?> ⬅️ ::Reshape(%"_val_13", %"_val_14") {allowzero=0} 8 | # Constant_75 %"_val_16"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 9 | # Cast_76 %"_val_17"<?,?> ⬅️ ::Cast(%"_val_16") {to=7} 10 | # Constant_77 %"_val_18"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 11 | # Reshape_78 %"_val_19"<?,?> ⬅️ ::Reshape(%"_val_17", %"_val_18") {allowzero=0} 12 | # Constant_79 %"_val_20"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 13 | # Cast_80 %"_val_21"<?,?> ⬅️ ::Cast(%"_val_20") {to=7} 14 | # Constant_81 %"_val_22"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 15 | # Reshape_82 %"_val_23"<?,?> ⬅️ ::Reshape(%"_val_21", %"_val_22") {allowzero=0} 16 | # Slice_83 %"slice_8"<FLOAT,[2,1,1024,1024]> ⬅️ ::Slice(%"primals_8", %"_val_11", %"_val_15", %"_val_19", %"_val_23") 17 | # aten_t_84 %"t"<FLOAT,[16,16]> ⬅️ pkg.onnxscript.torch_lib::aten_t(%"primals_1") 18 | # Constant_85 %"_val_26"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[2]>(name='')} 19 | # aten_view_86 %"view"<FLOAT,[2048,16]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"primals_6", %"_val_26") 20 | # aten_t_87 %"t_3"<FLOAT,[16,16]> ⬅️ pkg.onnxscript.torch_lib::aten_t(%"primals_4") 21 | # aten_t_88 %"t_1"<FLOAT,[16,16]> ⬅️ pkg.onnxscript.torch_lib::aten_t(%"primals_2") 22 | # aten_t_89 %"t_2"<FLOAT,[16,16]> ⬅️ pkg.onnxscript.torch_lib::aten_t(%"primals_3") 23 | # aten_unsqueeze_90 %"unsqueeze"<FLOAT,[1,4]> ⬅️ pkg.onnxscript.torch_lib::aten_unsqueeze(%"primals_5") {dim=0} 24 | # Constant_91 %"_val_32"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 25 | # Cast_92 %"_val_33"<?,?> ⬅️ ::Cast(%"_val_32") {to=7} 26 | # Constant_93 %"_val_34"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 27 | # Reshape_94 %"_val_35"<?,?> ⬅️ ::Reshape(%"_val_33", %"_val_34") {allowzero=0} 28 | # Constant_95 %"_val_36"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 29 | # Cast_96 %"_val_37"<?,?> ⬅️ ::Cast(%"_val_36") {to=7} 30 | # Constant_97 %"_val_38"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 31 | # Reshape_98 %"_val_39"<?,?> ⬅️ ::Reshape(%"_val_37", %"_val_38") {allowzero=0} 32 | # Constant_99 %"_val_40"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 33 | # Cast_100 %"_val_41"<?,?> ⬅️ ::Cast(%"_val_40") {to=7} 34 | # Constant_101 %"_val_42"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 35 | # Reshape_102 %"_val_43"<?,?> ⬅️ ::Reshape(%"_val_41", %"_val_42") {allowzero=0} 36 | # Constant_103 %"_val_44"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 37 | # Cast_104 %"_val_45"<?,?> ⬅️ ::Cast(%"_val_44") {to=7} 38 | # Constant_105 %"_val_46"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 39 | # Reshape_106 %"_val_47"<?,?> ⬅️ ::Reshape(%"_val_45", %"_val_46") {allowzero=0} 40 | # Slice_107 %"slice_2"<INT64,[1,1024]> ⬅️ ::Slice(%"primals_7", %"_val_35", %"_val_39", %"_val_43", %"_val_47") 41 | # Constant_108 %"_val_49"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 42 | # Cast_109 %"_val_50"<?,?> ⬅️ ::Cast(%"_val_49") {to=7} 43 | # Constant_110 %"_val_51"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 44 | # Reshape_111 %"_val_52"<?,?> ⬅️ ::Reshape(%"_val_50", %"_val_51") {allowzero=0} 45 | # Constant_112 %"_val_53"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 46 | # Cast_113 %"_val_54"<?,?> ⬅️ ::Cast(%"_val_53") {to=7} 47 | # Constant_114 %"_val_55"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 48 | # Reshape_115 %"_val_56"<?,?> ⬅️ ::Reshape(%"_val_54", %"_val_55") {allowzero=0} 49 | # Constant_116 %"_val_57"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 50 | # Cast_117 %"_val_58"<?,?> ⬅️ ::Cast(%"_val_57") {to=7} 51 | # Constant_118 %"_val_59"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 52 | # Reshape_119 %"_val_60"<?,?> ⬅️ ::Reshape(%"_val_58", %"_val_59") {allowzero=0} 53 | # Constant_120 %"_val_61"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 54 | # Cast_121 %"_val_62"<?,?> ⬅️ ::Cast(%"_val_61") {to=7} 55 | # Constant_122 %"_val_63"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 56 | # Reshape_123 %"_val_64"<?,?> ⬅️ ::Reshape(%"_val_62", %"_val_63") {allowzero=0} 57 | # Slice_124 %"slice_9"<FLOAT,[2,1,1024,1024]> ⬅️ ::Slice(%"slice_8", %"_val_52", %"_val_56", %"_val_60", %"_val_64") 58 | # aten_mm_125 %"mm"<FLOAT,[2048,16]> ⬅️ pkg.onnxscript.torch_lib::aten_mm(%"view", %"t") 59 | # aten_t_126 %"t_6"<FLOAT,[16,16]> ⬅️ pkg.onnxscript.torch_lib::aten_t(%"t_3") 60 | # aten_mm_127 %"mm_1"<FLOAT,[2048,16]> ⬅️ pkg.onnxscript.torch_lib::aten_mm(%"view", %"t_1") 61 | # aten_mm_128 %"mm_2"<FLOAT,[2048,16]> ⬅️ pkg.onnxscript.torch_lib::aten_mm(%"view", %"t_2") 62 | # Constant_129 %"_val_70"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 63 | # Cast_130 %"_val_71"<?,?> ⬅️ ::Cast(%"_val_70") {to=7} 64 | # Constant_131 %"_val_72"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 65 | # Reshape_132 %"_val_73"<?,?> ⬅️ ::Reshape(%"_val_71", %"_val_72") {allowzero=0} 66 | # Constant_133 %"_val_74"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 67 | # Cast_134 %"_val_75"<?,?> ⬅️ ::Cast(%"_val_74") {to=7} 68 | # Constant_135 %"_val_76"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 69 | # Reshape_136 %"_val_77"<?,?> ⬅️ ::Reshape(%"_val_75", %"_val_76") {allowzero=0} 70 | # Constant_137 %"_val_78"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 71 | # Cast_138 %"_val_79"<?,?> ⬅️ ::Cast(%"_val_78") {to=7} 72 | # Constant_139 %"_val_80"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 73 | # Reshape_140 %"_val_81"<?,?> ⬅️ ::Reshape(%"_val_79", %"_val_80") {allowzero=0} 74 | # Constant_141 %"_val_82"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 75 | # Cast_142 %"_val_83"<?,?> ⬅️ ::Cast(%"_val_82") {to=7} 76 | # Constant_143 %"_val_84"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 77 | # Reshape_144 %"_val_85"<?,?> ⬅️ ::Reshape(%"_val_83", %"_val_84") {allowzero=0} 78 | # Slice_145 %"slice_1"<FLOAT,[1,4]> ⬅️ ::Slice(%"unsqueeze", %"_val_73", %"_val_77", %"_val_81", %"_val_85") 79 | # aten_unsqueeze_146 %"unsqueeze_2"<INT64,[1,1,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_unsqueeze(%"slice_2") {dim=1} 80 | # Constant_147 %"_val_88"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 81 | # Cast_148 %"_val_89"<?,?> ⬅️ ::Cast(%"_val_88") {to=7} 82 | # Constant_149 %"_val_90"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 83 | # Reshape_150 %"_val_91"<?,?> ⬅️ ::Reshape(%"_val_89", %"_val_90") {allowzero=0} 84 | # Constant_151 %"_val_92"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 85 | # Cast_152 %"_val_93"<?,?> ⬅️ ::Cast(%"_val_92") {to=7} 86 | # Constant_153 %"_val_94"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 87 | # Reshape_154 %"_val_95"<?,?> ⬅️ ::Reshape(%"_val_93", %"_val_94") {allowzero=0} 88 | # Constant_155 %"_val_96"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 89 | # Cast_156 %"_val_97"<?,?> ⬅️ ::Cast(%"_val_96") {to=7} 90 | # Constant_157 %"_val_98"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 91 | # Reshape_158 %"_val_99"<?,?> ⬅️ ::Reshape(%"_val_97", %"_val_98") {allowzero=0} 92 | # Constant_159 %"_val_100"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 93 | # Cast_160 %"_val_101"<?,?> ⬅️ ::Cast(%"_val_100") {to=7} 94 | # Constant_161 %"_val_102"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 95 | # Reshape_162 %"_val_103"<?,?> ⬅️ ::Reshape(%"_val_101", %"_val_102") {allowzero=0} 96 | # Slice_163 %"slice_10"<FLOAT,[2,1,1024,1024]> ⬅️ ::Slice(%"slice_9", %"_val_91", %"_val_95", %"_val_99", %"_val_103") 97 | # Constant_164 %"_val_105"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[3]>(name='')} 98 | # aten_view_165 %"view_1"<FLOAT,[2,1024,16]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"mm", %"_val_105") 99 | # Constant_166 %"_val_107"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[3]>(name='')} 100 | # aten_view_167 %"view_3"<FLOAT,[2,1024,16]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"mm_1", %"_val_107") 101 | # Constant_168 %"_val_109"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[3]>(name='')} 102 | # aten_view_169 %"view_5"<FLOAT,[2,1024,16]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"mm_2", %"_val_109") 103 | # aten_unsqueeze_170 %"unsqueeze_1"<FLOAT,[1,4,1]> ⬅️ pkg.onnxscript.torch_lib::aten_unsqueeze(%"slice_1") {dim=2} 104 | # Constant_171 %"_val_112"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 105 | # Cast_172 %"_val_113"<?,?> ⬅️ ::Cast(%"_val_112") {to=7} 106 | # Constant_173 %"_val_114"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 107 | # Reshape_174 %"_val_115"<?,?> ⬅️ ::Reshape(%"_val_113", %"_val_114") {allowzero=0} 108 | # Constant_175 %"_val_116"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 109 | # Cast_176 %"_val_117"<?,?> ⬅️ ::Cast(%"_val_116") {to=7} 110 | # Constant_177 %"_val_118"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 111 | # Reshape_178 %"_val_119"<?,?> ⬅️ ::Reshape(%"_val_117", %"_val_118") {allowzero=0} 112 | # Constant_179 %"_val_120"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 113 | # Cast_180 %"_val_121"<?,?> ⬅️ ::Cast(%"_val_120") {to=7} 114 | # Constant_181 %"_val_122"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 115 | # Reshape_182 %"_val_123"<?,?> ⬅️ ::Reshape(%"_val_121", %"_val_122") {allowzero=0} 116 | # Constant_183 %"_val_124"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 117 | # Cast_184 %"_val_125"<?,?> ⬅️ ::Cast(%"_val_124") {to=7} 118 | # Constant_185 %"_val_126"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 119 | # Reshape_186 %"_val_127"<?,?> ⬅️ ::Reshape(%"_val_125", %"_val_126") {allowzero=0} 120 | # Slice_187 %"slice_3"<INT64,[1,1,1024]> ⬅️ ::Slice(%"unsqueeze_2", %"_val_115", %"_val_119", %"_val_123", %"_val_127") 121 | # Constant_188 %"_val_129"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[4]>(name='')} 122 | # aten_view_189 %"view_6"<FLOAT,[2,1024,2,8]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"view_1", %"_val_129") 123 | # Constant_190 %"_val_131"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[4]>(name='')} 124 | # aten_view_191 %"view_7"<FLOAT,[2,1024,2,8]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"view_3", %"_val_131") 125 | # Constant_192 %"_val_133"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[4]>(name='')} 126 | # aten_view_193 %"view_8"<FLOAT,[2,1024,2,8]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"view_5", %"_val_133") 127 | # Constant_194 %"_val_135"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[3]>(name='')} 128 | # aten_expand_195 %"expand"<FLOAT,[1,4,1]> ⬅️ pkg.onnxscript.torch_lib::aten_expand(%"unsqueeze_1", %"_val_135") 129 | # Cast_196 %"_to_copy"<FLOAT,[1,1,1024]> ⬅️ ::Cast(%"slice_3") {to=1} 130 | # Transpose_197 %"transpose"<FLOAT,[2,2,1024,8]> ⬅️ ::Transpose(%"view_6") {perm=[0, 2, 1, 3]} 131 | # Transpose_198 %"transpose_1"<FLOAT,[2,2,1024,8]> ⬅️ ::Transpose(%"view_7") {perm=[0, 2, 1, 3]} 132 | # Transpose_199 %"transpose_2"<FLOAT,[2,2,1024,8]> ⬅️ ::Transpose(%"view_8") {perm=[0, 2, 1, 3]} 133 | # Constant_200 %"_val_141"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[3]>(name='')} 134 | # aten_expand_201 %"expand_1"<FLOAT,[1,4,1]> ⬅️ pkg.onnxscript.torch_lib::aten_expand(%"expand", %"_val_141") 135 | # Constant_202 %"_val_143"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[3]>(name='')} 136 | # aten_expand_203 %"expand_2"<FLOAT,[1,1,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_expand(%"_to_copy", %"_val_143") 137 | # Constant_204 %"_val_145"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 138 | # Cast_205 %"_val_146"<?,?> ⬅️ ::Cast(%"_val_145") {to=7} 139 | # Constant_206 %"_val_147"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 140 | # Reshape_207 %"_val_148"<?,?> ⬅️ ::Reshape(%"_val_146", %"_val_147") {allowzero=0} 141 | # Constant_208 %"_val_149"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 142 | # Cast_209 %"_val_150"<?,?> ⬅️ ::Cast(%"_val_149") {to=7} 143 | # Constant_210 %"_val_151"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 144 | # Reshape_211 %"_val_152"<?,?> ⬅️ ::Reshape(%"_val_150", %"_val_151") {allowzero=0} 145 | # Constant_212 %"_val_153"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 146 | # Cast_213 %"_val_154"<?,?> ⬅️ ::Cast(%"_val_153") {to=7} 147 | # Constant_214 %"_val_155"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 148 | # Reshape_215 %"_val_156"<?,?> ⬅️ ::Reshape(%"_val_154", %"_val_155") {allowzero=0} 149 | # Constant_216 %"_val_157"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 150 | # Cast_217 %"_val_158"<?,?> ⬅️ ::Cast(%"_val_157") {to=7} 151 | # Constant_218 %"_val_159"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 152 | # Reshape_219 %"_val_160"<?,?> ⬅️ ::Reshape(%"_val_158", %"_val_159") {allowzero=0} 153 | # Slice_220 %"slice_4"<FLOAT,[2,2,1024,4]> ⬅️ ::Slice(%"transpose", %"_val_148", %"_val_152", %"_val_156", %"_val_160") 154 | # Constant_221 %"_val_162"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 155 | # Cast_222 %"_val_163"<?,?> ⬅️ ::Cast(%"_val_162") {to=7} 156 | # Constant_223 %"_val_164"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 157 | # Reshape_224 %"_val_165"<?,?> ⬅️ ::Reshape(%"_val_163", %"_val_164") {allowzero=0} 158 | # Constant_225 %"_val_166"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 159 | # Cast_226 %"_val_167"<?,?> ⬅️ ::Cast(%"_val_166") {to=7} 160 | # Constant_227 %"_val_168"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 161 | # Reshape_228 %"_val_169"<?,?> ⬅️ ::Reshape(%"_val_167", %"_val_168") {allowzero=0} 162 | # Constant_229 %"_val_170"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 163 | # Cast_230 %"_val_171"<?,?> ⬅️ ::Cast(%"_val_170") {to=7} 164 | # Constant_231 %"_val_172"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 165 | # Reshape_232 %"_val_173"<?,?> ⬅️ ::Reshape(%"_val_171", %"_val_172") {allowzero=0} 166 | # Constant_233 %"_val_174"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 167 | # Cast_234 %"_val_175"<?,?> ⬅️ ::Cast(%"_val_174") {to=7} 168 | # Constant_235 %"_val_176"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 169 | # Reshape_236 %"_val_177"<?,?> ⬅️ ::Reshape(%"_val_175", %"_val_176") {allowzero=0} 170 | # Slice_237 %"slice_5"<FLOAT,[2,2,1024,4]> ⬅️ ::Slice(%"transpose", %"_val_165", %"_val_169", %"_val_173", %"_val_177") 171 | # Constant_238 %"_val_179"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 172 | # Cast_239 %"_val_180"<?,?> ⬅️ ::Cast(%"_val_179") {to=7} 173 | # Constant_240 %"_val_181"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 174 | # Reshape_241 %"_val_182"<?,?> ⬅️ ::Reshape(%"_val_180", %"_val_181") {allowzero=0} 175 | # Constant_242 %"_val_183"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 176 | # Cast_243 %"_val_184"<?,?> ⬅️ ::Cast(%"_val_183") {to=7} 177 | # Constant_244 %"_val_185"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 178 | # Reshape_245 %"_val_186"<?,?> ⬅️ ::Reshape(%"_val_184", %"_val_185") {allowzero=0} 179 | # Constant_246 %"_val_187"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 180 | # Cast_247 %"_val_188"<?,?> ⬅️ ::Cast(%"_val_187") {to=7} 181 | # Constant_248 %"_val_189"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 182 | # Reshape_249 %"_val_190"<?,?> ⬅️ ::Reshape(%"_val_188", %"_val_189") {allowzero=0} 183 | # Constant_250 %"_val_191"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 184 | # Cast_251 %"_val_192"<?,?> ⬅️ ::Cast(%"_val_191") {to=7} 185 | # Constant_252 %"_val_193"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 186 | # Reshape_253 %"_val_194"<?,?> ⬅️ ::Reshape(%"_val_192", %"_val_193") {allowzero=0} 187 | # Slice_254 %"slice_6"<FLOAT,[2,2,1024,4]> ⬅️ ::Slice(%"transpose_1", %"_val_182", %"_val_186", %"_val_190", %"_val_194") 188 | # Constant_255 %"_val_196"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 189 | # Cast_256 %"_val_197"<?,?> ⬅️ ::Cast(%"_val_196") {to=7} 190 | # Constant_257 %"_val_198"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 191 | # Reshape_258 %"_val_199"<?,?> ⬅️ ::Reshape(%"_val_197", %"_val_198") {allowzero=0} 192 | # Constant_259 %"_val_200"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 193 | # Cast_260 %"_val_201"<?,?> ⬅️ ::Cast(%"_val_200") {to=7} 194 | # Constant_261 %"_val_202"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 195 | # Reshape_262 %"_val_203"<?,?> ⬅️ ::Reshape(%"_val_201", %"_val_202") {allowzero=0} 196 | # Constant_263 %"_val_204"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 197 | # Cast_264 %"_val_205"<?,?> ⬅️ ::Cast(%"_val_204") {to=7} 198 | # Constant_265 %"_val_206"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 199 | # Reshape_266 %"_val_207"<?,?> ⬅️ ::Reshape(%"_val_205", %"_val_206") {allowzero=0} 200 | # Constant_267 %"_val_208"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')} 201 | # Cast_268 %"_val_209"<?,?> ⬅️ ::Cast(%"_val_208") {to=7} 202 | # Constant_269 %"_val_210"<?,?> ⬅️ ::Constant() {value_ints=[-1]} 203 | # Reshape_270 %"_val_211"<?,?> ⬅️ ::Reshape(%"_val_209", %"_val_210") {allowzero=0} 204 | # Slice_271 %"slice_7"<FLOAT,[2,2,1024,4]> ⬅️ ::Slice(%"transpose_1", %"_val_199", %"_val_203", %"_val_207", %"_val_211") 205 | # Constant_272 %"_val_213"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[4]>(name='')} 206 | # aten_expand_273 %"expand_6"<FLOAT,[2,2,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_expand(%"transpose_2", %"_val_213") 207 | # Constant_274 %"_val_215"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[3]>(name='')} 208 | # aten_view_275 %"view_9"<FLOAT,[1,4,1]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"expand_1", %"_val_215") 209 | # Constant_276 %"_val_217"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[3]>(name='')} 210 | # aten_view_277 %"view_10"<FLOAT,[1,1,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"expand_2", %"_val_217") 211 | # aten_neg_278 %"neg"<FLOAT,[2,2,1024,4]> ⬅️ pkg.onnxscript.torch_lib::aten_neg(%"slice_5") 212 | # aten_neg_279 %"neg_1"<FLOAT,[2,2,1024,4]> ⬅️ pkg.onnxscript.torch_lib::aten_neg(%"slice_7") 213 | # aten_clone_280 %"clone_3"<FLOAT,[2,2,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_clone(%"expand_6") {memory_format=} 214 | # aten_bmm_281 %"bmm"<FLOAT,[1,4,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_bmm(%"view_9", %"view_10") 215 | # SequenceConstruct_282 %"223"<?,?> ⬅️ ::SequenceConstruct(%"neg", %"slice_4") 216 | # aten_cat_283 %"cat_1"<FLOAT,[2,2,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_cat(%"223") {dim=-1} 217 | # SequenceConstruct_284 %"225"<?,?> ⬅️ ::SequenceConstruct(%"neg_1", %"slice_6") 218 | # aten_cat_285 %"cat_2"<FLOAT,[2,2,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_cat(%"225") {dim=-1} 219 | # Constant_286 %"_val_227"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[3]>(name='')} 220 | # aten_view_287 %"view_16"<FLOAT,[4,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"clone_3", %"_val_227") 221 | # Constant_288 %"_val_229"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[3]>(name='')} 222 | # aten_view_289 %"view_11"<FLOAT,[1,4,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"bmm", %"_val_229") 223 | # Transpose_290 %"transpose_8"<FLOAT,[4,8,1024]> ⬅️ ::Transpose(%"view_16") {perm=[0, 2, 1]} 224 | # Transpose_291 %"transpose_3"<FLOAT,[1,1024,4]> ⬅️ ::Transpose(%"view_11") {perm=[0, 2, 1]} 225 | # SequenceConstruct_292 %"233"<?,?> ⬅️ ::SequenceConstruct(%"transpose_3", %"transpose_3") 226 | # aten_cat_293 %"cat"<FLOAT,[1,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_cat(%"233") {dim=-1} 227 | # aten_cos_294 %"cos"<FLOAT,[1,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_cos(%"cat") 228 | # aten_sin_295 %"sin"<FLOAT,[1,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_sin(%"cat") 229 | # aten_unsqueeze_296 %"unsqueeze_3"<FLOAT,[1,1,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_unsqueeze(%"cos") {dim=1} 230 | # aten_unsqueeze_297 %"unsqueeze_4"<FLOAT,[1,1,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_unsqueeze(%"sin") {dim=1} 231 | # aten_mul_298 %"mul"<FLOAT,[2,2,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_mul(%"transpose", %"unsqueeze_3") 232 | # aten_mul_299 %"mul_2"<FLOAT,[2,2,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_mul(%"transpose_1", %"unsqueeze_3") 233 | # aten_mul_300 %"mul_1"<FLOAT,[2,2,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_mul(%"cat_1", %"unsqueeze_4") 234 | # aten_mul_301 %"mul_3"<FLOAT,[2,2,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_mul(%"cat_2", %"unsqueeze_4") 235 | # aten_add_302 %"add"<FLOAT,[2,2,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_add(%"mul", %"mul_1") {alpha=1.0} 236 | # aten_add_303 %"add_1"<FLOAT,[2,2,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_add(%"mul_2", %"mul_3") {alpha=1.0} 237 | # Constant_304 %"_val_245"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[4]>(name='')} 238 | # aten_expand_305 %"expand_3"<FLOAT,[2,2,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_expand(%"add", %"_val_245") 239 | # Transpose_306 %"transpose_4"<FLOAT,[2,2,8,1024]> ⬅️ ::Transpose(%"add_1") {perm=[0, 1, 3, 2]} 240 | # aten_clone_307 %"clone"<FLOAT,[2,2,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_clone(%"expand_3") {memory_format=} 241 | # Constant_308 %"_val_249"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[4]>(name='')} 242 | # aten_expand_309 %"expand_4"<FLOAT,[2,2,8,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_expand(%"transpose_4", %"_val_249") 243 | # Constant_310 %"_val_251"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[3]>(name='')} 244 | # aten_view_311 %"view_12"<FLOAT,[4,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"clone", %"_val_251") 245 | # aten_clone_312 %"clone_1"<FLOAT,[2,2,8,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_clone(%"expand_4") {memory_format=} 246 | # Transpose_313 %"transpose_9"<FLOAT,[4,8,1024]> ⬅️ ::Transpose(%"view_12") {perm=[0, 2, 1]} 247 | # Constant_314 %"_val_255"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[3]>(name='')} 248 | # aten_view_315 %"view_13"<FLOAT,[4,8,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"clone_1", %"_val_255") 249 | # aten_bmm_316 %"bmm_1"<FLOAT,[4,1024,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_bmm(%"view_12", %"view_13") 250 | # Transpose_317 %"transpose_10"<FLOAT,[4,1024,8]> ⬅️ ::Transpose(%"view_13") {perm=[0, 2, 1]} 251 | # Constant_318 %"_val_259"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[4]>(name='')} 252 | # aten_view_319 %"view_14"<FLOAT,[2,2,1024,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"bmm_1", %"_val_259") 253 | # Constant_320 %"_val_261"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<FLOAT,[]>(name='')} 254 | # aten_div_321 %"div"<FLOAT,[2,2,1024,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_div(%"view_14", %"_val_261") 255 | # aten_add_322 %"add_2"<FLOAT,[2,2,1024,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_add(%"div", %"slice_10") {alpha=1.0} 256 | # aten_softmax_no_dtype_323 %"_softmax"<FLOAT,[2,2,1024,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_softmax_no_dtype(%"add_2") {dim=-1} 257 | # aten_detach_324 %"detach"<FLOAT,[2,2,1024,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_detach(%"_softmax") 258 | # aten_clone_325 %"clone_2"<FLOAT,[2,2,1024,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_clone(%"_softmax") {memory_format=} 259 | # aten_detach_326 %"detach_1"<FLOAT,[2,2,1024,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_detach(%"detach") 260 | # Constant_327 %"_val_268"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[4]>(name='')} 261 | # aten_expand_328 %"expand_5"<FLOAT,[2,2,1024,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_expand(%"clone_2", %"_val_268") 262 | # aten_detach_329 %"detach_2"<FLOAT,[2,2,1024,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_detach(%"detach_1") 263 | # Constant_330 %"_val_271"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[3]>(name='')} 264 | # aten_view_331 %"view_15"<FLOAT,[4,1024,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"expand_5", %"_val_271") 265 | # aten_detach_332 %"detach_3"<FLOAT,[2,2,1024,1024]> ⬅️ pkg.onnxscript.torch_lib::aten_detach(%"detach_2") 266 | # aten_bmm_333 %"bmm_2"<FLOAT,[4,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_bmm(%"view_15", %"view_16") 267 | # Transpose_334 %"transpose_7"<FLOAT,[4,1024,1024]> ⬅️ ::Transpose(%"view_15") {perm=[0, 2, 1]} 268 | # Constant_335 %"_val_276"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[4]>(name='')} 269 | # aten_view_336 %"view_17"<FLOAT,[2,2,1024,8]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"bmm_2", %"_val_276") 270 | # Transpose_337 %"transpose_5"<FLOAT,[2,1024,2,8]> ⬅️ ::Transpose(%"view_17") {perm=[0, 2, 1, 3]} 271 | # aten_clone_338 %"clone_4"<FLOAT,[2,1024,2,8]> ⬅️ pkg.onnxscript.torch_lib::aten_clone(%"transpose_5") {memory_format=} 272 | # Constant_339 %"_val_280"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[3]>(name='')} 273 | # aten_view_340 %"view_18"<FLOAT,[2,1024,16]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"clone_4", %"_val_280") 274 | # Constant_341 %"_val_282"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[2]>(name='')} 275 | # aten_view_342 %"view_19"<FLOAT,[2048,16]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"view_18", %"_val_282") 276 | # aten_mm_343 %"mm_3"<FLOAT,[2048,16]> ⬅️ pkg.onnxscript.torch_lib::aten_mm(%"view_19", %"t_3") 277 | # Constant_344 %"_val_285"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[3]>(name='')} 278 | # aten_view_345 %"view_20"<FLOAT,[2,1024,16]> ⬅️ pkg.onnxscript.torch_lib::aten_view(%"mm_3", %"_val_285") return %"view"<FLOAT,[2048,16]>, %"t_6"<FLOAT,[16,16]>, %"transpose_8"<FLOAT,[4,8,1024]>, %"cat"<FLOAT,[1,1024,8]>, %"transpose_9"<FLOAT,[4,8,1024]>, %"transpose_10"<FLOAT,[4,1024,8]>, %"detach_3"<FLOAT,[2,2,1024,1024]>, %"transpose_7"<FLOAT,[4,1024,1024]>, %"view_19"<FLOAT,[2048,16]>, %"view_20"<FLOAT,[2,1024,16]> }
Convert from the IR object back to ModelProto
model_proto_back = ir.serde.serialize_model(model)
Next steps¶
Read the introductions for a more detailed introduction of the IR (Documentation in progress 🚧)