zhoukz commited on
Commit
896c62c
·
1 Parent(s): 05a65a3

Upload folder using huggingface_hub

Browse files
config.json CHANGED
@@ -67,7 +67,7 @@
67
  "torch_dtype": "bfloat16",
68
  "use_cache": true,
69
  "use_sliding_window": false,
70
- "vocab_size": 152064
71
  },
72
  "torch_dtype": "float32",
73
  "transformers_version": "4.52.4"
 
67
  "torch_dtype": "bfloat16",
68
  "use_cache": true,
69
  "use_sliding_window": false,
70
+ "vocab_size": 151936
71
  },
72
  "torch_dtype": "float32",
73
  "transformers_version": "4.52.4"
configuration_midashenglm.py CHANGED
@@ -1,5 +1,4 @@
1
- from ast import Dict
2
- from typing import Optional, Tuple, Union
3
 
4
  from transformers import PretrainedConfig
5
  from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
@@ -66,7 +65,7 @@ class MiDashengLMConfig(PretrainedConfig):
66
  self,
67
  audio_encoder_config: Dict = {},
68
  subsample_factor: int = 5,
69
- text_config: Dict = None,
70
  **kwargs,
71
  ):
72
  self.audio_encoder_config = DashengConfig(**audio_encoder_config)
 
1
+ from typing import Dict, Optional, Tuple, Union
 
2
 
3
  from transformers import PretrainedConfig
4
  from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
 
65
  self,
66
  audio_encoder_config: Dict = {},
67
  subsample_factor: int = 5,
68
+ text_config: Dict = {},
69
  **kwargs,
70
  ):
71
  self.audio_encoder_config = DashengConfig(**audio_encoder_config)
model.safetensors.index.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "metadata": {
3
- "total_size": 9385880844
4
  },
