AWS Trainium & Inferentia documentation

Distributed Training with optimum-neuron

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Distributed Training with optimum-neuron

AWS Trainium instances provide powerful infrastructure for training large language models at scale. A trn1.32xlarge instance contains 16 Neuron devices with 32 cores total, offering 512GB of memory (16GB per core).

However, training large models presents a fundamental challenge: by default, each Neuron core operates as an independent data-parallel worker, requiring the entire model, gradients, and optimizer state (approximately 4Γ— the model size) to fit within a single core’s 16GB memory limit, with additional space needed for activations.

For models that exceed these memory constraints, optimum-neuron provides sophisticated parallelism strategies that distribute computation and memory across multiple devices, enabling you to train models that would be impossible to fit on individual cores:

Parallelism Strategies Overview

1. ZeRO-1 (Optimizer State Sharding)

ZeRO-1 is an optimizer-level optimization that reduces memory usage without changing your model architecture.

How it works: Shards the optimizer state (gradients, momentum, variance) across data-parallel ranks instead of replicating it on each device.

Memory savings: Reduces optimizer memory usage by 1/data_parellel_size.

When to use: Always beneficial when training with multiple devices, regardless of model size.

2. Tensor Parallelism (Intra-layer Model Parallelism)

Tensor Parallelism splits individual model layers across multiple devices.

How it works: Shards matrix multiplications (linear layers, attention) along rows or columns across devices. Each device computes part of each layer, requiring communication between devices for each forward/backward pass.

Memory savings: Reduces model parameter memory by 1/tensor_parallel_size.

When to use: When your model is too large to fit on a single device, even after applying ZeRO-1.

Typical deployment: Usually applied within a single node (intra-node) due to high communication requirements.

Trade-offs: Increases communication overhead between devices, which can slow down training if overused.

3. Sequence Parallelism (Activation Sharding)

Sequence parallelism is an optimization that works alongside Tensor Parallelism to further reduce memory usage.

How it works: Shards activations along the sequence dimension in regions where tensors are not already sharded by tensor parallelism.

Memory savings: Reduces activation memory proportional to sequence length, especially beneficial for long sequences.

When to use: Always enable when using tensor parallelism - it provides additional memory savings with minimal overhead.

Requirement: Only works in combination with tensor parallelism.

4. Pipeline Parallelism (Inter-layer Model Parallelism)

Pipeline Parallelism splits model layers across different devices.

How it works: Divides your model into stages, with each stage containing consecutive layers running on different devices. Uses microbatching to keep all devices busy.

Memory savings: Reduces model parameter memory by 1/pipeline_parallel_size.

When to use: For very large models that don’t fit even with tensor parallelism, or when you want to scale across many devices with less communication overhead than tensor parallelism.

Typical deployment: Usually applied across multiple nodes (inter-node) to scale to larger numbers of devices while minimizing high-bandwidth communication requirements.

Trade-offs: Introduces pipeline bubbles (idle time) and requires careful tuning of microbatch sizes.

The good news is that it is possible to combine those techniques, and optimum-neuron makes it very easy!

All the training examples in the optimum-neuron repo use these parallelism features via the NeuronTrainer.

How to enable ZeRO-1?

ZeRO-1 can be enabled either through the NeuronTrainer or directly with the NeuronAccelerator.

Via the NeuronTrainer

from optimum.neuron import NeuronTrainingArguments, NeuronTrainer

# Enable ZeRO-1 in the training arguments
training_args = NeuronTrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=1,
    zero_1=True,  # Enable ZeRO-1
    bf16=True,
    # ... other training arguments
)

trainer = NeuronTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

Since the example scripts use the NeuronTrainer, you can enable ZeRO-1 when using them by adding the --zero_1 flag to your command line.

For example:

