AWS Trainium & Inferentia documentation

Contributing Custom Models for Training

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Contributing Custom Models for Training

This guide explains how to add custom model implementations to the optimum/neuron/models/training/ directory. Custom models are needed to support distributed training features like tensor parallelism, pipeline parallelism, and sequence parallelism on AWS Trainium devices.

Architecture Components

1. NeuronModelMixin

The NeuronModelMixin class provides core functionality:

  • from_pretrained(): Loads regular Transformers weights into custom implementations
  • save_pretrained(): Saves sharded checkpoints with consolidation metadata
  • Pipeline parallelism support through PIPELINE_* attributes

2. Weight Transformation Specs

Transformation specs handle converting weights between:

  • Original Transformers format → Custom parallel format (during loading)
  • Custom parallel format → Original Transformers format (during checkpoint consolidation)

Key transformation spec types:

  • FusedLinearsSpec: Handles fused linear layers (e.g., gate_up_proj)
  • GQAQKVColumnParallelLinearSpec: Handles grouped query attention projections when the tensor parallel size is greater than the number of key-value heads

For complete API documentation of all transformation specs and utility functions, see the Model Weight Transformation Specs API Reference.

3. Parallel Layers

Use these parallel layers from neuronx_distributed:

  • ColumnParallelLinear: Splits weight matrix along output dimension
  • RowParallelLinear: Splits weight matrix along input dimension
  • ParallelEmbedding: Splits embedding table across ranks
  • GQAQKVColumnParallelLinear: Specialized for grouped query attention projections when the tensor parallel size is greater than the number of key-value heads

Implementation Steps

Step 1: Create Model Structure

Create a new directory: optimum/neuron/models/training/your_model/

__init__.py

from .modeling_your_model import YourModelForCausalLM, YourModel

__all__ = ["YourModelForCausalLM", "YourModel"]

Step 2: Implement the Model Building Blocks

modeling_your_model.py

Imports and Dependencies

import torch
from torch import nn
from neuronx_distributed.parallel_layers.layers import (
    ColumnParallelLinear,
    RowParallelLinear,
    ParallelEmbedding,
)
from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear
from transformers import PreTrainedModel
from transformers.models.your_model import YourModelConfig

from ..config import TrainingNeuronConfig
from ..modeling_utils import NeuronModelMixin
from ..transformations_utils import (
    CustomModule,
    FusedLinearsSpec,
    GQAQKVColumnParallelLinearSpec,
    ModelWeightTransformationSpecs,
)

Embedding Layer

class YourModelEmbeddings(nn.Module):
    def __init__(self, config, trn_config):
        super().__init__()
        self.embed_tokens = ParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            dtype=config.torch_dtype,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
        )

MLP Layer with Fused Linears

Important: Any module that has transformation specs must inherit from CustomModule to ensure proper handling of weight transformations, and the transformation specs must be defined in the self.specs attribute.