5
  "weight_map": {
6
  "audio_encoder.blocks.0.attn.proj.bias": "model-00001-of-00002.safetensors",
 
1
  {
2
  "metadata": {
3
+ "total_size": 9384832268
4
  },
5
  "weight_map": {
6
  "audio_encoder.blocks.0.attn.proj.bias": "model-00001-of-00002.safetensors",
modeling_midashenglm.py CHANGED
@@ -1,13 +1,14 @@
 
1
  import collections.abc
2
  from dataclasses import dataclass
3
- from functools import partial
4
- from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union
5
 
6
  import torch
7
  import torch.nn as nn
8
  import torchaudio.transforms as audio_transforms
9
  from torch import Tensor
10
  from transformers import GenerationMixin, PreTrainedModel
 
11
  from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
12
  from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
13
  Qwen2_5OmniTextConfig,
@@ -18,28 +19,33 @@ from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import (
18
 
19
  from .configuration_midashenglm import DashengConfig, MiDashengLMConfig
20
 
 
21
 
22
- def to_2tuple(x: Any) -> Tuple[Any, Any]:
23
- if isinstance(x, collections.abc.Iterable):
24
- return x
 
 
 
 
25
  return (x, x)
26
 
27
 
28
  class AudioPatchEmbed(nn.Module):
29
  def __init__(
30
  self,
31
- input_size: Union[int, Tuple[int, int]] = 64,
32
- patch_size: Union[int, Tuple[int, int]] = 16,
33
- patch_stride: Union[int, Tuple[int, int]] = 16,
34
  in_chans: int = 1,
35
  embed_dim: int = 768,
36
  norm_layer: Optional[Callable] = None,
37
  flatten: bool = False,
38
  ):
39
  super().__init__()
40
- self.input_size = to_2tuple(input_size)
41
- self.patch_size = to_2tuple(patch_size)
42
- self.patch_stride = to_2tuple(patch_stride)
43
  self.grid_size = (
44
  self.input_size[0] // self.patch_stride[0],
45
  self.input_size[1] // self.patch_stride[1],
@@ -48,7 +54,10 @@ class AudioPatchEmbed(nn.Module):
48
  self.flatten = flatten
49
 
50
  self.proj = nn.Conv2d(
51
- in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride
 
 
 
52
  )
53
  self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
54
 
@@ -78,14 +87,13 @@ class DashengMlp(nn.Module):
78
  in_features: int,
79
  hidden_features: Optional[int] = None,
80
  out_features: Optional[int] = None,
81
- act_layer: Type[nn.Module] = nn.GELU,
82
  drop: float = 0.0,
83
  ):
84
  super().__init__()
85
  out_features = out_features or in_features
86
  hidden_features = hidden_features or in_features
87
  self.fc1 = nn.Linear(in_features, hidden_features)
88
- self.act = act_layer()
89
  self.fc2 = nn.Linear(hidden_features, out_features)
90
  self.drop = nn.Dropout(drop)
91
 
@@ -173,13 +181,10 @@ class DashengBlock(nn.Module):
173
  drop: float = 0.0,
174
  attn_drop: float = 0.0,
175
  init_values: Optional[float] = None,
176
- act_layer: Type[nn.Module] = nn.GELU,
177
- norm_layer: Type[nn.Module] = nn.LayerNorm,
178
- attention_type: Type[nn.Module] = DashengAttention,
179
  ):
180
  super().__init__()
181
- self.norm1 = norm_layer(dim)
182
- self.attn = attention_type(
183
  dim,
184
  num_heads=num_heads,
185
  qkv_bias=qkv_bias,
@@ -190,11 +195,10 @@ class DashengBlock(nn.Module):
190
  LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
191
  )
192
 
193
- self.norm2 = norm_layer(dim)
194
  self.mlp = DashengMlp(
195
  in_features=dim,
196
  hidden_features=int(dim * mlp_ratio),
197
- act_layer=act_layer,
198
  drop=drop,
199
  )
200
  self.ls2 = (
@@ -250,7 +254,6 @@ class DashengAudioTransformer(PreTrainedModel):
250
  torch.randn(1, config.embed_dim, self.patch_embed.grid_size[0], 1) * 0.02
251
  )
252
 
253
- norm_layer = partial(nn.LayerNorm, eps=1e-6)
254
  self.pos_drop = nn.Dropout(p=config.drop_rate)
255
  self.blocks = nn.ModuleList(
256
  DashengBlock(
@@ -261,11 +264,10 @@ class DashengAudioTransformer(PreTrainedModel):
261
  init_values=config.init_values,
262
  drop=config.drop_rate,
263
  attn_drop=config.attn_drop_rate,
264
- norm_layer=norm_layer,
265
  )
266
  for i in range(config.depth)
267
  )
268
- self.norm = norm_layer(config.embed_dim)
269
 
270
  self.post_init()
271
 
@@ -295,7 +297,7 @@ class DashengAudioTransformer(PreTrainedModel):
295
  self,
296
  x: torch.Tensor,
297
  x_length: Optional[torch.Tensor] = None,
298
- ) -> torch.Tensor:
299
  x = self.front_end(x)
300
  target_length_in_patches = self.target_length // 4
301
  x = x.unsqueeze(1)
@@ -363,10 +365,10 @@ class AudioProjectorSubsample(nn.Module):
363
 
364
  @dataclass
365
  class Qwen25OmniTextModelOutput(ModelOutput):
366
- logits: torch.FloatTensor = None
367
- past_key_values: Optional[List[torch.FloatTensor]] = None
368
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
369
- attentions: Optional[Tuple[torch.FloatTensor]] = None
370
 
371
 
372
  class Qwen25OmniThinkerTextOnlyDecoder(PreTrainedModel, GenerationMixin):
@@ -390,10 +392,22 @@ class Qwen25OmniThinkerTextOnlyDecoder(PreTrainedModel, GenerationMixin):
390
 
391
  def forward(
392
  self,
 
 
393
  return_dict: Optional[bool] = None,
394
  **kwargs: Any,
395
- ) -> Qwen25OmniTextModelOutput:
 
 
 
 
 
 
 
 
396
  outputs: BaseModelOutputWithPast = self.model(
 
 
397
  return_dict=True,
398
  **kwargs,
399
  )
@@ -463,23 +477,26 @@ class MiDashengLMModel(PreTrainedModel):
463
  def _prepare_with_input_ids(
464
  self,
465
  input_ids: torch.Tensor,
466
- audio_embeddings: torch.Tensor,
467
- audio_token_id: int,
468
  ) -> torch.Tensor:
469
- special_mask = input_ids == audio_token_id
470
- assert audio_embeddings.shape[1] <= (special_mask.sum(-1)).max(), (
471
- "Mask and audio embeddings seem to have different sizes: "
472
- f"{audio_embeddings.shape=}, {special_mask=}, {input_ids=}, "
473
- f"{audio_embeddings.shape[1]=} vs {(special_mask.sum(-1)).max()=}"
474
- )
475
  input_embeddings = self.decoder.model.embed_tokens(input_ids)
476
- audio_embeddings = audio_embeddings.to(input_embeddings.dtype)
 
 
 
 
 
 
 
477
 
478
- for i in range(len(special_mask)):
479
- mask = special_mask[i]
480
- number_of_tokens = mask.sum(-1)
481
- input_embeddings[i, mask] = audio_embeddings[i, :number_of_tokens]
482
- return input_embeddings
 
 
483
 
484
  def forward(
485
  self,
@@ -487,7 +504,6 @@ class MiDashengLMModel(PreTrainedModel):
487
  input_values: Optional[Tensor] = None,
488
  inputs_embeds: Optional[Tensor] = None,
489
  audio_length: Optional[Iterable[int]] = None,
490
- attention_mask: Optional[Tensor] = None,
491
  audio_token_id: Optional[int] = None,
492
  **kwargs: Any,
493
  ):
@@ -498,6 +514,11 @@ class MiDashengLMModel(PreTrainedModel):
498
  )
499
 
500
  if input_values is not None:
 
 
 
 
 
501
  input_values = input_values.to(self.device)
502
  audio_encoder_hidden_states = self._forward_audio_encoder(
503
  input_values, audio_length=audio_length
@@ -530,7 +551,6 @@ class MiDashengLMModel(PreTrainedModel):
530
  return self.decoder(
531
  input_ids=None,
532
  inputs_embeds=inputs_embeds,
533
- attention_mask=attention_mask,
534
  **kwargs,
535
  )
536
 
@@ -548,6 +568,7 @@ class MiDashengLMModel(PreTrainedModel):
548
  raise ValueError(
549
  "Both `inputs_embeds` and `input_ids` are passed. Please pass only one of them."
550
  )
 
551
 
552
  if input_values is not None:
553
  input_values = input_values.to(self.device)
@@ -555,15 +576,7 @@ class MiDashengLMModel(PreTrainedModel):
555
  input_values, audio_length=audio_length
556
  )
557
  else:
558
- batch, _ = input_ids.shape
559
- input_values = torch.zeros(
560
- batch,
561
- 0,
562
- self.audio_encoder.embed_dim,
563
- device=input_ids.device,
564
- )
565
-
566
- input_ids = input_ids.to(self.device)
567
  inputs_embeds = self._prepare_with_input_ids(
568
  input_ids=input_ids,
569
  audio_embeddings=audio_encoder_hidden_states,
 
1
+ import collections
2
  import collections.abc
3
  from dataclasses import dataclass
4
+ from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union, cast
 
5
 
6
  import torch
7
  import torch.nn as nn
8
  import torchaudio.transforms as audio_transforms
9
  from torch import Tensor
10
  from transformers import GenerationMixin, PreTrainedModel
11
+ from transformers.cache_utils import Cache
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
13
  from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
14
  Qwen2_5OmniTextConfig,
 
19
 
20
  from .configuration_midashenglm import DashengConfig, MiDashengLMConfig
21
 
22
+ _Tuple2 = Union[int, Tuple[int, int], Sequence[int]]
23
 
24
+
25
+ def _resolve_tuple2(x: _Tuple2) -> Tuple[int, int]:
26
+ if isinstance(x, collections.abc.Sequence):
27
+ assert len(x) == 2, (
28
+ f"Expected a sequence of length 2, got {x} with length {len(x)}"
29
+ )
30
+ return cast(Tuple[int, int], tuple(x))
31
  return (x, x)
32
 
33
 
34
  class AudioPatchEmbed(nn.Module):
35
  def __init__(
36
  self,
37
+ input_size: _Tuple2 = 64,
38
+ patch_size: _Tuple2 = 16,
39
+ patch_stride: _Tuple2 = 16,
40
  in_chans: int = 1,
41
  embed_dim: int = 768,
42
  norm_layer: Optional[Callable] = None,
43
  flatten: bool = False,
44
  ):
45
  super().__init__()
46
+ self.input_size = _resolve_tuple2(input_size)
47
+ self.patch_size = _resolve_tuple2(patch_size)
48
+ self.patch_stride = _resolve_tuple2(patch_stride)
49
  self.grid_size = (
50
  self.input_size[0] // self.patch_stride[0],
51
  self.input_size[1] // self.patch_stride[1],
 
54
  self.flatten = flatten
55
 
56
  self.proj = nn.Conv2d(
57
+ in_chans,
58
+ embed_dim,
59
+ kernel_size=self.patch_size,
60
+ stride=self.patch_stride,
61
  )
62
  self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
63
 
 
87
  in_features: int,
88
  hidden_features: Optional[int] = None,
89
  out_features: Optional[int] = None,
 
90
  drop: float = 0.0,
91
  ):
92
  super().__init__()
93
  out_features = out_features or in_features
94
  hidden_features = hidden_features or in_features
95
  self.fc1 = nn.Linear(in_features, hidden_features)
96
+ self.act = nn.GELU()
97
  self.fc2 = nn.Linear(hidden_features, out_features)
98
  self.drop = nn.Dropout(drop)
99
 
 
181
  drop: float = 0.0,
182
  attn_drop: float = 0.0,
183
  init_values: Optional[float] = None,
 
 
 
184
  ):
185
  super().__init__()
186
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
187
+ self.attn = DashengAttention(
188
  dim,
189
  num_heads=num_heads,
190
  qkv_bias=qkv_bias,
 
195
  LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
196
  )
197
 
198
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
199
  self.mlp = DashengMlp(
200
  in_features=dim,
201
  hidden_features=int(dim * mlp_ratio),
 
202
  drop=drop,
203
  )
204
  self.ls2 = (
 
254
  torch.randn(1, config.embed_dim, self.patch_embed.grid_size[0], 1) * 0.02
255
  )
256
 
 
257
  self.pos_drop = nn.Dropout(p=config.drop_rate)
258
  self.blocks = nn.ModuleList(
259
  DashengBlock(
 
264
  init_values=config.init_values,
265
  drop=config.drop_rate,
266
  attn_drop=config.attn_drop_rate,
 
267
  )
268
  for i in range(config.depth)
269
  )
270
+ self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)
271
 
