mixture-of-recursions-360m / modeling_mor.py
sudeshmu's picture
Upload MoR (Mixture-of-Recursions) model
861e577 verified
"""
Simplified MoR (Mixture-of-Recursions) model implementation for Hugging Face Hub.
This provides basic inference capabilities while maintaining compatibility with the full MoR framework.
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union
from transformers import LlamaForCausalLM, LlamaConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
class MoRConfig(LlamaConfig):
"""
Configuration class for MoR model.
Extends LlamaConfig with MoR-specific parameters.
"""
def __init__(
self,
mor_enabled=True,
num_recursions=3,
routing_strategy="expert_choice",
kv_sharing=None,
**kwargs
):
super().__init__(**kwargs)
# MoR-specific configurations
self.mor_enabled = mor_enabled
self.num_recursions = num_recursions
self.routing_strategy = routing_strategy
self.kv_sharing = kv_sharing
class MoRLlamaForCausalLM(LlamaForCausalLM):
"""
Simplified MoR model for Hugging Face Hub.
This implementation provides basic inference capabilities while maintaining
compatibility with the original MoR training framework. For full MoR features
including dynamic routing and recursion-wise KV caching, use the complete
implementation from the original repository.
"""
config_class = MoRConfig
def __init__(self, config):
super().__init__(config)
# Store MoR-specific config
self.mor_config = config
# For simplified inference, we'll use the standard forward pass
# Full MoR capabilities require the complete training framework
@add_start_docstrings_to_model_forward("Standard forward pass with simplified MoR compatibility")
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass for simplified MoR model.
For basic inference, this behaves like a standard LLaMA model.
Advanced MoR features require the complete training framework.
"""
# Use standard LLaMA forward pass for simplified inference
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""
Load MoR model from pretrained checkpoint.
This method handles loading the model weights while maintaining
compatibility with both the simplified and full MoR implementations.
"""
# Load the model using the parent class method
model = super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
**kwargs
)
return model
def generate_with_mor(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
max_length: int = 100,
temperature: float = 1.0,
do_sample: bool = True,
**kwargs
):
"""
Generate text with MoR-aware settings.
This is a convenience method that provides optimized generation
settings for MoR models.
"""
return self.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
temperature=temperature,
do_sample=do_sample,
pad_token_id=self.config.eos_token_id,
**kwargs
)
# Register the model for auto-loading
try:
from transformers import AutoConfig, AutoModelForCausalLM
AutoConfig.register("mor_llama", MoRConfig)
AutoModelForCausalLM.register(MoRConfig, MoRLlamaForCausalLM)
except:
# Registration may fail in some environments, but the model can still be used directly
pass