Upload 17 files
Browse filestext-generation, text-generation-inference
- CODE_OF_CONDUCT.md +9 -0
- LICENSE +22 -0
- NOTICE.md +38 -0
- README.md +235 -0
- SECURITY.md +41 -0
- added_tokens.json +12 -0
- config.json +37 -0
- configuration_phi4flash.py +173 -0
- generation_config.json +10 -0
- model-00001-of-00002.safetensors +3 -0
- model-00002-of-00002.safetensors +3 -0
- model.safetensors.index.json +441 -0
- modeling_phi4flash.py +2098 -0
- special_tokens_map.json +30 -0
- tokenizer.json +0 -0
- tokenizer_config.json +111 -0
- vocab.json +0 -0
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Microsoft Open Source Code of Conduct
|
2 |
+
|
3 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
4 |
+
|
5 |
+
Resources:
|
6 |
+
|
7 |
+
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
8 |
+
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
9 |
+
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
|
LICENSE
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Microsoft.
|
2 |
+
Copyright (c) Microsoft Corporation.
|
3 |
+
|
4 |
+
MIT License
|
5 |
+
|
6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
of this software and associated documentation files (the "Software"), to deal
|
8 |
+
in the Software without restriction, including without limitation the rights
|
9 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
copies of the Software, and to permit persons to whom the Software is
|
11 |
+
furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
The above copyright notice and this permission notice shall be included in all
|
14 |
+
copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
SOFTWARE.
|
NOTICE.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NOTICES AND INFORMATION
|
2 |
+
Do Not Translate or Localize
|
3 |
+
|
4 |
+
This software incorporates material from third parties.
|
5 |
+
|
6 |
+
**Component.** https://github.com/Dao-AILab/flash-attention
|
7 |
+
|
8 |
+
**Open Source License/Copyright Notice.**
|
9 |
+
|
10 |
+
BSD 3-Clause License
|
11 |
+
|
12 |
+
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
|
13 |
+
All rights reserved.
|
14 |
+
|
15 |
+
Redistribution and use in source and binary forms, with or without
|
16 |
+
modification, are permitted provided that the following conditions are met:
|
17 |
+
|
18 |
+
* Redistributions of source code must retain the above copyright notice, this
|
19 |
+
list of conditions and the following disclaimer.
|
20 |
+
|
21 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
22 |
+
this list of conditions and the following disclaimer in the documentation
|
23 |
+
and/or other materials provided with the distribution.
|
24 |
+
|
25 |
+
* Neither the name of the copyright holder nor the names of its
|
26 |
+
contributors may be used to endorse or promote products derived from
|
27 |
+
this software without specific prior written permission.
|
28 |
+
|
29 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
30 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
31 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
32 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
33 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
34 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
35 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
36 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
37 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
38 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
README.md
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- en
|
4 |
+
library_name: transformers
|
5 |
+
license: mit
|
6 |
+
license_link: https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/resolve/main/LICENSE
|
7 |
+
pipeline_tag: text-generation
|
8 |
+
tags:
|
9 |
+
- nlp
|
10 |
+
- math
|
11 |
+
- code
|
12 |
+
widget:
|
13 |
+
- messages:
|
14 |
+
- role: user
|
15 |
+
content: How to solve 3*x^2+4*x+5=1?
|
16 |
+
---
|
17 |
+
|
18 |
+
## Model Summary
|
19 |
+
|
20 |
+
Phi-4-mini-flash-reasoning is a lightweight open model built upon synthetic data with a focus on high-quality, reasoning dense data further finetuned for more advanced math reasoning capabilities.
|
21 |
+
The model belongs to the Phi-4 model family and supports 64K token context length.
|
22 |
+
|
23 |
+
📰 [Phi-4-mini-flash-reasoning Blog](https://azure.microsoft.com/en-us/blog/reasoning-reimagined-introducing-phi-4-mini-flash-reasoning/) <br>
|
24 |
+
📖 [Phi-4-mini-flash-reasoning Paper](https://aka.ms/flashreasoning-paper) | [HF Paper](https://huggingface.co/papers/2507.06607) <br>
|
25 |
+
📚 [Training Codebase](https://github.com/microsoft/ArchScale) <br>
|
26 |
+
👩🍳 [Phi Cookbook](https://github.com/microsoft/PhiCookBook) <br>
|
27 |
+
🏡 [Phi Portal](https://azure.microsoft.com/en-us/products/phi) <br>
|
28 |
+
🚀 [vLLM Inference](https://github.com/vllm-project/vllm/pull/20702) <br>
|
29 |
+
🖥️ Try It [Azure](https://ai.azure.com/explore/models/Phi-4-mini-flash-reasoning/version/1/registry/azureml-phi-prod) <br>
|
30 |
+
|
31 |
+
|
32 |
+
🎉**Phi-4 models**: [[Phi-4-mini-reasoning](https://huggingface.co/microsoft/Phi-4-mini-reasoning)] | [[Phi-4-reasoning](https://huggingface.co/microsoft/Phi-4-reasoning)] | [[multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) | [onnx](https://huggingface.co/microsoft/Phi-4-multimodal-instruct-onnx)];
|
33 |
+
[[mini-instruct](https://huggingface.co/microsoft/Phi-4-mini-instruct) | [onnx](https://huggingface.co/microsoft/Phi-4-mini-instruct-onnx)]
|
34 |
+
|
35 |
+
## Intended Uses
|
36 |
+
|
37 |
+
### Primary Use Cases
|
38 |
+
|
39 |
+
Phi-4-mini-flash-reasoning is designed for multi-step, logic-intensive mathematical problem-solving tasks under memory/compute constrained environments and latency bound scenarios.
|
40 |
+
Some of the use cases include formal proof generation, symbolic computation, advanced word problems, and a wide range of mathematical reasoning scenarios.
|
41 |
+
These models excel at maintaining context across steps, applying structured logic, and delivering accurate, reliable solutions in domains that require deep analytical thinking.
|
42 |
+
|
43 |
+
### Use Case Considerations
|
44 |
+
|
45 |
+
This model is designed and tested for math reasoning only. It is not specifically designed or evaluated for all downstream purposes.
|
46 |
+
Developers should consider common limitations of language models, as well as performance difference across languages, as they select use cases, and evaluate and mitigate for accuracy, safety, and fairness before using within a specific downstream use case, particularly for high-risk scenarios.
|
47 |
+
Developers should be aware of and adhere to applicable laws or regulations (including but not limited to privacy, trade compliance laws, etc.) that are relevant to their use case.
|
48 |
+
|
49 |
+
***Nothing contained in this Model Card should be interpreted as or deemed a restriction or modification to the license the model is released under.***
|
50 |
+
|
51 |
+
## Release Notes
|
52 |
+
|
53 |
+
This release of Phi-4-mini-flash-reasoning addresses user feedback and market demand for a compact reasoning model.
|
54 |
+
It is a compact transformer-based language model optimized for mathematical reasoning, built to deliver high-quality, step-by-step problem solving in environments where computing or latency is constrained.
|
55 |
+
The model is fine-tuned with synthetic math data from a more capable model (much larger, smarter, more accurate, and better at following instructions), which has resulted in enhanced reasoning performance.
|
56 |
+
Phi-4-mini-flash-reasoning balances reasoning ability with efficiency, making it potentially suitable for educational applications, embedded tutoring, and lightweight deployment on edge or mobile systems.
|
57 |
+
If a critical issue is identified with Phi-4-mini-flash-reasoning, it should be promptly reported through the MSRC Researcher Portal or secure@microsoft.com
|
58 |
+
|
59 |
+
### Model Quality
|
60 |
+
|
61 |
+
To understand the capabilities, the 3.8B parameters Phi-4-mini-flash-reasoning model was compared with a set of models over a variety of reasoning benchmarks.
|
62 |
+
We use a more accurate evaluation where Pass@1 accuracy is averaged over 64 samples for AIME24/25 and 8 samples for Math500 and GPQA Diamond. A high-level overview of the model quality is as follows:
|
63 |
+
|
64 |
+
| **Model** | **AIME24** | **AIME25** | **Math500** | **GPQA Diamond** |
|
65 |
+
| :----------------------------------- | :--------- | :--------- | :---------- | :--------------- |
|
66 |
+
| DeepSeek-R1-Distill-Qwen-1.5B | 29.58 | 20.78 | 84.50 | 37.69 |
|
67 |
+
| DeepSeek-R1-Distill-Qwen-7B | 53.70 | 35.94 | 93.03 | 47.85 |
|
68 |
+
| DeepSeek-R1-Distill-Llama-8B | 43.96 | 27.34 | 87.48 | 45.83 |
|
69 |
+
| Bespoke-Stratos-7B | 21.51 | 18.28 | 80.73 | 38.51 |
|
70 |
+
| OpenThinker-7B | 29.69 | 24.32 | 87.25 | 41.60 |
|
71 |
+
| Phi4-mini-Reasoning (3.8B) | 48.13 | 31.77 | 91.20 | 44.51 |
|
72 |
+
| **Phi4-mini-Flash-Reasoning (3.8B)** | **52.29** | **33.59** | **92.45** | **45.08** |
|
73 |
+
|
74 |
+
Overall, the model with only 3.8B-param achieves a similar level of math and science reasoning ability as much larger models.
|
75 |
+
However, it is still fundamentally limited by its size for certain tasks. The model simply does not have the capacity to store too much factual knowledge, therefore, users may experience factual incorrectness. However, it may be possible to resolve such weakness by augmenting Phi-4-mini-flash-reasoning with a search engine, particularly when using the model under RAG settings.
|
76 |
+
|
77 |
+
### Model Efficiency
|
78 |
+
|
79 |
+
The two figures below compare the latency and throughput performance of the Phi-4-mini-reasoning and Phi-4-mini-flash-reasoning models under the vLLM inference framework. All evaluations were performed on a single NVIDIA A100-80GB GPU with tensor parallelism disabled (TP = 1). The Phi-4-mini-flash-reasoning model, which incorporates a decoder-hybrid-decoder architecture with attention and state space model (SSM), exhibits significantly greater computational efficiency—achieving up-to a 10× improvement in throughput when processing user requests with 2K prompt length and 32K generation length. Furthermore, Phi-4-mini-flash-reasoning demonstrates near-linear growth in latency with respect to the number of tokens generated (up to 32k), in contrast to the quadratic growth observed in Phi-4-mini-reasoning. These findings indicate that Phi-4-mini-flash-reasoning is more scalable and better suited for long-sequence generation tasks.
|
80 |
+
|
81 |
+
<div align="left">
|
82 |
+
<img src="lat.png" width="300"/>
|
83 |
+
<img src="thr_lat.png" width="298"/>
|
84 |
+
</div>
|
85 |
+
Figure 1. The first plot shows average inference latency as a function of generation length, while the second plot illustrates how inference latency varies with throughput. Both experiments were conducted using the vLLM inference framework on a single A100-80GB GPU over varying concurrency levels of user requests.
|
86 |
+
|
87 |
+
## Usage
|
88 |
+
|
89 |
+
### Tokenizer
|
90 |
+
|
91 |
+
Phi-4-mini-flash-reasoning supports a vocabulary size of up to `200064` tokens. The [tokenizer files](https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/blob/main/added_tokens.json) already provide placeholder tokens that can be used for downstream fine-tuning, but they can also be extended up to the model's vocabulary size.
|
92 |
+
|
93 |
+
### Input Formats
|
94 |
+
|
95 |
+
Given the nature of the training data, the Phi-4-mini-flash-reasoning
|
96 |
+
model is best suited for prompts using this specific chat format:
|
97 |
+
|
98 |
+
```yaml
|
99 |
+
<|user|>How to solve 3*x^2+4*x+5=1?<|end|><|assistant|>
|
100 |
+
```
|
101 |
+
### Inference with transformers
|
102 |
+
List of required packages:
|
103 |
+
|
104 |
+
```
|
105 |
+
flash_attn==2.7.4.post1
|
106 |
+
torch==2.6.0
|
107 |
+
mamba-ssm==2.2.4 --no-build-isolation
|
108 |
+
causal-conv1d==1.5.0.post8
|
109 |
+
transformers==4.46.1
|
110 |
+
accelerate==1.4.0
|
111 |
+
```
|
112 |
+
|
113 |
+
Phi-4-mini-flash-reasoning is also available in [Azure AI Foundry](https://ai.azure.com/explore/models/Phi-4-mini-flash-reasoning/version/1/registry/azureml-phi-prod)
|
114 |
+
|
115 |
+
#### Example
|
116 |
+
|
117 |
+
After obtaining the Phi-4-mini-flash-reasoning model checkpoints, users can use this sample code for inference.
|
118 |
+
|
119 |
+
```python
|
120 |
+
import torch
|
121 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
122 |
+
torch.random.manual_seed(0)
|
123 |
+
|
124 |
+
model_id = "microsoft/Phi-4-mini-flash-reasoning"
|
125 |
+
model = AutoModelForCausalLM.from_pretrained(
|
126 |
+
model_id,
|
127 |
+
device_map="cuda",
|
128 |
+
torch_dtype="auto",
|
129 |
+
trust_remote_code=True,
|
130 |
+
)
|
131 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
132 |
+
|
133 |
+
messages = [{
|
134 |
+
"role": "user",
|
135 |
+
"content": "How to solve 3*x^2+4*x+5=1?"
|
136 |
+
}]
|
137 |
+
inputs = tokenizer.apply_chat_template(
|
138 |
+
messages,
|
139 |
+
add_generation_prompt=True,
|
140 |
+
return_dict=True,
|
141 |
+
return_tensors="pt",
|
142 |
+
)
|
143 |
+
|
144 |
+
outputs = model.generate(
|
145 |
+
**inputs.to(model.device),
|
146 |
+
max_new_tokens=32768,
|
147 |
+
temperature=0.6,
|
148 |
+
top_p=0.95,
|
149 |
+
do_sample=True,
|
150 |
+
)
|
151 |
+
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
152 |
+
|
153 |
+
print(outputs[0])
|
154 |
+
```
|
155 |
+
|
156 |
+
## Training
|
157 |
+
|
158 |
+
### Model
|
159 |
+
|
160 |
+
+ **Architecture:** Phi-4-mini-flash-reasoning adopts a hybrid SambaY architecture with Differential Attention, featuring 3.8 billion parameters and a 200K vocabulary. It incorporates state space models, grouped-query attention, a gated memory sharing mechanism, a shared key-value cache with a single global attention layer, and shared input-output embeddings.<br>
|
161 |
+
+ **Inputs:** Text. It is best suited for prompts using the chat format.<br>
|
162 |
+
+ **Context length:** 64K tokens<br>
|
163 |
+
+ **GPUs:** Pre-training: 1024 A100-80G; Reasoning training: 128 H100-80G <br>
|
164 |
+
+ **Training time:** Pre-training: 14 days; Reasoning training: 2days <br>
|
165 |
+
+ **Training data:** Pre-training: 5T tokens; Reasoning training: 150B tokens<br>
|
166 |
+
+ **Outputs:** Generated text<br>
|
167 |
+
+ **Dates:** Trained in May 2025 <br>
|
168 |
+
+ **Status:** This is a static model trained on offline datasets with the cutoff date of February 2025 for publicly available data.<br>
|
169 |
+
+ **Supported languages:** English<br>
|
170 |
+
+ **Release date:** June 2025<br>
|
171 |
+
|
172 |
+
### Training Datasets
|
173 |
+
|
174 |
+
The training data for Phi-4-mini-flash-reasoning consists exclusively of synthetic mathematical content generated by a stronger and more advanced reasoning model, Deepseek-R1.
|
175 |
+
The objective is to distill knowledge from this model. This synthetic dataset comprises over one million diverse math problems spanning multiple levels of difficulty (from middle school to Ph.D. level).
|
176 |
+
For each problem in the synthetic dataset, eight distinct solutions (rollouts) were sampled, and only those verified as correct were retained, resulting in approximately 30 billion tokens of math content.
|
177 |
+
The dataset integrates three primary components:
|
178 |
+
1) a curated selection of high-quality, publicly available math questions and a part of the SFT(Supervised Fine-Tuning) data that was used to train the base Phi-4-mini-flash model;
|
179 |
+
2) an extensive collection of synthetic math data generated by the Deepseek-R1 model, designed specifically for high-quality supervised fine-tuning and model distillation; and
|
180 |
+
3) a balanced set of correct and incorrect answers used to construct preference data aimed at enhancing Phi-4-mini-flash-reasoning's reasoning capabilities by learning more effective reasoning trajectories
|
181 |
+
|
182 |
+
## Software
|
183 |
+
* [PyTorch](https://github.com/pytorch/pytorch)
|
184 |
+
* [Transformers](https://github.com/huggingface/transformers)
|
185 |
+
* [Flash-Attention](https://github.com/HazyResearch/flash-attention)
|
186 |
+
* [Mamba](https://github.com/state-spaces/mamba)
|
187 |
+
* [Causal-Conv1d](https://github.com/Dao-AILab/causal-conv1d)
|
188 |
+
|
189 |
+
## Hardware
|
190 |
+
Note that by default, the Phi-4-mini-flash-reasoning model uses flash attention, which requires certain types of GPU hardware to run. We have tested on the following GPU types:
|
191 |
+
* NVIDIA A100
|
192 |
+
* NVIDIA H100
|
193 |
+
|
194 |
+
## Safety Evaluation and Red-Teaming
|
195 |
+
|
196 |
+
The Phi-4 family of models has adopted a robust safety post-training approach. This approach leverages a variety of both open-source and in-house generated datasets. The overall technique employed to do the safety alignment is a combination of SFT, DPO (Direct Preference Optimization), and RLHF (Reinforcement Learning from Human Feedback) approaches by utilizing human-labeled and synthetic English-language datasets, including publicly available datasets focusing on helpfulness and harmlessness, as well as various questions and answers targeted to multiple safety categories.
|
197 |
+
|
198 |
+
Phi-4-Mini-Flash-Reasoning was developed in accordance with Microsoft's responsible AI principles. Potential safety risks in the model’s responses were assessed using the Azure AI Foundry’s Risk and Safety Evaluation framework, focusing on harmful content, direct jailbreak, and model groundedness. The Phi-4-Mini-Flash-Reasoning Model Card contains additional information about our approach to safety and responsible AI considerations that developers should be aware of when using this model.
|
199 |
+
|
200 |
+
## Responsible AI Considerations
|
201 |
+
|
202 |
+
Like other language models, the Phi family of models can potentially behave in ways that are unfair, unreliable, or offensive. Some of the limiting behaviors to be aware of include:
|
203 |
+
|
204 |
+
+ Quality of Service: The Phi models are trained primarily on English text and some additional multilingual text. Languages other than English will experience worse performance as well as performance disparities across non-English. English language varieties with less representation in the training data might experience worse performance than standard American English.
|
205 |
+
+ Multilingual performance and safety gaps: We believe it is important to make language models more widely available across different languages, but the Phi 4 models still exhibit challenges common across multilingual releases. As with any deployment of LLMs, developers will be better positioned to test for performance or safety gaps for their linguistic and cultural context and customize the model with additional fine-tuning and appropriate safeguards.
|
206 |
+
+ Representation of Harms & Perpetuation of Stereotypes: These models can over- or under-represent groups of people, erase representation of some groups, or reinforce demeaning or negative stereotypes. Despite safety post-training, these limitations may still be present due to differing levels of representation of different groups, cultural contexts, or prevalence of examples of negative stereotypes in training data that reflect real-world patterns and societal biases.
|
207 |
+
+ Inappropriate or Offensive Content: These models may produce other types of inappropriate or offensive content, which may make it inappropriate to deploy for sensitive contexts without additional mitigations that are specific to the case.
|
208 |
+
+ Information Reliability: Language models can generate nonsensical content or fabricate content that might sound reasonable but is inaccurate or outdated.
|
209 |
+
+ Election Information Reliability : The model has an elevated defect rate when responding to election-critical queries, which may result in incorrect or unauthoritative election critical information being presented. We are working to improve the model's performance in this area. Users should verify information related to elections with the election authority in their region.
|
210 |
+
+ Limited Scope for Code: The majority of Phi 4 training data is based in Python and uses common packages such as "typing, math, random, collections, datetime, itertools". If the model generates Python scripts that utilize other packages or scripts in other languages, it is strongly recommended that users manually verify all API uses.
|
211 |
+
+ Long Conversation: Phi 4 models, like other models, can in some cases generate responses that are repetitive, unhelpful, or inconsistent in very long chat sessions in both English and non-English languages. Developers are encouraged to place appropriate mitigations, like limiting conversation turns to account for the possible conversational drift.
|
212 |
+
|
213 |
+
Developers should apply responsible AI best practices, including mapping, measuring, and mitigating risks associated with their specific use case and cultural, linguistic context. Phi 4 family of models are general purpose models. As developers plan to deploy these models for specific use cases, they are encouraged to fine-tune the models for their use case and leverage the models as part of broader AI systems with language-specific safeguards in place. Important areas for consideration include:
|
214 |
+
|
215 |
+
+ Allocation: Models may not be suitable for scenarios that could have consequential impact on legal status or the allocation of resources or life opportunities (ex: housing, employment, credit, etc.) without further assessments and additional debiasing techniques.
|
216 |
+
+ High-Risk Scenarios: Developers should assess the suitability of using models in high-risk scenarios where unfair, unreliable or offensive outputs might be extremely costly or lead to harm. This includes providing advice in sensitive or expert domains where accuracy and reliability are critical (ex: legal or health advice). Additional safeguards should be implemented at the application level according to the deployment context.
|
217 |
+
+ Misinformation: Models may produce inaccurate information. Developers should follow transparency best practices and inform end-users they are interacting with an AI system. At the application level, developers can build feedback mechanisms and pipelines to ground responses in use-case specific, contextual information, a technique known as Retrieval Augmented Generation (RAG).
|
218 |
+
+ Generation of Harmful Content: Developers should assess outputs for their context and use available safety classifiers or custom solutions appropriate for their use case.
|
219 |
+
+ Misuse: Other forms of misuse such as fraud, spam, or malware production may be possible, and developers should ensure that their applications do not violate applicable laws and regulations.
|
220 |
+
|
221 |
+
## License
|
222 |
+
The model is licensed under the [MIT license](./LICENSE).
|
223 |
+
|
224 |
+
## Trademarks
|
225 |
+
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft’s Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party’s policies.
|
226 |
+
|
227 |
+
|
228 |
+
## Appendix A: Benchmark Methodology
|
229 |
+
|
230 |
+
We include a brief word on methodology here - and in particular, how we think about optimizing prompts. In an ideal world, we would never change any prompts in our benchmarks to ensure it is always an apples-to-apples comparison when comparing different models. Indeed, this is our default approach, and is the case in the vast majority of models we have run to date. For all benchmarks, we consider using the same generation configuration such as max sequence length (32768), the same temperature for the fair comparison.
|
231 |
+
Benchmark datasets
|
232 |
+
We evaluate the model with three of the most popular math benchmarks where the strongest reasoning models are competing together. Specifically:
|
233 |
+
+ Math-500: This benchmark consists of 500 challenging math problems designed to test the model's ability to perform complex mathematical reasoning and problem-solving.
|
234 |
+
+ AIME 2024/AIME 2025: The American Invitational Mathematics Examination (AIME) is a highly regarded math competition that features a series of difficult problems aimed at assessing advanced mathematical skills and logical reasoning. We evaluate the models on the problems from both 2024 and the year 2025 examinations.
|
235 |
+
+ GPQA Diamond: The Graduate-Level Google-Proof Q&A (GPQA) Diamond benchmark focuses on evaluating the model's ability to understand and solve a wide range of mathematical questions, including both straightforward calculations and more intricate problem-solving tasks.
|
SECURITY.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->
|
2 |
+
|
3 |
+
## Security
|
4 |
+
|
5 |
+
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
|
6 |
+
|
7 |
+
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
|
8 |
+
|
9 |
+
## Reporting Security Issues
|
10 |
+
|
11 |
+
**Please do not report security vulnerabilities through public GitHub issues.**
|
12 |
+
|
13 |
+
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
|
14 |
+
|
15 |
+
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
|
16 |
+
|
17 |
+
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
18 |
+
|
19 |
+
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
20 |
+
|
21 |
+
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
22 |
+
* Full paths of source file(s) related to the manifestation of the issue
|
23 |
+
* The location of the affected source code (tag/branch/commit or direct URL)
|
24 |
+
* Any special configuration required to reproduce the issue
|
25 |
+
* Step-by-step instructions to reproduce the issue
|
26 |
+
* Proof-of-concept or exploit code (if possible)
|
27 |
+
* Impact of the issue, including how an attacker might exploit the issue
|
28 |
+
|
29 |
+
This information will help us triage your report more quickly.
|
30 |
+
|
31 |
+
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
|
32 |
+
|
33 |
+
## Preferred Languages
|
34 |
+
|
35 |
+
We prefer all communications to be in English.
|
36 |
+
|
37 |
+
## Policy
|
38 |
+
|
39 |
+
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
|
40 |
+
|
41 |
+
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
added_tokens.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<|/tool_call|>": 200026,
|
3 |
+
"<|/tool|>": 200024,
|
4 |
+
"<|assistant|>": 200019,
|
5 |
+
"<|end|>": 200020,
|
6 |
+
"<|system|>": 200022,
|
7 |
+
"<|tag|>": 200028,
|
8 |
+
"<|tool_call|>": 200025,
|
9 |
+
"<|tool_response|>": 200027,
|
10 |
+
"<|tool|>": 200023,
|
11 |
+
"<|user|>": 200021
|
12 |
+
}
|
config.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"Phi4FlashForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_dropout": 0.0,
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "configuration_phi4flash.Phi4FlashConfig",
|
8 |
+
"AutoModelForCausalLM": "modeling_phi4flash.Phi4FlashForCausalLM",
|
9 |
+
"AutoTokenizer": "Xenova/gpt-4o"
|
10 |
+
},
|
11 |
+
"pad_token_id": 199999,
|
12 |
+
"bos_token_id": 199999,
|
13 |
+
"embd_pdrop": 0.0,
|
14 |
+
"eos_token_id": 199999,
|
15 |
+
"hidden_act": "silu",
|
16 |
+
"hidden_size": 2560,
|
17 |
+
"initializer_range": 0.02,
|
18 |
+
"intermediate_size": 10240,
|
19 |
+
"layer_norm_eps": 1e-5,
|
20 |
+
"max_position_embeddings": 262144,
|
21 |
+
"_attn_implementation": "flash_attention_2",
|
22 |
+
"mb_per_layer": 2,
|
23 |
+
"model_type": "phi4flash",
|
24 |
+
"num_attention_heads": 40,
|
25 |
+
"num_hidden_layers": 32,
|
26 |
+
"num_key_value_heads": 20,
|
27 |
+
"resid_pdrop": 0.0,
|
28 |
+
"sliding_window": 512,
|
29 |
+
"torch_dtype": "bfloat16",
|
30 |
+
"tie_word_embeddings": true,
|
31 |
+
"transformers_version": "4.46.1",
|
32 |
+
"use_cache": true,
|
33 |
+
"mlp_bias": false,
|
34 |
+
"lm_head_bias": false,
|
35 |
+
"vocab_size": 200064
|
36 |
+
}
|
37 |
+
|
configuration_phi4flash.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
""" Phi4Flash model configuration"""
|
17 |
+
|
18 |
+
|
19 |
+
from transformers.configuration_utils import PretrainedConfig
|
20 |
+
from transformers.utils import logging
|
21 |
+
import math
|
22 |
+
logger = logging.get_logger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class Phi4FlashConfig(PretrainedConfig):
|
26 |
+
r"""
|
27 |
+
This is the configuration class to store the configuration of a [`Phi4FlashModel`]. It is used to instantiate an Phi4Flash
|
28 |
+
model according to the specified arguments, defining the model architecture.
|
29 |
+
|
30 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
31 |
+
documentation from [`PretrainedConfig`] for more information.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
vocab_size (`int`, *optional*, defaults to 51200):
|
35 |
+
Vocabulary size of the Phi4Flash model. Defines the number of different tokens that can be represented by the
|
36 |
+
`inputs_ids` passed when calling [`Phi4FlashModel`].
|
37 |
+
hidden_size (`int`, *optional*, defaults to 2048):
|
38 |
+
Dimension of the hidden representations.
|
39 |
+
intermediate_size (`int`, *optional*, defaults to 8192):
|
40 |
+
Dimension of the MLP representations.
|
41 |
+
num_hidden_layers (`int`, *optional*, defaults to 24):
|
42 |
+
Number of hidden layers in the Transformer decoder.
|
43 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
44 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
45 |
+
num_key_value_heads (`int`, *optional*):
|
46 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
47 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
48 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
49 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
50 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
51 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
52 |
+
`num_attention_heads`.
|
53 |
+
resid_pdrop (`float`, *optional*, defaults to 0.0):
|
54 |
+
Dropout probability for mlp outputs.
|
55 |
+
embd_pdrop (`int`, *optional*, defaults to 0.0):
|
56 |
+
The dropout ratio for the embeddings.
|
57 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
58 |
+
The dropout ratio after computing the attention scores.
|
59 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`):
|
60 |
+
The non-linear activation function (function or string) in the decoder.
|
61 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
62 |
+
The maximum sequence length that this model might ever be used with. Phi-1 and Phi-1.5 supports up to 2048
|
63 |
+
tokens.
|
64 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
65 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
66 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
67 |
+
The epsilon used by the rms normalization layers.
|
68 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
69 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
70 |
+
relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
|
71 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
72 |
+
Whether to tie weight embeddings
|
73 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
74 |
+
The base period of the RoPE embeddings.
|
75 |
+
|
76 |
+
Example:
|
77 |
+
|
78 |
+
```python
|
79 |
+
>>> from transformers import Phi4FlashModel, Phi4FlashConfig
|
80 |
+
|
81 |
+
>>> # Initializing a Phi4Flash style configuration
|
82 |
+
>>> configuration = Phi4FlashConfig.from_pretrained("microsoft/Phi4-mini-flash-reasoning")
|
83 |
+
|
84 |
+
>>> # Initializing a model from the configuration
|
85 |
+
>>> model = Phi4FlashModel(configuration)
|
86 |
+
|
87 |
+
>>> # Accessing the model configuration
|
88 |
+
>>> configuration = model.config
|
89 |
+
```"""
|
90 |
+
|
91 |
+
model_type = "phi4flash"
|
92 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
93 |
+
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
vocab_size=51200,
|
97 |
+
hidden_size=2560,
|
98 |
+
intermediate_size=9216,
|
99 |
+
num_hidden_layers=32,
|
100 |
+
num_attention_heads=40,
|
101 |
+
num_key_value_heads=4,
|
102 |
+
resid_pdrop=0.0,
|
103 |
+
embd_pdrop=0.0,
|
104 |
+
attention_dropout=0.0,
|
105 |
+
hidden_act="silu",
|
106 |
+
max_position_embeddings=4096,
|
107 |
+
initializer_range=0.02,
|
108 |
+
layer_norm_eps=1e-5,
|
109 |
+
use_cache=True,
|
110 |
+
tie_word_embeddings=True,
|
111 |
+
rope_theta=10000.0,
|
112 |
+
bos_token_id=1,
|
113 |
+
eos_token_id=2,
|
114 |
+
sliding_window=2047,
|
115 |
+
mb_per_layer= 2,
|
116 |
+
mamba_d_state=16,
|
117 |
+
mamba_d_conv=4,
|
118 |
+
mamba_expand=2,
|
119 |
+
mamba_dt_rank="auto",
|
120 |
+
mamba_conv_bias=True,
|
121 |
+
mamba_proj_bias=False,
|
122 |
+
**kwargs,
|
123 |
+
):
|
124 |
+
self.vocab_size = vocab_size
|
125 |
+
self.hidden_size = hidden_size
|
126 |
+
self.intermediate_size = intermediate_size
|
127 |
+
self.num_hidden_layers = num_hidden_layers
|
128 |
+
self.num_attention_heads = num_attention_heads
|
129 |
+
|
130 |
+
if num_key_value_heads is None:
|
131 |
+
num_key_value_heads = num_attention_heads
|
132 |
+
|
133 |
+
self.num_key_value_heads = num_key_value_heads
|
134 |
+
self.resid_pdrop = resid_pdrop
|
135 |
+
self.embd_pdrop = embd_pdrop
|
136 |
+
self.attention_dropout = attention_dropout
|
137 |
+
self.hidden_act = hidden_act
|
138 |
+
self.max_position_embeddings = max_position_embeddings
|
139 |
+
self.initializer_range = initializer_range
|
140 |
+
self.layer_norm_eps = layer_norm_eps
|
141 |
+
self.use_cache = use_cache
|
142 |
+
self.rope_theta = rope_theta
|
143 |
+
self.mb_per_layer = mb_per_layer
|
144 |
+
self.sliding_window = [
|
145 |
+
sliding_window if layer_idx < num_hidden_layers // 2 and layer_idx % 2 == 1 else None
|
146 |
+
for layer_idx in range(num_hidden_layers)
|
147 |
+
]
|
148 |
+
|
149 |
+
self.mamba_d_state = mamba_d_state
|
150 |
+
self.mamba_d_conv = mamba_d_conv
|
151 |
+
self.mamba_expand = mamba_expand
|
152 |
+
self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
|
153 |
+
self.mamba_conv_bias = mamba_conv_bias
|
154 |
+
self.mamba_proj_bias = mamba_proj_bias
|
155 |
+
|
156 |
+
super().__init__(
|
157 |
+
bos_token_id=bos_token_id,
|
158 |
+
eos_token_id=eos_token_id,
|
159 |
+
tie_word_embeddings=tie_word_embeddings,
|
160 |
+
**kwargs,
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
@property
|
165 |
+
def layers_block_type(self):
|
166 |
+
layer_block_types = []
|
167 |
+
for i in range(self.num_hidden_layers):
|
168 |
+
if i % 2 == 1:
|
169 |
+
layer_block_type = "attention" if i <= (self.num_hidden_layers //2 +1) else "shared_attention"
|
170 |
+
else:
|
171 |
+
layer_block_type = "mamba"
|
172 |
+
layer_block_types.append(layer_block_type)
|
173 |
+
return layer_block_types
|
generation_config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 199999,
|
4 |
+
"eos_token_id": [
|
5 |
+
200020,
|
6 |
+
199999
|
7 |
+
],
|
8 |
+
"pad_token_id": 199999,
|
9 |
+
"transformers_version": "4.45.0"
|
10 |
+
}
|
model-00001-of-00002.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:58cb8678cd1495c42afd4b2d9ccd26bbace3e54e94905db8c346e7b39ce7a956
|
3 |
+
size 135
|
model-00002-of-00002.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:848f66d812125dbf83052af44cd1e010d9b0cf92b5b532b7df61ce1e65935c3e
|
3 |
+
size 135
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 7706608640
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"model.embed_tokens.weight": "model-00001-of-00002.safetensors",
|
7 |
+
"model.final_layernorm.bias": "model-00002-of-00002.safetensors",
|
8 |
+
"model.final_layernorm.weight": "model-00002-of-00002.safetensors",
|
9 |
+
"model.layers.0.attn.A_log": "model-00001-of-00002.safetensors",
|
10 |
+
"model.layers.0.attn.D": "model-00001-of-00002.safetensors",
|
11 |
+
"model.layers.0.attn.conv1d.bias": "model-00001-of-00002.safetensors",
|
12 |
+
"model.layers.0.attn.conv1d.weight": "model-00001-of-00002.safetensors",
|
13 |
+
"model.layers.0.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
|
14 |
+
"model.layers.0.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
|
15 |
+
"model.layers.0.attn.in_proj.weight": "model-00001-of-00002.safetensors",
|
16 |
+
"model.layers.0.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
17 |
+
"model.layers.0.attn.x_proj.weight": "model-00001-of-00002.safetensors",
|
18 |
+
"model.layers.0.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
19 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
20 |
+
"model.layers.0.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
21 |
+
"model.layers.0.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
22 |
+
"model.layers.0.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
23 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
24 |
+
"model.layers.1.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
|
25 |
+
"model.layers.1.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
|
26 |
+
"model.layers.1.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
|
27 |
+
"model.layers.1.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
|
28 |
+
"model.layers.1.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
|
29 |
+
"model.layers.1.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
|
30 |
+
"model.layers.1.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
|
31 |
+
"model.layers.1.attn.out_proj.bias": "model-00001-of-00002.safetensors",
|
32 |
+
"model.layers.1.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
33 |
+
"model.layers.1.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
34 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
35 |
+
"model.layers.1.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
36 |
+
"model.layers.1.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
37 |
+
"model.layers.1.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
38 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
39 |
+
"model.layers.10.attn.A_log": "model-00001-of-00002.safetensors",
|
40 |
+
"model.layers.10.attn.D": "model-00001-of-00002.safetensors",
|
41 |
+
"model.layers.10.attn.conv1d.bias": "model-00001-of-00002.safetensors",
|
42 |
+
"model.layers.10.attn.conv1d.weight": "model-00001-of-00002.safetensors",
|
43 |
+
"model.layers.10.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
|
44 |
+
"model.layers.10.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
|
45 |
+
"model.layers.10.attn.in_proj.weight": "model-00001-of-00002.safetensors",
|
46 |
+
"model.layers.10.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
47 |
+
"model.layers.10.attn.x_proj.weight": "model-00001-of-00002.safetensors",
|
48 |
+
"model.layers.10.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
49 |
+
"model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
50 |
+
"model.layers.10.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
51 |
+
"model.layers.10.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
52 |
+
"model.layers.10.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
53 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
54 |
+
"model.layers.11.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
|
55 |
+
"model.layers.11.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
|
56 |
+
"model.layers.11.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
|
57 |
+
"model.layers.11.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
|
58 |
+
"model.layers.11.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
|
59 |
+
"model.layers.11.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
|
60 |
+
"model.layers.11.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
|
61 |
+
"model.layers.11.attn.out_proj.bias": "model-00001-of-00002.safetensors",
|
62 |
+
"model.layers.11.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
63 |
+
"model.layers.11.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
64 |
+
"model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
65 |
+
"model.layers.11.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
66 |
+
"model.layers.11.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
67 |
+
"model.layers.11.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
68 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
69 |
+
"model.layers.12.attn.A_log": "model-00001-of-00002.safetensors",
|
70 |
+
"model.layers.12.attn.D": "model-00001-of-00002.safetensors",
|
71 |
+
"model.layers.12.attn.conv1d.bias": "model-00001-of-00002.safetensors",
|
72 |
+
"model.layers.12.attn.conv1d.weight": "model-00001-of-00002.safetensors",
|
73 |
+
"model.layers.12.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
|
74 |
+
"model.layers.12.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
|
75 |
+
"model.layers.12.attn.in_proj.weight": "model-00001-of-00002.safetensors",
|
76 |
+
"model.layers.12.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
77 |
+
"model.layers.12.attn.x_proj.weight": "model-00001-of-00002.safetensors",
|
78 |
+
"model.layers.12.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
79 |
+
"model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
80 |
+
"model.layers.12.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
81 |
+
"model.layers.12.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
82 |
+
"model.layers.12.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
83 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
84 |
+
"model.layers.13.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
|
85 |
+
"model.layers.13.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
|
86 |
+
"model.layers.13.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
|
87 |
+
"model.layers.13.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
|
88 |
+
"model.layers.13.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
|
89 |
+
"model.layers.13.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
|
90 |
+
"model.layers.13.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
|
91 |
+
"model.layers.13.attn.out_proj.bias": "model-00001-of-00002.safetensors",
|
92 |
+
"model.layers.13.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
93 |
+
"model.layers.13.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
94 |
+
"model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
95 |
+
"model.layers.13.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
96 |
+
"model.layers.13.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
97 |
+
"model.layers.13.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
98 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
99 |
+
"model.layers.14.attn.A_log": "model-00001-of-00002.safetensors",
|
100 |
+
"model.layers.14.attn.D": "model-00001-of-00002.safetensors",
|
101 |
+
"model.layers.14.attn.conv1d.bias": "model-00001-of-00002.safetensors",
|
102 |
+
"model.layers.14.attn.conv1d.weight": "model-00001-of-00002.safetensors",
|
103 |
+
"model.layers.14.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
|
104 |
+
"model.layers.14.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
|
105 |
+
"model.layers.14.attn.in_proj.weight": "model-00001-of-00002.safetensors",
|
106 |
+
"model.layers.14.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
107 |
+
"model.layers.14.attn.x_proj.weight": "model-00001-of-00002.safetensors",
|
108 |
+
"model.layers.14.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
109 |
+
"model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
110 |
+
"model.layers.14.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
111 |
+
"model.layers.14.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
112 |
+
"model.layers.14.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
113 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
114 |
+
"model.layers.15.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
|
115 |
+
"model.layers.15.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
|
116 |
+
"model.layers.15.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
|
117 |
+
"model.layers.15.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
|
118 |
+
"model.layers.15.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
|
119 |
+
"model.layers.15.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
|
120 |
+
"model.layers.15.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
|
121 |
+
"model.layers.15.attn.out_proj.bias": "model-00001-of-00002.safetensors",
|
122 |
+
"model.layers.15.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
123 |
+
"model.layers.15.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
124 |
+
"model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
125 |
+
"model.layers.15.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
126 |
+
"model.layers.15.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
127 |
+
"model.layers.15.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
128 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
129 |
+
"model.layers.16.attn.A_log": "model-00001-of-00002.safetensors",
|
130 |
+
"model.layers.16.attn.D": "model-00001-of-00002.safetensors",
|
131 |
+
"model.layers.16.attn.conv1d.bias": "model-00001-of-00002.safetensors",
|
132 |
+
"model.layers.16.attn.conv1d.weight": "model-00001-of-00002.safetensors",
|
133 |
+
"model.layers.16.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
|
134 |
+
"model.layers.16.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
|
135 |
+
"model.layers.16.attn.in_proj.weight": "model-00001-of-00002.safetensors",
|
136 |
+
"model.layers.16.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
137 |
+
"model.layers.16.attn.x_proj.weight": "model-00001-of-00002.safetensors",
|
138 |
+
"model.layers.16.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
139 |
+
"model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
140 |
+
"model.layers.16.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
141 |
+
"model.layers.16.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
142 |
+
"model.layers.16.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
143 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
144 |
+
"model.layers.17.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
|
145 |
+
"model.layers.17.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
|
146 |
+
"model.layers.17.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
|
147 |
+
"model.layers.17.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
|
148 |
+
"model.layers.17.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
|
149 |
+
"model.layers.17.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
|
150 |
+
"model.layers.17.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
|
151 |
+
"model.layers.17.attn.out_proj.bias": "model-00001-of-00002.safetensors",
|
152 |
+
"model.layers.17.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
153 |
+
"model.layers.17.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
154 |
+
"model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
155 |
+
"model.layers.17.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
156 |
+
"model.layers.17.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
157 |
+
"model.layers.17.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
158 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
159 |
+
"model.layers.18.attn.in_proj.weight": "model-00002-of-00002.safetensors",
|
160 |
+
"model.layers.18.attn.out_proj.weight": "model-00002-of-00002.safetensors",
|
161 |
+
"model.layers.18.input_layernorm.bias": "model-00002-of-00002.safetensors",
|
162 |
+
"model.layers.18.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
163 |
+
"model.layers.18.mlp.fc1.weight": "model-00002-of-00002.safetensors",
|
164 |
+
"model.layers.18.mlp.fc2.weight": "model-00002-of-00002.safetensors",
|
165 |
+
"model.layers.18.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
|
166 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
167 |
+
"model.layers.19.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
|
168 |
+
"model.layers.19.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
|
169 |
+
"model.layers.19.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
|
170 |
+
"model.layers.19.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
|
171 |
+
"model.layers.19.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
|
172 |
+
"model.layers.19.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
|
173 |
+
"model.layers.19.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
|
174 |
+
"model.layers.19.attn.out_proj.bias": "model-00002-of-00002.safetensors",
|
175 |
+
"model.layers.19.attn.out_proj.weight": "model-00002-of-00002.safetensors",
|
176 |
+
"model.layers.19.input_layernorm.bias": "model-00002-of-00002.safetensors",
|
177 |
+
"model.layers.19.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
178 |
+
"model.layers.19.mlp.fc1.weight": "model-00002-of-00002.safetensors",
|
179 |
+
"model.layers.19.mlp.fc2.weight": "model-00002-of-00002.safetensors",
|
180 |
+
"model.layers.19.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
|
181 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
182 |
+
"model.layers.2.attn.A_log": "model-00001-of-00002.safetensors",
|
183 |
+
"model.layers.2.attn.D": "model-00001-of-00002.safetensors",
|
184 |
+
"model.layers.2.attn.conv1d.bias": "model-00001-of-00002.safetensors",
|
185 |
+
"model.layers.2.attn.conv1d.weight": "model-00001-of-00002.safetensors",
|
186 |
+
"model.layers.2.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
|
187 |
+
"model.layers.2.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
|
188 |
+
"model.layers.2.attn.in_proj.weight": "model-00001-of-00002.safetensors",
|
189 |
+
"model.layers.2.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
190 |
+
"model.layers.2.attn.x_proj.weight": "model-00001-of-00002.safetensors",
|
191 |
+
"model.layers.2.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
192 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
193 |
+
"model.layers.2.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
194 |
+
"model.layers.2.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
195 |
+
"model.layers.2.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
196 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
197 |
+
"model.layers.20.attn.in_proj.weight": "model-00002-of-00002.safetensors",
|
198 |
+
"model.layers.20.attn.out_proj.weight": "model-00002-of-00002.safetensors",
|
199 |
+
"model.layers.20.input_layernorm.bias": "model-00002-of-00002.safetensors",
|
200 |
+
"model.layers.20.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
201 |
+
"model.layers.20.mlp.fc1.weight": "model-00002-of-00002.safetensors",
|
202 |
+
"model.layers.20.mlp.fc2.weight": "model-00002-of-00002.safetensors",
|
203 |
+
"model.layers.20.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
|
204 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
205 |
+
"model.layers.21.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
|
206 |
+
"model.layers.21.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
|
207 |
+
"model.layers.21.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
|
208 |
+
"model.layers.21.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
|
209 |
+
"model.layers.21.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
|
210 |
+
"model.layers.21.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
|
211 |
+
"model.layers.21.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
|
212 |
+
"model.layers.21.attn.out_proj.bias": "model-00002-of-00002.safetensors",
|
213 |
+
"model.layers.21.attn.out_proj.weight": "model-00002-of-00002.safetensors",
|
214 |
+
"model.layers.21.input_layernorm.bias": "model-00002-of-00002.safetensors",
|
215 |
+
"model.layers.21.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
216 |
+
"model.layers.21.mlp.fc1.weight": "model-00002-of-00002.safetensors",
|
217 |
+
"model.layers.21.mlp.fc2.weight": "model-00002-of-00002.safetensors",
|
218 |
+
"model.layers.21.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
|
219 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
220 |
+
"model.layers.22.attn.in_proj.weight": "model-00002-of-00002.safetensors",
|
221 |
+
"model.layers.22.attn.out_proj.weight": "model-00002-of-00002.safetensors",
|
222 |
+
"model.layers.22.input_layernorm.bias": "model-00002-of-00002.safetensors",
|
223 |
+
"model.layers.22.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
224 |
+
"model.layers.22.mlp.fc1.weight": "model-00002-of-00002.safetensors",
|
225 |
+
"model.layers.22.mlp.fc2.weight": "model-00002-of-00002.safetensors",
|
226 |
+
"model.layers.22.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
|
227 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
228 |
+
"model.layers.23.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
|
229 |
+
"model.layers.23.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
|
230 |
+
"model.layers.23.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
|
231 |
+
"model.layers.23.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
|
232 |
+
"model.layers.23.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
|
233 |
+
"model.layers.23.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
|
234 |
+
"model.layers.23.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
|
235 |
+
"model.layers.23.attn.out_proj.bias": "model-00002-of-00002.safetensors",
|
236 |
+
"model.layers.23.attn.out_proj.weight": "model-00002-of-00002.safetensors",
|
237 |
+
"model.layers.23.input_layernorm.bias": "model-00002-of-00002.safetensors",
|
238 |
+
"model.layers.23.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
239 |
+
"model.layers.23.mlp.fc1.weight": "model-00002-of-00002.safetensors",
|
240 |
+
"model.layers.23.mlp.fc2.weight": "model-00002-of-00002.safetensors",
|
241 |
+
"model.layers.23.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
|
242 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
243 |
+
"model.layers.24.attn.in_proj.weight": "model-00002-of-00002.safetensors",
|
244 |
+
"model.layers.24.attn.out_proj.weight": "model-00002-of-00002.safetensors",
|
245 |
+
"model.layers.24.input_layernorm.bias": "model-00002-of-00002.safetensors",
|
246 |
+
"model.layers.24.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
247 |
+
"model.layers.24.mlp.fc1.weight": "model-00002-of-00002.safetensors",
|
248 |
+
"model.layers.24.mlp.fc2.weight": "model-00002-of-00002.safetensors",
|
249 |
+
"model.layers.24.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
|
250 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
251 |
+
"model.layers.25.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
|
252 |
+
"model.layers.25.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
|
253 |
+
"model.layers.25.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
|
254 |
+
"model.layers.25.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
|
255 |
+
"model.layers.25.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
|
256 |
+
"model.layers.25.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
|
257 |
+
"model.layers.25.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
|
258 |
+
"model.layers.25.attn.out_proj.bias": "model-00002-of-00002.safetensors",
|
259 |
+
"model.layers.25.attn.out_proj.weight": "model-00002-of-00002.safetensors",
|
260 |
+
"model.layers.25.input_layernorm.bias": "model-00002-of-00002.safetensors",
|
261 |
+
"model.layers.25.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
262 |
+
"model.layers.25.mlp.fc1.weight": "model-00002-of-00002.safetensors",
|
263 |
+
"model.layers.25.mlp.fc2.weight": "model-00002-of-00002.safetensors",
|
264 |
+
"model.layers.25.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
|
265 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
266 |
+
"model.layers.26.attn.in_proj.weight": "model-00002-of-00002.safetensors",
|
267 |
+
"model.layers.26.attn.out_proj.weight": "model-00002-of-00002.safetensors",
|
268 |
+
"model.layers.26.input_layernorm.bias": "model-00002-of-00002.safetensors",
|
269 |
+
"model.layers.26.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
270 |
+
"model.layers.26.mlp.fc1.weight": "model-00002-of-00002.safetensors",
|
271 |
+
"model.layers.26.mlp.fc2.weight": "model-00002-of-00002.safetensors",
|
272 |
+
"model.layers.26.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
|
273 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
274 |
+
"model.layers.27.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
|
275 |
+
"model.layers.27.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
|
276 |
+
"model.layers.27.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
|
277 |
+
"model.layers.27.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
|
278 |
+
"model.layers.27.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
|
279 |
+
"model.layers.27.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
|
280 |
+
"model.layers.27.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
|
281 |
+
"model.layers.27.attn.out_proj.bias": "model-00002-of-00002.safetensors",
|
282 |
+
"model.layers.27.attn.out_proj.weight": "model-00002-of-00002.safetensors",
|
283 |
+
"model.layers.27.input_layernorm.bias": "model-00002-of-00002.safetensors",
|
284 |
+
"model.layers.27.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
285 |
+
"model.layers.27.mlp.fc1.weight": "model-00002-of-00002.safetensors",
|
286 |
+
"model.layers.27.mlp.fc2.weight": "model-00002-of-00002.safetensors",
|
287 |
+
"model.layers.27.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
|
288 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
289 |
+
"model.layers.28.attn.in_proj.weight": "model-00002-of-00002.safetensors",
|
290 |
+
"model.layers.28.attn.out_proj.weight": "model-00002-of-00002.safetensors",
|
291 |
+
"model.layers.28.input_layernorm.bias": "model-00002-of-00002.safetensors",
|
292 |
+
"model.layers.28.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
293 |
+
"model.layers.28.mlp.fc1.weight": "model-00002-of-00002.safetensors",
|
294 |
+
"model.layers.28.mlp.fc2.weight": "model-00002-of-00002.safetensors",
|
295 |
+
"model.layers.28.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
|
296 |
+
"model.layers.28.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
297 |
+
"model.layers.29.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
|
298 |
+
"model.layers.29.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
|
299 |
+
"model.layers.29.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
|
300 |
+
"model.layers.29.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
|
301 |
+
"model.layers.29.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
|
302 |
+
"model.layers.29.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
|
303 |
+
"model.layers.29.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
|
304 |
+
"model.layers.29.attn.out_proj.bias": "model-00002-of-00002.safetensors",
|
305 |
+
"model.layers.29.attn.out_proj.weight": "model-00002-of-00002.safetensors",
|
306 |
+
"model.layers.29.input_layernorm.bias": "model-00002-of-00002.safetensors",
|
307 |
+
"model.layers.29.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
308 |
+
"model.layers.29.mlp.fc1.weight": "model-00002-of-00002.safetensors",
|
309 |
+
"model.layers.29.mlp.fc2.weight": "model-00002-of-00002.safetensors",
|
310 |
+
"model.layers.29.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
|
311 |
+
"model.layers.29.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
312 |
+
"model.layers.3.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
|
313 |
+
"model.layers.3.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
|
314 |
+
"model.layers.3.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
|
315 |
+
"model.layers.3.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
|
316 |
+
"model.layers.3.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
|
317 |
+
"model.layers.3.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
|
318 |
+
"model.layers.3.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
|
319 |
+
"model.layers.3.attn.out_proj.bias": "model-00001-of-00002.safetensors",
|
320 |
+
"model.layers.3.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
321 |
+
"model.layers.3.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
322 |
+
"model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
323 |
+
"model.layers.3.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
324 |
+
"model.layers.3.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
325 |
+
"model.layers.3.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
326 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
327 |
+
"model.layers.30.attn.in_proj.weight": "model-00002-of-00002.safetensors",
|
328 |
+
"model.layers.30.attn.out_proj.weight": "model-00002-of-00002.safetensors",
|
329 |
+
"model.layers.30.input_layernorm.bias": "model-00002-of-00002.safetensors",
|
330 |
+
"model.layers.30.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
331 |
+
"model.layers.30.mlp.fc1.weight": "model-00002-of-00002.safetensors",
|
332 |
+
"model.layers.30.mlp.fc2.weight": "model-00002-of-00002.safetensors",
|
333 |
+
"model.layers.30.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
|
334 |
+
"model.layers.30.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
335 |
+
"model.layers.31.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
|
336 |
+
"model.layers.31.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
|
337 |
+
"model.layers.31.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
|
338 |
+
"model.layers.31.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
|
339 |
+
"model.layers.31.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
|
340 |
+
"model.layers.31.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
|
341 |
+
"model.layers.31.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
|
342 |
+
"model.layers.31.attn.out_proj.bias": "model-00002-of-00002.safetensors",
|
343 |
+
"model.layers.31.attn.out_proj.weight": "model-00002-of-00002.safetensors",
|
344 |
+
"model.layers.31.input_layernorm.bias": "model-00002-of-00002.safetensors",
|
345 |
+
"model.layers.31.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
346 |
+
"model.layers.31.mlp.fc1.weight": "model-00002-of-00002.safetensors",
|
347 |
+
"model.layers.31.mlp.fc2.weight": "model-00002-of-00002.safetensors",
|
348 |
+
"model.layers.31.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
|
349 |
+
"model.layers.31.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
350 |
+
"model.layers.4.attn.A_log": "model-00001-of-00002.safetensors",
|
351 |
+
"model.layers.4.attn.D": "model-00001-of-00002.safetensors",
|
352 |
+
"model.layers.4.attn.conv1d.bias": "model-00001-of-00002.safetensors",
|
353 |
+
"model.layers.4.attn.conv1d.weight": "model-00001-of-00002.safetensors",
|
354 |
+
"model.layers.4.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
|
355 |
+
"model.layers.4.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
|
356 |
+
"model.layers.4.attn.in_proj.weight": "model-00001-of-00002.safetensors",
|
357 |
+
"model.layers.4.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
358 |
+
"model.layers.4.attn.x_proj.weight": "model-00001-of-00002.safetensors",
|
359 |
+
"model.layers.4.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
360 |
+
"model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
361 |
+
"model.layers.4.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
362 |
+
"model.layers.4.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
363 |
+
"model.layers.4.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
364 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
365 |
+
"model.layers.5.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
|
366 |
+
"model.layers.5.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
|
367 |
+
"model.layers.5.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
|
368 |
+
"model.layers.5.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
|
369 |
+
"model.layers.5.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
|
370 |
+
"model.layers.5.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
|
371 |
+
"model.layers.5.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
|
372 |
+
"model.layers.5.attn.out_proj.bias": "model-00001-of-00002.safetensors",
|
373 |
+
"model.layers.5.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
374 |
+
"model.layers.5.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
375 |
+
"model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
376 |
+
"model.layers.5.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
377 |
+
"model.layers.5.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
378 |
+
"model.layers.5.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
379 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
380 |
+
"model.layers.6.attn.A_log": "model-00001-of-00002.safetensors",
|
381 |
+
"model.layers.6.attn.D": "model-00001-of-00002.safetensors",
|
382 |
+
"model.layers.6.attn.conv1d.bias": "model-00001-of-00002.safetensors",
|
383 |
+
"model.layers.6.attn.conv1d.weight": "model-00001-of-00002.safetensors",
|
384 |
+
"model.layers.6.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
|
385 |
+
"model.layers.6.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
|
386 |
+
"model.layers.6.attn.in_proj.weight": "model-00001-of-00002.safetensors",
|
387 |
+
"model.layers.6.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
388 |
+
"model.layers.6.attn.x_proj.weight": "model-00001-of-00002.safetensors",
|
389 |
+
"model.layers.6.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
390 |
+
"model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
391 |
+
"model.layers.6.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
392 |
+
"model.layers.6.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
393 |
+
"model.layers.6.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
394 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
395 |
+
"model.layers.7.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
|
396 |
+
"model.layers.7.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
|
397 |
+
"model.layers.7.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
|
398 |
+
"model.layers.7.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
|
399 |
+
"model.layers.7.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
|
400 |
+
"model.layers.7.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
|
401 |
+
"model.layers.7.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
|
402 |
+
"model.layers.7.attn.out_proj.bias": "model-00001-of-00002.safetensors",
|
403 |
+
"model.layers.7.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
404 |
+
"model.layers.7.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
405 |
+
"model.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
406 |
+
"model.layers.7.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
407 |
+
"model.layers.7.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
408 |
+
"model.layers.7.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
409 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
410 |
+
"model.layers.8.attn.A_log": "model-00001-of-00002.safetensors",
|
411 |
+
"model.layers.8.attn.D": "model-00001-of-00002.safetensors",
|
412 |
+
"model.layers.8.attn.conv1d.bias": "model-00001-of-00002.safetensors",
|
413 |
+
"model.layers.8.attn.conv1d.weight": "model-00001-of-00002.safetensors",
|
414 |
+
"model.layers.8.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
|
415 |
+
"model.layers.8.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
|
416 |
+
"model.layers.8.attn.in_proj.weight": "model-00001-of-00002.safetensors",
|
417 |
+
"model.layers.8.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
418 |
+
"model.layers.8.attn.x_proj.weight": "model-00001-of-00002.safetensors",
|
419 |
+
"model.layers.8.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
420 |
+
"model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
421 |
+
"model.layers.8.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
422 |
+
"model.layers.8.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
423 |
+
"model.layers.8.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
424 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
425 |
+
"model.layers.9.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
|
426 |
+
"model.layers.9.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
|
427 |
+
"model.layers.9.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
|
428 |
+
"model.layers.9.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
|
429 |
+
"model.layers.9.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
|
430 |
+
"model.layers.9.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
|
431 |
+
"model.layers.9.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
|
432 |
+
"model.layers.9.attn.out_proj.bias": "model-00001-of-00002.safetensors",
|
433 |
+
"model.layers.9.attn.out_proj.weight": "model-00001-of-00002.safetensors",
|
434 |
+
"model.layers.9.input_layernorm.bias": "model-00001-of-00002.safetensors",
|
435 |
+
"model.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
436 |
+
"model.layers.9.mlp.fc1.weight": "model-00001-of-00002.safetensors",
|
437 |
+
"model.layers.9.mlp.fc2.weight": "model-00001-of-00002.safetensors",
|
438 |
+
"model.layers.9.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
|
439 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors"
|
440 |
+
}
|
441 |
+
}
|
modeling_phi4flash.py
ADDED
@@ -0,0 +1,2098 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
""" PyTorch Phi4Flash model."""
|
17 |
+
|
18 |
+
|
19 |
+
import inspect
|
20 |
+
import math
|
21 |
+
import warnings
|
22 |
+
from typing import List, Optional, Tuple, Union, Dict, Any
|
23 |
+
import copy
|
24 |
+
import torch
|
25 |
+
import torch.nn.functional as F
|
26 |
+
import torch.utils.checkpoint
|
27 |
+
from torch import nn
|
28 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
29 |
+
from transformers.activations import ACT2FN
|
30 |
+
from transformers.cache_utils import Cache, DynamicCache
|
31 |
+
from transformers.utils import is_torchdynamo_compiling
|
32 |
+
from transformers.modeling_outputs import (
|
33 |
+
BaseModelOutputWithPast,
|
34 |
+
CausalLMOutputWithPast,
|
35 |
+
SequenceClassifierOutputWithPast,
|
36 |
+
TokenClassifierOutput,
|
37 |
+
)
|
38 |
+
from transformers.modeling_utils import PreTrainedModel
|
39 |
+
from transformers.generation import GenerationMixin
|
40 |
+
from transformers.utils import (
|
41 |
+
add_code_sample_docstrings,
|
42 |
+
add_start_docstrings,
|
43 |
+
add_start_docstrings_to_model_forward,
|
44 |
+
is_flash_attn_greater_or_equal_2_10,
|
45 |
+
logging,
|
46 |
+
replace_return_docstrings,
|
47 |
+
)
|
48 |
+
from einops import rearrange, repeat
|
49 |
+
|
50 |
+
from .configuration_phi4flash import Phi4FlashConfig
|
51 |
+
|
52 |
+
logger = logging.get_logger(__name__)
|
53 |
+
|
54 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
55 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
56 |
+
|
57 |
+
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
58 |
+
|
59 |
+
if not _flash_supports_window_size:
|
60 |
+
raise ValueError("Please update flash-attention to support window size.")
|
61 |
+
|
62 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
63 |
+
import causal_conv1d_cuda
|
64 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
65 |
+
|
66 |
+
from torch.amp import custom_bwd, custom_fwd
|
67 |
+
import selective_scan_cuda
|
68 |
+
|
69 |
+
_CHECKPOINT_FOR_DOC = "microsoft/Phi-4-mini-flash-reasoning"
|
70 |
+
_CONFIG_FOR_DOC = "Phi4FlashConfig"
|
71 |
+
|
72 |
+
# monkey patch to add support for our cache
|
73 |
+
def _prepare_cache_for_generation(
|
74 |
+
self,
|
75 |
+
generation_config,
|
76 |
+
model_kwargs: Dict,
|
77 |
+
assistant_model: "PreTrainedModel",
|
78 |
+
batch_size: int,
|
79 |
+
max_cache_length: int,
|
80 |
+
device: torch.device,
|
81 |
+
) -> bool:
|
82 |
+
"""
|
83 |
+
Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is
|
84 |
+
instantiated, writes it to `model_kwargs`, under the name expected by the model.
|
85 |
+
"""
|
86 |
+
|
87 |
+
cache_name = "past_key_values"
|
88 |
+
|
89 |
+
# Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in
|
90 |
+
# `generation_config.validate()`)
|
91 |
+
if generation_config.use_cache is False:
|
92 |
+
return
|
93 |
+
|
94 |
+
# Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation`
|
95 |
+
|
96 |
+
# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
|
97 |
+
# which is only supported in dynamic caches atm
|
98 |
+
if assistant_model is not None:
|
99 |
+
logger.warning_once(
|
100 |
+
"An assistant model is provided, using a dynamic cache instead of a cache of type="
|
101 |
+
f"'{generation_config.cache_implementation}'."
|
102 |
+
)
|
103 |
+
model_kwargs[cache_name] = DynamicCache()
|
104 |
+
return
|
105 |
+
|
106 |
+
model_kwargs[cache_name] = self._get_cache(
|
107 |
+
cache_implementation="sambay",
|
108 |
+
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
|
109 |
+
max_cache_len=max_cache_length,
|
110 |
+
device=device,
|
111 |
+
model_kwargs=model_kwargs,
|
112 |
+
)
|
113 |
+
|
114 |
+
def _get_cache(
|
115 |
+
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
|
116 |
+
) -> Cache:
|
117 |
+
"""
|
118 |
+
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
|
119 |
+
new `generate` call requires a larger cache or uses a different batch size.
|
120 |
+
|
121 |
+
Returns the resulting cache object.
|
122 |
+
"""
|
123 |
+
cache_cls: Cache = SambaYCache
|
124 |
+
requires_cross_attention_cache = (
|
125 |
+
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
126 |
+
)
|
127 |
+
|
128 |
+
if hasattr(self, "_cache"):
|
129 |
+
cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
|
130 |
+
|
131 |
+
if cache_implementation == "sliding_window":
|
132 |
+
max_cache_len = min(self.config.sliding_window[1], max_cache_len)
|
133 |
+
|
134 |
+
need_new_cache = (
|
135 |
+
not hasattr(self, "_cache")
|
136 |
+
or (not isinstance(cache_to_check, cache_cls))
|
137 |
+
or cache_to_check.batch_size != batch_size
|
138 |
+
)
|
139 |
+
if cache_implementation != "mamba":
|
140 |
+
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
|
141 |
+
|
142 |
+
if requires_cross_attention_cache and hasattr(self, "_cache"):
|
143 |
+
need_new_cache = (
|
144 |
+
need_new_cache
|
145 |
+
or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
|
146 |
+
)
|
147 |
+
|
148 |
+
if need_new_cache:
|
149 |
+
if hasattr(self.config, "_pre_quantization_dtype"):
|
150 |
+
cache_dtype = self.config._pre_quantization_dtype
|
151 |
+
else:
|
152 |
+
if not is_torchdynamo_compiling():
|
153 |
+
cache_dtype = self.dtype
|
154 |
+
else:
|
155 |
+
# NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`.
|
156 |
+
# Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative
|
157 |
+
# models. May cause trobles with non-text modalities.
|
158 |
+
cache_dtype = self.get_output_embeddings().weight.dtype
|
159 |
+
|
160 |
+
def get_layer_device_map(execution_device_map: Optional[dict] = None):
|
161 |
+
if execution_device_map is None:
|
162 |
+
return None
|
163 |
+
elif len(execution_device_map) == 1 and "" in execution_device_map:
|
164 |
+
return {idx: execution_device_map[""] for idx in range(self.config.num_hidden_layers)}
|
165 |
+
layer_device_map = {}
|
166 |
+
for layer in execution_device_map:
|
167 |
+
for idx in range(self.config.num_hidden_layers):
|
168 |
+
if f".{idx}." in f"{layer}.":
|
169 |
+
layer_device_map[idx] = execution_device_map[layer]
|
170 |
+
break
|
171 |
+
for idx in range(self.config.num_hidden_layers):
|
172 |
+
if idx not in layer_device_map:
|
173 |
+
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
|
174 |
+
return layer_device_map
|
175 |
+
|
176 |
+
execution_device_map = None
|
177 |
+
# Taken from dispatch_model from accelerate.
|
178 |
+
# This is needed here if we don't want to make changes in accelerate in order to save execution_device
|
179 |
+
# For offloaded case, we need to get the execution device, not just the device where it is offloaded
|
180 |
+
if hasattr(self, "hf_device_map"):
|
181 |
+
main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
|
182 |
+
execution_device_map = {
|
183 |
+
name: main_device if device in ["cpu", "disk"] else device
|
184 |
+
for name, device in self.hf_device_map.items()
|
185 |
+
}
|
186 |
+
layer_device_map = get_layer_device_map(execution_device_map)
|
187 |
+
|
188 |
+
cache_kwargs = {
|
189 |
+
"config": self.config.get_text_config(),
|
190 |
+
"batch_size": batch_size,
|
191 |
+
"max_cache_len": max_cache_len,
|
192 |
+
"device": device,
|
193 |
+
"dtype": cache_dtype,
|
194 |
+
"layer_device_map": layer_device_map,
|
195 |
+
}
|
196 |
+
self._cache = cache_cls(**cache_kwargs)
|
197 |
+
else:
|
198 |
+
self._cache.reset()
|
199 |
+
return self._cache
|
200 |
+
|
201 |
+
GenerationMixin._prepare_cache_for_generation = _prepare_cache_for_generation
|
202 |
+
GenerationMixin._get_cache = _get_cache
|
203 |
+
|
204 |
+
class SambaYCache(Cache):
|
205 |
+
"""
|
206 |
+
A dynamic cache that can handle the sliding window attention cache, one layer of full attention cache and the mamba cache
|
207 |
+
(which has a constant shape regardless of seq_len).
|
208 |
+
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(self,
|
212 |
+
config: Phi4FlashConfig,
|
213 |
+
batch_size: int = None,
|
214 |
+
max_cache_len: int = None,
|
215 |
+
device: Union[torch.device, str] = "cuda",
|
216 |
+
dtype: torch.dtype = torch.float16,
|
217 |
+
max_batch_size: Optional[int] = None,
|
218 |
+
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
219 |
+
) -> None:
|
220 |
+
super().__init__()
|
221 |
+
self.dtype = dtype
|
222 |
+
self.has_previous_state = False # only used by mamba
|
223 |
+
intermediate_size = config.mamba_expand * config.hidden_size
|
224 |
+
ssm_state_size = config.mamba_d_state
|
225 |
+
conv_kernel_size = config.mamba_d_conv
|
226 |
+
self.conv_kernel_size = conv_kernel_size
|
227 |
+
|
228 |
+
if batch_size is not None:
|
229 |
+
logger.warning_once(
|
230 |
+
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
231 |
+
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
232 |
+
)
|
233 |
+
|
234 |
+
self.max_cache_len = max_cache_len
|
235 |
+
self.max_batch_size = batch_size or max_batch_size
|
236 |
+
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
237 |
+
self.head_dim = config.hidden_size // config.num_attention_heads
|
238 |
+
self.num_key_value_heads = config.num_key_value_heads
|
239 |
+
self.global_attn_idx = config.num_hidden_layers//2 + 1
|
240 |
+
self.key_cache: List[torch.Tensor] = []
|
241 |
+
self.value_cache: List[torch.Tensor] = []
|
242 |
+
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
|
243 |
+
sliding_cache_shape = (
|
244 |
+
self.max_batch_size,
|
245 |
+
self.num_key_value_heads,
|
246 |
+
min(config.sliding_window[1], max_cache_len),
|
247 |
+
self.head_dim,
|
248 |
+
)
|
249 |
+
conv_cache_shape = (self.max_batch_size, intermediate_size, conv_kernel_size)
|
250 |
+
ssm_cache_shape = (self.max_batch_size, intermediate_size, ssm_state_size)
|
251 |
+
for i in range(config.num_hidden_layers//2 + 2):
|
252 |
+
if layer_device_map is not None:
|
253 |
+
layer_device = layer_device_map[i]
|
254 |
+
else:
|
255 |
+
layer_device = device
|
256 |
+
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
257 |
+
# breaks when updating the cache.
|
258 |
+
if i == self.global_attn_idx:
|
259 |
+
key_cache_shape = value_cache_shape = global_cache_shape
|
260 |
+
elif i % 2 == 0:
|
261 |
+
key_cache_shape = conv_cache_shape
|
262 |
+
value_cache_shape = ssm_cache_shape
|
263 |
+
else:
|
264 |
+
key_cache_shape = value_cache_shape = sliding_cache_shape
|
265 |
+
new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=layer_device)
|
266 |
+
new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=layer_device)
|
267 |
+
torch._dynamo.mark_static_address(new_layer_key_cache)
|
268 |
+
torch._dynamo.mark_static_address(new_layer_value_cache)
|
269 |
+
self.key_cache.append(new_layer_key_cache)
|
270 |
+
self.value_cache.append(new_layer_value_cache)
|
271 |
+
|
272 |
+
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
273 |
+
if cache_position.shape[0] > max_cache_len:
|
274 |
+
k_out = key_states[:, :, -max_cache_len:, :]
|
275 |
+
v_out = value_states[:, :, -max_cache_len:, :]
|
276 |
+
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
|
277 |
+
self.key_cache[layer_idx] += k_out
|
278 |
+
self.value_cache[layer_idx] += v_out
|
279 |
+
# we should return the whole states instead of k_out, v_out to take the whole prompt
|
280 |
+
# into consideration when building kv cache instead of just throwing away tokens outside of the window
|
281 |
+
return key_states, value_states
|
282 |
+
|
283 |
+
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
|
284 |
+
cache_position = cache_position.clamp(0, max_cache_len - 1)
|
285 |
+
to_shift = cache_position >= max_cache_len - 1
|
286 |
+
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
|
287 |
+
k_out = k_out[:, :, indices]
|
288 |
+
v_out = v_out[:, :, indices]
|
289 |
+
|
290 |
+
k_out[:, :, cache_position] = key_states
|
291 |
+
v_out[:, :, cache_position] = value_states
|
292 |
+
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
|
293 |
+
self.key_cache[layer_idx].zero_()
|
294 |
+
self.value_cache[layer_idx].zero_()
|
295 |
+
|
296 |
+
self.key_cache[layer_idx] += k_out
|
297 |
+
self.value_cache[layer_idx] += v_out
|
298 |
+
return k_out, v_out
|
299 |
+
|
300 |
+
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
301 |
+
k_out[:, :, cache_position] = key_states
|
302 |
+
v_out[:, :, cache_position] = value_states
|
303 |
+
|
304 |
+
self.key_cache[layer_idx] = k_out
|
305 |
+
self.value_cache[layer_idx] = v_out
|
306 |
+
return k_out, v_out
|
307 |
+
|
308 |
+
def update(
|
309 |
+
self,
|
310 |
+
key_states: torch.Tensor,
|
311 |
+
value_states: torch.Tensor,
|
312 |
+
layer_idx: int,
|
313 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
314 |
+
) -> Tuple[torch.Tensor]:
|
315 |
+
cache_position = cache_kwargs.get("cache_position")
|
316 |
+
k_out = self.key_cache[layer_idx]
|
317 |
+
v_out = self.value_cache[layer_idx]
|
318 |
+
if layer_idx == self.global_attn_idx:
|
319 |
+
update_fn = self._static_update
|
320 |
+
elif layer_idx % 2 == 1:
|
321 |
+
update_fn = self._sliding_update
|
322 |
+
|
323 |
+
return update_fn(
|
324 |
+
cache_position,
|
325 |
+
layer_idx,
|
326 |
+
key_states,
|
327 |
+
value_states,
|
328 |
+
k_out,
|
329 |
+
v_out,
|
330 |
+
k_out.shape[2],
|
331 |
+
)
|
332 |
+
|
333 |
+
def get_max_cache_shape(self) -> Optional[int]:
|
334 |
+
return self.max_cache_len
|
335 |
+
|
336 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0):
|
337 |
+
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
338 |
+
# limit the check to the first batch member and head dimension.
|
339 |
+
# TODO: deprecate this function in favor of `cache_position`
|
340 |
+
return (self.key_cache[self.global_attn_idx][0, 0].any(dim=-1)).sum()
|
341 |
+
|
342 |
+
def reset(self):
|
343 |
+
"""Resets the cache values while preserving the objects"""
|
344 |
+
for layer_idx in range(len(self.key_cache)):
|
345 |
+
# In-place ops prevent breaking the static address
|
346 |
+
self.key_cache[layer_idx].zero_()
|
347 |
+
self.value_cache[layer_idx].zero_()
|
348 |
+
|
349 |
+
@property
|
350 |
+
def batch_size(self):
|
351 |
+
logger.warning_once(
|
352 |
+
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
|
353 |
+
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
|
354 |
+
)
|
355 |
+
return self.max_batch_size
|
356 |
+
|
357 |
+
|
358 |
+
|
359 |
+
|
360 |
+
swiglu_fwd_codestring = """
|
361 |
+
template <typename T> T swiglu_fwd(T x, T y) {
|
362 |
+
return float(x) * float(y) / (1.0f + ::exp(-float(x)));
|
363 |
+
}
|
364 |
+
"""
|
365 |
+
swiglu_bwd_codestring = """
|
366 |
+
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
|
367 |
+
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
|
368 |
+
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
|
369 |
+
dy = float(x) * x_sigmoid * float(g);
|
370 |
+
}
|
371 |
+
"""
|
372 |
+
swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
|
373 |
+
swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
|
374 |
+
|
375 |
+
|
376 |
+
class SwiGLUFunction(torch.autograd.Function):
|
377 |
+
|
378 |
+
@staticmethod
|
379 |
+
def forward(ctx, x, y):
|
380 |
+
ctx.save_for_backward(x, y)
|
381 |
+
return swiglu_fwd(x, y)
|
382 |
+
|
383 |
+
@staticmethod
|
384 |
+
def backward(ctx, dout):
|
385 |
+
x, y = ctx.saved_tensors
|
386 |
+
return swiglu_bwd(x, y, dout)
|
387 |
+
|
388 |
+
swiglu = SwiGLUFunction.apply
|
389 |
+
|
390 |
+
|
391 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->SambaY
|
392 |
+
class SambaYRMSNorm(nn.Module):
|
393 |
+
def __init__(self, hidden_size, eps=1e-5):
|
394 |
+
"""
|
395 |
+
SambaYRMSNorm is equivalent to T5LayerNorm
|
396 |
+
"""
|
397 |
+
super().__init__()
|
398 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
399 |
+
self.variance_epsilon = eps
|
400 |
+
|
401 |
+
def forward(self, hidden_states):
|
402 |
+
input_dtype = hidden_states.dtype
|
403 |
+
hidden_states = hidden_states.to(torch.float32)
|
404 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
405 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
406 |
+
return self.weight * hidden_states.to(input_dtype)
|
407 |
+
|
408 |
+
|
409 |
+
PHI_NORM_CLASS = nn.LayerNorm
|
410 |
+
|
411 |
+
|
412 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
413 |
+
def _get_unpad_data(attention_mask):
|
414 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
415 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
416 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
417 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
418 |
+
return (
|
419 |
+
indices,
|
420 |
+
cu_seqlens,
|
421 |
+
max_seqlen_in_batch,
|
422 |
+
)
|
423 |
+
|
424 |
+
|
425 |
+
class SambaYMLP(nn.Module):
|
426 |
+
"""Gated Linear Unit.
|
427 |
+
|
428 |
+
Reference:
|
429 |
+
Language Modeling with Gated Convolutional Networks.
|
430 |
+
https://arxiv.org/pdf/1612.08083v3.pdf.
|
431 |
+
|
432 |
+
"""
|
433 |
+
|
434 |
+
def __init__(self, config):
|
435 |
+
super().__init__()
|
436 |
+
|
437 |
+
self.config = config
|
438 |
+
self.fc1 = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
|
439 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
440 |
+
|
441 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
442 |
+
|
443 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
444 |
+
y = self.fc1(hidden_states)
|
445 |
+
|
446 |
+
# Special case for SwiGLU
|
447 |
+
if self.config.hidden_act == "silu" and swiglu is not None:
|
448 |
+
gate, y = y.chunk(2, dim=-1)
|
449 |
+
y = swiglu(gate, y)
|
450 |
+
else:
|
451 |
+
gate, y = y.chunk(2, dim=-1)
|
452 |
+
y = y * self.activation_fn(gate)
|
453 |
+
|
454 |
+
return self.fc2(y)
|
455 |
+
|
456 |
+
|
457 |
+
class SambaYAttention(nn.Module):
|
458 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
459 |
+
|
460 |
+
def __init__(self, config: Phi4FlashConfig, layer_idx: Optional[int] = None, yoco_cross: bool = False):
|
461 |
+
super().__init__()
|
462 |
+
self.config = config
|
463 |
+
self.layer_idx = layer_idx
|
464 |
+
if layer_idx is None:
|
465 |
+
logger.warning_once(
|
466 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
467 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
468 |
+
"when creating this class."
|
469 |
+
)
|
470 |
+
|
471 |
+
self.attention_dropout = config.attention_dropout
|
472 |
+
self.hidden_size = config.hidden_size
|
473 |
+
self.num_heads = config.num_attention_heads
|
474 |
+
self.head_dim = self.hidden_size // self.num_heads
|
475 |
+
self.num_key_value_heads = config.num_key_value_heads
|
476 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
477 |
+
self.max_position_embeddings = config.max_position_embeddings
|
478 |
+
self.is_causal = True
|
479 |
+
self.yoco_cross = yoco_cross
|
480 |
+
|
481 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
482 |
+
raise ValueError(
|
483 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
484 |
+
f" and `num_heads`: {self.num_heads})."
|
485 |
+
)
|
486 |
+
|
487 |
+
op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
|
488 |
+
self.out_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
|
489 |
+
if yoco_cross:
|
490 |
+
self.Wqkv = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
491 |
+
else:
|
492 |
+
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
|
493 |
+
|
494 |
+
self.inner_cross_attn = FlashDiffCustomAttention(self.head_dim, self.layer_idx,)
|
495 |
+
|
496 |
+
|
497 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
498 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
499 |
+
|
500 |
+
def forward(
|
501 |
+
self,
|
502 |
+
hidden_states: torch.Tensor,
|
503 |
+
attention_mask: Optional[torch.Tensor] = None,
|
504 |
+
position_ids: Optional[torch.LongTensor] = None,
|
505 |
+
past_key_value: Optional[Cache] = None,
|
506 |
+
output_attentions: bool = False,
|
507 |
+
use_cache: bool = False,
|
508 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
509 |
+
raise NotImplementedError("SambaYAttention only support flash attention")
|
510 |
+
|
511 |
+
|
512 |
+
class SambaYFlashAttention2(SambaYAttention):
|
513 |
+
"""
|
514 |
+
SambaY flash attention module. This module inherits from `SambaYAttention` as the weights of the module stays
|
515 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
516 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
517 |
+
"""
|
518 |
+
|
519 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
520 |
+
def __init__(self, *args, **kwargs):
|
521 |
+
super().__init__(*args, **kwargs)
|
522 |
+
|
523 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
524 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
525 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
526 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
527 |
+
|
528 |
+
|
529 |
+
|
530 |
+
def forward(
|
531 |
+
self,
|
532 |
+
hidden_states: torch.Tensor,
|
533 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
534 |
+
position_ids: Optional[torch.LongTensor] = None,
|
535 |
+
past_key_value: Optional[Cache] = None,
|
536 |
+
output_attentions: bool = False,
|
537 |
+
use_cache: bool = False,
|
538 |
+
cache_position: Optional[torch.LongTensor] = None,
|
539 |
+
yoco_key_values: Optional[torch.Tensor] = None,
|
540 |
+
**kwargs,
|
541 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
542 |
+
# SambaYFlashAttention2 attention does not support output_attentions
|
543 |
+
|
544 |
+
output_attentions = False
|
545 |
+
if "padding_mask" in kwargs:
|
546 |
+
warnings.warn(
|
547 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
548 |
+
)
|
549 |
+
|
550 |
+
# overwrite attention_mask with padding_mask
|
551 |
+
attention_mask = kwargs.pop("padding_mask")
|
552 |
+
|
553 |
+
bsz, q_len, _ = hidden_states.size()
|
554 |
+
if self.yoco_cross:
|
555 |
+
q = self.Wqkv(hidden_states)
|
556 |
+
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim).transpose(1,2)
|
557 |
+
key_states, value_states = yoco_key_values
|
558 |
+
query_states = q
|
559 |
+
|
560 |
+
use_sliding_windows = False
|
561 |
+
else:
|
562 |
+
|
563 |
+
qkv = self.Wqkv(hidden_states)
|
564 |
+
query_pos = self.num_heads * self.head_dim
|
565 |
+
query_states = qkv[..., :query_pos]
|
566 |
+
key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
|
567 |
+
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
|
568 |
+
|
569 |
+
# Flash attention requires the input to have the shape
|
570 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
571 |
+
# therefore we just need to keep the original shape
|
572 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
573 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
574 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
575 |
+
|
576 |
+
use_sliding_windows = self.config.sliding_window is not None and self.config.sliding_window[self.layer_idx] is not None
|
577 |
+
|
578 |
+
if past_key_value is not None:
|
579 |
+
|
580 |
+
cache_kwargs = {"cache_position": cache_position}# Specific to RoPE models
|
581 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
582 |
+
|
583 |
+
|
584 |
+
yoco_key_values = key_states, value_states
|
585 |
+
|
586 |
+
attn_dropout = self.attention_dropout if self.training else 0.0
|
587 |
+
|
588 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
589 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
590 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
591 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
592 |
+
# in fp32.
|
593 |
+
|
594 |
+
if query_states.dtype == torch.float32:
|
595 |
+
if torch.is_autocast_enabled():
|
596 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
597 |
+
# Handle the case where the model is quantized
|
598 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
599 |
+
target_dtype = self.config._pre_quantization_dtype
|
600 |
+
else:
|
601 |
+
target_dtype = self.Wqkv.weight.dtype
|
602 |
+
|
603 |
+
logger.warning_once(
|
604 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
605 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
606 |
+
f" {target_dtype}."
|
607 |
+
)
|
608 |
+
|
609 |
+
query_states = query_states.to(target_dtype)
|
610 |
+
key_states = key_states.to(target_dtype)
|
611 |
+
value_states = value_states.to(target_dtype)
|
612 |
+
|
613 |
+
# Reashape to the expected shape for Flash Attention
|
614 |
+
# -> b,q,h,d
|
615 |
+
query_states = query_states.transpose(1, 2)
|
616 |
+
key_states = key_states.transpose(1, 2)
|
617 |
+
value_states = value_states.transpose(1, 2)
|
618 |
+
if attention_mask is not None:
|
619 |
+
key_states = key_states[:, :attention_mask.shape[-1]]
|
620 |
+
value_states = value_states[:, :attention_mask.shape[-1]]
|
621 |
+
attn_output = self._flash_attention_forward(
|
622 |
+
query_states,
|
623 |
+
key_states,
|
624 |
+
value_states,
|
625 |
+
attention_mask,
|
626 |
+
q_len,
|
627 |
+
dropout=attn_dropout,
|
628 |
+
use_sliding_windows=use_sliding_windows,
|
629 |
+
)
|
630 |
+
|
631 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
632 |
+
attn_output = self.out_proj(attn_output)
|
633 |
+
|
634 |
+
if not output_attentions:
|
635 |
+
attn_weights = None
|
636 |
+
|
637 |
+
return attn_output, attn_weights, yoco_key_values
|
638 |
+
|
639 |
+
def _flash_attention_forward(
|
640 |
+
self,
|
641 |
+
query_states,
|
642 |
+
key_states,
|
643 |
+
value_states,
|
644 |
+
attention_mask,
|
645 |
+
query_length,
|
646 |
+
dropout=0.0,
|
647 |
+
softmax_scale=None,
|
648 |
+
use_sliding_windows=False,
|
649 |
+
):
|
650 |
+
"""
|
651 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
652 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
653 |
+
|
654 |
+
Args:
|
655 |
+
query_states (`torch.Tensor`):
|
656 |
+
Input query states to be passed to Flash Attention API
|
657 |
+
key_states (`torch.Tensor`):
|
658 |
+
Input key states to be passed to Flash Attention API
|
659 |
+
value_states (`torch.Tensor`):
|
660 |
+
Input value states to be passed to Flash Attention API
|
661 |
+
attention_mask (`torch.Tensor`):
|
662 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
663 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
664 |
+
dropout (`float`):
|
665 |
+
Attention dropout
|
666 |
+
softmax_scale (`float`, *optional*):
|
667 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
668 |
+
use_sliding_windows (`bool`, *optional*):
|
669 |
+
Whether to activate sliding window attention.
|
670 |
+
"""
|
671 |
+
causal = self.is_causal
|
672 |
+
# Contains at least one padding token in the sequence
|
673 |
+
if attention_mask is not None:
|
674 |
+
batch_size = query_states.shape[0]
|
675 |
+
(
|
676 |
+
query_states,
|
677 |
+
key_states,
|
678 |
+
value_states,
|
679 |
+
indices_q,
|
680 |
+
cu_seq_lens,
|
681 |
+
max_seq_lens,
|
682 |
+
) = self._upad_input(query_states, key_states, value_states, attention_mask, query_length)
|
683 |
+
|
684 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
685 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
686 |
+
|
687 |
+
if not use_sliding_windows:
|
688 |
+
attn_output_unpad = self.inner_cross_attn(
|
689 |
+
query_states,
|
690 |
+
key_states,
|
691 |
+
value_states,
|
692 |
+
cu_seqlens_q=cu_seqlens_q,
|
693 |
+
cu_seqlens_k=cu_seqlens_k,
|
694 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
695 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
696 |
+
dropout_p=dropout,
|
697 |
+
softmax_scale=softmax_scale,
|
698 |
+
causal=causal,
|
699 |
+
)
|
700 |
+
else:
|
701 |
+
attn_output_unpad = self.inner_cross_attn(
|
702 |
+
query_states,
|
703 |
+
key_states,
|
704 |
+
value_states,
|
705 |
+
cu_seqlens_q=cu_seqlens_q,
|
706 |
+
cu_seqlens_k=cu_seqlens_k,
|
707 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
708 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
709 |
+
dropout_p=dropout,
|
710 |
+
softmax_scale=softmax_scale,
|
711 |
+
causal=causal,
|
712 |
+
window_size=(
|
713 |
+
self.config.sliding_window[self.layer_idx] -1,
|
714 |
+
self.config.sliding_window[self.layer_idx] -1,
|
715 |
+
),
|
716 |
+
)
|
717 |
+
|
718 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
719 |
+
else:
|
720 |
+
if not use_sliding_windows:
|
721 |
+
attn_output = self.inner_cross_attn(
|
722 |
+
query_states,
|
723 |
+
key_states,
|
724 |
+
value_states,
|
725 |
+
dropout_p=dropout,
|
726 |
+
softmax_scale=softmax_scale,
|
727 |
+
causal=causal,
|
728 |
+
)
|
729 |
+
else:
|
730 |
+
attn_output = self.inner_cross_attn(
|
731 |
+
query_states,
|
732 |
+
key_states,
|
733 |
+
value_states,
|
734 |
+
dropout_p=dropout,
|
735 |
+
softmax_scale=softmax_scale,
|
736 |
+
causal=causal,
|
737 |
+
window_size=(
|
738 |
+
self.config.sliding_window[self.layer_idx] -1,
|
739 |
+
self.config.sliding_window[self.layer_idx] -1,
|
740 |
+
),
|
741 |
+
)
|
742 |
+
|
743 |
+
return attn_output
|
744 |
+
|
745 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
746 |
+
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
747 |
+
|
748 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
749 |
+
|
750 |
+
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
751 |
+
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
752 |
+
|
753 |
+
if query_length == kv_seq_len:
|
754 |
+
query_layer = index_first_axis(
|
755 |
+
query_layer.reshape(batch_size * kv_seq_len, -1, head_dim),
|
756 |
+
indices_k,
|
757 |
+
)
|
758 |
+
cu_seqlens_q = cu_seqlens_k
|
759 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
760 |
+
indices_q = indices_k
|
761 |
+
elif query_length == 1:
|
762 |
+
max_seqlen_in_batch_q = 1
|
763 |
+
cu_seqlens_q = torch.arange(
|
764 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
765 |
+
) # There is a memcpy here, that is very bad.
|
766 |
+
indices_q = cu_seqlens_q[:-1]
|
767 |
+
query_layer = query_layer.squeeze(1)
|
768 |
+
else:
|
769 |
+
# The -q_len: slice assumes left padding.
|
770 |
+
attention_mask = attention_mask[:, -query_length:]
|
771 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
772 |
+
|
773 |
+
return (
|
774 |
+
query_layer,
|
775 |
+
key_layer,
|
776 |
+
value_layer,
|
777 |
+
indices_q,
|
778 |
+
(cu_seqlens_q, cu_seqlens_k),
|
779 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
780 |
+
)
|
781 |
+
|
782 |
+
|
783 |
+
|
784 |
+
class Phi3Mamba(nn.Module):
|
785 |
+
def __init__(
|
786 |
+
self,
|
787 |
+
d_model,
|
788 |
+
d_state=16,
|
789 |
+
d_conv=4,
|
790 |
+
expand=2,
|
791 |
+
dt_rank="auto",
|
792 |
+
conv_bias=True,
|
793 |
+
bias=False,
|
794 |
+
use_fast_path=True, # Fused kernel options
|
795 |
+
layer_idx=None,
|
796 |
+
yoco_cross=False,
|
797 |
+
yoco_kv=False,
|
798 |
+
dtype=None,
|
799 |
+
):
|
800 |
+
factory_kwargs = {"dtype": dtype}
|
801 |
+
super().__init__()
|
802 |
+
self.d_model = d_model
|
803 |
+
self.d_state = d_state
|
804 |
+
self.d_conv = d_conv
|
805 |
+
self.expand = expand
|
806 |
+
self.d_inner = int(self.expand * self.d_model)
|
807 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
808 |
+
self.use_fast_path = use_fast_path
|
809 |
+
self.layer_idx = layer_idx
|
810 |
+
|
811 |
+
self.yoco_cross = yoco_cross
|
812 |
+
self.yoco_kv = yoco_kv
|
813 |
+
if self.yoco_cross:
|
814 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
|
815 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
816 |
+
else:
|
817 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
818 |
+
|
819 |
+
self.conv1d = nn.Conv1d(
|
820 |
+
in_channels=self.d_inner,
|
821 |
+
out_channels=self.d_inner,
|
822 |
+
bias=conv_bias,
|
823 |
+
kernel_size=d_conv,
|
824 |
+
groups=self.d_inner,
|
825 |
+
padding=d_conv - 1,
|
826 |
+
**factory_kwargs,
|
827 |
+
)
|
828 |
+
|
829 |
+
self.activation = "silu"
|
830 |
+
self.act = nn.SiLU()
|
831 |
+
|
832 |
+
self.x_proj = nn.Linear(
|
833 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
834 |
+
)
|
835 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
836 |
+
|
837 |
+
# S4D real initialization
|
838 |
+
A = repeat(
|
839 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32),
|
840 |
+
"n -> d n",
|
841 |
+
d=self.d_inner,
|
842 |
+
).contiguous()
|
843 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
844 |
+
self.A_log = nn.Parameter(A_log)
|
845 |
+
|
846 |
+
# D "skip" parameter
|
847 |
+
self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32
|
848 |
+
|
849 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
850 |
+
|
851 |
+
def forward(self, hidden_states, inference_params=None, mask= None, yoco_key_values = None, cache_position = None):
|
852 |
+
"""
|
853 |
+
hidden_states: (B, L, D)
|
854 |
+
Returns: same shape as hidden_states
|
855 |
+
"""
|
856 |
+
|
857 |
+
if self.yoco_cross:
|
858 |
+
out = self.in_proj(hidden_states)
|
859 |
+
out = swiglu(out, yoco_key_values)
|
860 |
+
out = self.out_proj(out)
|
861 |
+
return out, yoco_key_values
|
862 |
+
|
863 |
+
batch, seqlen, _ = hidden_states.shape
|
864 |
+
conv_state, ssm_state = None, None
|
865 |
+
if inference_params is not None:
|
866 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params)
|
867 |
+
if cache_position[0] > 0: #inference_params.get_seq_length(self.layer_idx) > 0:
|
868 |
+
# The states are updated inplace
|
869 |
+
out, _, _, yoco_key_values = self.step(hidden_states, conv_state, ssm_state, yoco_key_values)
|
870 |
+
return out, yoco_key_values
|
871 |
+
|
872 |
+
# We do matmul and transpose BLH -> HBL at the same time
|
873 |
+
xz = rearrange(
|
874 |
+
self.in_proj.weight @ rearrange(hidden_states.to(dtype = self.in_proj.weight.dtype), "b l d -> d (b l)"),
|
875 |
+
"d (b l) -> b d l",
|
876 |
+
l=seqlen,
|
877 |
+
)
|
878 |
+
if self.in_proj.bias is not None:
|
879 |
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
880 |
+
|
881 |
+
|
882 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
883 |
+
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
884 |
+
if (not self.yoco_kv) and self.use_fast_path and inference_params is None: # Doesn't support outputting the states
|
885 |
+
out = mamba_inner_fn(
|
886 |
+
xz,
|
887 |
+
self.conv1d.weight,
|
888 |
+
self.conv1d.bias,
|
889 |
+
self.x_proj.weight,
|
890 |
+
self.dt_proj.weight,
|
891 |
+
self.out_proj.weight,
|
892 |
+
self.out_proj.bias,
|
893 |
+
A,
|
894 |
+
None, # input-dependent B
|
895 |
+
None, # input-dependent C
|
896 |
+
self.D.float(),
|
897 |
+
delta_bias=self.dt_proj.bias.float(),
|
898 |
+
mask=mask,
|
899 |
+
delta_softplus=True,
|
900 |
+
)
|
901 |
+
else:
|
902 |
+
x, z = xz.chunk(2, dim=1)
|
903 |
+
if self.yoco_kv:
|
904 |
+
z = z.transpose(-1,-2).contiguous()
|
905 |
+
if mask is not None:
|
906 |
+
x = x * mask.unsqueeze(1)
|
907 |
+
# Compute short convolution
|
908 |
+
if conv_state is not None:
|
909 |
+
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
910 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
911 |
+
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
|
912 |
+
if causal_conv1d_fn is None:
|
913 |
+
x = self.act(self.conv1d(x)[..., :seqlen])
|
914 |
+
else:
|
915 |
+
assert self.activation in ["silu", "swish"]
|
916 |
+
x = causal_conv1d_fn(
|
917 |
+
x=x,
|
918 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
919 |
+
bias=self.conv1d.bias,
|
920 |
+
activation=self.activation,
|
921 |
+
)
|
922 |
+
if mask is not None:
|
923 |
+
x = x * mask.unsqueeze(1)
|
924 |
+
# We're careful here about the layout, to avoid extra transposes.
|
925 |
+
# We want dt to have d as the slowest moving dimension
|
926 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
927 |
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
928 |
+
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
929 |
+
dt = self.dt_proj.weight @ dt.t()
|
930 |
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
931 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
932 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
933 |
+
assert self.activation in ["silu", "swish"]
|
934 |
+
y = selective_scan_fn(
|
935 |
+
x,
|
936 |
+
dt,
|
937 |
+
A,
|
938 |
+
B,
|
939 |
+
C,
|
940 |
+
self.D.float(),
|
941 |
+
z= None if self.yoco_kv else z,
|
942 |
+
delta_bias=self.dt_proj.bias.float(),
|
943 |
+
delta_softplus=True,
|
944 |
+
return_last_state=ssm_state is not None,
|
945 |
+
)
|
946 |
+
if ssm_state is not None:
|
947 |
+
y, last_state = y
|
948 |
+
ssm_state.copy_(last_state)
|
949 |
+
y = rearrange(y, "b d l -> b l d")
|
950 |
+
if self.yoco_kv:
|
951 |
+
yoco_key_values = y
|
952 |
+
y = swiglu(z, y)
|
953 |
+
out = self.out_proj(y)
|
954 |
+
return out, yoco_key_values
|
955 |
+
|
956 |
+
def step(self, hidden_states, conv_state, ssm_state, yoco_key_values):
|
957 |
+
dtype = hidden_states.dtype
|
958 |
+
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
959 |
+
xz = self.in_proj(hidden_states.to(dtype = self.in_proj.weight.dtype).squeeze(1)) # (B 2D)
|
960 |
+
x, z = xz.chunk(2, dim=-1) # (B D)
|
961 |
+
|
962 |
+
# Conv step
|
963 |
+
if causal_conv1d_update is None:
|
964 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
965 |
+
conv_state[:, :, -1] = x
|
966 |
+
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
967 |
+
if self.conv1d.bias is not None:
|
968 |
+
x = x + self.conv1d.bias
|
969 |
+
x = self.act(x).to(dtype=dtype)
|
970 |
+
else:
|
971 |
+
x = causal_conv1d_update(
|
972 |
+
x,
|
973 |
+
conv_state,
|
974 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
975 |
+
self.conv1d.bias,
|
976 |
+
self.activation,
|
977 |
+
)
|
978 |
+
|
979 |
+
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
980 |
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
981 |
+
# Don't add dt_bias here
|
982 |
+
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
983 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
984 |
+
|
985 |
+
# SSM step
|
986 |
+
if selective_state_update is None:
|
987 |
+
# Discretize A and B
|
988 |
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
989 |
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
990 |
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
991 |
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
992 |
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
993 |
+
y = y + self.D.to(dtype) * x
|
994 |
+
y = y * self.act(z) # (B D)
|
995 |
+
else:
|
996 |
+
y = selective_state_update(
|
997 |
+
ssm_state, x, dt, A, B, C, self.D, z= None if self.yoco_kv else z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
998 |
+
)
|
999 |
+
if self.yoco_kv:
|
1000 |
+
yoco_key_values = y.unsqueeze(1)
|
1001 |
+
y = swiglu(z, y)
|
1002 |
+
out = self.out_proj(y)
|
1003 |
+
return out.unsqueeze(1), conv_state, ssm_state, yoco_key_values
|
1004 |
+
|
1005 |
+
def _get_states_from_cache(self, inference_params):
|
1006 |
+
conv_state, ssm_state = inference_params.key_cache[self.layer_idx], inference_params.value_cache[self.layer_idx]
|
1007 |
+
return conv_state, ssm_state
|
1008 |
+
|
1009 |
+
|
1010 |
+
|
1011 |
+
|
1012 |
+
class SambaYDecoderLayer(nn.Module):
|
1013 |
+
def __init__(self, config: Phi4FlashConfig, layer_idx: int):
|
1014 |
+
super().__init__()
|
1015 |
+
|
1016 |
+
self.mlp = SambaYMLP(config)
|
1017 |
+
self.input_layernorm = PHI_NORM_CLASS(config.hidden_size, eps=config.layer_norm_eps)
|
1018 |
+
|
1019 |
+
self.yoco_kv = False
|
1020 |
+
self.yoco_cross = False
|
1021 |
+
self.yoco_mb = False
|
1022 |
+
self.layer_idx = layer_idx
|
1023 |
+
assert config.num_hidden_layers % 4 == 0, 'n_layer should be divisible by 4 for SambaY '
|
1024 |
+
if layer_idx >= config.num_hidden_layers//2:
|
1025 |
+
self.yoco_mb = True
|
1026 |
+
self.yoco_kv = (layer_idx >= (config.num_hidden_layers//2 +1))
|
1027 |
+
self.yoco_cross = (layer_idx >= (config.num_hidden_layers//2 +2))
|
1028 |
+
if (layer_idx >= (config.num_hidden_layers//2 +1)):
|
1029 |
+
config = copy.deepcopy(config)
|
1030 |
+
config.sliding_window = None
|
1031 |
+
self.config= config
|
1032 |
+
|
1033 |
+
self.use_mamba = config.mb_per_layer > 0 and layer_idx % config.mb_per_layer == 0
|
1034 |
+
if self.use_mamba:
|
1035 |
+
factory_kwargs = {"d_conv": config.mamba_d_conv, "d_state": config.mamba_d_state, "expand": config.mamba_expand , "dtype": None}
|
1036 |
+
self.attn = Phi3Mamba(config.hidden_size, layer_idx=layer_idx, yoco_cross=self.yoco_cross, yoco_kv=self.yoco_mb, **factory_kwargs)
|
1037 |
+
else:
|
1038 |
+
self.attn = SambaYFlashAttention2(config, layer_idx=layer_idx, yoco_cross=self.yoco_cross)
|
1039 |
+
|
1040 |
+
self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
|
1041 |
+
self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
|
1042 |
+
self.post_attention_layernorm = PHI_NORM_CLASS(config.hidden_size, eps=config.layer_norm_eps)
|
1043 |
+
|
1044 |
+
def forward(
|
1045 |
+
self,
|
1046 |
+
hidden_states: torch.Tensor,
|
1047 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1048 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1049 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
1050 |
+
output_attentions: Optional[bool] = False,
|
1051 |
+
use_cache: Optional[bool] = False,
|
1052 |
+
cache_position: Optional[torch.LongTensor] = None,
|
1053 |
+
ssm_output: Optional[torch.Tensor] = None,
|
1054 |
+
yoco_key_values: Optional[torch.Tensor] = None,
|
1055 |
+
**kwargs,
|
1056 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
1057 |
+
"""
|
1058 |
+
Args:
|
1059 |
+
hidden_states (`torch.FloatTensor`):
|
1060 |
+
input to the layer of shape `(batch, seq_len, embed_dim)`
|
1061 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
1062 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
1063 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
1064 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
|
1065 |
+
`[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
1066 |
+
output_attentions (`bool`, *optional*):
|
1067 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
1068 |
+
returned tensors for more detail.
|
1069 |
+
use_cache (`bool`, *optional*):
|
1070 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
1071 |
+
(see `past_key_values`).
|
1072 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
1073 |
+
"""
|
1074 |
+
|
1075 |
+
residual = hidden_states
|
1076 |
+
|
1077 |
+
hidden_states = self.input_layernorm(hidden_states.to(dtype=self.input_layernorm.weight.dtype))
|
1078 |
+
|
1079 |
+
if self.use_mamba:
|
1080 |
+
attn_outputs, ssm_output = self.attn(
|
1081 |
+
hidden_states, inference_params=past_key_value,
|
1082 |
+
mask = attention_mask, yoco_key_values = ssm_output,
|
1083 |
+
cache_position=cache_position,
|
1084 |
+
)
|
1085 |
+
residual = residual.to(torch.float32)
|
1086 |
+
self_attn_weights = None
|
1087 |
+
else:
|
1088 |
+
if self.config.sliding_window is not None and self.config.sliding_window[self.layer_idx] is not None and attention_mask is not None: # efficient SDPA and no padding
|
1089 |
+
if past_key_value is not None and cache_position[0] > 0: # when decoding
|
1090 |
+
attention_mask = attention_mask[:, -self.config.sliding_window[self.layer_idx]:]
|
1091 |
+
#hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
|
1092 |
+
# Self Attention
|
1093 |
+
attn_outputs, self_attn_weights, yoco_key_values = self.attn(
|
1094 |
+
hidden_states=hidden_states,
|
1095 |
+
attention_mask=attention_mask,
|
1096 |
+
position_ids=position_ids,
|
1097 |
+
past_key_value=past_key_value,
|
1098 |
+
output_attentions=output_attentions,
|
1099 |
+
use_cache=use_cache,
|
1100 |
+
cache_position=cache_position,
|
1101 |
+
yoco_key_values = yoco_key_values,
|
1102 |
+
)
|
1103 |
+
|
1104 |
+
hidden_states = residual + self.resid_attn_dropout(attn_outputs)
|
1105 |
+
|
1106 |
+
residual = hidden_states
|
1107 |
+
hidden_states = self.post_attention_layernorm(hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype))
|
1108 |
+
hidden_states = self.mlp(hidden_states)
|
1109 |
+
hidden_states = residual + self.resid_mlp_dropout(hidden_states)
|
1110 |
+
|
1111 |
+
outputs = (hidden_states,)
|
1112 |
+
outputs += (ssm_output,)
|
1113 |
+
outputs += (yoco_key_values,)
|
1114 |
+
if output_attentions:
|
1115 |
+
outputs += (self_attn_weights,)
|
1116 |
+
|
1117 |
+
return outputs
|
1118 |
+
|
1119 |
+
|
1120 |
+
PHI_START_DOCSTRING = r"""
|
1121 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
1122 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
1123 |
+
etc.)
|
1124 |
+
|
1125 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
1126 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
1127 |
+
and behavior.
|
1128 |
+
|
1129 |
+
Parameters:
|
1130 |
+
config ([`Phi4FlashConfig`]):
|
1131 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
1132 |
+
load the weights associated with the model, only the configuration. Check out the
|
1133 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
1134 |
+
"""
|
1135 |
+
|
1136 |
+
|
1137 |
+
@add_start_docstrings(
|
1138 |
+
"The bare Phi4Flash Model outputting raw hidden-states without any specific head on top.",
|
1139 |
+
PHI_START_DOCSTRING,
|
1140 |
+
)
|
1141 |
+
class Phi4FlashPreTrainedModel(PreTrainedModel):
|
1142 |
+
config_class = Phi4FlashConfig
|
1143 |
+
base_model_prefix = "model"
|
1144 |
+
supports_gradient_checkpointing = True
|
1145 |
+
_no_split_modules = ["SambaYDecoderLayer"]
|
1146 |
+
_skip_keys_device_placement = "past_key_values"
|
1147 |
+
_supports_flash_attn_2 = True
|
1148 |
+
_supports_sdpa = False
|
1149 |
+
_supports_cache_class = True
|
1150 |
+
|
1151 |
+
def _init_weights(self, module):
|
1152 |
+
std = self.config.initializer_range
|
1153 |
+
if isinstance(module, nn.Linear):
|
1154 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
1155 |
+
if module.bias is not None:
|
1156 |
+
module.bias.data.zero_()
|
1157 |
+
elif isinstance(module, nn.Embedding):
|
1158 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
1159 |
+
if module.padding_idx is not None:
|
1160 |
+
module.weight.data[module.padding_idx].zero_()
|
1161 |
+
|
1162 |
+
|
1163 |
+
PHI_INPUTS_DOCSTRING = r"""
|
1164 |
+
Args:
|
1165 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
1166 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
1167 |
+
it.
|
1168 |
+
|
1169 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
1170 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
1171 |
+
|
1172 |
+
[What are input IDs?](../glossary#input-ids)
|
1173 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1174 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
1175 |
+
|
1176 |
+
- 1 for tokens that are **not masked**,
|
1177 |
+
- 0 for tokens that are **masked**.
|
1178 |
+
|
1179 |
+
[What are attention masks?](../glossary#attention-mask)
|
1180 |
+
|
1181 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
1182 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
1183 |
+
|
1184 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
1185 |
+
`past_key_values`).
|
1186 |
+
|
1187 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
1188 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
1189 |
+
information on the default strategy.
|
1190 |
+
|
1191 |
+
- 1 indicates the head is **not masked**,
|
1192 |
+
- 0 indicates the head is **masked**.
|
1193 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1194 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
1195 |
+
config.n_positions - 1]`.
|
1196 |
+
|
1197 |
+
[What are position IDs?](../glossary#position-ids)
|
1198 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
1199 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
1200 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
1201 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
1202 |
+
|
1203 |
+
Two formats are allowed:
|
1204 |
+
- a [`~cache_utils.Cache`] instance;
|
1205 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
1206 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
1207 |
+
cache format.
|
1208 |
+
|
1209 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
1210 |
+
legacy cache format will be returned.
|
1211 |
+
|
1212 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
1213 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
1214 |
+
of shape `(batch_size, sequence_length)`.
|
1215 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
1216 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
1217 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
1218 |
+
model's internal embedding lookup matrix.
|
1219 |
+
use_cache (`bool`, *optional*):
|
1220 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
1221 |
+
`past_key_values`).
|
1222 |
+
output_attentions (`bool`, *optional*):
|
1223 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
1224 |
+
tensors for more detail.
|
1225 |
+
output_hidden_states (`bool`, *optional*):
|
1226 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
1227 |
+
more detail.
|
1228 |
+
return_dict (`bool`, *optional*):
|
1229 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
1230 |
+
"""
|
1231 |
+
|
1232 |
+
|
1233 |
+
@add_start_docstrings(
|
1234 |
+
"The bare Phi4Flash Model outputting raw hidden-states without any specific head on top.",
|
1235 |
+
PHI_START_DOCSTRING,
|
1236 |
+
)
|
1237 |
+
class Phi4FlashModel(Phi4FlashPreTrainedModel):
|
1238 |
+
"""
|
1239 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SambaYDecoderLayer`]
|
1240 |
+
|
1241 |
+
Args:
|
1242 |
+
config: Phi4FlashConfig
|
1243 |
+
"""
|
1244 |
+
|
1245 |
+
def __init__(self, config: Phi4FlashConfig):
|
1246 |
+
super().__init__(config)
|
1247 |
+
self.padding_idx = config.pad_token_id
|
1248 |
+
self.vocab_size = config.vocab_size
|
1249 |
+
|
1250 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
1251 |
+
self.embed_dropout = nn.Dropout(config.embd_pdrop)
|
1252 |
+
self.layers = nn.ModuleList(
|
1253 |
+
[SambaYDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
1254 |
+
)
|
1255 |
+
self.final_layernorm = PHI_NORM_CLASS(config.hidden_size, eps=config.layer_norm_eps)
|
1256 |
+
|
1257 |
+
self._attn_implementation = config._attn_implementation
|
1258 |
+
|
1259 |
+
self.gradient_checkpointing = False
|
1260 |
+
# Initialize weights and apply final processing
|
1261 |
+
self.post_init()
|
1262 |
+
|
1263 |
+
def get_input_embeddings(self):
|
1264 |
+
return self.embed_tokens
|
1265 |
+
|
1266 |
+
def set_input_embeddings(self, value):
|
1267 |
+
self.embed_tokens = value
|
1268 |
+
|
1269 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
1270 |
+
def forward(
|
1271 |
+
self,
|
1272 |
+
input_ids: torch.LongTensor = None,
|
1273 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1274 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1275 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1276 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1277 |
+
use_cache: Optional[bool] = None,
|
1278 |
+
output_attentions: Optional[bool] = None,
|
1279 |
+
output_hidden_states: Optional[bool] = None,
|
1280 |
+
return_dict: Optional[bool] = None,
|
1281 |
+
cache_position: Optional[torch.LongTensor] = None,
|
1282 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
1283 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1284 |
+
output_hidden_states = (
|
1285 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1286 |
+
)
|
1287 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1288 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1289 |
+
|
1290 |
+
# retrieve input_ids and inputs_embeds
|
1291 |
+
if input_ids is not None and inputs_embeds is not None:
|
1292 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
1293 |
+
elif input_ids is not None:
|
1294 |
+
batch_size, seq_length = input_ids.shape[:2]
|
1295 |
+
elif inputs_embeds is not None:
|
1296 |
+
batch_size, seq_length = inputs_embeds.shape[:2]
|
1297 |
+
else:
|
1298 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
1299 |
+
|
1300 |
+
|
1301 |
+
if self.gradient_checkpointing and self.training:
|
1302 |
+
if use_cache:
|
1303 |
+
logger.warning_once(
|
1304 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
1305 |
+
)
|
1306 |
+
use_cache = False
|
1307 |
+
|
1308 |
+
if inputs_embeds is None:
|
1309 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
1310 |
+
|
1311 |
+
if use_cache and past_key_values is None and not self.training:
|
1312 |
+
batch_size, seq_len, _ = inputs_embeds.shape
|
1313 |
+
past_key_values = SambaYCache(
|
1314 |
+
self.config,
|
1315 |
+
max_batch_size=batch_size,
|
1316 |
+
max_cache_len=seq_len,
|
1317 |
+
device=self.device,
|
1318 |
+
dtype=inputs_embeds.dtype,
|
1319 |
+
)
|
1320 |
+
|
1321 |
+
|
1322 |
+
if cache_position is None:
|
1323 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
1324 |
+
cache_position = torch.arange(
|
1325 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
1326 |
+
)
|
1327 |
+
|
1328 |
+
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache and not self.training:
|
1329 |
+
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
1330 |
+
if is_padding_right:
|
1331 |
+
raise ValueError(
|
1332 |
+
"You are attempting to perform batched generation with padding_side='right'"
|
1333 |
+
" this may lead to unexpected behaviour for Flash Attention version of Phi4Flash. Make sure to "
|
1334 |
+
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1335 |
+
)
|
1336 |
+
|
1337 |
+
hidden_states = inputs_embeds
|
1338 |
+
|
1339 |
+
# decoder layers
|
1340 |
+
all_hidden_states = () if output_hidden_states else None
|
1341 |
+
all_self_attns = () if output_attentions else None
|
1342 |
+
ssm_output = None
|
1343 |
+
yoco_key_values = None
|
1344 |
+
for decoder_layer in self.layers: # TODO: only need to inference the first half of the layers during pre-fill
|
1345 |
+
if output_hidden_states:
|
1346 |
+
all_hidden_states += (hidden_states,)
|
1347 |
+
|
1348 |
+
if self.gradient_checkpointing and self.training:
|
1349 |
+
layer_outputs = self._gradient_checkpointing_func(
|
1350 |
+
decoder_layer.__call__,
|
1351 |
+
hidden_states,
|
1352 |
+
attention_mask,
|
1353 |
+
position_ids,
|
1354 |
+
past_key_values,
|
1355 |
+
output_attentions,
|
1356 |
+
use_cache,
|
1357 |
+
cache_position,
|
1358 |
+
ssm_output,
|
1359 |
+
yoco_key_values,
|
1360 |
+
)
|
1361 |
+
else:
|
1362 |
+
layer_outputs = decoder_layer(
|
1363 |
+
hidden_states,
|
1364 |
+
attention_mask=attention_mask,
|
1365 |
+
position_ids=position_ids,
|
1366 |
+
past_key_value=past_key_values,
|
1367 |
+
output_attentions=output_attentions,
|
1368 |
+
use_cache=use_cache,
|
1369 |
+
cache_position = cache_position,
|
1370 |
+
ssm_output = ssm_output,
|
1371 |
+
yoco_key_values = yoco_key_values,
|
1372 |
+
)
|
1373 |
+
|
1374 |
+
hidden_states = layer_outputs[0]
|
1375 |
+
ssm_output = layer_outputs[1]
|
1376 |
+
yoco_key_values = layer_outputs[2]
|
1377 |
+
|
1378 |
+
if output_attentions:
|
1379 |
+
all_self_attns += (layer_outputs[3],)
|
1380 |
+
|
1381 |
+
hidden_states = self.final_layernorm(hidden_states.to(dtype=self.final_layernorm.weight.dtype))
|
1382 |
+
|
1383 |
+
# add hidden states from the last decoder layer
|
1384 |
+
if output_hidden_states:
|
1385 |
+
all_hidden_states += (hidden_states,)
|
1386 |
+
|
1387 |
+
output = BaseModelOutputWithPast(
|
1388 |
+
last_hidden_state=hidden_states,
|
1389 |
+
past_key_values=past_key_values,
|
1390 |
+
hidden_states=all_hidden_states,
|
1391 |
+
attentions=all_self_attns,
|
1392 |
+
)
|
1393 |
+
return output if return_dict else output.to_tuple()
|
1394 |
+
|
1395 |
+
|
1396 |
+
|
1397 |
+
class Phi4FlashForCausalLM(Phi4FlashPreTrainedModel, GenerationMixin):
|
1398 |
+
_tied_weights_keys = ["lm_head.weight"]
|
1399 |
+
|
1400 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi4Flash,bias=False->bias=True
|
1401 |
+
def __init__(self, config):
|
1402 |
+
super().__init__(config)
|
1403 |
+
self.model = Phi4FlashModel(config)
|
1404 |
+
self.vocab_size = config.vocab_size
|
1405 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1406 |
+
|
1407 |
+
# Initialize weights and apply final processing
|
1408 |
+
self.post_init()
|
1409 |
+
|
1410 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
|
1411 |
+
def get_input_embeddings(self):
|
1412 |
+
return self.model.embed_tokens
|
1413 |
+
|
1414 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
|
1415 |
+
def set_input_embeddings(self, value):
|
1416 |
+
self.model.embed_tokens = value
|
1417 |
+
|
1418 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
|
1419 |
+
def get_output_embeddings(self):
|
1420 |
+
return self.lm_head
|
1421 |
+
|
1422 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
|
1423 |
+
def set_output_embeddings(self, new_embeddings):
|
1424 |
+
self.lm_head = new_embeddings
|
1425 |
+
|
1426 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
|
1427 |
+
def set_decoder(self, decoder):
|
1428 |
+
self.model = decoder
|
1429 |
+
|
1430 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
|
1431 |
+
def get_decoder(self):
|
1432 |
+
return self.model
|
1433 |
+
|
1434 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
1435 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1436 |
+
def forward(
|
1437 |
+
self,
|
1438 |
+
input_ids: torch.LongTensor = None,
|
1439 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1440 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1441 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1442 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1443 |
+
labels: Optional[torch.LongTensor] = None,
|
1444 |
+
use_cache: Optional[bool] = None,
|
1445 |
+
output_attentions: Optional[bool] = None,
|
1446 |
+
output_hidden_states: Optional[bool] = None,
|
1447 |
+
return_dict: Optional[bool] = None,
|
1448 |
+
cache_position: Optional[torch.LongTensor] = None,
|
1449 |
+
num_logits_to_keep: int = 0,
|
1450 |
+
**loss_kwargs,
|
1451 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1452 |
+
r"""
|
1453 |
+
Args:
|
1454 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1455 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1456 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1457 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1458 |
+
|
1459 |
+
Returns:
|
1460 |
+
|
1461 |
+
Example:
|
1462 |
+
|
1463 |
+
```python
|
1464 |
+
>>> from transformers import AutoTokenizer, Phi4FlashForCausalLM
|
1465 |
+
|
1466 |
+
>>> model = Phi4FlashForCausalLM.from_pretrained("microsoft/Phi4-mini-flash-reasoning")
|
1467 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi4-mini-flash-reasoning")
|
1468 |
+
|
1469 |
+
>>> prompt = "This is an example script ."
|
1470 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1471 |
+
|
1472 |
+
>>> # Generate
|
1473 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1474 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1475 |
+
'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
|
1476 |
+
```"""
|
1477 |
+
|
1478 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1479 |
+
output_hidden_states = (
|
1480 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1481 |
+
)
|
1482 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1483 |
+
|
1484 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1485 |
+
outputs = self.model(
|
1486 |
+
input_ids=input_ids,
|
1487 |
+
attention_mask=attention_mask,
|
1488 |
+
position_ids=position_ids,
|
1489 |
+
past_key_values=past_key_values,
|
1490 |
+
inputs_embeds=inputs_embeds,
|
1491 |
+
use_cache=use_cache,
|
1492 |
+
output_attentions=output_attentions,
|
1493 |
+
output_hidden_states=output_hidden_states,
|
1494 |
+
return_dict=return_dict,
|
1495 |
+
cache_position = cache_position,
|
1496 |
+
)
|
1497 |
+
|
1498 |
+
hidden_states = outputs[0]
|
1499 |
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
1500 |
+
|
1501 |
+
loss = None
|
1502 |
+
if labels is not None:
|
1503 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
1504 |
+
|
1505 |
+
if not return_dict:
|
1506 |
+
output = (logits,) + outputs[1:]
|
1507 |
+
return (loss,) + output if loss is not None else output
|
1508 |
+
|
1509 |
+
return CausalLMOutputWithPast(
|
1510 |
+
loss=loss,
|
1511 |
+
logits=logits,
|
1512 |
+
past_key_values=outputs.past_key_values,
|
1513 |
+
hidden_states=outputs.hidden_states,
|
1514 |
+
attentions=outputs.attentions,
|
1515 |
+
)
|
1516 |
+
|
1517 |
+
|
1518 |
+
@add_start_docstrings(
|
1519 |
+
"""
|
1520 |
+
The Phi4FlashModel with a sequence classification head on top (linear layer).
|
1521 |
+
|
1522 |
+
[`Phi4FlashForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1523 |
+
(e.g. GPT-2) do.
|
1524 |
+
|
1525 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1526 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1527 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1528 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1529 |
+
each row of the batch).
|
1530 |
+
""",
|
1531 |
+
PHI_START_DOCSTRING,
|
1532 |
+
)
|
1533 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi4Flash with self.transformer->self.model, transformer_outputs->model_outputs
|
1534 |
+
class Phi4FlashForSequenceClassification(Phi4FlashPreTrainedModel):
|
1535 |
+
def __init__(self, config):
|
1536 |
+
super().__init__(config)
|
1537 |
+
self.num_labels = config.num_labels
|
1538 |
+
self.model = Phi4FlashModel(config)
|
1539 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1540 |
+
|
1541 |
+
# Initialize weights and apply final processing
|
1542 |
+
self.post_init()
|
1543 |
+
|
1544 |
+
def get_input_embeddings(self):
|
1545 |
+
return self.model.embed_tokens
|
1546 |
+
|
1547 |
+
def set_input_embeddings(self, value):
|
1548 |
+
self.model.embed_tokens = value
|
1549 |
+
|
1550 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
1551 |
+
def forward(
|
1552 |
+
self,
|
1553 |
+
input_ids: torch.LongTensor = None,
|
1554 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1555 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1556 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1557 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1558 |
+
labels: Optional[torch.LongTensor] = None,
|
1559 |
+
use_cache: Optional[bool] = None,
|
1560 |
+
output_attentions: Optional[bool] = None,
|
1561 |
+
output_hidden_states: Optional[bool] = None,
|
1562 |
+
return_dict: Optional[bool] = None,
|
1563 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1564 |
+
r"""
|
1565 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1566 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1567 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1568 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1569 |
+
"""
|
1570 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1571 |
+
|
1572 |
+
model_outputs = self.model(
|
1573 |
+
input_ids,
|
1574 |
+
attention_mask=attention_mask,
|
1575 |
+
position_ids=position_ids,
|
1576 |
+
past_key_values=past_key_values,
|
1577 |
+
inputs_embeds=inputs_embeds,
|
1578 |
+
use_cache=use_cache,
|
1579 |
+
output_attentions=output_attentions,
|
1580 |
+
output_hidden_states=output_hidden_states,
|
1581 |
+
return_dict=return_dict,
|
1582 |
+
)
|
1583 |
+
hidden_states = model_outputs[0]
|
1584 |
+
logits = self.score(hidden_states)
|
1585 |
+
|
1586 |
+
if input_ids is not None:
|
1587 |
+
batch_size = input_ids.shape[0]
|
1588 |
+
else:
|
1589 |
+
batch_size = inputs_embeds.shape[0]
|
1590 |
+
|
1591 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
1592 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
1593 |
+
if self.config.pad_token_id is None:
|
1594 |
+
sequence_lengths = -1
|
1595 |
+
else:
|
1596 |
+
if input_ids is not None:
|
1597 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
1598 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
1599 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
1600 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
1601 |
+
else:
|
1602 |
+
sequence_lengths = -1
|
1603 |
+
|
1604 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
1605 |
+
|
1606 |
+
loss = None
|
1607 |
+
if labels is not None:
|
1608 |
+
labels = labels.to(logits.device)
|
1609 |
+
if self.config.problem_type is None:
|
1610 |
+
if self.num_labels == 1:
|
1611 |
+
self.config.problem_type = "regression"
|
1612 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1613 |
+
self.config.problem_type = "single_label_classification"
|
1614 |
+
else:
|
1615 |
+
self.config.problem_type = "multi_label_classification"
|
1616 |
+
|
1617 |
+
if self.config.problem_type == "regression":
|
1618 |
+
loss_fct = MSELoss()
|
1619 |
+
if self.num_labels == 1:
|
1620 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1621 |
+
else:
|
1622 |
+
loss = loss_fct(pooled_logits, labels)
|
1623 |
+
elif self.config.problem_type == "single_label_classification":
|
1624 |
+
loss_fct = CrossEntropyLoss()
|
1625 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
1626 |
+
elif self.config.problem_type == "multi_label_classification":
|
1627 |
+
loss_fct = BCEWithLogitsLoss()
|
1628 |
+
loss = loss_fct(pooled_logits, labels)
|
1629 |
+
if not return_dict:
|
1630 |
+
output = (pooled_logits,) + model_outputs[1:]
|
1631 |
+
return ((loss,) + output) if loss is not None else output
|
1632 |
+
|
1633 |
+
return SequenceClassifierOutputWithPast(
|
1634 |
+
loss=loss,
|
1635 |
+
logits=pooled_logits,
|
1636 |
+
past_key_values=model_outputs.past_key_values,
|
1637 |
+
hidden_states=model_outputs.hidden_states,
|
1638 |
+
attentions=model_outputs.attentions,
|
1639 |
+
)
|
1640 |
+
|
1641 |
+
|
1642 |
+
@add_start_docstrings(
|
1643 |
+
"""
|
1644 |
+
Phi4FlashModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
1645 |
+
Named-Entity-Recognition (NER) tasks.
|
1646 |
+
""",
|
1647 |
+
PHI_START_DOCSTRING,
|
1648 |
+
)
|
1649 |
+
# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi4Flash,self.transformer->self.model,transformer_outputs->model_outputs
|
1650 |
+
class Phi4FlashForTokenClassification(Phi4FlashPreTrainedModel):
|
1651 |
+
def __init__(self, config: Phi4FlashConfig):
|
1652 |
+
super().__init__(config)
|
1653 |
+
self.num_labels = config.num_labels
|
1654 |
+
|
1655 |
+
self.model = Phi4FlashModel(config)
|
1656 |
+
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
|
1657 |
+
classifier_dropout = config.classifier_dropout
|
1658 |
+
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
|
1659 |
+
classifier_dropout = config.hidden_dropout
|
1660 |
+
else:
|
1661 |
+
classifier_dropout = 0.1
|
1662 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
1663 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1664 |
+
|
1665 |
+
# Initialize weights and apply final processing
|
1666 |
+
self.post_init()
|
1667 |
+
|
1668 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
1669 |
+
@add_code_sample_docstrings(
|
1670 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1671 |
+
output_type=TokenClassifierOutput,
|
1672 |
+
config_class=_CONFIG_FOR_DOC,
|
1673 |
+
)
|
1674 |
+
def forward(
|
1675 |
+
self,
|
1676 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1677 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
1678 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1679 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1680 |
+
labels: Optional[torch.Tensor] = None,
|
1681 |
+
use_cache: Optional[bool] = None,
|
1682 |
+
output_attentions: Optional[bool] = None,
|
1683 |
+
output_hidden_states: Optional[bool] = None,
|
1684 |
+
return_dict: Optional[bool] = None,
|
1685 |
+
**deprecated_arguments,
|
1686 |
+
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
1687 |
+
r"""
|
1688 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1689 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1690 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1691 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1692 |
+
"""
|
1693 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1694 |
+
|
1695 |
+
model_outputs = self.model(
|
1696 |
+
input_ids,
|
1697 |
+
past_key_values=past_key_values,
|
1698 |
+
attention_mask=attention_mask,
|
1699 |
+
inputs_embeds=inputs_embeds,
|
1700 |
+
use_cache=use_cache,
|
1701 |
+
output_attentions=output_attentions,
|
1702 |
+
output_hidden_states=output_hidden_states,
|
1703 |
+
return_dict=return_dict,
|
1704 |
+
)
|
1705 |
+
|
1706 |
+
hidden_states = model_outputs[0]
|
1707 |
+
hidden_states = self.dropout(hidden_states)
|
1708 |
+
logits = self.classifier(hidden_states)
|
1709 |
+
|
1710 |
+
loss = None
|
1711 |
+
if labels is not None:
|
1712 |
+
# move labels to correct device to enable model parallelism
|
1713 |
+
labels = labels.to(logits.device)
|
1714 |
+
batch_size, seq_length = labels.shape
|
1715 |
+
loss_fct = CrossEntropyLoss()
|
1716 |
+
loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
|
1717 |
+
|
1718 |
+
if not return_dict:
|
1719 |
+
output = (logits,) + model_outputs[2:]
|
1720 |
+
return ((loss,) + output) if loss is not None else output
|
1721 |
+
|
1722 |
+
return TokenClassifierOutput(
|
1723 |
+
loss=loss,
|
1724 |
+
logits=logits,
|
1725 |
+
hidden_states=model_outputs.hidden_states,
|
1726 |
+
attentions=model_outputs.attentions,
|
1727 |
+
)
|
1728 |
+
|
1729 |
+
## support batched generation
|
1730 |
+
|
1731 |
+
class SelectiveScanFn(torch.autograd.Function):
|
1732 |
+
|
1733 |
+
@staticmethod
|
1734 |
+
def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
1735 |
+
return_last_state=False):
|
1736 |
+
if u.stride(-1) != 1:
|
1737 |
+
u = u.contiguous()
|
1738 |
+
if delta.stride(-1) != 1:
|
1739 |
+
delta = delta.contiguous()
|
1740 |
+
if D is not None:
|
1741 |
+
D = D.contiguous()
|
1742 |
+
if B.stride(-1) != 1:
|
1743 |
+
B = B.contiguous()
|
1744 |
+
if C.stride(-1) != 1:
|
1745 |
+
C = C.contiguous()
|
1746 |
+
if z is not None and z.stride(-1) != 1:
|
1747 |
+
z = z.contiguous()
|
1748 |
+
if B.dim() == 3:
|
1749 |
+
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
1750 |
+
ctx.squeeze_B = True
|
1751 |
+
if C.dim() == 3:
|
1752 |
+
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
1753 |
+
ctx.squeeze_C = True
|
1754 |
+
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
|
1755 |
+
ctx.delta_softplus = delta_softplus
|
1756 |
+
ctx.has_z = z is not None
|
1757 |
+
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
1758 |
+
if not ctx.has_z:
|
1759 |
+
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
1760 |
+
return out if not return_last_state else (out, last_state)
|
1761 |
+
else:
|
1762 |
+
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
1763 |
+
out_z = rest[0]
|
1764 |
+
return out_z if not return_last_state else (out_z, last_state)
|
1765 |
+
|
1766 |
+
@staticmethod
|
1767 |
+
def backward(ctx, dout, *args):
|
1768 |
+
if not ctx.has_z:
|
1769 |
+
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
1770 |
+
z = None
|
1771 |
+
out = None
|
1772 |
+
else:
|
1773 |
+
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
1774 |
+
if dout.stride(-1) != 1:
|
1775 |
+
dout = dout.contiguous()
|
1776 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
1777 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
1778 |
+
# Here we just pass in None and dz will be allocated in the C++ code.
|
1779 |
+
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
|
1780 |
+
u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
|
1781 |
+
False # option to recompute out_z, not used here
|
1782 |
+
)
|
1783 |
+
dz = rest[0] if ctx.has_z else None
|
1784 |
+
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
1785 |
+
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
1786 |
+
return (du, ddelta, dA, dB, dC,
|
1787 |
+
dD if D is not None else None,
|
1788 |
+
dz,
|
1789 |
+
ddelta_bias if delta_bias is not None else None,
|
1790 |
+
None,
|
1791 |
+
None)
|
1792 |
+
|
1793 |
+
|
1794 |
+
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
1795 |
+
return_last_state=False):
|
1796 |
+
"""if return_last_state is True, returns (out, last_state)
|
1797 |
+
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
1798 |
+
not considered in the backward pass.
|
1799 |
+
"""
|
1800 |
+
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
|
1801 |
+
|
1802 |
+
|
1803 |
+
class MambaInnerFn(torch.autograd.Function):
|
1804 |
+
|
1805 |
+
@staticmethod
|
1806 |
+
@custom_fwd(device_type="cuda")
|
1807 |
+
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
1808 |
+
out_proj_weight, out_proj_bias,
|
1809 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
1810 |
+
C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=1,):
|
1811 |
+
"""
|
1812 |
+
xz: (batch, dim, seqlen)
|
1813 |
+
"""
|
1814 |
+
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
1815 |
+
assert checkpoint_lvl in [0, 1]
|
1816 |
+
L = xz.shape[-1]
|
1817 |
+
delta_rank = delta_proj_weight.shape[1]
|
1818 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
1819 |
+
if torch.is_autocast_enabled():
|
1820 |
+
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
1821 |
+
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
1822 |
+
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
1823 |
+
out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
1824 |
+
if out_proj_bias is not None else None)
|
1825 |
+
if xz.stride(-1) != 1:
|
1826 |
+
xz = xz.contiguous()
|
1827 |
+
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
1828 |
+
x, z = xz.chunk(2, dim=1)
|
1829 |
+
if mask is not None:
|
1830 |
+
x = x * mask.unsqueeze(1)
|
1831 |
+
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
1832 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
1833 |
+
x, conv1d_weight, conv1d_bias, None, None, None, True
|
1834 |
+
)
|
1835 |
+
if mask is not None:
|
1836 |
+
conv1d_out = conv1d_out * mask.unsqueeze(1)
|
1837 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
1838 |
+
# We want delta to have d as the slowest moving dimension
|
1839 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
1840 |
+
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
1841 |
+
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
|
1842 |
+
ctx.is_variable_B = B is None
|
1843 |
+
ctx.is_variable_C = C is None
|
1844 |
+
ctx.B_proj_bias_is_None = B_proj_bias is None
|
1845 |
+
ctx.C_proj_bias_is_None = C_proj_bias is None
|
1846 |
+
if B is None: # variable B
|
1847 |
+
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
|
1848 |
+
if B_proj_bias is not None:
|
1849 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
1850 |
+
if not A.is_complex():
|
1851 |
+
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
1852 |
+
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
1853 |
+
else:
|
1854 |
+
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
1855 |
+
else:
|
1856 |
+
if B.stride(-1) != 1:
|
1857 |
+
B = B.contiguous()
|
1858 |
+
if C is None: # variable C
|
1859 |
+
C = x_dbl[:, -d_state:] # (bl dstate)
|
1860 |
+
if C_proj_bias is not None:
|
1861 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
1862 |
+
if not A.is_complex():
|
1863 |
+
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
1864 |
+
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
1865 |
+
else:
|
1866 |
+
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
1867 |
+
else:
|
1868 |
+
if C.stride(-1) != 1:
|
1869 |
+
C = C.contiguous()
|
1870 |
+
if D is not None:
|
1871 |
+
D = D.contiguous()
|
1872 |
+
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
|
1873 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
1874 |
+
)
|
1875 |
+
ctx.delta_softplus = delta_softplus
|
1876 |
+
ctx.out_proj_bias_is_None = out_proj_bias is None
|
1877 |
+
ctx.checkpoint_lvl = checkpoint_lvl
|
1878 |
+
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
|
1879 |
+
conv1d_out, delta = None, None
|
1880 |
+
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
|
1881 |
+
delta_proj_weight, out_proj_weight, conv1d_out, delta,
|
1882 |
+
A, B, C, D, delta_bias, scan_intermediates, out)
|
1883 |
+
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
1884 |
+
|
1885 |
+
@staticmethod
|
1886 |
+
@custom_bwd(device_type="cuda")
|
1887 |
+
def backward(ctx, dout):
|
1888 |
+
# dout: (batch, seqlen, dim)
|
1889 |
+
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
1890 |
+
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
|
1891 |
+
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
|
1892 |
+
L = xz.shape[-1]
|
1893 |
+
delta_rank = delta_proj_weight.shape[1]
|
1894 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
1895 |
+
x, z = xz.chunk(2, dim=1)
|
1896 |
+
if dout.stride(-1) != 1:
|
1897 |
+
dout = dout.contiguous()
|
1898 |
+
if ctx.checkpoint_lvl == 1:
|
1899 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
1900 |
+
x, conv1d_weight, conv1d_bias, None, None, None, True
|
1901 |
+
)
|
1902 |
+
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
|
1903 |
+
"d (b l) -> b d l", l = L)
|
1904 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
1905 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
1906 |
+
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
1907 |
+
dx, dz = dxz.chunk(2, dim=1)
|
1908 |
+
dout = rearrange(dout, "b l e -> e (b l)")
|
1909 |
+
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
1910 |
+
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
|
1911 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
|
1912 |
+
ctx.delta_softplus,
|
1913 |
+
True # option to recompute out_z
|
1914 |
+
)
|
1915 |
+
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
|
1916 |
+
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
1917 |
+
dD = dD if D is not None else None
|
1918 |
+
dx_dbl = torch.empty_like(x_dbl)
|
1919 |
+
dB_proj_bias = None
|
1920 |
+
if ctx.is_variable_B:
|
1921 |
+
if not A.is_complex():
|
1922 |
+
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
1923 |
+
else:
|
1924 |
+
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
1925 |
+
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
1926 |
+
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
|
1927 |
+
dB = None
|
1928 |
+
dC_proj_bias = None
|
1929 |
+
if ctx.is_variable_C:
|
1930 |
+
if not A.is_complex():
|
1931 |
+
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
1932 |
+
else:
|
1933 |
+
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
1934 |
+
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
1935 |
+
dx_dbl[:, -d_state:] = dC # (bl d)
|
1936 |
+
dC = None
|
1937 |
+
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
1938 |
+
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
1939 |
+
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
1940 |
+
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
1941 |
+
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
1942 |
+
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
1943 |
+
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
1944 |
+
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
1945 |
+
# backward of conv1d with the backward of chunk).
|
1946 |
+
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
1947 |
+
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
|
1948 |
+
)
|
1949 |
+
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
1950 |
+
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
1951 |
+
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
|
1952 |
+
dout_proj_weight, dout_proj_bias,
|
1953 |
+
dA, dB, dC, dD,
|
1954 |
+
ddelta_bias if delta_bias is not None else None,
|
1955 |
+
dB_proj_bias, dC_proj_bias, None, None)
|
1956 |
+
|
1957 |
+
|
1958 |
+
def mamba_inner_fn(
|
1959 |
+
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
1960 |
+
out_proj_weight, out_proj_bias,
|
1961 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
1962 |
+
C_proj_bias=None, mask=None, delta_softplus=True
|
1963 |
+
):
|
1964 |
+
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
1965 |
+
out_proj_weight, out_proj_bias,
|
1966 |
+
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, mask, delta_softplus)
|
1967 |
+
|
1968 |
+
|
1969 |
+
def lambda_init_fn(depth):
|
1970 |
+
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
1971 |
+
|
1972 |
+
|
1973 |
+
def split_heads(x):
|
1974 |
+
# split by num_heads, the stripe pattern is friendly to tensor parallel.
|
1975 |
+
x = rearrange(x, "... (H two) D -> ... H two D", two=2)
|
1976 |
+
x1 = x[..., 0, :]
|
1977 |
+
x2 = x[..., 1, :]
|
1978 |
+
return x1, x2
|
1979 |
+
|
1980 |
+
class FlashDiffCustomAttention(nn.Module):
|
1981 |
+
"""Implement the scaled dot product attention with softmax.
|
1982 |
+
Arguments
|
1983 |
+
---------
|
1984 |
+
head_dim: The dimension of the heads.
|
1985 |
+
depth: The layer id, starting from 0.
|
1986 |
+
"""
|
1987 |
+
|
1988 |
+
def __init__(
|
1989 |
+
self,
|
1990 |
+
head_dim,
|
1991 |
+
depth,
|
1992 |
+
fa_og = True,
|
1993 |
+
):
|
1994 |
+
super().__init__()
|
1995 |
+
assert flash_attn_varlen_func is not None, "FlashAttention is not installed"
|
1996 |
+
assert flash_attn_func is not None, "FlashAttention is not installed"
|
1997 |
+
self.head_dim = head_dim
|
1998 |
+
self.fa_og = fa_og # turning it to false needs customized flash attention https://github.com/xiayuqing0622/flex_head_fa
|
1999 |
+
self.lambda_init = lambda_init_fn(depth)
|
2000 |
+
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
|
2001 |
+
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
|
2002 |
+
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
|
2003 |
+
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
|
2004 |
+
|
2005 |
+
self.subln = SambaYRMSNorm(2 * self.head_dim, eps=1e-5)
|
2006 |
+
|
2007 |
+
def forward(
|
2008 |
+
self,
|
2009 |
+
q,
|
2010 |
+
k,
|
2011 |
+
v,
|
2012 |
+
dropout_p = 0.0,
|
2013 |
+
cu_seqlens_q=None,
|
2014 |
+
max_seqlen_q=None,
|
2015 |
+
cu_seqlens_k=None,
|
2016 |
+
max_seqlen_k=None,
|
2017 |
+
softmax_scale=None,
|
2018 |
+
window_size=(-1, -1),
|
2019 |
+
causal=None,
|
2020 |
+
):
|
2021 |
+
"""Implements the multihead softmax attention.
|
2022 |
+
Arguments
|
2023 |
+
---------
|
2024 |
+
q, k, v: The tensors containing the query, key, and value.
|
2025 |
+
If cu_seqlens is None and max_seqlen is None, then each has shape (B, S, H, D).
|
2026 |
+
If cu_seqlens is not None and max_seqlen is not None, then each has shape
|
2027 |
+
(total, H, D), where total is the sum of the sequence lengths in the batch.
|
2028 |
+
causal: if passed, will override self.causal
|
2029 |
+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
2030 |
+
of the sequences in the batch, used to index into qkv.
|
2031 |
+
max_seqlen: int. Maximum sequence length in the batch.
|
2032 |
+
Returns:
|
2033 |
+
--------
|
2034 |
+
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
|
2035 |
+
else (B, S, H, D).
|
2036 |
+
"""
|
2037 |
+
q = q.to(torch.bfloat16)
|
2038 |
+
k = k.to(torch.bfloat16)
|
2039 |
+
v = v.to(torch.bfloat16)
|
2040 |
+
|
2041 |
+
assert q.dtype in [torch.float16, torch.bfloat16]
|
2042 |
+
assert q.is_cuda and k.is_cuda and v.is_cuda
|
2043 |
+
#causal = self.causal if causal is None else causal
|
2044 |
+
unpadded = cu_seqlens_q is not None
|
2045 |
+
q1, q2 = split_heads(q)
|
2046 |
+
k1, k2 = split_heads(k)
|
2047 |
+
if self.fa_og:
|
2048 |
+
v1, v2 = split_heads(v)
|
2049 |
+
else:
|
2050 |
+
v = rearrange(v, "... (H two) D -> ... H (two D)", two=2)
|
2051 |
+
|
2052 |
+
kwargs = {
|
2053 |
+
"dropout_p": dropout_p,
|
2054 |
+
"softmax_scale": softmax_scale,
|
2055 |
+
"causal": causal,
|
2056 |
+
"window_size": window_size,
|
2057 |
+
}
|
2058 |
+
|
2059 |
+
if unpadded:
|
2060 |
+
assert cu_seqlens_q.dtype == torch.int32
|
2061 |
+
assert max_seqlen_q is not None
|
2062 |
+
assert isinstance(max_seqlen_q, int)
|
2063 |
+
assert cu_seqlens_k is not None
|
2064 |
+
assert cu_seqlens_k.dtype == torch.int32
|
2065 |
+
assert max_seqlen_k is not None
|
2066 |
+
assert isinstance(max_seqlen_k, int)
|
2067 |
+
|
2068 |
+
kwargs.update({
|
2069 |
+
"cu_seqlens_q": cu_seqlens_q,
|
2070 |
+
"max_seqlen_q": max_seqlen_q,
|
2071 |
+
"cu_seqlens_k": cu_seqlens_k,
|
2072 |
+
"max_seqlen_k": max_seqlen_k,
|
2073 |
+
})
|
2074 |
+
attn_func = flash_attn_varlen_func
|
2075 |
+
else:
|
2076 |
+
attn_func = flash_attn_func
|
2077 |
+
|
2078 |
+
if self.fa_og:
|
2079 |
+
attn11 = attn_func(q1, k1, v1, **kwargs)
|
2080 |
+
attn12 = attn_func(q1, k1, v2, **kwargs)
|
2081 |
+
attn1 = torch.cat([attn11, attn12], dim=-1)
|
2082 |
+
attn21 = attn_func(q2, k2, v1, **kwargs)
|
2083 |
+
attn22 = attn_func(q2, k2, v2, **kwargs)
|
2084 |
+
attn2 = torch.cat([attn21, attn22], dim=-1)
|
2085 |
+
else:
|
2086 |
+
attn1 = attn_func(q1, k1, v, **kwargs)
|
2087 |
+
attn2 = attn_func(q2, k2, v, **kwargs)
|
2088 |
+
|
2089 |
+
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
|
2090 |
+
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
|
2091 |
+
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
2092 |
+
|
2093 |
+
attn = attn1 - lambda_full * attn2
|
2094 |
+
attn = self.subln(attn)
|
2095 |
+
attn = attn * (1 - self.lambda_init)
|
2096 |
+
# reshape back to 2 * num_head
|
2097 |
+
attn = rearrange(attn, "... H (two D) -> ... (H two) D", two=2)
|
2098 |
+
return attn
|
special_tokens_map.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|endoftext|>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|endoftext|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": {
|
17 |
+
"content": "<|endoftext|>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"unk_token": {
|
24 |
+
"content": "<|endoftext|>",
|
25 |
+
"lstrip": false,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
}
|
30 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": false,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"add_prefix_space": false,
|
5 |
+
"added_tokens_decoder": {
|
6 |
+
"199999": {
|
7 |
+
"content": "<|endoftext|>",
|
8 |
+
"lstrip": false,
|
9 |
+
"normalized": false,
|
10 |
+
"rstrip": false,
|
11 |
+
"single_word": false,
|
12 |
+
"special": true
|
13 |
+
},
|
14 |
+
"200018": {
|
15 |
+
"content": "<|endofprompt|>",
|
16 |
+
"lstrip": false,
|
17 |
+
"normalized": false,
|
18 |
+
"rstrip": false,
|
19 |
+
"single_word": false,
|
20 |
+
"special": true
|
21 |
+
},
|
22 |
+
"200019": {
|
23 |
+
"content": "<|assistant|>",
|
24 |
+
"lstrip": false,
|
25 |
+
"normalized": false,
|
26 |
+
"rstrip": true,
|
27 |
+
"single_word": false,
|
28 |
+
"special": true
|
29 |
+
},
|
30 |
+
"200020": {
|
31 |
+
"content": "<|end|>",
|
32 |
+
"lstrip": false,
|
33 |
+
"normalized": false,
|
34 |
+
"rstrip": true,
|
35 |
+
"single_word": false,
|
36 |
+
"special": true
|
37 |
+
},
|
38 |
+
"200021": {
|
39 |
+
"content": "<|user|>",
|
40 |
+
"lstrip": false,
|
41 |
+
"normalized": false,
|
42 |
+
"rstrip": true,
|
43 |
+
"single_word": false,
|
44 |
+
"special": true
|
45 |
+
},
|
46 |
+
"200022": {
|
47 |
+
"content": "<|system|>",
|
48 |
+
"lstrip": false,
|
49 |
+
"normalized": false,
|
50 |
+
"rstrip": true,
|
51 |
+
"single_word": false,
|
52 |
+
"special": true
|
53 |
+
},
|
54 |
+
"200023": {
|
55 |
+
"content": "<|tool|>",
|
56 |
+
"lstrip": false,
|
57 |
+
"normalized": false,
|
58 |
+
"rstrip": true,
|
59 |
+
"single_word": false,
|
60 |
+
"special": false
|
61 |
+
},
|
62 |
+
"200024": {
|
63 |
+
"content": "<|/tool|>",
|
64 |
+
"lstrip": false,
|
65 |
+
"normalized": false,
|
66 |
+
"rstrip": true,
|
67 |
+
"single_word": false,
|
68 |
+
"special": false
|
69 |
+
},
|
70 |
+
"200025": {
|
71 |
+
"content": "<|tool_call|>",
|
72 |
+
"lstrip": false,
|
73 |
+
"normalized": false,
|
74 |
+
"rstrip": true,
|
75 |
+
"single_word": false,
|
76 |
+
"special": false
|
77 |
+
},
|
78 |
+
"200026": {
|
79 |
+
"content": "<|/tool_call|>",
|
80 |
+
"lstrip": false,
|
81 |
+
"normalized": false,
|
82 |
+
"rstrip": true,
|
83 |
+
"single_word": false,
|
84 |
+
"special": false
|
85 |
+
},
|
86 |
+
"200027": {
|
87 |
+
"content": "<|tool_response|>",
|
88 |
+
"lstrip": false,
|
89 |
+
"normalized": false,
|
90 |
+
"rstrip": true,
|
91 |
+
"single_word": false,
|
92 |
+
"special": false
|
93 |
+
},
|
94 |
+
"200028": {
|
95 |
+
"content": "<|tag|>",
|
96 |
+
"lstrip": false,
|
97 |
+
"normalized": false,
|
98 |
+
"rstrip": true,
|
99 |
+
"single_word": false,
|
100 |
+
"special": true
|
101 |
+
}
|
102 |
+
},
|
103 |
+
"bos_token": "<|endoftext|>",
|
104 |
+
"chat_template": "{% for message in messages %}{% if message['role'] == 'system' and 'tools' in message and message['tools'] is not none %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|tool|>' + message['tools'] + '<|/tool|>' + '<|end|>' }}{% else %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|end|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %}",
|
105 |
+
"clean_up_tokenization_spaces": false,
|
106 |
+
"eos_token": "<|endoftext|>",
|
107 |
+
"model_max_length": 65536,
|
108 |
+
"pad_token": "<|endoftext|>",
|
109 |
+
"tokenizer_class": "GPT2Tokenizer",
|
110 |
+
"unk_token": "<|endoftext|>"
|
111 |
+
}
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|