272
  self.post_init()
273
 
 
297
  self,
298
  x: torch.Tensor,
299
  x_length: Optional[torch.Tensor] = None,
300
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
301
  x = self.front_end(x)
302
  target_length_in_patches = self.target_length // 4
303
  x = x.unsqueeze(1)
 
365
 
366
  @dataclass
367
  class Qwen25OmniTextModelOutput(ModelOutput):
368
+ logits: Optional[torch.FloatTensor] = None
369
+ past_key_values: Optional[Cache] = None
370
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
371
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
372
 
373
 
374
  class Qwen25OmniThinkerTextOnlyDecoder(PreTrainedModel, GenerationMixin):
 
392
 
393
  def forward(
394
  self,
395
+ attention_mask: Optional[Tensor] = None,
396
+ position_ids: Optional[torch.Tensor] = None,
397
  return_dict: Optional[bool] = None,
398
  **kwargs: Any,
399
+ ) -> Union[Tuple, Qwen25OmniTextModelOutput]:
400
+ if attention_mask is not None and position_ids is None:
401
+ position_ids = (
402
+ attention_mask.long()
403
+ .cumsum(dim=-1)
404
+ .masked_fill_(attention_mask == 0, 1)
405
+ - 1
406
+ )
407
+
408
  outputs: BaseModelOutputWithPast = self.model(
409
+ attention_mask=attention_mask,
410
+ position_ids=position_ids,
411
  return_dict=True,
412
  **kwargs,
413
  )
 
