from typing import Dict, Optional, Tuple, Union from transformers import PretrainedConfig from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( Qwen2_5OmniTextConfig, ) class DashengConfig(PretrainedConfig): model_type = "midashenglm_dasheng_encoder" def __init__( self, embed_dim: int = 768, outputdim: int = 527, patch_size: Union[int, Tuple[int, int]] = 16, patch_stride: Union[int, Tuple[int, int]] = 16, input_channels: int = 1, target_length: int = 1012, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, init_values: Optional[float] = None, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, f_min: float = 0.0, f_max: float = 8000.0, center: bool = True, win_length: int = 512, hop_length: int = 160, sample_rate: int = 16000, n_fft: int = 512, n_mels: int = 64, **kwargs, ): self.embed_dim = embed_dim self.outputdim = outputdim self.patch_size = patch_size self.patch_stride = patch_stride self.input_channels = input_channels self.target_length = target_length self.depth = depth self.num_heads = num_heads self.mlp_ratio = mlp_ratio self.qkv_bias = qkv_bias self.init_values = init_values self.drop_rate = drop_rate self.attn_drop_rate = attn_drop_rate self.f_min = f_min self.f_max = f_max self.center = center self.win_length = win_length self.hop_length = hop_length self.sample_rate = sample_rate self.n_fft = n_fft self.n_mels = n_mels super().__init__(**kwargs) class MiDashengLMConfig(PretrainedConfig): model_type = "midashenglm" def __init__( self, audio_encoder_config: Dict = {}, subsample_factor: int = 5, text_config: Dict = {}, audio_token_id: Optional[int] = None, **kwargs, ): self.audio_encoder_config = DashengConfig(**audio_encoder_config) self.subsample_factor = subsample_factor self.text_config = ( Qwen2_5OmniTextConfig(**text_config) if text_config else Qwen2_5OmniTextConfig() ) self.audio_token_id = audio_token_id super().__init__(**kwargs)