zhoukz commited on
Commit
5e3b785
·
verified ·
1 Parent(s): 4c11a4f

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_midashenglm.py +15 -15
modeling_midashenglm.py CHANGED
@@ -394,12 +394,10 @@ class Qwen25OmniTextModelOutput(ModelOutput):
394
 
395
  class Qwen25OmniThinkerTextOnlyDecoder(PreTrainedModel, GenerationMixin):
396
  config_class = Qwen2_5OmniTextConfig
397
- _supports_flash_attn_2 = Qwen2_5OmniThinkerTextModel._supports_flash_attn_2
398
- _supports_sdpa = Qwen2_5OmniThinkerTextModel._supports_sdpa
399
- _supports_flex_attn = Qwen2_5OmniThinkerTextModel._supports_flex_attn
400
- _supports_cache_class = Qwen2_5OmniThinkerTextModel._supports_cache_class
401
- _supports_static_cache = Qwen2_5OmniThinkerTextModel._supports_static_cache
402
- _supports_quantized_cache = Qwen2_5OmniThinkerTextModel._supports_quantized_cache
403
 
404
  def __init__(self, config: Qwen2_5OmniTextConfig):
405
  super().__init__(config)
@@ -471,15 +469,11 @@ class Qwen25OmniThinkerTextOnlyDecoder(PreTrainedModel, GenerationMixin):
471
 
472
  class MiDashengLMModel(PreTrainedModel):
473
  config_class = MiDashengLMConfig
474
- _supports_flash_attn_2 = Qwen2_5OmniThinkerTextModel._supports_flash_attn_2
475
- _supports_sdpa = Qwen2_5OmniThinkerTextModel._supports_sdpa
476
- _supports_flex_attn = Qwen2_5OmniThinkerTextModel._supports_flex_attn
477
- _supports_cache_class = Qwen2_5OmniThinkerTextModel._supports_cache_class
478
- _supports_static_cache = Qwen2_5OmniThinkerTextModel._supports_static_cache
479
- _supports_quantized_cache = Qwen2_5OmniThinkerTextModel._supports_quantized_cache
480
- supports_gradient_checkpointing = (
481
- Qwen2_5OmniThinkerTextModel.supports_gradient_checkpointing
482
- )
483
 
484
  def __init__(self, config: MiDashengLMConfig):
485
  super().__init__(config)
@@ -501,6 +495,12 @@ class MiDashengLMModel(PreTrainedModel):
501
 
502
  self.post_init()
503
 
 
 
 
 
 
 
504
  def _forward_audio_encoder(
505
  self,
506
  audios: torch.Tensor,
 
394
 
395
  class Qwen25OmniThinkerTextOnlyDecoder(PreTrainedModel, GenerationMixin):
396
  config_class = Qwen2_5OmniTextConfig
397
+ _supports_flash_attn_2 = True
398
+ _supports_sdpa = True
399
+ _supports_cache_class = True
400
+ _supports_static_cache = True
 
 
401
 
402
  def __init__(self, config: Qwen2_5OmniTextConfig):
403
  super().__init__(config)
 
469
 
470
  class MiDashengLMModel(PreTrainedModel):
471
  config_class = MiDashengLMConfig
472
+ _supports_flash_attn_2 = True
473
+ _supports_sdpa = True
474
+ _supports_cache_class = True
475
+ _supports_static_cache = True
476
+ supports_gradient_checkpointing = True
 
 
 
 
477
 
478
  def __init__(self, config: MiDashengLMConfig):
479
  super().__init__(config)
 
495
 
496
  self.post_init()
497
 
498
+ def get_input_embeddings(self):
499
+ return self.decoder.model.embed_tokens
500
+
501
+ def get_output_embeddings(self):
502
+ return self.decoder.lm_head
503
+
504
  def _forward_audio_encoder(
505
  self,
506
  audios: torch.Tensor,