Source code for onnxscript.tools.transformers_models.llama

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=import-outside-toplevel
from __future__ import annotations

from typing import Any, Sequence

import torch

import onnxscript.tools.transformers_models


def get_llama_model(
    input_dims: Sequence[tuple[int, int]] = ((2, 8), (4, 7), (9, 15)),
    hidden_size: int = 16,
    num_hidden_layers: int = 1,
    vocab_size: int = 1024,
    intermediate_size: int = 16,
    max_position_embeddings: int = 1024,
    num_attention_heads: int = 2,
    _attn_implementation: str = "eager",  # needed value to remove graph breaks
    with_mask: bool = True,
) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]:
    """
    Returns a model.
    See `LlamaConfig
    <https://huggingface.co/docs/transformers/main/en/model_doc/llama#transformers.LlamaConfig>`_.
    The parameters are chosen for a unit test configuration.
    """
    from transformers import LlamaConfig
    from transformers.models.llama.modeling_llama import LlamaModel

    dynamic_shapes = {0: {0: "batch", 1: "length"}}
    if with_mask:
        dynamic_shapes.update({1: {0: "batch", 1: "length"}})

    config = LlamaConfig(
        num_hidden_layers=num_hidden_layers,
        vocab_size=vocab_size,
        hidden_size=hidden_size,
        intermediate_size=intermediate_size,
        max_position_embeddings=max_position_embeddings,
        num_attention_heads=num_attention_heads,
    )
    if _attn_implementation:
        config._attn_implementation = _attn_implementation  # pylint: disable=protected-access

    if with_mask:

        class LlamaModelWrapperMask(torch.nn.Module):
            def __init__(self, config):
                super().__init__()
                self.model = LlamaModel(config)

            def forward(self, input_ids, attention_mask):
                model_output = self.model(
                    input_ids, attention_mask=attention_mask, use_cache=False
                )
                return model_output.to_tuple()

        def generate_example_inputs_mask(batch: int, seq: int, vocab_size: int):
            input_ids = onnxscript.tools.transformers_models.ids_tensor(
                [batch, seq], vocab_size
            )
            input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))
            assert input_mask.dtype == torch.float32
            return input_ids, input_mask

        example_args_collection = []
        for b, s in input_dims:
            example_args_collection.append(generate_example_inputs_mask(b, s, vocab_size))

        return LlamaModelWrapperMask(config), example_args_collection, dynamic_shapes

    # no mask

    class LlamaModelWrapper(torch.nn.Module):
        def __init__(self, config):
            super().__init__()
            self.model = LlamaModel(config)

        def forward(self, input_ids):
            model_output = self.model(input_ids, use_cache=False)
            return model_output.to_tuple()

    def generate_example_inputs(batch: int, seq: int, vocab_size: int):
        input_ids = onnxscript.tools.transformers_models.ids_tensor([batch, seq], vocab_size)
        return (input_ids,)

    example_args_collection = []
    for b, s in input_dims:
        example_args_collection.append(generate_example_inputs(b, s, vocab_size))

    return LlamaModelWrapper(config), example_args_collection, dynamic_shapes


[docs] def get_llama_model_from_config( warmup: int = 5, repeat: int = 10, config: str = "small", num_hidden_layers: int = 1, implementation: str = "eager", dynamic_shapes: bool = False, with_mask: bool = True, ) -> tuple[Any, list[tuple[torch.Tensor, ...]], dict]: """ Returns a model Phi to test or benchmark. Args: warmup: Number of inputs to generate. repeat: Number of inputs to generate for repeat. config: small, medium or large num_hidden_layers: Number of hidden layers. implementation: eager or sdpa with_mask: One or two inputs. dynamic_shapes: dynamic shapes or not Returns: Model and list of inputs. """ if config == "small": conf_dict = dict( input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( dynamic_shapes, warmup, repeat ), hidden_size=16, num_hidden_layers=num_hidden_layers, vocab_size=1024, intermediate_size=16, max_position_embeddings=1024, num_attention_heads=2, _attn_implementation=implementation, with_mask=with_mask, ) elif config == "medium": conf_dict = dict( input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( dynamic_shapes, warmup, repeat ), hidden_size=1024, num_hidden_layers=num_hidden_layers, vocab_size=1024, intermediate_size=1024, max_position_embeddings=1024, num_attention_heads=2, _attn_implementation=implementation, with_mask=with_mask, ) elif config in ("large", "default"): conf_dict = dict( input_dims=onnxscript.tools.transformers_models.get_input_dims_for_llm( dynamic_shapes, warmup, repeat ), hidden_size=4096, num_hidden_layers=num_hidden_layers, vocab_size=32000, intermediate_size=11008, max_position_embeddings=2048, num_attention_heads=32, _attn_implementation=implementation, with_mask=with_mask, ) else: raise ValueError(f"Unexpected configuration {config!r}.") return get_llama_model(**conf_dict) # type: ignore[arg-type]