AWS Trainium & Inferentia documentation

Supported architectures

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Supported architectures

Training

Training on AWS Trainium instances (Trn1) enables large-scale model training with distributed parallelism strategies.

Requirements:

  • Model must be compatible with the Neuron SDK. If it small enough to fit within 16GB, training is supported for any architecture that can be successfully compiled.
  • Memory constraint: Each accelerator has 16GB of memory for model weights, gradients, optimizer states, and activations.
  • For large models: Custom modeling implementation with tensor parallelism and/or pipeline parallelism support is required.

The following architectures have custom modeling implementations with distributed training support:

Architecture Task Tensor Parallelism Pipeline Parallelism
Llama, Llama 2, Llama 3 text-generation
Qwen3 text-generation
Granite text-generation

If you need to add support for a custom model not listed above, check out our contribute for training guide to learn how to implement custom modeling with distributed training support. You can also open an issue in the Optimum Neuron GitHub repository to request support for it.

Inference

The following table lists the architectures and tasks that Optimum Neuron supports for inference on Amazon EC2 Inf2 instances.

If a LLM is listed, e.g. a model with a text-generation task, it means that there is also TGI support for it.

Transformers

Architecture Task
ALBERT feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
AST feature-extraction, audio-classification
BERT feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
Beit feature-extraction, image-classification
CamemBERT feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
CLIP feature-extraction, image-classification
ConvBERT feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
ConvNext feature-extraction, image-classification
ConvNextV2 feature-extraction, image-classification
CvT feature-extraction, image-classification
DeBERTa (INF2 only) feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
DeBERTa-v2 (INF2 only) feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
Deit feature-extraction, image-classification
DistilBERT feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
DonutSwin feature-extraction
Dpt feature-extraction
ELECTRA feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
ESM feature-extraction, fill-mask, text-classification, token-classification
FlauBERT feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
Granite text-generation
Hubert feature-extraction, automatic-speech-recognition, audio-classification
Levit feature-extraction, image-classification
Llama, Llama 2, Llama 3 text-generation
Mixtral text-generation
MobileBERT feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
MobileNetV2 feature-extraction, image-classification, semantic-segmentation
MobileViT feature-extraction, image-classification, semantic-segmentation
ModernBERT feature-extraction, fill-mask, text-classification, token-classification
MPNet feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
Phi3 text-generation
Phi feature-extraction, text-classification, token-classification
Qwen2, Qwen 3 text-generation
RoBERTa feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
RoFormer feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
Swin feature-extraction, image-classification
T5 text2text-generation
UniSpeech feature-extraction, automatic-speech-recognition, audio-classification
UniSpeech-SAT feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector
ViT feature-extraction, image-classification
Wav2Vec2 feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector
WavLM feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector
Whisper automatic-speech-recognition
XLM feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
XLM-RoBERTa feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
Yolos feature-extraction, object-detection

Diffusers

Architecture Task
Stable Diffusion text-to-image, image-to-image, inpaint
Stable Diffusion XL Base text-to-image, image-to-image, inpaint
Stable Diffusion XL Refiner image-to-image, inpaint
SDXL Turbo text-to-image, image-to-image, inpaint
LCM text-to-image
PixArt-α text-to-image
PixArt-Σ text-to-image
Flux text-to-image

Sentence Transformers

Architecture Task
Transformer feature-extraction, sentence-similarity
CLIP feature-extraction, zero-shot-image-classification

To learn how to export a model for inference, you can check this guide.