|
""" |
|
Simplified MoR (Mixture-of-Recursions) model implementation for Hugging Face Hub. |
|
This provides basic inference capabilities while maintaining compatibility with the full MoR framework. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from typing import Optional, Tuple, Union |
|
from transformers import LlamaForCausalLM, LlamaConfig |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings |
|
|
|
class MoRConfig(LlamaConfig): |
|
""" |
|
Configuration class for MoR model. |
|
Extends LlamaConfig with MoR-specific parameters. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
mor_enabled=True, |
|
num_recursions=3, |
|
routing_strategy="expert_choice", |
|
kv_sharing=None, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
|
|
|
|
self.mor_enabled = mor_enabled |
|
self.num_recursions = num_recursions |
|
self.routing_strategy = routing_strategy |
|
self.kv_sharing = kv_sharing |
|
|
|
class MoRLlamaForCausalLM(LlamaForCausalLM): |
|
""" |
|
Simplified MoR model for Hugging Face Hub. |
|
|
|
This implementation provides basic inference capabilities while maintaining |
|
compatibility with the original MoR training framework. For full MoR features |
|
including dynamic routing and recursion-wise KV caching, use the complete |
|
implementation from the original repository. |
|
""" |
|
|
|
config_class = MoRConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
self.mor_config = config |
|
|
|
|
|
|
|
|
|
@add_start_docstrings_to_model_forward("Standard forward pass with simplified MoR compatibility") |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
""" |
|
Forward pass for simplified MoR model. |
|
|
|
For basic inference, this behaves like a standard LLaMA model. |
|
Advanced MoR features require the complete training framework. |
|
""" |
|
|
|
|
|
return super().forward( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
labels=labels, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
**kwargs |
|
) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
""" |
|
Load MoR model from pretrained checkpoint. |
|
|
|
This method handles loading the model weights while maintaining |
|
compatibility with both the simplified and full MoR implementations. |
|
""" |
|
|
|
|
|
model = super().from_pretrained( |
|
pretrained_model_name_or_path, |
|
*model_args, |
|
**kwargs |
|
) |
|
|
|
return model |
|
|
|
def generate_with_mor( |
|
self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
max_length: int = 100, |
|
temperature: float = 1.0, |
|
do_sample: bool = True, |
|
**kwargs |
|
): |
|
""" |
|
Generate text with MoR-aware settings. |
|
|
|
This is a convenience method that provides optimized generation |
|
settings for MoR models. |
|
""" |
|
|
|
return self.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
temperature=temperature, |
|
do_sample=do_sample, |
|
pad_token_id=self.config.eos_token_id, |
|
**kwargs |
|
) |
|
|
|
|
|
try: |
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
AutoConfig.register("mor_llama", MoRConfig) |
|
AutoModelForCausalLM.register(MoRConfig, MoRLlamaForCausalLM) |
|
except: |
|
|
|
pass |