class YourModelMLP(nn.Module, CustomModule):
    def __init__(self, config, trn_config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        
        # Fused gate and up projections
        self.gate_up_proj = ColumnParallelLinear(
            self.hidden_size,
            2 * self.intermediate_size,
            stride=2,  # Important for proper sharding
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
            input_is_parallel=True,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        
        # Define transformation specs
        self.specs = ModelWeightTransformationSpecs()
        self.specs.add_spec(
            FusedLinearsSpec(
                fused_linear_name="gate_up_proj",
                linear_names=["gate_proj", "up_proj"],
                bias=False,
                fuse_axis="column",  # Fuse along output dimension
                original_dims=[self.intermediate_size, self.intermediate_size],
            )
        )

Attention Layer

The attention layer implementation depends on the model’s architecture and tensor parallel configuration. There are three main variants:

1. Separate Q, K, V Projections (Default)

class YourModelAttention(nn.Module, CustomModule):
    def __init__(self, config, trn_config, layer_idx):
        super().__init__()
        self.config = config
        self.num_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.head_dim = config.hidden_size // self.num_heads
        
        # Separate projections for Q, K, V
        self.q_proj = ColumnParallelLinear(
            config.hidden_size,
            self.num_heads * self.head_dim,
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        self.k_proj = ColumnParallelLinear(
            config.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        self.v_proj = ColumnParallelLinear(
            config.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        
        self.o_proj = RowParallelLinear(
            self.num_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            input_is_parallel=True,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        
        # No transformation specs needed - regular parallel layers
        self.specs = ModelWeightTransformationSpecs()

2. Fused QKV Projection (Multi-Head Attention)

class YourModelAttention(nn.Module, CustomModule):
    def __init__(self, config, trn_config, layer_idx):
        super().__init__()
        # ... (same setup as above)
        
        tp_size = get_tensor_model_parallel_size()
        
        # Only use fused QKV when num_heads == num_key_value_heads (no GQA)
        if trn_config.fuse_qkv and self.num_heads == self.num_key_value_heads:
            self.qkv_proj = ColumnParallelLinear(
                config.hidden_size,
                3 * self.num_heads * self.head_dim,  # Q + K + V
                stride=3,  # Important for proper sharding
                bias=False,
                gather_output=False,
                sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
                dtype=config.torch_dtype,
            )
            
            # Define transformation specs for fused QKV
            self.specs = ModelWeightTransformationSpecs()
            self.specs.add_spec(
                FusedLinearsSpec(
                    fused_linear_name="qkv_proj",
                    linear_names=["q_proj", "k_proj", "v_proj"],
                    bias=False,
                    fuse_axis="column",
                    original_dims=[self.num_heads * self.head_dim] * 3,
                )
            )
            self.split_size = self.num_heads * self.head_dim // tp_size

3. GQA QKV Projection (Required for Challenging TP Configurations)

class YourModelAttention(nn.Module, CustomModule):
    def __init__(self, config, trn_config, layer_idx):
        super().__init__()
        # ... (same setup as above)
        
        tp_size = get_tensor_model_parallel_size()
        
        # Use GQA QKV when KV heads can't be evenly distributed across TP ranks
        # This happens when: num_key_value_heads < tp_size or num_key_value_heads % tp_size != 0
        self.qkv_linear = (self.num_key_value_heads < tp_size) or (self.num_key_value_heads % tp_size != 0)
        
        if self.qkv_linear:
            # Calculate KV size multiplier to ensure even distribution
            if trn_config.kv_size_multiplier is None:
                self.kv_size_multiplier = trn_config.auto_kv_size_multiplier(self.num_key_value_heads)
            else:
                self.kv_size_multiplier = trn_config.kv_size_multiplier
                
            self.qkv_proj = GQAQKVColumnParallelLinear(
                config.hidden_size,
                [self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim],
                bias=False,
                gather_output=False,
                sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
                kv_size_multiplier=self.kv_size_multiplier,
                fuse_qkv=trn_config.fuse_qkv,
                dtype=config.torch_dtype,
            )
            
            # Define transformation specs for GQA QKV
            self.specs = ModelWeightTransformationSpecs()
            self.specs.add_spec(
                GQAQKVColumnParallelLinearSpec(
                    gqa_qkv_projection_name="qkv_proj",
                    query_projection_name="q_proj",
                    key_projection_name="k_proj", 
                    value_projection_name="v_proj",
                    output_projection_name="o_proj",
                    num_attention_heads=self.num_heads,
                    num_key_value_heads=self.num_key_value_heads,
                    kv_size_multiplier=self.kv_size_multiplier,
                    q_output_size_per_partition=self.qkv_proj.q_output_size_per_partition,
                    kv_output_size_per_partition=self.qkv_proj.kv_output_size_per_partition,
                    fuse_qkv=trn_config.fuse_qkv,
                )
            )

When to Use Each Variant:

  • Separate Q, K, V: Default approach, works for all configurations but may be less efficient
  • Fused QKV: Use when num_heads == num_key_value_heads (no grouped query attention) and fuse_qkv=True
  • GQA QKV: Required when using grouped query attention with challenging tensor parallel configurations where KV heads cannot be evenly distributed across TP ranks

The choice is typically determined by:

tp_size = get_tensor_model_parallel_size()
use_gqa_qkv = (num_key_value_heads < tp_size) or (num_key_value_heads % tp_size != 0)
use_fused_qkv = trn_config.fuse_qkv and (num_heads == num_key_value_heads) and not use_gqa_qkv

Step 3: Implement Main Model Classes

Base Model

class YourPreTrainedModel(PreTrainedModel, NeuronModelMixin):
    config_class = YourModelConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["YourModelDecoderLayer"]
    _skip_keys_device_placement = "past_key_values"
    _supports_flash_attn_2 = True
    _supports_cache_class = True
    _supports_quantized_cache = True
    _supports_static_cache = True


class YourModel(NeuronModelMixin, YourPreTrainedModel):
    def __init__(self, config: YourModelConfig, trn_config: TrainingNeuronConfig):
        YourPreTrainedModel.__init__(self, config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.trn_config = trn_config
        
        self.embed_tokens = ParallelEmbedding(...)
        self.layers = nn.ModuleList([
            YourModelDecoderLayer(config, trn_config, layer_idx)
            for layer_idx in range(config.num_hidden_layers)
        ])
        self.norm = YourModelRMSNorm(...)
        
        self.post_init()

CausalLM Model

class YourModelForCausalLM(NeuronModelMixin, YourPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]
    
    # Pipeline parallelism support
    SUPPORTS_PIPELINE_PARALLELISM = True
    PIPELINE_TRANSFORMER_LAYER_CLS = YourModelDecoderLayer
    PIPELINE_INPUT_NAMES = ["input_ids", "attention_mask"]
    
    def __init__(self, config, trn_config):
        super().__init__(config)
        self.trn_config = trn_config
        self.model = YourModel(config, trn_config)
        self.vocab_size = config.vocab_size
        
        self.lm_head = ColumnParallelLinear(
            config.hidden_size,
            config.vocab_size,
            bias=False,
            gather_output=False,
            dtype=config.torch_dtype,
        )
        
        self.post_init()

Step 4: Register Model

Update optimum/neuron/models/training/__init__.py:

from .your_model import YourModelForCausalLM, YourModel

__all__ = [..., "YourModelForCausalLM", "YourModel"]

Update optimum/neuron/models/training/auto_models.py:

from .your_model.modeling_your_model import YourModelForCausalLM, YourModel

# Register the base model (without head)
register_neuron_model_for_training("your_model", "model")(YourModel)

# Register the CausalLM model
register_neuron_model_for_training("your_model", "text-generation")(YourModelForCausalLM)

Here "your_model" is the corresponding to the model_type attribute of your model’s configuration class.

Best Practices

1. Parallel Layer Configuration

  • Use gather_output=False for intermediate layers
  • Set input_is_parallel=True for layers that receive parallel input
  • Configure sequence_parallel_enabled consistently across layers
  • Use appropriate stride values for proper weight sharding

2. Weight Transformation Specs

  • Always define specs for modules that use fused or parallel layers
  • Use CustomModule mixin for any module with transformation specs
  • Ensure spec parameter names match the actual module structure
  • Test both regular and LoRA weight transformations

3. Pipeline Parallelism

  • Set SUPPORTS_PIPELINE_PARALLELISM = True for supported models
  • Define PIPELINE_TRANSFORMER_LAYER_CLS as your decoder layer class
  • List all input names in PIPELINE_INPUT_NAMES

4. Flash Attention Support

  • Set _supports_flash_attn_2 = True if your model supports it
  • Implement both eager and flash attention paths
  • Use appropriate attention function dispatching

Testing Your Implementation

The training tests in tests/training/ provide a comprehensive testing framework that validates numerical correctness, distributed training scenarios, and checkpoint compatibility. Most of the tests are not designed to be run on every custom modeling implementation, but rather to validate the core functionality of the Optimum Neuron training infrastructure. With that in mind, here’s what you need to implement for your custom modeling:

1. Custom Modeling Validation

The test_custom_modeling.py file validates that your custom implementation produces identical outputs to the original Transformers model:

Update tests/training/test_custom_modeling.py:

CUSTOM_MODELINGS_TO_TEST = [
    # ... existing models ...
    ("YourModelForCausalLM", "your-org/your-model-name"),
]

Important: For custom modeling validation tests, use small/tiny models to ensure CI efficiency. The test models should have:

  • Small vocabulary size (e.g., 1000-8000 tokens)
  • Few layers (e.g., 2-4 layers)
  • Small hidden dimensions (e.g., 128-512)
  • Minimal attention heads (e.g., 4-8 heads)

Examples of good test models for custom modeling validation:

  • "michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random" - 4 layers, 4 KV heads
  • "michaelbenayoun/granite-tiny-4kv-heads-4layers-random" - Tiny Granite model
  • "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random" - Tiny Qwen3 model

Key test your model must pass:

def test_custom_modeling_matches_original()  # Output matching
  • Numerical Correctness: Ensures custom models match Transformers outputs exactly
  • Parallelization Support: Tests various QKV implementations (regular, fused, GQA)

2. End-to-End Training Validation

The test_overfit.py file validates training convergence. To include your model in end-to-end training validation, you must add it to the parametrized test cases:

Update tests/training/test_overfit.py:

@pytest.mark.parametrize(
    "model_class_name,model_name_or_path,learning_rate,warmup_ratio,training_kwargs,use_flash_attention_2,max_expected_loss,max_length,num_steps",
    [
        # ... existing models ...
        [
            "YourModelForCausalLM",
            "your-org/your-model-name",
            1e-4,
            0.03,
            {},
            True,
            0.5,
            2048,
            50,
        ],
    ],
    ids=[
        # ... existing model IDs ...
        "your-org/your-model-name",
    ],
)

This validates:

  • Convergence Validation: Ensures models can overfit simple datasets

Your model will be tested in:

def test_overfit_custom_modeling_causal_lm()       # Basic training (your model included)

3. Auto Model Loading

The test_modeling_auto.py file validates that your model can be loaded using the NeuronModel and NeuronModelForCausalLM auto classes. To include your model in these tests, you must add it to the test cases:

Update tests/training/test_modeling_auto.py:

@pytest.mark.parametrize("from_pretrained", [False, True], ids=["from_config", "from_pretrained"])
@distributed_test(world_size=1)
@is_trainium_test
def test_auto_model_with_supported_architecture(from_pretrained):
    trn_config = TrainingNeuronConfig()
    kwargs = {"torch_dtype": torch.bfloat16}
    for model_name_or_path in [
        "michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
        "michaelbenayoun/granite-tiny-4kv-heads-4layers-random", 
        "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random",
        "your-org/your-model-name",  # Add your model here
    ]:
        # ... rest of test logic

@pytest.mark.parametrize("from_pretrained", [False, True], ids=["from_config", "from_pretrained"])
@distributed_test(world_size=1)
@is_trainium_test
def test_auto_model_for_causal_lm_with_supported_architecture(from_pretrained):
    trn_config = TrainingNeuronConfig()
    kwargs = {"torch_dtype": torch.bfloat16}
    for model_name_or_path in [
        "michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
        "michaelbenayoun/granite-tiny-4kv-heads-4layers-random",
        "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random", 
        "your-org/your-model-name",  # Add your model here
    ]:
        # ... rest of test logic

This validates:

  • Auto Model Loading: Tests that NeuronModel.from_pretrained() and NeuronModel.from_config() work correctly
  • Auto CausalLM Loading: Tests that NeuronModelForCausalLM.from_pretrained() and NeuronModelForCausalLM.from_config() work correctly

4. Running Tests

Tests require AWS Trainium instances. Run specific test categories:

# Run all custom modeling tests
pytest tests/training/test_custom_modeling.py -v

# Run specific model tests
pytest tests/training/test_custom_modeling.py -v -k "your_model"

# Run end-to-end training validation
pytest tests/training/test_overfit.py -v

5. Test Requirements

Your implementation must:

  1. Pass numerical correctness tests against original Transformers implementation
  2. Support parallelization strategies (at minimum DP and TP; PP support recommended)
  3. Handle various QKV implementations (regular, fused, GQA)
  4. Support checkpoint consolidation for distributed training
  5. Support LoRA training if applicable
  6. Demonstrate convergence through overfitting tests

The testing framework ensures your custom model maintains compatibility with the existing Optimum Neuron training infrastructure while delivering expected performance and correctness guarantees.

Common Issues

  • Weight Shape Mismatches: Ensure transformation specs handle tensor shapes correctly
  • Pipeline Parallelism Errors: Check that all required attributes are set
  • Memory Issues: Consider gradient checkpointing and activation recomputation
  • Attention Compatibility: Verify attention implementations work with your model architecture

Additional Resources

This guide provides the foundation for implementing custom models. For complete examples and advanced patterns, reference these existing implementations:

  • LLaMA: optimum/neuron/models/training/llama/modeling_llama.py - Complete implementation with (regular, fused and GQA attention), fused MLP
  • Qwen3: optimum/neuron/models/training/qwen3/modeling_qwen3.py - Demonstrates how to adapt the Llama implementation for Qwen3 with q_norm and k_norm layers

Key files to study:

  • optimum/neuron/models/training/modeling_utils.py - Base NeuronModelMixin class
  • optimum/neuron/models/training/transformations_utils.py - Weight transformation specifications
  • optimum/neuron/models/training/config.py - TrainingNeuronConfig for parallelism settings