Upload MoR (Mixture-of-Recursions) model
Browse files- README.md +188 -0
- config.json +35 -0
- generation_config.json +6 -0
- modeling_mor.py +144 -0
- pytorch_model.bin +3 -0
- requirements.txt +6 -0
- tokenizer_config.json +7 -0
README.md
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
library_name: transformers
|
4 |
+
tags:
|
5 |
+
- mixture-of-recursions
|
6 |
+
- adaptive-computation
|
7 |
+
- early-exiting
|
8 |
+
- llama
|
9 |
+
- language-model
|
10 |
+
- efficient-inference
|
11 |
+
base_model: microsoft/DialoGPT-medium
|
12 |
+
datasets:
|
13 |
+
- HuggingFaceTB/smollm-corpus
|
14 |
+
language:
|
15 |
+
- en
|
16 |
+
pipeline_tag: text-generation
|
17 |
+
model_type: llama
|
18 |
+
---
|
19 |
+
|
20 |
+
# Mixture-of-Recursions (MoR): Learning Dynamic Recursive Depths for Adaptive Token-Level Computation
|
21 |
+
|
22 |
+
<div align="center">
|
23 |
+
|
24 |
+
[](https://arxiv.org/abs/2507.10524)
|
25 |
+
[](https://github.com/raymin0223/mixture_of_recursions)
|
26 |
+
[](https://opensource.org/licenses/MIT)
|
27 |
+
|
28 |
+
</div>
|
29 |
+
|
30 |
+
## Model Description
|
31 |
+
|
32 |
+
This is a **Mixture-of-Recursions (MoR)** model that implements adaptive token-level computation through dynamic recursive depths. MoR addresses key bottlenecks in early-exiting techniques by introducing a unified framework that tackles both missing Key-Value (KV) cache problems and inefficient batched inference.
|
33 |
+
|
34 |
+
**Key Features:**
|
35 |
+
- 🚀 **Up to 2× greater inference throughput** compared to standard transformers at similar accuracy
|
36 |
+
- 🧠 **Dynamic routing mechanism** that assigns optimal recursion depth to each token
|
37 |
+
- 💾 **Recursion-wise KV caching strategy** that optimizes memory usage
|
38 |
+
- ⚡ **Efficient batched inference** through parameter sharing
|
39 |
+
- 🎯 **End-to-end trainable** architecture
|
40 |
+
|
41 |
+
### Model Details
|
42 |
+
|
43 |
+
- **Model Size**: 360M parameters
|
44 |
+
- **Architecture**: Based on LLaMA with MoR modifications
|
45 |
+
- **Context Length**: 1024 tokens
|
46 |
+
- **Vocabulary Size**: 49,152 tokens
|
47 |
+
- **Hidden Size**: 960
|
48 |
+
- **Number of Layers**: 32
|
49 |
+
- **Attention Heads**: 15 (5 KV heads)
|
50 |
+
- **Training Data**: FineWeb-Edu deduplicated subset
|
51 |
+
|
52 |
+
## Quick Start
|
53 |
+
|
54 |
+
### Installation
|
55 |
+
|
56 |
+
```bash
|
57 |
+
pip install torch transformers accelerate
|
58 |
+
```
|
59 |
+
|
60 |
+
### Basic Usage
|
61 |
+
|
62 |
+
```python
|
63 |
+
import torch
|
64 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
65 |
+
|
66 |
+
# Load model and tokenizer
|
67 |
+
model_name = "your-username/mixture-of-recursions-360m"
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
69 |
+
model = AutoModelForCausalLM.from_pretrained(
|
70 |
+
model_name,
|
71 |
+
torch_dtype=torch.bfloat16,
|
72 |
+
device_map="auto"
|
73 |
+
)
|
74 |
+
|
75 |
+
# Generate text
|
76 |
+
prompt = "The key to artificial intelligence is"
|
77 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
78 |
+
|
79 |
+
with torch.no_grad():
|
80 |
+
outputs = model.generate(
|
81 |
+
**inputs,
|
82 |
+
max_length=100,
|
83 |
+
temperature=0.7,
|
84 |
+
do_sample=True,
|
85 |
+
pad_token_id=tokenizer.eos_token_id
|
86 |
+
)
|
87 |
+
|
88 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
89 |
+
print(generated_text)
|
90 |
+
```
|
91 |
+
|
92 |
+
### Advanced Usage with Custom Recursion
|
93 |
+
|
94 |
+
```python
|
95 |
+
# For advanced users: Access MoR-specific features
|
96 |
+
# Note: This requires the original MoR codebase for full functionality
|
97 |
+
|
98 |
+
from transformers import AutoConfig
|
99 |
+
|
100 |
+
config = AutoConfig.from_pretrained(model_name)
|
101 |
+
# The model supports dynamic recursion depths through routing mechanisms
|
102 |
+
# See the original repository for complete MoR training and inference scripts
|
103 |
+
```
|
104 |
+
|
105 |
+
## Model Architecture
|
106 |
+
|
107 |
+
The MoR model introduces several key innovations over standard transformers:
|
108 |
+
|
109 |
+
### 1. Dynamic Routing Mechanism
|
110 |
+
- **Expert-choice routing**: Dynamically selects which tokens to process at each recursion depth
|
111 |
+
- **Token-choice routing**: Allows tokens to choose their optimal processing depth
|
112 |
+
- **Trainable routers**: End-to-end learning of routing decisions
|
113 |
+
|
114 |
+
### 2. Recursion-wise KV Caching
|
115 |
+
- Solves the missing KV cache problem in early-exiting models
|
116 |
+
- Selective KV pair storage for memory optimization
|
117 |
+
- Enables efficient parallel decoding
|
118 |
+
|
119 |
+
### 3. Parameter Sharing Strategies
|
120 |
+
- **Cycle sharing**: Enables tokens at different depths to be processed together
|
121 |
+
- **Middle cycle sharing**: Optimizes parameter utilization across recursion levels
|
122 |
+
|
123 |
+
## Training Details
|
124 |
+
|
125 |
+
- **Training Framework**: PyTorch with DeepSpeed/Accelerate
|
126 |
+
- **Hardware**: 4x H100/A100 GPUs
|
127 |
+
- **Optimization**: AdamW with cosine learning rate schedule
|
128 |
+
- **Mixed Precision**: bfloat16
|
129 |
+
- **Gradient Accumulation**: Multi-step accumulation for effective large batch training
|
130 |
+
|
131 |
+
## Performance
|
132 |
+
|
133 |
+
### Efficiency Gains
|
134 |
+
- **Throughput**: Up to 2× improvement over standard transformers
|
135 |
+
- **Memory**: Reduced memory requirements through optimized KV caching
|
136 |
+
- **Training**: Lower total FLOPs during training
|
137 |
+
|
138 |
+
### Accuracy Preservation
|
139 |
+
The model maintains competitive performance on standard benchmarks while providing significant efficiency improvements.
|
140 |
+
|
141 |
+
## Use Cases
|
142 |
+
|
143 |
+
- **Efficient text generation**: Ideal for applications requiring fast inference
|
144 |
+
- **Resource-constrained deployment**: Suitable for edge devices and mobile applications
|
145 |
+
- **Real-time applications**: Chat systems, interactive AI assistants
|
146 |
+
- **Research**: Adaptive computation and early-exiting research
|
147 |
+
|
148 |
+
## Limitations
|
149 |
+
|
150 |
+
- Custom architecture requires specific handling for full MoR features
|
151 |
+
- Optimal performance achieved with the complete MoR training framework
|
152 |
+
- May require model-specific optimizations for deployment
|
153 |
+
|
154 |
+
## Citation
|
155 |
+
|
156 |
+
If you use this model in your research, please cite:
|
157 |
+
|
158 |
+
```bibtex
|
159 |
+
@misc{bae2025mixtureofrecursionslearningdynamicrecursive,
|
160 |
+
title={Mixture-of-Recursions: Learning Dynamic Recursive Depths for Adaptive Token-Level Computation},
|
161 |
+
author={Sangmin Bae and Yujin Kim and Reza Bayat and Sungnyun Kim and Jiyoun Ha and Tal Schuster and Adam Fisch and Hrayr Harutyunyan and Ziwei Ji and Aaron Courville and Se-Young Yun},
|
162 |
+
year={2025},
|
163 |
+
eprint={2507.10524},
|
164 |
+
archivePrefix={arXiv},
|
165 |
+
primaryClass={cs.CL},
|
166 |
+
url={https://arxiv.org/abs/2507.10524},
|
167 |
+
}
|
168 |
+
```
|
169 |
+
|
170 |
+
## License
|
171 |
+
|
172 |
+
This model is released under the MIT License. See the LICENSE file for details.
|
173 |
+
|
174 |
+
## Authors
|
175 |
+
|
176 |
+
**Sangmin Bae**, **Yujin Kim**, **Reza Bayat**, Sungnyun Kim, Jiyoun Ha, Tal Schuster, Adam Fisch, Hrayr Harutyunyan, Ziwei Ji, Aaron Courville, Se-Young Yun
|
177 |
+
|
178 |
+
*KAIST AI, Mila, Google Cloud, Google DeepMind, Google Research, Université de Montréal*
|
179 |
+
|
180 |
+
## Links
|
181 |
+
|
182 |
+
- 📄 [Paper](https://arxiv.org/abs/2507.10524)
|
183 |
+
- 💻 [GitHub Repository](https://github.com/raymin0223/mixture_of_recursions)
|
184 |
+
- 🤗 [Hugging Face Model](https://huggingface.co/your-username/mixture-of-recursions-360m)
|
185 |
+
|
186 |
+
---
|
187 |
+
|
188 |
+
*For complete training scripts, evaluation code, and advanced MoR features, please visit the [official GitHub repository](https://github.com/raymin0223/mixture_of_recursions).*
|
config.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"MoRLlamaForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_bias": false,
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"eos_token_id": 0,
|
9 |
+
"head_dim": 64,
|
10 |
+
"hidden_act": "silu",
|
11 |
+
"hidden_size": 960,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 2560,
|
14 |
+
"max_position_embeddings": 1024,
|
15 |
+
"mlp_bias": false,
|
16 |
+
"model_type": "mor_llama",
|
17 |
+
"num_attention_heads": 15,
|
18 |
+
"num_hidden_layers": 32,
|
19 |
+
"num_key_value_heads": 5,
|
20 |
+
"pretraining_tp": 1,
|
21 |
+
"rms_norm_eps": 1e-05,
|
22 |
+
"rope_scaling": null,
|
23 |
+
"rope_theta": 10000.0,
|
24 |
+
"tie_word_embeddings": true,
|
25 |
+
"torch_dtype": "bfloat16",
|
26 |
+
"transformers_version": "4.50.0",
|
27 |
+
"use_cache": true,
|
28 |
+
"vocab_size": 49152,
|
29 |
+
"auto_map": {
|
30 |
+
"AutoConfig": "modeling_mor.MoRConfig",
|
31 |
+
"AutoModelForCausalLM": "modeling_mor.MoRLlamaForCausalLM"
|
32 |
+
},
|
33 |
+
"custom_model": true,
|
34 |
+
"mor_enabled": true
|
35 |
+
}
|
generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 0,
|
4 |
+
"eos_token_id": 0,
|
5 |
+
"transformers_version": "4.50.0"
|
6 |
+
}
|
modeling_mor.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Simplified MoR (Mixture-of-Recursions) model implementation for Hugging Face Hub.
|
3 |
+
This provides basic inference capabilities while maintaining compatibility with the full MoR framework.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from typing import Optional, Tuple, Union
|
9 |
+
from transformers import LlamaForCausalLM, LlamaConfig
|
10 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
11 |
+
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
|
12 |
+
|
13 |
+
class MoRConfig(LlamaConfig):
|
14 |
+
"""
|
15 |
+
Configuration class for MoR model.
|
16 |
+
Extends LlamaConfig with MoR-specific parameters.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
mor_enabled=True,
|
22 |
+
num_recursions=3,
|
23 |
+
routing_strategy="expert_choice",
|
24 |
+
kv_sharing=None,
|
25 |
+
**kwargs
|
26 |
+
):
|
27 |
+
super().__init__(**kwargs)
|
28 |
+
|
29 |
+
# MoR-specific configurations
|
30 |
+
self.mor_enabled = mor_enabled
|
31 |
+
self.num_recursions = num_recursions
|
32 |
+
self.routing_strategy = routing_strategy
|
33 |
+
self.kv_sharing = kv_sharing
|
34 |
+
|
35 |
+
class MoRLlamaForCausalLM(LlamaForCausalLM):
|
36 |
+
"""
|
37 |
+
Simplified MoR model for Hugging Face Hub.
|
38 |
+
|
39 |
+
This implementation provides basic inference capabilities while maintaining
|
40 |
+
compatibility with the original MoR training framework. For full MoR features
|
41 |
+
including dynamic routing and recursion-wise KV caching, use the complete
|
42 |
+
implementation from the original repository.
|
43 |
+
"""
|
44 |
+
|
45 |
+
config_class = MoRConfig
|
46 |
+
|
47 |
+
def __init__(self, config):
|
48 |
+
super().__init__(config)
|
49 |
+
|
50 |
+
# Store MoR-specific config
|
51 |
+
self.mor_config = config
|
52 |
+
|
53 |
+
# For simplified inference, we'll use the standard forward pass
|
54 |
+
# Full MoR capabilities require the complete training framework
|
55 |
+
|
56 |
+
@add_start_docstrings_to_model_forward("Standard forward pass with simplified MoR compatibility")
|
57 |
+
def forward(
|
58 |
+
self,
|
59 |
+
input_ids: torch.LongTensor = None,
|
60 |
+
attention_mask: Optional[torch.Tensor] = None,
|
61 |
+
position_ids: Optional[torch.LongTensor] = None,
|
62 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
63 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
64 |
+
labels: Optional[torch.LongTensor] = None,
|
65 |
+
use_cache: Optional[bool] = None,
|
66 |
+
output_attentions: Optional[bool] = None,
|
67 |
+
output_hidden_states: Optional[bool] = None,
|
68 |
+
return_dict: Optional[bool] = None,
|
69 |
+
**kwargs
|
70 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
71 |
+
"""
|
72 |
+
Forward pass for simplified MoR model.
|
73 |
+
|
74 |
+
For basic inference, this behaves like a standard LLaMA model.
|
75 |
+
Advanced MoR features require the complete training framework.
|
76 |
+
"""
|
77 |
+
|
78 |
+
# Use standard LLaMA forward pass for simplified inference
|
79 |
+
return super().forward(
|
80 |
+
input_ids=input_ids,
|
81 |
+
attention_mask=attention_mask,
|
82 |
+
position_ids=position_ids,
|
83 |
+
past_key_values=past_key_values,
|
84 |
+
inputs_embeds=inputs_embeds,
|
85 |
+
labels=labels,
|
86 |
+
use_cache=use_cache,
|
87 |
+
output_attentions=output_attentions,
|
88 |
+
output_hidden_states=output_hidden_states,
|
89 |
+
return_dict=return_dict,
|
90 |
+
**kwargs
|
91 |
+
)
|
92 |
+
|
93 |
+
@classmethod
|
94 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
95 |
+
"""
|
96 |
+
Load MoR model from pretrained checkpoint.
|
97 |
+
|
98 |
+
This method handles loading the model weights while maintaining
|
99 |
+
compatibility with both the simplified and full MoR implementations.
|
100 |
+
"""
|
101 |
+
|
102 |
+
# Load the model using the parent class method
|
103 |
+
model = super().from_pretrained(
|
104 |
+
pretrained_model_name_or_path,
|
105 |
+
*model_args,
|
106 |
+
**kwargs
|
107 |
+
)
|
108 |
+
|
109 |
+
return model
|
110 |
+
|
111 |
+
def generate_with_mor(
|
112 |
+
self,
|
113 |
+
input_ids: torch.LongTensor,
|
114 |
+
attention_mask: Optional[torch.Tensor] = None,
|
115 |
+
max_length: int = 100,
|
116 |
+
temperature: float = 1.0,
|
117 |
+
do_sample: bool = True,
|
118 |
+
**kwargs
|
119 |
+
):
|
120 |
+
"""
|
121 |
+
Generate text with MoR-aware settings.
|
122 |
+
|
123 |
+
This is a convenience method that provides optimized generation
|
124 |
+
settings for MoR models.
|
125 |
+
"""
|
126 |
+
|
127 |
+
return self.generate(
|
128 |
+
input_ids=input_ids,
|
129 |
+
attention_mask=attention_mask,
|
130 |
+
max_length=max_length,
|
131 |
+
temperature=temperature,
|
132 |
+
do_sample=do_sample,
|
133 |
+
pad_token_id=self.config.eos_token_id,
|
134 |
+
**kwargs
|
135 |
+
)
|
136 |
+
|
137 |
+
# Register the model for auto-loading
|
138 |
+
try:
|
139 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
140 |
+
AutoConfig.register("mor_llama", MoRConfig)
|
141 |
+
AutoModelForCausalLM.register(MoRConfig, MoRLlamaForCausalLM)
|
142 |
+
except:
|
143 |
+
# Registration may fail in some environments, but the model can still be used directly
|
144 |
+
pass
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dbbae3e59491e4bfb15dda78c6fb70372bed99e460761d8ce2dae5f5f3ced538
|
3 |
+
size 523115322
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
transformers>=4.35.0
|
3 |
+
huggingface_hub>=0.17.0
|
4 |
+
accelerate>=0.20.0
|
5 |
+
safetensors>=0.3.0
|
6 |
+
tokenizers>=0.14.0
|
tokenizer_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"tokenizer_class": "LlamaTokenizer",
|
3 |
+
"bos_token": "<s>",
|
4 |
+
"eos_token": "</s>",
|
5 |
+
"unk_token": "<unk>",
|
6 |
+
"model_max_length": 1024
|
7 |
+
}
|