Source code for onnxscript.optimizer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import logging
from typing import Any

import onnx
import onnx.shape_inference

from onnxscript import ir, rewriter
from onnxscript.optimizer import _constant_folding, _inliner
from onnxscript.optimizer.constant_folding import fold_constants
from onnxscript.optimizer.remove_unused import remove_unused_nodes
from onnxscript.optimizer.remove_unused_function import remove_unused_functions
from onnxscript.optimizer.simple_function_folding import (
    inline_functions_with_unused_outputs,
    inline_simple_functions,
)
from onnxscript.rewriter import (
    broadcast_to_matmul,
    cast_constant_of_shape,
    gemm_to_matmul_add,
    no_op,
)

logger = logging.getLogger(__name__)

_DEFAULT_REWRITE_RULES = [
    *no_op.rules.rules,  # TODO: merge this rule into constant folding?
    *broadcast_to_matmul.rules.rules,
    gemm_to_matmul_add.rule,
    *cast_constant_of_shape.rules.rules,
]


[docs] def optimize( model: onnx.ModelProto, num_iterations: int = 2, *, onnx_shape_inference: bool = True, stop_if_no_change: bool = True, external_data_folder: str = "", **kwargs: Any, ) -> onnx.ModelProto: """Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc. Args: model (onnx.ModelProto): The model to optimize. num_iterations (int, optional): Number of iterations to perform. onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model. Set this to False to turn off onnx shape inference, and rely on model carried shapes and types. This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries the symbolic shapes recorded from dynamo tracing. stop_if_no_change (bool, optional): Whether to stop if no change is detected. external_data_folder (str, optional): The folder to store external data. **kwargs: Additional keyword arguments. For BC purposes. """ if kwargs.pop("function_aware_folding", None) is not None: logger.warning( "'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. " "To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. " "This would turn off incremental onnx shape inference and rely on model carried shapes and types. " "See 'onnx_shape_inference' for more details." ) for _ in range(num_iterations): if onnx_shape_inference: if model.ByteSize() < 1024 * 1024 * 1024 * 2: # NOTE: strict mode is disabled because it crashes on the models # that have different shapes inferred from the model carried shapes. # The case can be found in: # https://github.com/microsoft/onnxscript/issues/1443 model = onnx.shape_inference.infer_shapes( model, check_type=True, strict_mode=False, data_prop=True ) else: logger.warning( "The model size is too large for full model shape inference. " "Skipping this step." ) inline_simple_functions(model) modified = fold_constants( model, external_data_folder, onnx_shape_inference=onnx_shape_inference ) remove_unused_nodes(model) inline_simple_functions(model) model = remove_unused_functions(model) inline_functions_with_unused_outputs(model) # NOTE: This is general rewrite rules model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES) if stop_if_no_change and not modified: logger.debug("Stopping after %d iterations.", _) break for node in model.graph.node: logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name) for function in model.functions: for node in function.node: logger.debug( "Function %s::%s node %s::%s name %s.", function.domain, function.name, node.domain, node.op_type, node.name, ) return model
def optimize_ir( model: ir.Model, num_iterations: int = 2, *, onnx_shape_inference: bool = True, stop_if_no_change: bool = True, ) -> None: del stop_if_no_change # Looks like rewriter doesn't support this yet. _inliner.inline(model) for _ in range(num_iterations): _constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference) rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES) remove_unused_nodes(model) __all__ = [ "fold_constants", "remove_unused_nodes", "optimize", "optimize_ir", ]