torchrun --nproc_per_node=2 examples/training/qwen3/finetune_qwen3.py \
    --model_name_or_path Qwen/Qwen2.5-0.5B \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --do_train \
    --per_device_train_batch_size 1 \
    --block_size 1024 \
    --bf16 \
    --zero_1 \
    --tensor_parallel_size 2 \
    --output_dir my_training/

Via the NeuronAccelerator

When using the NeuronAccelerator directly, you need to create a TrainingNeuronConfig and enable ZeRO-1 separately:

from torch.optim import AdamW
from optimum.neuron import NeuronAccelerator
from optimum.neuron.models.training.config import TrainingNeuronConfig

# Create the training configuration
trn_config = TrainingNeuronConfig()

# Create accelerator with ZeRO-1 enabled
accelerator = NeuronAccelerator(
    trn_config=trn_config,
    zero_1=True,  # Enable ZeRO-1
    mixed_precision="bf16",
)

model = ...  # Your model instance
optimizer = AdamW(model.parameters(), lr=5e-5)

# Prepare model and optimizer
model, optimizer = accelerator.prepare(model, optimizer)

How to enable Tensor Parallelism?

Tensor Parallelism can be used with either the NeuronTrainer or NeuronAccelerator.

Important: Tensor parallelism requires models that have a custom modeling implementation in optimum.neuron.models.training.

When doing Tensor Parallelism, you have several important settings:

  1. The tensor_parallel_size: Ideally it should be the smallest value for which the model fits in memory.
  2. Whether sequence parallelism should be enabled: Sequence parallelism shards the activations on the sequence axis outside of the tensor parallel regions, saving memory by sharding the activations.

When using distributed training, the training script is called by torchrun, which will dispatch it to workers, one worker per core. Each worker will load the sharded model and dispatch the parameters automatically across the cores. The tensor_parallel_size is the number of workers to shard the model parameters on.

Via the NeuronTrainer

from optimum.neuron import NeuronTrainingArguments, NeuronTrainer

# Configure tensor parallelism in training arguments
training_args = NeuronTrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=1,
    bf16=True,
    tensor_parallel_size=8,
    # ... other training arguments
)

trainer = NeuronTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

Since the example scripts use the NeuronTrainer, you can enable Tensor Parallelism when using them by specifying the --tensor_parallel_size argument.

For example:

torchrun --nproc_per_node=8 examples/training/qwen3/finetune_qwen3.py \
    --model_name_or_path Qwen/Qwen2.5-0.5B \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --do_train \
    --per_device_train_batch_size 1 \
    --block_size 1024 \
    --bf16 \
    --tensor_parallel_size 8 \
    --output_dir my_training/

Via the NeuronAccelerator

When using the NeuronAccelerator directly, you configure tensor parallelism through the TrainingNeuronConfig:

from torch.optim import AdamW
from optimum.neuron import NeuronAccelerator
from optimum.neuron.models.training.config import TrainingNeuronConfig

# Configure tensor parallelism
trn_config = TrainingNeuronConfig(
    tensor_parallel_size=8,
    sequence_parallel_enabled=True,
    checkpoint_dir=None,  # Can be specified when resuming from checkpoint
)

accelerator = NeuronAccelerator(
    trn_config=trn_config,
    mixed_precision="bf16",
)

model = ...  # Your model instance
optimizer = AdamW(model.parameters(), lr=5e-5)

model, optimizer = accelerator.prepare(model, optimizer)

How to enable Pipeline Parallelism?

Pipeline Parallelism allows you to split your model layers across multiple devices, enabling training of very large models that wouldn’t fit on a single device, or even a signle node.

Important: Pipeline parallelism requires models that have a custom modeling implementation in optimum.neuron.models.training and declare SUPPORTS_PIPELINE_PARALLELISM = True.

Configuration Options

Pipeline parallelism has several configuration parameters:

  • pipeline_parallel_size: Number of pipeline stages (devices to split layers across)
  • pipeline_parallel_num_microbatches: Number of microbatches for pipeline scheduling
  • When pipeline parallelism is enabled, ZeRO-1 can be automatically applied to the pipeline parallel optimizer

