AWS Trainium & Inferentia documentation
Model Weight Transformation Specs
Model Weight Transformation Specs
The transformation specs API defines how model weights are transformed between the original Transformers implementation and the custom implementation optimized for Neuron devices. This enables automatic weight conversion during model loading and checkpoint consolidation.
Base Classes
ModelWeightTransformationSpec
This class defines the interface for transforming model weights between the original Transformers implementation and the custom implementation for Neuron.
Adapts the PEFT config to match the custom modeling implementation.
adapt_state_dict
< source >( module_fully_qualified_name: str named_parameters: dict[str, torch.nn.parameter.Parameter] orig_state_dict: dict[str, torch.Tensor] upstanding_sharded_params: dict[str, torch.Tensor] inplace: bool = False )
Transforms the state dict from the original Transformers model to match the custom modeling implementation.
Returns the set of parameter names that this spec would affect.
Guesses the PEFT type of the module associated to the spec.
Restores the PEFT config to the original one that matches the original Transformers implementation.
to_original_weights
< source >( module_fully_qualified_name: str sharded_state_dicts: dict[str, list[torch.Tensor]] parameters_metadata: dict[str, dict[str, typing.Any]] ) → tuple[dict[str, torch.Tensor], list[str]]
Parameters
- sharded_state_dicts (dict[str, list[torch.Tensor]]) — The sharded state dicts from the custom modeling implementation.
- parameters_metadata (dict[str, dict[str, Any]]) — Metadata about the parameters in the original model.
Returns
tuple[dict[str, torch.Tensor], list[str]]
A tuple containing the transformed weights and a list of the names of the parameters to remove from the final state dict.
Produces the weights associated to this transformation spec from the custom model to match the original Transformers weights.
ModelWeightTransformationSpecs
class optimum.neuron.models.training.ModelWeightTransformationSpecs
< source >( module_fully_qualified_name: str | None = None specs: optimum.neuron.models.training.transformations_utils.ModelWeightTransformationSpec | list[optimum.neuron.models.training.transformations_utils.ModelWeightTransformationSpec] = <factory> )
Defines a list of transformation specs for a given module of the model.
CustomModule
This class is used to mark a module as a custom module. It is used to identify the modules that contain weights that need to transformed when loading and saving the model.
Transformation Specifications
FusedLinearsSpec
class optimum.neuron.models.training.FusedLinearsSpec
< source >( fused_linear_name: str linear_names: list[str] bias: bool fuse_axis: typing.Union[typing.Literal[0], typing.Literal[1], typing.Literal['column'], typing.Literal['row']] original_dims: list[int] tp_size: int = <factory> )
Represents a transformation where multiple linear layers are fused into a single linear layer. It can handle the case where the fused linear layer is sharded across multiple tensor parallel ranks.
GQAQKVColumnParallelLinearSpec
class optimum.neuron.models.training.GQAQKVColumnParallelLinearSpec
< source >( gqa_qkv_projection_name: str query_projection_name: str key_projection_name: str value_projection_name: str output_projection_name: str num_attention_heads: int num_key_value_heads: int kv_size_multiplier: int q_output_size_per_partition: int kv_output_size_per_partition: int fuse_qkv: bool bias: bool tp_size: int = <factory> )
Represents the transformation of separate query, key, and value projections into a single GQAQKVColumnParalleLinear projection.
compute_query_indices_for_rank
< source >( tp_size: int tp_rank: int num_attention_heads: int num_key_value_heads: int kv_size_multiplier: int )
Computes the permutation for the query weight for a given TP rank.
create_kv_proj_local_weight_from_regular_weight
< source >( weight_data: Tensor kv_size_multiplier: int output_size_per_partition: int )
Creates the local version of the key or value projections weight for the given TP rank.
create_query_or_output_projection_local_weight_from_regular_weight
< source >( weight_data: Tensor num_attention_heads: int num_key_value_heads: int kv_size_multiplier: int query_or_output_proj: typing.Union[typing.Literal['query'], typing.Literal['output']] )
Creates the local version of the query or output projections weight for the given TP rank.
Utility Functions
Weight Creation Functions
optimum.neuron.models.training.transformations_utils.create_local_weight_with_padding
< source >( full_weight: Tensor partition_dim: int stride: int out_weight: torch.Tensor | None = None )
Shards a tensor along a given axis and return a slice corresponding to the rank. This will round up the layer to the next multiple if there is need to pad the tensor.
optimum.neuron.models.training.transformations_utils.create_local_fused_weight
< source >( tp_rank tp_size individual_weights partition_dim fuse_axis out_weight = None )
Shards individual weights across the tensor parallel ranks and fuses them into a single weight.
Model-level Functions
optimum.neuron.models.training.specialize_transformation_specs_for_model
< source >( model: Module )
optimum.neuron.models.training.adapt_peft_config_for_model
< source >( model: Module peft_config: peft.config.PeftConfig | dict[str, peft.config.PeftConfig] inplace: bool = False )
optimum.neuron.models.training.to_original_peft_config_for_model
< source >( model: Module peft_config: PeftConfig inplace: bool = False )
State Dict Functions
optimum.neuron.models.training.adapt_state_dict
< source >( model: Module state_dict: dict[str, torch.Tensor] upstanding_sharded_params: dict[str, torch.Tensor] inplace: bool = False **peft_kwargs: Any )
Transforms the state dict from the original Transformers model to match the custom modeling implementation.
optimum.neuron.models.training.to_original_weights
< source >( transformations_specs: list[optimum.neuron.models.training.transformations_utils.ModelWeightTransformationSpecs] sharded_state_dicts: dict[str, list[torch.Tensor]] parameters_metadata: dict[str, dict[str, typing.Any]] **peft_kwargs: Any )
Consolidates the sharded state dicts produced by saving the custom model into a single state dict that matches the original Transformers model weights.
Metadata Functions
Creates the metadata to be saved with the model weights to be able to reconstruct the original weights when consolidating the sharded state dicts.
optimum.neuron.models.training.transformations_utils.get_tensor_model_parallel_attributes
< source >( tensor: Tensor )
Returns the tensor model parallel attributes of a tensor.
Helper Functions
optimum.neuron.models.training.transformations_utils.get_adapter_name
< source >( parameter_fully_qualified_name: str )