477
  def _prepare_with_input_ids(
478
  self,
479
  input_ids: torch.Tensor,
480
+ audio_embeddings: Optional[torch.Tensor],
481
+ audio_token_id: Optional[int],
482
  ) -> torch.Tensor:
 
 
 
 
 
 
483
  input_embeddings = self.decoder.model.embed_tokens(input_ids)
484
+ if audio_embeddings is not None:
485
+ special_mask = input_ids == audio_token_id
486
+ assert audio_embeddings.shape[1] <= (special_mask.sum(-1)).max(), (
487
+ "Mask and audio embeddings seem to have different sizes: "
488
+ f"{audio_embeddings.shape=}, {special_mask=}, {input_ids=}, "
489
+ f"{audio_embeddings.shape[1]=} vs {(special_mask.sum(-1)).max()=}"
490
+ )
491
+ audio_embeddings = audio_embeddings.to(input_embeddings.dtype)
492
 
493
+ for i in range(len(special_mask)):
494
+ mask = special_mask[i]
495
+ number_of_tokens = mask.sum(-1)
496
+ input_embeddings[i, mask] = audio_embeddings[i, :number_of_tokens]
497
+ return input_embeddings
498
+ else:
499
+ return input_embeddings
500
 
501
  def forward(
502
  self,
 
504
  input_values: Optional[Tensor] = None,
505
  inputs_embeds: Optional[Tensor] = None,
506
  audio_length: Optional[Iterable[int]] = None,
 
507
  audio_token_id: Optional[int] = None,
508
  **kwargs: Any,
509
  ):
 