Via the NeuronTrainer

from optimum.neuron import NeuronTrainingArguments, NeuronTrainer
from optimum.neuron.models.training import LlamaForCausalLM  # Custom model implementation

# Configure pipeline parallelism in training arguments
training_args = NeuronTrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=4,  # Will be split into microbatches
    bf16=True,
    tensor_parallel_size=2,
    pipeline_parallel_size=4,                    # Split model across 4 pipeline stages
    pipeline_parallel_num_microbatches=4,        # Number of microbatches
    zero_1=True,                                 # Enable ZeRO-1 with pipeline parallelism
    # ... other training arguments
)

# Load model using custom implementation - must be done with the model class directly
model = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B",
    trn_config=training_args.trn_config  # Pass the auto-generated trn_config
)

trainer = NeuronTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

Via the NeuronAccelerator

from optimum.neuron import NeuronAccelerator
from optimum.neuron.models.training.config import TrainingNeuronConfig
from optimum.neuron.models.training import LlamaForCausalLM
from torch.optim import AdamW

# Configure combined parallelism strategies
trn_config = TrainingNeuronConfig(
    tensor_parallel_size=2,
    pipeline_parallel_size=4,
    pipeline_parallel_num_microbatches=4,
    sequence_parallel_enabled=True,
)

accelerator = NeuronAccelerator(
    trn_config=trn_config,
    zero_1=True,  # Can combine with ZeRO-1
    mixed_precision="bf16",
)

# Load model with custom implementation
model = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B",
    trn_config=trn_config
)

optimizer = AdamW(model.parameters(), lr=5e-5)
model, optimizer = accelerator.prepare(model, optimizer)

When using pipeline parallelism, the total number of processes should be at least tensor_parallel_size * pipeline_parallel_size. For example, with tensor_parallel_size=2 and pipeline_parallel_size=4, you need 8 processes total.

Combining Parallelism Strategies

You can combine multiple parallelism strategies for maximum memory efficiency and performance. Here’s an example with all strategies combined:

Via the NeuronTrainer

from optimum.neuron import NeuronTrainingArguments, NeuronTrainer
from optimum.neuron.models.training import LlamaForCausalLM

# Example: Combine all parallelism strategies
training_args = NeuronTrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=32,
    bf16=True,
    gradient_checkpointing=True,
    
    # ZeRO-1
    zero_1=True,
    
    # Tensor parallelism
    tensor_parallel_size=4,
    disable_sequence_parallel=False,     # Enable sequence parallelism
    
    # Pipeline parallelism
    pipeline_parallel_size=2,
    pipeline_parallel_num_microbatches=8,
    
    # Additional optimizations
    fuse_qkv=True,                      # Fuse QKV projections for efficiency
    kv_size_multiplier=None,            # Auto-calculate optimal KV multiplier
)

# Load model using custom implementation
model = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B",
    trn_config=training_args.trn_config
)

trainer = NeuronTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

This configuration uses 4 * 2 = 8 total processes:

  • Each tensor parallel group has 4 processes
  • Each pipeline stage runs on one tensor parallel group

We can then run the training script on the trn1.32xlarge instance with 32 Neuron cores, resulting in the following configuration: dp=4, tp=4, pp=2, which means 4 data-parallel groups, each with 4 tensor-parallel devices, and 2 pipeline stages.

Checkpoint consolidation

Since distributed training uses sharded checkpoints across different workers, you need to consolidate them to create a standard model checkpoint that can be shared and used outside of the specific training configuration.

The Optimum CLI provides a way of doing that very easily via the optimum neuron consolidate command:

optimum-cli neuron consolidate --help

usage: optimum-cli neuron consolidate [-h] [-f {pytorch,safetensors}] checkpoint_dir output_dir

