Source code for onnxscript.testing

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

__all__ = [
    "assert_isomorphic",
    "assert_isomorphic_graph",
    "assert_isomorphic_function",
    "assert_onnx_proto_equal",
]

import difflib
import math
from typing import Any, Collection, Sequence

import google.protobuf.message
import onnx
from onnx import parser

import onnxscript


[docs] def assert_isomorphic(graph_or_function_1, graph_or_function_2): """Assert two graphs or functions are isomorphic.""" assert _isomorphic( _to_function_or_graph(graph_or_function_1), _to_function_or_graph(graph_or_function_2), )
[docs] def assert_isomorphic_graph(graph1, graph2): """Assert two graphs are isomorphic.""" assert _isomorphic(_to_graph_proto(graph1), _to_graph_proto(graph2))
[docs] def assert_isomorphic_function(fn1, fn2): """Assert two functions are isomorphic.""" assert _isomorphic(_to_function_proto(fn1), _to_function_proto(fn2))
def _default_equality_op(x, y): return x == y def _same_optional(field, obj1, obj2, equals=_default_equality_op): """Check two proto object have same value for optional field. This is restricted to simple field types where == comparison is sufficient. """ if obj1.HasField(field): return obj2.HasField(field) and equals(getattr(obj1, field), getattr(obj2, field)) return not obj2.HasField(field) def _same_repeated(values1, values2, equals=_default_equality_op): if len(values1) != len(values2): return False return all(equals(val1, val2) for val1, val2 in zip(values1, values2)) def _same_string_string_map(proto1, proto2): """Compare repeated StringStringEntryProto as maps.""" def to_map(proto): return {x.key: x.value for x in proto} return to_map(proto1) == to_map(proto2) def _same_tensor(tp1, tp2): if tp1.dims != tp2.dims: return False if not _same_optional("data_type", tp1, tp2): return False # Segmented representation not supported yet if tp1.HasField("segment") or tp2.HasField("segment"): return False if tp1.float_data != tp2.float_data: return False if tp1.int32_data != tp2.int32_data: return False if tp1.string_data != tp2.string_data: return False if tp1.int64_data != tp2.int64_data: return False if tp1.uint64_data != tp2.uint64_data: return False if tp1.double_data != tp2.double_data: return False # Ignore name for comparison: # if not _same_optional("name", tp1, tp2): return False if not _same_optional("doc_string", tp1, tp2): return False if not _same_optional("data_location", tp1, tp2): return False if not _same_string_string_map(tp1.external_data, tp2.external_data): return False return True def _same_dim(dim1, dim2): return _same_optional("dim_value", dim1, dim2) and _same_optional("dim_param", dim1, dim2) def _same_shape(shape1, shape2): return _same_repeated(shape1.dim, shape2.dim, _same_dim) def _same_tensor_type(tt1, tt2): return (tt1.elem_type == tt2.elem_type) and _same_optional("shape", tt1, tt2, _same_shape) def _same_type(tp1, tp2): # Handles only tensor type at this point. return _same_optional("tensor_type", tp1, tp2, _same_tensor_type) def _same_value_info(vi1, vi2): return ( _same_optional("name", vi1, vi2) and _same_optional("type", vi1, vi2, _same_type) and _same_optional("doc_string", vi1, vi2) ) def _same_attr(attr1, attr2, graph_equality): # no name check; names used to match attributes already. for field in ["type", "ref_attr_name", "f", "i", "s"]: if not _same_optional(field, attr1, attr2): return False if not _same_optional("t", attr1, attr2, _same_tensor): return False if not _same_repeated(attr1.tensors, attr2.tensors, _same_tensor): return False for field in ["floats", "ints", "strings"]: if getattr(attr1, field) != getattr(attr2, field): return False if not _same_optional("g", attr1, attr2, graph_equality): return False if not _same_repeated(attr1.graphs, attr2.graphs, graph_equality): return False for field in ["sparse_tensor", "tp"]: # TODO(gramalingam): check for more complex fields if attr1.HasField(field) or attr2.HasField(field): return False return True def _same_attrs(attrs1, attrs2, graph_equality): if len(attrs1) != len(attrs2): return False attrs1map = {a.name: a for a in attrs1} for attr2 in attrs2: if attr2.name not in attrs1map: return False attr1 = attrs1map[attr2.name] if not _same_attr(attr1, attr2, graph_equality): return False return True def _ioname(x): """Return the name of an input/output of a function or graph""" return x.name if isinstance(x, onnx.ValueInfoProto) else x class _Matcher: """An isomorphism matcher for two functions or two graphs.""" def __init__(self, fg1, fg2, outer_scope) -> None: def defmap(f): """Compute a map from variables v to their definition-sites. A definition-site (n, i) indicates the i-th output of n-th node The special value (-1, i) is used to indicate the i-th input of a function/graph. """ result = {} for i, x in enumerate(f.input): result[_ioname(x)] = (-1, i) for ni, n in enumerate(f.node): for xi, x in enumerate(n.output): result[x] = (ni, xi) return result self.defmap1 = defmap(fg1) self.defmap2 = defmap(fg2) self.fg1 = fg1 self.fg2 = fg2 self.node_mapping: dict[onnx.NodeProto, onnx.NodeProto] = {} self.outer_scope = outer_scope def same_value(self, var1, var2): """Match two variables (strings).""" if var1 == "": return var2 == "" if var2 == "": return False if var1 not in self.defmap1 or var2 not in self.defmap2: # If one of the variables is in current scope, or if there is no outer scope, fail if (var1 in self.defmap1) or (var2 in self.defmap2) or (self.outer_scope is None): return False # Both variables are in outer-scopes. Delay check until later return self.outer_scope.same_value(var1, var2) (node1, index1) = self.defmap1[var1] (node2, index2) = self.defmap2[var2] return (index1 == index2) and self.same_node(node1, node2) def same_node(self, n1, n2): """Match two node-indices. The special node-index -1 represents inputs.""" if (n1 == -1) and (n2 == -1): return True # Both are inputs if (n1 == -1) or (n2 == -1): return False # Only one is input if n1 in self.node_mapping: return self.node_mapping[n1] == n2 node1 = self.fg1.node[n1] node2 = self.fg2.node[n2] if node1.op_type != node2.op_type: return False if node1.domain != node2.domain: return False # check attrs if not _same_attrs(node1.attribute, node2.attribute, self.same_sub_graph): return False if not self.same_value_list(node1.input, node2.input): return False # Nodes represent same computation. Cache the comparison result. self.node_mapping[n1] = n2 return True def same_value_list(self, list1, list2): """Match two lists of variables (either a string or ValueInfoProto)""" if len(list1) != len(list2): return False return all(self.same_value(_ioname(x), _ioname(y)) for x, y in zip(list1, list2)) def same_sub_graph(self, g1, g2): """Match two sub-graphs.""" sub_graph_matcher = _Matcher(g1, g2, self) return sub_graph_matcher.same_graph() def same_graph(self): """Match two sub-graphs.""" g1 = self.fg1 g2 = self.fg2 if not _same_repeated(g1.input, g2.input, _same_value_info): return False if g1.initializer or g2.initializer: return False # TODO if g1.sparse_initializer or g2.sparse_initializer: return False # TODO if not self.same_value_list(g1.output, g2.output): return False # TODO completeness tests! return True def same_function(self): """Match (top-level) two functions.""" # Ok for function names/domain to be different. if len(self.fg1.input) != len(self.fg2.input): return False if set(self.fg1.attribute) != set(self.fg2.attribute): return False # Opset imports must be same (but possibly in different order): # Convert opset-imports into a dictionary def imports(f): # Assumes each domain has only one entry in a valid FunctionProto return {entry.domain: entry.version for entry in f.opset_import} if imports(self.fg1) != imports(self.fg2): return False # Now do a specific form of isomorphism check: Both must compute the same # set of operations, possibly in different order as long as they respect # the topological-sort order requirement. The two may use different names # for intermediate-values, as long as the computation is the same. if len(self.fg1.node) != len(self.fg2.node): return False if not self.same_value_list(self.fg1.output, self.fg2.output): return False # We do not allow for unused values in the function, which are # hard to handle in an isomorphism check. if len(self.node_mapping) != len(self.fg1.node): return False if len(set(self.node_mapping.values())) != len(self.fg2.node): return False return True def _isomorphic(fg1, fg2): """Checks that two function/graph bodies are isomorphic. Assumes that the inputs are valid FunctionProto/GraphProto. Use a separate check to verify that the inputs satisfy FunctionProto/GraphProto requirements (like no duplicate attributes). """ matcher = _Matcher(fg1, fg2, None) if isinstance(fg1, onnx.FunctionProto): if not isinstance(fg2, onnx.FunctionProto): raise TypeError("Both inputs must be same type (function or graph)") return matcher.same_function() if isinstance(fg1, onnx.GraphProto): if not isinstance(fg2, onnx.GraphProto): raise TypeError("Both inputs must be same type (function or graph)") return matcher.same_graph() raise TypeError("Inputs must be either a FunctionProto or GraphProto") def _to_function_proto(f): if isinstance(f, onnx.FunctionProto): return f if isinstance(f, onnxscript.OnnxFunction): return f.to_function_proto() if isinstance(f, str): return parser.parse_function(f) raise TypeError(f"Cannot convert {type(f)} to FunctionProto") def _to_graph_proto(g): if isinstance(g, onnx.GraphProto): return g if isinstance(g, onnxscript.OnnxFunction): return g.to_model_proto().graph if isinstance(g, str): return parser.parse_graph(g) raise TypeError(f"Cannot convert {type(g)} to ModelProto") def _to_function_or_graph(obj): if isinstance(obj, onnx.FunctionProto): return obj if isinstance(obj, onnx.GraphProto): return obj if isinstance(obj, onnx.ModelProto): return obj.graph if isinstance(obj, onnxscript.OnnxFunction): return obj.to_function_proto() raise TypeError(f"Cannot convert {type(obj)} to FunctionProto or GraphProto") def _opset_import_key(opset_import: onnx.OperatorSetIdProto) -> tuple[str, int]: return (opset_import.domain, opset_import.version) def _value_info_key(value_info: onnx.ValueInfoProto) -> str: return value_info.name def _function_key(function: onnx.FunctionProto) -> tuple[str, str, str]: return (function.domain, function.name, getattr(function, "overload", "")) def _find_duplicates(with_duplicates: Collection[Any]) -> list[Any]: """Return a list of duplicated elements in a collection.""" seen = set() duplicates = [] for x in with_duplicates: if x in seen: duplicates.append(x) seen.add(x) return duplicates
[docs] def assert_onnx_proto_equal( a: google.protobuf.message.Message | Any, b: google.protobuf.message.Message | Any ) -> None: """Assert that two ONNX protos are equal. Equality is defined as having the same fields with the same values. When a field takes the default value, it is considered equal to the field not being set. Sequential fields with name `opset_import`, `value_info`, and `functions` are compared disregarding the order of their elements. Args: a: The first ONNX proto. b: The second ONNX proto. """ assert type(a) == type(b), f"Type not equal: {type(a)} != {type(b)}" # pylint: disable=unidiomatic-typecheck a_fields = {field.name: value for field, value in a.ListFields()} b_fields = {field.name: value for field, value in b.ListFields()} all_fields = sorted(set(a_fields.keys()) | set(b_fields.keys())) for field in all_fields: # Obtain the default value if the field is not set. This way we can compare the two fields. a_value = getattr(a, field) b_value = getattr(b, field) if ( isinstance(a_value, Sequence) and isinstance(b_value, Sequence) and not isinstance(a_value, (str, bytes)) and not isinstance(b_value, (str, bytes)) ): # Check length first a_keys: list[Any] = [] b_keys: list[Any] = [] if field == "opset_import": a_value = sorted(a_value, key=_opset_import_key) b_value = sorted(b_value, key=_opset_import_key) a_keys = [_opset_import_key(opset_import) for opset_import in a_value] b_keys = [_opset_import_key(opset_import) for opset_import in b_value] elif field == "value_info": a_value = sorted(a_value, key=_value_info_key) b_value = sorted(b_value, key=_value_info_key) a_keys = [_value_info_key(value_info) for value_info in a_value] b_keys = [_value_info_key(value_info) for value_info in b_value] elif field == "functions": a_value = sorted(a_value, key=_function_key) b_value = sorted(b_value, key=_function_key) a_keys = [_function_key(functions) for functions in a_value] b_keys = [_function_key(functions) for functions in b_value] if a_keys != b_keys: keys_only_in_a = set(a_keys) - set(b_keys) keys_only_in_b = set(b_keys) - set(a_keys) error_message = ( f"Field {field} not equal: keys_only_in_a={keys_only_in_a}, keys_only_in_b={keys_only_in_b}. " f"Field type: {type(a_value)}. " f"Duplicated a_keys: {_find_duplicates(a_keys)}, duplicated b_keys: {_find_duplicates(b_keys)}" ) raise AssertionError(error_message) if len(a_value) != len(b_value): error_message = ( f"Field {field} not equal: len(a)={len(a_value)}, len(b)={len(b_value)} " f"Field type: {type(a_value)}" ) raise AssertionError(error_message) # Check every element for i in range(len(a_value)): # pylint: disable=consider-using-enumerate a_value_i = a_value[i] b_value_i = b_value[i] if isinstance(a_value_i, google.protobuf.message.Message) and isinstance( b_value_i, google.protobuf.message.Message ): try: assert_onnx_proto_equal(a_value_i, b_value_i) except AssertionError as e: error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}, a_value_i: {a_value_i}, b_value_i: {b_value_i}" raise AssertionError(error_message) from e elif a_value_i != b_value_i: if ( isinstance(a_value_i, float) and isinstance(b_value_i, float) and math.isnan(a_value_i) and math.isnan(b_value_i) ): # Consider NaNs equal continue error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}" for line in difflib.ndiff( str(a_value_i).splitlines(), str(b_value_i).splitlines() ): error_message += "\n" + line raise AssertionError(error_message) elif isinstance(a_value, google.protobuf.message.Message) and isinstance( b_value, google.protobuf.message.Message ): assert_onnx_proto_equal(a_value, b_value) elif a_value != b_value: if ( isinstance(a_value, float) and isinstance(b_value, float) and math.isnan(a_value) and math.isnan(b_value) ): # Consider NaNs equal continue error_message = f"Field {field} not equal. field_a: {a_value}, field_b: {b_value}" raise AssertionError(error_message)