Upload folder using huggingface_hub
Browse files- modeling_midashenglm.py +15 -15
modeling_midashenglm.py
CHANGED
@@ -394,12 +394,10 @@ class Qwen25OmniTextModelOutput(ModelOutput):
|
|
394 |
|
395 |
class Qwen25OmniThinkerTextOnlyDecoder(PreTrainedModel, GenerationMixin):
|
396 |
config_class = Qwen2_5OmniTextConfig
|
397 |
-
_supports_flash_attn_2 =
|
398 |
-
_supports_sdpa =
|
399 |
-
|
400 |
-
|
401 |
-
_supports_static_cache = Qwen2_5OmniThinkerTextModel._supports_static_cache
|
402 |
-
_supports_quantized_cache = Qwen2_5OmniThinkerTextModel._supports_quantized_cache
|
403 |
|
404 |
def __init__(self, config: Qwen2_5OmniTextConfig):
|
405 |
super().__init__(config)
|
@@ -471,15 +469,11 @@ class Qwen25OmniThinkerTextOnlyDecoder(PreTrainedModel, GenerationMixin):
|
|
471 |
|
472 |
class MiDashengLMModel(PreTrainedModel):
|
473 |
config_class = MiDashengLMConfig
|
474 |
-
_supports_flash_attn_2 =
|
475 |
-
_supports_sdpa =
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
_supports_quantized_cache = Qwen2_5OmniThinkerTextModel._supports_quantized_cache
|
480 |
-
supports_gradient_checkpointing = (
|
481 |
-
Qwen2_5OmniThinkerTextModel.supports_gradient_checkpointing
|
482 |
-
)
|
483 |
|
484 |
def __init__(self, config: MiDashengLMConfig):
|
485 |
super().__init__(config)
|
@@ -501,6 +495,12 @@ class MiDashengLMModel(PreTrainedModel):
|
|
501 |
|
502 |
self.post_init()
|
503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
504 |
def _forward_audio_encoder(
|
505 |
self,
|
506 |
audios: torch.Tensor,
|
|
|
394 |
|
395 |
class Qwen25OmniThinkerTextOnlyDecoder(PreTrainedModel, GenerationMixin):
|
396 |
config_class = Qwen2_5OmniTextConfig
|
397 |
+
_supports_flash_attn_2 = True
|
398 |
+
_supports_sdpa = True
|
399 |
+
_supports_cache_class = True
|
400 |
+
_supports_static_cache = True
|
|
|
|
|
401 |
|
402 |
def __init__(self, config: Qwen2_5OmniTextConfig):
|
403 |
super().__init__(config)
|
|
|
469 |
|
470 |
class MiDashengLMModel(PreTrainedModel):
|
471 |
config_class = MiDashengLMConfig
|
472 |
+
_supports_flash_attn_2 = True
|
473 |
+
_supports_sdpa = True
|
474 |
+
_supports_cache_class = True
|
475 |
+
_supports_static_cache = True
|
476 |
+
supports_gradient_checkpointing = True
|
|
|
|
|
|
|
|
|
477 |
|
478 |
def __init__(self, config: MiDashengLMConfig):
|
479 |
super().__init__(config)
|
|
|
495 |
|
496 |
self.post_init()
|
497 |
|
498 |
+
def get_input_embeddings(self):
|
499 |
+
return self.decoder.model.embed_tokens
|
500 |
+
|
501 |
+
def get_output_embeddings(self):
|
502 |
+
return self.decoder.lm_head
|
503 |
+
|
504 |
def _forward_audio_encoder(
|
505 |
self,
|
506 |
audios: torch.Tensor,
|