514
  )
515
 
516
  if input_values is not None:
517
+ if audio_token_id is None:
518
+ raise ValueError(
519
+ "If `input_values` is provided, `audio_token_id` must also be provided."
520
+ )
521
+
522
  input_values = input_values.to(self.device)
523
  audio_encoder_hidden_states = self._forward_audio_encoder(
524
  input_values, audio_length=audio_length
 
551
  return self.decoder(
552
  input_ids=None,
553
  inputs_embeds=inputs_embeds,
 
554
  **kwargs,
555
  )
556
 
 
568
  raise ValueError(
569
  "Both `inputs_embeds` and `input_ids` are passed. Please pass only one of them."
570
  )
571
+ input_ids = input_ids.to(self.device)
572
 
573
  if input_values is not None:
574
  input_values = input_values.to(self.device)
 
576
  input_values, audio_length=audio_length
577
  )
578
  else:
579
+ audio_encoder_hidden_states = None
 
 
 
 
 
 
 
 
580
  inputs_embeds = self._prepare_with_input_ids(
581
  input_ids=input_ids,
582
  audio_embeddings=audio_encoder_hidden_states,
preprocessor_config.json CHANGED
@@ -1,13 +1,13 @@
1
  {
2
  "auto_map": {
3
- "AutoProcessor": "processing_midashenglm.MiAudioLLMProcessor"
4
  },
5
  "do_normalize": false,
6
  "feature_extractor_type": "Wav2Vec2FeatureExtractor",
7
  "feature_size": 1,
8
  "padding_side": "right",
9
  "padding_value": 0.0,
10
- "processor_class": "MiAudioLLMProcessor",
11
  "return_attention_mask": false,
12
  "sampling_rate": 16000
13
  }
 
1
  {
2
  "auto_map": {
3
+ "AutoProcessor": "processing_midashenglm.MiDashengLMProcessor"
4
  },
5
  "do_normalize": false,
6
  "feature_extractor_type": "Wav2Vec2FeatureExtractor",
7
  "feature_size": 1,
8
  "padding_side": "right",
9
  "padding_value": 0.0,
10
+ "processor_class": "MiDashengLMProcessor",
11
  "return_attention_mask": false,
12
  "sampling_rate": 16000
13
  }
processing_midashenglm.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Union
2
 
3
  import numpy as np
4
  import torch
@@ -7,8 +7,8 @@ from transformers.feature_extraction_utils import BatchFeature
7
  from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
8
 
9
 
10
- class MiAudioLLMProcessorKwargs(ProcessingKwargs):
11
- _defaults = {
12
  "text_kwargs": {
13
  "padding": True,
14
  "padding_side": "left",
@@ -36,7 +36,7 @@ def calculate_mel_frames_dasheng(
36
  )
37
 
38
 
39
- class MiAudioLLMProcessor(ProcessorMixin):
40
  attributes = ["feature_extractor", "tokenizer"]
41
  valid_kwargs = [
42
  "chat_template",
@@ -49,15 +49,14 @@ class MiAudioLLMProcessor(ProcessorMixin):
49
 
50
  def __init__(
51
  self,
52
- feature_extractor: Optional[Wav2Vec2FeatureExtractor] = None,
53
- tokenizer: Optional[Union[Qwen2Tokenizer, Qwen2TokenizerFast]] = None,
54
  model_subsampling: int = 5,
55
- chat_template: Optional[str] = None,
56
  audio_token: Optional[str] = None,
57
  audio_bos_token: Optional[str] = None,
58
  audio_eos_token: Optional[str] = None,
59
  ):
60
- assert tokenizer is not None, "Tokenizer Needs to be passed"
61
  assert audio_token is not None or hasattr(tokenizer, "audio_token"), (
62
  "Either `audio_token` must be provided or tokenizer must have `audio_token` attribute."
63
  )
@@ -67,22 +66,62 @@ class MiAudioLLMProcessor(ProcessorMixin):
67
  assert audio_eos_token is not None or hasattr(tokenizer, "audio_eos_token"), (
68
  "Either `audio_eos_token` must be provided or tokenizer must have `audio_eos_token` attribute."
69
  )
 
 
 
70
 
71
  if chat_template is None:
72
  chat_template = tokenizer.chat_template
73
 
74
- self.audio_token: str = audio_token or tokenizer.audio_token
75
- self.audio_bos_token = audio_bos_token or tokenizer.audio_bos_token
76
- self.audio_eos_token = audio_eos_token or tokenizer.audio_eos_token
77
- self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
78
- self.model_subsampling = model_subsampling
 
 
 
 
 
 
 
79
 
80
- if feature_extractor is not None:
81
- assert not feature_extractor.do_normalize, (
82
- "This model does not use normalization. Please set `do_normalize=False` in the feature extractor."
83
- )
 
 
 
 
 
84
 
85
  super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  @classmethod
88
  def _validate_audio_sample(
@@ -117,7 +156,7 @@ class MiAudioLLMProcessor(ProcessorMixin):
117
  self,
118
  text: Optional[List[str]] = None,
119
  audio: Optional[Union[List[np.ndarray], List[torch.Tensor]]] = None,
120
- **kwargs: Unpack[MiAudioLLMProcessorKwargs],
121
  ) -> BatchFeature:
122
  if text is None:
123
  raise ValueError("You need to specify `text` input to process.")
@@ -135,7 +174,7 @@ class MiAudioLLMProcessor(ProcessorMixin):
135
  raise ValueError("This model does not support images or videos.")
136
 
137
  output_kwargs = self._merge_kwargs(
138
- MiAudioLLMProcessorKwargs,
139
  tokenizer_init_kwargs=self.tokenizer.init_kwargs,
140
  **kwargs,
141
  )
@@ -157,7 +196,9 @@ class MiAudioLLMProcessor(ProcessorMixin):
157
 
158
  # + Padding
159
  audio_inputs = self.feature_extractor(
160
- audio, **output_kwargs["audio_kwargs"]
 
 
161
  )
162
 
163
  # remove attention mask, dasheng uses lengths
@@ -216,28 +257,17 @@ class MiAudioLLMProcessor(ProcessorMixin):
216
 
217
  return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", "pt")
218
  inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
219
- if hasattr(self, "_check_special_mm_tokens"):
220
- self._check_special_mm_tokens(text, inputs, modalities=["audio"])
 
 
 
221
 
222
  if audio is not None:
223
  inputs.update(audio_inputs)
224
 
225
  return BatchFeature(data={**inputs}, tensor_type=return_tensors)
226
 
227
- def batch_decode(self, *args, **kwargs):
228
- """
229
- This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
230
- refer to the docstring of this method for more information.
231
- """
232
- return self.tokenizer.batch_decode(*args, **kwargs)
233
-
234
- def decode(self, *args, **kwargs):
235
- """
236
- This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
237
- the docstring of this method for more information.
238
- """
239
- return self.tokenizer.decode(*args, **kwargs)
240
-
241
  @property
242
  def model_input_names(self):
243
  tokenizer_input_names = self.tokenizer.model_input_names
 
1
+ from typing import Dict, List, Optional, Union, cast
2
 
3
  import numpy as np
4
  import torch
 
7
  from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
8
 
9
 
10
+ class MiDashengLMProcessorKwargs(ProcessingKwargs):
11
+ _defaults = { # type: ignore
12
  "text_kwargs": {
13
  "padding": True,
14
  "padding_side": "left",
 
36
  )
37
 
38
 
39
+ class MiDashengLMProcessor(ProcessorMixin):
40
  attributes = ["feature_extractor", "tokenizer"]
41
  valid_kwargs = [
42
  "chat_template",
 
49
 
50
  def __init__(
51
  self,
52
+ feature_extractor: Wav2Vec2FeatureExtractor,
53
+ tokenizer: Union[Qwen2Tokenizer, Qwen2TokenizerFast],
54
  model_subsampling: int = 5,
55
+ chat_template: Optional[Union[str, Dict[str, str]]] = None,
56
  audio_token: Optional[str] = None,
57
  audio_bos_token: Optional[str] = None,
58
  audio_eos_token: Optional[str] = None,
59
  ):
 
60
  assert audio_token is not None or hasattr(tokenizer, "audio_token"), (
61
  "Either `audio_token` must be provided or tokenizer must have `audio_token` attribute."
62
  )
 
66
  assert audio_eos_token is not None or hasattr(tokenizer, "audio_eos_token"), (
67
  "Either `audio_eos_token` must be provided or tokenizer must have `audio_eos_token` attribute."
68
  )
69
+ assert not feature_extractor.do_normalize, (
70
+ "This model does not use normalization. Please set `do_normalize=False` in the feature extractor."
71
+ )
72
 
73
  if chat_template is None:
74
  chat_template = tokenizer.chat_template
75
 
76
+ def get_token(token_name: str) -> str:
77
+ if not hasattr(tokenizer, token_name):
78
+ raise ValueError(
79
+ f"Tokenizer does not have attribute `{token_name}`. "
80
+ "Please provide it as an argument to the processor."
81
+ )
82
+ token = getattr(tokenizer, token_name)
83
+ if not isinstance(token, str):
84
+ raise TypeError(
85
+ f"Expected token {token_name} to be a string, but got {type(token)}."
86
+ )
87
+ return token
88
 
89
+ self.audio_token = audio_token or get_token("audio_token")
90
+ self.audio_bos_token = audio_bos_token or get_token("audio_bos_token")
91
+ self.audio_eos_token = audio_eos_token or get_token("audio_eos_token")
92
+
93
+ self.audio_token_id = cast(
94
+ int, tokenizer.convert_tokens_to_ids(self.audio_token)
95
+ )
96
+ self.model_subsampling = model_subsampling
97
+ self.sampling_rate = feature_extractor.sampling_rate
98
 
99
  super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
100
+ self.feature_extractor: Wav2Vec2FeatureExtractor
101
+ self.tokenizer: Union[Qwen2Tokenizer, Qwen2TokenizerFast]
102
+ self.chat_template: Optional[Union[str, Dict[str, str]]]
103
+
104
+ def _process_messages_for_chat_template(
105
+ self,
106
+ conversation,
107
+ batch_images,
108
+ batch_videos,
109
+ batch_video_metadata,
110
+ **mm_load_kwargs,
111
+ ):
112
+ if (sr := mm_load_kwargs.get("sampling_rate", None)) is not None:
113
+ if sr != self.sampling_rate:
114
+ raise ValueError(
115
+ f"This model is trained with a sampling rate of {self.sampling_rate}, "
116
+ f"but the sampling rate {sr} is used to load audio."
117
+ )
118
+ return super()._process_messages_for_chat_template(
119
+ conversation,
120
+ batch_images,
121
+ batch_videos,
122
+ batch_video_metadata,
123
+ **mm_load_kwargs,
124
+ )
125
 
126
  @classmethod
127
  def _validate_audio_sample(
 
156
  self,
157
  text: Optional[List[str]] = None,
158
  audio: Optional[Union[List[np.ndarray], List[torch.Tensor]]] = None,
159
+ **kwargs: Unpack[MiDashengLMProcessorKwargs],
160
  ) -> BatchFeature:
161
  if text is None:
162
  raise ValueError("You need to specify `text` input to process.")
 
174
  raise ValueError("This model does not support images or videos.")
175
 
176
  output_kwargs = self._merge_kwargs(
177
+ MiDashengLMProcessorKwargs, # type: ignore # Bad type hint in transformers
178
  tokenizer_init_kwargs=self.tokenizer.init_kwargs,
179
  **kwargs,
180
  )
 
196
 
197
  # + Padding
198
  audio_inputs = self.feature_extractor(
199
+ audio,
200
+ sampling_rate=self.sampling_rate,
201
+ **output_kwargs["audio_kwargs"],
202
  )
203
 
204
  # remove attention mask, dasheng uses lengths
 
257
 
258
  return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", "pt")
259
  inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
260
+ self._check_special_mm_tokens(
261
+ text,
262
+ BatchFeature(inputs), # type: ignore
263
+ modalities=["audio"],
264
+ )
265
 
266
  if audio is not None:
267
  inputs.update(audio_inputs)
268
 
269
  return BatchFeature(data={**inputs}, tensor_type=return_tensors)
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  @property
272
  def model_input_names(self):
273
  tokenizer_input_names = self.tokenizer.model_input_names
processor_config.json CHANGED
@@ -3,8 +3,8 @@
3
  "audio_eos_token": "<|audio_eos|>",
4
  "audio_token": "<|AUDIO|>",
5
  "auto_map": {
6
- "AutoProcessor": "processing_midashenglm.MiAudioLLMProcessor"
7
  },
8
  "model_subsampling": 5,
9
- "processor_class": "MiAudioLLMProcessor"
10
  }
 
3
  "audio_eos_token": "<|audio_eos|>",
4
  "audio_token": "<|AUDIO|>",
5
  "auto_map": {
6
+ "AutoProcessor": "processing_midashenglm.MiDashengLMProcessor"
7
  },
8
  "model_subsampling": 5,
9
+ "processor_class": "MiDashengLMProcessor"
10
  }
tokenizer_config.json CHANGED
@@ -337,7 +337,7 @@
337
  "audio_eos_token": "<|audio_eos|>",
338
  "audio_token": "<|AUDIO|>",
339
  "auto_map": {
340
- "AutoProcessor": "processing_midashenglm.MiAudioLLMProcessor"
341
  },
342
  "bos_token": null,
343
  "clean_up_tokenization_spaces": false,
@@ -355,7 +355,7 @@
355
  "image_token": "<|IMAGE|>",
356
  "model_max_length": 32768,
357
  "pad_token": "<|endoftext|>",
358
- "processor_class": "MiAudioLLMProcessor",
359
  "split_special_tokens": false,
360
  "tokenizer_class": "Qwen2Tokenizer",
361
  "unk_token": null,
 
337
  "audio_eos_token": "<|audio_eos|>",
338
  "audio_token": "<|AUDIO|>",
339
  "auto_map": {
340
+ "AutoProcessor": "processing_midashenglm.MiDashengLMProcessor"
341
  },
342
  "bos_token": null,
343
  "clean_up_tokenization_spaces": false,
 
355
  "image_token": "<|IMAGE|>",
356
  "model_max_length": 32768,
357
  "pad_token": "<|endoftext|>",
358
+ "processor_class": "MiDashengLMProcessor",
359
  "split_special_tokens": false,
360
  "tokenizer_class": "Qwen2Tokenizer",
361
  "unk_token": null,