noman007 commited on
Commit
13a325b
·
verified ·
1 Parent(s): 71417a6

Upload 17 files

Browse files

text-generation, text-generation-inference

CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Microsoft Open Source Code of Conduct
2
+
3
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4
+
5
+ Resources:
6
+
7
+ - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8
+ - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9
+ - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
LICENSE ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Microsoft.
2
+ Copyright (c) Microsoft Corporation.
3
+
4
+ MIT License
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
NOTICE.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NOTICES AND INFORMATION
2
+ Do Not Translate or Localize
3
+
4
+ This software incorporates material from third parties.
5
+
6
+ **Component.** https://github.com/Dao-AILab/flash-attention
7
+
8
+ **Open Source License/Copyright Notice.**
9
+
10
+ BSD 3-Clause License
11
+
12
+ Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
13
+ All rights reserved.
14
+
15
+ Redistribution and use in source and binary forms, with or without
16
+ modification, are permitted provided that the following conditions are met:
17
+
18
+ * Redistributions of source code must retain the above copyright notice, this
19
+ list of conditions and the following disclaimer.
20
+
21
+ * Redistributions in binary form must reproduce the above copyright notice,
22
+ this list of conditions and the following disclaimer in the documentation
23
+ and/or other materials provided with the distribution.
24
+
25
+ * Neither the name of the copyright holder nor the names of its
26
+ contributors may be used to endorse or promote products derived from
27
+ this software without specific prior written permission.
28
+
29
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
30
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
31
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
32
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
33
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
34
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
35
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
36
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
37
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
38
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ library_name: transformers
5
+ license: mit
6
+ license_link: https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/resolve/main/LICENSE
7
+ pipeline_tag: text-generation
8
+ tags:
9
+ - nlp
10
+ - math
11
+ - code
12
+ widget:
13
+ - messages:
14
+ - role: user
15
+ content: How to solve 3*x^2+4*x+5=1?
16
+ ---
17
+
18
+ ## Model Summary
19
+
20
+ Phi-4-mini-flash-reasoning is a lightweight open model built upon synthetic data with a focus on high-quality, reasoning dense data further finetuned for more advanced math reasoning capabilities.
21
+ The model belongs to the Phi-4 model family and supports 64K token context length.
22
+
23
+ 📰 [Phi-4-mini-flash-reasoning Blog](https://azure.microsoft.com/en-us/blog/reasoning-reimagined-introducing-phi-4-mini-flash-reasoning/) <br>
24
+ 📖 [Phi-4-mini-flash-reasoning Paper](https://aka.ms/flashreasoning-paper) | [HF Paper](https://huggingface.co/papers/2507.06607) <br>
25
+ 📚 [Training Codebase](https://github.com/microsoft/ArchScale) <br>
26
+ 👩‍🍳 [Phi Cookbook](https://github.com/microsoft/PhiCookBook) <br>
27
+ 🏡 [Phi Portal](https://azure.microsoft.com/en-us/products/phi) <br>
28
+ 🚀 [vLLM Inference](https://github.com/vllm-project/vllm/pull/20702) <br>
29
+ 🖥️ Try It [Azure](https://ai.azure.com/explore/models/Phi-4-mini-flash-reasoning/version/1/registry/azureml-phi-prod) <br>
30
+
31
+
32
+ 🎉**Phi-4 models**: [[Phi-4-mini-reasoning](https://huggingface.co/microsoft/Phi-4-mini-reasoning)] | [[Phi-4-reasoning](https://huggingface.co/microsoft/Phi-4-reasoning)] | [[multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) | [onnx](https://huggingface.co/microsoft/Phi-4-multimodal-instruct-onnx)];
33
+ [[mini-instruct](https://huggingface.co/microsoft/Phi-4-mini-instruct) | [onnx](https://huggingface.co/microsoft/Phi-4-mini-instruct-onnx)]
34
+
35
+ ## Intended Uses
36
+
37
+ ### Primary Use Cases
38
+
39
+ Phi-4-mini-flash-reasoning is designed for multi-step, logic-intensive mathematical problem-solving tasks under memory/compute constrained environments and latency bound scenarios.
40
+ Some of the use cases include formal proof generation, symbolic computation, advanced word problems, and a wide range of mathematical reasoning scenarios.
41
+ These models excel at maintaining context across steps, applying structured logic, and delivering accurate, reliable solutions in domains that require deep analytical thinking.
42
+
43
+ ### Use Case Considerations
44
+
45
+ This model is designed and tested for math reasoning only. It is not specifically designed or evaluated for all downstream purposes.
46
+ Developers should consider common limitations of language models, as well as performance difference across languages, as they select use cases, and evaluate and mitigate for accuracy, safety, and fairness before using within a specific downstream use case, particularly for high-risk scenarios.
47
+ Developers should be aware of and adhere to applicable laws or regulations (including but not limited to privacy, trade compliance laws, etc.) that are relevant to their use case.
48
+
49
+ ***Nothing contained in this Model Card should be interpreted as or deemed a restriction or modification to the license the model is released under.***
50
+
51
+ ## Release Notes
52
+
53
+ This release of Phi-4-mini-flash-reasoning addresses user feedback and market demand for a compact reasoning model.
54
+ It is a compact transformer-based language model optimized for mathematical reasoning, built to deliver high-quality, step-by-step problem solving in environments where computing or latency is constrained.
55
+ The model is fine-tuned with synthetic math data from a more capable model (much larger, smarter, more accurate, and better at following instructions), which has resulted in enhanced reasoning performance.
56
+ Phi-4-mini-flash-reasoning balances reasoning ability with efficiency, making it potentially suitable for educational applications, embedded tutoring, and lightweight deployment on edge or mobile systems.
57
+ If a critical issue is identified with Phi-4-mini-flash-reasoning, it should be promptly reported through the MSRC Researcher Portal or secure@microsoft.com
58
+
59
+ ### Model Quality
60
+
61
+ To understand the capabilities, the 3.8B parameters Phi-4-mini-flash-reasoning model was compared with a set of models over a variety of reasoning benchmarks.
62
+ We use a more accurate evaluation where Pass@1 accuracy is averaged over 64 samples for AIME24/25 and 8 samples for Math500 and GPQA Diamond. A high-level overview of the model quality is as follows:
63
+
64
+ | **Model** | **AIME24** | **AIME25** | **Math500** | **GPQA Diamond** |
65
+ | :----------------------------------- | :--------- | :--------- | :---------- | :--------------- |
66
+ | DeepSeek-R1-Distill-Qwen-1.5B | 29.58 | 20.78 | 84.50 | 37.69 |
67
+ | DeepSeek-R1-Distill-Qwen-7B | 53.70 | 35.94 | 93.03 | 47.85 |
68
+ | DeepSeek-R1-Distill-Llama-8B | 43.96 | 27.34 | 87.48 | 45.83 |
69
+ | Bespoke-Stratos-7B | 21.51 | 18.28 | 80.73 | 38.51 |
70
+ | OpenThinker-7B | 29.69 | 24.32 | 87.25 | 41.60 |
71
+ | Phi4-mini-Reasoning (3.8B) | 48.13 | 31.77 | 91.20 | 44.51 |
72
+ | **Phi4-mini-Flash-Reasoning (3.8B)** | **52.29** | **33.59** | **92.45** | **45.08** |
73
+
74
+ Overall, the model with only 3.8B-param achieves a similar level of math and science reasoning ability as much larger models.
75
+ However, it is still fundamentally limited by its size for certain tasks. The model simply does not have the capacity to store too much factual knowledge, therefore, users may experience factual incorrectness. However, it may be possible to resolve such weakness by augmenting Phi-4-mini-flash-reasoning with a search engine, particularly when using the model under RAG settings.
76
+
77
+ ### Model Efficiency
78
+
79
+ The two figures below compare the latency and throughput performance of the Phi-4-mini-reasoning and Phi-4-mini-flash-reasoning models under the vLLM inference framework. All evaluations were performed on a single NVIDIA A100-80GB GPU with tensor parallelism disabled (TP = 1). The Phi-4-mini-flash-reasoning model, which incorporates a decoder-hybrid-decoder architecture with attention and state space model (SSM), exhibits significantly greater computational efficiency—achieving up-to a 10× improvement in throughput when processing user requests with 2K prompt length and 32K generation length. Furthermore, Phi-4-mini-flash-reasoning demonstrates near-linear growth in latency with respect to the number of tokens generated (up to 32k), in contrast to the quadratic growth observed in Phi-4-mini-reasoning. These findings indicate that Phi-4-mini-flash-reasoning is more scalable and better suited for long-sequence generation tasks.
80
+
81
+ <div align="left">
82
+ <img src="lat.png" width="300"/>
83
+ <img src="thr_lat.png" width="298"/>
84
+ </div>
85
+ Figure 1. The first plot shows average inference latency as a function of generation length, while the second plot illustrates how inference latency varies with throughput. Both experiments were conducted using the vLLM inference framework on a single A100-80GB GPU over varying concurrency levels of user requests.
86
+
87
+ ## Usage
88
+
89
+ ### Tokenizer
90
+
91
+ Phi-4-mini-flash-reasoning supports a vocabulary size of up to `200064` tokens. The [tokenizer files](https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/blob/main/added_tokens.json) already provide placeholder tokens that can be used for downstream fine-tuning, but they can also be extended up to the model's vocabulary size.
92
+
93
+ ### Input Formats
94
+
95
+ Given the nature of the training data, the Phi-4-mini-flash-reasoning
96
+ model is best suited for prompts using this specific chat format:
97
+
98
+ ```yaml
99
+ <|user|>How to solve 3*x^2+4*x+5=1?<|end|><|assistant|>
100
+ ```
101
+ ### Inference with transformers
102
+ List of required packages:
103
+
104
+ ```
105
+ flash_attn==2.7.4.post1
106
+ torch==2.6.0
107
+ mamba-ssm==2.2.4 --no-build-isolation
108
+ causal-conv1d==1.5.0.post8
109
+ transformers==4.46.1
110
+ accelerate==1.4.0
111
+ ```
112
+
113
+ Phi-4-mini-flash-reasoning is also available in [Azure AI Foundry](https://ai.azure.com/explore/models/Phi-4-mini-flash-reasoning/version/1/registry/azureml-phi-prod)
114
+
115
+ #### Example
116
+
117
+ After obtaining the Phi-4-mini-flash-reasoning model checkpoints, users can use this sample code for inference.
118
+
119
+ ```python
120
+ import torch
121
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
122
+ torch.random.manual_seed(0)
123
+
124
+ model_id = "microsoft/Phi-4-mini-flash-reasoning"
125
+ model = AutoModelForCausalLM.from_pretrained(
126
+ model_id,
127
+ device_map="cuda",
128
+ torch_dtype="auto",
129
+ trust_remote_code=True,
130
+ )
131
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
132
+
133
+ messages = [{
134
+ "role": "user",
135
+ "content": "How to solve 3*x^2+4*x+5=1?"
136
+ }]
137
+ inputs = tokenizer.apply_chat_template(
138
+ messages,
139
+ add_generation_prompt=True,
140
+ return_dict=True,
141
+ return_tensors="pt",
142
+ )
143
+
144
+ outputs = model.generate(
145
+ **inputs.to(model.device),
146
+ max_new_tokens=32768,
147
+ temperature=0.6,
148
+ top_p=0.95,
149
+ do_sample=True,
150
+ )
151
+ outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
152
+
153
+ print(outputs[0])
154
+ ```
155
+
156
+ ## Training
157
+
158
+ ### Model
159
+
160
+ + **Architecture:** Phi-4-mini-flash-reasoning adopts a hybrid SambaY architecture with Differential Attention, featuring 3.8 billion parameters and a 200K vocabulary. It incorporates state space models, grouped-query attention, a gated memory sharing mechanism, a shared key-value cache with a single global attention layer, and shared input-output embeddings.<br>
161
+ + **Inputs:** Text. It is best suited for prompts using the chat format.<br>
162
+ + **Context length:** 64K tokens<br>
163
+ + **GPUs:** Pre-training: 1024 A100-80G; Reasoning training: 128 H100-80G <br>
164
+ + **Training time:** Pre-training: 14 days; Reasoning training: 2days <br>
165
+ + **Training data:** Pre-training: 5T tokens; Reasoning training: 150B tokens<br>
166
+ + **Outputs:** Generated text<br>
167
+ + **Dates:** Trained in May 2025 <br>
168
+ + **Status:** This is a static model trained on offline datasets with the cutoff date of February 2025 for publicly available data.<br>
169
+ + **Supported languages:** English<br>
170
+ + **Release date:** June 2025<br>
171
+
172
+ ### Training Datasets
173
+
174
+ The training data for Phi-4-mini-flash-reasoning consists exclusively of synthetic mathematical content generated by a stronger and more advanced reasoning model, Deepseek-R1.
175
+ The objective is to distill knowledge from this model. This synthetic dataset comprises over one million diverse math problems spanning multiple levels of difficulty (from middle school to Ph.D. level).
176
+ For each problem in the synthetic dataset, eight distinct solutions (rollouts) were sampled, and only those verified as correct were retained, resulting in approximately 30 billion tokens of math content.
177
+ The dataset integrates three primary components:
178
+ 1) a curated selection of high-quality, publicly available math questions and a part of the SFT(Supervised Fine-Tuning) data that was used to train the base Phi-4-mini-flash model;
179
+ 2) an extensive collection of synthetic math data generated by the Deepseek-R1 model, designed specifically for high-quality supervised fine-tuning and model distillation; and
180
+ 3) a balanced set of correct and incorrect answers used to construct preference data aimed at enhancing Phi-4-mini-flash-reasoning's reasoning capabilities by learning more effective reasoning trajectories
181
+
182
+ ## Software
183
+ * [PyTorch](https://github.com/pytorch/pytorch)
184
+ * [Transformers](https://github.com/huggingface/transformers)
185
+ * [Flash-Attention](https://github.com/HazyResearch/flash-attention)
186
+ * [Mamba](https://github.com/state-spaces/mamba)
187
+ * [Causal-Conv1d](https://github.com/Dao-AILab/causal-conv1d)
188
+
189
+ ## Hardware
190
+ Note that by default, the Phi-4-mini-flash-reasoning model uses flash attention, which requires certain types of GPU hardware to run. We have tested on the following GPU types:
191
+ * NVIDIA A100
192
+ * NVIDIA H100
193
+
194
+ ## Safety Evaluation and Red-Teaming
195
+
196
+ The Phi-4 family of models has adopted a robust safety post-training approach. This approach leverages a variety of both open-source and in-house generated datasets. The overall technique employed to do the safety alignment is a combination of SFT, DPO (Direct Preference Optimization), and RLHF (Reinforcement Learning from Human Feedback) approaches by utilizing human-labeled and synthetic English-language datasets, including publicly available datasets focusing on helpfulness and harmlessness, as well as various questions and answers targeted to multiple safety categories.
197
+
198
+ Phi-4-Mini-Flash-Reasoning was developed in accordance with Microsoft's responsible AI principles. Potential safety risks in the model’s responses were assessed using the Azure AI Foundry’s Risk and Safety Evaluation framework, focusing on harmful content, direct jailbreak, and model groundedness. The Phi-4-Mini-Flash-Reasoning Model Card contains additional information about our approach to safety and responsible AI considerations that developers should be aware of when using this model.
199
+
200
+ ## Responsible AI Considerations
201
+
202
+ Like other language models, the Phi family of models can potentially behave in ways that are unfair, unreliable, or offensive. Some of the limiting behaviors to be aware of include:
203
+
204
+ + Quality of Service: The Phi models are trained primarily on English text and some additional multilingual text. Languages other than English will experience worse performance as well as performance disparities across non-English. English language varieties with less representation in the training data might experience worse performance than standard American English.
205
+ + Multilingual performance and safety gaps: We believe it is important to make language models more widely available across different languages, but the Phi 4 models still exhibit challenges common across multilingual releases. As with any deployment of LLMs, developers will be better positioned to test for performance or safety gaps for their linguistic and cultural context and customize the model with additional fine-tuning and appropriate safeguards.
206
+ + Representation of Harms & Perpetuation of Stereotypes: These models can over- or under-represent groups of people, erase representation of some groups, or reinforce demeaning or negative stereotypes. Despite safety post-training, these limitations may still be present due to differing levels of representation of different groups, cultural contexts, or prevalence of examples of negative stereotypes in training data that reflect real-world patterns and societal biases.
207
+ + Inappropriate or Offensive Content: These models may produce other types of inappropriate or offensive content, which may make it inappropriate to deploy for sensitive contexts without additional mitigations that are specific to the case.
208
+ + Information Reliability: Language models can generate nonsensical content or fabricate content that might sound reasonable but is inaccurate or outdated.
209
+ + Election Information Reliability : The model has an elevated defect rate when responding to election-critical queries, which may result in incorrect or unauthoritative election critical information being presented. We are working to improve the model's performance in this area. Users should verify information related to elections with the election authority in their region.
210
+ + Limited Scope for Code: The majority of Phi 4 training data is based in Python and uses common packages such as "typing, math, random, collections, datetime, itertools". If the model generates Python scripts that utilize other packages or scripts in other languages, it is strongly recommended that users manually verify all API uses.
211
+ + Long Conversation: Phi 4 models, like other models, can in some cases generate responses that are repetitive, unhelpful, or inconsistent in very long chat sessions in both English and non-English languages. Developers are encouraged to place appropriate mitigations, like limiting conversation turns to account for the possible conversational drift.
212
+
213
+ Developers should apply responsible AI best practices, including mapping, measuring, and mitigating risks associated with their specific use case and cultural, linguistic context. Phi 4 family of models are general purpose models. As developers plan to deploy these models for specific use cases, they are encouraged to fine-tune the models for their use case and leverage the models as part of broader AI systems with language-specific safeguards in place. Important areas for consideration include:
214
+
215
+ + Allocation: Models may not be suitable for scenarios that could have consequential impact on legal status or the allocation of resources or life opportunities (ex: housing, employment, credit, etc.) without further assessments and additional debiasing techniques.
216
+ + High-Risk Scenarios: Developers should assess the suitability of using models in high-risk scenarios where unfair, unreliable or offensive outputs might be extremely costly or lead to harm. This includes providing advice in sensitive or expert domains where accuracy and reliability are critical (ex: legal or health advice). Additional safeguards should be implemented at the application level according to the deployment context.
217
+ + Misinformation: Models may produce inaccurate information. Developers should follow transparency best practices and inform end-users they are interacting with an AI system. At the application level, developers can build feedback mechanisms and pipelines to ground responses in use-case specific, contextual information, a technique known as Retrieval Augmented Generation (RAG).
218
+ + Generation of Harmful Content: Developers should assess outputs for their context and use available safety classifiers or custom solutions appropriate for their use case.
219
+ + Misuse: Other forms of misuse such as fraud, spam, or malware production may be possible, and developers should ensure that their applications do not violate applicable laws and regulations.
220
+
221
+ ## License
222
+ The model is licensed under the [MIT license](./LICENSE).
223
+
224
+ ## Trademarks
225
+ This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft’s Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party’s policies.
226
+
227
+
228
+ ## Appendix A: Benchmark Methodology
229
+
230
+ We include a brief word on methodology here - and in particular, how we think about optimizing prompts. In an ideal world, we would never change any prompts in our benchmarks to ensure it is always an apples-to-apples comparison when comparing different models. Indeed, this is our default approach, and is the case in the vast majority of models we have run to date. For all benchmarks, we consider using the same generation configuration such as max sequence length (32768), the same temperature for the fair comparison.
231
+ Benchmark datasets
232
+ We evaluate the model with three of the most popular math benchmarks where the strongest reasoning models are competing together. Specifically:
233
+ + Math-500: This benchmark consists of 500 challenging math problems designed to test the model's ability to perform complex mathematical reasoning and problem-solving.
234
+ + AIME 2024/AIME 2025: The American Invitational Mathematics Examination (AIME) is a highly regarded math competition that features a series of difficult problems aimed at assessing advanced mathematical skills and logical reasoning. We evaluate the models on the problems from both 2024 and the year 2025 examinations.
235
+ + GPQA Diamond: The Graduate-Level Google-Proof Q&A (GPQA) Diamond benchmark focuses on evaluating the model's ability to understand and solve a wide range of mathematical questions, including both straightforward calculations and more intricate problem-solving tasks.
SECURITY.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->
2
+
3
+ ## Security
4
+
5
+ Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6
+
7
+ If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8
+
9
+ ## Reporting Security Issues
10
+
11
+ **Please do not report security vulnerabilities through public GitHub issues.**
12
+
13
+ Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14
+
15
+ If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16
+
17
+ You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18
+
19
+ Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20
+
21
+ * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22
+ * Full paths of source file(s) related to the manifestation of the issue
23
+ * The location of the affected source code (tag/branch/commit or direct URL)
24
+ * Any special configuration required to reproduce the issue
25
+ * Step-by-step instructions to reproduce the issue
26
+ * Proof-of-concept or exploit code (if possible)
27
+ * Impact of the issue, including how an attacker might exploit the issue
28
+
29
+ This information will help us triage your report more quickly.
30
+
31
+ If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32
+
33
+ ## Preferred Languages
34
+
35
+ We prefer all communications to be in English.
36
+
37
+ ## Policy
38
+
39
+ Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40
+
41
+ <!-- END MICROSOFT SECURITY.MD BLOCK -->
added_tokens.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<|/tool_call|>": 200026,
3
+ "<|/tool|>": 200024,
4
+ "<|assistant|>": 200019,
5
+ "<|end|>": 200020,
6
+ "<|system|>": 200022,
7
+ "<|tag|>": 200028,
8
+ "<|tool_call|>": 200025,
9
+ "<|tool_response|>": 200027,
10
+ "<|tool|>": 200023,
11
+ "<|user|>": 200021
12
+ }
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Phi4FlashForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_phi4flash.Phi4FlashConfig",
8
+ "AutoModelForCausalLM": "modeling_phi4flash.Phi4FlashForCausalLM",
9
+ "AutoTokenizer": "Xenova/gpt-4o"
10
+ },
11
+ "pad_token_id": 199999,
12
+ "bos_token_id": 199999,
13
+ "embd_pdrop": 0.0,
14
+ "eos_token_id": 199999,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 2560,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 10240,
19
+ "layer_norm_eps": 1e-5,
20
+ "max_position_embeddings": 262144,
21
+ "_attn_implementation": "flash_attention_2",
22
+ "mb_per_layer": 2,
23
+ "model_type": "phi4flash",
24
+ "num_attention_heads": 40,
25
+ "num_hidden_layers": 32,
26
+ "num_key_value_heads": 20,
27
+ "resid_pdrop": 0.0,
28
+ "sliding_window": 512,
29
+ "torch_dtype": "bfloat16",
30
+ "tie_word_embeddings": true,
31
+ "transformers_version": "4.46.1",
32
+ "use_cache": true,
33
+ "mlp_bias": false,
34
+ "lm_head_bias": false,
35
+ "vocab_size": 200064
36
+ }
37
+
configuration_phi4flash.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Phi4Flash model configuration"""
17
+
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+ import math
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class Phi4FlashConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Phi4FlashModel`]. It is used to instantiate an Phi4Flash
28
+ model according to the specified arguments, defining the model architecture.
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*, defaults to 51200):
35
+ Vocabulary size of the Phi4Flash model. Defines the number of different tokens that can be represented by the
36
+ `inputs_ids` passed when calling [`Phi4FlashModel`].
37
+ hidden_size (`int`, *optional*, defaults to 2048):
38
+ Dimension of the hidden representations.
39
+ intermediate_size (`int`, *optional*, defaults to 8192):
40
+ Dimension of the MLP representations.
41
+ num_hidden_layers (`int`, *optional*, defaults to 24):
42
+ Number of hidden layers in the Transformer decoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 32):
44
+ Number of attention heads for each attention layer in the Transformer decoder.
45
+ num_key_value_heads (`int`, *optional*):
46
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
47
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
48
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
49
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
50
+ by meanpooling all the original heads within that group. For more details checkout [this
51
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
52
+ `num_attention_heads`.
53
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
54
+ Dropout probability for mlp outputs.
55
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
56
+ The dropout ratio for the embeddings.
57
+ attention_dropout (`float`, *optional*, defaults to 0.0):
58
+ The dropout ratio after computing the attention scores.
59
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`):
60
+ The non-linear activation function (function or string) in the decoder.
61
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
62
+ The maximum sequence length that this model might ever be used with. Phi-1 and Phi-1.5 supports up to 2048
63
+ tokens.
64
+ initializer_range (`float`, *optional*, defaults to 0.02):
65
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
66
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
67
+ The epsilon used by the rms normalization layers.
68
+ use_cache (`bool`, *optional*, defaults to `True`):
69
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
70
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
71
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
72
+ Whether to tie weight embeddings
73
+ rope_theta (`float`, *optional*, defaults to 10000.0):
74
+ The base period of the RoPE embeddings.
75
+
76
+ Example:
77
+
78
+ ```python
79
+ >>> from transformers import Phi4FlashModel, Phi4FlashConfig
80
+
81
+ >>> # Initializing a Phi4Flash style configuration
82
+ >>> configuration = Phi4FlashConfig.from_pretrained("microsoft/Phi4-mini-flash-reasoning")
83
+
84
+ >>> # Initializing a model from the configuration
85
+ >>> model = Phi4FlashModel(configuration)
86
+
87
+ >>> # Accessing the model configuration
88
+ >>> configuration = model.config
89
+ ```"""
90
+
91
+ model_type = "phi4flash"
92
+ keys_to_ignore_at_inference = ["past_key_values"]
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_size=51200,
97
+ hidden_size=2560,
98
+ intermediate_size=9216,
99
+ num_hidden_layers=32,
100
+ num_attention_heads=40,
101
+ num_key_value_heads=4,
102
+ resid_pdrop=0.0,
103
+ embd_pdrop=0.0,
104
+ attention_dropout=0.0,
105
+ hidden_act="silu",
106
+ max_position_embeddings=4096,
107
+ initializer_range=0.02,
108
+ layer_norm_eps=1e-5,
109
+ use_cache=True,
110
+ tie_word_embeddings=True,
111
+ rope_theta=10000.0,
112
+ bos_token_id=1,
113
+ eos_token_id=2,
114
+ sliding_window=2047,
115
+ mb_per_layer= 2,
116
+ mamba_d_state=16,
117
+ mamba_d_conv=4,
118
+ mamba_expand=2,
119
+ mamba_dt_rank="auto",
120
+ mamba_conv_bias=True,
121
+ mamba_proj_bias=False,
122
+ **kwargs,
123
+ ):
124
+ self.vocab_size = vocab_size
125
+ self.hidden_size = hidden_size
126
+ self.intermediate_size = intermediate_size
127
+ self.num_hidden_layers = num_hidden_layers
128
+ self.num_attention_heads = num_attention_heads
129
+
130
+ if num_key_value_heads is None:
131
+ num_key_value_heads = num_attention_heads
132
+
133
+ self.num_key_value_heads = num_key_value_heads
134
+ self.resid_pdrop = resid_pdrop
135
+ self.embd_pdrop = embd_pdrop
136
+ self.attention_dropout = attention_dropout
137
+ self.hidden_act = hidden_act
138
+ self.max_position_embeddings = max_position_embeddings
139
+ self.initializer_range = initializer_range
140
+ self.layer_norm_eps = layer_norm_eps
141
+ self.use_cache = use_cache
142
+ self.rope_theta = rope_theta
143
+ self.mb_per_layer = mb_per_layer
144
+ self.sliding_window = [
145
+ sliding_window if layer_idx < num_hidden_layers // 2 and layer_idx % 2 == 1 else None
146
+ for layer_idx in range(num_hidden_layers)
147
+ ]
148
+
149
+ self.mamba_d_state = mamba_d_state
150
+ self.mamba_d_conv = mamba_d_conv
151
+ self.mamba_expand = mamba_expand
152
+ self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
153
+ self.mamba_conv_bias = mamba_conv_bias
154
+ self.mamba_proj_bias = mamba_proj_bias
155
+
156
+ super().__init__(
157
+ bos_token_id=bos_token_id,
158
+ eos_token_id=eos_token_id,
159
+ tie_word_embeddings=tie_word_embeddings,
160
+ **kwargs,
161
+ )
162
+
163
+
164
+ @property
165
+ def layers_block_type(self):
166
+ layer_block_types = []
167
+ for i in range(self.num_hidden_layers):
168
+ if i % 2 == 1:
169
+ layer_block_type = "attention" if i <= (self.num_hidden_layers //2 +1) else "shared_attention"
170
+ else:
171
+ layer_block_type = "mamba"
172
+ layer_block_types.append(layer_block_type)
173
+ return layer_block_types
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 199999,
4
+ "eos_token_id": [
5
+ 200020,
6
+ 199999
7
+ ],
8
+ "pad_token_id": 199999,
9
+ "transformers_version": "4.45.0"
10
+ }
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58cb8678cd1495c42afd4b2d9ccd26bbace3e54e94905db8c346e7b39ce7a956
3
+ size 135
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:848f66d812125dbf83052af44cd1e010d9b0cf92b5b532b7df61ce1e65935c3e
3
+ size 135
model.safetensors.index.json ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 7706608640
4
+ },
5
+ "weight_map": {
6
+ "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
7
+ "model.final_layernorm.bias": "model-00002-of-00002.safetensors",
8
+ "model.final_layernorm.weight": "model-00002-of-00002.safetensors",
9
+ "model.layers.0.attn.A_log": "model-00001-of-00002.safetensors",
10
+ "model.layers.0.attn.D": "model-00001-of-00002.safetensors",
11
+ "model.layers.0.attn.conv1d.bias": "model-00001-of-00002.safetensors",
12
+ "model.layers.0.attn.conv1d.weight": "model-00001-of-00002.safetensors",
13
+ "model.layers.0.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
14
+ "model.layers.0.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
15
+ "model.layers.0.attn.in_proj.weight": "model-00001-of-00002.safetensors",
16
+ "model.layers.0.attn.out_proj.weight": "model-00001-of-00002.safetensors",
17
+ "model.layers.0.attn.x_proj.weight": "model-00001-of-00002.safetensors",
18
+ "model.layers.0.input_layernorm.bias": "model-00001-of-00002.safetensors",
19
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
20
+ "model.layers.0.mlp.fc1.weight": "model-00001-of-00002.safetensors",
21
+ "model.layers.0.mlp.fc2.weight": "model-00001-of-00002.safetensors",
22
+ "model.layers.0.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
23
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
24
+ "model.layers.1.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
25
+ "model.layers.1.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
26
+ "model.layers.1.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
27
+ "model.layers.1.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
28
+ "model.layers.1.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
29
+ "model.layers.1.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
30
+ "model.layers.1.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
31
+ "model.layers.1.attn.out_proj.bias": "model-00001-of-00002.safetensors",
32
+ "model.layers.1.attn.out_proj.weight": "model-00001-of-00002.safetensors",
33
+ "model.layers.1.input_layernorm.bias": "model-00001-of-00002.safetensors",
34
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
35
+ "model.layers.1.mlp.fc1.weight": "model-00001-of-00002.safetensors",
36
+ "model.layers.1.mlp.fc2.weight": "model-00001-of-00002.safetensors",
37
+ "model.layers.1.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
38
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
39
+ "model.layers.10.attn.A_log": "model-00001-of-00002.safetensors",
40
+ "model.layers.10.attn.D": "model-00001-of-00002.safetensors",
41
+ "model.layers.10.attn.conv1d.bias": "model-00001-of-00002.safetensors",
42
+ "model.layers.10.attn.conv1d.weight": "model-00001-of-00002.safetensors",
43
+ "model.layers.10.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
44
+ "model.layers.10.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
45
+ "model.layers.10.attn.in_proj.weight": "model-00001-of-00002.safetensors",
46
+ "model.layers.10.attn.out_proj.weight": "model-00001-of-00002.safetensors",
47
+ "model.layers.10.attn.x_proj.weight": "model-00001-of-00002.safetensors",
48
+ "model.layers.10.input_layernorm.bias": "model-00001-of-00002.safetensors",
49
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
50
+ "model.layers.10.mlp.fc1.weight": "model-00001-of-00002.safetensors",
51
+ "model.layers.10.mlp.fc2.weight": "model-00001-of-00002.safetensors",
52
+ "model.layers.10.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
53
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
54
+ "model.layers.11.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
55
+ "model.layers.11.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
56
+ "model.layers.11.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
57
+ "model.layers.11.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
58
+ "model.layers.11.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
59
+ "model.layers.11.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
60
+ "model.layers.11.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
61
+ "model.layers.11.attn.out_proj.bias": "model-00001-of-00002.safetensors",
62
+ "model.layers.11.attn.out_proj.weight": "model-00001-of-00002.safetensors",
63
+ "model.layers.11.input_layernorm.bias": "model-00001-of-00002.safetensors",
64
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
65
+ "model.layers.11.mlp.fc1.weight": "model-00001-of-00002.safetensors",
66
+ "model.layers.11.mlp.fc2.weight": "model-00001-of-00002.safetensors",
67
+ "model.layers.11.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
68
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
69
+ "model.layers.12.attn.A_log": "model-00001-of-00002.safetensors",
70
+ "model.layers.12.attn.D": "model-00001-of-00002.safetensors",
71
+ "model.layers.12.attn.conv1d.bias": "model-00001-of-00002.safetensors",
72
+ "model.layers.12.attn.conv1d.weight": "model-00001-of-00002.safetensors",
73
+ "model.layers.12.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
74
+ "model.layers.12.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
75
+ "model.layers.12.attn.in_proj.weight": "model-00001-of-00002.safetensors",
76
+ "model.layers.12.attn.out_proj.weight": "model-00001-of-00002.safetensors",
77
+ "model.layers.12.attn.x_proj.weight": "model-00001-of-00002.safetensors",
78
+ "model.layers.12.input_layernorm.bias": "model-00001-of-00002.safetensors",
79
+ "model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
80
+ "model.layers.12.mlp.fc1.weight": "model-00001-of-00002.safetensors",
81
+ "model.layers.12.mlp.fc2.weight": "model-00001-of-00002.safetensors",
82
+ "model.layers.12.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
83
+ "model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
84
+ "model.layers.13.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
85
+ "model.layers.13.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
86
+ "model.layers.13.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
87
+ "model.layers.13.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
88
+ "model.layers.13.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
89
+ "model.layers.13.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
90
+ "model.layers.13.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
91
+ "model.layers.13.attn.out_proj.bias": "model-00001-of-00002.safetensors",
92
+ "model.layers.13.attn.out_proj.weight": "model-00001-of-00002.safetensors",
93
+ "model.layers.13.input_layernorm.bias": "model-00001-of-00002.safetensors",
94
+ "model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
95
+ "model.layers.13.mlp.fc1.weight": "model-00001-of-00002.safetensors",
96
+ "model.layers.13.mlp.fc2.weight": "model-00001-of-00002.safetensors",
97
+ "model.layers.13.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
98
+ "model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
99
+ "model.layers.14.attn.A_log": "model-00001-of-00002.safetensors",
100
+ "model.layers.14.attn.D": "model-00001-of-00002.safetensors",
101
+ "model.layers.14.attn.conv1d.bias": "model-00001-of-00002.safetensors",
102
+ "model.layers.14.attn.conv1d.weight": "model-00001-of-00002.safetensors",
103
+ "model.layers.14.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
104
+ "model.layers.14.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
105
+ "model.layers.14.attn.in_proj.weight": "model-00001-of-00002.safetensors",
106
+ "model.layers.14.attn.out_proj.weight": "model-00001-of-00002.safetensors",
107
+ "model.layers.14.attn.x_proj.weight": "model-00001-of-00002.safetensors",
108
+ "model.layers.14.input_layernorm.bias": "model-00001-of-00002.safetensors",
109
+ "model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
110
+ "model.layers.14.mlp.fc1.weight": "model-00001-of-00002.safetensors",
111
+ "model.layers.14.mlp.fc2.weight": "model-00001-of-00002.safetensors",
112
+ "model.layers.14.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
113
+ "model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
114
+ "model.layers.15.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
115
+ "model.layers.15.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
116
+ "model.layers.15.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
117
+ "model.layers.15.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
118
+ "model.layers.15.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
119
+ "model.layers.15.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
120
+ "model.layers.15.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
121
+ "model.layers.15.attn.out_proj.bias": "model-00001-of-00002.safetensors",
122
+ "model.layers.15.attn.out_proj.weight": "model-00001-of-00002.safetensors",
123
+ "model.layers.15.input_layernorm.bias": "model-00001-of-00002.safetensors",
124
+ "model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
125
+ "model.layers.15.mlp.fc1.weight": "model-00001-of-00002.safetensors",
126
+ "model.layers.15.mlp.fc2.weight": "model-00001-of-00002.safetensors",
127
+ "model.layers.15.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
128
+ "model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
129
+ "model.layers.16.attn.A_log": "model-00001-of-00002.safetensors",
130
+ "model.layers.16.attn.D": "model-00001-of-00002.safetensors",
131
+ "model.layers.16.attn.conv1d.bias": "model-00001-of-00002.safetensors",
132
+ "model.layers.16.attn.conv1d.weight": "model-00001-of-00002.safetensors",
133
+ "model.layers.16.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
134
+ "model.layers.16.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
135
+ "model.layers.16.attn.in_proj.weight": "model-00001-of-00002.safetensors",
136
+ "model.layers.16.attn.out_proj.weight": "model-00001-of-00002.safetensors",
137
+ "model.layers.16.attn.x_proj.weight": "model-00001-of-00002.safetensors",
138
+ "model.layers.16.input_layernorm.bias": "model-00001-of-00002.safetensors",
139
+ "model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
140
+ "model.layers.16.mlp.fc1.weight": "model-00001-of-00002.safetensors",
141
+ "model.layers.16.mlp.fc2.weight": "model-00001-of-00002.safetensors",
142
+ "model.layers.16.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
143
+ "model.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
144
+ "model.layers.17.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
145
+ "model.layers.17.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
146
+ "model.layers.17.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
147
+ "model.layers.17.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
148
+ "model.layers.17.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
149
+ "model.layers.17.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
150
+ "model.layers.17.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
151
+ "model.layers.17.attn.out_proj.bias": "model-00001-of-00002.safetensors",
152
+ "model.layers.17.attn.out_proj.weight": "model-00001-of-00002.safetensors",
153
+ "model.layers.17.input_layernorm.bias": "model-00001-of-00002.safetensors",
154
+ "model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
155
+ "model.layers.17.mlp.fc1.weight": "model-00001-of-00002.safetensors",
156
+ "model.layers.17.mlp.fc2.weight": "model-00001-of-00002.safetensors",
157
+ "model.layers.17.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
158
+ "model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
159
+ "model.layers.18.attn.in_proj.weight": "model-00002-of-00002.safetensors",
160
+ "model.layers.18.attn.out_proj.weight": "model-00002-of-00002.safetensors",
161
+ "model.layers.18.input_layernorm.bias": "model-00002-of-00002.safetensors",
162
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00002.safetensors",
163
+ "model.layers.18.mlp.fc1.weight": "model-00002-of-00002.safetensors",
164
+ "model.layers.18.mlp.fc2.weight": "model-00002-of-00002.safetensors",
165
+ "model.layers.18.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
166
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
167
+ "model.layers.19.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
168
+ "model.layers.19.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
169
+ "model.layers.19.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
170
+ "model.layers.19.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
171
+ "model.layers.19.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
172
+ "model.layers.19.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
173
+ "model.layers.19.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
174
+ "model.layers.19.attn.out_proj.bias": "model-00002-of-00002.safetensors",
175
+ "model.layers.19.attn.out_proj.weight": "model-00002-of-00002.safetensors",
176
+ "model.layers.19.input_layernorm.bias": "model-00002-of-00002.safetensors",
177
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00002.safetensors",
178
+ "model.layers.19.mlp.fc1.weight": "model-00002-of-00002.safetensors",
179
+ "model.layers.19.mlp.fc2.weight": "model-00002-of-00002.safetensors",
180
+ "model.layers.19.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
181
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
182
+ "model.layers.2.attn.A_log": "model-00001-of-00002.safetensors",
183
+ "model.layers.2.attn.D": "model-00001-of-00002.safetensors",
184
+ "model.layers.2.attn.conv1d.bias": "model-00001-of-00002.safetensors",
185
+ "model.layers.2.attn.conv1d.weight": "model-00001-of-00002.safetensors",
186
+ "model.layers.2.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
187
+ "model.layers.2.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
188
+ "model.layers.2.attn.in_proj.weight": "model-00001-of-00002.safetensors",
189
+ "model.layers.2.attn.out_proj.weight": "model-00001-of-00002.safetensors",
190
+ "model.layers.2.attn.x_proj.weight": "model-00001-of-00002.safetensors",
191
+ "model.layers.2.input_layernorm.bias": "model-00001-of-00002.safetensors",
192
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
193
+ "model.layers.2.mlp.fc1.weight": "model-00001-of-00002.safetensors",
194
+ "model.layers.2.mlp.fc2.weight": "model-00001-of-00002.safetensors",
195
+ "model.layers.2.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
196
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
197
+ "model.layers.20.attn.in_proj.weight": "model-00002-of-00002.safetensors",
198
+ "model.layers.20.attn.out_proj.weight": "model-00002-of-00002.safetensors",
199
+ "model.layers.20.input_layernorm.bias": "model-00002-of-00002.safetensors",
200
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00002.safetensors",
201
+ "model.layers.20.mlp.fc1.weight": "model-00002-of-00002.safetensors",
202
+ "model.layers.20.mlp.fc2.weight": "model-00002-of-00002.safetensors",
203
+ "model.layers.20.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
204
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
205
+ "model.layers.21.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
206
+ "model.layers.21.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
207
+ "model.layers.21.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
208
+ "model.layers.21.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
209
+ "model.layers.21.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
210
+ "model.layers.21.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
211
+ "model.layers.21.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
212
+ "model.layers.21.attn.out_proj.bias": "model-00002-of-00002.safetensors",
213
+ "model.layers.21.attn.out_proj.weight": "model-00002-of-00002.safetensors",
214
+ "model.layers.21.input_layernorm.bias": "model-00002-of-00002.safetensors",
215
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00002.safetensors",
216
+ "model.layers.21.mlp.fc1.weight": "model-00002-of-00002.safetensors",
217
+ "model.layers.21.mlp.fc2.weight": "model-00002-of-00002.safetensors",
218
+ "model.layers.21.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
219
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
220
+ "model.layers.22.attn.in_proj.weight": "model-00002-of-00002.safetensors",
221
+ "model.layers.22.attn.out_proj.weight": "model-00002-of-00002.safetensors",
222
+ "model.layers.22.input_layernorm.bias": "model-00002-of-00002.safetensors",
223
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00002.safetensors",
224
+ "model.layers.22.mlp.fc1.weight": "model-00002-of-00002.safetensors",
225
+ "model.layers.22.mlp.fc2.weight": "model-00002-of-00002.safetensors",
226
+ "model.layers.22.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
227
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
228
+ "model.layers.23.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
229
+ "model.layers.23.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
230
+ "model.layers.23.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
231
+ "model.layers.23.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
232
+ "model.layers.23.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
233
+ "model.layers.23.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
234
+ "model.layers.23.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
235
+ "model.layers.23.attn.out_proj.bias": "model-00002-of-00002.safetensors",
236
+ "model.layers.23.attn.out_proj.weight": "model-00002-of-00002.safetensors",
237
+ "model.layers.23.input_layernorm.bias": "model-00002-of-00002.safetensors",
238
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00002.safetensors",
239
+ "model.layers.23.mlp.fc1.weight": "model-00002-of-00002.safetensors",
240
+ "model.layers.23.mlp.fc2.weight": "model-00002-of-00002.safetensors",
241
+ "model.layers.23.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
242
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
243
+ "model.layers.24.attn.in_proj.weight": "model-00002-of-00002.safetensors",
244
+ "model.layers.24.attn.out_proj.weight": "model-00002-of-00002.safetensors",
245
+ "model.layers.24.input_layernorm.bias": "model-00002-of-00002.safetensors",
246
+ "model.layers.24.input_layernorm.weight": "model-00002-of-00002.safetensors",
247
+ "model.layers.24.mlp.fc1.weight": "model-00002-of-00002.safetensors",
248
+ "model.layers.24.mlp.fc2.weight": "model-00002-of-00002.safetensors",
249
+ "model.layers.24.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
250
+ "model.layers.24.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
251
+ "model.layers.25.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
252
+ "model.layers.25.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
253
+ "model.layers.25.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
254
+ "model.layers.25.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
255
+ "model.layers.25.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
256
+ "model.layers.25.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
257
+ "model.layers.25.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
258
+ "model.layers.25.attn.out_proj.bias": "model-00002-of-00002.safetensors",
259
+ "model.layers.25.attn.out_proj.weight": "model-00002-of-00002.safetensors",
260
+ "model.layers.25.input_layernorm.bias": "model-00002-of-00002.safetensors",
261
+ "model.layers.25.input_layernorm.weight": "model-00002-of-00002.safetensors",
262
+ "model.layers.25.mlp.fc1.weight": "model-00002-of-00002.safetensors",
263
+ "model.layers.25.mlp.fc2.weight": "model-00002-of-00002.safetensors",
264
+ "model.layers.25.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
265
+ "model.layers.25.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
266
+ "model.layers.26.attn.in_proj.weight": "model-00002-of-00002.safetensors",
267
+ "model.layers.26.attn.out_proj.weight": "model-00002-of-00002.safetensors",
268
+ "model.layers.26.input_layernorm.bias": "model-00002-of-00002.safetensors",
269
+ "model.layers.26.input_layernorm.weight": "model-00002-of-00002.safetensors",
270
+ "model.layers.26.mlp.fc1.weight": "model-00002-of-00002.safetensors",
271
+ "model.layers.26.mlp.fc2.weight": "model-00002-of-00002.safetensors",
272
+ "model.layers.26.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
273
+ "model.layers.26.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
274
+ "model.layers.27.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
275
+ "model.layers.27.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
276
+ "model.layers.27.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
277
+ "model.layers.27.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
278
+ "model.layers.27.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
279
+ "model.layers.27.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
280
+ "model.layers.27.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
281
+ "model.layers.27.attn.out_proj.bias": "model-00002-of-00002.safetensors",
282
+ "model.layers.27.attn.out_proj.weight": "model-00002-of-00002.safetensors",
283
+ "model.layers.27.input_layernorm.bias": "model-00002-of-00002.safetensors",
284
+ "model.layers.27.input_layernorm.weight": "model-00002-of-00002.safetensors",
285
+ "model.layers.27.mlp.fc1.weight": "model-00002-of-00002.safetensors",
286
+ "model.layers.27.mlp.fc2.weight": "model-00002-of-00002.safetensors",
287
+ "model.layers.27.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
288
+ "model.layers.27.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
289
+ "model.layers.28.attn.in_proj.weight": "model-00002-of-00002.safetensors",
290
+ "model.layers.28.attn.out_proj.weight": "model-00002-of-00002.safetensors",
291
+ "model.layers.28.input_layernorm.bias": "model-00002-of-00002.safetensors",
292
+ "model.layers.28.input_layernorm.weight": "model-00002-of-00002.safetensors",
293
+ "model.layers.28.mlp.fc1.weight": "model-00002-of-00002.safetensors",
294
+ "model.layers.28.mlp.fc2.weight": "model-00002-of-00002.safetensors",
295
+ "model.layers.28.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
296
+ "model.layers.28.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
297
+ "model.layers.29.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
298
+ "model.layers.29.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
299
+ "model.layers.29.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
300
+ "model.layers.29.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
301
+ "model.layers.29.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
302
+ "model.layers.29.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
303
+ "model.layers.29.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
304
+ "model.layers.29.attn.out_proj.bias": "model-00002-of-00002.safetensors",
305
+ "model.layers.29.attn.out_proj.weight": "model-00002-of-00002.safetensors",
306
+ "model.layers.29.input_layernorm.bias": "model-00002-of-00002.safetensors",
307
+ "model.layers.29.input_layernorm.weight": "model-00002-of-00002.safetensors",
308
+ "model.layers.29.mlp.fc1.weight": "model-00002-of-00002.safetensors",
309
+ "model.layers.29.mlp.fc2.weight": "model-00002-of-00002.safetensors",
310
+ "model.layers.29.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
311
+ "model.layers.29.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
312
+ "model.layers.3.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
313
+ "model.layers.3.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
314
+ "model.layers.3.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
315
+ "model.layers.3.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
316
+ "model.layers.3.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
317
+ "model.layers.3.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
318
+ "model.layers.3.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
319
+ "model.layers.3.attn.out_proj.bias": "model-00001-of-00002.safetensors",
320
+ "model.layers.3.attn.out_proj.weight": "model-00001-of-00002.safetensors",
321
+ "model.layers.3.input_layernorm.bias": "model-00001-of-00002.safetensors",
322
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
323
+ "model.layers.3.mlp.fc1.weight": "model-00001-of-00002.safetensors",
324
+ "model.layers.3.mlp.fc2.weight": "model-00001-of-00002.safetensors",
325
+ "model.layers.3.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
326
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
327
+ "model.layers.30.attn.in_proj.weight": "model-00002-of-00002.safetensors",
328
+ "model.layers.30.attn.out_proj.weight": "model-00002-of-00002.safetensors",
329
+ "model.layers.30.input_layernorm.bias": "model-00002-of-00002.safetensors",
330
+ "model.layers.30.input_layernorm.weight": "model-00002-of-00002.safetensors",
331
+ "model.layers.30.mlp.fc1.weight": "model-00002-of-00002.safetensors",
332
+ "model.layers.30.mlp.fc2.weight": "model-00002-of-00002.safetensors",
333
+ "model.layers.30.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
334
+ "model.layers.30.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
335
+ "model.layers.31.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
336
+ "model.layers.31.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
337
+ "model.layers.31.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
338
+ "model.layers.31.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
339
+ "model.layers.31.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
340
+ "model.layers.31.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
341
+ "model.layers.31.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
342
+ "model.layers.31.attn.out_proj.bias": "model-00002-of-00002.safetensors",
343
+ "model.layers.31.attn.out_proj.weight": "model-00002-of-00002.safetensors",
344
+ "model.layers.31.input_layernorm.bias": "model-00002-of-00002.safetensors",
345
+ "model.layers.31.input_layernorm.weight": "model-00002-of-00002.safetensors",
346
+ "model.layers.31.mlp.fc1.weight": "model-00002-of-00002.safetensors",
347
+ "model.layers.31.mlp.fc2.weight": "model-00002-of-00002.safetensors",
348
+ "model.layers.31.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
349
+ "model.layers.31.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
350
+ "model.layers.4.attn.A_log": "model-00001-of-00002.safetensors",
351
+ "model.layers.4.attn.D": "model-00001-of-00002.safetensors",
352
+ "model.layers.4.attn.conv1d.bias": "model-00001-of-00002.safetensors",
353
+ "model.layers.4.attn.conv1d.weight": "model-00001-of-00002.safetensors",
354
+ "model.layers.4.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
355
+ "model.layers.4.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
356
+ "model.layers.4.attn.in_proj.weight": "model-00001-of-00002.safetensors",
357
+ "model.layers.4.attn.out_proj.weight": "model-00001-of-00002.safetensors",
358
+ "model.layers.4.attn.x_proj.weight": "model-00001-of-00002.safetensors",
359
+ "model.layers.4.input_layernorm.bias": "model-00001-of-00002.safetensors",
360
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
361
+ "model.layers.4.mlp.fc1.weight": "model-00001-of-00002.safetensors",
362
+ "model.layers.4.mlp.fc2.weight": "model-00001-of-00002.safetensors",
363
+ "model.layers.4.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
364
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
365
+ "model.layers.5.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
366
+ "model.layers.5.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
367
+ "model.layers.5.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
368
+ "model.layers.5.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
369
+ "model.layers.5.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
370
+ "model.layers.5.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
371
+ "model.layers.5.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
372
+ "model.layers.5.attn.out_proj.bias": "model-00001-of-00002.safetensors",
373
+ "model.layers.5.attn.out_proj.weight": "model-00001-of-00002.safetensors",
374
+ "model.layers.5.input_layernorm.bias": "model-00001-of-00002.safetensors",
375
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
376
+ "model.layers.5.mlp.fc1.weight": "model-00001-of-00002.safetensors",
377
+ "model.layers.5.mlp.fc2.weight": "model-00001-of-00002.safetensors",
378
+ "model.layers.5.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
379
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
380
+ "model.layers.6.attn.A_log": "model-00001-of-00002.safetensors",
381
+ "model.layers.6.attn.D": "model-00001-of-00002.safetensors",
382
+ "model.layers.6.attn.conv1d.bias": "model-00001-of-00002.safetensors",
383
+ "model.layers.6.attn.conv1d.weight": "model-00001-of-00002.safetensors",
384
+ "model.layers.6.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
385
+ "model.layers.6.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
386
+ "model.layers.6.attn.in_proj.weight": "model-00001-of-00002.safetensors",
387
+ "model.layers.6.attn.out_proj.weight": "model-00001-of-00002.safetensors",
388
+ "model.layers.6.attn.x_proj.weight": "model-00001-of-00002.safetensors",
389
+ "model.layers.6.input_layernorm.bias": "model-00001-of-00002.safetensors",
390
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
391
+ "model.layers.6.mlp.fc1.weight": "model-00001-of-00002.safetensors",
392
+ "model.layers.6.mlp.fc2.weight": "model-00001-of-00002.safetensors",
393
+ "model.layers.6.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
394
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
395
+ "model.layers.7.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
396
+ "model.layers.7.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
397
+ "model.layers.7.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
398
+ "model.layers.7.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
399
+ "model.layers.7.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
400
+ "model.layers.7.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
401
+ "model.layers.7.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
402
+ "model.layers.7.attn.out_proj.bias": "model-00001-of-00002.safetensors",
403
+ "model.layers.7.attn.out_proj.weight": "model-00001-of-00002.safetensors",
404
+ "model.layers.7.input_layernorm.bias": "model-00001-of-00002.safetensors",
405
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
406
+ "model.layers.7.mlp.fc1.weight": "model-00001-of-00002.safetensors",
407
+ "model.layers.7.mlp.fc2.weight": "model-00001-of-00002.safetensors",
408
+ "model.layers.7.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
409
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
410
+ "model.layers.8.attn.A_log": "model-00001-of-00002.safetensors",
411
+ "model.layers.8.attn.D": "model-00001-of-00002.safetensors",
412
+ "model.layers.8.attn.conv1d.bias": "model-00001-of-00002.safetensors",
413
+ "model.layers.8.attn.conv1d.weight": "model-00001-of-00002.safetensors",
414
+ "model.layers.8.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
415
+ "model.layers.8.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
416
+ "model.layers.8.attn.in_proj.weight": "model-00001-of-00002.safetensors",
417
+ "model.layers.8.attn.out_proj.weight": "model-00001-of-00002.safetensors",
418
+ "model.layers.8.attn.x_proj.weight": "model-00001-of-00002.safetensors",
419
+ "model.layers.8.input_layernorm.bias": "model-00001-of-00002.safetensors",
420
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
421
+ "model.layers.8.mlp.fc1.weight": "model-00001-of-00002.safetensors",
422
+ "model.layers.8.mlp.fc2.weight": "model-00001-of-00002.safetensors",
423
+ "model.layers.8.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
424
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
425
+ "model.layers.9.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
426
+ "model.layers.9.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
427
+ "model.layers.9.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
428
+ "model.layers.9.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
429
+ "model.layers.9.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
430
+ "model.layers.9.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
431
+ "model.layers.9.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
432
+ "model.layers.9.attn.out_proj.bias": "model-00001-of-00002.safetensors",
433
+ "model.layers.9.attn.out_proj.weight": "model-00001-of-00002.safetensors",
434
+ "model.layers.9.input_layernorm.bias": "model-00001-of-00002.safetensors",
435
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
436
+ "model.layers.9.mlp.fc1.weight": "model-00001-of-00002.safetensors",
437
+ "model.layers.9.mlp.fc2.weight": "model-00001-of-00002.safetensors",
438
+ "model.layers.9.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
439
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors"
440
+ }
441
+ }
modeling_phi4flash.py ADDED
@@ -0,0 +1,2098 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ PyTorch Phi4Flash model."""
17
+
18
+
19
+ import inspect
20
+ import math
21
+ import warnings
22
+ from typing import List, Optional, Tuple, Union, Dict, Any
23
+ import copy
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.utils import is_torchdynamo_compiling
32
+ from transformers.modeling_outputs import (
33
+ BaseModelOutputWithPast,
34
+ CausalLMOutputWithPast,
35
+ SequenceClassifierOutputWithPast,
36
+ TokenClassifierOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.generation import GenerationMixin
40
+ from transformers.utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ is_flash_attn_greater_or_equal_2_10,
45
+ logging,
46
+ replace_return_docstrings,
47
+ )
48
+ from einops import rearrange, repeat
49
+
50
+ from .configuration_phi4flash import Phi4FlashConfig
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
55
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
56
+
57
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
58
+
59
+ if not _flash_supports_window_size:
60
+ raise ValueError("Please update flash-attention to support window size.")
61
+
62
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
63
+ import causal_conv1d_cuda
64
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
65
+
66
+ from torch.amp import custom_bwd, custom_fwd
67
+ import selective_scan_cuda
68
+
69
+ _CHECKPOINT_FOR_DOC = "microsoft/Phi-4-mini-flash-reasoning"
70
+ _CONFIG_FOR_DOC = "Phi4FlashConfig"
71
+
72
+ # monkey patch to add support for our cache
73
+ def _prepare_cache_for_generation(
74
+ self,
75
+ generation_config,
76
+ model_kwargs: Dict,
77
+ assistant_model: "PreTrainedModel",
78
+ batch_size: int,
79
+ max_cache_length: int,
80
+ device: torch.device,
81
+ ) -> bool:
82
+ """
83
+ Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is
84
+ instantiated, writes it to `model_kwargs`, under the name expected by the model.
85
+ """
86
+
87
+ cache_name = "past_key_values"
88
+
89
+ # Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in
90
+ # `generation_config.validate()`)
91
+ if generation_config.use_cache is False:
92
+ return
93
+
94
+ # Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation`
95
+
96
+ # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
97
+ # which is only supported in dynamic caches atm
98
+ if assistant_model is not None:
99
+ logger.warning_once(
100
+ "An assistant model is provided, using a dynamic cache instead of a cache of type="
101
+ f"'{generation_config.cache_implementation}'."
102
+ )
103
+ model_kwargs[cache_name] = DynamicCache()
104
+ return
105
+
106
+ model_kwargs[cache_name] = self._get_cache(
107
+ cache_implementation="sambay",
108
+ batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
109
+ max_cache_len=max_cache_length,
110
+ device=device,
111
+ model_kwargs=model_kwargs,
112
+ )
113
+
114
+ def _get_cache(
115
+ self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
116
+ ) -> Cache:
117
+ """
118
+ Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
119
+ new `generate` call requires a larger cache or uses a different batch size.
120
+
121
+ Returns the resulting cache object.
122
+ """
123
+ cache_cls: Cache = SambaYCache
124
+ requires_cross_attention_cache = (
125
+ self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
126
+ )
127
+
128
+ if hasattr(self, "_cache"):
129
+ cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
130
+
131
+ if cache_implementation == "sliding_window":
132
+ max_cache_len = min(self.config.sliding_window[1], max_cache_len)
133
+
134
+ need_new_cache = (
135
+ not hasattr(self, "_cache")
136
+ or (not isinstance(cache_to_check, cache_cls))
137
+ or cache_to_check.batch_size != batch_size
138
+ )
139
+ if cache_implementation != "mamba":
140
+ need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
141
+
142
+ if requires_cross_attention_cache and hasattr(self, "_cache"):
143
+ need_new_cache = (
144
+ need_new_cache
145
+ or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
146
+ )
147
+
148
+ if need_new_cache:
149
+ if hasattr(self.config, "_pre_quantization_dtype"):
150
+ cache_dtype = self.config._pre_quantization_dtype
151
+ else:
152
+ if not is_torchdynamo_compiling():
153
+ cache_dtype = self.dtype
154
+ else:
155
+ # NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`.
156
+ # Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative
157
+ # models. May cause trobles with non-text modalities.
158
+ cache_dtype = self.get_output_embeddings().weight.dtype
159
+
160
+ def get_layer_device_map(execution_device_map: Optional[dict] = None):
161
+ if execution_device_map is None:
162
+ return None
163
+ elif len(execution_device_map) == 1 and "" in execution_device_map:
164
+ return {idx: execution_device_map[""] for idx in range(self.config.num_hidden_layers)}
165
+ layer_device_map = {}
166
+ for layer in execution_device_map:
167
+ for idx in range(self.config.num_hidden_layers):
168
+ if f".{idx}." in f"{layer}.":
169
+ layer_device_map[idx] = execution_device_map[layer]
170
+ break
171
+ for idx in range(self.config.num_hidden_layers):
172
+ if idx not in layer_device_map:
173
+ raise RuntimeError(f"layer {idx} has not been mapped to a device.")
174
+ return layer_device_map
175
+
176
+ execution_device_map = None
177
+ # Taken from dispatch_model from accelerate.
178
+ # This is needed here if we don't want to make changes in accelerate in order to save execution_device
179
+ # For offloaded case, we need to get the execution device, not just the device where it is offloaded
180
+ if hasattr(self, "hf_device_map"):
181
+ main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
182
+ execution_device_map = {
183
+ name: main_device if device in ["cpu", "disk"] else device
184
+ for name, device in self.hf_device_map.items()
185
+ }
186
+ layer_device_map = get_layer_device_map(execution_device_map)
187
+
188
+ cache_kwargs = {
189
+ "config": self.config.get_text_config(),
190
+ "batch_size": batch_size,
191
+ "max_cache_len": max_cache_len,
192
+ "device": device,
193
+ "dtype": cache_dtype,
194
+ "layer_device_map": layer_device_map,
195
+ }
196
+ self._cache = cache_cls(**cache_kwargs)
197
+ else:
198
+ self._cache.reset()
199
+ return self._cache
200
+
201
+ GenerationMixin._prepare_cache_for_generation = _prepare_cache_for_generation
202
+ GenerationMixin._get_cache = _get_cache
203
+
204
+ class SambaYCache(Cache):
205
+ """
206
+ A dynamic cache that can handle the sliding window attention cache, one layer of full attention cache and the mamba cache
207
+ (which has a constant shape regardless of seq_len).
208
+
209
+ """
210
+
211
+ def __init__(self,
212
+ config: Phi4FlashConfig,
213
+ batch_size: int = None,
214
+ max_cache_len: int = None,
215
+ device: Union[torch.device, str] = "cuda",
216
+ dtype: torch.dtype = torch.float16,
217
+ max_batch_size: Optional[int] = None,
218
+ layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
219
+ ) -> None:
220
+ super().__init__()
221
+ self.dtype = dtype
222
+ self.has_previous_state = False # only used by mamba
223
+ intermediate_size = config.mamba_expand * config.hidden_size
224
+ ssm_state_size = config.mamba_d_state
225
+ conv_kernel_size = config.mamba_d_conv
226
+ self.conv_kernel_size = conv_kernel_size
227
+
228
+ if batch_size is not None:
229
+ logger.warning_once(
230
+ f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
231
+ "v4.49. Use the more precisely named 'max_batch_size' argument instead."
232
+ )
233
+
234
+ self.max_cache_len = max_cache_len
235
+ self.max_batch_size = batch_size or max_batch_size
236
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
237
+ self.head_dim = config.hidden_size // config.num_attention_heads
238
+ self.num_key_value_heads = config.num_key_value_heads
239
+ self.global_attn_idx = config.num_hidden_layers//2 + 1
240
+ self.key_cache: List[torch.Tensor] = []
241
+ self.value_cache: List[torch.Tensor] = []
242
+ global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
243
+ sliding_cache_shape = (
244
+ self.max_batch_size,
245
+ self.num_key_value_heads,
246
+ min(config.sliding_window[1], max_cache_len),
247
+ self.head_dim,
248
+ )
249
+ conv_cache_shape = (self.max_batch_size, intermediate_size, conv_kernel_size)
250
+ ssm_cache_shape = (self.max_batch_size, intermediate_size, ssm_state_size)
251
+ for i in range(config.num_hidden_layers//2 + 2):
252
+ if layer_device_map is not None:
253
+ layer_device = layer_device_map[i]
254
+ else:
255
+ layer_device = device
256
+ # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
257
+ # breaks when updating the cache.
258
+ if i == self.global_attn_idx:
259
+ key_cache_shape = value_cache_shape = global_cache_shape
260
+ elif i % 2 == 0:
261
+ key_cache_shape = conv_cache_shape
262
+ value_cache_shape = ssm_cache_shape
263
+ else:
264
+ key_cache_shape = value_cache_shape = sliding_cache_shape
265
+ new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=layer_device)
266
+ new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=layer_device)
267
+ torch._dynamo.mark_static_address(new_layer_key_cache)
268
+ torch._dynamo.mark_static_address(new_layer_value_cache)
269
+ self.key_cache.append(new_layer_key_cache)
270
+ self.value_cache.append(new_layer_value_cache)
271
+
272
+ def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
273
+ if cache_position.shape[0] > max_cache_len:
274
+ k_out = key_states[:, :, -max_cache_len:, :]
275
+ v_out = value_states[:, :, -max_cache_len:, :]
276
+ # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
277
+ self.key_cache[layer_idx] += k_out
278
+ self.value_cache[layer_idx] += v_out
279
+ # we should return the whole states instead of k_out, v_out to take the whole prompt
280
+ # into consideration when building kv cache instead of just throwing away tokens outside of the window
281
+ return key_states, value_states
282
+
283
+ slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
284
+ cache_position = cache_position.clamp(0, max_cache_len - 1)
285
+ to_shift = cache_position >= max_cache_len - 1
286
+ indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
287
+ k_out = k_out[:, :, indices]
288
+ v_out = v_out[:, :, indices]
289
+
290
+ k_out[:, :, cache_position] = key_states
291
+ v_out[:, :, cache_position] = value_states
292
+ # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
293
+ self.key_cache[layer_idx].zero_()
294
+ self.value_cache[layer_idx].zero_()
295
+
296
+ self.key_cache[layer_idx] += k_out
297
+ self.value_cache[layer_idx] += v_out
298
+ return k_out, v_out
299
+
300
+ def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
301
+ k_out[:, :, cache_position] = key_states
302
+ v_out[:, :, cache_position] = value_states
303
+
304
+ self.key_cache[layer_idx] = k_out
305
+ self.value_cache[layer_idx] = v_out
306
+ return k_out, v_out
307
+
308
+ def update(
309
+ self,
310
+ key_states: torch.Tensor,
311
+ value_states: torch.Tensor,
312
+ layer_idx: int,
313
+ cache_kwargs: Optional[Dict[str, Any]] = None,
314
+ ) -> Tuple[torch.Tensor]:
315
+ cache_position = cache_kwargs.get("cache_position")
316
+ k_out = self.key_cache[layer_idx]
317
+ v_out = self.value_cache[layer_idx]
318
+ if layer_idx == self.global_attn_idx:
319
+ update_fn = self._static_update
320
+ elif layer_idx % 2 == 1:
321
+ update_fn = self._sliding_update
322
+
323
+ return update_fn(
324
+ cache_position,
325
+ layer_idx,
326
+ key_states,
327
+ value_states,
328
+ k_out,
329
+ v_out,
330
+ k_out.shape[2],
331
+ )
332
+
333
+ def get_max_cache_shape(self) -> Optional[int]:
334
+ return self.max_cache_len
335
+
336
+ def get_seq_length(self, layer_idx: Optional[int] = 0):
337
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
338
+ # limit the check to the first batch member and head dimension.
339
+ # TODO: deprecate this function in favor of `cache_position`
340
+ return (self.key_cache[self.global_attn_idx][0, 0].any(dim=-1)).sum()
341
+
342
+ def reset(self):
343
+ """Resets the cache values while preserving the objects"""
344
+ for layer_idx in range(len(self.key_cache)):
345
+ # In-place ops prevent breaking the static address
346
+ self.key_cache[layer_idx].zero_()
347
+ self.value_cache[layer_idx].zero_()
348
+
349
+ @property
350
+ def batch_size(self):
351
+ logger.warning_once(
352
+ f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
353
+ "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
354
+ )
355
+ return self.max_batch_size
356
+
357
+
358
+
359
+
360
+ swiglu_fwd_codestring = """
361
+ template <typename T> T swiglu_fwd(T x, T y) {
362
+ return float(x) * float(y) / (1.0f + ::exp(-float(x)));
363
+ }
364
+ """
365
+ swiglu_bwd_codestring = """
366
+ template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
367
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
368
+ dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
369
+ dy = float(x) * x_sigmoid * float(g);
370
+ }
371
+ """
372
+ swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
373
+ swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
374
+
375
+
376
+ class SwiGLUFunction(torch.autograd.Function):
377
+
378
+ @staticmethod
379
+ def forward(ctx, x, y):
380
+ ctx.save_for_backward(x, y)
381
+ return swiglu_fwd(x, y)
382
+
383
+ @staticmethod
384
+ def backward(ctx, dout):
385
+ x, y = ctx.saved_tensors
386
+ return swiglu_bwd(x, y, dout)
387
+
388
+ swiglu = SwiGLUFunction.apply
389
+
390
+
391
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->SambaY
392
+ class SambaYRMSNorm(nn.Module):
393
+ def __init__(self, hidden_size, eps=1e-5):
394
+ """
395
+ SambaYRMSNorm is equivalent to T5LayerNorm
396
+ """
397
+ super().__init__()
398
+ self.weight = nn.Parameter(torch.ones(hidden_size))
399
+ self.variance_epsilon = eps
400
+
401
+ def forward(self, hidden_states):
402
+ input_dtype = hidden_states.dtype
403
+ hidden_states = hidden_states.to(torch.float32)
404
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
405
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
406
+ return self.weight * hidden_states.to(input_dtype)
407
+
408
+
409
+ PHI_NORM_CLASS = nn.LayerNorm
410
+
411
+
412
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
413
+ def _get_unpad_data(attention_mask):
414
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
415
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
416
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
417
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
418
+ return (
419
+ indices,
420
+ cu_seqlens,
421
+ max_seqlen_in_batch,
422
+ )
423
+
424
+
425
+ class SambaYMLP(nn.Module):
426
+ """Gated Linear Unit.
427
+
428
+ Reference:
429
+ Language Modeling with Gated Convolutional Networks.
430
+ https://arxiv.org/pdf/1612.08083v3.pdf.
431
+
432
+ """
433
+
434
+ def __init__(self, config):
435
+ super().__init__()
436
+
437
+ self.config = config
438
+ self.fc1 = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
439
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
440
+
441
+ self.activation_fn = ACT2FN[config.hidden_act]
442
+
443
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
444
+ y = self.fc1(hidden_states)
445
+
446
+ # Special case for SwiGLU
447
+ if self.config.hidden_act == "silu" and swiglu is not None:
448
+ gate, y = y.chunk(2, dim=-1)
449
+ y = swiglu(gate, y)
450
+ else:
451
+ gate, y = y.chunk(2, dim=-1)
452
+ y = y * self.activation_fn(gate)
453
+
454
+ return self.fc2(y)
455
+
456
+
457
+ class SambaYAttention(nn.Module):
458
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
459
+
460
+ def __init__(self, config: Phi4FlashConfig, layer_idx: Optional[int] = None, yoco_cross: bool = False):
461
+ super().__init__()
462
+ self.config = config
463
+ self.layer_idx = layer_idx
464
+ if layer_idx is None:
465
+ logger.warning_once(
466
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
467
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
468
+ "when creating this class."
469
+ )
470
+
471
+ self.attention_dropout = config.attention_dropout
472
+ self.hidden_size = config.hidden_size
473
+ self.num_heads = config.num_attention_heads
474
+ self.head_dim = self.hidden_size // self.num_heads
475
+ self.num_key_value_heads = config.num_key_value_heads
476
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
477
+ self.max_position_embeddings = config.max_position_embeddings
478
+ self.is_causal = True
479
+ self.yoco_cross = yoco_cross
480
+
481
+ if (self.head_dim * self.num_heads) != self.hidden_size:
482
+ raise ValueError(
483
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
484
+ f" and `num_heads`: {self.num_heads})."
485
+ )
486
+
487
+ op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
488
+ self.out_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
489
+ if yoco_cross:
490
+ self.Wqkv = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
491
+ else:
492
+ self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
493
+
494
+ self.inner_cross_attn = FlashDiffCustomAttention(self.head_dim, self.layer_idx,)
495
+
496
+
497
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
498
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
499
+
500
+ def forward(
501
+ self,
502
+ hidden_states: torch.Tensor,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ position_ids: Optional[torch.LongTensor] = None,
505
+ past_key_value: Optional[Cache] = None,
506
+ output_attentions: bool = False,
507
+ use_cache: bool = False,
508
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
509
+ raise NotImplementedError("SambaYAttention only support flash attention")
510
+
511
+
512
+ class SambaYFlashAttention2(SambaYAttention):
513
+ """
514
+ SambaY flash attention module. This module inherits from `SambaYAttention` as the weights of the module stays
515
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
516
+ flash attention and deal with padding tokens in case the input contains any of them.
517
+ """
518
+
519
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
520
+ def __init__(self, *args, **kwargs):
521
+ super().__init__(*args, **kwargs)
522
+
523
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
524
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
525
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
526
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
527
+
528
+
529
+
530
+ def forward(
531
+ self,
532
+ hidden_states: torch.Tensor,
533
+ attention_mask: Optional[torch.LongTensor] = None,
534
+ position_ids: Optional[torch.LongTensor] = None,
535
+ past_key_value: Optional[Cache] = None,
536
+ output_attentions: bool = False,
537
+ use_cache: bool = False,
538
+ cache_position: Optional[torch.LongTensor] = None,
539
+ yoco_key_values: Optional[torch.Tensor] = None,
540
+ **kwargs,
541
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
542
+ # SambaYFlashAttention2 attention does not support output_attentions
543
+
544
+ output_attentions = False
545
+ if "padding_mask" in kwargs:
546
+ warnings.warn(
547
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
548
+ )
549
+
550
+ # overwrite attention_mask with padding_mask
551
+ attention_mask = kwargs.pop("padding_mask")
552
+
553
+ bsz, q_len, _ = hidden_states.size()
554
+ if self.yoco_cross:
555
+ q = self.Wqkv(hidden_states)
556
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim).transpose(1,2)
557
+ key_states, value_states = yoco_key_values
558
+ query_states = q
559
+
560
+ use_sliding_windows = False
561
+ else:
562
+
563
+ qkv = self.Wqkv(hidden_states)
564
+ query_pos = self.num_heads * self.head_dim
565
+ query_states = qkv[..., :query_pos]
566
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
567
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
568
+
569
+ # Flash attention requires the input to have the shape
570
+ # batch_size x seq_length x head_dim x hidden_dim
571
+ # therefore we just need to keep the original shape
572
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
573
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
574
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
575
+
576
+ use_sliding_windows = self.config.sliding_window is not None and self.config.sliding_window[self.layer_idx] is not None
577
+
578
+ if past_key_value is not None:
579
+
580
+ cache_kwargs = {"cache_position": cache_position}# Specific to RoPE models
581
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
582
+
583
+
584
+ yoco_key_values = key_states, value_states
585
+
586
+ attn_dropout = self.attention_dropout if self.training else 0.0
587
+
588
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
589
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
590
+ # cast them back in the correct dtype just to be sure everything works as expected.
591
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
592
+ # in fp32.
593
+
594
+ if query_states.dtype == torch.float32:
595
+ if torch.is_autocast_enabled():
596
+ target_dtype = torch.get_autocast_gpu_dtype()
597
+ # Handle the case where the model is quantized
598
+ elif hasattr(self.config, "_pre_quantization_dtype"):
599
+ target_dtype = self.config._pre_quantization_dtype
600
+ else:
601
+ target_dtype = self.Wqkv.weight.dtype
602
+
603
+ logger.warning_once(
604
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
605
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
606
+ f" {target_dtype}."
607
+ )
608
+
609
+ query_states = query_states.to(target_dtype)
610
+ key_states = key_states.to(target_dtype)
611
+ value_states = value_states.to(target_dtype)
612
+
613
+ # Reashape to the expected shape for Flash Attention
614
+ # -> b,q,h,d
615
+ query_states = query_states.transpose(1, 2)
616
+ key_states = key_states.transpose(1, 2)
617
+ value_states = value_states.transpose(1, 2)
618
+ if attention_mask is not None:
619
+ key_states = key_states[:, :attention_mask.shape[-1]]
620
+ value_states = value_states[:, :attention_mask.shape[-1]]
621
+ attn_output = self._flash_attention_forward(
622
+ query_states,
623
+ key_states,
624
+ value_states,
625
+ attention_mask,
626
+ q_len,
627
+ dropout=attn_dropout,
628
+ use_sliding_windows=use_sliding_windows,
629
+ )
630
+
631
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
632
+ attn_output = self.out_proj(attn_output)
633
+
634
+ if not output_attentions:
635
+ attn_weights = None
636
+
637
+ return attn_output, attn_weights, yoco_key_values
638
+
639
+ def _flash_attention_forward(
640
+ self,
641
+ query_states,
642
+ key_states,
643
+ value_states,
644
+ attention_mask,
645
+ query_length,
646
+ dropout=0.0,
647
+ softmax_scale=None,
648
+ use_sliding_windows=False,
649
+ ):
650
+ """
651
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
652
+ first unpad the input, then computes the attention scores and pad the final attention scores.
653
+
654
+ Args:
655
+ query_states (`torch.Tensor`):
656
+ Input query states to be passed to Flash Attention API
657
+ key_states (`torch.Tensor`):
658
+ Input key states to be passed to Flash Attention API
659
+ value_states (`torch.Tensor`):
660
+ Input value states to be passed to Flash Attention API
661
+ attention_mask (`torch.Tensor`):
662
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
663
+ position of padding tokens and 1 for the position of non-padding tokens.
664
+ dropout (`float`):
665
+ Attention dropout
666
+ softmax_scale (`float`, *optional*):
667
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
668
+ use_sliding_windows (`bool`, *optional*):
669
+ Whether to activate sliding window attention.
670
+ """
671
+ causal = self.is_causal
672
+ # Contains at least one padding token in the sequence
673
+ if attention_mask is not None:
674
+ batch_size = query_states.shape[0]
675
+ (
676
+ query_states,
677
+ key_states,
678
+ value_states,
679
+ indices_q,
680
+ cu_seq_lens,
681
+ max_seq_lens,
682
+ ) = self._upad_input(query_states, key_states, value_states, attention_mask, query_length)
683
+
684
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
685
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
686
+
687
+ if not use_sliding_windows:
688
+ attn_output_unpad = self.inner_cross_attn(
689
+ query_states,
690
+ key_states,
691
+ value_states,
692
+ cu_seqlens_q=cu_seqlens_q,
693
+ cu_seqlens_k=cu_seqlens_k,
694
+ max_seqlen_q=max_seqlen_in_batch_q,
695
+ max_seqlen_k=max_seqlen_in_batch_k,
696
+ dropout_p=dropout,
697
+ softmax_scale=softmax_scale,
698
+ causal=causal,
699
+ )
700
+ else:
701
+ attn_output_unpad = self.inner_cross_attn(
702
+ query_states,
703
+ key_states,
704
+ value_states,
705
+ cu_seqlens_q=cu_seqlens_q,
706
+ cu_seqlens_k=cu_seqlens_k,
707
+ max_seqlen_q=max_seqlen_in_batch_q,
708
+ max_seqlen_k=max_seqlen_in_batch_k,
709
+ dropout_p=dropout,
710
+ softmax_scale=softmax_scale,
711
+ causal=causal,
712
+ window_size=(
713
+ self.config.sliding_window[self.layer_idx] -1,
714
+ self.config.sliding_window[self.layer_idx] -1,
715
+ ),
716
+ )
717
+
718
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
719
+ else:
720
+ if not use_sliding_windows:
721
+ attn_output = self.inner_cross_attn(
722
+ query_states,
723
+ key_states,
724
+ value_states,
725
+ dropout_p=dropout,
726
+ softmax_scale=softmax_scale,
727
+ causal=causal,
728
+ )
729
+ else:
730
+ attn_output = self.inner_cross_attn(
731
+ query_states,
732
+ key_states,
733
+ value_states,
734
+ dropout_p=dropout,
735
+ softmax_scale=softmax_scale,
736
+ causal=causal,
737
+ window_size=(
738
+ self.config.sliding_window[self.layer_idx] -1,
739
+ self.config.sliding_window[self.layer_idx] -1,
740
+ ),
741
+ )
742
+
743
+ return attn_output
744
+
745
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
746
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
747
+
748
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
749
+
750
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
751
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
752
+
753
+ if query_length == kv_seq_len:
754
+ query_layer = index_first_axis(
755
+ query_layer.reshape(batch_size * kv_seq_len, -1, head_dim),
756
+ indices_k,
757
+ )
758
+ cu_seqlens_q = cu_seqlens_k
759
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
760
+ indices_q = indices_k
761
+ elif query_length == 1:
762
+ max_seqlen_in_batch_q = 1
763
+ cu_seqlens_q = torch.arange(
764
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
765
+ ) # There is a memcpy here, that is very bad.
766
+ indices_q = cu_seqlens_q[:-1]
767
+ query_layer = query_layer.squeeze(1)
768
+ else:
769
+ # The -q_len: slice assumes left padding.
770
+ attention_mask = attention_mask[:, -query_length:]
771
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
772
+
773
+ return (
774
+ query_layer,
775
+ key_layer,
776
+ value_layer,
777
+ indices_q,
778
+ (cu_seqlens_q, cu_seqlens_k),
779
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
780
+ )
781
+
782
+
783
+
784
+ class Phi3Mamba(nn.Module):
785
+ def __init__(
786
+ self,
787
+ d_model,
788
+ d_state=16,
789
+ d_conv=4,
790
+ expand=2,
791
+ dt_rank="auto",
792
+ conv_bias=True,
793
+ bias=False,
794
+ use_fast_path=True, # Fused kernel options
795
+ layer_idx=None,
796
+ yoco_cross=False,
797
+ yoco_kv=False,
798
+ dtype=None,
799
+ ):
800
+ factory_kwargs = {"dtype": dtype}
801
+ super().__init__()
802
+ self.d_model = d_model
803
+ self.d_state = d_state
804
+ self.d_conv = d_conv
805
+ self.expand = expand
806
+ self.d_inner = int(self.expand * self.d_model)
807
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
808
+ self.use_fast_path = use_fast_path
809
+ self.layer_idx = layer_idx
810
+
811
+ self.yoco_cross = yoco_cross
812
+ self.yoco_kv = yoco_kv
813
+ if self.yoco_cross:
814
+ self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
815
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
816
+ else:
817
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
818
+
819
+ self.conv1d = nn.Conv1d(
820
+ in_channels=self.d_inner,
821
+ out_channels=self.d_inner,
822
+ bias=conv_bias,
823
+ kernel_size=d_conv,
824
+ groups=self.d_inner,
825
+ padding=d_conv - 1,
826
+ **factory_kwargs,
827
+ )
828
+
829
+ self.activation = "silu"
830
+ self.act = nn.SiLU()
831
+
832
+ self.x_proj = nn.Linear(
833
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
834
+ )
835
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
836
+
837
+ # S4D real initialization
838
+ A = repeat(
839
+ torch.arange(1, self.d_state + 1, dtype=torch.float32),
840
+ "n -> d n",
841
+ d=self.d_inner,
842
+ ).contiguous()
843
+ A_log = torch.log(A) # Keep A_log in fp32
844
+ self.A_log = nn.Parameter(A_log)
845
+
846
+ # D "skip" parameter
847
+ self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32
848
+
849
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
850
+
851
+ def forward(self, hidden_states, inference_params=None, mask= None, yoco_key_values = None, cache_position = None):
852
+ """
853
+ hidden_states: (B, L, D)
854
+ Returns: same shape as hidden_states
855
+ """
856
+
857
+ if self.yoco_cross:
858
+ out = self.in_proj(hidden_states)
859
+ out = swiglu(out, yoco_key_values)
860
+ out = self.out_proj(out)
861
+ return out, yoco_key_values
862
+
863
+ batch, seqlen, _ = hidden_states.shape
864
+ conv_state, ssm_state = None, None
865
+ if inference_params is not None:
866
+ conv_state, ssm_state = self._get_states_from_cache(inference_params)
867
+ if cache_position[0] > 0: #inference_params.get_seq_length(self.layer_idx) > 0:
868
+ # The states are updated inplace
869
+ out, _, _, yoco_key_values = self.step(hidden_states, conv_state, ssm_state, yoco_key_values)
870
+ return out, yoco_key_values
871
+
872
+ # We do matmul and transpose BLH -> HBL at the same time
873
+ xz = rearrange(
874
+ self.in_proj.weight @ rearrange(hidden_states.to(dtype = self.in_proj.weight.dtype), "b l d -> d (b l)"),
875
+ "d (b l) -> b d l",
876
+ l=seqlen,
877
+ )
878
+ if self.in_proj.bias is not None:
879
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
880
+
881
+
882
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
883
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
884
+ if (not self.yoco_kv) and self.use_fast_path and inference_params is None: # Doesn't support outputting the states
885
+ out = mamba_inner_fn(
886
+ xz,
887
+ self.conv1d.weight,
888
+ self.conv1d.bias,
889
+ self.x_proj.weight,
890
+ self.dt_proj.weight,
891
+ self.out_proj.weight,
892
+ self.out_proj.bias,
893
+ A,
894
+ None, # input-dependent B
895
+ None, # input-dependent C
896
+ self.D.float(),
897
+ delta_bias=self.dt_proj.bias.float(),
898
+ mask=mask,
899
+ delta_softplus=True,
900
+ )
901
+ else:
902
+ x, z = xz.chunk(2, dim=1)
903
+ if self.yoco_kv:
904
+ z = z.transpose(-1,-2).contiguous()
905
+ if mask is not None:
906
+ x = x * mask.unsqueeze(1)
907
+ # Compute short convolution
908
+ if conv_state is not None:
909
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
910
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
911
+ conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
912
+ if causal_conv1d_fn is None:
913
+ x = self.act(self.conv1d(x)[..., :seqlen])
914
+ else:
915
+ assert self.activation in ["silu", "swish"]
916
+ x = causal_conv1d_fn(
917
+ x=x,
918
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
919
+ bias=self.conv1d.bias,
920
+ activation=self.activation,
921
+ )
922
+ if mask is not None:
923
+ x = x * mask.unsqueeze(1)
924
+ # We're careful here about the layout, to avoid extra transposes.
925
+ # We want dt to have d as the slowest moving dimension
926
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
927
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
928
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
929
+ dt = self.dt_proj.weight @ dt.t()
930
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
931
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
932
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
933
+ assert self.activation in ["silu", "swish"]
934
+ y = selective_scan_fn(
935
+ x,
936
+ dt,
937
+ A,
938
+ B,
939
+ C,
940
+ self.D.float(),
941
+ z= None if self.yoco_kv else z,
942
+ delta_bias=self.dt_proj.bias.float(),
943
+ delta_softplus=True,
944
+ return_last_state=ssm_state is not None,
945
+ )
946
+ if ssm_state is not None:
947
+ y, last_state = y
948
+ ssm_state.copy_(last_state)
949
+ y = rearrange(y, "b d l -> b l d")
950
+ if self.yoco_kv:
951
+ yoco_key_values = y
952
+ y = swiglu(z, y)
953
+ out = self.out_proj(y)
954
+ return out, yoco_key_values
955
+
956
+ def step(self, hidden_states, conv_state, ssm_state, yoco_key_values):
957
+ dtype = hidden_states.dtype
958
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
959
+ xz = self.in_proj(hidden_states.to(dtype = self.in_proj.weight.dtype).squeeze(1)) # (B 2D)
960
+ x, z = xz.chunk(2, dim=-1) # (B D)
961
+
962
+ # Conv step
963
+ if causal_conv1d_update is None:
964
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
965
+ conv_state[:, :, -1] = x
966
+ x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
967
+ if self.conv1d.bias is not None:
968
+ x = x + self.conv1d.bias
969
+ x = self.act(x).to(dtype=dtype)
970
+ else:
971
+ x = causal_conv1d_update(
972
+ x,
973
+ conv_state,
974
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
975
+ self.conv1d.bias,
976
+ self.activation,
977
+ )
978
+
979
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
980
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
981
+ # Don't add dt_bias here
982
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
983
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
984
+
985
+ # SSM step
986
+ if selective_state_update is None:
987
+ # Discretize A and B
988
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
989
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
990
+ dB = torch.einsum("bd,bn->bdn", dt, B)
991
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
992
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
993
+ y = y + self.D.to(dtype) * x
994
+ y = y * self.act(z) # (B D)
995
+ else:
996
+ y = selective_state_update(
997
+ ssm_state, x, dt, A, B, C, self.D, z= None if self.yoco_kv else z, dt_bias=self.dt_proj.bias, dt_softplus=True
998
+ )
999
+ if self.yoco_kv:
1000
+ yoco_key_values = y.unsqueeze(1)
1001
+ y = swiglu(z, y)
1002
+ out = self.out_proj(y)
1003
+ return out.unsqueeze(1), conv_state, ssm_state, yoco_key_values
1004
+
1005
+ def _get_states_from_cache(self, inference_params):
1006
+ conv_state, ssm_state = inference_params.key_cache[self.layer_idx], inference_params.value_cache[self.layer_idx]
1007
+ return conv_state, ssm_state
1008
+
1009
+
1010
+
1011
+
1012
+ class SambaYDecoderLayer(nn.Module):
1013
+ def __init__(self, config: Phi4FlashConfig, layer_idx: int):
1014
+ super().__init__()
1015
+
1016
+ self.mlp = SambaYMLP(config)
1017
+ self.input_layernorm = PHI_NORM_CLASS(config.hidden_size, eps=config.layer_norm_eps)
1018
+
1019
+ self.yoco_kv = False
1020
+ self.yoco_cross = False
1021
+ self.yoco_mb = False
1022
+ self.layer_idx = layer_idx
1023
+ assert config.num_hidden_layers % 4 == 0, 'n_layer should be divisible by 4 for SambaY '
1024
+ if layer_idx >= config.num_hidden_layers//2:
1025
+ self.yoco_mb = True
1026
+ self.yoco_kv = (layer_idx >= (config.num_hidden_layers//2 +1))
1027
+ self.yoco_cross = (layer_idx >= (config.num_hidden_layers//2 +2))
1028
+ if (layer_idx >= (config.num_hidden_layers//2 +1)):
1029
+ config = copy.deepcopy(config)
1030
+ config.sliding_window = None
1031
+ self.config= config
1032
+
1033
+ self.use_mamba = config.mb_per_layer > 0 and layer_idx % config.mb_per_layer == 0
1034
+ if self.use_mamba:
1035
+ factory_kwargs = {"d_conv": config.mamba_d_conv, "d_state": config.mamba_d_state, "expand": config.mamba_expand , "dtype": None}
1036
+ self.attn = Phi3Mamba(config.hidden_size, layer_idx=layer_idx, yoco_cross=self.yoco_cross, yoco_kv=self.yoco_mb, **factory_kwargs)
1037
+ else:
1038
+ self.attn = SambaYFlashAttention2(config, layer_idx=layer_idx, yoco_cross=self.yoco_cross)
1039
+
1040
+ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
1041
+ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
1042
+ self.post_attention_layernorm = PHI_NORM_CLASS(config.hidden_size, eps=config.layer_norm_eps)
1043
+
1044
+ def forward(
1045
+ self,
1046
+ hidden_states: torch.Tensor,
1047
+ attention_mask: Optional[torch.Tensor] = None,
1048
+ position_ids: Optional[torch.LongTensor] = None,
1049
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1050
+ output_attentions: Optional[bool] = False,
1051
+ use_cache: Optional[bool] = False,
1052
+ cache_position: Optional[torch.LongTensor] = None,
1053
+ ssm_output: Optional[torch.Tensor] = None,
1054
+ yoco_key_values: Optional[torch.Tensor] = None,
1055
+ **kwargs,
1056
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1057
+ """
1058
+ Args:
1059
+ hidden_states (`torch.FloatTensor`):
1060
+ input to the layer of shape `(batch, seq_len, embed_dim)`
1061
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1062
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
1063
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
1064
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
1065
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
1066
+ output_attentions (`bool`, *optional*):
1067
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1068
+ returned tensors for more detail.
1069
+ use_cache (`bool`, *optional*):
1070
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1071
+ (see `past_key_values`).
1072
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1073
+ """
1074
+
1075
+ residual = hidden_states
1076
+
1077
+ hidden_states = self.input_layernorm(hidden_states.to(dtype=self.input_layernorm.weight.dtype))
1078
+
1079
+ if self.use_mamba:
1080
+ attn_outputs, ssm_output = self.attn(
1081
+ hidden_states, inference_params=past_key_value,
1082
+ mask = attention_mask, yoco_key_values = ssm_output,
1083
+ cache_position=cache_position,
1084
+ )
1085
+ residual = residual.to(torch.float32)
1086
+ self_attn_weights = None
1087
+ else:
1088
+ if self.config.sliding_window is not None and self.config.sliding_window[self.layer_idx] is not None and attention_mask is not None: # efficient SDPA and no padding
1089
+ if past_key_value is not None and cache_position[0] > 0: # when decoding
1090
+ attention_mask = attention_mask[:, -self.config.sliding_window[self.layer_idx]:]
1091
+ #hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
1092
+ # Self Attention
1093
+ attn_outputs, self_attn_weights, yoco_key_values = self.attn(
1094
+ hidden_states=hidden_states,
1095
+ attention_mask=attention_mask,
1096
+ position_ids=position_ids,
1097
+ past_key_value=past_key_value,
1098
+ output_attentions=output_attentions,
1099
+ use_cache=use_cache,
1100
+ cache_position=cache_position,
1101
+ yoco_key_values = yoco_key_values,
1102
+ )
1103
+
1104
+ hidden_states = residual + self.resid_attn_dropout(attn_outputs)
1105
+
1106
+ residual = hidden_states
1107
+ hidden_states = self.post_attention_layernorm(hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype))
1108
+ hidden_states = self.mlp(hidden_states)
1109
+ hidden_states = residual + self.resid_mlp_dropout(hidden_states)
1110
+
1111
+ outputs = (hidden_states,)
1112
+ outputs += (ssm_output,)
1113
+ outputs += (yoco_key_values,)
1114
+ if output_attentions:
1115
+ outputs += (self_attn_weights,)
1116
+
1117
+ return outputs
1118
+
1119
+
1120
+ PHI_START_DOCSTRING = r"""
1121
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1122
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1123
+ etc.)
1124
+
1125
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1126
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1127
+ and behavior.
1128
+
1129
+ Parameters:
1130
+ config ([`Phi4FlashConfig`]):
1131
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1132
+ load the weights associated with the model, only the configuration. Check out the
1133
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1134
+ """
1135
+
1136
+
1137
+ @add_start_docstrings(
1138
+ "The bare Phi4Flash Model outputting raw hidden-states without any specific head on top.",
1139
+ PHI_START_DOCSTRING,
1140
+ )
1141
+ class Phi4FlashPreTrainedModel(PreTrainedModel):
1142
+ config_class = Phi4FlashConfig
1143
+ base_model_prefix = "model"
1144
+ supports_gradient_checkpointing = True
1145
+ _no_split_modules = ["SambaYDecoderLayer"]
1146
+ _skip_keys_device_placement = "past_key_values"
1147
+ _supports_flash_attn_2 = True
1148
+ _supports_sdpa = False
1149
+ _supports_cache_class = True
1150
+
1151
+ def _init_weights(self, module):
1152
+ std = self.config.initializer_range
1153
+ if isinstance(module, nn.Linear):
1154
+ module.weight.data.normal_(mean=0.0, std=std)
1155
+ if module.bias is not None:
1156
+ module.bias.data.zero_()
1157
+ elif isinstance(module, nn.Embedding):
1158
+ module.weight.data.normal_(mean=0.0, std=std)
1159
+ if module.padding_idx is not None:
1160
+ module.weight.data[module.padding_idx].zero_()
1161
+
1162
+
1163
+ PHI_INPUTS_DOCSTRING = r"""
1164
+ Args:
1165
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1166
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1167
+ it.
1168
+
1169
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1170
+ [`PreTrainedTokenizer.__call__`] for details.
1171
+
1172
+ [What are input IDs?](../glossary#input-ids)
1173
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1174
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1175
+
1176
+ - 1 for tokens that are **not masked**,
1177
+ - 0 for tokens that are **masked**.
1178
+
1179
+ [What are attention masks?](../glossary#attention-mask)
1180
+
1181
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1182
+ [`PreTrainedTokenizer.__call__`] for details.
1183
+
1184
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1185
+ `past_key_values`).
1186
+
1187
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1188
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1189
+ information on the default strategy.
1190
+
1191
+ - 1 indicates the head is **not masked**,
1192
+ - 0 indicates the head is **masked**.
1193
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1194
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1195
+ config.n_positions - 1]`.
1196
+
1197
+ [What are position IDs?](../glossary#position-ids)
1198
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1199
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1200
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1201
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1202
+
1203
+ Two formats are allowed:
1204
+ - a [`~cache_utils.Cache`] instance;
1205
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1206
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1207
+ cache format.
1208
+
1209
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1210
+ legacy cache format will be returned.
1211
+
1212
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1213
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1214
+ of shape `(batch_size, sequence_length)`.
1215
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1216
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1217
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1218
+ model's internal embedding lookup matrix.
1219
+ use_cache (`bool`, *optional*):
1220
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1221
+ `past_key_values`).
1222
+ output_attentions (`bool`, *optional*):
1223
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1224
+ tensors for more detail.
1225
+ output_hidden_states (`bool`, *optional*):
1226
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1227
+ more detail.
1228
+ return_dict (`bool`, *optional*):
1229
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1230
+ """
1231
+
1232
+
1233
+ @add_start_docstrings(
1234
+ "The bare Phi4Flash Model outputting raw hidden-states without any specific head on top.",
1235
+ PHI_START_DOCSTRING,
1236
+ )
1237
+ class Phi4FlashModel(Phi4FlashPreTrainedModel):
1238
+ """
1239
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SambaYDecoderLayer`]
1240
+
1241
+ Args:
1242
+ config: Phi4FlashConfig
1243
+ """
1244
+
1245
+ def __init__(self, config: Phi4FlashConfig):
1246
+ super().__init__(config)
1247
+ self.padding_idx = config.pad_token_id
1248
+ self.vocab_size = config.vocab_size
1249
+
1250
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1251
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
1252
+ self.layers = nn.ModuleList(
1253
+ [SambaYDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1254
+ )
1255
+ self.final_layernorm = PHI_NORM_CLASS(config.hidden_size, eps=config.layer_norm_eps)
1256
+
1257
+ self._attn_implementation = config._attn_implementation
1258
+
1259
+ self.gradient_checkpointing = False
1260
+ # Initialize weights and apply final processing
1261
+ self.post_init()
1262
+
1263
+ def get_input_embeddings(self):
1264
+ return self.embed_tokens
1265
+
1266
+ def set_input_embeddings(self, value):
1267
+ self.embed_tokens = value
1268
+
1269
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1270
+ def forward(
1271
+ self,
1272
+ input_ids: torch.LongTensor = None,
1273
+ attention_mask: Optional[torch.Tensor] = None,
1274
+ position_ids: Optional[torch.LongTensor] = None,
1275
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1276
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1277
+ use_cache: Optional[bool] = None,
1278
+ output_attentions: Optional[bool] = None,
1279
+ output_hidden_states: Optional[bool] = None,
1280
+ return_dict: Optional[bool] = None,
1281
+ cache_position: Optional[torch.LongTensor] = None,
1282
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1283
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1284
+ output_hidden_states = (
1285
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1286
+ )
1287
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1288
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1289
+
1290
+ # retrieve input_ids and inputs_embeds
1291
+ if input_ids is not None and inputs_embeds is not None:
1292
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1293
+ elif input_ids is not None:
1294
+ batch_size, seq_length = input_ids.shape[:2]
1295
+ elif inputs_embeds is not None:
1296
+ batch_size, seq_length = inputs_embeds.shape[:2]
1297
+ else:
1298
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1299
+
1300
+
1301
+ if self.gradient_checkpointing and self.training:
1302
+ if use_cache:
1303
+ logger.warning_once(
1304
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1305
+ )
1306
+ use_cache = False
1307
+
1308
+ if inputs_embeds is None:
1309
+ inputs_embeds = self.embed_tokens(input_ids)
1310
+
1311
+ if use_cache and past_key_values is None and not self.training:
1312
+ batch_size, seq_len, _ = inputs_embeds.shape
1313
+ past_key_values = SambaYCache(
1314
+ self.config,
1315
+ max_batch_size=batch_size,
1316
+ max_cache_len=seq_len,
1317
+ device=self.device,
1318
+ dtype=inputs_embeds.dtype,
1319
+ )
1320
+
1321
+
1322
+ if cache_position is None:
1323
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1324
+ cache_position = torch.arange(
1325
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1326
+ )
1327
+
1328
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache and not self.training:
1329
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1330
+ if is_padding_right:
1331
+ raise ValueError(
1332
+ "You are attempting to perform batched generation with padding_side='right'"
1333
+ " this may lead to unexpected behaviour for Flash Attention version of Phi4Flash. Make sure to "
1334
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1335
+ )
1336
+
1337
+ hidden_states = inputs_embeds
1338
+
1339
+ # decoder layers
1340
+ all_hidden_states = () if output_hidden_states else None
1341
+ all_self_attns = () if output_attentions else None
1342
+ ssm_output = None
1343
+ yoco_key_values = None
1344
+ for decoder_layer in self.layers: # TODO: only need to inference the first half of the layers during pre-fill
1345
+ if output_hidden_states:
1346
+ all_hidden_states += (hidden_states,)
1347
+
1348
+ if self.gradient_checkpointing and self.training:
1349
+ layer_outputs = self._gradient_checkpointing_func(
1350
+ decoder_layer.__call__,
1351
+ hidden_states,
1352
+ attention_mask,
1353
+ position_ids,
1354
+ past_key_values,
1355
+ output_attentions,
1356
+ use_cache,
1357
+ cache_position,
1358
+ ssm_output,
1359
+ yoco_key_values,
1360
+ )
1361
+ else:
1362
+ layer_outputs = decoder_layer(
1363
+ hidden_states,
1364
+ attention_mask=attention_mask,
1365
+ position_ids=position_ids,
1366
+ past_key_value=past_key_values,
1367
+ output_attentions=output_attentions,
1368
+ use_cache=use_cache,
1369
+ cache_position = cache_position,
1370
+ ssm_output = ssm_output,
1371
+ yoco_key_values = yoco_key_values,
1372
+ )
1373
+
1374
+ hidden_states = layer_outputs[0]
1375
+ ssm_output = layer_outputs[1]
1376
+ yoco_key_values = layer_outputs[2]
1377
+
1378
+ if output_attentions:
1379
+ all_self_attns += (layer_outputs[3],)
1380
+
1381
+ hidden_states = self.final_layernorm(hidden_states.to(dtype=self.final_layernorm.weight.dtype))
1382
+
1383
+ # add hidden states from the last decoder layer
1384
+ if output_hidden_states:
1385
+ all_hidden_states += (hidden_states,)
1386
+
1387
+ output = BaseModelOutputWithPast(
1388
+ last_hidden_state=hidden_states,
1389
+ past_key_values=past_key_values,
1390
+ hidden_states=all_hidden_states,
1391
+ attentions=all_self_attns,
1392
+ )
1393
+ return output if return_dict else output.to_tuple()
1394
+
1395
+
1396
+
1397
+ class Phi4FlashForCausalLM(Phi4FlashPreTrainedModel, GenerationMixin):
1398
+ _tied_weights_keys = ["lm_head.weight"]
1399
+
1400
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi4Flash,bias=False->bias=True
1401
+ def __init__(self, config):
1402
+ super().__init__(config)
1403
+ self.model = Phi4FlashModel(config)
1404
+ self.vocab_size = config.vocab_size
1405
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1406
+
1407
+ # Initialize weights and apply final processing
1408
+ self.post_init()
1409
+
1410
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1411
+ def get_input_embeddings(self):
1412
+ return self.model.embed_tokens
1413
+
1414
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1415
+ def set_input_embeddings(self, value):
1416
+ self.model.embed_tokens = value
1417
+
1418
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1419
+ def get_output_embeddings(self):
1420
+ return self.lm_head
1421
+
1422
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1423
+ def set_output_embeddings(self, new_embeddings):
1424
+ self.lm_head = new_embeddings
1425
+
1426
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1427
+ def set_decoder(self, decoder):
1428
+ self.model = decoder
1429
+
1430
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1431
+ def get_decoder(self):
1432
+ return self.model
1433
+
1434
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1435
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1436
+ def forward(
1437
+ self,
1438
+ input_ids: torch.LongTensor = None,
1439
+ attention_mask: Optional[torch.Tensor] = None,
1440
+ position_ids: Optional[torch.LongTensor] = None,
1441
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1442
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1443
+ labels: Optional[torch.LongTensor] = None,
1444
+ use_cache: Optional[bool] = None,
1445
+ output_attentions: Optional[bool] = None,
1446
+ output_hidden_states: Optional[bool] = None,
1447
+ return_dict: Optional[bool] = None,
1448
+ cache_position: Optional[torch.LongTensor] = None,
1449
+ num_logits_to_keep: int = 0,
1450
+ **loss_kwargs,
1451
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1452
+ r"""
1453
+ Args:
1454
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1455
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1456
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1457
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1458
+
1459
+ Returns:
1460
+
1461
+ Example:
1462
+
1463
+ ```python
1464
+ >>> from transformers import AutoTokenizer, Phi4FlashForCausalLM
1465
+
1466
+ >>> model = Phi4FlashForCausalLM.from_pretrained("microsoft/Phi4-mini-flash-reasoning")
1467
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi4-mini-flash-reasoning")
1468
+
1469
+ >>> prompt = "This is an example script ."
1470
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1471
+
1472
+ >>> # Generate
1473
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1474
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1475
+ 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1476
+ ```"""
1477
+
1478
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1479
+ output_hidden_states = (
1480
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1481
+ )
1482
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1483
+
1484
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1485
+ outputs = self.model(
1486
+ input_ids=input_ids,
1487
+ attention_mask=attention_mask,
1488
+ position_ids=position_ids,
1489
+ past_key_values=past_key_values,
1490
+ inputs_embeds=inputs_embeds,
1491
+ use_cache=use_cache,
1492
+ output_attentions=output_attentions,
1493
+ output_hidden_states=output_hidden_states,
1494
+ return_dict=return_dict,
1495
+ cache_position = cache_position,
1496
+ )
1497
+
1498
+ hidden_states = outputs[0]
1499
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1500
+
1501
+ loss = None
1502
+ if labels is not None:
1503
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
1504
+
1505
+ if not return_dict:
1506
+ output = (logits,) + outputs[1:]
1507
+ return (loss,) + output if loss is not None else output
1508
+
1509
+ return CausalLMOutputWithPast(
1510
+ loss=loss,
1511
+ logits=logits,
1512
+ past_key_values=outputs.past_key_values,
1513
+ hidden_states=outputs.hidden_states,
1514
+ attentions=outputs.attentions,
1515
+ )
1516
+
1517
+
1518
+ @add_start_docstrings(
1519
+ """
1520
+ The Phi4FlashModel with a sequence classification head on top (linear layer).
1521
+
1522
+ [`Phi4FlashForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1523
+ (e.g. GPT-2) do.
1524
+
1525
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1526
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1527
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1528
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1529
+ each row of the batch).
1530
+ """,
1531
+ PHI_START_DOCSTRING,
1532
+ )
1533
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi4Flash with self.transformer->self.model, transformer_outputs->model_outputs
1534
+ class Phi4FlashForSequenceClassification(Phi4FlashPreTrainedModel):
1535
+ def __init__(self, config):
1536
+ super().__init__(config)
1537
+ self.num_labels = config.num_labels
1538
+ self.model = Phi4FlashModel(config)
1539
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1540
+
1541
+ # Initialize weights and apply final processing
1542
+ self.post_init()
1543
+
1544
+ def get_input_embeddings(self):
1545
+ return self.model.embed_tokens
1546
+
1547
+ def set_input_embeddings(self, value):
1548
+ self.model.embed_tokens = value
1549
+
1550
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1551
+ def forward(
1552
+ self,
1553
+ input_ids: torch.LongTensor = None,
1554
+ attention_mask: Optional[torch.Tensor] = None,
1555
+ position_ids: Optional[torch.LongTensor] = None,
1556
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1557
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1558
+ labels: Optional[torch.LongTensor] = None,
1559
+ use_cache: Optional[bool] = None,
1560
+ output_attentions: Optional[bool] = None,
1561
+ output_hidden_states: Optional[bool] = None,
1562
+ return_dict: Optional[bool] = None,
1563
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1564
+ r"""
1565
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1566
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1567
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1568
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1569
+ """
1570
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1571
+
1572
+ model_outputs = self.model(
1573
+ input_ids,
1574
+ attention_mask=attention_mask,
1575
+ position_ids=position_ids,
1576
+ past_key_values=past_key_values,
1577
+ inputs_embeds=inputs_embeds,
1578
+ use_cache=use_cache,
1579
+ output_attentions=output_attentions,
1580
+ output_hidden_states=output_hidden_states,
1581
+ return_dict=return_dict,
1582
+ )
1583
+ hidden_states = model_outputs[0]
1584
+ logits = self.score(hidden_states)
1585
+
1586
+ if input_ids is not None:
1587
+ batch_size = input_ids.shape[0]
1588
+ else:
1589
+ batch_size = inputs_embeds.shape[0]
1590
+
1591
+ if self.config.pad_token_id is None and batch_size != 1:
1592
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1593
+ if self.config.pad_token_id is None:
1594
+ sequence_lengths = -1
1595
+ else:
1596
+ if input_ids is not None:
1597
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1598
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1599
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1600
+ sequence_lengths = sequence_lengths.to(logits.device)
1601
+ else:
1602
+ sequence_lengths = -1
1603
+
1604
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1605
+
1606
+ loss = None
1607
+ if labels is not None:
1608
+ labels = labels.to(logits.device)
1609
+ if self.config.problem_type is None:
1610
+ if self.num_labels == 1:
1611
+ self.config.problem_type = "regression"
1612
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1613
+ self.config.problem_type = "single_label_classification"
1614
+ else:
1615
+ self.config.problem_type = "multi_label_classification"
1616
+
1617
+ if self.config.problem_type == "regression":
1618
+ loss_fct = MSELoss()
1619
+ if self.num_labels == 1:
1620
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1621
+ else:
1622
+ loss = loss_fct(pooled_logits, labels)
1623
+ elif self.config.problem_type == "single_label_classification":
1624
+ loss_fct = CrossEntropyLoss()
1625
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1626
+ elif self.config.problem_type == "multi_label_classification":
1627
+ loss_fct = BCEWithLogitsLoss()
1628
+ loss = loss_fct(pooled_logits, labels)
1629
+ if not return_dict:
1630
+ output = (pooled_logits,) + model_outputs[1:]
1631
+ return ((loss,) + output) if loss is not None else output
1632
+
1633
+ return SequenceClassifierOutputWithPast(
1634
+ loss=loss,
1635
+ logits=pooled_logits,
1636
+ past_key_values=model_outputs.past_key_values,
1637
+ hidden_states=model_outputs.hidden_states,
1638
+ attentions=model_outputs.attentions,
1639
+ )
1640
+
1641
+
1642
+ @add_start_docstrings(
1643
+ """
1644
+ Phi4FlashModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1645
+ Named-Entity-Recognition (NER) tasks.
1646
+ """,
1647
+ PHI_START_DOCSTRING,
1648
+ )
1649
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi4Flash,self.transformer->self.model,transformer_outputs->model_outputs
1650
+ class Phi4FlashForTokenClassification(Phi4FlashPreTrainedModel):
1651
+ def __init__(self, config: Phi4FlashConfig):
1652
+ super().__init__(config)
1653
+ self.num_labels = config.num_labels
1654
+
1655
+ self.model = Phi4FlashModel(config)
1656
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1657
+ classifier_dropout = config.classifier_dropout
1658
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1659
+ classifier_dropout = config.hidden_dropout
1660
+ else:
1661
+ classifier_dropout = 0.1
1662
+ self.dropout = nn.Dropout(classifier_dropout)
1663
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1664
+
1665
+ # Initialize weights and apply final processing
1666
+ self.post_init()
1667
+
1668
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1669
+ @add_code_sample_docstrings(
1670
+ checkpoint=_CHECKPOINT_FOR_DOC,
1671
+ output_type=TokenClassifierOutput,
1672
+ config_class=_CONFIG_FOR_DOC,
1673
+ )
1674
+ def forward(
1675
+ self,
1676
+ input_ids: Optional[torch.LongTensor] = None,
1677
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1678
+ attention_mask: Optional[torch.Tensor] = None,
1679
+ inputs_embeds: Optional[torch.Tensor] = None,
1680
+ labels: Optional[torch.Tensor] = None,
1681
+ use_cache: Optional[bool] = None,
1682
+ output_attentions: Optional[bool] = None,
1683
+ output_hidden_states: Optional[bool] = None,
1684
+ return_dict: Optional[bool] = None,
1685
+ **deprecated_arguments,
1686
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1687
+ r"""
1688
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1689
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1690
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1691
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1692
+ """
1693
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1694
+
1695
+ model_outputs = self.model(
1696
+ input_ids,
1697
+ past_key_values=past_key_values,
1698
+ attention_mask=attention_mask,
1699
+ inputs_embeds=inputs_embeds,
1700
+ use_cache=use_cache,
1701
+ output_attentions=output_attentions,
1702
+ output_hidden_states=output_hidden_states,
1703
+ return_dict=return_dict,
1704
+ )
1705
+
1706
+ hidden_states = model_outputs[0]
1707
+ hidden_states = self.dropout(hidden_states)
1708
+ logits = self.classifier(hidden_states)
1709
+
1710
+ loss = None
1711
+ if labels is not None:
1712
+ # move labels to correct device to enable model parallelism
1713
+ labels = labels.to(logits.device)
1714
+ batch_size, seq_length = labels.shape
1715
+ loss_fct = CrossEntropyLoss()
1716
+ loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
1717
+
1718
+ if not return_dict:
1719
+ output = (logits,) + model_outputs[2:]
1720
+ return ((loss,) + output) if loss is not None else output
1721
+
1722
+ return TokenClassifierOutput(
1723
+ loss=loss,
1724
+ logits=logits,
1725
+ hidden_states=model_outputs.hidden_states,
1726
+ attentions=model_outputs.attentions,
1727
+ )
1728
+
1729
+ ## support batched generation
1730
+
1731
+ class SelectiveScanFn(torch.autograd.Function):
1732
+
1733
+ @staticmethod
1734
+ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
1735
+ return_last_state=False):
1736
+ if u.stride(-1) != 1:
1737
+ u = u.contiguous()
1738
+ if delta.stride(-1) != 1:
1739
+ delta = delta.contiguous()
1740
+ if D is not None:
1741
+ D = D.contiguous()
1742
+ if B.stride(-1) != 1:
1743
+ B = B.contiguous()
1744
+ if C.stride(-1) != 1:
1745
+ C = C.contiguous()
1746
+ if z is not None and z.stride(-1) != 1:
1747
+ z = z.contiguous()
1748
+ if B.dim() == 3:
1749
+ B = rearrange(B, "b dstate l -> b 1 dstate l")
1750
+ ctx.squeeze_B = True
1751
+ if C.dim() == 3:
1752
+ C = rearrange(C, "b dstate l -> b 1 dstate l")
1753
+ ctx.squeeze_C = True
1754
+ out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
1755
+ ctx.delta_softplus = delta_softplus
1756
+ ctx.has_z = z is not None
1757
+ last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
1758
+ if not ctx.has_z:
1759
+ ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
1760
+ return out if not return_last_state else (out, last_state)
1761
+ else:
1762
+ ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
1763
+ out_z = rest[0]
1764
+ return out_z if not return_last_state else (out_z, last_state)
1765
+
1766
+ @staticmethod
1767
+ def backward(ctx, dout, *args):
1768
+ if not ctx.has_z:
1769
+ u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
1770
+ z = None
1771
+ out = None
1772
+ else:
1773
+ u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
1774
+ if dout.stride(-1) != 1:
1775
+ dout = dout.contiguous()
1776
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
1777
+ # backward of selective_scan_cuda with the backward of chunk).
1778
+ # Here we just pass in None and dz will be allocated in the C++ code.
1779
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
1780
+ u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
1781
+ False # option to recompute out_z, not used here
1782
+ )
1783
+ dz = rest[0] if ctx.has_z else None
1784
+ dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
1785
+ dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
1786
+ return (du, ddelta, dA, dB, dC,
1787
+ dD if D is not None else None,
1788
+ dz,
1789
+ ddelta_bias if delta_bias is not None else None,
1790
+ None,
1791
+ None)
1792
+
1793
+
1794
+ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
1795
+ return_last_state=False):
1796
+ """if return_last_state is True, returns (out, last_state)
1797
+ last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
1798
+ not considered in the backward pass.
1799
+ """
1800
+ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
1801
+
1802
+
1803
+ class MambaInnerFn(torch.autograd.Function):
1804
+
1805
+ @staticmethod
1806
+ @custom_fwd(device_type="cuda")
1807
+ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
1808
+ out_proj_weight, out_proj_bias,
1809
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
1810
+ C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=1,):
1811
+ """
1812
+ xz: (batch, dim, seqlen)
1813
+ """
1814
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
1815
+ assert checkpoint_lvl in [0, 1]
1816
+ L = xz.shape[-1]
1817
+ delta_rank = delta_proj_weight.shape[1]
1818
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
1819
+ if torch.is_autocast_enabled():
1820
+ x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
1821
+ delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
1822
+ out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
1823
+ out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
1824
+ if out_proj_bias is not None else None)
1825
+ if xz.stride(-1) != 1:
1826
+ xz = xz.contiguous()
1827
+ conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
1828
+ x, z = xz.chunk(2, dim=1)
1829
+ if mask is not None:
1830
+ x = x * mask.unsqueeze(1)
1831
+ conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
1832
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
1833
+ x, conv1d_weight, conv1d_bias, None, None, None, True
1834
+ )
1835
+ if mask is not None:
1836
+ conv1d_out = conv1d_out * mask.unsqueeze(1)
1837
+ # We're being very careful here about the layout, to avoid extra transposes.
1838
+ # We want delta to have d as the slowest moving dimension
1839
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
1840
+ x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
1841
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
1842
+ ctx.is_variable_B = B is None
1843
+ ctx.is_variable_C = C is None
1844
+ ctx.B_proj_bias_is_None = B_proj_bias is None
1845
+ ctx.C_proj_bias_is_None = C_proj_bias is None
1846
+ if B is None: # variable B
1847
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
1848
+ if B_proj_bias is not None:
1849
+ B = B + B_proj_bias.to(dtype=B.dtype)
1850
+ if not A.is_complex():
1851
+ # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
1852
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
1853
+ else:
1854
+ B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
1855
+ else:
1856
+ if B.stride(-1) != 1:
1857
+ B = B.contiguous()
1858
+ if C is None: # variable C
1859
+ C = x_dbl[:, -d_state:] # (bl dstate)
1860
+ if C_proj_bias is not None:
1861
+ C = C + C_proj_bias.to(dtype=C.dtype)
1862
+ if not A.is_complex():
1863
+ # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
1864
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
1865
+ else:
1866
+ C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
1867
+ else:
1868
+ if C.stride(-1) != 1:
1869
+ C = C.contiguous()
1870
+ if D is not None:
1871
+ D = D.contiguous()
1872
+ out, scan_intermediates, out_z = selective_scan_cuda.fwd(
1873
+ conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
1874
+ )
1875
+ ctx.delta_softplus = delta_softplus
1876
+ ctx.out_proj_bias_is_None = out_proj_bias is None
1877
+ ctx.checkpoint_lvl = checkpoint_lvl
1878
+ if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
1879
+ conv1d_out, delta = None, None
1880
+ ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
1881
+ delta_proj_weight, out_proj_weight, conv1d_out, delta,
1882
+ A, B, C, D, delta_bias, scan_intermediates, out)
1883
+ return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
1884
+
1885
+ @staticmethod
1886
+ @custom_bwd(device_type="cuda")
1887
+ def backward(ctx, dout):
1888
+ # dout: (batch, seqlen, dim)
1889
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
1890
+ (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
1891
+ conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
1892
+ L = xz.shape[-1]
1893
+ delta_rank = delta_proj_weight.shape[1]
1894
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
1895
+ x, z = xz.chunk(2, dim=1)
1896
+ if dout.stride(-1) != 1:
1897
+ dout = dout.contiguous()
1898
+ if ctx.checkpoint_lvl == 1:
1899
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
1900
+ x, conv1d_weight, conv1d_bias, None, None, None, True
1901
+ )
1902
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
1903
+ "d (b l) -> b d l", l = L)
1904
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
1905
+ # backward of selective_scan_cuda with the backward of chunk).
1906
+ dxz = torch.empty_like(xz) # (batch, dim, seqlen)
1907
+ dx, dz = dxz.chunk(2, dim=1)
1908
+ dout = rearrange(dout, "b l e -> e (b l)")
1909
+ dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
1910
+ dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
1911
+ conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
1912
+ ctx.delta_softplus,
1913
+ True # option to recompute out_z
1914
+ )
1915
+ dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
1916
+ dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
1917
+ dD = dD if D is not None else None
1918
+ dx_dbl = torch.empty_like(x_dbl)
1919
+ dB_proj_bias = None
1920
+ if ctx.is_variable_B:
1921
+ if not A.is_complex():
1922
+ dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
1923
+ else:
1924
+ dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
1925
+ dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
1926
+ dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
1927
+ dB = None
1928
+ dC_proj_bias = None
1929
+ if ctx.is_variable_C:
1930
+ if not A.is_complex():
1931
+ dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
1932
+ else:
1933
+ dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
1934
+ dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
1935
+ dx_dbl[:, -d_state:] = dC # (bl d)
1936
+ dC = None
1937
+ ddelta = rearrange(ddelta, "b d l -> d (b l)")
1938
+ ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
1939
+ dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
1940
+ dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
1941
+ dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
1942
+ dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
1943
+ dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
1944
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
1945
+ # backward of conv1d with the backward of chunk).
1946
+ dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
1947
+ x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
1948
+ )
1949
+ dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
1950
+ dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
1951
+ return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
1952
+ dout_proj_weight, dout_proj_bias,
1953
+ dA, dB, dC, dD,
1954
+ ddelta_bias if delta_bias is not None else None,
1955
+ dB_proj_bias, dC_proj_bias, None, None)
1956
+
1957
+
1958
+ def mamba_inner_fn(
1959
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
1960
+ out_proj_weight, out_proj_bias,
1961
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
1962
+ C_proj_bias=None, mask=None, delta_softplus=True
1963
+ ):
1964
+ return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
1965
+ out_proj_weight, out_proj_bias,
1966
+ A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, mask, delta_softplus)
1967
+
1968
+
1969
+ def lambda_init_fn(depth):
1970
+ return 0.8 - 0.6 * math.exp(-0.3 * depth)
1971
+
1972
+
1973
+ def split_heads(x):
1974
+ # split by num_heads, the stripe pattern is friendly to tensor parallel.
1975
+ x = rearrange(x, "... (H two) D -> ... H two D", two=2)
1976
+ x1 = x[..., 0, :]
1977
+ x2 = x[..., 1, :]
1978
+ return x1, x2
1979
+
1980
+ class FlashDiffCustomAttention(nn.Module):
1981
+ """Implement the scaled dot product attention with softmax.
1982
+ Arguments
1983
+ ---------
1984
+ head_dim: The dimension of the heads.
1985
+ depth: The layer id, starting from 0.
1986
+ """
1987
+
1988
+ def __init__(
1989
+ self,
1990
+ head_dim,
1991
+ depth,
1992
+ fa_og = True,
1993
+ ):
1994
+ super().__init__()
1995
+ assert flash_attn_varlen_func is not None, "FlashAttention is not installed"
1996
+ assert flash_attn_func is not None, "FlashAttention is not installed"
1997
+ self.head_dim = head_dim
1998
+ self.fa_og = fa_og # turning it to false needs customized flash attention https://github.com/xiayuqing0622/flex_head_fa
1999
+ self.lambda_init = lambda_init_fn(depth)
2000
+ self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
2001
+ self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
2002
+ self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
2003
+ self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
2004
+
2005
+ self.subln = SambaYRMSNorm(2 * self.head_dim, eps=1e-5)
2006
+
2007
+ def forward(
2008
+ self,
2009
+ q,
2010
+ k,
2011
+ v,
2012
+ dropout_p = 0.0,
2013
+ cu_seqlens_q=None,
2014
+ max_seqlen_q=None,
2015
+ cu_seqlens_k=None,
2016
+ max_seqlen_k=None,
2017
+ softmax_scale=None,
2018
+ window_size=(-1, -1),
2019
+ causal=None,
2020
+ ):
2021
+ """Implements the multihead softmax attention.
2022
+ Arguments
2023
+ ---------
2024
+ q, k, v: The tensors containing the query, key, and value.
2025
+ If cu_seqlens is None and max_seqlen is None, then each has shape (B, S, H, D).
2026
+ If cu_seqlens is not None and max_seqlen is not None, then each has shape
2027
+ (total, H, D), where total is the sum of the sequence lengths in the batch.
2028
+ causal: if passed, will override self.causal
2029
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
2030
+ of the sequences in the batch, used to index into qkv.
2031
+ max_seqlen: int. Maximum sequence length in the batch.
2032
+ Returns:
2033
+ --------
2034
+ out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
2035
+ else (B, S, H, D).
2036
+ """
2037
+ q = q.to(torch.bfloat16)
2038
+ k = k.to(torch.bfloat16)
2039
+ v = v.to(torch.bfloat16)
2040
+
2041
+ assert q.dtype in [torch.float16, torch.bfloat16]
2042
+ assert q.is_cuda and k.is_cuda and v.is_cuda
2043
+ #causal = self.causal if causal is None else causal
2044
+ unpadded = cu_seqlens_q is not None
2045
+ q1, q2 = split_heads(q)
2046
+ k1, k2 = split_heads(k)
2047
+ if self.fa_og:
2048
+ v1, v2 = split_heads(v)
2049
+ else:
2050
+ v = rearrange(v, "... (H two) D -> ... H (two D)", two=2)
2051
+
2052
+ kwargs = {
2053
+ "dropout_p": dropout_p,
2054
+ "softmax_scale": softmax_scale,
2055
+ "causal": causal,
2056
+ "window_size": window_size,
2057
+ }
2058
+
2059
+ if unpadded:
2060
+ assert cu_seqlens_q.dtype == torch.int32
2061
+ assert max_seqlen_q is not None
2062
+ assert isinstance(max_seqlen_q, int)
2063
+ assert cu_seqlens_k is not None
2064
+ assert cu_seqlens_k.dtype == torch.int32
2065
+ assert max_seqlen_k is not None
2066
+ assert isinstance(max_seqlen_k, int)
2067
+
2068
+ kwargs.update({
2069
+ "cu_seqlens_q": cu_seqlens_q,
2070
+ "max_seqlen_q": max_seqlen_q,
2071
+ "cu_seqlens_k": cu_seqlens_k,
2072
+ "max_seqlen_k": max_seqlen_k,
2073
+ })
2074
+ attn_func = flash_attn_varlen_func
2075
+ else:
2076
+ attn_func = flash_attn_func
2077
+
2078
+ if self.fa_og:
2079
+ attn11 = attn_func(q1, k1, v1, **kwargs)
2080
+ attn12 = attn_func(q1, k1, v2, **kwargs)
2081
+ attn1 = torch.cat([attn11, attn12], dim=-1)
2082
+ attn21 = attn_func(q2, k2, v1, **kwargs)
2083
+ attn22 = attn_func(q2, k2, v2, **kwargs)
2084
+ attn2 = torch.cat([attn21, attn22], dim=-1)
2085
+ else:
2086
+ attn1 = attn_func(q1, k1, v, **kwargs)
2087
+ attn2 = attn_func(q2, k2, v, **kwargs)
2088
+
2089
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
2090
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
2091
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
2092
+
2093
+ attn = attn1 - lambda_full * attn2
2094
+ attn = self.subln(attn)
2095
+ attn = attn * (1 - self.lambda_init)
2096
+ # reshape back to 2 * num_head
2097
+ attn = rearrange(attn, "... H (two D) -> ... (H two) D", two=2)
2098
+ return attn
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "199999": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "200018": {
15
+ "content": "<|endofprompt|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "200019": {
23
+ "content": "<|assistant|>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": true,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "200020": {
31
+ "content": "<|end|>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": true,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "200021": {
39
+ "content": "<|user|>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": true,
43
+ "single_word": false,
44
+ "special": true
45
+ },
46
+ "200022": {
47
+ "content": "<|system|>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": true,
51
+ "single_word": false,
52
+ "special": true
53
+ },
54
+ "200023": {
55
+ "content": "<|tool|>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": true,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "200024": {
63
+ "content": "<|/tool|>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": true,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "200025": {
71
+ "content": "<|tool_call|>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": true,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "200026": {
79
+ "content": "<|/tool_call|>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": true,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "200027": {
87
+ "content": "<|tool_response|>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": true,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "200028": {
95
+ "content": "<|tag|>",
96
+ "lstrip": false,
97
+ "normalized": false,
98
+ "rstrip": true,
99
+ "single_word": false,
100
+ "special": true
101
+ }
102
+ },
103
+ "bos_token": "<|endoftext|>",
104
+ "chat_template": "{% for message in messages %}{% if message['role'] == 'system' and 'tools' in message and message['tools'] is not none %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|tool|>' + message['tools'] + '<|/tool|>' + '<|end|>' }}{% else %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|end|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %}",
105
+ "clean_up_tokenization_spaces": false,
106
+ "eos_token": "<|endoftext|>",
107
+ "model_max_length": 65536,
108
+ "pad_token": "<|endoftext|>",
109
+ "tokenizer_class": "GPT2Tokenizer",
110
+ "unk_token": "<|endoftext|>"
111
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff