AWS Trainium & Inferentia documentation
Contributing Custom Models for Training
Contributing Custom Models for Training
This guide explains how to add custom model implementations to the optimum/neuron/models/training/
directory. Custom models are needed to support distributed training features like tensor parallelism, pipeline parallelism, and sequence parallelism on AWS Trainium devices.
Architecture Components
1. NeuronModelMixin
The NeuronModelMixin
class provides core functionality:
from_pretrained()
: Loads regular Transformers weights into custom implementationssave_pretrained()
: Saves sharded checkpoints with consolidation metadata- Pipeline parallelism support through
PIPELINE_*
attributes
2. Weight Transformation Specs
Transformation specs handle converting weights between:
- Original Transformers format → Custom parallel format (during loading)
- Custom parallel format → Original Transformers format (during checkpoint consolidation)
Key transformation spec types:
FusedLinearsSpec
: Handles fused linear layers (e.g.,gate_up_proj
)GQAQKVColumnParallelLinearSpec
: Handles grouped query attention projections when the tensor parallel size is greater than the number of key-value heads
For complete API documentation of all transformation specs and utility functions, see the Model Weight Transformation Specs API Reference.
3. Parallel Layers
Use these parallel layers from neuronx_distributed
:
ColumnParallelLinear
: Splits weight matrix along output dimensionRowParallelLinear
: Splits weight matrix along input dimensionParallelEmbedding
: Splits embedding table across ranksGQAQKVColumnParallelLinear
: Specialized for grouped query attention projections when the tensor parallel size is greater than the number of key-value heads
Implementation Steps
Step 1: Create Model Structure
Create a new directory: optimum/neuron/models/training/your_model/
__init__.py
from .modeling_your_model import YourModelForCausalLM, YourModel
__all__ = ["YourModelForCausalLM", "YourModel"]
Step 2: Implement the Model Building Blocks
modeling_your_model.py
Imports and Dependencies
import torch
from torch import nn
from neuronx_distributed.parallel_layers.layers import (
ColumnParallelLinear,
RowParallelLinear,
ParallelEmbedding,
)
from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear
from transformers import PreTrainedModel
from transformers.models.your_model import YourModelConfig
from ..config import TrainingNeuronConfig
from ..modeling_utils import NeuronModelMixin
from ..transformations_utils import (
CustomModule,
FusedLinearsSpec,
GQAQKVColumnParallelLinearSpec,
ModelWeightTransformationSpecs,
)
Embedding Layer
class YourModelEmbeddings(nn.Module):
def __init__(self, config, trn_config):
super().__init__()
self.embed_tokens = ParallelEmbedding(
config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
)
MLP Layer with Fused Linears
Important: Any module that has transformation specs must inherit from CustomModule
to ensure proper handling of weight transformations, and the transformation specs must be defined in the self.specs
attribute.
class YourModelMLP(nn.Module, CustomModule):
def __init__(self, config, trn_config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
# Fused gate and up projections
self.gate_up_proj = ColumnParallelLinear(
self.hidden_size,
2 * self.intermediate_size,
stride=2, # Important for proper sharding
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
)
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
input_is_parallel=True,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
)
# Define transformation specs
self.specs = ModelWeightTransformationSpecs()
self.specs.add_spec(
FusedLinearsSpec(
fused_linear_name="gate_up_proj",
linear_names=["gate_proj", "up_proj"],
bias=False,
fuse_axis="column", # Fuse along output dimension
original_dims=[self.intermediate_size, self.intermediate_size],
)
)
Attention Layer
The attention layer implementation depends on the model’s architecture and tensor parallel configuration. There are three main variants:
1. Separate Q, K, V Projections (Default)
class YourModelAttention(nn.Module, CustomModule):
def __init__(self, config, trn_config, layer_idx):
super().__init__()
self.config = config
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // self.num_heads
# Separate projections for Q, K, V
self.q_proj = ColumnParallelLinear(
config.hidden_size,
self.num_heads * self.head_dim,
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
)
self.k_proj = ColumnParallelLinear(
config.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
)
self.v_proj = ColumnParallelLinear(
config.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
config.hidden_size,
bias=False,
input_is_parallel=True,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
)
# No transformation specs needed - regular parallel layers
self.specs = ModelWeightTransformationSpecs()
2. Fused QKV Projection (Multi-Head Attention)
class YourModelAttention(nn.Module, CustomModule):
def __init__(self, config, trn_config, layer_idx):
super().__init__()
# ... (same setup as above)
tp_size = get_tensor_model_parallel_size()
# Only use fused QKV when num_heads == num_key_value_heads (no GQA)
if trn_config.fuse_qkv and self.num_heads == self.num_key_value_heads:
self.qkv_proj = ColumnParallelLinear(
config.hidden_size,
3 * self.num_heads * self.head_dim, # Q + K + V
stride=3, # Important for proper sharding
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
)
# Define transformation specs for fused QKV
self.specs = ModelWeightTransformationSpecs()
self.specs.add_spec(
FusedLinearsSpec(
fused_linear_name="qkv_proj",
linear_names=["q_proj", "k_proj", "v_proj"],
bias=False,
fuse_axis="column",
original_dims=[self.num_heads * self.head_dim] * 3,
)
)
self.split_size = self.num_heads * self.head_dim // tp_size
3. GQA QKV Projection (Required for Challenging TP Configurations)
class YourModelAttention(nn.Module, CustomModule):
def __init__(self, config, trn_config, layer_idx):
super().__init__()
# ... (same setup as above)
tp_size = get_tensor_model_parallel_size()
# Use GQA QKV when KV heads can't be evenly distributed across TP ranks
# This happens when: num_key_value_heads < tp_size or num_key_value_heads % tp_size != 0
self.qkv_linear = (self.num_key_value_heads < tp_size) or (self.num_key_value_heads % tp_size != 0)
if self.qkv_linear:
# Calculate KV size multiplier to ensure even distribution
if trn_config.kv_size_multiplier is None:
self.kv_size_multiplier = trn_config.auto_kv_size_multiplier(self.num_key_value_heads)
else:
self.kv_size_multiplier = trn_config.kv_size_multiplier
self.qkv_proj = GQAQKVColumnParallelLinear(
config.hidden_size,
[self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim],
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
kv_size_multiplier=self.kv_size_multiplier,
fuse_qkv=trn_config.fuse_qkv,
dtype=config.torch_dtype,
)
# Define transformation specs for GQA QKV
self.specs = ModelWeightTransformationSpecs()
self.specs.add_spec(
GQAQKVColumnParallelLinearSpec(
gqa_qkv_projection_name="qkv_proj",
query_projection_name="q_proj",
key_projection_name="k_proj",
value_projection_name="v_proj",
output_projection_name="o_proj",
num_attention_heads=self.num_heads,
num_key_value_heads=self.num_key_value_heads,
kv_size_multiplier=self.kv_size_multiplier,
q_output_size_per_partition=self.qkv_proj.q_output_size_per_partition,
kv_output_size_per_partition=self.qkv_proj.kv_output_size_per_partition,
fuse_qkv=trn_config.fuse_qkv,
)
)
When to Use Each Variant:
- Separate Q, K, V: Default approach, works for all configurations but may be less efficient
- Fused QKV: Use when
num_heads == num_key_value_heads
(no grouped query attention) andfuse_qkv=True
- GQA QKV: Required when using grouped query attention with challenging tensor parallel configurations where KV heads cannot be evenly distributed across TP ranks
The choice is typically determined by:
tp_size = get_tensor_model_parallel_size()
use_gqa_qkv = (num_key_value_heads < tp_size) or (num_key_value_heads % tp_size != 0)
use_fused_qkv = trn_config.fuse_qkv and (num_heads == num_key_value_heads) and not use_gqa_qkv
Step 3: Implement Main Model Classes
Base Model
class YourPreTrainedModel(PreTrainedModel, NeuronModelMixin):
config_class = YourModelConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["YourModelDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
class YourModel(NeuronModelMixin, YourPreTrainedModel):
def __init__(self, config: YourModelConfig, trn_config: TrainingNeuronConfig):
YourPreTrainedModel.__init__(self, config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.trn_config = trn_config
self.embed_tokens = ParallelEmbedding(...)
self.layers = nn.ModuleList([
YourModelDecoderLayer(config, trn_config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = YourModelRMSNorm(...)
self.post_init()
CausalLM Model
class YourModelForCausalLM(NeuronModelMixin, YourPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
# Pipeline parallelism support
SUPPORTS_PIPELINE_PARALLELISM = True
PIPELINE_TRANSFORMER_LAYER_CLS = YourModelDecoderLayer
PIPELINE_INPUT_NAMES = ["input_ids", "attention_mask"]
def __init__(self, config, trn_config):
super().__init__(config)
self.trn_config = trn_config
self.model = YourModel(config, trn_config)
self.vocab_size = config.vocab_size
self.lm_head = ColumnParallelLinear(
config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
dtype=config.torch_dtype,
)
self.post_init()
Step 4: Register Model
Update optimum/neuron/models/training/__init__.py
:
from .your_model import YourModelForCausalLM, YourModel
__all__ = [..., "YourModelForCausalLM", "YourModel"]
Update optimum/neuron/models/training/auto_models.py
:
from .your_model.modeling_your_model import YourModelForCausalLM, YourModel
# Register the base model (without head)
register_neuron_model_for_training("your_model", "model")(YourModel)
# Register the CausalLM model
register_neuron_model_for_training("your_model", "text-generation")(YourModelForCausalLM)
Here "your_model"
is the corresponding to the model_type
attribute of your model’s configuration class.
Best Practices
1. Parallel Layer Configuration
- Use
gather_output=False
for intermediate layers - Set
input_is_parallel=True
for layers that receive parallel input - Configure
sequence_parallel_enabled
consistently across layers - Use appropriate
stride
values for proper weight sharding
2. Weight Transformation Specs
- Always define specs for modules that use fused or parallel layers
- Use
CustomModule
mixin for any module with transformation specs - Ensure spec parameter names match the actual module structure
- Test both regular and LoRA weight transformations
3. Pipeline Parallelism
- Set
SUPPORTS_PIPELINE_PARALLELISM = True
for supported models - Define
PIPELINE_TRANSFORMER_LAYER_CLS
as your decoder layer class - List all input names in
PIPELINE_INPUT_NAMES
4. Flash Attention Support
- Set
_supports_flash_attn_2 = True
if your model supports it - Implement both eager and flash attention paths
- Use appropriate attention function dispatching
Testing Your Implementation
The training tests in tests/training/
provide a comprehensive testing framework that validates numerical correctness, distributed training scenarios, and checkpoint compatibility.
Most of the tests are not designed to be run on every custom modeling implementation, but rather to validate the core functionality of the Optimum Neuron training infrastructure.
With that in mind, here’s what you need to implement for your custom modeling:
1. Custom Modeling Validation
The test_custom_modeling.py
file validates that your custom implementation produces identical outputs to the original Transformers model:
Update tests/training/test_custom_modeling.py
:
CUSTOM_MODELINGS_TO_TEST = [
# ... existing models ...
("YourModelForCausalLM", "your-org/your-model-name"),
]
Important: For custom modeling validation tests, use small/tiny models to ensure CI efficiency. The test models should have:
- Small vocabulary size (e.g., 1000-8000 tokens)
- Few layers (e.g., 2-4 layers)
- Small hidden dimensions (e.g., 128-512)
- Minimal attention heads (e.g., 4-8 heads)
Examples of good test models for custom modeling validation:
"michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random"
- 4 layers, 4 KV heads"michaelbenayoun/granite-tiny-4kv-heads-4layers-random"
- Tiny Granite model"michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random"
- Tiny Qwen3 model
Key test your model must pass:
def test_custom_modeling_matches_original() # Output matching
- Numerical Correctness: Ensures custom models match Transformers outputs exactly
- Parallelization Support: Tests various QKV implementations (regular, fused, GQA)
2. End-to-End Training Validation
The test_overfit.py
file validates training convergence. To include your model in end-to-end training validation, you must add it to the parametrized test cases:
Update tests/training/test_overfit.py
:
@pytest.mark.parametrize(
"model_class_name,model_name_or_path,learning_rate,warmup_ratio,training_kwargs,use_flash_attention_2,max_expected_loss,max_length,num_steps",
[
# ... existing models ...
[
"YourModelForCausalLM",
"your-org/your-model-name",
1e-4,
0.03,
{},
True,
0.5,
2048,
50,
],
],
ids=[
# ... existing model IDs ...
"your-org/your-model-name",
],
)
This validates:
- Convergence Validation: Ensures models can overfit simple datasets
Your model will be tested in:
def test_overfit_custom_modeling_causal_lm() # Basic training (your model included)
3. Auto Model Loading
The test_modeling_auto.py
file validates that your model can be loaded using the NeuronModel
and NeuronModelForCausalLM
auto classes. To include your model in these tests, you must add it to the test cases:
Update tests/training/test_modeling_auto.py
:
@pytest.mark.parametrize("from_pretrained", [False, True], ids=["from_config", "from_pretrained"])
@distributed_test(world_size=1)
@is_trainium_test
def test_auto_model_with_supported_architecture(from_pretrained):
trn_config = TrainingNeuronConfig()
kwargs = {"torch_dtype": torch.bfloat16}
for model_name_or_path in [
"michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
"michaelbenayoun/granite-tiny-4kv-heads-4layers-random",
"michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random",
"your-org/your-model-name", # Add your model here
]:
# ... rest of test logic
@pytest.mark.parametrize("from_pretrained", [False, True], ids=["from_config", "from_pretrained"])
@distributed_test(world_size=1)
@is_trainium_test
def test_auto_model_for_causal_lm_with_supported_architecture(from_pretrained):
trn_config = TrainingNeuronConfig()
kwargs = {"torch_dtype": torch.bfloat16}
for model_name_or_path in [
"michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
"michaelbenayoun/granite-tiny-4kv-heads-4layers-random",
"michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random",
"your-org/your-model-name", # Add your model here
]:
# ... rest of test logic
This validates:
- Auto Model Loading: Tests that
NeuronModel.from_pretrained()
andNeuronModel.from_config()
work correctly - Auto CausalLM Loading: Tests that
NeuronModelForCausalLM.from_pretrained()
andNeuronModelForCausalLM.from_config()
work correctly
4. Running Tests
Tests require AWS Trainium instances. Run specific test categories:
# Run all custom modeling tests
pytest tests/training/test_custom_modeling.py -v
# Run specific model tests
pytest tests/training/test_custom_modeling.py -v -k "your_model"
# Run end-to-end training validation
pytest tests/training/test_overfit.py -v
5. Test Requirements
Your implementation must:
- Pass numerical correctness tests against original Transformers implementation
- Support parallelization strategies (at minimum DP and TP; PP support recommended)
- Handle various QKV implementations (regular, fused, GQA)
- Support checkpoint consolidation for distributed training
- Support LoRA training if applicable
- Demonstrate convergence through overfitting tests
The testing framework ensures your custom model maintains compatibility with the existing Optimum Neuron training infrastructure while delivering expected performance and correctness guarantees.
Common Issues
- Weight Shape Mismatches: Ensure transformation specs handle tensor shapes correctly
- Pipeline Parallelism Errors: Check that all required attributes are set
- Memory Issues: Consider gradient checkpointing and activation recomputation
- Attention Compatibility: Verify attention implementations work with your model architecture
Additional Resources
This guide provides the foundation for implementing custom models. For complete examples and advanced patterns, reference these existing implementations:
- LLaMA:
optimum/neuron/models/training/llama/modeling_llama.py
- Complete implementation with (regular, fused and GQA attention), fused MLP - Qwen3:
optimum/neuron/models/training/qwen3/modeling_qwen3.py
- Demonstrates how to adapt the Llama implementation for Qwen3 withq_norm
andk_norm
layers
Key files to study:
optimum/neuron/models/training/modeling_utils.py
- BaseNeuronModelMixin
classoptimum/neuron/models/training/transformations_utils.py
- Weight transformation specificationsoptimum/neuron/models/training/config.py
-TrainingNeuronConfig
for parallelism settings