Spaces:
Running
on
Zero
Running
on
Zero
"""Configuration management module for the Dia model. | |
This module provides comprehensive configuration management for the Dia model, | |
utilizing Pydantic for validation. It defines configurations for data processing, | |
model architecture (encoder and decoder), and training settings. | |
Key components: | |
- DataConfig: Parameters for data loading and preprocessing. | |
- EncoderConfig: Architecture details for the encoder module. | |
- DecoderConfig: Architecture details for the decoder module. | |
- ModelConfig: Combined model architecture settings. | |
- TrainingConfig: Training hyperparameters and settings. | |
- DiaConfig: Master configuration combining all components. | |
""" | |
import os | |
from pydantic import BaseModel, Field | |
class EncoderConfig(BaseModel, frozen=True): | |
"""Configuration for the encoder component of the Dia model. | |
Attributes: | |
model_type: Type of the model, defaults to "dia_encoder". | |
hidden_size: Size of the encoder layers, defaults to 1024. | |
intermediate_size: Size of the "intermediate" (i.e., feed-forward) layer in the encoder, defaults to 4096. | |
num_hidden_layers: Number of hidden layers in the encoder, defaults to 12. | |
num_attention_heads: Number of attention heads in the encoder, defaults to 16. | |
num_key_value_heads: Number of key-value heads in the encoder, defaults to 16. | |
head_dim: Dimension of each attention head, defaults to 128. | |
hidden_act: Activation function in the encoder, defaults to "silu". | |
max_position_embeddings: Maximum number of position embeddings, defaults to 1024. | |
initializer_range: Range for initializing weights, defaults to 0.02. | |
norm_eps: Epsilon value for normalization layers, defaults to 1e-5. | |
rope_theta: Theta value for RoPE, defaults to 10000.0. | |
rope_scaling: Optional scaling factor for RoPE. | |
vocab_size: Vocabulary size, defaults to 256. | |
""" | |
head_dim: int = Field(default=128, gt=0) | |
hidden_act: str = Field(default="silu") | |
hidden_size: int = Field(default=1024, gt=0) | |
initializer_range: float = Field(default=0.02) | |
intermediate_size: int = Field(default=4096, gt=0) | |
max_position_embeddings: int = Field(default=1024, gt=0) | |
model_type: str = Field(default="dia_encoder") | |
norm_eps: float = Field(default=1e-5) | |
num_attention_heads: int = Field(default=16, gt=0) | |
num_hidden_layers: int = Field(default=12, gt=0) | |
num_key_value_heads: int = Field(default=16, gt=0) | |
rope_scaling: float | None = Field(default=None) | |
rope_theta: float = Field(default=10000.0) | |
vocab_size: int = Field(default=256, gt=0) | |
class DecoderConfig(BaseModel, frozen=True): | |
"""Configuration for the decoder component of the Dia model. | |
Attributes: | |
model_type: Type of the model, defaults to "dia_decoder". | |
hidden_size: Size of the decoder layers, defaults to 2048. | |
intermediate_size: Size of the "intermediate" (i.e., feed-forward) layer in the decoder, defaults to 8192. | |
num_hidden_layers: Number of hidden layers in the decoder, defaults to 18. | |
num_attention_heads: Number of attention heads in the decoder, defaults to 16. | |
num_key_value_heads: Number of key-value heads in the decoder, defaults to 4. | |
head_dim: Dimension of each attention head, defaults to 128. | |
cross_hidden_size: Size of the cross-attention layers, defaults to 1024. | |
cross_num_attention_heads: Number of attention heads in the cross-attention mechanism, defaults to 16. | |
cross_num_key_value_heads: Number of key-value heads in the cross-attention mechanism, defaults to 16. | |
cross_head_dim: Dimension of each cross-attention head, defaults to 128. | |
hidden_act: Activation function in the decoder, defaults to "silu". | |
max_position_embeddings: Maximum number of position embeddings in the decoder, defaults to 3072. | |
initializer_range: Range for initializing weights in the decoder, defaults to 0.02. | |
norm_eps: Epsilon value for normalization layers in the decoder, defaults to 1e-5. | |
rope_theta: Theta value for RoPE in the decoder, defaults to 10000.0. | |
rope_scaling: Optional scaling factor for RoPE in the decoder. | |
vocab_size: Vocabulary size for the decoder, defaults to 1028. | |
num_channels: Number of channels in the decoder, defaults to 9. | |
""" | |
cross_head_dim: int = Field(default=128, gt=0) | |
cross_hidden_size: int = Field(default=1024, gt=0) | |
cross_num_attention_heads: int = Field(default=16, gt=0) | |
cross_num_key_value_heads: int = Field(default=16, gt=0) | |
head_dim: int = Field(default=128, gt=0) | |
hidden_act: str = Field(default="silu") | |
hidden_size: int = Field(default=2048, gt=0) | |
initializer_range: float = Field(default=0.02) | |
intermediate_size: int = Field(default=8192, gt=0) | |
max_position_embeddings: int = Field(default=3072, gt=0) | |
model_type: str = Field(default="dia_decoder") | |
norm_eps: float = Field(default=1e-5) | |
num_attention_heads: int = Field(default=16, gt=0) | |
num_channels: int = Field(default=9, gt=0) | |
num_hidden_layers: int = Field(default=18, gt=0) | |
num_key_value_heads: int = Field(default=4, gt=0) | |
rope_scaling: float | None = Field(default=None) | |
rope_theta: float = Field(default=10000.0) | |
vocab_size: int = Field(default=1028, gt=0) | |
class DiaConfig(BaseModel, frozen=True): | |
"""Main configuration container for the Dia model architecture. | |
Attributes: | |
model_type: Type of the model, defaults to "dia". | |
is_encoder_decoder: Flag indicating if the model is an encoder-decoder type, defaults to True. | |
encoder: Configuration for the encoder component. | |
decoder: Configuration for the decoder component. | |
src_vocab_size: Size of the source (text) vocabulary. | |
tgt_vocab_size: Size of the target (audio code) vocabulary. | |
initializer_range: Range for initializing weights, defaults to 0.02. | |
norm_eps: Epsilon value for normalization layers, defaults to 1e-5. | |
torch_dtype: Data type for model weights in PyTorch, defaults to "float32". | |
bos_token_id: Beginning-of-sequence token ID, defaults to 1026. | |
eos_token_id: End-of-sequence token ID, defaults to 1024. | |
pad_token_id: Padding token ID, defaults to 1025. | |
rope_theta: Theta value for RoPE, defaults to 10000.0. | |
rope_scaling: Optional scaling factor for RoPE. | |
transformers_version: Version of the transformers library, defaults to "4.53.0.dev0". | |
architectures: List of model architectures, defaults to ["DiaForConditionalGeneration"]. | |
delay_pattern: List of delay values for each audio channel, defaults to [0,8,9,10,11,12,13,14,15]. | |
""" | |
architectures: list[str] = Field( | |
default_factory=lambda: ["DiaForConditionalGeneration"] | |
) | |
bos_token_id: int = Field(default=1026) | |
decoder_config: DecoderConfig | |
delay_pattern: list[int] = Field( | |
default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15] | |
) | |
encoder_config: EncoderConfig | |
eos_token_id: int = Field(default=1024) | |
initializer_range: float = Field(default=0.02) | |
is_encoder_decoder: bool = Field(default=True) | |
model_type: str = Field(default="dia") | |
norm_eps: float = Field(default=1e-5) | |
pad_token_id: int = Field(default=1025) | |
torch_dtype: str = Field(default="float32") | |
transformers_version: str = Field(default="4.53.0.dev0") | |
def save(self, path: str) -> None: | |
"""Save the current configuration instance to a JSON file. | |
Ensures the parent directory exists and the file has a .json extension. | |
Args: | |
path: The target file path to save the configuration. | |
Raises: | |
ValueError: If the path is not a file with a .json extension. | |
""" | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
config_json = self.model_dump_json(indent=2) | |
with open(path, "w") as f: | |
f.write(config_json) | |
def load(cls, path: str) -> "DiaConfig | None": | |
"""Load and validate a Dia configuration from a JSON file. | |
Args: | |
path: The path to the configuration file. | |
Returns: | |
A validated DiaConfig instance if the file exists and is valid, | |
otherwise None if the file is not found. | |
Raises: | |
ValueError: If the path does not point to an existing .json file. | |
pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema. | |
""" | |
try: | |
with open(path, "r") as f: | |
content = f.read() | |
return cls.model_validate_json(content) | |
except FileNotFoundError: | |
return None | |