positional arguments:
  checkpoint_dir        The path to the directory containing the checkpoints.
  output_dir            The path to the output directory containing the consolidated checkpoint.

optional arguments:
  -h, --help            show this help message and exit
  -f {pytorch,safetensors}, --format {pytorch,safetensors}
                        The format used to save the consolidated checkpoint.

All you need to do is specify the sharded checkpoints directory and the output directory that will contain the consolidated checkpoints, and the command takes care of the rest. It is also possible to specify the output format of the consolidated checkpoints. By default it will export them to the safetensors format, which is the recommended format to use.

Example:

Training with distributed parallelism just completed and the output dir is called my_training. The directory looks like the following:

my_training/
β”œβ”€β”€ README.md
β”œβ”€β”€ all_results.json 
β”œβ”€β”€ checkpoint-10 
β”‚   β”œβ”€β”€ config.json
β”‚   β”œβ”€β”€ scheduler.pt
β”‚   β”œβ”€β”€ special_tokens_map.json
β”‚   β”œβ”€β”€ shards/
β”‚   β”œβ”€β”€ tokenizer.json
β”‚   β”œβ”€β”€ tokenizer.model
β”‚   β”œβ”€β”€ tokenizer_config.json
β”‚   β”œβ”€β”€ trainer_state.json
β”‚   └── training_args.bin
β”œβ”€β”€ config.json
β”œβ”€β”€ special_tokens_map.json
β”œβ”€β”€ shards/
β”‚   β”œβ”€β”€ tp_rank_00_pp_rank_00
β”‚   β”œβ”€β”€ tp_rank_01_pp_rank_00
β”‚   β”œβ”€β”€ tp_rank_02_pp_rank_00
β”‚   β”œβ”€β”€ tp_rank_03_pp_rank_00
β”‚   β”œβ”€β”€ tp_rank_00_pp_rank_01
β”‚   β”œβ”€β”€ tp_rank_01_pp_rank_01
β”‚   β”œβ”€β”€ tp_rank_02_pp_rank_01
β”‚   └── tp_rank_03_pp_rank_01
β”œβ”€β”€ tokenizer.json
β”œβ”€β”€ tokenizer.model
β”œβ”€β”€ tokenizer_config.json
β”œβ”€β”€ train_results.json
β”œβ”€β”€ trainer_state.json
β”œβ”€β”€ training_args.bin
└── trn_config.json

You can consolidate the sharded checkpoints in my_training/shards, which correspond to the sharded checkpoints saved at the end of training, by running the following command:

optimum-cli neuron consolidate my_training my_training_consolidated_checkpoint

The sharded checkpoints are saved under a directory called shards. The optimum-cli neuron consolidate command accepts as input both a directory that contains a shards directory, or the shards directory itself.

Best Practices

Choosing Parallelism Strategy

  1. Start with Tensor Parallelism: Use the smallest tensor_parallel_size that fits your model in memory
  2. Add Pipeline Parallelism: For very large models, combine with pipeline parallelism
  3. Enable Sequence Parallelism: Always enable when using tensor parallelism for memory savings (set disable_sequence_parallel=False)
  4. Use ZeRO-1: Combine with any parallelism strategy for optimizer memory savings

Memory Optimization

  • Enable gradient_checkpointing for large models
  • Set appropriate pipeline_parallel_num_microbatches for pipeline parallelism

Troubleshooting

Common Issues

  1. Out of Memory: Reduce batch size, increase parallelism, or enable gradient checkpointing
  2. Model Not Supported: Ensure you’re using a model from optimum.neuron.models.training
  3. Pipeline Parallelism Fails: Check that the model supports pipeline parallelism
  4. Incorrect Process Count: Ensure nproc_per_node matches your parallelism configuration

Debugging Tips

  • Start with smaller models and parallelism sizes
  • Check that all processes can communicate properly
  • Verify checkpoint directories and permissions
  • Monitor Neuron device utilization