itazap HF Staff commited on
Commit
724be6e
·
verified ·
1 Parent(s): df03219

Upload BLT model converted

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. backup_blt_modellike/__init__.py +28 -0
  2. backup_blt_modellike/__pycache__/__init__.cpython-312.pyc +0 -0
  3. backup_blt_modellike/__pycache__/tokenization_blt.cpython-312.pyc +0 -0
  4. backup_blt_modellike/configuration_blt.py +225 -0
  5. backup_blt_modellike/convert_blt_weights_to_hf.py +397 -0
  6. backup_blt_modellike/modeling_blt.py +971 -0
  7. backup_blt_modellike/tokenization_blt.py +412 -0
  8. backup_blt_wip copy/__init__.py +0 -0
  9. backup_blt_wip copy/__pycache__/__init__.cpython-312.pyc +0 -0
  10. backup_blt_wip copy/__pycache__/blt_args.cpython-312.pyc +0 -0
  11. backup_blt_wip copy/__pycache__/blt_one_file.cpython-312.pyc +0 -0
  12. backup_blt_wip copy/__pycache__/configuration_blt.cpython-312.pyc +0 -0
  13. backup_blt_wip copy/__pycache__/configuration_blt_og.cpython-312.pyc +0 -0
  14. backup_blt_wip copy/__pycache__/modeling_blt.cpython-312.pyc +0 -0
  15. backup_blt_wip copy/__pycache__/modeling_blt_dev.cpython-312.pyc +0 -0
  16. backup_blt_wip copy/__pycache__/modeling_blt_modellike.cpython-312.pyc +0 -0
  17. backup_blt_wip copy/__pycache__/modeling_blt_old.cpython-312.pyc +0 -0
  18. backup_blt_wip copy/__pycache__/modeling_blt_wip.cpython-312.pyc +0 -0
  19. backup_blt_wip copy/__pycache__/modeling_blt_wip_backup.cpython-312.pyc +0 -0
  20. backup_blt_wip copy/__pycache__/tokenization_blt.cpython-312.pyc +0 -0
  21. backup_blt_wip copy/configuration_blt.py +390 -0
  22. backup_blt_wip copy/configuration_blt_og.py +608 -0
  23. backup_blt_wip copy/modeling_blt.py +1287 -0
  24. backup_blt_wip copy/modeling_blt_old.py +1602 -0
  25. backup_blt_wip copy/modular_blt.py +1180 -0
  26. backup_blt_wip copy/tokenization_blt.py +271 -0
  27. backup_blt_wip_backup/__pycache__/blt_args.cpython-312.pyc +0 -0
  28. backup_blt_wip_backup/__pycache__/blt_one_file.cpython-312.pyc +0 -0
  29. backup_blt_wip_backup/__pycache__/configuration_blt.cpython-312.pyc +0 -0
  30. backup_blt_wip_backup/__pycache__/modeling_blt_wip.cpython-312.pyc +0 -0
  31. backup_blt_wip_backup/__pycache__/modeling_blt_wip_backup.cpython-312.pyc +0 -0
  32. backup_blt_wip_backup/__pycache__/tokenization_blt.cpython-312.pyc +0 -0
  33. backup_blt_wip_backup/blt_args.py +187 -0
  34. backup_blt_wip_backup/configuration_blt.py +590 -0
  35. backup_blt_wip_backup/convert_hf_blt_original_to_unified.py +540 -0
  36. backup_blt_wip_backup/modeling_blt_wip.py +1836 -0
  37. backup_blt_wip_backup/modeling_blt_wip_backup.py +2166 -0
  38. backup_blt_wip_backup/tokenization_blt.py +273 -0
  39. backup_blt_wip_backup/tokenizers/__init__.py +1 -0
  40. backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-312.pyc +0 -0
  41. backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-39.pyc +0 -0
  42. backup_blt_wip_backup/tokenizers/__pycache__/abstract_tokenizer.cpython-312.pyc +0 -0
  43. backup_blt_wip_backup/tokenizers/__pycache__/blt_tokenizer.cpython-312.pyc +0 -0
  44. backup_blt_wip_backup/tokenizers/__pycache__/build_tokenizer.cpython-312.pyc +0 -0
  45. backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-312.pyc +0 -0
  46. backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-39.pyc +0 -0
  47. backup_blt_wip_backup/tokenizers/__pycache__/sentence_piece_tokenizer.cpython-312.pyc +0 -0
  48. backup_blt_wip_backup/tokenizers/__pycache__/tiktoken_tokenizer.cpython-312.pyc +0 -0
  49. backup_blt_wip_backup/tokenizers/abstract_tokenizer.py +21 -0
  50. backup_blt_wip_backup/tokenizers/blt_tokenizer.py +143 -0
backup_blt_modellike/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_blt import *
22
+ from .modeling_blt import *
23
+ from .tokenization_blt import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
backup_blt_modellike/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (667 Bytes). View file
 
backup_blt_modellike/__pycache__/tokenization_blt.cpython-312.pyc ADDED
Binary file (20.5 kB). View file
 
backup_blt_modellike/configuration_blt.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """BLT model configuration"""
21
+
22
+ from ...configuration_utils import PretrainedConfig
23
+ from ...modeling_rope_utils import rope_config_validation
24
+
25
+
26
+ class BLTConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`BLTModel`]. It is used to instantiate an BLT
29
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30
+ defaults will yield a similar configuration to that of the BLT-7B.
31
+ e.g. [meta-blt/BLT-2-7b-hf](https://huggingface.co/meta-blt/BLT-2-7b-hf)
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+
37
+ Args:
38
+ vocab_size (`int`, *optional*, defaults to 32000):
39
+ Vocabulary size of the BLT model. Defines the number of different tokens that can be represented by the
40
+ `inputs_ids` passed when calling [`BLTModel`]
41
+ hidden_size (`int`, *optional*, defaults to 4096):
42
+ Dimension of the hidden representations.
43
+ intermediate_size (`int`, *optional*, defaults to 11008):
44
+ Dimension of the MLP representations.
45
+ num_hidden_layers (`int`, *optional*, defaults to 32):
46
+ Number of hidden layers in the Transformer decoder.
47
+ num_attention_heads (`int`, *optional*, defaults to 32):
48
+ Number of attention heads for each attention layer in the Transformer decoder.
49
+ num_key_value_heads (`int`, *optional*):
50
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
51
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
52
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
53
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
54
+ by meanpooling all the original heads within that group. For more details checkout [this
55
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
56
+ `num_attention_heads`.
57
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
58
+ The non-linear activation function (function or string) in the decoder.
59
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
60
+ The maximum sequence length that this model might ever be used with. BLT 1 supports up to 2048 tokens,
61
+ BLT 2 up to 4096, CodeBLT up to 16384.
62
+ initializer_range (`float`, *optional*, defaults to 0.02):
63
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
65
+ The epsilon used by the rms normalization layers.
66
+ use_cache (`bool`, *optional*, defaults to `True`):
67
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
68
+ relevant if `config.is_decoder=True`.
69
+ pad_token_id (`int`, *optional*):
70
+ Padding token id.
71
+ bos_token_id (`int`, *optional*, defaults to 1):
72
+ Beginning of stream token id.
73
+ eos_token_id (`int`, *optional*, defaults to 2):
74
+ End of stream token id.
75
+ pretraining_tp (`int`, *optional*, defaults to 1):
76
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
77
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
78
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
79
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
80
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
81
+ Whether to tie weight embeddings
82
+ rope_theta (`float`, *optional*, defaults to 10000.0):
83
+ The base period of the RoPE embeddings.
84
+ rope_scaling (`Dict`, *optional*):
85
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
86
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
87
+ accordingly.
88
+ Expected contents:
89
+ `rope_type` (`str`):
90
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
91
+ 'blt3'], with 'default' being the original RoPE implementation.
92
+ `factor` (`float`, *optional*):
93
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
94
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
95
+ original maximum pre-trained length.
96
+ `original_max_position_embeddings` (`int`, *optional*):
97
+ Used with 'dynamic', 'longrope' and 'blt3'. The original max position embeddings used during
98
+ pretraining.
99
+ `attention_factor` (`float`, *optional*):
100
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
101
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
102
+ `factor` field to infer the suggested value.
103
+ `beta_fast` (`float`, *optional*):
104
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
105
+ ramp function. If unspecified, it defaults to 32.
106
+ `beta_slow` (`float`, *optional*):
107
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
108
+ ramp function. If unspecified, it defaults to 1.
109
+ `short_factor` (`List[float]`, *optional*):
110
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
111
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
112
+ size divided by the number of attention heads divided by 2
113
+ `long_factor` (`List[float]`, *optional*):
114
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
115
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
116
+ size divided by the number of attention heads divided by 2
117
+ `low_freq_factor` (`float`, *optional*):
118
+ Only used with 'blt3'. Scaling factor applied to low frequency components of the RoPE
119
+ `high_freq_factor` (`float`, *optional*):
120
+ Only used with 'blt3'. Scaling factor applied to high frequency components of the RoPE
121
+ attention_bias (`bool`, *optional*, defaults to `False`):
122
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
123
+ attention_dropout (`float`, *optional*, defaults to 0.0):
124
+ The dropout ratio for the attention probabilities.
125
+ mlp_bias (`bool`, *optional*, defaults to `False`):
126
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
127
+ head_dim (`int`, *optional*):
128
+ The attention head dimension. If None, it will default to hidden_size // num_attention_heads
129
+
130
+ ```python
131
+ >>> from transformers import BLTModel, BLTConfig
132
+
133
+ >>> # Initializing a BLT blt-7b style configuration
134
+ >>> configuration = BLTConfig()
135
+
136
+ >>> # Initializing a model from the blt-7b style configuration
137
+ >>> model = BLTModel(configuration)
138
+
139
+ >>> # Accessing the model configuration
140
+ >>> configuration = model.config
141
+ ```"""
142
+
143
+ model_type = "blt"
144
+ keys_to_ignore_at_inference = ["past_key_values"]
145
+ # Default tensor parallel plan for base model `BLTModel`
146
+ base_model_tp_plan = {
147
+ "layers.*.self_attn.q_proj": "colwise",
148
+ "layers.*.self_attn.k_proj": "colwise",
149
+ "layers.*.self_attn.v_proj": "colwise",
150
+ "layers.*.self_attn.o_proj": "rowwise",
151
+ "layers.*.mlp.gate_proj": "colwise",
152
+ "layers.*.mlp.up_proj": "colwise",
153
+ "layers.*.mlp.down_proj": "rowwise",
154
+ }
155
+ base_model_pp_plan = {
156
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
157
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
158
+ "norm": (["hidden_states"], ["hidden_states"]),
159
+ }
160
+
161
+ def __init__(
162
+ self,
163
+ vocab_size=32000,
164
+ hidden_size=4096,
165
+ intermediate_size=11008,
166
+ num_hidden_layers=32,
167
+ num_attention_heads=32,
168
+ num_key_value_heads=None,
169
+ hidden_act="silu",
170
+ max_position_embeddings=2048,
171
+ initializer_range=0.02,
172
+ rms_norm_eps=1e-6,
173
+ use_cache=True,
174
+ pad_token_id=None,
175
+ bos_token_id=1,
176
+ eos_token_id=2,
177
+ pretraining_tp=1,
178
+ tie_word_embeddings=False,
179
+ rope_theta=10000.0,
180
+ rope_scaling=None,
181
+ attention_bias=False,
182
+ attention_dropout=0.0,
183
+ mlp_bias=False,
184
+ head_dim=None,
185
+ **kwargs,
186
+ ):
187
+ self.vocab_size = vocab_size
188
+ self.max_position_embeddings = max_position_embeddings
189
+ self.hidden_size = hidden_size
190
+ self.intermediate_size = intermediate_size
191
+ self.num_hidden_layers = num_hidden_layers
192
+ self.num_attention_heads = num_attention_heads
193
+
194
+ # for backward compatibility
195
+ if num_key_value_heads is None:
196
+ num_key_value_heads = num_attention_heads
197
+
198
+ self.num_key_value_heads = num_key_value_heads
199
+ self.hidden_act = hidden_act
200
+ self.initializer_range = initializer_range
201
+ self.rms_norm_eps = rms_norm_eps
202
+ self.pretraining_tp = pretraining_tp
203
+ self.use_cache = use_cache
204
+ self.rope_theta = rope_theta
205
+ self.rope_scaling = rope_scaling
206
+ self.attention_bias = attention_bias
207
+ self.attention_dropout = attention_dropout
208
+ self.mlp_bias = mlp_bias
209
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
210
+ # Validate the correctness of rotary position embeddings parameters
211
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
212
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
213
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
214
+ rope_config_validation(self)
215
+
216
+ super().__init__(
217
+ pad_token_id=pad_token_id,
218
+ bos_token_id=bos_token_id,
219
+ eos_token_id=eos_token_id,
220
+ tie_word_embeddings=tie_word_embeddings,
221
+ **kwargs,
222
+ )
223
+
224
+
225
+ __all__ = ["BLTConfig"]
backup_blt_modellike/convert_blt_weights_to_hf.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ from typing import Any, Dict, Optional
6
+
7
+ import torch
8
+ from huggingface_hub import hf_hub_download, upload_folder
9
+ from safetensors.torch import load_file, save_file
10
+
11
+ from transformers.models.blt_wip.configuration_blt import BLTConfig
12
+ from transformers.models.blt_wip.modeling_blt import BLTModel
13
+ from transformers.models.blt_wip.modeling_blt_dev import BLTForCausalLM
14
+ from transformers.utils import logging as transformers_logging
15
+
16
+
17
+ logger = transformers_logging.get_logger(__name__)
18
+ transformers_logging.set_verbosity_info()
19
+
20
+
21
+ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]:
22
+ logger.info("Merging configurations")
23
+
24
+ with open(config_path, "r") as f:
25
+ main_config = json.load(f)
26
+
27
+ with open(entropy_params_path, "r") as f:
28
+ entropy_data = json.load(f)
29
+
30
+ entropy_model_params = entropy_data.get("entropy_model", {})
31
+ patcher_args = entropy_data.get("data", {}).get("patcher_args", {})
32
+
33
+ unified_config = main_config.copy()["args"]
34
+
35
+ for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]:
36
+ if key in unified_config and not isinstance(unified_config[key], int):
37
+ unified_config[key] = int(unified_config[key])
38
+
39
+ patch_size = patcher_args.get("patch_size", 8)
40
+ if isinstance(patch_size, float):
41
+ patch_size = int(patch_size)
42
+
43
+ # Create patcher config
44
+ patcher_hidden_size = int(entropy_model_params.get("dim", 512))
45
+ patcher_multiple_of = int(entropy_model_params.get("multiple_of", 256))
46
+ patcher_intermediate_size = patcher_multiple_of * ((int(8 * patcher_hidden_size / 3) + patcher_multiple_of - 1) // patcher_multiple_of)
47
+
48
+ patcher_config = {
49
+ "vocab_size": int(entropy_model_params.get("vocab_size", 256)),
50
+ "hidden_size": patcher_hidden_size,
51
+ "num_hidden_layers": int(entropy_model_params.get("n_layers", 8)),
52
+ "num_attention_heads": int(entropy_model_params.get("n_heads", 8)),
53
+ "num_key_value_heads": int(entropy_model_params.get("n_kv_heads"))
54
+ if entropy_model_params.get("n_kv_heads") is not None
55
+ else None,
56
+ "max_position_embeddings": int(entropy_model_params.get("max_seqlen", 1024)),
57
+ "norm_eps": entropy_model_params.get("norm_eps", 1e-5),
58
+ "dropout": entropy_model_params.get("dropout", 0.0),
59
+ "rope_theta": entropy_model_params.get("rope_theta", 10000.0),
60
+ "attn_impl": entropy_model_params.get("attn_impl", "sdpa"),
61
+ "attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"),
62
+ "intermediate_size": patcher_intermediate_size,
63
+ }
64
+
65
+ # Create encoder config
66
+ encoder_hidden_size = unified_config.get("dim_local_encoder", 1024)
67
+ encoder_multiple_of = unified_config.get("multiple_of", 256)
68
+ encoder_intermediate_size = encoder_multiple_of * ((int(8 * encoder_hidden_size / 3) + encoder_multiple_of - 1) // encoder_multiple_of)
69
+
70
+ encoder_config = {
71
+ "vocab_size": unified_config.get("vocab_size", 256),
72
+ "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_encoder", False),
73
+ "cross_attn_k": unified_config.get("cross_attn_k", 2),
74
+ "hidden_size_global": unified_config.get("hidden_size_global", 2048),
75
+ "pm_size": unified_config.get("pm_size", 0),
76
+ "hidden_size": encoder_hidden_size,
77
+ "num_attention_heads": unified_config.get("n_heads_local_encoder", 16),
78
+ "num_key_value_heads": unified_config.get("n_kv_heads"),
79
+ "num_hidden_layers": unified_config.get("n_layers_local_encoder", 1),
80
+ "norm_eps": unified_config.get("norm_eps", 1e-5),
81
+ "dropout": unified_config.get("dropout", 0.0),
82
+ "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024),
83
+ "rope_theta": unified_config.get("rope_theta", 10000.0),
84
+ "rope_scaling": {"rope_type": "default"},
85
+ "hidden_act": unified_config.get("hidden_act", "silu"),
86
+ "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"),
87
+ "intermediate_size": encoder_intermediate_size,
88
+ }
89
+
90
+ # Create decoder config
91
+ decoder_hidden_size = unified_config.get("dim_local_decoder", 1024)
92
+ decoder_multiple_of = unified_config.get("multiple_of", 256)
93
+ decoder_intermediate_size = decoder_multiple_of * ((int(8 * decoder_hidden_size / 3) + decoder_multiple_of - 1) // decoder_multiple_of)
94
+
95
+ decoder_config = {
96
+ "vocab_size": unified_config.get("vocab_size", 256),
97
+ "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_decoder", False),
98
+ "cross_attn_k": unified_config.get("cross_attn_k", 2),
99
+ "hidden_size_global": unified_config.get("hidden_size_global", 2048),
100
+ "hidden_size": decoder_hidden_size,
101
+ "num_attention_heads": unified_config.get("n_heads_local_decoder", 16),
102
+ "num_key_value_heads": unified_config.get("n_kv_heads"),
103
+ "num_hidden_layers": unified_config.get("n_layers_local_decoder", 9),
104
+ "norm_eps": unified_config.get("norm_eps", 1e-5),
105
+ "dropout": unified_config.get("dropout", 0.0),
106
+ "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024),
107
+ "rope_theta": unified_config.get("rope_theta", 10000.0),
108
+ "rope_scaling": {"rope_type": "default"},
109
+ "hidden_act": unified_config.get("hidden_act", "silu"),
110
+ "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"),
111
+ "intermediate_size": decoder_intermediate_size,
112
+ }
113
+
114
+ # Create global transformer config
115
+ global_hidden_size = unified_config.get("dim_global", 2048)
116
+ global_multiple_of = unified_config.get("multiple_of", 256)
117
+ global_intermediate_size = global_multiple_of * ((int(8 * global_hidden_size / 3) + global_multiple_of - 1) // global_multiple_of)
118
+
119
+ global_config = {
120
+ "hidden_size": global_hidden_size,
121
+ "num_attention_heads": unified_config.get("n_heads_global", 16),
122
+ "num_key_value_heads": unified_config.get("n_kv_heads_global"),
123
+ "num_hidden_layers": unified_config.get("n_layers_global", 25),
124
+ "norm_eps": unified_config.get("norm_eps", 1e-5),
125
+ "dropout": unified_config.get("dropout", 0.0),
126
+ "max_position_embeddings": unified_config.get("max_seqlen", 1024),
127
+ "rope_theta": unified_config.get("rope_theta", 10000.0),
128
+ "rope_scaling": {"rope_type": "default"},
129
+ "hidden_act": unified_config.get("hidden_act", "silu"),
130
+ "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"),
131
+ "intermediate_size": global_intermediate_size,
132
+ }
133
+
134
+ # Create main config with sub-configs
135
+ main_config_dict = {
136
+ "model_type": "blt",
137
+ "vocab_size": unified_config.get("vocab_size", 256),
138
+ "max_position_embeddings": unified_config.get("max_seqlen", 1024),
139
+ "patch_in_forward": True,
140
+ "realtime_patching": True,
141
+ "patching_mode": "entropy",
142
+ "patch_size": patch_size,
143
+ "patching_threshold": patcher_args.get("threshold", 0.5),
144
+ "patching_threshold_add": patcher_args.get("threshold_add", 0.0),
145
+ "max_patch_length": patcher_args.get("max_patch_length"),
146
+ "patching_batch_size": patcher_args.get("patching_batch_size", 1),
147
+ "patching_device": patcher_args.get("patching_device", "cuda"),
148
+ "monotonicity": patcher_args.get("monotonicity", False),
149
+ "cross_attn_k": unified_config.get("cross_attn_k", 2),
150
+ "encoder_hash_byte_group_size": unified_config.get("encoder_hash_byte_group_size"),
151
+ "encoder_hash_byte_group_vocab": unified_config.get("encoder_hash_byte_group_vocab", 30000),
152
+ "encoder_hash_byte_group_nb_functions": unified_config.get("encoder_hash_byte_group_nb_functions", 3),
153
+ "pm_size": unified_config.get("pm_size", 0),
154
+ "patcher_config": patcher_config,
155
+ "encoder_config": encoder_config,
156
+ "decoder_config": decoder_config,
157
+ "global_config": global_config,
158
+ }
159
+
160
+ main_config_dict["tie_word_embeddings"] = False
161
+
162
+ logger.info(f"Merged configuration with {len(main_config_dict)} parameters")
163
+ return main_config_dict
164
+
165
+
166
+ def apply_weight_mapping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
167
+ component_mappings = {
168
+ ".attention.": ".self_attn.",
169
+ ".feed_forward.": ".mlp.",
170
+ ".attention_norm.": ".input_layernorm.",
171
+ ".ffn_norm.": ".post_attention_layernorm.",
172
+ ".tok_embeddings.": ".embed_tokens.",
173
+ ".cross_attn_norm_q.": ".q_norm.",
174
+ ".cross_attn_norm_kv.": ".k_norm.",
175
+ ".w1.": ".gate_proj.",
176
+ ".w2.": ".down_proj.",
177
+ ".w3.": ".up_proj.",
178
+ ".wq.": ".q_proj.",
179
+ ".wk.": ".k_proj.",
180
+ ".wv.": ".v_proj.",
181
+ ".wo.": ".o_proj.",
182
+ ".output.": ".lm_head.",
183
+ }
184
+
185
+ new_state_dict = {}
186
+
187
+ for old_key, tensor in state_dict.items():
188
+ new_key = old_key
189
+
190
+ for old_pattern, new_pattern in component_mappings.items():
191
+ if old_pattern in new_key:
192
+ new_key = new_key.replace(old_pattern, new_pattern)
193
+
194
+ new_state_dict[new_key] = tensor
195
+
196
+ return new_state_dict
197
+
198
+
199
+ def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]:
200
+ main_weights = load_file(weights_path)
201
+
202
+ entropy_weights = torch.load(entropy_weights_path, map_location="cpu", weights_only=True)
203
+
204
+ if "model" in entropy_weights:
205
+ entropy_weights = entropy_weights["model"]
206
+ elif "state_dict" in entropy_weights:
207
+ entropy_weights = entropy_weights["state_dict"]
208
+
209
+ unified_weights = main_weights.copy()
210
+
211
+ for key, tensor in entropy_weights.items():
212
+ patcher_key = f"patcher.{key}"
213
+ unified_weights[patcher_key] = tensor
214
+
215
+ unified_weights = apply_weight_mapping(unified_weights)
216
+
217
+ decoder_lm_head_key = "local_decoder.lm_head.weight"
218
+ top_lm_head_key = "lm_head.weight"
219
+ unified_weights[top_lm_head_key] = unified_weights[decoder_lm_head_key]
220
+ del unified_weights[decoder_lm_head_key]
221
+
222
+ prefixed_weights = {}
223
+ for key, tensor in unified_weights.items():
224
+ if key == top_lm_head_key:
225
+ prefixed_weights[key] = tensor
226
+ elif not key.startswith("model."):
227
+ prefixed_weights[f"model.{key}"] = tensor
228
+ else:
229
+ prefixed_weights[key] = tensor
230
+
231
+ unified_weights = prefixed_weights
232
+
233
+ return unified_weights
234
+
235
+
236
+ def create_tokenizer_config(output_dir: str, config: Dict[str, Any]):
237
+ tokenizer_config = {
238
+ "tokenizer_class": "BltTokenizer",
239
+ "vocab_size": config.get("vocab_size", 256),
240
+ "model_max_length": config.get("max_seqlen", 1024),
241
+ "add_bos_token": True,
242
+ "add_eos_token": True,
243
+ "bos_token": "<s>",
244
+ "eos_token": "</s>",
245
+ "pad_token": "<pad>",
246
+ "unk_token": "<unk>",
247
+ }
248
+
249
+ tokenizer_path = os.path.join(output_dir, "tokenizer_config.json")
250
+ with open(tokenizer_path, "w") as f:
251
+ json.dump(tokenizer_config, f, indent=2)
252
+
253
+
254
+ def push_to_hub(
255
+ local_dir: str,
256
+ repo_id: str,
257
+ commit_message: str = "Upload converted BLT model",
258
+ private: bool = False,
259
+ token: Optional[str] = None,
260
+ ) -> None:
261
+ try:
262
+ upload_folder(
263
+ folder_path=local_dir,
264
+ repo_id=repo_id,
265
+ commit_message=commit_message,
266
+ repo_type="model",
267
+ token=token,
268
+ )
269
+ logger.info(f"Successfully pushed model to {repo_id}")
270
+
271
+ except Exception as e:
272
+ logger.error(f"Failed to push model to Hub: {e}")
273
+ raise
274
+
275
+
276
+ def convert_hf_blt_to_unified(
277
+ model_id: str,
278
+ output_dir: str,
279
+ config_name: str = "config.json",
280
+ weights_name: str = "model.bin",
281
+ cache_dir: Optional[str] = None,
282
+ push_to_hub_repo: Optional[str] = None,
283
+ hub_private: bool = False,
284
+ hub_token: Optional[str] = None,
285
+ ) -> None:
286
+ # Download model files
287
+ config_path = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir)
288
+ weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", cache_dir=cache_dir)
289
+ entropy_params_path = hf_hub_download(repo_id=model_id, filename="entropy_model/params.json", cache_dir=cache_dir)
290
+ entropy_weights_path = hf_hub_download(
291
+ repo_id=model_id, filename="entropy_model/consolidated.pth", cache_dir=cache_dir
292
+ )
293
+
294
+ unified_config = merge_configurations(config_path, entropy_params_path)
295
+ unified_weights = merge_weights(weights_path, entropy_weights_path)
296
+
297
+ os.makedirs(output_dir, exist_ok=True)
298
+
299
+ config_path = os.path.join(output_dir, config_name)
300
+ with open(config_path, "w") as f:
301
+ json.dump(unified_config, f, indent=2)
302
+
303
+ if weights_name.endswith(".bin"):
304
+ weights_name = weights_name.replace(".bin", ".safetensors")
305
+
306
+ weights_path = os.path.join(output_dir, weights_name)
307
+ save_file(unified_weights, weights_path)
308
+
309
+ create_tokenizer_config(output_dir, unified_config)
310
+
311
+ logger.info(f"Conversion completed, model saved to: {output_dir}")
312
+
313
+ if push_to_hub_repo:
314
+ push_to_hub(
315
+ local_dir=output_dir,
316
+ repo_id=push_to_hub_repo,
317
+ commit_message="Upload BLT model converted",
318
+ private=hub_private,
319
+ token=hub_token,
320
+ )
321
+
322
+
323
+ def main():
324
+ parser = argparse.ArgumentParser(
325
+ description="Convert BLT models from HuggingFace Hub format to unified format",
326
+ formatter_class=argparse.RawDescriptionHelpFormatter,
327
+ )
328
+
329
+ parser.add_argument(
330
+ "--model_id",
331
+ type=str,
332
+ default="facebook/blt-1b",
333
+ )
334
+ parser.add_argument(
335
+ "--output_dir",
336
+ type=str,
337
+ default="./blt_converted",
338
+ )
339
+ parser.add_argument(
340
+ "--config_name",
341
+ type=str,
342
+ default="config.json",
343
+ )
344
+ parser.add_argument(
345
+ "--weights_name",
346
+ type=str,
347
+ default="model.bin",
348
+ )
349
+ parser.add_argument(
350
+ "--cache_dir",
351
+ type=str,
352
+ default=None,
353
+ )
354
+ parser.add_argument(
355
+ "--debug",
356
+ action="store_true",
357
+ default=True,
358
+ )
359
+ parser.add_argument(
360
+ "--push_to_hub",
361
+ type=str,
362
+ default=None,
363
+ )
364
+ parser.add_argument(
365
+ "--hub_private",
366
+ action="store_true",
367
+ default=False,
368
+ )
369
+ parser.add_argument(
370
+ "--hub_token",
371
+ type=str,
372
+ default="hf_token",
373
+ )
374
+
375
+ args = parser.parse_args()
376
+
377
+ transformers_logging.set_verbosity_debug()
378
+ logging.basicConfig(level=logging.DEBUG)
379
+
380
+ try:
381
+ convert_hf_blt_to_unified(
382
+ model_id=args.model_id,
383
+ output_dir=args.output_dir,
384
+ config_name=args.config_name,
385
+ weights_name=args.weights_name,
386
+ cache_dir=args.cache_dir,
387
+ push_to_hub_repo=args.push_to_hub,
388
+ hub_private=args.hub_private,
389
+ hub_token=args.hub_token,
390
+ )
391
+ except Exception as e:
392
+ logger.error(f"Conversion failed: {e}")
393
+ raise
394
+
395
+
396
+ if __name__ == "__main__":
397
+ main()
backup_blt_modellike/modeling_blt.py ADDED
@@ -0,0 +1,971 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ from typing import Callable, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+
26
+ from ...activations import ACT2FN
27
+ from ...cache_utils import Cache, DynamicCache
28
+ from ...generation import GenerationMixin
29
+ from ...modeling_attn_mask_utils import AttentionMaskConverter
30
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
31
+ from ...modeling_layers import GradientCheckpointingLayer
32
+ from ...modeling_outputs import (
33
+ BaseModelOutputWithPast,
34
+ CausalLMOutputWithPast,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutputWithPast,
37
+ TokenClassifierOutput,
38
+ )
39
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
+ from ...processing_utils import Unpack
42
+ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
43
+ from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
44
+ from .configuration_blt import BLTConfig
45
+
46
+
47
+ if is_torch_flex_attn_available():
48
+ from torch.nn.attention.flex_attention import BlockMask
49
+
50
+ from ...integrations.flex_attention import make_flex_block_causal_mask
51
+
52
+ from ...integrations import use_kernel_forward_from_hub
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+
58
+ @use_kernel_forward_from_hub("RMSNorm")
59
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->BLT
60
+ class BLTRMSNorm(nn.Module):
61
+ def __init__(self, hidden_size, eps=1e-6):
62
+ """
63
+ BLTRMSNorm is equivalent to T5LayerNorm
64
+ """
65
+ super().__init__()
66
+ self.weight = nn.Parameter(torch.ones(hidden_size))
67
+ self.variance_epsilon = eps
68
+
69
+ def forward(self, hidden_states):
70
+ input_dtype = hidden_states.dtype
71
+ hidden_states = hidden_states.to(torch.float32)
72
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
73
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
74
+ return self.weight * hidden_states.to(input_dtype)
75
+
76
+ def extra_repr(self):
77
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
78
+
79
+
80
+ ALL_LAYERNORM_LAYERS.append(BLTRMSNorm)
81
+
82
+
83
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->BLT
84
+ class BLTRotaryEmbedding(nn.Module):
85
+ def __init__(self, config: BLTConfig, device=None):
86
+ super().__init__()
87
+ # BC: "rope_type" was originally "type"
88
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
89
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
90
+ else:
91
+ self.rope_type = "default"
92
+ self.max_seq_len_cached = config.max_position_embeddings
93
+ self.original_max_seq_len = config.max_position_embeddings
94
+
95
+ self.config = config
96
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
97
+
98
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
99
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
100
+ self.original_inv_freq = self.inv_freq
101
+
102
+ @torch.no_grad()
103
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
104
+ def forward(self, x, position_ids):
105
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
106
+ position_ids_expanded = position_ids[:, None, :].float()
107
+
108
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
109
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
110
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
111
+ emb = torch.cat((freqs, freqs), dim=-1)
112
+ cos = emb.cos() * self.attention_scaling
113
+ sin = emb.sin() * self.attention_scaling
114
+
115
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
116
+
117
+
118
+ def rotate_half(x):
119
+ """Rotates half the hidden dims of the input."""
120
+ x1 = x[..., : x.shape[-1] // 2]
121
+ x2 = x[..., x.shape[-1] // 2 :]
122
+ return torch.cat((-x2, x1), dim=-1)
123
+
124
+
125
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
126
+ """Applies Rotary Position Embedding to the query and key tensors.
127
+
128
+ Args:
129
+ q (`torch.Tensor`): The query tensor.
130
+ k (`torch.Tensor`): The key tensor.
131
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
132
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
133
+ position_ids (`torch.Tensor`, *optional*):
134
+ Deprecated and unused.
135
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
136
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
137
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
138
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
139
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
140
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
141
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
142
+ Returns:
143
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
144
+ """
145
+ cos = cos.unsqueeze(unsqueeze_dim)
146
+ sin = sin.unsqueeze(unsqueeze_dim)
147
+ q_embed = (q * cos) + (rotate_half(q) * sin)
148
+ k_embed = (k * cos) + (rotate_half(k) * sin)
149
+ return q_embed, k_embed
150
+
151
+
152
+ # Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->BLT
153
+ class BLTMLP(nn.Module):
154
+ def __init__(self, config):
155
+ super().__init__()
156
+ self.config = config
157
+ self.hidden_size = config.hidden_size
158
+ self.intermediate_size = config.intermediate_size
159
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
160
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
161
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
162
+ self.act_fn = ACT2FN[config.hidden_act]
163
+
164
+ def forward(self, x):
165
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
166
+ return down_proj
167
+
168
+
169
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
170
+ """
171
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
172
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
173
+ """
174
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
175
+ if n_rep == 1:
176
+ return hidden_states
177
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
178
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
179
+
180
+
181
+ def eager_attention_forward(
182
+ module: nn.Module,
183
+ query: torch.Tensor,
184
+ key: torch.Tensor,
185
+ value: torch.Tensor,
186
+ attention_mask: Optional[torch.Tensor],
187
+ scaling: float,
188
+ dropout: float = 0.0,
189
+ **kwargs,
190
+ ):
191
+ key_states = repeat_kv(key, module.num_key_value_groups)
192
+ value_states = repeat_kv(value, module.num_key_value_groups)
193
+
194
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
195
+ if attention_mask is not None:
196
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
197
+ attn_weights = attn_weights + causal_mask
198
+
199
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
200
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
201
+ attn_output = torch.matmul(attn_weights, value_states)
202
+ attn_output = attn_output.transpose(1, 2).contiguous()
203
+
204
+ return attn_output, attn_weights
205
+
206
+
207
+ # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->BLT
208
+ class BLTAttention(nn.Module):
209
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
210
+
211
+ def __init__(self, config: BLTConfig, layer_idx: int):
212
+ super().__init__()
213
+ self.config = config
214
+ self.layer_idx = layer_idx
215
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
216
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
217
+ self.scaling = self.head_dim**-0.5
218
+ self.attention_dropout = config.attention_dropout
219
+ self.is_causal = True
220
+
221
+ self.q_proj = nn.Linear(
222
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
223
+ )
224
+ self.k_proj = nn.Linear(
225
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
226
+ )
227
+ self.v_proj = nn.Linear(
228
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
229
+ )
230
+ self.o_proj = nn.Linear(
231
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
232
+ )
233
+
234
+ def forward(
235
+ self,
236
+ hidden_states: torch.Tensor,
237
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
238
+ attention_mask: Optional[torch.Tensor],
239
+ past_key_value: Optional[Cache] = None,
240
+ cache_position: Optional[torch.LongTensor] = None,
241
+ **kwargs: Unpack[FlashAttentionKwargs],
242
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
243
+ input_shape = hidden_states.shape[:-1]
244
+ hidden_shape = (*input_shape, -1, self.head_dim)
245
+
246
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
247
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
248
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
249
+
250
+ cos, sin = position_embeddings
251
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
252
+
253
+ if past_key_value is not None:
254
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
255
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
256
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
257
+
258
+ attention_interface: Callable = eager_attention_forward
259
+
260
+ if self.config._attn_implementation != "eager":
261
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
262
+ logger.warning_once(
263
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
264
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
265
+ )
266
+ else:
267
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
268
+
269
+ attn_output, attn_weights = attention_interface(
270
+ self,
271
+ query_states,
272
+ key_states,
273
+ value_states,
274
+ attention_mask,
275
+ dropout=0.0 if not self.training else self.attention_dropout,
276
+ scaling=self.scaling,
277
+ **kwargs,
278
+ )
279
+
280
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
281
+ attn_output = self.o_proj(attn_output)
282
+ return attn_output, attn_weights
283
+
284
+
285
+ # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->BLT
286
+ class BLTDecoderLayer(GradientCheckpointingLayer):
287
+ def __init__(self, config: BLTConfig, layer_idx: int):
288
+ super().__init__()
289
+ self.hidden_size = config.hidden_size
290
+
291
+ self.self_attn = BLTAttention(config=config, layer_idx=layer_idx)
292
+
293
+ self.mlp = BLTMLP(config)
294
+ self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
295
+ self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
296
+
297
+ def forward(
298
+ self,
299
+ hidden_states: torch.Tensor,
300
+ attention_mask: Optional[torch.Tensor] = None,
301
+ position_ids: Optional[torch.LongTensor] = None,
302
+ past_key_value: Optional[Cache] = None,
303
+ output_attentions: Optional[bool] = False,
304
+ use_cache: Optional[bool] = False,
305
+ cache_position: Optional[torch.LongTensor] = None,
306
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
307
+ **kwargs: Unpack[FlashAttentionKwargs],
308
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
309
+ residual = hidden_states
310
+ hidden_states = self.input_layernorm(hidden_states)
311
+
312
+ # Self Attention
313
+ hidden_states, self_attn_weights = self.self_attn(
314
+ hidden_states=hidden_states,
315
+ attention_mask=attention_mask,
316
+ position_ids=position_ids,
317
+ past_key_value=past_key_value,
318
+ output_attentions=output_attentions,
319
+ use_cache=use_cache,
320
+ cache_position=cache_position,
321
+ position_embeddings=position_embeddings,
322
+ **kwargs,
323
+ )
324
+ hidden_states = residual + hidden_states
325
+
326
+ # Fully Connected
327
+ residual = hidden_states
328
+ hidden_states = self.post_attention_layernorm(hidden_states)
329
+ hidden_states = self.mlp(hidden_states)
330
+ hidden_states = residual + hidden_states
331
+
332
+ outputs = (hidden_states,)
333
+ if output_attentions:
334
+ outputs += (self_attn_weights,)
335
+
336
+ return outputs
337
+
338
+
339
+ @auto_docstring
340
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->BLT
341
+ class BLTPreTrainedModel(PreTrainedModel):
342
+ config_class = BLTConfig
343
+ base_model_prefix = "model"
344
+ supports_gradient_checkpointing = True
345
+ _no_split_modules = ["BLTDecoderLayer"]
346
+ _skip_keys_device_placement = ["past_key_values"]
347
+ _supports_flash_attn_2 = True
348
+ _supports_sdpa = True
349
+ _supports_flex_attn = True
350
+ _supports_cache_class = True
351
+ _supports_quantized_cache = True
352
+ _supports_static_cache = True
353
+ _supports_attention_backend = True
354
+
355
+ def _init_weights(self, module):
356
+ std = self.config.initializer_range
357
+ if isinstance(module, nn.Linear):
358
+ module.weight.data.normal_(mean=0.0, std=std)
359
+ if module.bias is not None:
360
+ module.bias.data.zero_()
361
+ elif isinstance(module, nn.Embedding):
362
+ module.weight.data.normal_(mean=0.0, std=std)
363
+ if module.padding_idx is not None:
364
+ module.weight.data[module.padding_idx].zero_()
365
+ elif isinstance(module, BLTRMSNorm):
366
+ module.weight.data.fill_(1.0)
367
+
368
+
369
+ @auto_docstring
370
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->BLT
371
+ class BLTModel(BLTPreTrainedModel):
372
+ def __init__(self, config: BLTConfig):
373
+ super().__init__(config)
374
+ self.padding_idx = config.pad_token_id
375
+ self.vocab_size = config.vocab_size
376
+
377
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
378
+ self.layers = nn.ModuleList(
379
+ [BLTDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
380
+ )
381
+ self.norm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
382
+ self.rotary_emb = BLTRotaryEmbedding(config=config)
383
+ self.gradient_checkpointing = False
384
+
385
+ # Initialize weights and apply final processing
386
+ self.post_init()
387
+
388
+ def get_input_embeddings(self):
389
+ return self.embed_tokens
390
+
391
+ def set_input_embeddings(self, value):
392
+ self.embed_tokens = value
393
+
394
+ @can_return_tuple
395
+ @auto_docstring
396
+ def forward(
397
+ self,
398
+ input_ids: Optional[torch.LongTensor] = None,
399
+ attention_mask: Optional[torch.Tensor] = None,
400
+ position_ids: Optional[torch.LongTensor] = None,
401
+ past_key_values: Optional[Cache] = None,
402
+ inputs_embeds: Optional[torch.FloatTensor] = None,
403
+ use_cache: Optional[bool] = None,
404
+ output_attentions: Optional[bool] = None,
405
+ output_hidden_states: Optional[bool] = None,
406
+ cache_position: Optional[torch.LongTensor] = None,
407
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
408
+ ) -> BaseModelOutputWithPast:
409
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
410
+ output_hidden_states = (
411
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
412
+ )
413
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
414
+
415
+ if (input_ids is None) ^ (inputs_embeds is not None):
416
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
417
+
418
+ if self.gradient_checkpointing and self.training and use_cache:
419
+ logger.warning_once(
420
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
421
+ )
422
+ use_cache = False
423
+
424
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
425
+ if not isinstance(past_key_values, (type(None), Cache)):
426
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
427
+
428
+ if inputs_embeds is None:
429
+ inputs_embeds = self.embed_tokens(input_ids)
430
+
431
+ if use_cache and past_key_values is None:
432
+ past_key_values = DynamicCache()
433
+
434
+ if cache_position is None:
435
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
436
+ cache_position = torch.arange(
437
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
438
+ )
439
+
440
+ if position_ids is None:
441
+ position_ids = cache_position.unsqueeze(0)
442
+
443
+ causal_mask = self._update_causal_mask(
444
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
445
+ )
446
+
447
+ hidden_states = inputs_embeds
448
+
449
+ # create position embeddings to be shared across the decoder layers
450
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
451
+
452
+ # decoder layers
453
+ all_hidden_states = () if output_hidden_states else None
454
+ all_self_attns = () if output_attentions else None
455
+
456
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
457
+ if output_hidden_states:
458
+ all_hidden_states += (hidden_states,)
459
+
460
+ layer_outputs = decoder_layer(
461
+ hidden_states,
462
+ attention_mask=causal_mask,
463
+ position_ids=position_ids,
464
+ past_key_value=past_key_values,
465
+ output_attentions=output_attentions,
466
+ use_cache=use_cache,
467
+ cache_position=cache_position,
468
+ position_embeddings=position_embeddings,
469
+ **flash_attn_kwargs,
470
+ )
471
+
472
+ hidden_states = layer_outputs[0]
473
+
474
+ if output_attentions:
475
+ all_self_attns += (layer_outputs[1],)
476
+
477
+ hidden_states = self.norm(hidden_states)
478
+
479
+ # add hidden states from the last decoder layer
480
+ if output_hidden_states:
481
+ all_hidden_states += (hidden_states,)
482
+
483
+ return BaseModelOutputWithPast(
484
+ last_hidden_state=hidden_states,
485
+ past_key_values=past_key_values if use_cache else None,
486
+ hidden_states=all_hidden_states,
487
+ attentions=all_self_attns,
488
+ )
489
+
490
+ def _update_causal_mask(
491
+ self,
492
+ attention_mask: Union[torch.Tensor, "BlockMask"],
493
+ input_tensor: torch.Tensor,
494
+ cache_position: torch.Tensor,
495
+ past_key_values: Cache,
496
+ output_attentions: bool = False,
497
+ ):
498
+ if self.config._attn_implementation == "flash_attention_2":
499
+ if attention_mask is not None and (attention_mask == 0.0).any():
500
+ return attention_mask
501
+ return None
502
+ if self.config._attn_implementation == "flex_attention":
503
+ if isinstance(attention_mask, torch.Tensor):
504
+ attention_mask = make_flex_block_causal_mask(attention_mask)
505
+ return attention_mask
506
+
507
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
508
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
509
+ # to infer the attention mask.
510
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
511
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
512
+
513
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
514
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
515
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
516
+ attention_mask,
517
+ inputs_embeds=input_tensor,
518
+ past_key_values_length=past_seen_tokens,
519
+ is_training=self.training,
520
+ ):
521
+ return None
522
+
523
+ dtype = input_tensor.dtype
524
+ sequence_length = input_tensor.shape[1]
525
+ if using_compilable_cache:
526
+ target_length = past_key_values.get_max_cache_shape()
527
+ else:
528
+ target_length = (
529
+ attention_mask.shape[-1]
530
+ if isinstance(attention_mask, torch.Tensor)
531
+ else past_seen_tokens + sequence_length + 1
532
+ )
533
+
534
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
535
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
536
+ attention_mask,
537
+ sequence_length=sequence_length,
538
+ target_length=target_length,
539
+ dtype=dtype,
540
+ cache_position=cache_position,
541
+ batch_size=input_tensor.shape[0],
542
+ )
543
+
544
+ if (
545
+ self.config._attn_implementation == "sdpa"
546
+ and attention_mask is not None
547
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
548
+ and not output_attentions
549
+ ):
550
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
551
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
552
+ # Details: https://github.com/pytorch/pytorch/issues/110213
553
+ min_dtype = torch.finfo(dtype).min
554
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
555
+
556
+ return causal_mask
557
+
558
+ @staticmethod
559
+ def _prepare_4d_causal_attention_mask_with_cache_position(
560
+ attention_mask: torch.Tensor,
561
+ sequence_length: int,
562
+ target_length: int,
563
+ dtype: torch.dtype,
564
+ cache_position: torch.Tensor,
565
+ batch_size: int,
566
+ **kwargs,
567
+ ):
568
+ """
569
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
570
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
571
+
572
+ Args:
573
+ attention_mask (`torch.Tensor`):
574
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
575
+ `(batch_size, 1, query_length, key_value_length)`.
576
+ sequence_length (`int`):
577
+ The sequence length being processed.
578
+ target_length (`int`):
579
+ The target length: when generating with static cache, the mask should be as long as the static cache,
580
+ to account for the 0 padding, the part of the cache that is not filled yet.
581
+ dtype (`torch.dtype`):
582
+ The dtype to use for the 4D attention mask.
583
+ cache_position (`torch.Tensor`):
584
+ Indices depicting the position of the input sequence tokens in the sequence.
585
+ batch_size (`torch.Tensor`):
586
+ Batch size.
587
+ """
588
+ if attention_mask is not None and attention_mask.dim() == 4:
589
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
590
+ causal_mask = attention_mask
591
+ else:
592
+ min_dtype = torch.finfo(dtype).min
593
+ causal_mask = torch.full(
594
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
595
+ )
596
+ if sequence_length != 1:
597
+ causal_mask = torch.triu(causal_mask, diagonal=1)
598
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
599
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
600
+ if attention_mask is not None:
601
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
602
+ mask_length = attention_mask.shape[-1]
603
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
604
+ causal_mask.device
605
+ )
606
+ padding_mask = padding_mask == 0
607
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
608
+ padding_mask, min_dtype
609
+ )
610
+
611
+ return causal_mask
612
+
613
+
614
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
615
+
616
+
617
+ @auto_docstring
618
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->BLT,llama->blt
619
+ class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin):
620
+ _tied_weights_keys = ["lm_head.weight"]
621
+ _tp_plan = {"lm_head": "colwise_rep"}
622
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
623
+
624
+ def __init__(self, config):
625
+ super().__init__(config)
626
+ self.model = BLTModel(config)
627
+ self.vocab_size = config.vocab_size
628
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
629
+
630
+ # Initialize weights and apply final processing
631
+ self.post_init()
632
+
633
+ def get_input_embeddings(self):
634
+ return self.model.embed_tokens
635
+
636
+ def set_input_embeddings(self, value):
637
+ self.model.embed_tokens = value
638
+
639
+ def get_output_embeddings(self):
640
+ return self.lm_head
641
+
642
+ def set_output_embeddings(self, new_embeddings):
643
+ self.lm_head = new_embeddings
644
+
645
+ def set_decoder(self, decoder):
646
+ self.model = decoder
647
+
648
+ def get_decoder(self):
649
+ return self.model
650
+
651
+ @can_return_tuple
652
+ @auto_docstring
653
+ def forward(
654
+ self,
655
+ input_ids: Optional[torch.LongTensor] = None,
656
+ attention_mask: Optional[torch.Tensor] = None,
657
+ position_ids: Optional[torch.LongTensor] = None,
658
+ past_key_values: Optional[Cache] = None,
659
+ inputs_embeds: Optional[torch.FloatTensor] = None,
660
+ labels: Optional[torch.LongTensor] = None,
661
+ use_cache: Optional[bool] = None,
662
+ output_attentions: Optional[bool] = None,
663
+ output_hidden_states: Optional[bool] = None,
664
+ cache_position: Optional[torch.LongTensor] = None,
665
+ logits_to_keep: Union[int, torch.Tensor] = 0,
666
+ **kwargs: Unpack[KwargsForCausalLM],
667
+ ) -> CausalLMOutputWithPast:
668
+ r"""
669
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
670
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
671
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
672
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
673
+
674
+ Example:
675
+
676
+ ```python
677
+ >>> from transformers import AutoTokenizer, BLTForCausalLM
678
+
679
+ >>> model = BLTForCausalLM.from_pretrained("meta-blt/BLT-2-7b-hf")
680
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-blt/BLT-2-7b-hf")
681
+
682
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
683
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
684
+
685
+ >>> # Generate
686
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
687
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
688
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
689
+ ```"""
690
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
691
+ output_hidden_states = (
692
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
693
+ )
694
+
695
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
696
+ outputs: BaseModelOutputWithPast = self.model(
697
+ input_ids=input_ids,
698
+ attention_mask=attention_mask,
699
+ position_ids=position_ids,
700
+ past_key_values=past_key_values,
701
+ inputs_embeds=inputs_embeds,
702
+ use_cache=use_cache,
703
+ output_attentions=output_attentions,
704
+ output_hidden_states=output_hidden_states,
705
+ cache_position=cache_position,
706
+ **kwargs,
707
+ )
708
+
709
+ hidden_states = outputs.last_hidden_state
710
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
711
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
712
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
713
+
714
+ loss = None
715
+ if labels is not None:
716
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
717
+
718
+ return CausalLMOutputWithPast(
719
+ loss=loss,
720
+ logits=logits,
721
+ past_key_values=outputs.past_key_values,
722
+ hidden_states=outputs.hidden_states,
723
+ attentions=outputs.attentions,
724
+ )
725
+
726
+
727
+ @auto_docstring(
728
+ custom_intro="""
729
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
730
+
731
+ [`BLTForSequenceClassification`] uses the last token in order to do the classification, as other causal models
732
+ (e.g. GPT-2) do.
733
+
734
+ Since it does classification on the last token, it requires to know the position of the last token. If a
735
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
736
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
737
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
738
+ each row of the batch).
739
+ """
740
+ )
741
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->BLT
742
+ class BLTForSequenceClassification(BLTPreTrainedModel):
743
+ def __init__(self, config):
744
+ super().__init__(config)
745
+ self.num_labels = config.num_labels
746
+ self.model = BLTModel(config)
747
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
748
+
749
+ # Initialize weights and apply final processing
750
+ self.post_init()
751
+
752
+ def get_input_embeddings(self):
753
+ return self.model.embed_tokens
754
+
755
+ def set_input_embeddings(self, value):
756
+ self.model.embed_tokens = value
757
+
758
+ @can_return_tuple
759
+ @auto_docstring
760
+ def forward(
761
+ self,
762
+ input_ids: Optional[torch.LongTensor] = None,
763
+ attention_mask: Optional[torch.Tensor] = None,
764
+ position_ids: Optional[torch.LongTensor] = None,
765
+ past_key_values: Optional[Cache] = None,
766
+ inputs_embeds: Optional[torch.FloatTensor] = None,
767
+ labels: Optional[torch.LongTensor] = None,
768
+ use_cache: Optional[bool] = None,
769
+ output_attentions: Optional[bool] = None,
770
+ output_hidden_states: Optional[bool] = None,
771
+ ) -> SequenceClassifierOutputWithPast:
772
+ r"""
773
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
774
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
775
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
776
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
777
+ """
778
+
779
+ transformer_outputs: BaseModelOutputWithPast = self.model(
780
+ input_ids,
781
+ attention_mask=attention_mask,
782
+ position_ids=position_ids,
783
+ past_key_values=past_key_values,
784
+ inputs_embeds=inputs_embeds,
785
+ use_cache=use_cache,
786
+ output_attentions=output_attentions,
787
+ output_hidden_states=output_hidden_states,
788
+ )
789
+ hidden_states = transformer_outputs.last_hidden_state
790
+ logits = self.score(hidden_states)
791
+
792
+ if input_ids is not None:
793
+ batch_size = input_ids.shape[0]
794
+ else:
795
+ batch_size = inputs_embeds.shape[0]
796
+
797
+ if self.config.pad_token_id is None and batch_size != 1:
798
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
799
+ if self.config.pad_token_id is None:
800
+ last_non_pad_token = -1
801
+ elif input_ids is not None:
802
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
803
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
804
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
805
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
806
+ else:
807
+ last_non_pad_token = -1
808
+ logger.warning_once(
809
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
810
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
811
+ )
812
+
813
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
814
+
815
+ loss = None
816
+ if labels is not None:
817
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
818
+
819
+ return SequenceClassifierOutputWithPast(
820
+ loss=loss,
821
+ logits=pooled_logits,
822
+ past_key_values=transformer_outputs.past_key_values,
823
+ hidden_states=transformer_outputs.hidden_states,
824
+ attentions=transformer_outputs.attentions,
825
+ )
826
+
827
+
828
+ @auto_docstring
829
+ # Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->BLT
830
+ class BLTForQuestionAnswering(BLTPreTrainedModel):
831
+ base_model_prefix = "transformer"
832
+
833
+ def __init__(self, config):
834
+ super().__init__(config)
835
+ self.transformer = BLTModel(config)
836
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
837
+
838
+ # Initialize weights and apply final processing
839
+ self.post_init()
840
+
841
+ def get_input_embeddings(self):
842
+ return self.transformer.embed_tokens
843
+
844
+ def set_input_embeddings(self, value):
845
+ self.transformer.embed_tokens = value
846
+
847
+ @can_return_tuple
848
+ @auto_docstring
849
+ def forward(
850
+ self,
851
+ input_ids: Optional[torch.LongTensor] = None,
852
+ attention_mask: Optional[torch.Tensor] = None,
853
+ position_ids: Optional[torch.LongTensor] = None,
854
+ past_key_values: Optional[Cache] = None,
855
+ inputs_embeds: Optional[torch.FloatTensor] = None,
856
+ start_positions: Optional[torch.LongTensor] = None,
857
+ end_positions: Optional[torch.LongTensor] = None,
858
+ output_attentions: Optional[bool] = None,
859
+ output_hidden_states: Optional[bool] = None,
860
+ **kwargs,
861
+ ) -> QuestionAnsweringModelOutput:
862
+ outputs: BaseModelOutputWithPast = self.transformer(
863
+ input_ids,
864
+ attention_mask=attention_mask,
865
+ position_ids=position_ids,
866
+ past_key_values=past_key_values,
867
+ inputs_embeds=inputs_embeds,
868
+ output_attentions=output_attentions,
869
+ output_hidden_states=output_hidden_states,
870
+ )
871
+
872
+ sequence_output = outputs.last_hidden_state
873
+
874
+ logits = self.qa_outputs(sequence_output)
875
+ start_logits, end_logits = logits.split(1, dim=-1)
876
+ start_logits = start_logits.squeeze(-1).contiguous()
877
+ end_logits = end_logits.squeeze(-1).contiguous()
878
+
879
+ loss = None
880
+ if start_positions is not None and end_positions is not None:
881
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
882
+
883
+ return QuestionAnsweringModelOutput(
884
+ loss=loss,
885
+ start_logits=start_logits,
886
+ end_logits=end_logits,
887
+ hidden_states=outputs.hidden_states,
888
+ attentions=outputs.attentions,
889
+ )
890
+
891
+
892
+ @auto_docstring
893
+ # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->BLT
894
+ class BLTForTokenClassification(BLTPreTrainedModel):
895
+ def __init__(self, config):
896
+ super().__init__(config)
897
+ self.num_labels = config.num_labels
898
+ self.model = BLTModel(config)
899
+ if getattr(config, "classifier_dropout", None) is not None:
900
+ classifier_dropout = config.classifier_dropout
901
+ elif getattr(config, "hidden_dropout", None) is not None:
902
+ classifier_dropout = config.hidden_dropout
903
+ else:
904
+ classifier_dropout = 0.1
905
+ self.dropout = nn.Dropout(classifier_dropout)
906
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
907
+
908
+ # Initialize weights and apply final processing
909
+ self.post_init()
910
+
911
+ def get_input_embeddings(self):
912
+ return self.model.embed_tokens
913
+
914
+ def set_input_embeddings(self, value):
915
+ self.model.embed_tokens = value
916
+
917
+ @can_return_tuple
918
+ @auto_docstring
919
+ def forward(
920
+ self,
921
+ input_ids: Optional[torch.LongTensor] = None,
922
+ attention_mask: Optional[torch.Tensor] = None,
923
+ position_ids: Optional[torch.LongTensor] = None,
924
+ past_key_values: Optional[Cache] = None,
925
+ inputs_embeds: Optional[torch.FloatTensor] = None,
926
+ labels: Optional[torch.LongTensor] = None,
927
+ use_cache: Optional[bool] = None,
928
+ output_attentions: Optional[bool] = None,
929
+ output_hidden_states: Optional[bool] = None,
930
+ ) -> TokenClassifierOutput:
931
+ r"""
932
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
933
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
934
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
935
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
936
+ """
937
+
938
+ outputs: BaseModelOutputWithPast = self.model(
939
+ input_ids,
940
+ attention_mask=attention_mask,
941
+ position_ids=position_ids,
942
+ past_key_values=past_key_values,
943
+ inputs_embeds=inputs_embeds,
944
+ use_cache=use_cache,
945
+ output_attentions=output_attentions,
946
+ output_hidden_states=output_hidden_states,
947
+ )
948
+ sequence_output = outputs.last_hidden_state
949
+ sequence_output = self.dropout(sequence_output)
950
+ logits = self.score(sequence_output)
951
+
952
+ loss = None
953
+ if labels is not None:
954
+ loss = self.loss_function(logits, labels, self.config)
955
+
956
+ return TokenClassifierOutput(
957
+ loss=loss,
958
+ logits=logits,
959
+ hidden_states=outputs.hidden_states,
960
+ attentions=outputs.attentions,
961
+ )
962
+
963
+
964
+ __all__ = [
965
+ "BLTForCausalLM",
966
+ "BLTModel",
967
+ "BLTPreTrainedModel",
968
+ "BLTForSequenceClassification",
969
+ "BLTForQuestionAnswering",
970
+ "BLTForTokenClassification",
971
+ ]
backup_blt_modellike/tokenization_blt.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ """Tokenization classes for BLT."""
22
+
23
+ import os
24
+ from shutil import copyfile
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
26
+
27
+ import sentencepiece as spm
28
+
29
+ from ...convert_slow_tokenizer import import_protobuf
30
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
31
+ from ...utils import logging
32
+ from ...utils.import_utils import requires
33
+
34
+
35
+ if TYPE_CHECKING:
36
+ from ...tokenization_utils_base import TextInput
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
41
+
42
+ SPIECE_UNDERLINE = "▁"
43
+
44
+ B_INST, E_INST = "[INST]", "[/INST]"
45
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
46
+
47
+ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
48
+ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
49
+ that your responses are socially unbiased and positive in nature.
50
+
51
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
52
+ correct. If you don't know the answer to a question, please don't share false information.""" # fmt: skip
53
+
54
+
55
+ @requires(backends=("sentencepiece",))
56
+ class BLTTokenizer(PreTrainedTokenizer):
57
+ """
58
+ Construct a BLT tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
59
+ no padding token in the original model.
60
+
61
+ Args:
62
+ vocab_file (`str`):
63
+ Path to the vocabulary file.
64
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
65
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
66
+ token instead.
67
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
68
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
69
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
70
+ The end of sequence token.
71
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*):
72
+ A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
73
+ attention mechanisms or loss computation.
74
+ sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
75
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
76
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
77
+ to set:
78
+
79
+ - `enable_sampling`: Enable subword regularization.
80
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
81
+
82
+ - `nbest_size = {0,1}`: No sampling is performed.
83
+ - `nbest_size > 1`: samples from the nbest_size results.
84
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
85
+ using forward-filtering-and-backward-sampling algorithm.
86
+
87
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
88
+ BPE-dropout.
89
+
90
+ add_bos_token (`bool`, *optional*, defaults to `True`):
91
+ Whether or not to add an `bos_token` at the start of sequences.
92
+ add_eos_token (`bool`, *optional*, defaults to `False`):
93
+ Whether or not to add an `eos_token` at the end of sequences.
94
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
95
+ Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
96
+ extra spaces.
97
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
98
+ Whether or not the default system prompt for BLT should be used.
99
+ spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
100
+ Whether or not to add spaces between special tokens.
101
+ legacy (`bool`, *optional*):
102
+ Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
103
+ and #25224 which includes fixes to properly handle tokens that appear after special tokens.
104
+ Make sure to also set `from_slow` to `True`.
105
+ A simple example:
106
+
107
+ - `legacy=True`:
108
+ ```python
109
+ >>> from transformers import BLTTokenizerFast
110
+
111
+ >>> tokenizer = BLTTokenizerFast.from_pretrained("huggyblt/blt-7b", legacy=True, from_slow=True)
112
+ >>> tokenizer.encode("Hello <s>.") # 869 is '▁.'
113
+ [1, 15043, 29871, 1, 869]
114
+ ```
115
+ - `legacy=False`:
116
+ ```python
117
+ >>> from transformers import BLTTokenizerFast
118
+
119
+ >>> tokenizer = BLTTokenizerFast.from_pretrained("huggyblt/blt-7b", legacy=False, from_slow=True)
120
+ >>> tokenizer.encode("Hello <s>.") # 29889 is '.'
121
+ [1, 15043, 29871, 1, 29889]
122
+ ```
123
+ Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
124
+ add_prefix_space (`bool`, *optional*, defaults to `True`):
125
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
126
+ other word. Again, this should be set with `from_slow=True` to make sure it's taken into account.
127
+ """
128
+
129
+ vocab_files_names = VOCAB_FILES_NAMES
130
+ model_input_names = ["input_ids", "attention_mask"]
131
+
132
+ def __init__(
133
+ self,
134
+ vocab_file,
135
+ unk_token="<unk>",
136
+ bos_token="<s>",
137
+ eos_token="</s>",
138
+ pad_token=None,
139
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
140
+ add_bos_token=True,
141
+ add_eos_token=False,
142
+ clean_up_tokenization_spaces=False,
143
+ use_default_system_prompt=False,
144
+ spaces_between_special_tokens=False,
145
+ legacy=None,
146
+ add_prefix_space=True,
147
+ **kwargs,
148
+ ):
149
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
150
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
151
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
152
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
153
+ pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
154
+
155
+ if legacy is None:
156
+ logger.warning_once(
157
+ f"You are using the default legacy behaviour of the {self.__class__}. This is"
158
+ " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
159
+ " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
160
+ " means, and thoroughly read the reason why this was added as explained in"
161
+ " https://github.com/huggingface/transformers/pull/24565 - if you loaded a blt tokenizer from a GGUF file"
162
+ " you can ignore this message"
163
+ )
164
+ legacy = True
165
+
166
+ self.legacy = legacy
167
+ self.vocab_file = vocab_file
168
+ self.add_bos_token = add_bos_token
169
+ self.add_eos_token = add_eos_token
170
+ self.use_default_system_prompt = use_default_system_prompt
171
+ self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
172
+ self.add_prefix_space = add_prefix_space
173
+
174
+ super().__init__(
175
+ bos_token=bos_token,
176
+ eos_token=eos_token,
177
+ unk_token=unk_token,
178
+ pad_token=pad_token,
179
+ add_bos_token=add_bos_token,
180
+ add_eos_token=add_eos_token,
181
+ sp_model_kwargs=self.sp_model_kwargs,
182
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
183
+ use_default_system_prompt=use_default_system_prompt,
184
+ spaces_between_special_tokens=spaces_between_special_tokens,
185
+ legacy=legacy,
186
+ add_prefix_space=add_prefix_space,
187
+ **kwargs,
188
+ )
189
+
190
+ @property
191
+ def unk_token_length(self):
192
+ return len(self.sp_model.encode(str(self.unk_token)))
193
+
194
+ def get_spm_processor(self, from_slow=False):
195
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
196
+ if self.legacy or from_slow: # no dependency on protobuf
197
+ tokenizer.Load(self.vocab_file)
198
+ return tokenizer
199
+
200
+ with open(self.vocab_file, "rb") as f:
201
+ sp_model = f.read()
202
+ model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
203
+ model = model_pb2.ModelProto.FromString(sp_model)
204
+ normalizer_spec = model_pb2.NormalizerSpec()
205
+ normalizer_spec.add_dummy_prefix = False
206
+ model.normalizer_spec.MergeFrom(normalizer_spec)
207
+ sp_model = model.SerializeToString()
208
+ tokenizer.LoadFromSerializedProto(sp_model)
209
+ return tokenizer
210
+
211
+ def __getstate__(self):
212
+ state = self.__dict__.copy()
213
+ state["sp_model"] = None
214
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
215
+ return state
216
+
217
+ def __setstate__(self, d):
218
+ self.__dict__.update(d)
219
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
220
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
221
+
222
+ @property
223
+ def vocab_size(self):
224
+ """Returns vocab size"""
225
+ return self.sp_model.get_piece_size()
226
+
227
+ def get_vocab(self):
228
+ """Returns vocab as a dict"""
229
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
230
+ vocab.update(self.added_tokens_encoder)
231
+ return vocab
232
+
233
+ def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
234
+ """
235
+ Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
236
+ first token is special.
237
+ """
238
+ if self.legacy or len(text) == 0:
239
+ return super().tokenize(text, **kwargs)
240
+
241
+ text = text.replace(SPIECE_UNDERLINE, " ")
242
+ if self.add_prefix_space:
243
+ text = SPIECE_UNDERLINE + text
244
+
245
+ tokens = super().tokenize(text, **kwargs)
246
+
247
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
248
+ tokens = tokens[1:]
249
+ return tokens
250
+
251
+ def _tokenize(self, text, **kwargs):
252
+ """
253
+ Returns a tokenized string.
254
+
255
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
256
+ SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
257
+ `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
258
+ `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
259
+ `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
260
+ """
261
+ if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
262
+ return self.sp_model.encode(text, out_type=str)
263
+
264
+ # 1. Encode string + prefix ex: "<unk> Hey"
265
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
266
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
267
+ return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
268
+
269
+ def _convert_token_to_id(self, token):
270
+ """Converts a token (str) in an id using the vocab."""
271
+ return self.sp_model.piece_to_id(token)
272
+
273
+ def _convert_id_to_token(self, index):
274
+ """Converts an index (integer) in a token (str) using the vocab."""
275
+ token = self.sp_model.IdToPiece(index)
276
+ return token
277
+
278
+ def convert_tokens_to_string(self, tokens):
279
+ """Converts a sequence of tokens (string) in a single string."""
280
+ # since we manually add the prefix space, we have to remove it when decoding
281
+ if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space:
282
+ tokens[0] = tokens[0][1:]
283
+
284
+ current_sub_tokens = []
285
+ out_string = ""
286
+ prev_is_special = False
287
+ for i, token in enumerate(tokens):
288
+ # make sure that special tokens are not decoded using sentencepiece model
289
+ if token in self.all_special_tokens:
290
+ if not prev_is_special and i != 0 and self.legacy:
291
+ out_string += " "
292
+ out_string += self.sp_model.decode(current_sub_tokens) + token
293
+ prev_is_special = True
294
+ current_sub_tokens = []
295
+ else:
296
+ if prev_is_special and i == 1 and self.add_prefix_space and not token.startswith(SPIECE_UNDERLINE):
297
+ out_string += " "
298
+ current_sub_tokens.append(token)
299
+ prev_is_special = False
300
+ out_string += self.sp_model.decode(current_sub_tokens)
301
+ return out_string
302
+
303
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
304
+ """
305
+ Save the vocabulary and special tokens file to a directory.
306
+
307
+ Args:
308
+ save_directory (`str`):
309
+ The directory in which to save the vocabulary.
310
+
311
+ Returns:
312
+ `Tuple(str)`: Paths to the files saved.
313
+ """
314
+ if not os.path.isdir(save_directory):
315
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
316
+ return
317
+ out_vocab_file = os.path.join(
318
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
319
+ )
320
+
321
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
322
+ copyfile(self.vocab_file, out_vocab_file)
323
+ elif not os.path.isfile(self.vocab_file):
324
+ with open(out_vocab_file, "wb") as fi:
325
+ content_spiece_model = self.sp_model.serialized_model_proto()
326
+ fi.write(content_spiece_model)
327
+
328
+ return (out_vocab_file,)
329
+
330
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
331
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
332
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
333
+
334
+ output = bos_token_id + token_ids_0 + eos_token_id
335
+
336
+ if token_ids_1 is not None:
337
+ output = output + bos_token_id + token_ids_1 + eos_token_id
338
+
339
+ return output
340
+
341
+ def get_special_tokens_mask(
342
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
343
+ ) -> List[int]:
344
+ """
345
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
346
+ special tokens using the tokenizer `prepare_for_model` method.
347
+
348
+ Args:
349
+ token_ids_0 (`List[int]`):
350
+ List of IDs.
351
+ token_ids_1 (`List[int]`, *optional*):
352
+ Optional second list of IDs for sequence pairs.
353
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
354
+ Whether or not the token list is already formatted with special tokens for the model.
355
+
356
+ Returns:
357
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
358
+ """
359
+ if already_has_special_tokens:
360
+ return super().get_special_tokens_mask(
361
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
362
+ )
363
+
364
+ bos_token_id = [1] if self.add_bos_token else []
365
+ eos_token_id = [1] if self.add_eos_token else []
366
+
367
+ if token_ids_1 is None:
368
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
369
+ return (
370
+ bos_token_id
371
+ + ([0] * len(token_ids_0))
372
+ + eos_token_id
373
+ + bos_token_id
374
+ + ([0] * len(token_ids_1))
375
+ + eos_token_id
376
+ )
377
+
378
+ def create_token_type_ids_from_sequences(
379
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
380
+ ) -> List[int]:
381
+ """
382
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
383
+ sequence pair mask has the following format:
384
+
385
+ ```
386
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
387
+ | first sequence | second sequence |
388
+ ```
389
+
390
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
391
+
392
+ Args:
393
+ token_ids_0 (`List[int]`):
394
+ List of ids.
395
+ token_ids_1 (`List[int]`, *optional*):
396
+ Optional second list of IDs for sequence pairs.
397
+
398
+ Returns:
399
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
400
+ """
401
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
402
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
403
+
404
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
405
+
406
+ if token_ids_1 is not None:
407
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
408
+
409
+ return output
410
+
411
+
412
+ #__all__ = ["BLTTokenizer"]
backup_blt_wip copy/__init__.py ADDED
File without changes
backup_blt_wip copy/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (179 Bytes). View file
 
backup_blt_wip copy/__pycache__/blt_args.cpython-312.pyc ADDED
Binary file (7.05 kB). View file
 
backup_blt_wip copy/__pycache__/blt_one_file.cpython-312.pyc ADDED
Binary file (96.4 kB). View file
 
backup_blt_wip copy/__pycache__/configuration_blt.cpython-312.pyc ADDED
Binary file (12.7 kB). View file
 
backup_blt_wip copy/__pycache__/configuration_blt_og.cpython-312.pyc ADDED
Binary file (22.4 kB). View file
 
backup_blt_wip copy/__pycache__/modeling_blt.cpython-312.pyc ADDED
Binary file (64 kB). View file
 
backup_blt_wip copy/__pycache__/modeling_blt_dev.cpython-312.pyc ADDED
Binary file (64 kB). View file
 
backup_blt_wip copy/__pycache__/modeling_blt_modellike.cpython-312.pyc ADDED
Binary file (67.4 kB). View file
 
backup_blt_wip copy/__pycache__/modeling_blt_old.cpython-312.pyc ADDED
Binary file (68.8 kB). View file
 
backup_blt_wip copy/__pycache__/modeling_blt_wip.cpython-312.pyc ADDED
Binary file (78.7 kB). View file
 
backup_blt_wip copy/__pycache__/modeling_blt_wip_backup.cpython-312.pyc ADDED
Binary file (89.6 kB). View file
 
backup_blt_wip copy/__pycache__/tokenization_blt.cpython-312.pyc ADDED
Binary file (11.6 kB). View file
 
backup_blt_wip copy/configuration_blt.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Facebook Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """BLT model configuration"""
16
+
17
+ from enum import Enum
18
+ from typing import Union
19
+
20
+ from ...configuration_utils import PretrainedConfig
21
+ from ...utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ class BLTLocalEncoderConfig(PretrainedConfig):
27
+ """
28
+ Configuration class for the BLT Local Encoder component.
29
+ """
30
+
31
+ model_type = "blt_local_encoder"
32
+
33
+ def __init__(
34
+ self,
35
+ vocab_size=256,
36
+ cross_attn_all_layers=True,
37
+ cross_attn_k=2,
38
+ hidden_size_global=2048,
39
+ hidden_size=512,
40
+ num_attention_heads=8,
41
+ num_key_value_heads=None,
42
+ num_hidden_layers=8,
43
+ norm_eps=1e-5,
44
+ dropout=0.0,
45
+ max_position_embeddings=1024,
46
+ rope_theta=10000.0,
47
+ rope_scaling=None,
48
+ hidden_act="silu",
49
+ intermediate_size=None,
50
+ _attn_implementation="sdpa",
51
+ **kwargs,
52
+ ):
53
+ self.vocab_size = vocab_size
54
+ self.cross_attn_all_layers = cross_attn_all_layers
55
+ self.cross_attn_k = cross_attn_k
56
+ self.hidden_size_global = hidden_size_global
57
+ self.hidden_size = hidden_size
58
+ self.num_attention_heads = num_attention_heads
59
+ self.num_key_value_heads = num_key_value_heads or num_attention_heads
60
+ self.head_dim = hidden_size // num_attention_heads
61
+ self.intermediate_size = intermediate_size or int(8 * hidden_size / 3)
62
+ self.num_hidden_layers = num_hidden_layers
63
+ self.norm_eps = norm_eps
64
+ self.dropout = dropout
65
+ self.max_position_embeddings = max_position_embeddings
66
+ self.rope_theta = rope_theta
67
+ self.rope_scaling = rope_scaling or {"rope_type": "default"}
68
+ self.hidden_act = hidden_act
69
+ self._attn_implementation = _attn_implementation
70
+
71
+ super().__init__(**kwargs)
72
+
73
+ class BLTLocalDecoderConfig(PretrainedConfig):
74
+ """
75
+ Configuration class for the BLT Local Decoder component.
76
+ """
77
+
78
+ model_type = "blt_local_decoder"
79
+
80
+ def __init__(
81
+ self,
82
+ vocab_size=256,
83
+ cross_attn_all_layers=True,
84
+ cross_attn_k=2,
85
+ hidden_size_global=2048,
86
+ hidden_size=512,
87
+ num_attention_heads=8,
88
+ num_key_value_heads=None,
89
+ num_hidden_layers=8,
90
+ norm_eps=1e-5,
91
+ dropout=0.0,
92
+ max_position_embeddings=1024,
93
+ rope_theta=10000.0,
94
+ rope_scaling=None,
95
+ hidden_act="silu",
96
+ intermediate_size=None,
97
+ _attn_implementation="sdpa",
98
+ **kwargs,
99
+ ):
100
+ self.vocab_size = vocab_size
101
+ self.cross_attn_all_layers = cross_attn_all_layers
102
+ self.cross_attn_k = cross_attn_k
103
+ self.hidden_size_global = hidden_size_global
104
+ self.hidden_size = hidden_size
105
+ self.num_attention_heads = num_attention_heads
106
+ self.num_key_value_heads = num_key_value_heads or num_attention_heads
107
+ self.head_dim = hidden_size // num_attention_heads
108
+ self.intermediate_size = intermediate_size or int(8 * hidden_size / 3)
109
+ self.num_hidden_layers = num_hidden_layers
110
+ self.norm_eps = norm_eps
111
+ self.dropout = dropout
112
+ self.max_position_embeddings = max_position_embeddings
113
+ self.rope_theta = rope_theta
114
+ self.rope_scaling = rope_scaling or {"rope_type": "default"}
115
+ self.hidden_act = hidden_act
116
+ self._attn_implementation = _attn_implementation
117
+
118
+ super().__init__(**kwargs)
119
+
120
+
121
+ class BLTGlobalTransformerConfig(PretrainedConfig):
122
+ """
123
+ Configuration class for the BLT Global Transformer component.
124
+ """
125
+
126
+ model_type = "blt_global_transformer"
127
+
128
+ def __init__(
129
+ self,
130
+ hidden_size=512,
131
+ num_attention_heads=8,
132
+ num_key_value_heads=None,
133
+ num_hidden_layers=8,
134
+ norm_eps=1e-5,
135
+ dropout=0.0,
136
+ max_position_embeddings=1024,
137
+ rope_theta=10000.0,
138
+ rope_scaling=None,
139
+ hidden_act="silu",
140
+ intermediate_size=None,
141
+ _attn_implementation="sdpa",
142
+ **kwargs,
143
+ ):
144
+ self.hidden_size = hidden_size
145
+ self.num_attention_heads = num_attention_heads
146
+ self.num_key_value_heads = num_key_value_heads or num_attention_heads
147
+ self.head_dim = hidden_size // num_attention_heads
148
+ self.intermediate_size = intermediate_size or int(8 * hidden_size / 3)
149
+ self.num_hidden_layers = num_hidden_layers
150
+ self.norm_eps = norm_eps
151
+ self.dropout = dropout
152
+ self.max_position_embeddings = max_position_embeddings
153
+ self.rope_theta = rope_theta
154
+ self.rope_scaling = rope_scaling or {"rope_type": "default"}
155
+ self.hidden_act = hidden_act
156
+ self._attn_implementation = _attn_implementation
157
+
158
+ super().__init__(**kwargs)
159
+
160
+
161
+ class BLTPatcherConfig(PretrainedConfig):
162
+ r"""
163
+ Configuration class for the BLT Patcher/Entropy model component.
164
+
165
+ Args:
166
+ vocab_size (`int`, *optional*, defaults to 256):
167
+ Vocabulary size for the entropy model used in patching.
168
+ hidden_size (`int`, *optional*, defaults to 512):
169
+ Hidden dimension for the entropy model.
170
+ num_hidden_layers (`int`, *optional*, defaults to 8):
171
+ Number of layers in the entropy model.
172
+ num_attention_heads (`int`, *optional*, defaults to 8):
173
+ Number of attention heads in the entropy model.
174
+ head_dim (`int`, *optional*):
175
+ Dimension of each attention head in the entropy model.
176
+ num_key_value_heads (`int`, *optional*):
177
+ Number of key-value heads in the entropy model.
178
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
179
+ Maximum sequence length for the entropy model.
180
+ norm_eps (`float`, *optional*, defaults to 1e-5):
181
+ Layer normalization epsilon for the entropy model.
182
+ dropout (`float`, *optional*, defaults to 0.0):
183
+ Dropout probability for the entropy model.
184
+ ffn_dim_multiplier (`float`, *optional*):
185
+ Feedforward dimension multiplier for the entropy model.
186
+ multiple_of (`int`, *optional*, defaults to 256):
187
+ Make feedforward dimension multiple of this for the entropy model.
188
+ rope_theta (`float`, *optional*, defaults to 10000.0):
189
+ RoPE theta parameter for the entropy model.
190
+ attn_impl (`str`, *optional*, defaults to "sdpa"):
191
+ Attention implementation for the entropy model.
192
+ attn_bias_type (`str`, *optional*, defaults to "causal"):
193
+ Attention bias type for the entropy model.
194
+ """
195
+
196
+ model_type = "blt_patcher"
197
+
198
+ def __init__(
199
+ self,
200
+ vocab_size=256,
201
+ hidden_size=512,
202
+ num_hidden_layers=8,
203
+ num_attention_heads=8,
204
+ num_key_value_heads=None,
205
+ max_position_embeddings=1024,
206
+ norm_eps=1e-5,
207
+ dropout=0.0,
208
+ rope_theta=10000.0,
209
+ attn_impl="sdpa",
210
+ attn_bias_type="causal",
211
+ intermediate_size=None,
212
+ **kwargs,
213
+ ):
214
+ self.vocab_size = vocab_size
215
+ self.hidden_size = hidden_size
216
+ self.num_hidden_layers = num_hidden_layers
217
+ self.num_attention_heads = num_attention_heads
218
+ self.head_dim = hidden_size // num_attention_heads
219
+ self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
220
+ self.max_position_embeddings = max_position_embeddings
221
+ self.norm_eps = norm_eps
222
+ self.dropout = dropout
223
+ self.rope_theta = rope_theta
224
+ self.attn_impl = attn_impl
225
+ self.attn_bias_type = attn_bias_type
226
+ self.hidden_act = "silu" # BLT uses silu activation
227
+ self.intermediate_size = intermediate_size or int(8 * self.hidden_size / 3)
228
+ self.rope_scaling = {"rope_type": "default"}
229
+ super().__init__(**kwargs)
230
+
231
+
232
+ class BLTConfig(PretrainedConfig):
233
+ r"""
234
+ This is the configuration class to store the configuration of a [`BLTModel`]. It is used to instantiate a
235
+ BLT model according to the specified arguments, defining the model architecture.
236
+
237
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
238
+ documentation from [`PretrainedConfig`] for more information.
239
+
240
+ Args:
241
+ vocab_size (`int`, *optional*, defaults to 256):
242
+ Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented.
243
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
244
+ The maximum sequence length that this model can handle.
245
+
246
+ # Patching configuration
247
+ patch_in_forward (`bool`, *optional*, defaults to False):
248
+ Whether to perform patching during forward pass.
249
+ patch_size (`float`, *optional*):
250
+ Size of patches for static patching.
251
+ patching_mode (`str`, *optional*):
252
+ Mode for patching ("entropy", "static", etc.).
253
+ patching_threshold (`float`, *optional*):
254
+ Threshold for entropy-based patching.
255
+ patching_batch_size (`int`, *optional*, defaults to 1):
256
+ Batch size for patching operations.
257
+ patching_device (`str`, *optional*, defaults to "cuda"):
258
+ Device to use for patching operations.
259
+ max_patch_length (`int`, *optional*):
260
+ Maximum length of patches.
261
+
262
+ # Cross attention configurations
263
+ cross_attn_k (`int`, *optional*):
264
+ Number of cross attention components.
265
+
266
+ # Encoder configurations
267
+ encoder_hash_byte_group_size (`Any`, *optional*):
268
+ Hash byte group size for encoder.
269
+ encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 30000):
270
+ Vocabulary size for hash byte groups.
271
+ encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 3):
272
+ Number of hash functions for byte groups.
273
+
274
+ # Component configurations
275
+ patcher_config (`Union[BLTPatcherConfig, dict]`, *optional*):
276
+ Configuration for the BLT patcher/entropy model component.
277
+ encoder_config (`Union[BLTLocalEncoderConfig, dict]`, *optional*):
278
+ Configuration for the BLT local encoder component.
279
+ decoder_config (`Union[BLTLocalDecoderConfig, dict]`, *optional*):
280
+ Configuration for the BLT local decoder component.
281
+ global_config (`Union[BLTGlobalTransformerConfig, dict]`, *optional*):
282
+ Configuration for the BLT global transformer component.
283
+
284
+ ```python
285
+ >>> from transformers import BLTModel, BLTConfig
286
+
287
+ >>> # Initializing a BLT configuration
288
+ >>> configuration = BLTConfig()
289
+
290
+ >>> # Initializing a model from the configuration
291
+ >>> model = BLTModel(configuration)
292
+
293
+ >>> # Accessing the model configuration
294
+ >>> configuration = model.config
295
+ ```"""
296
+
297
+ model_type = "blt"
298
+ keys_to_ignore_at_inference = ["past_key_values"]
299
+ sub_configs = {
300
+ "patcher_config": BLTPatcherConfig,
301
+ "encoder_config": BLTLocalEncoderConfig,
302
+ "decoder_config": BLTLocalDecoderConfig,
303
+ "global_config": BLTGlobalTransformerConfig
304
+ }
305
+
306
+ def __init__(
307
+ self,
308
+ vocab_size=256,
309
+ max_position_embeddings=1024,
310
+ patch_in_forward=False,
311
+ patch_size=None,
312
+ patching_mode=None,
313
+ patching_threshold=None,
314
+ patching_batch_size=1,
315
+ max_patch_length=None,
316
+ cross_attn_k=2,
317
+ encoder_hash_byte_group_size=None,
318
+ encoder_hash_byte_group_vocab=30000,
319
+ encoder_hash_byte_group_nb_functions=3,
320
+ patcher_config=None,
321
+ encoder_config=None,
322
+ decoder_config=None,
323
+ global_config=None,
324
+ tie_word_embeddings=False,
325
+ **kwargs,
326
+ ):
327
+
328
+ # Basic model configuration
329
+ self.tie_word_embeddings = tie_word_embeddings
330
+ self.vocab_size = vocab_size
331
+ self.max_position_embeddings = max_position_embeddings
332
+
333
+ # Patching configuration
334
+ self.patch_in_forward = patch_in_forward
335
+ self.patch_size = patch_size
336
+ self.patching_mode = patching_mode
337
+ self.patching_threshold = patching_threshold
338
+ self.patching_batch_size = patching_batch_size
339
+ self.max_patch_length = max_patch_length
340
+
341
+ # Cross attention configurations
342
+ self.cross_attn_k = cross_attn_k
343
+
344
+ # Encoder configurations
345
+ self.encoder_hash_byte_group_size = encoder_hash_byte_group_size or [2, 3, 4]
346
+ self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab
347
+ self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions
348
+
349
+ # Initialize component configurations
350
+ if patcher_config is None:
351
+ self.patcher_config = BLTPatcherConfig()
352
+ logger.info("patcher_config is None, using default BLT patcher config")
353
+ elif isinstance(patcher_config, dict):
354
+ self.patcher_config = BLTPatcherConfig(**patcher_config)
355
+ elif isinstance(patcher_config, BLTPatcherConfig):
356
+ self.patcher_config = patcher_config
357
+
358
+ if encoder_config is None:
359
+ self.encoder_config = BLTLocalEncoderConfig()
360
+ logger.info("encoder_config is None, using default BLT encoder config")
361
+ elif isinstance(encoder_config, dict):
362
+ self.encoder_config = BLTLocalEncoderConfig(**encoder_config)
363
+ elif isinstance(encoder_config, BLTLocalEncoderConfig):
364
+ self.encoder_config = encoder_config
365
+
366
+ if decoder_config is None:
367
+ self.decoder_config = BLTLocalDecoderConfig()
368
+ logger.info("decoder_config is None, using default BLT decoder config")
369
+ elif isinstance(decoder_config, dict):
370
+ self.decoder_config = BLTLocalDecoderConfig(**decoder_config)
371
+ elif isinstance(decoder_config, BLTLocalDecoderConfig):
372
+ self.decoder_config = decoder_config
373
+
374
+ if global_config is None:
375
+ self.global_config = BLTGlobalTransformerConfig()
376
+ logger.info("global_config is None, using default BLT global config")
377
+ elif isinstance(global_config, dict):
378
+ self.global_config = BLTGlobalTransformerConfig(**global_config)
379
+ elif isinstance(global_config, BLTGlobalTransformerConfig):
380
+ self.global_config = global_config
381
+
382
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
383
+
384
+ __all__ = [
385
+ "BLTConfig",
386
+ "BLTPatcherConfig",
387
+ "BLTLocalEncoderConfig",
388
+ "BLTLocalDecoderConfig",
389
+ "BLTGlobalTransformerConfig",
390
+ ]
backup_blt_wip copy/configuration_blt_og.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # old config
2
+
3
+ # coding=utf-8
4
+ # Copyright 2024 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """BLT (Byte Latent Transformer) model configuration"""
18
+
19
+ from enum import Enum
20
+
21
+ from ...configuration_utils import PretrainedConfig
22
+ from ...utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class InitStdFactor(str, Enum):
29
+ DISABLED = "disabled" # Init std is divided by 1.0
30
+ CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
31
+
32
+
33
+ class PatchingModeEnum(str, Enum):
34
+ entropy = "entropy"
35
+ bpe = "bpe"
36
+ bpe_patcher = "bpe_patcher"
37
+ space = "space"
38
+ static = "static"
39
+ byte = "byte"
40
+
41
+
42
+ class BLTConfig(PretrainedConfig):
43
+ r"""
44
+ This is the configuration class to store the configuration of a [`ByteLatentTransformer`]. It is used to instantiate a
45
+ BLT model according to the specified arguments, defining the model architecture.
46
+
47
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
48
+ documentation from [`PretrainedConfig`] for more information.
49
+
50
+ Args:
51
+ vocab_size (`int`, *optional*, defaults to 256):
52
+ Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented.
53
+ max_seqlen (`int`, *optional*, defaults to 1024):
54
+ The maximum sequence length that this model can handle.
55
+
56
+ # Main architecture dimensions
57
+ dim (`int`, *optional*, defaults to 512):
58
+ Main dimension of the model.
59
+ n_layers (`int`, *optional*, defaults to 8):
60
+ Number of layers in the main transformer.
61
+ n_heads (`int`, *optional*, defaults to 8):
62
+ Number of attention heads in the main transformer.
63
+ head_dim (`int`, *optional*):
64
+ Dimension of each attention head. If not specified, computed as dim // n_heads.
65
+ n_kv_heads (`int`, *optional*):
66
+ Number of key-value heads for grouped query attention. If not specified, defaults to n_heads.
67
+
68
+ # Component-specific dimensions
69
+ dim_global (`int`, *optional*, defaults to 512):
70
+ Dimension of the global transformer component.
71
+ dim_local_decoder (`int`, *optional*, defaults to 512):
72
+ Dimension of the local decoder component.
73
+ dim_local_encoder (`int`, *optional*, defaults to 512):
74
+ Dimension of the local encoder component.
75
+ n_layers_global (`int`, *optional*, defaults to 8):
76
+ Number of layers in the global transformer.
77
+ n_layers_local_decoder (`int`, *optional*, defaults to 8):
78
+ Number of layers in the local decoder.
79
+ n_layers_local_encoder (`int`, *optional*, defaults to 8):
80
+ Number of layers in the local encoder.
81
+ n_heads_global (`int`, *optional*, defaults to 8):
82
+ Number of attention heads in the global transformer.
83
+ n_heads_local_decoder (`int`, *optional*, defaults to 8):
84
+ Number of attention heads in the local decoder.
85
+ n_heads_local_encoder (`int`, *optional*, defaults to 8):
86
+ Number of attention heads in the local encoder.
87
+ n_kv_heads_global (`int`, *optional*):
88
+ Number of key-value heads in the global transformer.
89
+
90
+ # Transformer configuration
91
+ norm_eps (`float`, *optional*, defaults to 1e-5):
92
+ The epsilon used by the layer normalization layers.
93
+ dropout (`float`, *optional*, defaults to 0.0):
94
+ The dropout probability for all fully connected layers.
95
+ ffn_dim_multiplier (`float`, *optional*, defaults to 1.0):
96
+ Multiplier for the feedforward network dimension.
97
+ multiple_of (`int`, *optional*, defaults to 256):
98
+ Make feedforward network dimension multiple of this value.
99
+
100
+ # Positional encoding
101
+ rope_theta (`float`, *optional*, defaults to 10000.0):
102
+ The base period of the RoPE embeddings.
103
+ rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False):
104
+ Whether to use fp32 in RoPE outer product computation.
105
+
106
+ # Attention configuration
107
+ attn_impl (`str`, *optional*, defaults to "sdpa"):
108
+ Attention implementation to use ("sdpa" or "flex_attention").
109
+ attn_bias_type (`str`, *optional*, defaults to "causal"):
110
+ Type of attention bias to apply.
111
+ local_attention_window_len (`int`, *optional*):
112
+ Window length for local attention.
113
+ use_rope (`bool`, *optional*, defaults to True):
114
+ Whether to use rotary position embeddings.
115
+
116
+ # Initialization
117
+ init_base_std (`float`, *optional*):
118
+ Base standard deviation for weight initialization.
119
+ init_std_factor (`str`, *optional*, defaults to "disabled"):
120
+ Factor for adjusting initialization standard deviation.
121
+
122
+ # Embedding dimensions
123
+ dim_token_emb (`int`, *optional*):
124
+ Token embedding dimension.
125
+ dim_token (`int`, *optional*):
126
+ Token dimension.
127
+
128
+ # Patching configuration
129
+ patch_in_forward (`bool`, *optional*, defaults to False):
130
+ Whether to perform patching during forward pass.
131
+ realtime_patching (`bool`, *optional*, defaults to True):
132
+ Whether to use realtime patching.
133
+ patch_size (`float`, *optional*):
134
+ Size of patches for static patching.
135
+ patching_mode (`str`, *optional*):
136
+ Mode for patching ("entropy", "static", etc.).
137
+ patching_threshold (`float`, *optional*):
138
+ Threshold for entropy-based patching.
139
+ patching_threshold_add (`float`, *optional*):
140
+ Additional threshold parameter for patching.
141
+ monotonicity (`bool`, *optional*, defaults to False):
142
+ Whether to enforce monotonicity in patching.
143
+ patching_batch_size (`int`, *optional*, defaults to 1):
144
+ Batch size for patching operations.
145
+ patching_device (`str`, *optional*, defaults to "cuda"):
146
+ Device to use for patching operations.
147
+ max_patch_length (`int`, *optional*):
148
+ Maximum length of patches.
149
+ entropy_model_checkpoint_dir (`str`, *optional*):
150
+ Directory containing entropy model checkpoint.
151
+
152
+ # Cross attention configurations
153
+ cross_attn_encoder (`bool`, *optional*, defaults to False):
154
+ Whether to use cross attention in encoder.
155
+ cross_attn_decoder (`bool`, *optional*, defaults to False):
156
+ Whether to use cross attention in decoder.
157
+ cross_attn_window_encoder (`int`, *optional*):
158
+ Cross attention window for encoder.
159
+ cross_attn_window_decoder (`int`, *optional*):
160
+ Cross attention window for decoder.
161
+ cross_attn_k (`int`, *optional*):
162
+ Number of cross attention components.
163
+ cross_attn_nheads (`int`, *optional*):
164
+ Number of heads for cross attention.
165
+ cross_attn_all_layers_decoder (`bool`, *optional*, defaults to False):
166
+ Whether to apply cross attention to all decoder layers.
167
+ cross_attn_all_layers_encoder (`bool`, *optional*, defaults to False):
168
+ Whether to apply cross attention to all encoder layers.
169
+ cross_attn_use_flex_attention (`bool`, *optional*, defaults to True):
170
+ Whether to use flexible attention for cross attention.
171
+ cross_attn_init_by_pooling (`bool`, *optional*, defaults to False):
172
+ Whether to initialize cross attention by pooling.
173
+
174
+ # Encoder configurations
175
+ use_local_encoder_transformer (`bool`, *optional*, defaults to False):
176
+ Whether to use transformer in local encoder.
177
+ max_encoder_seq_length (`int`, *optional*):
178
+ Maximum sequence length for encoder.
179
+ encoder_hash_byte_group_size (`Any`, *optional*):
180
+ Hash byte group size for encoder.
181
+ encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 30000):
182
+ Vocabulary size for hash byte groups.
183
+ encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 3):
184
+ Number of hash functions for byte groups.
185
+ encoder_enable_byte_ngrams (`bool`, *optional*, defaults to False):
186
+ Whether to enable byte n-grams in encoder.
187
+ encoder_ngram_to_size_str (`str`, *optional*):
188
+ String defining n-gram sizes.
189
+ downsampling_by_pooling (`str`, *optional*):
190
+ Type of pooling for downsampling.
191
+
192
+ # Model behavior
193
+ share_encoder_decoder_emb (`bool`, *optional*, defaults to True):
194
+ Whether to share encoder and decoder embeddings.
195
+ weight_tying (`bool`, *optional*, defaults to False):
196
+ Whether to tie input and output embeddings.
197
+
198
+ # Performance optimization
199
+ sequence_parallel (`bool`, *optional*, defaults to False):
200
+ Whether to use sequence parallelism.
201
+ loss_parallel (`bool`, *optional*, defaults to False):
202
+ Whether to use loss parallelism.
203
+ fuse_sequence_parallel (`bool`, *optional*, defaults to False):
204
+ Whether to fuse sequence parallel operations.
205
+ use_fsdp (`bool`, *optional*, defaults to True):
206
+ Whether to use fully sharded data parallel.
207
+
208
+ # Parameter mixing
209
+ pm_size (`int`, *optional*, defaults to 0):
210
+ Parameter mixing size.
211
+
212
+ # Special tokens
213
+ bos_token_id (`int`, *optional*, defaults to 1):
214
+ The id of the "beginning-of-sequence" token.
215
+ eos_token_id (`int`, *optional*, defaults to 2):
216
+ The id of the "end-of-sequence" token.
217
+ pad_token_id (`int`, *optional*, defaults to -1):
218
+ The id of the padding token.
219
+
220
+ # Patcher/Entropy model configuration
221
+ patcher_vocab_size (`int`, *optional*, defaults to 256):
222
+ Vocabulary size for the entropy model used in patching.
223
+ patcher_dim (`int`, *optional*, defaults to 512):
224
+ Hidden dimension for the entropy model.
225
+ patcher_n_layers (`int`, *optional*, defaults to 8):
226
+ Number of layers in the entropy model.
227
+ patcher_n_heads (`int`, *optional*, defaults to 8):
228
+ Number of attention heads in the entropy model.
229
+ patcher_head_dim (`int`, *optional*):
230
+ Dimension of each attention head in the entropy model.
231
+ patcher_n_kv_heads (`int`, *optional*):
232
+ Number of key-value heads in the entropy model.
233
+ patcher_max_seqlen (`int`, *optional*, defaults to 1024):
234
+ Maximum sequence length for the entropy model.
235
+ patcher_norm_eps (`float`, *optional*, defaults to 1e-5):
236
+ Layer normalization epsilon for the entropy model.
237
+ patcher_dropout (`float`, *optional*, defaults to 0.0):
238
+ Dropout probability for the entropy model.
239
+ patcher_sliding_window (`int`, *optional*):
240
+ Sliding window size for the entropy model attention.
241
+ patcher_ffn_dim_multiplier (`float`, *optional*):
242
+ Feedforward dimension multiplier for the entropy model.
243
+ patcher_multiple_of (`int`, *optional*, defaults to 256):
244
+ Make feedforward dimension multiple of this for the entropy model.
245
+ patcher_rope_theta (`float`, *optional*, defaults to 10000.0):
246
+ RoPE theta parameter for the entropy model.
247
+ patcher_rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False):
248
+ Whether to use fp32 in RoPE outer product for the entropy model.
249
+ patcher_attn_impl (`str`, *optional*, defaults to "sdpa"):
250
+ Attention implementation for the entropy model.
251
+ patcher_attn_bias_type (`str`, *optional*, defaults to "causal"):
252
+ Attention bias type for the entropy model.
253
+ patcher_init_base_std (`float`, *optional*):
254
+ Base initialization standard deviation for the entropy model.
255
+ patcher_init_std_factor (`str`, *optional*, defaults to "disabled"):
256
+ Initialization std factor for the entropy model.
257
+ patcher_dim_token_emb (`int`, *optional*):
258
+ Token embedding dimension for the entropy model.
259
+ patcher_weight_tying (`bool`, *optional*, defaults to False):
260
+ Whether to tie embeddings in the entropy model.
261
+ patcher_bos_token_id (`int`, *optional*, defaults to 1):
262
+ Beginning of sequence token id for the entropy model.
263
+ patcher_eos_token_id (`int`, *optional*, defaults to 2):
264
+ End of sequence token id for the entropy model.
265
+
266
+ ```python
267
+ >>> from transformers import ByteLatentTransformer, BLTConfig
268
+
269
+ >>> # Initializing a BLT configuration
270
+ >>> configuration = BLTConfig()
271
+
272
+ >>> # Initializing a model from the configuration
273
+ >>> model = ByteLatentTransformer(configuration)
274
+
275
+ >>> # Accessing the model configuration
276
+ >>> configuration = model.config
277
+ ```"""
278
+
279
+ model_type = "blt"
280
+ keys_to_ignore_at_inference = ["past_key_values"]
281
+
282
+ def __init__(
283
+ self,
284
+ vocab_size=256,
285
+ max_seqlen=1024,
286
+ # Main architecture dimensions
287
+ dim=512,
288
+ n_layers=8,
289
+ n_heads=8,
290
+ head_dim=None,
291
+ n_kv_heads=None,
292
+ # Component-specific dimensions
293
+ dim_global=512,
294
+ dim_local_decoder=512,
295
+ dim_local_encoder=512,
296
+ n_layers_global=8,
297
+ n_layers_local_decoder=8,
298
+ n_layers_local_encoder=8,
299
+ n_heads_global=8,
300
+ n_heads_local_decoder=8,
301
+ n_heads_local_encoder=8,
302
+ n_kv_heads_global=None,
303
+ # Transformer configuration
304
+ norm_eps=1e-5,
305
+ dropout=0.0,
306
+ ffn_dim_multiplier=1.0,
307
+ multiple_of=256,
308
+ # Positional encoding
309
+ rope_theta=10000.0,
310
+ rope_use_fp32_in_outer_product=False,
311
+ # Attention configuration
312
+ attn_impl="sdpa",
313
+ attn_bias_type="causal",
314
+ local_attention_window_len=None,
315
+ use_rope=True,
316
+ # Initialization
317
+ init_base_std=None,
318
+ init_std_factor="disabled",
319
+ # Embedding dimensions
320
+ dim_token_emb=None,
321
+ dim_token=None,
322
+ # Patching configuration
323
+ patch_in_forward=False,
324
+ realtime_patching=True,
325
+ patch_size=None,
326
+ patching_mode=None,
327
+ patching_threshold=None,
328
+ patching_threshold_add=None,
329
+ monotonicity=False,
330
+ patching_batch_size=1,
331
+ patching_device="cuda",
332
+ max_patch_length=None,
333
+ entropy_model_checkpoint_dir=None,
334
+ # Cross attention configurations
335
+ cross_attn_encoder=False,
336
+ cross_attn_decoder=False,
337
+ cross_attn_window_encoder=None,
338
+ cross_attn_window_decoder=None,
339
+ cross_attn_k=None,
340
+ cross_attn_nheads=None,
341
+ cross_attn_all_layers_decoder=False,
342
+ cross_attn_all_layers_encoder=False,
343
+ cross_attn_use_flex_attention=True,
344
+ cross_attn_init_by_pooling=False,
345
+ # Encoder configurations
346
+ use_local_encoder_transformer=False,
347
+ max_encoder_seq_length=None,
348
+ encoder_hash_byte_group_size=None,
349
+ encoder_hash_byte_group_vocab=30000,
350
+ encoder_hash_byte_group_nb_functions=3,
351
+ encoder_enable_byte_ngrams=False,
352
+ encoder_ngram_to_size_str=None,
353
+ downsampling_by_pooling=None,
354
+ # Model behavior
355
+ share_encoder_decoder_emb=True,
356
+ weight_tying=False,
357
+ # Performance optimization
358
+ sequence_parallel=False,
359
+ loss_parallel=False,
360
+ fuse_sequence_parallel=False,
361
+ use_fsdp=True,
362
+ # Parameter mixing
363
+ pm_size=0,
364
+ # Special tokens
365
+ bos_token_id=1,
366
+ eos_token_id=2,
367
+ pad_token_id=-1,
368
+ # Patcher/Entropy model configuration
369
+ patcher_vocab_size=256,
370
+ patcher_dim=512,
371
+ patcher_n_layers=8,
372
+ patcher_n_heads=8,
373
+ patcher_head_dim=None,
374
+ patcher_n_kv_heads=None,
375
+ patcher_max_seqlen=1024,
376
+ patcher_norm_eps=1e-5,
377
+ patcher_dropout=0.0,
378
+ patcher_sliding_window=None,
379
+ patcher_ffn_dim_multiplier=None,
380
+ patcher_multiple_of=256,
381
+ patcher_rope_theta=10000.0,
382
+ patcher_rope_use_fp32_in_outer_product=False,
383
+ patcher_attn_impl="sdpa",
384
+ patcher_attn_bias_type="causal",
385
+ patcher_init_base_std=None,
386
+ patcher_init_std_factor="disabled",
387
+ patcher_dim_token_emb=None,
388
+ patcher_weight_tying=False,
389
+ patcher_bos_token_id=1,
390
+ patcher_eos_token_id=2,
391
+ # Inherited
392
+ **kwargs,
393
+ ):
394
+
395
+ self.sliding_window = None
396
+ # Basic model configuration
397
+ self.vocab_size = vocab_size
398
+ self.max_seqlen = max_seqlen
399
+
400
+ # Main architecture dimensions
401
+ self.dim = dim
402
+ self.n_layers = n_layers
403
+ self.n_heads = n_heads
404
+ self.head_dim = head_dim
405
+ self.n_kv_heads = n_kv_heads
406
+
407
+ # Component-specific dimensions
408
+ self.dim_global = dim_global
409
+ self.dim_local_decoder = dim_local_decoder
410
+ self.dim_local_encoder = dim_local_encoder
411
+ self.n_layers_global = n_layers_global
412
+ self.n_layers_local_decoder = n_layers_local_decoder
413
+ self.n_layers_local_encoder = n_layers_local_encoder
414
+ self.n_heads_global = n_heads_global
415
+ self.n_heads_local_decoder = n_heads_local_decoder
416
+ self.n_heads_local_encoder = n_heads_local_encoder
417
+ self.n_kv_heads_global = n_kv_heads_global
418
+
419
+ # Transformer configuration
420
+ self.norm_eps = norm_eps
421
+ self.dropout = dropout
422
+ self.ffn_dim_multiplier = ffn_dim_multiplier
423
+ self.multiple_of = multiple_of
424
+
425
+ # Positional encoding
426
+ self.rope_theta = rope_theta
427
+ self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
428
+
429
+ # Attention configuration
430
+ self.attn_impl = attn_impl
431
+ self.attn_bias_type = attn_bias_type
432
+ self.local_attention_window_len = local_attention_window_len
433
+ self.use_rope = use_rope
434
+
435
+ # Initialization
436
+ self.init_base_std = init_base_std
437
+ self.init_std_factor = InitStdFactor(init_std_factor)
438
+
439
+ # Embedding dimensions
440
+ self.dim_token_emb = dim_token_emb
441
+ self.dim_token = dim_token
442
+
443
+ # Patching configuration
444
+ self.patch_in_forward = patch_in_forward
445
+ self.realtime_patching = realtime_patching
446
+ self.patch_size = patch_size
447
+ self.patching_mode = patching_mode
448
+ self.patching_threshold = patching_threshold
449
+ self.patching_threshold_add = patching_threshold_add
450
+ self.monotonicity = monotonicity
451
+ self.patching_batch_size = patching_batch_size
452
+ self.patching_device = patching_device
453
+ self.max_patch_length = max_patch_length
454
+ self.entropy_model_checkpoint_dir = entropy_model_checkpoint_dir
455
+
456
+ # Cross attention configurations
457
+ self.cross_attn_encoder = cross_attn_encoder
458
+ self.cross_attn_decoder = cross_attn_decoder
459
+ self.cross_attn_window_encoder = cross_attn_window_encoder
460
+ self.cross_attn_window_decoder = cross_attn_window_decoder
461
+ self.cross_attn_k = cross_attn_k
462
+ self.cross_attn_nheads = cross_attn_nheads
463
+ self.cross_attn_all_layers_decoder = cross_attn_all_layers_decoder
464
+ self.cross_attn_all_layers_encoder = cross_attn_all_layers_encoder
465
+ self.cross_attn_use_flex_attention = cross_attn_use_flex_attention
466
+ self.cross_attn_init_by_pooling = cross_attn_init_by_pooling
467
+
468
+ # Encoder configurations
469
+ self.use_local_encoder_transformer = use_local_encoder_transformer
470
+ self.max_encoder_seq_length = max_encoder_seq_length
471
+ self.encoder_hash_byte_group_size = encoder_hash_byte_group_size
472
+ self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab
473
+ self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions
474
+ self.encoder_enable_byte_ngrams = encoder_enable_byte_ngrams
475
+ self.encoder_ngram_to_size_str = encoder_ngram_to_size_str
476
+ self.downsampling_by_pooling = downsampling_by_pooling
477
+
478
+ # Model behavior
479
+ self.share_encoder_decoder_emb = share_encoder_decoder_emb
480
+ self.weight_tying = weight_tying
481
+
482
+ # Performance optimization
483
+ self.sequence_parallel = sequence_parallel
484
+ self.loss_parallel = loss_parallel
485
+ self.fuse_sequence_parallel = fuse_sequence_parallel
486
+ self.use_fsdp = use_fsdp
487
+
488
+ # Parameter mixing
489
+ self.pm_size = pm_size
490
+
491
+ # Patcher/Entropy model configuration
492
+ self.patcher_vocab_size = patcher_vocab_size
493
+ self.patcher_dim = patcher_dim
494
+ self.patcher_n_layers = patcher_n_layers
495
+ self.patcher_n_heads = patcher_n_heads
496
+ self.patcher_head_dim = patcher_head_dim
497
+ self.patcher_n_kv_heads = patcher_n_kv_heads
498
+ self.patcher_max_seqlen = patcher_max_seqlen
499
+ self.patcher_norm_eps = patcher_norm_eps
500
+ self.patcher_dropout = patcher_dropout
501
+ self.patcher_sliding_window = patcher_sliding_window
502
+ self.patcher_ffn_dim_multiplier = patcher_ffn_dim_multiplier
503
+ self.patcher_multiple_of = patcher_multiple_of
504
+ self.patcher_rope_theta = patcher_rope_theta
505
+ self.patcher_rope_use_fp32_in_outer_product = patcher_rope_use_fp32_in_outer_product
506
+ self.patcher_attn_impl = patcher_attn_impl
507
+ self.patcher_attn_bias_type = patcher_attn_bias_type
508
+ self.patcher_init_base_std = patcher_init_base_std
509
+ self.patcher_init_std_factor = InitStdFactor(patcher_init_std_factor)
510
+ self.patcher_dim_token_emb = patcher_dim_token_emb
511
+ self.patcher_weight_tying = patcher_weight_tying
512
+ self.patcher_bos_token_id = patcher_bos_token_id
513
+ self.patcher_eos_token_id = patcher_eos_token_id
514
+
515
+ # Handle hash byte group size validation
516
+ if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str:
517
+ self.encoder_hash_byte_group_size = [
518
+ int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0
519
+ ]
520
+
521
+ # Rope
522
+ self.rope_scaling={
523
+ "type": "dynamic",
524
+ "factor": 2.0,
525
+ "rope_type": "dynamic"
526
+ }
527
+
528
+ self.num_key_value_heads=n_heads_local_encoder
529
+ self.max_position_embeddings=max_seqlen
530
+ self.hidden_size=dim_local_encoder
531
+ self.num_attention_heads=n_heads_local_encoder
532
+ # self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
533
+
534
+ super().__init__(
535
+ bos_token_id=bos_token_id,
536
+ eos_token_id=eos_token_id,
537
+ pad_token_id=pad_token_id,
538
+ **kwargs,
539
+ )
540
+
541
+ @property
542
+ def encoder_dim_token_emb(self):
543
+ """Compute encoder token embedding dimension."""
544
+ if self.dim_token is not None:
545
+ return self.dim_token
546
+ elif self.use_local_encoder_transformer:
547
+ return self.dim_local_encoder
548
+ else:
549
+ # Use default patch_size of 8 if not set
550
+ patch_size = self.patch_size if self.patch_size is not None else 8
551
+ return self.dim_global // patch_size
552
+
553
+ @property
554
+ def encoder_dim_patch_emb(self):
555
+ """Compute encoder patch embedding dimension."""
556
+ if self.cross_attn_encoder:
557
+ if self.cross_attn_init_by_pooling:
558
+ return self.dim_local_encoder
559
+ else:
560
+ return self.dim_global
561
+ return None
562
+
563
+ @property
564
+ def global_dim_patch_emb(self):
565
+ """Compute global patch embedding dimension."""
566
+ dim_token_emb = self.encoder_dim_token_emb
567
+ if self.cross_attn_encoder:
568
+ cross_attn_k = self.cross_attn_k if self.cross_attn_k is not None else 1
569
+ return dim_token_emb * cross_attn_k
570
+ elif (
571
+ self.downsampling_by_pooling is None
572
+ or not self.downsampling_by_pooling
573
+ or len(self.downsampling_by_pooling) == 0
574
+ ):
575
+ # Use default patch_size of 8 if not set
576
+ patch_size = self.patch_size if self.patch_size is not None else 8
577
+ return dim_token_emb * patch_size
578
+ else:
579
+ return dim_token_emb * sum([pooling in self.downsampling_by_pooling for pooling in ["avg", "min", "max"]])
580
+
581
+ @property
582
+ def decoder_dim_token_emb(self):
583
+ """Compute decoder token embedding dimension."""
584
+ if self.share_encoder_decoder_emb:
585
+ return self.encoder_dim_token_emb
586
+ elif self.dim_token is not None:
587
+ return self.dim_token
588
+ else:
589
+ return self.dim_local_decoder
590
+
591
+ def get_init_std_factor(self, depth: int) -> float:
592
+ """
593
+ Calculate the initialization standard deviation scaling factor for a given layer depth.
594
+
595
+ Args:
596
+ depth: Current layer depth (0-indexed)
597
+
598
+ Returns:
599
+ Scaling factor to divide the base initialization std by
600
+ """
601
+ if self.init_std_factor == InitStdFactor.CURRENT_DEPTH:
602
+ return (2 * (depth + 1)) ** 0.5
603
+ else: # DISABLED
604
+ return 1.0
605
+
606
+
607
+ __all__ = ["BLTConfig", "InitStdFactor", "PatchingModeEnum"]
608
+
backup_blt_wip copy/modeling_blt.py ADDED
@@ -0,0 +1,1287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """BLT model."""
16
+
17
+ from ...utils import is_torch_flex_attn_available, logging
18
+ from typing import Callable, List, Optional, Tuple, Union
19
+
20
+ from enum import Enum
21
+
22
+ from ...cache_utils import Cache
23
+ from ...activations import ACT2FN
24
+
25
+ import torch
26
+ import torch.distributions
27
+ import torch.nn
28
+ import torch.nn as nn
29
+ from torch.nn import functional as F
30
+
31
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
32
+
33
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
34
+ from .configuration_blt import (
35
+ BLTConfig,
36
+ BLTLocalEncoderConfig,
37
+ BLTLocalDecoderConfig,
38
+ BLTGlobalTransformerConfig,
39
+ BLTPatcherConfig,
40
+ )
41
+
42
+ from ...generation.utils import GenerationMixin
43
+ from ...modeling_outputs import CausalLMOutputWithPast
44
+
45
+ if is_torch_flex_attn_available():
46
+ from torch.nn.attention.flex_attention import BlockMask
47
+ from ...integrations.flex_attention import make_flex_block_causal_mask
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ class PatchingModeEnum(str, Enum):
53
+ entropy = "entropy"
54
+ bpe = "bpe"
55
+ bpe_patcher = "bpe_patcher"
56
+ space = "space"
57
+ static = "static"
58
+ byte = "byte"
59
+
60
+
61
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
62
+ """
63
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
64
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
65
+ """
66
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
67
+ if n_rep == 1:
68
+ return hidden_states
69
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
70
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
71
+
72
+
73
+ def eager_attention_forward(
74
+ module: nn.Module,
75
+ query: torch.Tensor,
76
+ key: torch.Tensor,
77
+ value: torch.Tensor,
78
+ attention_mask: Optional[torch.Tensor],
79
+ scaling: float,
80
+ dropout: float = 0.0,
81
+ **kwargs,
82
+ ):
83
+ key_states = repeat_kv(key, module.num_key_value_groups)
84
+ value_states = repeat_kv(value, module.num_key_value_groups)
85
+
86
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
87
+ if attention_mask is not None:
88
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
89
+ attn_weights = attn_weights + causal_mask
90
+
91
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
92
+ attn_weights = F.dropout(attn_weights, p=dropout, training=module.training)
93
+ attn_output = torch.matmul(attn_weights, value_states)
94
+ attn_output = attn_output.transpose(1, 2).contiguous()
95
+
96
+ return attn_output, attn_weights
97
+
98
+
99
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
100
+ # TODO: not exactly equivalent to other transformers implementations,, need feedback
101
+ # Extract first head_dim//2 elements which correspond to the unique frequencies
102
+ # This matches the original BLT approach which uses head_dim//2 frequency pairs
103
+ head_dim = q.shape[-1]
104
+ cos_freqs = cos[..., :head_dim//2] # [B, S, D/2]
105
+ sin_freqs = sin[..., :head_dim//2] # [B, S, D/2]
106
+
107
+ # Expand cos/sin to match query/key tensor format [B, H, S, D/2]
108
+ cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2]
109
+ sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2]
110
+
111
+ # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ...
112
+ q_pairs = q.view(*q.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2]
113
+ k_pairs = k.view(*k.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2]
114
+
115
+ # Extract real and i parts
116
+ q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2]
117
+ k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2]
118
+
119
+ # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag]
120
+ q_real_rot = cos_freqs * q_real - sin_freqs * q_imag
121
+ q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag
122
+ k_real_rot = cos_freqs * k_real - sin_freqs * k_imag
123
+ k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag
124
+
125
+ # Recombine pairs and reshape back to original format
126
+ q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D]
127
+ k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D]
128
+
129
+ return q_rot.type_as(q), k_rot.type_as(k)
130
+
131
+
132
+ class BLTMLP(nn.Module):
133
+ def __init__(self, config):
134
+ super().__init__()
135
+ self.hidden_size = config.hidden_size
136
+ self.intermediate_size = config.intermediate_size
137
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
138
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
139
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
140
+ self.act_fn = ACT2FN[config.hidden_act]
141
+
142
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
143
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
144
+ return down_proj
145
+
146
+
147
+ class BLTRMSNorm(nn.Module):
148
+ def __init__(self, hidden_size, eps=1e-6):
149
+ """
150
+ BLTRMSNorm is equivalent to T5LayerNorm
151
+ """
152
+ super().__init__()
153
+ self.weight = nn.Parameter(torch.ones(hidden_size))
154
+ self.variance_epsilon = eps
155
+
156
+ def forward(self, hidden_states):
157
+ input_dtype = hidden_states.dtype
158
+ hidden_states = hidden_states.to(torch.float32)
159
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
160
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
161
+ return self.weight * hidden_states.to(input_dtype)
162
+
163
+
164
+ class BLTTransformerLayer(nn.Module):
165
+ def __init__(self, config, layer_idx: int):
166
+ super().__init__()
167
+ self.hidden_size = config.hidden_size
168
+ self.layer_idx = layer_idx
169
+
170
+ self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx)
171
+ self.mlp = BLTMLP(config)
172
+ self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps)
173
+ self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps)
174
+
175
+ def forward(
176
+ self,
177
+ hidden_states: torch.Tensor,
178
+ attention_mask: Optional[torch.Tensor] = None,
179
+ position_ids: Optional[torch.LongTensor] = None,
180
+ past_key_value: Optional[Cache] = None,
181
+ output_attentions: Optional[bool] = False,
182
+ use_cache: Optional[bool] = False,
183
+ cache_position: Optional[torch.LongTensor] = None,
184
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
185
+ **kwargs,
186
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
187
+ """
188
+ Args:
189
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
190
+ attention_mask (`torch.FloatTensor`, *optional*):
191
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
192
+ query_sequence_length, key_sequence_length)` if default attention is used.
193
+ position_ids (`torch.LongTensor`, *optional*):
194
+ Position indices of tokens in the sequence for RoPE computation.
195
+ past_key_value (`Cache`, *optional*): cached past key and value projection states
196
+ output_attentions (`bool`, *optional*):
197
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
198
+ returned tensors for more detail.
199
+ use_cache (`bool`, *optional*):
200
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
201
+ (see `past_key_values`).
202
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
203
+ Indices depicting the position of the input sequence tokens in the sequence
204
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
205
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
206
+ with `head_dim` being the embedding dimension of each attention head.
207
+ kwargs (`dict`, *optional*):
208
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
209
+ into the model
210
+ """
211
+ residual = hidden_states
212
+ hidden_states = self.input_layernorm(hidden_states)
213
+
214
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
215
+ hidden_states=hidden_states,
216
+ attention_mask=attention_mask,
217
+ position_ids=position_ids,
218
+ past_key_value=past_key_value,
219
+ output_attentions=output_attentions,
220
+ use_cache=use_cache,
221
+ cache_position=cache_position,
222
+ position_embeddings=position_embeddings,
223
+ **kwargs,
224
+ )
225
+ hidden_states = residual + hidden_states
226
+
227
+ residual = hidden_states
228
+ hidden_states = self.post_attention_layernorm(hidden_states)
229
+ hidden_states = self.mlp(hidden_states)
230
+ hidden_states = residual + hidden_states
231
+
232
+ outputs = (hidden_states,)
233
+
234
+ if output_attentions:
235
+ outputs += (self_attn_weights,)
236
+
237
+ if use_cache:
238
+ outputs += (present_key_value,)
239
+
240
+ return outputs
241
+
242
+
243
+ class BLTSelfAttention(nn.Module):
244
+ def __init__(self, config, layer_idx: int):
245
+ super().__init__()
246
+ self.config = config
247
+ self.num_heads = config.num_attention_heads
248
+ self.dropout = config.dropout
249
+ self.hidden_size = config.hidden_size
250
+ self.num_key_value_heads = config.num_key_value_heads
251
+ self.head_dim = config.hidden_size // self.num_heads
252
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
253
+ self.scaling = None
254
+ self.rope_theta = config.rope_theta
255
+ self.layer_idx = layer_idx
256
+
257
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
258
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
259
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
260
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
261
+
262
+ def forward(
263
+ self,
264
+ hidden_states: torch.Tensor,
265
+ attention_mask: torch.Tensor,
266
+ position_embeddings: torch.Tensor,
267
+ output_attentions: bool = False,
268
+ use_cache: bool = False,
269
+ past_key_value=None,
270
+ cache_position=None,
271
+ **kwargs,
272
+ ):
273
+ bsz, q_len, _ = hidden_states.size()
274
+
275
+ query_states = self.q_proj(hidden_states)
276
+ key_states = self.k_proj(hidden_states)
277
+ value_states = self.v_proj(hidden_states)
278
+
279
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
280
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
281
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
282
+
283
+ cos, sin = position_embeddings
284
+
285
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
286
+
287
+ if past_key_value is not None:
288
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
289
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
290
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
291
+
292
+ attention_interface: Callable = eager_attention_forward
293
+ output_attentions = False
294
+ self.config._attn_implementation = "sdpa"
295
+ if self.config._attn_implementation != "eager":
296
+ if self.config._attn_implementation == "sdpa" and output_attentions:
297
+ logger.warning_once(
298
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
299
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
300
+ )
301
+ else:
302
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
303
+
304
+ attn_output, attn_weights = attention_interface(
305
+ self,
306
+ query_states,
307
+ key_states,
308
+ value_states,
309
+ attention_mask,
310
+ dropout=0.0 if not self.training else self.dropout,
311
+ scaling=self.scaling,
312
+ **kwargs,
313
+ )
314
+
315
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
316
+ attn_output = self.o_proj(attn_output)
317
+
318
+ if not output_attentions:
319
+ attn_weights = None
320
+
321
+ return attn_output, attn_weights, past_key_value
322
+
323
+
324
+ def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0):
325
+ primes = [
326
+ 1000000007, 5915587277, 1500450271, 3267000013, 5754853343,
327
+ 4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313,
328
+ ]
329
+ prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device)
330
+ powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device)
331
+ prime_powers = prime ** powers
332
+ return torch.sum(token_tensor * prime_powers, dim=-1)
333
+
334
+
335
+ def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000):
336
+ """Hash token groups and map to range [0, max_hash]."""
337
+ with torch.no_grad():
338
+ batch_size, seq_len = token_ids.shape
339
+ # Add padding for sliding window
340
+ padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device)
341
+ padded_tokens = torch.cat([padding, token_ids], dim=1)
342
+
343
+ # Create sliding windows and compute hashes
344
+ windows = padded_tokens.unfold(1, group_size, 1)
345
+ hashes = rolling_polynomial_hash(windows, hash_func_nb)
346
+ hash_values = hashes % max_hash
347
+
348
+ hash_values.requires_grad = False
349
+ return hash_values
350
+
351
+
352
+ def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list):
353
+ """Initialize hash-based token embeddings for the BLT encoder."""
354
+ num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size)
355
+ embeddings = [
356
+ nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim)
357
+ for _ in range(num_embeddings)
358
+ ]
359
+ return nn.ModuleList(embeddings)
360
+
361
+
362
+ def compute_hash_embeddings(
363
+ local_encoder_tokens: torch.Tensor,
364
+ local_encoder,
365
+ encoder_hash_tok_embedding: nn.ModuleList,
366
+ encoder_hash_byte_group_nb_functions: int,
367
+ encoder_hash_byte_group_size: list,
368
+ encoder_hash_byte_group_vocab: int,
369
+ ) -> torch.Tensor:
370
+ """Compute token embeddings enhanced with hash-based embeddings."""
371
+ embeddings = local_encoder.embed_tokens(local_encoder_tokens)
372
+ embedding_idx = 0
373
+ for func_nb in range(encoder_hash_byte_group_nb_functions):
374
+ for group_size in encoder_hash_byte_group_size:
375
+ hash_ids = byte_group_hash_function(
376
+ local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab
377
+ )
378
+ embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids)
379
+ embedding_idx += 1
380
+
381
+ return embeddings
382
+
383
+
384
+ def _prepare_patch_cross_attention_mask(
385
+ patch_ids: torch.Tensor,
386
+ num_patches: int,
387
+ sequence_length: int,
388
+ patches_as_queries: bool = False,
389
+ cross_attn_k: int = 1,
390
+ dtype: torch.dtype = torch.float32,
391
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
392
+ """
393
+ Prepare cross-attention mask for patch-based attention, following mllama's robust approach.
394
+
395
+ This function creates masks that control which patches can attend to which other patches,
396
+ with support for query/key role swapping and cross-attention multipliers.
397
+
398
+ Args:
399
+ patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids.
400
+ num_patches (int): Total number of patches.
401
+ sequence_length (int): Length of the sequence.
402
+ patches_as_queries (bool): If True, patches are used as queries, otherwise as keys.
403
+ cross_attn_k (int): Cross-attention multiplier for repeating patches.
404
+ dtype (torch.dtype): Data type for the output mask.
405
+
406
+ Returns:
407
+ Tuple[torch.Tensor, torch.Tensor]:
408
+ - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len]
409
+ - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows
410
+ """
411
+ batch_size, seq_len = patch_ids.shape
412
+ device = patch_ids.device
413
+
414
+ # Determine query and key lengths based on configuration
415
+ if patches_as_queries:
416
+ q_len = num_patches * cross_attn_k
417
+ kv_len = sequence_length
418
+ # Create patch-to-sequence mapping
419
+ q_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(-1).expand(
420
+ batch_size, num_patches, seq_len
421
+ )
422
+ kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len)
423
+ else:
424
+ q_len = sequence_length
425
+ kv_len = num_patches * cross_attn_k
426
+ # Create sequence-to-patch mapping
427
+ q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches)
428
+ kv_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(
429
+ batch_size, seq_len, num_patches
430
+ )
431
+
432
+ # Create base attention mask - boolean mask where True means "should attend"
433
+ # Exact patch matching
434
+ cross_attention_mask = q_patch_ids == kv_patch_ids
435
+
436
+ # Handle cross_attn_k multiplier by repeating along appropriate dimension
437
+ repeat_dim = 1 if patches_as_queries else -1
438
+ cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim)
439
+
440
+ # Validate dimensions
441
+ expected_shape = (batch_size, q_len, kv_len)
442
+ if cross_attention_mask.shape != expected_shape:
443
+ raise ValueError(f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}")
444
+
445
+ # Reshape so it can be used by attn module - add head dimension
446
+ cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len]
447
+
448
+ # Invert the mask (following mllama pattern exactly)
449
+ # True -> 0.0 (attend), False -> 1.0 (will become -inf)
450
+ inverted_cross_attn_mask = (1.0 - cross_attention_mask.to(dtype))
451
+ cross_attention_mask = inverted_cross_attn_mask.masked_fill(
452
+ inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
453
+ )
454
+
455
+ # Apply full-row bias (following mllama pattern exactly)
456
+ # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's
457
+ # last dimension contains negative infinity values, otherwise it's 1
458
+ negative_inf_value = torch.finfo(dtype).min
459
+ full_text_row_masked_out_mask = (
460
+ (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
461
+ )
462
+ cross_attention_mask *= full_text_row_masked_out_mask
463
+
464
+ return cross_attention_mask, full_text_row_masked_out_mask
465
+
466
+
467
+ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor:
468
+ """
469
+ Splits patch lengths into smaller segments if they exceed `max_patch_length`.
470
+ Pads the result to uniform length across the batch.
471
+
472
+ Args:
473
+ patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths.
474
+ max_patch_length (int, optional): Maximum allowed length per patch.
475
+
476
+ Returns:
477
+ torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths.
478
+ """
479
+ if max_patch_length is None:
480
+ return patch_lengths
481
+
482
+ batch_size = patch_lengths.size(0)
483
+ processed = []
484
+
485
+ for seq in patch_lengths:
486
+ splits = []
487
+ for length in seq[seq > 0]:
488
+ length = length.item()
489
+ full_chunks, remainder = divmod(length, max_patch_length)
490
+ splits.extend([max_patch_length] * full_chunks)
491
+ if remainder:
492
+ splits.append(remainder)
493
+ processed.append(splits)
494
+
495
+ # Find max length to pad to
496
+ max_len = max(len(splits) for splits in processed)
497
+ padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
498
+
499
+ for i, splits in enumerate(processed):
500
+ if splits:
501
+ padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
502
+
503
+ # Trim zero columns
504
+ if (padded != 0).any(dim=0).sum() < padded.shape[1]:
505
+ last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1
506
+ padded = padded[:, :last_nonzero]
507
+
508
+ return padded
509
+
510
+
511
+ class BLTRotaryEmbedding(nn.Module):
512
+ def __init__(self, config, device=None):
513
+ super().__init__()
514
+ self.rope_type = config.rope_scaling["rope_type"]
515
+ self.max_seq_len_cached = config.max_position_embeddings
516
+ self.original_max_seq_len = config.max_position_embeddings
517
+
518
+ self.config = config
519
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
520
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
521
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
522
+ self.original_inv_freq = self.inv_freq
523
+
524
+ @torch.no_grad()
525
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
526
+ def forward(self, x, position_ids):
527
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
528
+ position_ids_expanded = position_ids[:, None, :].float()
529
+
530
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
531
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
532
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
533
+ emb = torch.cat((freqs, freqs), dim=-1)
534
+ cos = emb.cos() * self.attention_scaling
535
+ sin = emb.sin() * self.attention_scaling
536
+
537
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
538
+
539
+
540
+ class BLTLocalEncoder(nn.Module):
541
+ def __init__(self, config: BLTLocalEncoderConfig):
542
+ super().__init__()
543
+
544
+ self.config = config
545
+
546
+ self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
547
+
548
+ self.rotary_emb = BLTRotaryEmbedding(config=config)
549
+
550
+ self.patch_embedding_projection = nn.Linear(
551
+ in_features=config.hidden_size,
552
+ out_features=config.hidden_size * config.cross_attn_k,
553
+ bias=False,
554
+ )
555
+
556
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
557
+
558
+ self.cross_attn_layers = torch.nn.ModuleList()
559
+ layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
560
+ for layer_idx in range(layers_to_add):
561
+ self.cross_attn_layers.append(
562
+ BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
563
+ )
564
+
565
+ def forward(
566
+ self,
567
+ input_ids: torch.Tensor,
568
+ input_embeds: Optional[torch.Tensor] = None,
569
+ patch_embeds: Optional[torch.Tensor] = None,
570
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
571
+ cross_mask: Optional[torch.Tensor] = None,
572
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
573
+ num_patches: Optional[int] = None,
574
+ patch_ids: Optional[torch.Tensor] = None,
575
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
576
+ ):
577
+ """ """
578
+ if input_embeds is None:
579
+ input_embeds = self.embed_tokens(input_ids)
580
+
581
+ batch_size, _, _ = input_embeds.shape
582
+
583
+ hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training)
584
+
585
+ position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1)
586
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
587
+
588
+ hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
589
+
590
+ for idx, layer in enumerate(self.layers):
591
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None)
592
+ hidden_states = layer_outputs[0]
593
+
594
+ if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers:
595
+ patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids)
596
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
597
+ patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size)
598
+
599
+ layer_idx = idx if self.config.cross_attn_all_layers else 0
600
+ cross_attention_output, _, _ = self.cross_attn_layers[layer_idx](
601
+ hidden_states=patch_embeds,
602
+ cross_attention_states=hidden_states,
603
+ attention_mask=cross_mask,
604
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
605
+ output_attentions=False,
606
+ use_cache=False,
607
+ cache_position=None,
608
+ )
609
+ patch_embeds = patch_embeds + cross_attention_output
610
+
611
+ encoder_cross_states = patch_embeds
612
+ return hidden_states, encoder_cross_states
613
+
614
+ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids):
615
+ """
616
+ Reduce variable length patches to single embedding per patch
617
+ Note: this works with variable number of patches for different sequences in the batch
618
+ It handles variable length patches by assuming that patch_lengths will be 0 for any
619
+ extra patches on the *right*. Since there can be a variable number of patches
620
+ this function also return the number of patches for each sequence in the batch.
621
+ Any embeddings on the right that are not allocated to a patch
622
+ (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
623
+ will be sent to a dummy patch, which is trimmed before returning.
624
+ """
625
+ batch_size, _, embedding_dim = hidden_states.shape
626
+
627
+ patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])
628
+
629
+ reduced_embeddings = torch.zeros((batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device)
630
+ reduced_embeddings = reduced_embeddings.scatter_reduce(
631
+ src=hidden_states,
632
+ dim=1,
633
+ index=patch_ids,
634
+ reduce=reduction,
635
+ include_self=False,
636
+ )
637
+ reduced_embeddings = reduced_embeddings[:, :max_num_patches, :]
638
+
639
+ return reduced_embeddings
640
+
641
+
642
+ class BLTLocalDecoder(nn.Module):
643
+ def __init__(self, config: BLTLocalDecoderConfig):
644
+ super().__init__()
645
+
646
+ # Extract config values to instance attributes
647
+ self.config = config
648
+ self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove
649
+
650
+ self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
651
+
652
+ self.rotary_emb = BLTRotaryEmbedding(config=config)
653
+
654
+ self.patch_embedding_projection = nn.Linear(
655
+ in_features=config.hidden_size_global,
656
+ out_features=config.hidden_size * config.cross_attn_k,
657
+ bias=False,
658
+ )
659
+
660
+ self.norm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps)
661
+
662
+ self.cross_attn_layers = torch.nn.ModuleList()
663
+ layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
664
+ for layer_idx in range(layers_to_add):
665
+ self.cross_attn_layers.append(
666
+ BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
667
+ )
668
+
669
+ # self.lm_head = nn.Linear(
670
+ # config.hidden_size,
671
+ # config.vocab_size,
672
+ # bias=False,
673
+ # )
674
+
675
+
676
+ def forward(
677
+ self,
678
+ tokens: torch.Tensor,
679
+ embeds: Optional[torch.Tensor],
680
+ patch_embeds: Optional[torch.Tensor] = None,
681
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
682
+ cross_mask: Optional[torch.Tensor] = None,
683
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
684
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
685
+ ):
686
+ batch_size, _, _ = embeds.shape
687
+
688
+ hidden_states = embeds
689
+
690
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
691
+ patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size)
692
+
693
+ if patch_embeds is not None and not self.cross_attn_decoder:
694
+ hidden_states = hidden_states + patch_embeds
695
+
696
+ position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1)
697
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
698
+
699
+ hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
700
+ for i, layer in enumerate(self.layers):
701
+ if i == 0 or self.config.cross_attn_all_layers:
702
+ # Use cross attention to extract info from patch_embeds into hidden_states
703
+ cross_attention_output, _, _ = self.cross_attn_layers[i](
704
+ hidden_states=hidden_states,
705
+ cross_attention_states=patch_embeds,
706
+ attention_mask=cross_mask,
707
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
708
+ output_attentions=False,
709
+ use_cache=False,
710
+ cache_position=None,
711
+ )
712
+ hidden_states = hidden_states + cross_attention_output
713
+
714
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None)
715
+ hidden_states = layer_outputs[0]
716
+
717
+ logits = self.norm(hidden_states)
718
+ # logits = self.lm_head(logits)
719
+ return logits, cache
720
+
721
+
722
+ class BLTCrossAttention(nn.Module):
723
+ """Cross-attention module for BLT, following transformers style"""
724
+
725
+ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None):
726
+ super().__init__()
727
+ self.config = config
728
+ self.layer_idx = layer_idx
729
+ # Use provided hidden_size or fallback to encoder dimension
730
+ self.hidden_size = hidden_size or config.encoder_config.hidden_size
731
+ self.num_heads = config.num_attention_heads
732
+ self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention
733
+ self.head_dim = self.hidden_size // self.num_heads
734
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
735
+ self.scaling = None #self.head_dim ** -0.5
736
+ self.dropout = config.dropout
737
+
738
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
739
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
740
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
741
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
742
+
743
+ self.q_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps)
744
+ self.k_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps)
745
+
746
+ def forward(
747
+ self,
748
+ hidden_states: torch.Tensor,
749
+ cross_attention_states: Optional[torch.Tensor] = None,
750
+ past_key_value: Optional[Cache] = None,
751
+ attention_mask: Optional[torch.Tensor] = None,
752
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
753
+ output_attentions: bool = False,
754
+ use_cache: Optional[bool] = None,
755
+ cache_position: Optional[torch.LongTensor] = None,
756
+ **kwargs,
757
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
758
+ """Input shape: Batch x Time x Channel"""
759
+ bsz, q_len, _ = hidden_states.size()
760
+
761
+ query_states = self.q_norm(hidden_states) # BLT normalizes first
762
+ query_states = self.q_proj(query_states)
763
+
764
+ if cross_attention_states is not None:
765
+ cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first
766
+ key_states = self.k_proj(cross_attention_states)
767
+ value_states = self.v_proj(cross_attention_states)
768
+ if past_key_value is not None:
769
+ # if we have a new cross attention states + new tokens, we only computed key_states on that new cross attention states
770
+ # we still update the cross key states, past_cross_states, new_cross_states. And use it!
771
+ key_states, value_states = past_key_value.update(
772
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
773
+ )
774
+ elif cache_position is not None and cache_position[0] != 0:
775
+ key_states, value_states = (
776
+ past_key_value.key_cache[self.layer_idx],
777
+ past_key_value.value_cache[self.layer_idx],
778
+ )
779
+ else:
780
+ if cross_attention_states is None:
781
+ raise ValueError(
782
+ "Cross attention layer can't find neither `cross_attention_states` nor cached values for key/values!"
783
+ )
784
+
785
+ attention_interface: Callable = eager_attention_forward
786
+
787
+ self.config._attn_implementation = "sdpa"
788
+ if self.config._attn_implementation != "eager":
789
+ if self.config._attn_implementation == "sdpa" and output_attentions:
790
+ logger.warning_once(
791
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
792
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
793
+ )
794
+ else:
795
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
796
+
797
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
798
+ key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
799
+ value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
800
+
801
+ attn_output, attn_weights = attention_interface(
802
+ self,
803
+ query_states,
804
+ key_states,
805
+ value_states,
806
+ attention_mask,
807
+ dropout=0.0,
808
+ scaling=self.scaling,
809
+ **kwargs,
810
+ )
811
+
812
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
813
+ attn_output = self.o_proj(attn_output)
814
+
815
+ if full_text_row_masked_out_mask is not None:
816
+ attn_output = full_text_row_masked_out_mask[:, 0] * attn_output
817
+
818
+ attn_output = attn_output + hidden_states
819
+
820
+ if not output_attentions:
821
+ attn_weights = None
822
+
823
+ return attn_output, attn_weights, past_key_value
824
+
825
+
826
+ class BLTGlobalTransformer(nn.Module):
827
+ def __init__(self, config: BLTGlobalTransformerConfig):
828
+ super().__init__()
829
+
830
+ self.config = config
831
+
832
+ self.layers = nn.ModuleList()
833
+ for layer_idx in range(config.num_hidden_layers):
834
+ self.layers.append(BLTTransformerLayer(config, layer_idx))
835
+
836
+ self.rotary_emb = BLTRotaryEmbedding(config=config)
837
+
838
+
839
+ def forward(
840
+ self,
841
+ input_embeds: torch.Tensor,
842
+ mask: Optional[Union[BlockMask, torch.Tensor, str]] = None,
843
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
844
+ ):
845
+ batch_size, seq_len, _ = input_embeds.shape
846
+
847
+ hidden_states = input_embeds
848
+
849
+ hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
850
+
851
+ position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1)
852
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
853
+
854
+ for i, layer in enumerate(self.layers):
855
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None)
856
+ hidden_states = layer_outputs[0]
857
+
858
+ return hidden_states, cache
859
+
860
+
861
+
862
+
863
+ class BLTPreTrainedModel(PreTrainedModel):
864
+ config_class = BLTConfig
865
+ base_model_prefix = "model"
866
+ supports_gradient_checkpointing = True
867
+ _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"]
868
+ _skip_keys_device_placement = ["past_key_values"]
869
+ _supports_flash_attn_2 = False # BLT uses its own attention implementation
870
+ _supports_sdpa = True
871
+ _supports_cache_class = False
872
+
873
+ def _init_weights(self, module):
874
+ if isinstance(module, nn.Linear):
875
+ std = getattr(module, '_custom_std', module.in_features ** (-0.5))
876
+ nn.init.trunc_normal_(
877
+ module.weight,
878
+ mean=0.0,
879
+ std=std,
880
+ a=-3 * std,
881
+ b=3 * std,
882
+ )
883
+ if module.bias is not None:
884
+ nn.init.zeros_(module.bias)
885
+
886
+ elif isinstance(module, nn.Embedding):
887
+ std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5))
888
+ nn.init.trunc_normal_(
889
+ module.weight,
890
+ mean=0.0,
891
+ std=std,
892
+ a=-3 * std,
893
+ b=3 * std,
894
+ )
895
+
896
+ elif isinstance(module, BLTModel):
897
+ if module.encoder_hash_tok_embedding is not None:
898
+ emb_std = module.config.encoder_config.hidden_size ** (-0.5)
899
+ for emb in module.encoder_hash_tok_embedding:
900
+ emb._custom_std = emb_std
901
+
902
+ elif isinstance(module, BLTLocalEncoder):
903
+ if module.patch_embedding_projection is not None:
904
+ module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5)
905
+
906
+ elif isinstance(module, BLTLocalDecoder):
907
+ if module.patch_embedding_projection is not None:
908
+ module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5)
909
+
910
+ elif isinstance(module, BLTPatcher):
911
+ emb_std = module.config.hidden_size ** (-0.5)
912
+ module.embed_tokens._custom_std = emb_std
913
+ module.lm_head._custom_std = emb_std
914
+
915
+ elif isinstance(module, BLTForCausalLM):
916
+ if module.lm_head is not None:
917
+ module.lm_head._custom_std = module.config.decoder_config.hidden_size ** (-0.5)
918
+
919
+
920
+ class BLTModel(BLTPreTrainedModel):
921
+ def __init__(self, config: BLTConfig):
922
+ super().__init__(config)
923
+ self.config = config
924
+ self.local_encoder = BLTLocalEncoder(config.encoder_config)
925
+ self.global_transformer = BLTGlobalTransformer(config.global_config)
926
+ self.local_decoder = BLTLocalDecoder(config.decoder_config)
927
+ self.encoder_hash_tok_embedding = init_hash_embeddings(
928
+ config,
929
+ local_encoder_dim=config.encoder_config.hidden_size,
930
+ encoder_hash_byte_group_size=config.encoder_hash_byte_group_size,
931
+ )
932
+ if self.config.patch_in_forward:
933
+ self.patcher = BLTPatcher(config.patcher_config)
934
+ self.patcher.eval()
935
+ for param in self.patcher.parameters():
936
+ param.requires_grad = False
937
+ else:
938
+ self.patcher = None
939
+
940
+ def forward(
941
+ self,
942
+ tokens: torch.Tensor,
943
+ patch_lengths: Optional[torch.Tensor] = None,
944
+ attention_mask=None,
945
+ position_ids=None,
946
+ past_key_values=None,
947
+ inputs_embeds=None,
948
+ use_cache=None,
949
+ output_attentions=None,
950
+ output_hidden_states=None,
951
+ return_dict=None,
952
+ cache_position=None,
953
+ **kwargs,
954
+ ):
955
+ """
956
+ Args:
957
+ tokens (torch.Tensor): Input token ids.
958
+ patch_lengths (Optional[torch.Tensor]): Patch lengths for patching.
959
+ attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Ignored, for compatibility.
960
+ Returns:
961
+ torch.Tensor: Final hidden states (as before).
962
+ """
963
+ batch_size, sequence_length = tokens.shape
964
+ # Handle patching
965
+ if patch_lengths is None:
966
+ if self.config.patching_mode == PatchingModeEnum.entropy:
967
+ _, patch_lengths, _ = self.patcher(
968
+ tokens,
969
+ patch_size=self.config.patch_size,
970
+ threshold=self.config.patching_threshold,
971
+ max_patch_length=self.config.max_patch_length,
972
+ patching_batch_size=self.config.patching_batch_size,
973
+ device=tokens.device,
974
+ )
975
+ else:
976
+ patch_lengths = process_patch_lengths(
977
+ torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device),
978
+ self.config.max_patch_length
979
+ )
980
+ patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length)
981
+ cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask(
982
+ patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, torch.float32
983
+ )
984
+ encoder_embeds = compute_hash_embeddings(
985
+ tokens, self.local_encoder, self.encoder_hash_tok_embedding,
986
+ self.config.encoder_hash_byte_group_nb_functions,
987
+ self.config.encoder_hash_byte_group_size,
988
+ self.config.encoder_hash_byte_group_vocab,
989
+ )
990
+ encoder_hidden_states, encoder_cross_states = self.local_encoder(
991
+ input_ids=tokens,
992
+ input_embeds=encoder_embeds,
993
+ patch_embeds=None,
994
+ cross_mask=cross_attn_mask_enc,
995
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc,
996
+ num_patches=patch_lengths.shape[1],
997
+ patch_ids=patch_ids,
998
+ )
999
+ global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
1000
+ global_hidden_states, _ = self.global_transformer(
1001
+ input_embeds=global_hidden_states,
1002
+ )
1003
+ decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length)
1004
+ cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask(
1005
+ decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, torch.float32
1006
+ )
1007
+ output, _ = self.local_decoder(
1008
+ tokens=tokens,
1009
+ embeds=encoder_hidden_states,
1010
+ patch_embeds=global_hidden_states,
1011
+ mask=None,
1012
+ cross_mask=cross_attn_mask_dec,
1013
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec,
1014
+ )
1015
+ if output_hidden_states or output_attentions:
1016
+ if return_dict:
1017
+ return {"last_hidden_state": output, "hidden_states": None, "attentions": None}
1018
+ else:
1019
+ return (output, None, None)
1020
+ return output
1021
+
1022
+ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
1023
+ """Convert patch lengths to patch IDs for each token position."""
1024
+ batch_size = patch_lengths.shape[0]
1025
+ patch_starts = torch.cat([
1026
+ torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
1027
+ patch_lengths.cumsum(dim=-1)[:, :-1]
1028
+ ], dim=-1)
1029
+
1030
+ token_positions = torch.arange(seq_len, device=patch_lengths.device)
1031
+ return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1
1032
+
1033
+
1034
+ class BLTPatcher(BLTPreTrainedModel):
1035
+ def __init__(self, config: BLTPatcherConfig):
1036
+ super().__init__(config)
1037
+
1038
+ self.rotary_emb = BLTRotaryEmbedding(config=self.config)
1039
+
1040
+ self.layers = nn.ModuleList()
1041
+
1042
+ for layer_idx in range(self.config.num_hidden_layers):
1043
+ self.layers.append(BLTTransformerLayer(self.config, layer_idx))
1044
+
1045
+
1046
+ self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
1047
+
1048
+ self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps)
1049
+
1050
+ self.lm_head = nn.Linear(
1051
+ self.config.hidden_size,
1052
+ self.config.vocab_size,
1053
+ bias=False,
1054
+ )
1055
+
1056
+ def forward(
1057
+ self,
1058
+ token_values: torch.Tensor,
1059
+ patch_size: Optional[int] = None,
1060
+ threshold: Optional[float] = None,
1061
+ max_patch_length: Optional[int] = None,
1062
+ patching_batch_size: int = 1,
1063
+ device: Optional[str] = None,
1064
+ ):
1065
+
1066
+ # Handle chunked processing for entropy calculation
1067
+ entropies = []
1068
+ predictions = []
1069
+ max_length = self.config.max_position_embeddings
1070
+ batch_numel = max_length * patching_batch_size
1071
+ splits = torch.split(token_values.flatten(), batch_numel)
1072
+
1073
+ for split in splits:
1074
+ pad_size = (max_length - (split.numel() % max_length)) % max_length
1075
+ pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False)
1076
+ split = torch.cat((split, pad), dim=0)
1077
+ split = split.reshape(-1, max_length)
1078
+ if device is not None:
1079
+ split = split.to(device)
1080
+
1081
+ # Process chunk: embeddings -> layers -> output
1082
+ batch_size, sequence_length = split.shape
1083
+ input_embeds = self.embed_tokens(split)
1084
+
1085
+ hidden_states = input_embeds
1086
+
1087
+ batch_size, _, _ = input_embeds.shape
1088
+
1089
+ position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1)
1090
+
1091
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1092
+
1093
+ for i, layer in enumerate(self.layers):
1094
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None)
1095
+ hidden_states = layer_outputs[0]
1096
+
1097
+ logits = self.lm_head(self.norm(hidden_states))
1098
+ logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :]
1099
+ predictions.append(logits)
1100
+ prediction_entropies = torch.distributions.Categorical(logits=logits).entropy()
1101
+ entropies.append(prediction_entropies)
1102
+
1103
+ concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape)
1104
+ concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1)
1105
+
1106
+ # Always compute patch lengths from concatenated entropies
1107
+ batch_size, sequence_length = token_values.shape
1108
+
1109
+ # Find patch start IDs based on entropy
1110
+ if patch_size is not None:
1111
+ patch_lengths = self.patch_lengths_from_entropies(
1112
+ entropies=concat_entropies,
1113
+ sequence_length=sequence_length,
1114
+ patch_size=patch_size,
1115
+ threshold=threshold,
1116
+ )
1117
+ else:
1118
+ # Default to byte-level patching
1119
+ patch_lengths = torch.ones((batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device)
1120
+ patch_lengths = process_patch_lengths(patch_lengths, max_patch_length)
1121
+ return concat_entropies, patch_lengths, concat_predictions
1122
+
1123
+ @staticmethod
1124
+ def patch_lengths_from_entropies(
1125
+ entropies,
1126
+ sequence_length,
1127
+ patch_size=None,
1128
+ threshold=None,
1129
+ ):
1130
+ """
1131
+ Computes patch lengths from token entropies.
1132
+
1133
+ Depending on whether a threshold is provided, the function uses either:
1134
+ - Top-k selection based on entropy (when `threshold` is None), or
1135
+ - Thresholding the entropy values (when `threshold` is set).
1136
+ """
1137
+
1138
+ batch_size = entropies.shape[0]
1139
+
1140
+ # Always include token 0 and 1 as starting tokens
1141
+ init_tokens = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1)
1142
+ offset = init_tokens.shape[1]
1143
+
1144
+ # Ignore first token entropy (BOS)
1145
+ entropies = entropies[:, 1:]
1146
+
1147
+ if threshold is None:
1148
+ # Use top-k entropy values to define patch start points
1149
+ num_patches = sequence_length // patch_size
1150
+ topk_indices = entropies.topk(num_patches - 2, dim=1).indices
1151
+ patch_starts = topk_indices.sort(dim=1).values
1152
+ else:
1153
+ # Threshold the entropy values to define patch start points
1154
+ patch_mask = entropies > threshold
1155
+
1156
+ seq_len = patch_mask.shape[1]
1157
+
1158
+ # Create patch IDs (token indices), and add a sentinel to ensure alignment
1159
+ token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1)
1160
+ sentinel = torch.full_like(token_indices, seq_len)
1161
+ padded_indices = torch.cat([token_indices, sentinel], dim=1)
1162
+
1163
+ # Pad mask with inverse to align sentinel correctly
1164
+ padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1)
1165
+
1166
+ # Select indices where mask is True
1167
+ patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len)
1168
+ max_valid_patches = patch_mask.sum(dim=1).max()
1169
+ patch_starts = patch_starts[:, :max_valid_patches]
1170
+
1171
+ # Offset patch starts to account for the two initial tokens
1172
+ patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1)
1173
+
1174
+ # Compute patch end positions by shifting start positions
1175
+ last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1)
1176
+ patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1)
1177
+
1178
+ patch_lengths = patch_ends - patch_start_ids + 1
1179
+
1180
+ return patch_lengths
1181
+
1182
+
1183
+ class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin):
1184
+ config_class = BLTConfig
1185
+ base_model_prefix = "model"
1186
+ supports_gradient_checkpointing = True
1187
+ _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"]
1188
+
1189
+ def __init__(self, config):
1190
+ super().__init__(config)
1191
+ self.model = BLTModel(config)
1192
+ self.vocab_size = config.vocab_size
1193
+ self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False)
1194
+ self.post_init()
1195
+
1196
+ def get_input_embeddings(self):
1197
+ return self.model.local_encoder.embed_tokens
1198
+
1199
+ def set_input_embeddings(self, value):
1200
+ self.model.local_encoder.embed_tokens = value
1201
+
1202
+ def get_output_embeddings(self):
1203
+ return self.lm_head
1204
+
1205
+ def set_output_embeddings(self, new_embeddings):
1206
+ self.lm_head = new_embeddings
1207
+
1208
+ def set_decoder(self, decoder):
1209
+ self.model = decoder
1210
+
1211
+ def get_decoder(self):
1212
+ return self.model
1213
+
1214
+ def forward(
1215
+ self,
1216
+ input_ids=None,
1217
+ attention_mask=None,
1218
+ position_ids=None,
1219
+ past_key_values=None,
1220
+ inputs_embeds=None,
1221
+ labels=None,
1222
+ use_cache=None,
1223
+ output_attentions=None,
1224
+ output_hidden_states=None,
1225
+ return_dict=None,
1226
+ cache_position=None,
1227
+ **kwargs,
1228
+ ):
1229
+ """
1230
+ Args:
1231
+ input_ids (torch.LongTensor): Input token ids.
1232
+ attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Standard transformers arguments.
1233
+ labels (torch.LongTensor, optional): Labels for language modeling loss.
1234
+ Returns:
1235
+ CausalLMOutputWithPast or tuple: Standard transformers output.
1236
+ """
1237
+ # Route only input_ids to BLTModel (as tokens)
1238
+ hidden_states = self.model(
1239
+ input_ids,
1240
+ attention_mask=attention_mask,
1241
+ position_ids=position_ids,
1242
+ past_key_values=past_key_values,
1243
+ inputs_embeds=inputs_embeds,
1244
+ use_cache=use_cache,
1245
+ output_attentions=output_attentions,
1246
+ output_hidden_states=output_hidden_states,
1247
+ return_dict=return_dict,
1248
+ cache_position=cache_position,
1249
+ **kwargs,
1250
+ )
1251
+ if isinstance(hidden_states, dict):
1252
+ sequence_output = hidden_states["last_hidden_state"]
1253
+ elif isinstance(hidden_states, tuple):
1254
+ sequence_output = hidden_states[0]
1255
+ else:
1256
+ sequence_output = hidden_states
1257
+ logits = self.lm_head(sequence_output)
1258
+ loss = None
1259
+ if labels is not None:
1260
+ # Shift so that tokens < n predict n
1261
+ shift_logits = logits[..., :-1, :].contiguous()
1262
+ shift_labels = labels[..., 1:].contiguous()
1263
+ loss_fct = torch.nn.CrossEntropyLoss()
1264
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1265
+ if not return_dict:
1266
+ output = (logits,)
1267
+ if loss is not None:
1268
+ output = (loss,) + output
1269
+ return output
1270
+ return CausalLMOutputWithPast(
1271
+ loss=loss,
1272
+ logits=logits,
1273
+ past_key_values=None,
1274
+ hidden_states=None,
1275
+ attentions=None,
1276
+ )
1277
+
1278
+ __all__ = [
1279
+ "BLTPreTrainedModel",
1280
+ "BLTModel",
1281
+ "BLTPatcher",
1282
+ "BLTLocalEncoder",
1283
+ "BLTLocalDecoder",
1284
+ "BLTGlobalTransformer",
1285
+ "BLTTransformerLayer",
1286
+ "BLTForCausalLM",
1287
+ ]
backup_blt_wip copy/modeling_blt_old.py ADDED
@@ -0,0 +1,1602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #blt old
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+
5
+ import logging
6
+ import os
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn
11
+ import torch.nn as nn
12
+ from torch.nn import functional as F
13
+ from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
14
+
15
+ from ...modeling_utils import PreTrainedModel
16
+ from .configuration_blt_og import (
17
+ BLTConfig,
18
+ PatchingModeEnum,
19
+ )
20
+
21
+ RMSNorm = nn.RMSNorm
22
+
23
+ logger = logging.getLogger()
24
+
25
+ flex_attention_comp = flex_attention
26
+
27
+
28
+ def causal_mask(b, h, q_idx, kv_idx):
29
+ return q_idx >= kv_idx
30
+
31
+
32
+ def create_causal_mask(
33
+ seqlen,
34
+ attn_impl: str,
35
+ attn_bias_type: str | None,
36
+ *,
37
+ eos_id: int | None = None,
38
+ tokens: torch.Tensor | None = None,
39
+ sliding_window: int | None = None,
40
+ ):
41
+ if attn_impl == "sdpa":
42
+ BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0))
43
+
44
+ if attn_bias_type == "causal":
45
+ return "causal"
46
+
47
+ if BLT_SUPPRESS_ATTN_ERROR == 1:
48
+ return "causal"
49
+ else:
50
+ raise ValueError(
51
+ "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1"
52
+ )
53
+ elif attn_impl == "flex_attention":
54
+ return create_block_mask(causal_mask, None, None, seqlen, seqlen)
55
+ else:
56
+ raise NotImplementedError(f"Attention {attn_impl} with {sliding_window} sliding window not implemented")
57
+
58
+
59
+ def cross_entropy(pred, target, **kwargs):
60
+ return F.nll_loss(
61
+ F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
62
+ target.flatten(end_dim=-1),
63
+ **kwargs,
64
+ )
65
+
66
+
67
+ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
68
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
69
+ assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
70
+ bs, slen, n_kv_heads, head_dim = x.shape
71
+ if n_rep == 1:
72
+ return x
73
+ return (
74
+ x[:, :, :, None, :]
75
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
76
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
77
+ )
78
+
79
+
80
+ def precompute_freqs_cis(
81
+ dim: int,
82
+ end: int,
83
+ theta: float = 10000.0,
84
+ rope_use_fp32_in_outer_product: bool = False,
85
+ ):
86
+ """
87
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
88
+
89
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
90
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
91
+ The returned tensor contains complex values in complex64 data type.
92
+
93
+ Args:
94
+ dim (int): Dimension of the frequency tensor.
95
+ end (int): End index for precomputing frequencies.
96
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
97
+
98
+ Returns:
99
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
100
+ """
101
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
102
+ t = torch.arange(end, device=freqs.device)
103
+ if rope_use_fp32_in_outer_product:
104
+ t = t.to(torch.float32)
105
+
106
+ freqs = torch.outer(t, freqs).float()
107
+
108
+ cos, sin = freqs.cos(), freqs.sin()
109
+
110
+ return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
111
+
112
+
113
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
114
+ """
115
+ Reshape frequency tensor for broadcasting it with another tensor.
116
+
117
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
118
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
119
+
120
+ Args:
121
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
122
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
123
+ seq_dim (int): Sequence dimension index.
124
+
125
+ Returns:
126
+ torch.Tensor: Reshaped frequency tensor.
127
+ """
128
+ ndim = x.ndim
129
+ assert 0 <= seq_dim < ndim
130
+ assert freqs_cis.shape == (
131
+ x.shape[seq_dim],
132
+ x.shape[-3],
133
+ 2,
134
+ 2,
135
+ ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
136
+ shape = [d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])] + [2, 2]
137
+ return freqs_cis.view(*shape)
138
+
139
+
140
+ def apply_rotary_emb(
141
+ xq: torch.Tensor,
142
+ xk: torch.Tensor,
143
+ seq_dim: int,
144
+ freqs_cis: torch.Tensor,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+
147
+ xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
148
+ xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
149
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, seq_dim).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
150
+ xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
151
+ xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
152
+ return xq_out.type_as(xq), xk_out.type_as(xk)
153
+
154
+
155
+ # Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
156
+ class RotaryEmbedding(torch.nn.Module):
157
+ """
158
+ RotaryEmbedding Module
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ theta: float,
164
+ head_dim: int,
165
+ max_seqlen: int = 1024,
166
+ rope_use_fp32_in_outer_product: bool = False,
167
+ ):
168
+ super().__init__()
169
+
170
+ self.theta = theta
171
+ self.head_dim = head_dim
172
+ self.max_seqlen = max_seqlen
173
+ self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
174
+
175
+ self.register_buffer(
176
+ "freqs_cis",
177
+ precompute_freqs_cis(
178
+ dim=head_dim,
179
+ end=max_seqlen,
180
+ theta=theta,
181
+ rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
182
+ ),
183
+ persistent=False,
184
+ )
185
+
186
+
187
+ def forward(self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None):
188
+ """
189
+ Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
190
+ Args:
191
+ seqlen (int): Contiguous sequence length
192
+ tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
193
+
194
+ Returns:
195
+ Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
196
+ """
197
+ test = (seqlen is not None) or (tok_idx is not None)
198
+ assert test, "Should provide atleast seqlen or tok_idx"
199
+ if tok_idx is not None:
200
+ return self.freqs_cis[tok_idx]
201
+ elif seqlen is not None:
202
+ return self.freqs_cis[0:seqlen]
203
+
204
+
205
+ class BLTSelfAttention(nn.Module):
206
+ def __init__(
207
+ self,
208
+ dim: int,
209
+ head_dim: int,
210
+ n_heads: int,
211
+ n_kv_heads: int,
212
+ rope_theta: float,
213
+ ):
214
+ super().__init__()
215
+
216
+ self.dim = dim
217
+ self.head_dim = head_dim
218
+ self.rope_theta = rope_theta
219
+
220
+ self.n_heads = n_heads
221
+ self.n_kv_heads = n_kv_heads
222
+ self.heads_per_group = self.n_heads // self.n_kv_heads
223
+
224
+ self.wq = nn.Linear(
225
+ dim,
226
+ n_heads * head_dim,
227
+ bias=False,
228
+ )
229
+ self.wk = nn.Linear(
230
+ dim,
231
+ n_kv_heads * head_dim,
232
+ bias=False,
233
+ )
234
+ self.wv = nn.Linear(
235
+ dim,
236
+ n_kv_heads * head_dim,
237
+ bias=False,
238
+ )
239
+
240
+ self.wo = nn.Linear(
241
+ n_heads * head_dim,
242
+ dim,
243
+ bias=False,
244
+ )
245
+
246
+ def forward(
247
+ self,
248
+ x: torch.Tensor,
249
+ freq_cis: torch.Tensor,
250
+ tok_idx: Optional[torch.Tensor] = None,
251
+ mask: Optional[Union[BlockMask, str]] = None,
252
+ attn_impl: str = "sdpa",
253
+ ) -> torch.Tensor:
254
+ # B S D
255
+ bsz, seq_len, dim = x.shape
256
+
257
+ xq = self.wq(x.view_as(x))
258
+ xk = self.wk(x.view_as(x))
259
+ xv = self.wv(x.view_as(x))
260
+
261
+ output_shape = xq.shape
262
+ # B S D -> B S H D
263
+ xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
264
+ xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
265
+ xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
266
+
267
+ xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
268
+
269
+ # This condition helps us be easily compatible
270
+ # with inference by adding a pluggable KVCache
271
+ if hasattr(self, "kv_cache"):
272
+ xk, xv = self.kv_cache.update(xk, xv, tok_idx)
273
+
274
+ xk = repeat_kv(xk, self.heads_per_group, dim=2)
275
+ xv = repeat_kv(xv, self.heads_per_group, dim=2)
276
+
277
+ if attn_impl == "flex_attention":
278
+ assert mask is None or isinstance(mask, BlockMask)
279
+ xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv))
280
+ output = flex_attention_comp(xq, xk, xv, block_mask=mask)
281
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
282
+
283
+ elif attn_impl == "sdpa":
284
+ xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv))
285
+ assert mask is None or isinstance(mask, (str, torch.Tensor))
286
+ is_causal = (mask == "causal") if isinstance(mask, str) else False
287
+ mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None
288
+ output = F.scaled_dot_product_attention(
289
+ xq,
290
+ xk,
291
+ xv,
292
+ is_causal=is_causal,
293
+ attn_mask=mask,
294
+ )
295
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
296
+ else:
297
+ raise NotImplementedError(f"Attention implementation {attn_impl} not supported")
298
+
299
+ output_reshaped = output.reshape(output_shape)
300
+
301
+ output = self.wo(output_reshaped)
302
+
303
+ return output
304
+
305
+
306
+ class BLTMLP(nn.Module):
307
+ def __init__(
308
+ self,
309
+ dim: int,
310
+ hidden_dim: int,
311
+ multiple_of: int,
312
+ ffn_dim_multiplier: Optional[float],
313
+ mp_size: int = 1,
314
+ ):
315
+ super().__init__()
316
+
317
+ hidden_dim = int(2 * hidden_dim / 3)
318
+ if ffn_dim_multiplier is not None:
319
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
320
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
321
+ assert hidden_dim % mp_size == 0
322
+
323
+ self.dim = dim
324
+ self.hidden_dim = hidden_dim
325
+
326
+ self.w1 = nn.Linear(
327
+ dim,
328
+ hidden_dim,
329
+ bias=False,
330
+ )
331
+ self.w3 = nn.Linear(
332
+ dim,
333
+ hidden_dim,
334
+ bias=False,
335
+ )
336
+ self.w2 = nn.Linear(
337
+ hidden_dim,
338
+ dim,
339
+ bias=False,
340
+ )
341
+
342
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
343
+ # B S D
344
+ x1 = self.w1(x.view_as(x))
345
+ x3 = self.w3(x.view_as(x))
346
+ output = self.w2(F.silu(x1) * x3)
347
+ return output
348
+
349
+
350
+
351
+
352
+ class BLTTransformerLayer(nn.Module):
353
+ def __init__(self, args):
354
+ super().__init__()
355
+
356
+ # Extract parameters from dictionary
357
+ dim = args["dim"]
358
+ n_heads = args["n_heads"]
359
+ head_dim = args["head_dim"]
360
+ n_kv_heads = args["n_kv_heads"]
361
+ rope_theta = args["rope_theta"]
362
+ multiple_of = args["multiple_of"]
363
+ ffn_dim_multiplier = args["ffn_dim_multiplier"]
364
+ norm_eps = args["norm_eps"]
365
+
366
+ assert (head_dim is not None) or (n_heads is not None), "Should specify at least head_dim or n_heads"
367
+ self.head_dim = head_dim or dim // n_heads
368
+ self.n_heads = n_heads or dim // head_dim
369
+ self.n_kv_heads = n_kv_heads or self.n_heads
370
+
371
+ assert n_heads % self.n_kv_heads == 0
372
+ assert dim % n_heads == 0
373
+
374
+ self.attention = BLTSelfAttention(
375
+ dim=dim,
376
+ head_dim=self.head_dim,
377
+ n_heads=self.n_heads,
378
+ n_kv_heads=self.n_kv_heads,
379
+ rope_theta=rope_theta,
380
+ )
381
+ self.feed_forward = BLTMLP(
382
+ dim=dim,
383
+ hidden_dim=4 * dim,
384
+ multiple_of=multiple_of,
385
+ ffn_dim_multiplier=ffn_dim_multiplier,
386
+ )
387
+ self.attention_norm = RMSNorm(dim, eps=norm_eps)
388
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
389
+
390
+ def forward(
391
+ self,
392
+ x: torch.Tensor,
393
+ freq_cis: torch.Tensor,
394
+ tok_idx: Optional[torch.Tensor] = None,
395
+ mask: Optional[Union[BlockMask, str]] = None,
396
+ attn_impl: str = "sdpa",
397
+ ) -> torch.Tensor:
398
+ norm_x = self.attention_norm(x)
399
+ attn_out = self.attention(
400
+ norm_x,
401
+ freq_cis,
402
+ tok_idx=tok_idx,
403
+ mask=mask,
404
+ attn_impl=attn_impl,
405
+ )
406
+ h = x + attn_out
407
+ h_norm = self.ffn_norm(h)
408
+ out = h + self.feed_forward(h_norm)
409
+ return out
410
+
411
+ def check_non_zero_after_zero(tensor):
412
+ zero_mask = tensor == 0
413
+ shifted_mask = torch.cat(
414
+ [
415
+ torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
416
+ zero_mask[:, :-1],
417
+ ],
418
+ dim=1,
419
+ )
420
+ non_zero_after_zero = (tensor != 0) & shifted_mask
421
+ return non_zero_after_zero.any()
422
+
423
+ def rolling_polynomial_hash(t, hash_func_nb: int = 0):
424
+ primes = [
425
+ 1000000007,
426
+ 5915587277,
427
+ 1500450271,
428
+ 3267000013,
429
+ 5754853343,
430
+ 4093082899,
431
+ 9576890767,
432
+ 3628273133,
433
+ 2860486313,
434
+ 5463458053,
435
+ 3367900313,
436
+ ]
437
+ prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device)
438
+ prime_powers = torch.stack([prime**i for i in range(t.shape[-1])])
439
+ return torch.sum(t * prime_powers, dim=-1)
440
+
441
+
442
+ def byte_group_hash_function(x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000):
443
+ """
444
+ Returns a hash of the input x and maps it to a value in the range [0, max_hash].
445
+
446
+ expects: x of shape (batch_size, seq_len) with values as ids in the token vocab.
447
+ returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash].
448
+
449
+ Note: max hash can make a big difference on the number of collisions.
450
+ """
451
+ with torch.no_grad():
452
+ bs, seq_len = x.shape
453
+ prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device)
454
+ x = torch.cat([prefix, x], dim=1)
455
+ windows = x.unfold(1, group_size, 1)
456
+ # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
457
+ hashes = rolling_polynomial_hash(windows, hash_func_nb)
458
+ hash_values_range = hashes % max_hash
459
+ hash_values_range.requires_grad = False
460
+ return hash_values_range
461
+
462
+
463
+ def create_patch_mask_from_ids(patch_ids, num_patches, window=None, patches_as_queries=False):
464
+ """
465
+ Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k)
466
+ is True if the patch id at position (i, j) is less than or equal to k.
467
+ Args:
468
+ patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids.
469
+ num_patches (int): Total number of patches.
470
+ window (int): If not None, only considers patches within a window of size window.
471
+ patches_as_queries (bool): If True, the patches are used as queries
472
+ Returns:
473
+ torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask.
474
+ """
475
+ bs, seq_len = patch_ids.shape
476
+ if not patches_as_queries:
477
+ q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches)
478
+ kv_ids = (
479
+ torch.arange(num_patches, device=patch_ids.device)
480
+ .unsqueeze(0)
481
+ .unsqueeze(0)
482
+ .expand(bs, seq_len, num_patches)
483
+ )
484
+ else:
485
+ kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len)
486
+ q_ids = (
487
+ torch.arange(num_patches, device=patch_ids.device)
488
+ .unsqueeze(0)
489
+ .unsqueeze(-1)
490
+ .expand(bs, num_patches, seq_len)
491
+ )
492
+ if window is None:
493
+ mask = q_ids == kv_ids
494
+ else:
495
+ mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window)
496
+ return mask
497
+
498
+
499
+ def cross_attn_mask(
500
+ patch_ids,
501
+ patch_lengths,
502
+ N,
503
+ patches_as_queries=False,
504
+ cross_attn_k=1,
505
+ window=None,
506
+ block_mask=True,
507
+ ):
508
+ bs = patch_ids.shape[0]
509
+ with torch.no_grad():
510
+ # Create the patch mask
511
+ cross_mask = create_patch_mask_from_ids(
512
+ patch_ids,
513
+ patch_lengths.shape[1],
514
+ window=window,
515
+ patches_as_queries=patches_as_queries,
516
+ ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1)
517
+ q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N
518
+ kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k
519
+ assert cross_mask.shape == (
520
+ bs,
521
+ q_len,
522
+ kv_len,
523
+ ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}"
524
+ # block_mask = None
525
+ if block_mask:
526
+
527
+ def patch_mask(b, h, q_idx, kv_idx):
528
+ return cross_mask[b, q_idx, kv_idx]
529
+
530
+ block_mask = create_block_mask(
531
+ patch_mask,
532
+ B=bs,
533
+ H=None,
534
+ Q_LEN=q_len,
535
+ KV_LEN=kv_len,
536
+ _compile=True,
537
+ )
538
+ return block_mask
539
+ else:
540
+ return torch.where(cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))).unsqueeze(
541
+ 1
542
+ ) # [bs, 1, q_len, kv_len]
543
+
544
+
545
+ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor:
546
+ if max_patch_length is None:
547
+ return patch_lengths
548
+
549
+ batch_size = patch_lengths.size(0)
550
+ split_all = []
551
+ max_len = 0
552
+
553
+ for seq in patch_lengths:
554
+ splits = []
555
+ for length in seq[seq > 0]:
556
+ # Split long patches into max_patch_length chunks
557
+ full, rem = divmod(length.item(), max_patch_length)
558
+ splits.extend([max_patch_length] * full + ([rem] if rem else []))
559
+ split_all.append(splits)
560
+ max_len = max(max_len, len(splits))
561
+
562
+ # Pad sequences to the maximum length
563
+ padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
564
+ for i, splits in enumerate(split_all):
565
+ if splits:
566
+ padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
567
+
568
+ # Trim trailing columns that are all zeros
569
+ last_non_zero = (padded != 0).flip(1).int().argmax(1).min()
570
+ if last_non_zero < padded.shape[1]:
571
+ padded = padded[:, :padded.shape[1] - last_non_zero]
572
+
573
+ return padded
574
+
575
+ class BLTLocalModelBase(nn.Module):
576
+ def __init__(self, config: BLTConfig, component_type: str = "encoder"):
577
+ super().__init__()
578
+
579
+ self.config = config
580
+
581
+ if component_type == "encoder":
582
+ self.dim = config.dim_local_encoder
583
+ self.n_layers = config.n_layers_local_encoder
584
+ self.n_heads = config.n_heads_local_encoder
585
+ self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen
586
+ self.attn_bias_type = "local_block_causal"
587
+ self.sliding_window = config.local_attention_window_len
588
+ elif component_type == "decoder":
589
+ self.dim = config.dim_local_decoder
590
+ self.n_layers = config.n_layers_local_decoder
591
+ self.n_heads = config.n_heads_local_decoder
592
+ self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen
593
+ self.attn_bias_type = "local_block_causal"
594
+ self.sliding_window = config.local_attention_window_len
595
+ else:
596
+ raise ValueError(f"Unknown component_type: {component_type}")
597
+
598
+ self.dropout = config.dropout
599
+ self.vocab_size = config.vocab_size + config.pm_size
600
+ self.patch_size = config.patch_size
601
+
602
+ self.attn_impl = config.attn_impl
603
+ self.use_rope = config.use_rope
604
+ self.init_std_factor = config.init_std_factor
605
+ self.init_base_std = config.init_base_std
606
+ self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None)
607
+ self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None)
608
+ self.cross_attn_k = getattr(config, "cross_attn_k", None)
609
+ self.eos_id = config.eos_token_id
610
+
611
+ self.boe_id = config.boe_id
612
+
613
+ # Initialize cross attention layers as None (will be set by subclasses if needed)
614
+ self.cross_attn_layers = None
615
+
616
+ # Create parameter dict for BLTTransformerLayers
617
+ layer_params = {
618
+ "dim": self.dim,
619
+ "n_heads": self.n_heads,
620
+ "head_dim": config.head_dim,
621
+ "n_kv_heads": getattr(config, "n_kv_heads", None),
622
+ "rope_theta": config.rope_theta,
623
+ "multiple_of": getattr(config, "multiple_of", 256),
624
+ "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None),
625
+ "norm_eps": config.norm_eps,
626
+ }
627
+
628
+ self.layers = nn.ModuleList([BLTTransformerLayer(layer_params) for _ in range(self.n_layers)])
629
+
630
+ if not self.use_rope:
631
+ self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length
632
+ else:
633
+ self.rope = RotaryEmbedding(
634
+ theta=config.rope_theta,
635
+ head_dim=config.head_dim or self.dim // self.n_heads,
636
+ max_seqlen=self.max_seqlen,
637
+ rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product,
638
+ )
639
+ self.pos_embeddings = None
640
+
641
+ # Set dimension-specific embedding dimensions
642
+ if component_type == "encoder":
643
+ self.dim_token_emb = config.encoder_dim_token_emb
644
+ self.dim_patch_emb = config.encoder_dim_patch_emb
645
+ elif component_type == "decoder":
646
+ self.dim_token_emb = config.decoder_dim_token_emb
647
+ self.dim_patch_emb = config.dim_global
648
+
649
+ self.token_embedding_projection = (
650
+ nn.Linear(self.dim_token_emb, self.dim, bias=False)
651
+ if self.dim_token_emb is not None and self.dim_token_emb != self.dim
652
+ else None
653
+ )
654
+
655
+ self.patch_embedding_projection = self._create_patch_projection(config)
656
+
657
+ def _should_create_patch_projection(self, config: BLTConfig):
658
+ dimension_mismatch = self.dim_patch_emb is not None and self.dim_patch_emb != self.dim
659
+
660
+ # Check cross attention conditions
661
+ cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or (
662
+ config.cross_attn_decoder and config.cross_attn_init_by_pooling
663
+ )
664
+
665
+ return dimension_mismatch or cross_attn_conditions
666
+
667
+ def _create_patch_projection(self, config):
668
+ if not self._should_create_patch_projection(config):
669
+ return None
670
+
671
+ output_dim = self.dim_token_emb * (self.cross_attn_k or 1)
672
+
673
+ return nn.Linear(
674
+ in_features=self.dim_patch_emb,
675
+ out_features=output_dim,
676
+ bias=False,
677
+ )
678
+
679
+ def apply_embedding(self, tokens, embeds):
680
+ if embeds is not None:
681
+ return embeds
682
+ else:
683
+ return self.tok_embeddings(tokens)
684
+
685
+
686
+ class BLTLocalEncoder(BLTLocalModelBase):
687
+ def __init__(self, config: BLTConfig):
688
+ super().__init__(config, component_type="encoder")
689
+
690
+ self.apply_transformer = config.use_local_encoder_transformer
691
+ self.downsampling_by_pooling = config.downsampling_by_pooling
692
+ self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None
693
+ self.cross_attn_encoder = config.cross_attn_encoder
694
+ self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder
695
+ self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling
696
+ self.cross_attn_nheads = config.cross_attn_nheads
697
+
698
+ self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim)
699
+
700
+ if self.cross_attn_encoder:
701
+ self.cross_attn_layers = torch.nn.ModuleList()
702
+ layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1
703
+ for _ in range(layers_to_add):
704
+ self.cross_attn_layers.append(
705
+ BLTCrossAttention(
706
+ dim=self.dim,
707
+ head_dim=self.dim // self.cross_attn_nheads,
708
+ n_heads=self.cross_attn_nheads,
709
+ n_kv_heads=self.cross_attn_nheads,
710
+ norm_eps=config.norm_eps,
711
+ )
712
+ )
713
+
714
+ def apply_embedding(self, tokens, embeds):
715
+ if embeds is not None:
716
+ assert self.expects_hash_embeddings, "Not expecting embeddings to be passed."
717
+ return embeds
718
+ else:
719
+ return self.tok_embeddings(tokens)
720
+
721
+ def forward(
722
+ self,
723
+ tokens: torch.Tensor,
724
+ embeds: Optional[torch.Tensor] = None,
725
+ patch_embeds: Optional[torch.Tensor] = None,
726
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
727
+ cross_mask: Optional[torch.Tensor] = None,
728
+ num_patches: Optional[int] = None,
729
+ patch_ids: Optional[torch.Tensor] = None,
730
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
731
+ ):
732
+ """ """
733
+ bs, seqlen = tokens.shape
734
+ if mask is None:
735
+ mask = create_causal_mask(
736
+ seqlen,
737
+ self.attn_impl,
738
+ "local_block_causal",
739
+ sliding_window=self.sliding_window,
740
+ tokens=tokens,
741
+ eos_id=self.eos_id,
742
+ )
743
+
744
+ h = self.apply_embedding(tokens, embeds)
745
+
746
+
747
+ freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
748
+
749
+
750
+ h = F.dropout(h, p=self.dropout, training=self.training)
751
+
752
+ for i, layer in enumerate(self.layers):
753
+ h = layer(h, freqs_cis, tok_idx=None, mask=mask, attn_impl=self.attn_impl)
754
+ # check if cross attention should be applied to either all layer or only the last layer
755
+ if self.cross_attn_encoder and (i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder):
756
+ # apply pooling and project
757
+ if self.cross_attn_init_by_pooling and patch_embeds is None:
758
+ patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids)
759
+ if self.patch_embedding_projection is not None:
760
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
761
+ patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim)
762
+
763
+ layer_idx = i if self.cross_attn_all_layers_encoder else 0
764
+ patch_embeds_cross = self.cross_attn_layers[layer_idx](
765
+ x=patch_embeds,
766
+ kv=h,
767
+ mask=cross_mask,
768
+ )
769
+ patch_embeds = patch_embeds + patch_embeds_cross
770
+
771
+ h_residual = patch_embeds if self.cross_attn_encoder else None
772
+ return (h, h_residual), cache
773
+
774
+ def patch_reduce(self, h, max_num_patches, reduction, patch_ids):
775
+ """
776
+ Reduce variable length patches to single embedding per patch
777
+ Note: this works with variable number of patches for different sequences in the batch
778
+ It handles variable length patches by assuming that patch_lengths will be 0 for any
779
+ extra patches on the *right*. Since there can be a variable number of patches
780
+ this function also return the number of patches for each sequence in the batch.
781
+ Any embeddings on the right that are not allocated to a patch
782
+ (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
783
+ will be sent to a dummy patch, which is trimmed before returning.
784
+ """
785
+ bs, seq_len, emb_dim = h.shape
786
+
787
+ patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
788
+
789
+ reduced_embs = torch.zeros((bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device)
790
+ reduced_embs = reduced_embs.scatter_reduce(
791
+ src=h,
792
+ dim=1,
793
+ index=patch_ids,
794
+ reduce=reduction,
795
+ include_self=False,
796
+ )
797
+ reduced_embs = reduced_embs[:, :max_num_patches, :]
798
+
799
+ return reduced_embs
800
+
801
+
802
+ class BLTLocalDecoder(BLTLocalModelBase):
803
+ def __init__(self, config: BLTConfig):
804
+ super().__init__(config, component_type="decoder")
805
+
806
+ # Model configuration flags
807
+ self.cross_attn_decoder = config.cross_attn_decoder
808
+ self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder
809
+ self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling
810
+ self.cross_attn_nheads = config.cross_attn_nheads
811
+
812
+ self.norm = RMSNorm(self.dim, eps=config.norm_eps)
813
+
814
+ if self.cross_attn_decoder:
815
+ self.cross_attn_layers = torch.nn.ModuleList()
816
+ layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1
817
+ for _ in range(layers_to_add):
818
+ self.cross_attn_layers.append(
819
+ BLTCrossAttention(
820
+ dim=self.dim,
821
+ head_dim=self.dim // self.cross_attn_nheads,
822
+ n_heads=self.cross_attn_nheads,
823
+ n_kv_heads=self.cross_attn_nheads,
824
+ norm_eps=config.norm_eps,
825
+ )
826
+ )
827
+
828
+ self.output = nn.Linear(
829
+ self.dim,
830
+ config.vocab_size,
831
+ bias=False,
832
+ )
833
+
834
+ def forward(
835
+ self,
836
+ tokens: torch.Tensor,
837
+ embeds: Optional[torch.Tensor],
838
+ patch_embeds: Optional[torch.Tensor] = None,
839
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
840
+ cross_mask: Optional[torch.Tensor] = None,
841
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
842
+ ):
843
+ bs, seqlen = tokens.shape
844
+ assert embeds is not None, "Embeddings must be provided"
845
+
846
+ if mask is None:
847
+ mask = create_causal_mask(
848
+ seqlen,
849
+ self.attn_impl,
850
+ "local_block_causal",
851
+ sliding_window=self.sliding_window,
852
+ tokens=tokens,
853
+ eos_id=self.eos_id,
854
+ )
855
+
856
+ h = embeds
857
+
858
+ if self.patch_embedding_projection is not None:
859
+ assert patch_embeds is not None, "Patch embeddings must be passed."
860
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
861
+ if self.cross_attn_k is not None:
862
+ patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim)
863
+
864
+ if patch_embeds is not None and not self.cross_attn_decoder:
865
+ h = h + patch_embeds
866
+
867
+ freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
868
+
869
+ h = F.dropout(h, p=self.dropout, training=self.training)
870
+ for i, layer in enumerate(self.layers):
871
+ if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder):
872
+ # Use cross attention to extract info from patch_embeds into h
873
+ h_cross = self.cross_attn_layers[i](
874
+ x=h,
875
+ kv=patch_embeds,
876
+ mask=cross_mask,
877
+ )
878
+ h = h + h_cross
879
+
880
+ h = layer(h, freqs_cis, tok_idx=None, mask=mask, attn_impl=self.attn_impl)
881
+
882
+ h_preds = self.norm(h)
883
+ h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
884
+ h_preds = self.output(h_preds)
885
+ h_preds = h_preds.float()
886
+ return h_preds, cache
887
+
888
+
889
+ class BLTCrossAttention(nn.Module):
890
+ def __init__(
891
+ self,
892
+ dim: int,
893
+ head_dim: int,
894
+ n_heads: int,
895
+ n_kv_heads: int,
896
+ norm_eps: float,
897
+ ):
898
+ super().__init__()
899
+
900
+ self.dim = dim
901
+ self.head_dim = head_dim
902
+
903
+ self.n_heads = n_heads
904
+ self.n_kv_heads = n_kv_heads
905
+ self.heads_per_group = self.n_heads // self.n_kv_heads
906
+
907
+ self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps)
908
+ self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
909
+
910
+ self.wq = nn.Linear(
911
+ dim,
912
+ n_heads * head_dim,
913
+ bias=False,
914
+ )
915
+ self.wk = nn.Linear(
916
+ dim,
917
+ n_kv_heads * head_dim,
918
+ bias=False,
919
+ )
920
+ self.wv = nn.Linear(
921
+ dim,
922
+ n_kv_heads * head_dim,
923
+ bias=False,
924
+ )
925
+
926
+ self.wo = nn.Linear(
927
+ n_heads * head_dim,
928
+ dim,
929
+ bias=False,
930
+ )
931
+
932
+ def forward(
933
+ self,
934
+ x: torch.Tensor,
935
+ kv: torch.Tensor,
936
+ mask: Optional[Union[BlockMask, str]] = None,
937
+ ) -> torch.Tensor:
938
+ # B S D
939
+ bsz, seq_len, _ = x.shape
940
+ _, slen_kv, _ = kv.shape
941
+ x_norm = self.cross_attn_norm_q(x)
942
+ kv = self.cross_attn_norm_kv(kv)
943
+
944
+ xq = self.wq(x_norm)
945
+ xk = self.wk(kv)
946
+ xv = self.wv(kv)
947
+
948
+ output_shape = xq.shape
949
+ # B S D -> B S H D
950
+ xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
951
+ xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
952
+ xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
953
+
954
+ xk = repeat_kv(xk, self.heads_per_group, dim=2)
955
+ xv = repeat_kv(xv, self.heads_per_group, dim=2)
956
+
957
+ # assert mask is None or isinstance(mask, BlockMask)
958
+ xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv))
959
+ # output = flex_attention_comp(xq, xk, xv, block_mask=mask)
960
+ is_causal = (mask == "causal") if isinstance(mask, str) else False
961
+ mask = mask if isinstance(mask, torch.Tensor) else None
962
+ mask = mask.to(dtype=xq.dtype).to(xq.device)
963
+ output = F.scaled_dot_product_attention(
964
+ xq,
965
+ xk,
966
+ xv,
967
+ is_causal=is_causal,
968
+ attn_mask=mask,
969
+ )
970
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
971
+
972
+ output = self.wo(output.reshape(output_shape))
973
+
974
+ return x + output
975
+
976
+
977
+ class BLTGlobalTransformer(nn.Module):
978
+ def __init__(self, config):
979
+ super().__init__()
980
+
981
+ self.config = config
982
+
983
+ self.dim = config.dim_global
984
+ self.rope_embeddings = RotaryEmbedding(
985
+ theta=config.rope_theta,
986
+ head_dim=config.head_dim or self.config.dim_global // config.n_heads_global,
987
+ max_seqlen=config.max_seqlen,
988
+ rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product,
989
+ )
990
+ # Handle both eos_id and eos_token_id for compatibility
991
+ self.eos_id = getattr(config, "eos_id", getattr(config, "eos_token_id", 2))
992
+
993
+ # Create parameter dict for BLTTransformerLayers
994
+ layer_params = {
995
+ "dim": self.dim,
996
+ "n_heads": config.n_heads_global,
997
+ "head_dim": config.head_dim,
998
+ "n_kv_heads": getattr(config, "n_kv_heads_global", None),
999
+ "rope_theta": config.rope_theta,
1000
+ "multiple_of": getattr(config, "multiple_of", 256),
1001
+ "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None),
1002
+ "norm_eps": config.norm_eps,
1003
+ }
1004
+
1005
+ self.layers = nn.ModuleList()
1006
+ for _ in range(config.n_layers_global):
1007
+ self.layers.append(BLTTransformerLayer(layer_params))
1008
+
1009
+ self.token_embedding_projection = None
1010
+ if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim:
1011
+ self.token_embedding_projection = nn.Linear(
1012
+ config.global_dim_patch_emb,
1013
+ config.dim_global,
1014
+ bias=False,
1015
+ )
1016
+
1017
+ def forward(
1018
+ self,
1019
+ tokens: torch.Tensor,
1020
+ tok_idx: Optional[torch.Tensor] = None,
1021
+ embeds: Optional[torch.Tensor] = None,
1022
+ mask: Optional[Union[BlockMask, torch.Tensor, str]] = None,
1023
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
1024
+ ):
1025
+ bs, seqlen = tokens.shape
1026
+
1027
+ h = embeds
1028
+
1029
+ mask = (
1030
+ mask
1031
+ if mask is not None
1032
+ else create_causal_mask(
1033
+ seqlen,
1034
+ self.config.attn_impl,
1035
+ self.config.attn_bias_type,
1036
+ tokens=tokens,
1037
+ eos_id=self.eos_id,
1038
+ )
1039
+ )
1040
+
1041
+ if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
1042
+ h = self.token_embedding_projection(h)
1043
+
1044
+ h = F.dropout(h, p=self.config.dropout, training=self.training)
1045
+ freq_cis = self.rope_embeddings(seqlen=self.config.max_seqlen, tok_idx=tok_idx)
1046
+
1047
+ for i, layer in enumerate(self.layers):
1048
+ h = layer(h, freq_cis, tok_idx=None, mask=mask, attn_impl=self.config.attn_impl)
1049
+
1050
+ return h, cache
1051
+
1052
+
1053
+ def compute_hash_embeddings(
1054
+ local_encoder_tokens: torch.Tensor,
1055
+ local_encoder,
1056
+ encoder_hash_tok_embedding: nn.ModuleList,
1057
+ encoder_hash_byte_group_nb_functions: int,
1058
+ encoder_hash_byte_group_size: list,
1059
+ encoder_hash_byte_group_vocab: int,
1060
+ ) -> torch.Tensor:
1061
+ """
1062
+ Compute embeddings using hash token embeddings.
1063
+
1064
+ Args:
1065
+ local_encoder_tokens: Input tokens tensor
1066
+ local_encoder: Encoder object with tok_embeddings method
1067
+ encoder_hash_tok_embedding: ModuleList of hash token embeddings
1068
+ encoder_hash_byte_group_nb_functions: Number of hash functions
1069
+ encoder_hash_byte_group_size: List of byte group sizes
1070
+ encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings
1071
+
1072
+ Returns:
1073
+ torch.Tensor: Combined embeddings
1074
+ """
1075
+ if encoder_hash_tok_embedding is None:
1076
+ return None
1077
+
1078
+ local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens)
1079
+
1080
+ i = 0
1081
+ for func_nb in range(encoder_hash_byte_group_nb_functions):
1082
+ for byte_group_size in encoder_hash_byte_group_size:
1083
+ hash_ids = byte_group_hash_function(
1084
+ local_encoder_tokens,
1085
+ byte_group_size,
1086
+ hash_func_nb=func_nb,
1087
+ max_hash=encoder_hash_byte_group_vocab,
1088
+ )
1089
+ hash_tok_embedding = encoder_hash_tok_embedding[i]
1090
+ local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids)
1091
+ i += 1
1092
+
1093
+ assert i == len(encoder_hash_tok_embedding)
1094
+ return local_encoder_embeds
1095
+
1096
+
1097
+ class BLTPreTrainedModel(PreTrainedModel):
1098
+ config_class = BLTConfig
1099
+ base_model_prefix = "model"
1100
+ supports_gradient_checkpointing = True
1101
+ _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"]
1102
+ _skip_keys_device_placement = ["past_key_values"]
1103
+ _supports_flash_attn_2 = False # BLT uses its own attention implementation
1104
+ _supports_sdpa = True
1105
+ _supports_cache_class = False
1106
+
1107
+ def _init_weights(self, module):
1108
+ if isinstance(module, nn.Linear):
1109
+ std = getattr(module, '_custom_std', module.in_features ** (-0.5))
1110
+
1111
+ nn.init.trunc_normal_(
1112
+ module.weight,
1113
+ mean=0.0,
1114
+ std=std,
1115
+ a=-3 * std,
1116
+ b=3 * std,
1117
+ )
1118
+ if module.bias is not None:
1119
+ nn.init.zeros_(module.bias)
1120
+
1121
+ elif isinstance(module, nn.Embedding):
1122
+ std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5))
1123
+
1124
+ nn.init.trunc_normal_(
1125
+ module.weight,
1126
+ mean=0.0,
1127
+ std=std,
1128
+ a=-3 * std,
1129
+ b=3 * std,
1130
+ )
1131
+
1132
+ elif isinstance(module, (nn.RMSNorm, nn.LayerNorm)):
1133
+ nn.init.ones_(module.weight)
1134
+ if module.bias is not None:
1135
+ nn.init.zeros_(module.bias)
1136
+
1137
+ elif isinstance(module, RotaryEmbedding):
1138
+ module.freqs_cis[...] = precompute_freqs_cis(
1139
+ dim=module.head_dim,
1140
+ end=module.max_seqlen,
1141
+ theta=module.theta,
1142
+ rope_use_fp32_in_outer_product=module.rope_use_fp32_in_outer_product,
1143
+ )
1144
+
1145
+ elif isinstance(module, BLTModel):
1146
+ if module.encoder_hash_tok_embedding is not None:
1147
+ emb_std = module.local_encoder.dim ** (-0.5)
1148
+ for emb in module.encoder_hash_tok_embedding:
1149
+ emb._custom_std = emb_std
1150
+
1151
+ elif isinstance(module, (BLTLocalEncoder, BLTLocalDecoder)):
1152
+ if module.token_embedding_projection is not None:
1153
+ module.token_embedding_projection._custom_std = module.dim ** (-0.5)
1154
+
1155
+ if module.patch_embedding_projection is not None:
1156
+ module.patch_embedding_projection._custom_std = module.dim_patch_emb ** (-0.5)
1157
+
1158
+ elif isinstance(module, BLTGlobalTransformer):
1159
+ if module.token_embedding_projection is not None:
1160
+ module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5)
1161
+
1162
+ elif isinstance(module, BLTPatcher):
1163
+ emb_std = module.config.patcher_dim ** (-0.5)
1164
+ module.tok_embeddings._custom_std = emb_std
1165
+ module.output._custom_std = emb_std
1166
+
1167
+
1168
+ class BLTModel(BLTPreTrainedModel):
1169
+ def __init__(self, config: BLTConfig):
1170
+ super().__init__(config)
1171
+
1172
+ self.config = config
1173
+ self.local_encoder = BLTLocalEncoder(config)
1174
+ self.global_transformer = BLTGlobalTransformer(config)
1175
+ self.local_decoder = BLTLocalDecoder(config)
1176
+
1177
+ self.encoder_hash_tok_embedding = init_hash_embeddings(
1178
+ config,
1179
+ local_encoder_dim=self.local_encoder.dim,
1180
+ encoder_hash_byte_group_size=config.encoder_hash_byte_group_size,
1181
+ )
1182
+
1183
+ if config.patch_in_forward:
1184
+ self.patcher = BLTPatcher(config)
1185
+ self.patcher.eval()
1186
+ for param in self.patcher.parameters():
1187
+ param.requires_grad = False
1188
+ else:
1189
+ self.patcher = None
1190
+
1191
+ def forward(
1192
+ self,
1193
+ tokens: torch.Tensor,
1194
+ patch_lengths: Optional[torch.Tensor] = None,
1195
+ ):
1196
+ # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings
1197
+ # are no longer used in the final BLT model
1198
+
1199
+ bs, N = tokens.shape # Batch size and sequence length
1200
+
1201
+ local_encoder_tokens, local_decoder_tokens = tokens, tokens
1202
+
1203
+ # Patching
1204
+ if patch_lengths is None:
1205
+ # assert (
1206
+ # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward
1207
+ # ), "Patch in forward not enabled and no patch_lengths passed."
1208
+
1209
+ # PATCHER MODEL DEFINED
1210
+ if self.config.patching_mode == PatchingModeEnum.entropy:
1211
+ _, patch_lengths, _ = self.patcher(
1212
+ local_encoder_tokens,
1213
+ patch_size=self.config.patch_size,
1214
+ include_next_token=True,
1215
+ threshold=self.config.patching_threshold,
1216
+ max_patch_length=self.config.max_patch_length,
1217
+ patching_batch_size=self.config.patching_batch_size,
1218
+ device=self.config.patching_device,
1219
+ )
1220
+ else:
1221
+ # self.config.patching_mode == PatchingModeEnum.byte
1222
+ bs, seq_len = local_encoder_tokens.shape
1223
+ seq_len_next_tok = seq_len + 1 # include_next_token=True
1224
+ patch_lengths = torch.ones(
1225
+ (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device
1226
+ )
1227
+
1228
+ patch_lengths = process_patch_lengths(patch_lengths, self.config.max_patch_length)
1229
+
1230
+ #assert torch.min(patch_lengths) >= 0
1231
+ # Generate patch IDs from patch_lengths
1232
+ patch_ids = self._patch_ids_from_lengths(patch_lengths, local_encoder_tokens.shape[-1])
1233
+ # assert torch.max(patch_ids) + 1 <= torch.max((patch_lengths != 0).sum(dim=-1)), (
1234
+ # f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}"
1235
+ # )
1236
+
1237
+ cross_attn_mask_enc = None
1238
+ # Cross-attention encoder
1239
+ if self.config.cross_attn_encoder:
1240
+ cross_attn_mask_enc = cross_attn_mask(
1241
+ patch_ids,
1242
+ patch_lengths,
1243
+ N,
1244
+ patches_as_queries=True,
1245
+ cross_attn_k=self.config.cross_attn_k,
1246
+ window=self.config.cross_attn_window_encoder,
1247
+ block_mask=self.config.cross_attn_use_flex_attention,
1248
+ )
1249
+
1250
+ # Hashing and embedding
1251
+ local_encoder_embeds = compute_hash_embeddings(
1252
+ local_encoder_tokens=local_encoder_tokens,
1253
+ local_encoder=self.local_encoder,
1254
+ encoder_hash_tok_embedding=self.encoder_hash_tok_embedding,
1255
+ encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions,
1256
+ encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size,
1257
+ encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab,
1258
+ )
1259
+
1260
+ # NOTE: Frequency-based n-gram embeddings removed as per paper
1261
+ # The final BLT model uses only hash-based n-gram embeddings
1262
+
1263
+ # Local encoder
1264
+ (h_encoder, h_cross), cache_encoder = self.local_encoder(
1265
+ tokens=local_encoder_tokens,
1266
+ embeds=local_encoder_embeds,
1267
+ patch_embeds=None,
1268
+ cross_mask=cross_attn_mask_enc,
1269
+ num_patches=patch_lengths.shape[1],
1270
+ patch_ids=patch_ids,
1271
+ )
1272
+
1273
+ # Downsampling
1274
+ h = h_cross.view(bs, patch_lengths.shape[1], -1)
1275
+
1276
+ # Global transformer
1277
+ global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.config.boe_id)
1278
+ rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id)
1279
+ eos_patch_ids = patch_ids[rows, cols]
1280
+ global_tokens[rows, eos_patch_ids] = self.config.eos_token_id
1281
+
1282
+ h, _ = self.global_transformer(
1283
+ embeds=h,
1284
+ tokens=global_tokens,
1285
+ )
1286
+
1287
+ # Unpatching
1288
+
1289
+ dec_embeds = h_encoder
1290
+
1291
+ # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens), so we need to map decoder positions to the remaining patches.
1292
+ decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], local_decoder_tokens.shape[-1])
1293
+ # assert torch.max(decoder_patch_ids) + 1 <= h.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}"
1294
+ # assert decoder_patch_ids.shape[1] == dec_embeds.shape[1], (
1295
+ # f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}"
1296
+ # )
1297
+
1298
+ # Cross-attention decoder
1299
+ if not self.config.cross_attn_decoder:
1300
+ h = torch.gather(h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]))
1301
+ cross_attn_mask_dec = None
1302
+ # assert local_decoder_tokens.shape == h.shape[:-1]
1303
+ else:
1304
+ cross_attn_mask_dec = cross_attn_mask(
1305
+ decoder_patch_ids,
1306
+ patch_lengths,
1307
+ N,
1308
+ patches_as_queries=False,
1309
+ cross_attn_k=self.config.cross_attn_k,
1310
+ window=self.config.cross_attn_window_decoder,
1311
+ block_mask=self.config.cross_attn_use_flex_attention,
1312
+ )
1313
+
1314
+ # Local decoder
1315
+ output, _ = self.local_decoder(
1316
+ embeds=dec_embeds,
1317
+ patch_embeds=h,
1318
+ tokens=local_decoder_tokens,
1319
+ cross_mask=cross_attn_mask_dec,
1320
+ )
1321
+ return output
1322
+
1323
+ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
1324
+ """
1325
+ Convert patch lengths to patch IDs for each token position.
1326
+ For each token position in the sequence, determines which patch it belongs to.
1327
+
1328
+ Args:
1329
+ patch_lengths: [batch_size, num_patches] - length of each patch
1330
+ seq_len: total sequence length
1331
+
1332
+ Returns:
1333
+ patch_ids: [batch_size, seq_len] - patch index for each token position
1334
+
1335
+ Example:
1336
+ patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1
1337
+ seq_len = 10
1338
+ Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]]
1339
+ # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3
1340
+ """
1341
+ batch_size, num_patches = patch_lengths.shape
1342
+
1343
+ # Create patch start positions: [0, 3, 5, 9] for the example above
1344
+ patch_starts = torch.cat(
1345
+ [
1346
+ torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
1347
+ patch_lengths.cumsum(dim=-1)[:, :-1], # cumsum without the final total
1348
+ ],
1349
+ dim=-1,
1350
+ )
1351
+
1352
+ # For each token position, find which patch it belongs to
1353
+ # by finding the rightmost patch start that's <= the position
1354
+ token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1]
1355
+
1356
+ # Broadcasting: patch_starts[batch, patch] <= token_positions[position]
1357
+ # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t
1358
+ position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)
1359
+
1360
+ # Count how many patch starts are <= each position, then subtract 1 to get patch index
1361
+ patch_ids = position_ge_patch_start.sum(dim=-1) - 1
1362
+
1363
+ return patch_ids
1364
+
1365
+
1366
+ class BLTPatcher(BLTPreTrainedModel):
1367
+ def __init__(self, config):
1368
+ super().__init__(config)
1369
+
1370
+ self.rope_embeddings = RotaryEmbedding(
1371
+ theta=config.patcher_rope_theta,
1372
+ head_dim=config.patcher_head_dim or config.patcher_dim // config.patcher_n_heads,
1373
+ max_seqlen=config.patcher_max_seqlen,
1374
+ rope_use_fp32_in_outer_product=config.patcher_rope_use_fp32_in_outer_product,
1375
+ )
1376
+
1377
+ self.layers = nn.ModuleList()
1378
+ for _ in range(config.patcher_n_layers):
1379
+ self.layers.append(
1380
+ BLTTransformerLayer(
1381
+ {
1382
+ "dim": config.patcher_dim,
1383
+ "n_heads": config.patcher_n_heads,
1384
+ "head_dim": config.patcher_head_dim,
1385
+ "n_kv_heads": config.patcher_n_kv_heads,
1386
+ "rope_theta": config.patcher_rope_theta,
1387
+ "multiple_of": config.patcher_multiple_of,
1388
+ "ffn_dim_multiplier": config.patcher_ffn_dim_multiplier,
1389
+ "norm_eps": config.patcher_norm_eps,
1390
+ }
1391
+ )
1392
+ )
1393
+
1394
+ #assert config.patcher_vocab_size > 0
1395
+
1396
+ self.tok_embeddings = torch.nn.Embedding(config.patcher_vocab_size, config.patcher_dim)
1397
+
1398
+ self.norm = RMSNorm(config.patcher_dim, eps=config.patcher_norm_eps)
1399
+
1400
+ self.output = nn.Linear(
1401
+ config.patcher_dim,
1402
+ config.patcher_vocab_size,
1403
+ bias=False,
1404
+ )
1405
+
1406
+ def forward(
1407
+ self,
1408
+ token_values: torch.Tensor,
1409
+ patch_size: Optional[int] = None,
1410
+ include_next_token: bool = True,
1411
+ threshold: Optional[float] = None,
1412
+ max_patch_length: Optional[int] = None,
1413
+ patching_batch_size: int = 1,
1414
+ device: Optional[str] = None,
1415
+ ):
1416
+
1417
+ # Handle chunked processing for entropy calculation
1418
+ entropies = []
1419
+ preds = []
1420
+ max_length = self.config.patcher_max_seqlen
1421
+ batch_numel = max_length * patching_batch_size
1422
+ splits = torch.split(token_values.flatten(), batch_numel)
1423
+
1424
+ for split in splits:
1425
+ pad_size = (max_length - (split.numel() % max_length)) % max_length
1426
+ pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False)
1427
+ split = torch.cat((split, pad), dim=0)
1428
+ split = split.reshape(-1, max_length)
1429
+ if device is not None:
1430
+ split = split.to(device)
1431
+
1432
+ # Process chunk: embeddings -> layers -> output
1433
+ bsz, seqlen = split.shape
1434
+ h = self.tok_embeddings(split)
1435
+ chunk_mask = create_causal_mask(
1436
+ seqlen,
1437
+ self.config.patcher_attn_impl ,
1438
+ self.config.patcher_attn_bias_type,
1439
+ sliding_window=self.config.patcher_sliding_window,
1440
+ tokens=split,
1441
+ eos_id=self.config.eos_id,
1442
+ )
1443
+
1444
+ freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None)
1445
+
1446
+ for i, layer in enumerate(self.layers):
1447
+ h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=self.config.patcher_attn_impl)
1448
+
1449
+ pred = self.output(self.norm(h))
1450
+ pred = pred.reshape(-1, pred.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab]
1451
+ preds.append(pred)
1452
+ pred_entropies = self.entropy(pred)
1453
+ entropies.append(pred_entropies)
1454
+
1455
+ concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape)
1456
+ concat_preds = torch.cat(preds, dim=0).reshape(token_values.shape[0], -1)
1457
+
1458
+ # Always compute patch lengths from concatenated entropies
1459
+ bs, seq_len = token_values.shape
1460
+ seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
1461
+
1462
+ # Find patch start IDs based on entropy
1463
+ if patch_size is not None:
1464
+ patch_start_ids = self.find_entropy_patch_start_ids(
1465
+ concat_entropies,
1466
+ patch_size,
1467
+ include_next_token=include_next_token,
1468
+ threshold=threshold
1469
+ )
1470
+ patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok)
1471
+ else:
1472
+ # Default to byte-level patching
1473
+ patch_lengths = torch.ones((bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device)
1474
+
1475
+ patch_lengths = process_patch_lengths(patch_lengths, max_patch_length)
1476
+ return concat_entropies, patch_lengths, concat_preds
1477
+
1478
+
1479
+ @staticmethod
1480
+ def entropy(scores):
1481
+ """
1482
+ scores: [bs, seq_len, vocab]
1483
+ returns [bs, seq_len]
1484
+
1485
+ Computes the entropy for each token in the batch.
1486
+ Note: uses natural log.
1487
+ """
1488
+ log_probs = F.log_softmax(scores, dim=-1)
1489
+ probs = torch.exp(log_probs)
1490
+ p_log_p = log_probs * probs
1491
+ entropy = -p_log_p.sum(dim=-1)
1492
+ return entropy
1493
+
1494
+ @staticmethod
1495
+ def patch_start_ids_from_patch_start_mask(patch_start_mask):
1496
+ bs, trunc_seq_len = patch_start_mask.shape
1497
+ max_patches = patch_start_mask.sum(dim=1).max()
1498
+ if max_patches == 0:
1499
+ patch_start_ids = torch.full(
1500
+ (bs, trunc_seq_len),
1501
+ trunc_seq_len,
1502
+ dtype=torch.long,
1503
+ device=patch_start_mask.device,
1504
+ )
1505
+ else:
1506
+ patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(bs, 1)
1507
+ extra_patch_ids = torch.full(
1508
+ (bs, trunc_seq_len),
1509
+ trunc_seq_len,
1510
+ dtype=torch.long,
1511
+ device=patch_start_mask.device,
1512
+ )
1513
+ all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
1514
+ patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1)
1515
+ patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, trunc_seq_len)[:, :max_patches]
1516
+ return patch_start_ids
1517
+
1518
+ @staticmethod
1519
+ def patch_lengths_from_start_ids(patch_start_ids, seq_len):
1520
+ """
1521
+ Calculate patch lengths from start ids.
1522
+ start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then
1523
+ the rest are filled to the seq len.
1524
+ seq_len: ex: 7 length of the sequence
1525
+
1526
+ returns the patch lengths:
1527
+ [1, 6] for the above example.
1528
+ """
1529
+ last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1)
1530
+ patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1)
1531
+ patch_lengths = patch_end_ids - patch_start_ids + 1
1532
+ assert torch.all(patch_lengths >= 0), f"{patch_lengths}"
1533
+ assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}"
1534
+ return patch_lengths
1535
+
1536
+ @staticmethod
1537
+ def find_entropy_patch_start_ids(
1538
+ entropies,
1539
+ patch_size=None,
1540
+ threshold=None,
1541
+ include_next_token=True,
1542
+ ):
1543
+ """
1544
+ Use entropies to find the start ids of each patch.
1545
+ Use patch_size or threshold to figure out the total number of patches to allocate.
1546
+
1547
+ When threshold is not None the number of patches is not constant between
1548
+ different sequences, but patches can be identified incrementally rather than
1549
+ decided globally using the entire sequence.
1550
+ """
1551
+ bs, seq_len = entropies.shape[:2]
1552
+
1553
+ first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(bs, 1)
1554
+ preds_truncation_len = first_ids.shape[1] # remove the first preds because they will be start of patches.
1555
+ entropies = entropies[:, 1:]
1556
+ if threshold is None:
1557
+ num_patches = seq_len // patch_size
1558
+ patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices
1559
+ patch_start_ids = patch_start_ids.sort(dim=1).values
1560
+ else:
1561
+ patch_start_mask = entropies > threshold
1562
+ if not include_next_token:
1563
+ patch_start_mask = patch_start_mask[:, :-1]
1564
+ # patch_start_mask[1:] |= tokens[:-1] < OFFSET
1565
+ patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask)
1566
+
1567
+ patch_start_ids = torch.cat((first_ids, patch_start_ids + preds_truncation_len), dim=1)
1568
+ return patch_start_ids
1569
+
1570
+ def init_hash_embeddings(
1571
+ config,
1572
+ local_encoder_dim: int,
1573
+ encoder_hash_byte_group_size: list,
1574
+ ):
1575
+ """Initialize hash-based token embeddings for the BLT encoder."""
1576
+ if config.encoder_hash_byte_group_size is None:
1577
+ return None
1578
+
1579
+ embeddings = []
1580
+ emb_dim = local_encoder_dim
1581
+ encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab
1582
+
1583
+ for _ in range(config.encoder_hash_byte_group_nb_functions):
1584
+ for _ in encoder_hash_byte_group_size:
1585
+ embeddings.append(
1586
+ nn.Embedding(
1587
+ encoder_hash_byte_group_vocab,
1588
+ emb_dim,
1589
+ )
1590
+ )
1591
+
1592
+ return nn.ModuleList(embeddings)
1593
+
1594
+
1595
+ __all__ = [
1596
+ "BLTPreTrainedModel",
1597
+ "BLTModel",
1598
+ "BLTPatcher",
1599
+ "BLTLocalEncoder",
1600
+ "BLTLocalDecoder",
1601
+ "BLTGlobalTransformer",
1602
+ ]
backup_blt_wip copy/modular_blt.py ADDED
@@ -0,0 +1,1180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """BLT model."""
16
+
17
+ from ...utils import is_torch_flex_attn_available, logging
18
+ from typing import Callable, List, Optional, Tuple, Union
19
+
20
+ from ...cache_utils import Cache
21
+ from ...activations import ACT2FN
22
+
23
+ import torch
24
+ import torch.distributions
25
+ import torch.nn
26
+ import torch.nn as nn
27
+ from torch.nn import functional as F
28
+
29
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
30
+
31
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
32
+ from .configuration_blt import (
33
+ BLTConfig,
34
+ BLTLocalEncoderConfig,
35
+ BLTLocalDecoderConfig,
36
+ BLTGlobalTransformerConfig,
37
+ BLTPatcherConfig,
38
+ PatchingModeEnum,
39
+ )
40
+
41
+ if is_torch_flex_attn_available():
42
+ from torch.nn.attention.flex_attention import BlockMask
43
+ from ...integrations.flex_attention import make_flex_block_causal_mask
44
+
45
+ from ..mllama.modeling_mllama import repeat_kv, eager_attention_forward, MllamaRotaryEmbedding, MllamaTextRMSNorm, MllamaCrossAttentionDecoderLayer, MllamaTextCrossAttention, MllamaTextSelfAttention
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
51
+ """
52
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
53
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
54
+ """
55
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
56
+ if n_rep == 1:
57
+ return hidden_states
58
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
59
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
60
+
61
+
62
+ def eager_attention_forward(
63
+ module: nn.Module,
64
+ query: torch.Tensor,
65
+ key: torch.Tensor,
66
+ value: torch.Tensor,
67
+ attention_mask: Optional[torch.Tensor],
68
+ scaling: float,
69
+ dropout: float = 0.0,
70
+ **kwargs,
71
+ ):
72
+ key_states = repeat_kv(key, module.num_key_value_groups)
73
+ value_states = repeat_kv(value, module.num_key_value_groups)
74
+
75
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
76
+ if attention_mask is not None:
77
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
78
+ attn_weights = attn_weights + causal_mask
79
+
80
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
81
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
82
+ attn_output = torch.matmul(attn_weights, value_states)
83
+ attn_output = attn_output.transpose(1, 2).contiguous()
84
+
85
+ return attn_output, attn_weights
86
+
87
+
88
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
89
+ # TODO: not exactly equivalent to other transformers implementations,, need feedback
90
+ # Extract first head_dim//2 elements which correspond to the unique frequencies
91
+ # This matches the original BLT approach which uses head_dim//2 frequency pairs
92
+ head_dim = q.shape[-1]
93
+ cos_freqs = cos[..., :head_dim//2] # [B, S, D/2]
94
+ sin_freqs = sin[..., :head_dim//2] # [B, S, D/2]
95
+
96
+ # Expand cos/sin to match query/key tensor format [B, H, S, D/2]
97
+ cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2]
98
+ sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2]
99
+
100
+ # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ...
101
+ q_pairs = q.view(*q.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2]
102
+ k_pairs = k.view(*k.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2]
103
+
104
+ # Extract real and i parts
105
+ q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2]
106
+ k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2]
107
+
108
+ # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag]
109
+ q_real_rot = cos_freqs * q_real - sin_freqs * q_imag
110
+ q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag
111
+ k_real_rot = cos_freqs * k_real - sin_freqs * k_imag
112
+ k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag
113
+
114
+ # Recombine pairs and reshape back to original format
115
+ q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D]
116
+ k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D]
117
+
118
+ return q_rot.type_as(q), k_rot.type_as(k)
119
+
120
+
121
+ class BLTMLP(nn.Module):
122
+ def __init__(self, config):
123
+ super().__init__()
124
+ self.config = config
125
+ self.hidden_size = config.hidden_size
126
+ self.intermediate_size = config.intermediate_size
127
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
128
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
129
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
130
+ self.act_fn = ACT2FN[config.hidden_act]
131
+
132
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
133
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
134
+ return down_proj
135
+
136
+
137
+ class BLTRMSNorm(nn.Module):
138
+ def __init__(self, hidden_size, eps=1e-6):
139
+ """
140
+ BLTRMSNorm is equivalent to T5LayerNorm
141
+ """
142
+ super().__init__()
143
+ self.weight = nn.Parameter(torch.ones(hidden_size))
144
+ self.variance_epsilon = eps
145
+
146
+ def forward(self, hidden_states):
147
+ input_dtype = hidden_states.dtype
148
+ hidden_states = hidden_states.to(torch.float32)
149
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
150
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
151
+ return self.weight * hidden_states.to(input_dtype)
152
+
153
+ def extra_repr(self):
154
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
155
+
156
+
157
+ class BLTTransformerLayer(nn.Module):
158
+ def __init__(self, config, layer_idx: int):
159
+ super().__init__()
160
+ self.hidden_size = config.hidden_size
161
+ self.layer_idx = layer_idx
162
+
163
+ self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx)
164
+ self.mlp = BLTMLP(config)
165
+ self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps)
166
+ self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps)
167
+
168
+ def forward(
169
+ self,
170
+ hidden_states: torch.Tensor,
171
+ attention_mask: Optional[torch.Tensor] = None,
172
+ position_ids: Optional[torch.LongTensor] = None,
173
+ past_key_value: Optional[Cache] = None,
174
+ output_attentions: Optional[bool] = False,
175
+ use_cache: Optional[bool] = False,
176
+ cache_position: Optional[torch.LongTensor] = None,
177
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
178
+ **kwargs,
179
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
180
+ """
181
+ Args:
182
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
183
+ attention_mask (`torch.FloatTensor`, *optional*):
184
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
185
+ query_sequence_length, key_sequence_length)` if default attention is used.
186
+ position_ids (`torch.LongTensor`, *optional*):
187
+ Position indices of tokens in the sequence for RoPE computation.
188
+ past_key_value (`Cache`, *optional*): cached past key and value projection states
189
+ output_attentions (`bool`, *optional*):
190
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
191
+ returned tensors for more detail.
192
+ use_cache (`bool`, *optional*):
193
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
194
+ (see `past_key_values`).
195
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
196
+ Indices depicting the position of the input sequence tokens in the sequence
197
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
198
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
199
+ with `head_dim` being the embedding dimension of each attention head.
200
+ kwargs (`dict`, *optional*):
201
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
202
+ into the model
203
+ """
204
+ residual = hidden_states
205
+ hidden_states = self.input_layernorm(hidden_states)
206
+
207
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
208
+ hidden_states=hidden_states,
209
+ attention_mask=attention_mask,
210
+ position_ids=position_ids,
211
+ past_key_value=past_key_value,
212
+ output_attentions=output_attentions,
213
+ use_cache=use_cache,
214
+ cache_position=cache_position,
215
+ position_embeddings=position_embeddings,
216
+ **kwargs,
217
+ )
218
+ hidden_states = residual + hidden_states
219
+
220
+ residual = hidden_states
221
+ hidden_states = self.post_attention_layernorm(hidden_states)
222
+ hidden_states = self.mlp(hidden_states)
223
+ hidden_states = residual + hidden_states
224
+
225
+ outputs = (hidden_states,)
226
+
227
+ if output_attentions:
228
+ outputs += (self_attn_weights,)
229
+
230
+ if use_cache:
231
+ outputs += (present_key_value,)
232
+
233
+ return outputs
234
+
235
+
236
+ class BLTSelfAttention(nn.Module):
237
+ def __init__(self, config, layer_idx: int):
238
+ super().__init__()
239
+ self.config = config
240
+ self.num_heads = config.num_attention_heads
241
+ self.dropout = config.dropout
242
+ self.hidden_size = config.hidden_size
243
+ self.num_key_value_heads = config.num_key_value_heads
244
+ self.head_dim = config.hidden_size // self.num_heads
245
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
246
+ self.scaling = None
247
+ self.rope_theta = config.rope_theta
248
+ self.layer_idx = layer_idx
249
+
250
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
251
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
252
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
253
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
254
+
255
+ def forward(
256
+ self,
257
+ hidden_states: torch.Tensor,
258
+ attention_mask: torch.Tensor,
259
+ position_embeddings: torch.Tensor,
260
+ output_attentions: bool = False,
261
+ use_cache: bool = False,
262
+ past_key_value=None,
263
+ cache_position=None,
264
+ **kwargs,
265
+ ):
266
+ bsz, q_len, _ = hidden_states.size()
267
+
268
+ query_states = self.q_proj(hidden_states)
269
+ key_states = self.k_proj(hidden_states)
270
+ value_states = self.v_proj(hidden_states)
271
+
272
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
273
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
274
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
275
+
276
+ cos, sin = position_embeddings
277
+
278
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
279
+
280
+ if past_key_value is not None:
281
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
282
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
283
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
284
+
285
+ attention_interface: Callable = eager_attention_forward
286
+ output_attentions = False
287
+ # self.config._attn_implementation = "sdpa"
288
+ # self.scaling = None
289
+ if self.config._attn_implementation != "eager":
290
+ if self.config._attn_implementation == "sdpa" and output_attentions:
291
+ logger.warning_once(
292
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
293
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
294
+ )
295
+ else:
296
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
297
+
298
+ attn_output, attn_weights = attention_interface(
299
+ self,
300
+ query_states,
301
+ key_states,
302
+ value_states,
303
+ attention_mask,
304
+ dropout=0.0 if not self.training else self.dropout,
305
+ scaling=self.scaling,
306
+ **kwargs,
307
+ )
308
+
309
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
310
+ attn_output = self.o_proj(attn_output)
311
+
312
+ if not output_attentions:
313
+ attn_weights = None
314
+
315
+ return attn_output, attn_weights, past_key_value
316
+
317
+
318
+ def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0):
319
+ primes = [
320
+ 1000000007, 5915587277, 1500450271, 3267000013, 5754853343,
321
+ 4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313,
322
+ ]
323
+ prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device)
324
+ powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device)
325
+ prime_powers = prime ** powers
326
+ return torch.sum(token_tensor * prime_powers, dim=-1)
327
+
328
+
329
+ def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000):
330
+ """Hash token groups and map to range [0, max_hash]."""
331
+ with torch.no_grad():
332
+ batch_size, seq_len = token_ids.shape
333
+ # Add padding for sliding window
334
+ padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device)
335
+ padded_tokens = torch.cat([padding, token_ids], dim=1)
336
+
337
+ # Create sliding windows and compute hashes
338
+ windows = padded_tokens.unfold(1, group_size, 1)
339
+ hashes = rolling_polynomial_hash(windows, hash_func_nb)
340
+ hash_values = hashes % max_hash
341
+
342
+ hash_values.requires_grad = False
343
+ return hash_values
344
+
345
+
346
+ def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list):
347
+ """Initialize hash-based token embeddings for the BLT encoder."""
348
+ num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size)
349
+ embeddings = [
350
+ nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim)
351
+ for _ in range(num_embeddings)
352
+ ]
353
+ return nn.ModuleList(embeddings)
354
+
355
+
356
+ def compute_hash_embeddings(
357
+ local_encoder_tokens: torch.Tensor,
358
+ local_encoder,
359
+ encoder_hash_tok_embedding: nn.ModuleList,
360
+ encoder_hash_byte_group_nb_functions: int,
361
+ encoder_hash_byte_group_size: list,
362
+ encoder_hash_byte_group_vocab: int,
363
+ ) -> torch.Tensor:
364
+ """Compute token embeddings enhanced with hash-based embeddings."""
365
+ embeddings = local_encoder.embed_tokens(local_encoder_tokens)
366
+ embedding_idx = 0
367
+ for func_nb in range(encoder_hash_byte_group_nb_functions):
368
+ for group_size in encoder_hash_byte_group_size:
369
+ hash_ids = byte_group_hash_function(
370
+ local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab
371
+ )
372
+ embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids)
373
+ embedding_idx += 1
374
+
375
+ return embeddings
376
+
377
+
378
+ def _prepare_patch_cross_attention_mask(
379
+ patch_ids: torch.Tensor,
380
+ num_patches: int,
381
+ sequence_length: int,
382
+ patches_as_queries: bool = False,
383
+ cross_attn_k: int = 1,
384
+ dtype: torch.dtype = torch.float32,
385
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
386
+ """
387
+ Prepare cross-attention mask for patch-based attention, following mllama's robust approach.
388
+
389
+ This function creates masks that control which patches can attend to which other patches,
390
+ with support for query/key role swapping and cross-attention multipliers.
391
+
392
+ Args:
393
+ patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids.
394
+ num_patches (int): Total number of patches.
395
+ sequence_length (int): Length of the sequence.
396
+ patches_as_queries (bool): If True, patches are used as queries, otherwise as keys.
397
+ cross_attn_k (int): Cross-attention multiplier for repeating patches.
398
+ dtype (torch.dtype): Data type for the output mask.
399
+
400
+ Returns:
401
+ Tuple[torch.Tensor, torch.Tensor]:
402
+ - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len]
403
+ - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows
404
+ """
405
+ batch_size, seq_len = patch_ids.shape
406
+ device = patch_ids.device
407
+
408
+ # Determine query and key lengths based on configuration
409
+ if patches_as_queries:
410
+ q_len = num_patches * cross_attn_k
411
+ kv_len = sequence_length
412
+ # Create patch-to-sequence mapping
413
+ q_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(-1).expand(
414
+ batch_size, num_patches, seq_len
415
+ )
416
+ kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len)
417
+ else:
418
+ q_len = sequence_length
419
+ kv_len = num_patches * cross_attn_k
420
+ # Create sequence-to-patch mapping
421
+ q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches)
422
+ kv_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(
423
+ batch_size, seq_len, num_patches
424
+ )
425
+
426
+ # Create base attention mask - boolean mask where True means "should attend"
427
+ # Exact patch matching
428
+ cross_attention_mask = q_patch_ids == kv_patch_ids
429
+
430
+ # Handle cross_attn_k multiplier by repeating along appropriate dimension
431
+ repeat_dim = 1 if patches_as_queries else -1
432
+ cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim)
433
+
434
+ # Validate dimensions
435
+ expected_shape = (batch_size, q_len, kv_len)
436
+ if cross_attention_mask.shape != expected_shape:
437
+ raise ValueError(f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}")
438
+
439
+ # Reshape so it can be used by attn module - add head dimension
440
+ cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len]
441
+
442
+ # Invert the mask (following mllama pattern exactly)
443
+ # True -> 0.0 (attend), False -> 1.0 (will become -inf)
444
+ inverted_cross_attn_mask = (1.0 - cross_attention_mask.to(dtype))
445
+ cross_attention_mask = inverted_cross_attn_mask.masked_fill(
446
+ inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
447
+ )
448
+
449
+ # Apply full-row bias (following mllama pattern exactly)
450
+ # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's
451
+ # last dimension contains negative infinity values, otherwise it's 1
452
+ negative_inf_value = torch.finfo(dtype).min
453
+ full_text_row_masked_out_mask = (
454
+ (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
455
+ )
456
+ cross_attention_mask *= full_text_row_masked_out_mask
457
+
458
+ return cross_attention_mask, full_text_row_masked_out_mask
459
+
460
+
461
+ def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor:
462
+ """
463
+ Splits patch lengths into smaller segments if they exceed `max_patch_length`.
464
+ Pads the result to uniform length across the batch.
465
+
466
+ Args:
467
+ patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths.
468
+ max_patch_length (int, optional): Maximum allowed length per patch.
469
+
470
+ Returns:
471
+ torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths.
472
+ """
473
+ if max_patch_length is None:
474
+ return patch_lengths
475
+
476
+ batch_size = patch_lengths.size(0)
477
+ processed = []
478
+
479
+ for seq in patch_lengths:
480
+ splits = []
481
+ for length in seq[seq > 0]:
482
+ length = length.item()
483
+ full_chunks, remainder = divmod(length, max_patch_length)
484
+ splits.extend([max_patch_length] * full_chunks)
485
+ if remainder:
486
+ splits.append(remainder)
487
+ processed.append(splits)
488
+
489
+ # Find max length to pad to
490
+ max_len = max(len(splits) for splits in processed)
491
+ padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
492
+
493
+ for i, splits in enumerate(processed):
494
+ if splits:
495
+ padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
496
+
497
+ # Trim zero columns
498
+ if (padded != 0).any(dim=0).sum() < padded.shape[1]:
499
+ last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1
500
+ padded = padded[:, :last_nonzero]
501
+
502
+ return padded
503
+
504
+
505
+ class BLTRotaryEmbedding(nn.Module):
506
+ def __init__(self, config, device=None):
507
+ super().__init__()
508
+ self.rope_type = config.rope_scaling["rope_type"]
509
+ self.max_seq_len_cached = config.max_position_embeddings
510
+ self.original_max_seq_len = config.max_position_embeddings
511
+
512
+ self.config = config
513
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
514
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
515
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
516
+ self.original_inv_freq = self.inv_freq
517
+
518
+ @torch.no_grad()
519
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
520
+ def forward(self, x, position_ids):
521
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
522
+ position_ids_expanded = position_ids[:, None, :].float()
523
+
524
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
525
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
526
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
527
+ emb = torch.cat((freqs, freqs), dim=-1)
528
+ cos = emb.cos() * self.attention_scaling
529
+ sin = emb.sin() * self.attention_scaling
530
+
531
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
532
+
533
+
534
+ class BLTLocalEncoder(nn.Module):
535
+ def __init__(self, config: BLTLocalEncoderConfig):
536
+ super().__init__()
537
+
538
+ self.hidden_size = config.hidden_size
539
+ self.vocab_size=config.vocab_size
540
+ self.num_hidden_layers = config.num_hidden_layers
541
+ self.dropout = config.dropout
542
+ self.cross_attn_all_layers = config.cross_attn_all_layers
543
+ self.cross_attn_k = config.cross_attn_k
544
+
545
+ self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(self.num_hidden_layers)])
546
+
547
+ self.rotary_emb = BLTRotaryEmbedding(config=config)
548
+
549
+ self.patch_embedding_projection = nn.Linear(
550
+ in_features=config.encoder_dim_patch_emb,
551
+ out_features=config.encoder_dim_token_emb * config.cross_attn_k,
552
+ bias=False,
553
+ )
554
+
555
+ self.embed_tokens = nn.Embedding(self.vocab_size + config.pm_size, self.hidden_size)
556
+
557
+ self.cross_attn_layers = torch.nn.ModuleList()
558
+ layers_to_add = self.num_hidden_layers if self.cross_attn_all_layers else 1
559
+ for layer_idx in range(layers_to_add):
560
+ self.cross_attn_layers.append(
561
+ BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size)
562
+ )
563
+
564
+ def forward(
565
+ self,
566
+ input_ids: torch.Tensor,
567
+ input_embeds: Optional[torch.Tensor] = None,
568
+ patch_embeds: Optional[torch.Tensor] = None,
569
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
570
+ cross_mask: Optional[torch.Tensor] = None,
571
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
572
+ num_patches: Optional[int] = None,
573
+ patch_ids: Optional[torch.Tensor] = None,
574
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
575
+ ):
576
+ """ """
577
+ if input_embeds is None:
578
+ input_embeds = self.embed_tokens(input_ids)
579
+
580
+ batch_size, _, _ = input_embeds.shape
581
+
582
+ hidden_states = nn.functional.dropout(input_embeds, p=self.dropout, training=self.training)
583
+
584
+ position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1)
585
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
586
+
587
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
588
+
589
+ for idx, layer in enumerate(self.layers):
590
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None)
591
+ hidden_states = layer_outputs[0]
592
+
593
+ if idx == len(self.layers) - 1 or self.cross_attn_all_layers:
594
+ patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids)
595
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
596
+ patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.hidden_size)
597
+
598
+ layer_idx = idx if self.cross_attn_all_layers else 0
599
+ cross_attention_output, _, _ = self.cross_attn_layers[layer_idx](
600
+ hidden_states=patch_embeds,
601
+ cross_attention_states=hidden_states,
602
+ attention_mask=cross_mask,
603
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
604
+ output_attentions=False,
605
+ use_cache=False,
606
+ cache_position=None,
607
+ )
608
+ patch_embeds = patch_embeds + cross_attention_output
609
+
610
+ encoder_cross_states = patch_embeds
611
+ return hidden_states, encoder_cross_states
612
+
613
+ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids):
614
+ """
615
+ Reduce variable length patches to single embedding per patch
616
+ Note: this works with variable number of patches for different sequences in the batch
617
+ It handles variable length patches by assuming that patch_lengths will be 0 for any
618
+ extra patches on the *right*. Since there can be a variable number of patches
619
+ this function also return the number of patches for each sequence in the batch.
620
+ Any embeddings on the right that are not allocated to a patch
621
+ (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
622
+ will be sent to a dummy patch, which is trimmed before returning.
623
+ """
624
+ batch_size, _, embedding_dim = hidden_states.shape
625
+
626
+ patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])
627
+
628
+ reduced_embeddings = torch.zeros((batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device)
629
+ reduced_embeddings = reduced_embeddings.scatter_reduce(
630
+ src=hidden_states,
631
+ dim=1,
632
+ index=patch_ids,
633
+ reduce=reduction,
634
+ include_self=False,
635
+ )
636
+ reduced_embeddings = reduced_embeddings[:, :max_num_patches, :]
637
+
638
+ return reduced_embeddings
639
+
640
+
641
+ class BLTLocalDecoder(nn.Module):
642
+ def __init__(self, config: BLTLocalDecoderConfig):
643
+ super().__init__()
644
+
645
+ # Extract config values to instance attributes
646
+ self.hidden_size = config.hidden_size
647
+ self.vocab_size=config.vocab_size
648
+ self.num_hidden_layers = config.num_hidden_layers
649
+ self.dropout = config.dropout
650
+ self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove
651
+ self.cross_attn_all_layers = config.cross_attn_all_layers
652
+ self.cross_attn_k = config.cross_attn_k
653
+
654
+ self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(self.num_hidden_layers)])
655
+
656
+ self.rotary_emb = BLTRotaryEmbedding(config=config)
657
+
658
+ self.patch_embedding_projection = nn.Linear(
659
+ in_features=config.hidden_size_global,
660
+ out_features=config.decoder_dim_token_emb * config.cross_attn_k,
661
+ bias=False,
662
+ )
663
+
664
+ self.norm = BLTRMSNorm(self.hidden_size, eps=config.norm_eps)
665
+
666
+ self.cross_attn_layers = torch.nn.ModuleList()
667
+ layers_to_add = self.num_hidden_layers if self.cross_attn_all_layers else 1
668
+ for layer_idx in range(layers_to_add):
669
+ self.cross_attn_layers.append(
670
+ BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size)
671
+ )
672
+
673
+ self.lm_head = nn.Linear(
674
+ self.hidden_size,
675
+ self.vocab_size,
676
+ bias=False,
677
+ )
678
+
679
+
680
+ def forward(
681
+ self,
682
+ tokens: torch.Tensor,
683
+ embeds: Optional[torch.Tensor],
684
+ patch_embeds: Optional[torch.Tensor] = None,
685
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
686
+ cross_mask: Optional[torch.Tensor] = None,
687
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
688
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
689
+ ):
690
+ batch_size, _, _ = embeds.shape
691
+
692
+ hidden_states = embeds
693
+
694
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
695
+ patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.hidden_size)
696
+
697
+ if patch_embeds is not None and not self.cross_attn_decoder:
698
+ hidden_states = hidden_states + patch_embeds
699
+
700
+ position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1)
701
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
702
+
703
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
704
+ for i, layer in enumerate(self.layers):
705
+ if i == 0 or self.cross_attn_all_layers:
706
+ # Use cross attention to extract info from patch_embeds into hidden_states
707
+ cross_attention_output, _, _ = self.cross_attn_layers[i](
708
+ hidden_states=hidden_states,
709
+ cross_attention_states=patch_embeds,
710
+ attention_mask=cross_mask,
711
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
712
+ output_attentions=False,
713
+ use_cache=False,
714
+ cache_position=None,
715
+ )
716
+ hidden_states = hidden_states + cross_attention_output
717
+
718
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None)
719
+ hidden_states = layer_outputs[0]
720
+
721
+ logits = self.lm_head(self.norm(hidden_states))
722
+ return logits, cache
723
+
724
+
725
+ class BLTCrossAttention(nn.Module):
726
+ """Cross-attention module for BLT, following transformers style"""
727
+
728
+ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None):
729
+ super().__init__()
730
+ self.config = config
731
+ self.layer_idx = layer_idx
732
+ # Use provided hidden_size or fallback to encoder dimension
733
+ self.hidden_size = hidden_size or config.hidden_size_local_encoder
734
+ self.num_heads = config.num_attention_heads
735
+ self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention
736
+ self.head_dim = self.hidden_size // self.num_heads
737
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
738
+ self.scaling = None
739
+ self.dropout = config.dropout
740
+
741
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
742
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
743
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
744
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
745
+
746
+ self.q_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps)
747
+ self.k_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps)
748
+
749
+ def forward(
750
+ self,
751
+ hidden_states: torch.Tensor,
752
+ cross_attention_states: Optional[torch.Tensor] = None,
753
+ past_key_value: Optional[Cache] = None,
754
+ attention_mask: Optional[torch.Tensor] = None,
755
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
756
+ output_attentions: bool = False,
757
+ use_cache: Optional[bool] = None,
758
+ cache_position: Optional[torch.LongTensor] = None,
759
+ **kwargs,
760
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
761
+ """Input shape: Batch x Time x Channel"""
762
+ bsz, q_len, _ = hidden_states.size()
763
+
764
+ query_states = self.q_norm(hidden_states) # BLT normalizes first
765
+ query_states = self.q_proj(query_states)
766
+
767
+ if cross_attention_states is not None:
768
+ cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first
769
+ key_states = self.k_proj(cross_attention_states)
770
+ value_states = self.v_proj(cross_attention_states)
771
+ if past_key_value is not None:
772
+ # if we have a new cross attention states + new tokens, we only computed key_states on that new cross attention states
773
+ # we still update the cross key states, past_cross_states, new_cross_states. And use it!
774
+ key_states, value_states = past_key_value.update(
775
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
776
+ )
777
+ elif cache_position is not None and cache_position[0] != 0:
778
+ key_states, value_states = (
779
+ past_key_value.key_cache[self.layer_idx],
780
+ past_key_value.value_cache[self.layer_idx],
781
+ )
782
+ else:
783
+ if cross_attention_states is None:
784
+ raise ValueError(
785
+ "Cross attention layer can't find neither `cross_attention_states` nor cached values for key/values!"
786
+ )
787
+
788
+ attention_interface: Callable = eager_attention_forward
789
+
790
+ # self.config._attn_implementation = "sdpa"
791
+ # attn = "sdpa"
792
+ if self.config._attn_implementation != "eager":
793
+ if self.config._attn_implementation == "sdpa" and output_attentions:
794
+ logger.warning_once(
795
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
796
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
797
+ )
798
+ else:
799
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
800
+
801
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
802
+ key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
803
+ value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
804
+
805
+ attn_output, attn_weights = attention_interface(
806
+ self,
807
+ query_states,
808
+ key_states,
809
+ value_states,
810
+ attention_mask,
811
+ dropout=0.0, #if not self.training else self.dropout,
812
+ scaling=self.scaling,
813
+ **kwargs,
814
+ )
815
+
816
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
817
+ attn_output = self.o_proj(attn_output)
818
+
819
+ # Apply full row masking if provided (following mllama pattern)
820
+ if full_text_row_masked_out_mask is not None:
821
+ attn_output = full_text_row_masked_out_mask[:, 0] * attn_output
822
+
823
+ attn_output = attn_output + hidden_states
824
+
825
+ if not output_attentions:
826
+ attn_weights = None
827
+
828
+ return attn_output, attn_weights, past_key_value
829
+
830
+
831
+ class BLTGlobalTransformer(nn.Module):
832
+ def __init__(self, config: BLTGlobalTransformerConfig):
833
+ super().__init__()
834
+
835
+ self.hidden_size = config.hidden_size
836
+ self.num_hidden_layers = config.num_hidden_layers
837
+ self.dropout = config.dropout
838
+
839
+ self.layers = nn.ModuleList()
840
+ for layer_idx in range(self.num_hidden_layers):
841
+ self.layers.append(BLTTransformerLayer(config, layer_idx))
842
+
843
+ self.rotary_emb = BLTRotaryEmbedding(config=config)
844
+
845
+
846
+ def forward(
847
+ self,
848
+ input_embeds: torch.Tensor,
849
+ mask: Optional[Union[BlockMask, torch.Tensor, str]] = None,
850
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
851
+ ):
852
+ batch_size, seq_len, _ = input_embeds.shape
853
+
854
+ hidden_states = input_embeds
855
+
856
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
857
+
858
+ position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1)
859
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
860
+
861
+ for i, layer in enumerate(self.layers):
862
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None)
863
+ hidden_states = layer_outputs[0]
864
+
865
+ return hidden_states, cache
866
+
867
+
868
+
869
+
870
+ class BLTPreTrainedModel(PreTrainedModel):
871
+ config_class = BLTConfig
872
+ base_model_prefix = "model"
873
+ supports_gradient_checkpointing = True
874
+ _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"]
875
+ _skip_keys_device_placement = ["past_key_values"]
876
+ _supports_flash_attn_2 = False # BLT uses its own attention implementation
877
+ _supports_sdpa = True
878
+ _supports_cache_class = False
879
+
880
+ def _init_weights(self, module):
881
+ if isinstance(module, nn.Linear):
882
+ std = getattr(module, '_custom_std', module.in_features ** (-0.5))
883
+ nn.init.trunc_normal_(
884
+ module.weight,
885
+ mean=0.0,
886
+ std=std,
887
+ a=-3 * std,
888
+ b=3 * std,
889
+ )
890
+ if module.bias is not None:
891
+ nn.init.zeros_(module.bias)
892
+
893
+ elif isinstance(module, nn.Embedding):
894
+ std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5))
895
+ nn.init.trunc_normal_(
896
+ module.weight,
897
+ mean=0.0,
898
+ std=std,
899
+ a=-3 * std,
900
+ b=3 * std,
901
+ )
902
+
903
+ elif isinstance(module, BLTModel):
904
+ if module.encoder_hash_tok_embedding is not None:
905
+ emb_std = module.config.hidden_size_local_encoder ** (-0.5)
906
+ for emb in module.encoder_hash_tok_embedding:
907
+ emb._custom_std = emb_std
908
+
909
+ elif isinstance(module, BLTLocalEncoder):
910
+ if module.patch_embedding_projection is not None:
911
+ module.patch_embedding_projection._custom_std = module.config.encoder_dim_patch_emb ** (-0.5)
912
+
913
+ elif isinstance(module, BLTLocalDecoder):
914
+ if module.patch_embedding_projection is not None:
915
+ module.patch_embedding_projection._custom_std = module.config.hidden_size_global ** (-0.5)
916
+
917
+ elif isinstance(module, BLTPatcher):
918
+ emb_std = module.config.hidden_size ** (-0.5)
919
+ module.embed_tokens._custom_std = emb_std
920
+ module.lm_head._custom_std = emb_std
921
+
922
+
923
+ class BLTModel(BLTPreTrainedModel):
924
+ def __init__(self, config: BLTConfig):
925
+ super().__init__(config)
926
+
927
+ self.config = config
928
+
929
+ self.local_encoder = BLTLocalEncoder(config.encoder_config)
930
+ self.global_transformer = BLTGlobalTransformer(config.global_config)
931
+ self.local_decoder = BLTLocalDecoder(config.decoder_config)
932
+
933
+ self.encoder_hash_tok_embedding = init_hash_embeddings(
934
+ config,
935
+ local_encoder_dim=config.hidden_size_local_encoder,
936
+ encoder_hash_byte_group_size=config.encoder_hash_byte_group_size,
937
+ )
938
+
939
+ if self.config.patch_in_forward:
940
+ self.patcher = BLTPatcher(config.patcher_config)
941
+ self.patcher.eval()
942
+ for param in self.patcher.parameters():
943
+ param.requires_grad = False
944
+ else:
945
+ self.patcher = None
946
+
947
+ def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = None):
948
+ batch_size, sequence_length = tokens.shape
949
+
950
+ # Handle patching
951
+ if patch_lengths is None:
952
+ if self.config.patching_mode == PatchingModeEnum.entropy:
953
+ _, patch_lengths, _ = self.patcher(
954
+ tokens,
955
+ patch_size=self.config.patch_size,
956
+ threshold=self.config.patching_threshold,
957
+ max_patch_length=self.config.max_patch_length,
958
+ patching_batch_size=self.config.patching_batch_size,
959
+ device=self.config.patching_device,
960
+ )
961
+ else:
962
+ # Default to byte-level patching
963
+ patch_lengths = process_patch_lengths(
964
+ torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device),
965
+ self.config.max_patch_length
966
+ )
967
+
968
+ patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length)
969
+ cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask(
970
+ patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, torch.float32
971
+ )
972
+
973
+ encoder_embeds = compute_hash_embeddings(
974
+ tokens, self.local_encoder, self.encoder_hash_tok_embedding,
975
+ self.config.encoder_hash_byte_group_nb_functions,
976
+ self.config.encoder_hash_byte_group_size,
977
+ self.config.encoder_hash_byte_group_vocab,
978
+ )
979
+
980
+ encoder_hidden_states, encoder_cross_states = self.local_encoder(
981
+ input_ids=tokens,
982
+ input_embeds=encoder_embeds,
983
+ patch_embeds=None,
984
+ cross_mask=cross_attn_mask_enc,
985
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc,
986
+ num_patches=patch_lengths.shape[1],
987
+ patch_ids=patch_ids,
988
+ )
989
+
990
+ global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
991
+
992
+ global_hidden_states, _ = self.global_transformer(
993
+ input_embeds=global_hidden_states,
994
+ )
995
+
996
+ decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length)
997
+ cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask(
998
+ decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, torch.float32
999
+ )
1000
+
1001
+ output, _ = self.local_decoder(
1002
+ tokens=tokens,
1003
+ embeds=encoder_hidden_states,
1004
+ patch_embeds=global_hidden_states,
1005
+ mask=None,
1006
+ cross_mask=cross_attn_mask_dec,
1007
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec,
1008
+ )
1009
+
1010
+ return output
1011
+
1012
+ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
1013
+ """Convert patch lengths to patch IDs for each token position."""
1014
+ batch_size = patch_lengths.shape[0]
1015
+ patch_starts = torch.cat([
1016
+ torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
1017
+ patch_lengths.cumsum(dim=-1)[:, :-1]
1018
+ ], dim=-1)
1019
+
1020
+ token_positions = torch.arange(seq_len, device=patch_lengths.device)
1021
+ return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1
1022
+
1023
+
1024
+ class BLTPatcher(BLTPreTrainedModel):
1025
+ def __init__(self, config: BLTPatcherConfig):
1026
+ super().__init__(config)
1027
+
1028
+ self.rotary_emb = BLTRotaryEmbedding(config=self.config)
1029
+
1030
+ self.layers = nn.ModuleList()
1031
+ # Create transformer layers using the patcher config
1032
+ for layer_idx in range(self.config.num_hidden_layers):
1033
+ self.layers.append(BLTTransformerLayer(self.config, layer_idx))
1034
+
1035
+
1036
+ self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.hidden_size)
1037
+
1038
+ self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps)
1039
+
1040
+ self.lm_head = nn.Linear(
1041
+ self.config.hidden_size,
1042
+ self.config.vocab_size,
1043
+ bias=False,
1044
+ )
1045
+
1046
+ def forward(
1047
+ self,
1048
+ token_values: torch.Tensor,
1049
+ patch_size: Optional[int] = None,
1050
+ threshold: Optional[float] = None,
1051
+ max_patch_length: Optional[int] = None,
1052
+ patching_batch_size: int = 1,
1053
+ device: Optional[str] = None,
1054
+ ):
1055
+
1056
+ # Handle chunked processing for entropy calculation
1057
+ entropies = []
1058
+ predictions = []
1059
+ max_length = self.config.max_position_embeddings
1060
+ batch_numel = max_length * patching_batch_size
1061
+ splits = torch.split(token_values.flatten(), batch_numel)
1062
+
1063
+ for split in splits:
1064
+ pad_size = (max_length - (split.numel() % max_length)) % max_length
1065
+ pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False)
1066
+ split = torch.cat((split, pad), dim=0)
1067
+ split = split.reshape(-1, max_length)
1068
+ if device is not None:
1069
+ split = split.to(device)
1070
+
1071
+ # Process chunk: embeddings -> layers -> output
1072
+ batch_size, sequence_length = split.shape
1073
+ input_embeds = self.embed_tokens(split)
1074
+
1075
+ hidden_states = input_embeds
1076
+
1077
+ batch_size, _, _ = input_embeds.shape
1078
+
1079
+ position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1)
1080
+
1081
+ position_embeddings = self.rotary_emb(hidden_states, position_ids) # = BLT self.rope
1082
+
1083
+ for i, layer in enumerate(self.layers):
1084
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) #, attn_impl=self.config.patcher_attn_impl )
1085
+ hidden_states = layer_outputs[0]
1086
+
1087
+ logits = self.lm_head(self.norm(hidden_states))
1088
+ logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab]
1089
+ predictions.append(logits)
1090
+ prediction_entropies = torch.distributions.Categorical(logits=logits).entropy()
1091
+ entropies.append(prediction_entropies)
1092
+
1093
+ concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape)
1094
+ concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1)
1095
+
1096
+ # Always compute patch lengths from concatenated entropies
1097
+ batch_size, sequence_length = token_values.shape
1098
+
1099
+ # Find patch start IDs based on entropy
1100
+ if patch_size is not None:
1101
+ patch_lengths = self.patch_lengths_from_entropies(
1102
+ entropies=concat_entropies,
1103
+ sequence_length=sequence_length,
1104
+ patch_size=patch_size,
1105
+ threshold=threshold,
1106
+ )
1107
+ else:
1108
+ # Default to byte-level patching
1109
+ patch_lengths = torch.ones((batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device)
1110
+ patch_lengths = process_patch_lengths(patch_lengths, max_patch_length)
1111
+ return concat_entropies, patch_lengths, concat_predictions
1112
+
1113
+ @staticmethod
1114
+ def patch_lengths_from_entropies(
1115
+ entropies,
1116
+ sequence_length,
1117
+ patch_size=None,
1118
+ threshold=None,
1119
+ ):
1120
+ """
1121
+ Computes patch lengths from token entropies.
1122
+
1123
+ Depending on whether a threshold is provided, the function uses either:
1124
+ - Top-k selection based on entropy (when `threshold` is None), or
1125
+ - Thresholding the entropy values (when `threshold` is set).
1126
+ """
1127
+
1128
+ batch_size = entropies.shape[0]
1129
+
1130
+ # Always include token 0 and 1 as starting tokens
1131
+ init_tokens = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1)
1132
+ offset = init_tokens.shape[1]
1133
+
1134
+ # Ignore first token entropy (BOS)
1135
+ entropies = entropies[:, 1:]
1136
+
1137
+ if threshold is None:
1138
+ # Use top-k entropy values to define patch start points
1139
+ num_patches = sequence_length // patch_size
1140
+ topk_indices = entropies.topk(num_patches - 2, dim=1).indices
1141
+ patch_starts = topk_indices.sort(dim=1).values
1142
+ else:
1143
+ # Threshold the entropy values to define patch start points
1144
+ patch_mask = entropies > threshold
1145
+
1146
+ seq_len = patch_mask.shape[1]
1147
+
1148
+ # Create patch IDs (token indices), and add a sentinel to ensure alignment
1149
+ token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1)
1150
+ sentinel = torch.full_like(token_indices, seq_len)
1151
+ padded_indices = torch.cat([token_indices, sentinel], dim=1)
1152
+
1153
+ # Pad mask with inverse to align sentinel correctly
1154
+ padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1)
1155
+
1156
+ # Select indices where mask is True
1157
+ patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len)
1158
+ max_valid_patches = patch_mask.sum(dim=1).max()
1159
+ patch_starts = patch_starts[:, :max_valid_patches]
1160
+
1161
+ # Offset patch starts to account for the two initial tokens
1162
+ patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1)
1163
+
1164
+ # Compute patch end positions by shifting start positions
1165
+ last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1)
1166
+ patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1)
1167
+
1168
+ patch_lengths = patch_ends - patch_start_ids + 1
1169
+
1170
+ return patch_lengths
1171
+
1172
+ __all__ = [
1173
+ "BLTPreTrainedModel",
1174
+ "BLTModel",
1175
+ "BLTPatcher",
1176
+ "BLTLocalEncoder",
1177
+ "BLTLocalDecoder",
1178
+ "BLTGlobalTransformer",
1179
+ "BLTTransformerLayer",
1180
+ ]
backup_blt_wip copy/tokenization_blt.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for BLT."""
16
+
17
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
18
+
19
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
20
+ from ...utils import logging
21
+
22
+
23
+ if TYPE_CHECKING:
24
+ from ...tokenization_utils_base import TextInput
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ # BLT tokenizer constants
29
+ SEP = " "
30
+ BOS_ID: int = 1
31
+ EOS_ID: int = 2
32
+ PAD_ID: int = -1
33
+ BOE_ID: int = 0
34
+ BPE_ID: int = 3
35
+ OFFSET: int = 4
36
+ BYTE_UNITS: int = 256
37
+
38
+ VOCAB_FILES_NAMES = {} # BLT doesn't require external vocab files
39
+
40
+
41
+ class BLTTokenizer(PreTrainedTokenizer):
42
+ """
43
+ Construct a BLT tokenizer. Based on byte-level tokenization where each byte is treated as a token.
44
+
45
+ This tokenizer converts text to UTF-8 bytes and then maps each byte to a token ID with an offset.
46
+ It supports special tokens for beginning of sequence (BOS), end of sequence (EOS),
47
+ beginning of example (BOE), and padding (PAD).
48
+
49
+ Args:
50
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
51
+ The beginning of sequence token.
52
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
53
+ The end of sequence token.
54
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`):
55
+ The padding token.
56
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
57
+ The unknown token. Not used in BLT but kept for compatibility.
58
+ boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<boe>"`):
59
+ The beginning of example token, specific to BLT.
60
+ add_bos_token (`bool`, *optional*, defaults to `True`):
61
+ Whether or not to add a `bos_token` at the start of sequences.
62
+ add_eos_token (`bool`, *optional*, defaults to `True`):
63
+ Whether or not to add an `eos_token` at the end of sequences.
64
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
65
+ Whether or not to cleanup spaces after decoding.
66
+ spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
67
+ Whether or not to add spaces between special tokens.
68
+ """
69
+
70
+ vocab_files_names = VOCAB_FILES_NAMES
71
+ model_input_names = ["input_ids", "attention_mask"]
72
+
73
+ def __init__(
74
+ self,
75
+ bos_token="<s>",
76
+ eos_token="</s>",
77
+ pad_token="<pad>",
78
+ unk_token="<unk>",
79
+ boe_token="<boe>",
80
+ add_bos_token=True,
81
+ add_eos_token=True,
82
+ clean_up_tokenization_spaces=False,
83
+ spaces_between_special_tokens=False,
84
+ **kwargs,
85
+ ):
86
+ # Store BLT-specific parameters first
87
+ self.add_bos_token = add_bos_token
88
+ self.add_eos_token = add_eos_token
89
+ self.vocab_size_unit_1 = BYTE_UNITS
90
+ self.offsetting_special_char = OFFSET
91
+
92
+ # BLT token IDs (exactly like original)
93
+ self.boe_id = BOE_ID
94
+ self.bos_id = BOS_ID
95
+ self.eos_id = EOS_ID
96
+ self.pad_id = PAD_ID
97
+ self.bpe_id = BPE_ID
98
+ self.n_words = self.vocab_size_unit_1 + self.offsetting_special_char
99
+
100
+ # Convert string tokens to AddedToken objects
101
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
102
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
103
+ pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
104
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
105
+ self.boe_token = AddedToken(boe_token, normalized=False, special=True) if isinstance(boe_token, str) else boe_token
106
+
107
+ super().__init__(
108
+ bos_token=bos_token,
109
+ eos_token=eos_token,
110
+ pad_token=pad_token,
111
+ unk_token=unk_token,
112
+ add_bos_token=add_bos_token,
113
+ add_eos_token=add_eos_token,
114
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
115
+ spaces_between_special_tokens=spaces_between_special_tokens,
116
+ **kwargs,
117
+ )
118
+
119
+ @property
120
+ def vocab_size(self):
121
+ """Returns vocab size"""
122
+ return self.vocab_size_unit_1 + self.offsetting_special_char
123
+
124
+ def get_vocab(self):
125
+ """Returns vocab as a dict"""
126
+ # Create a mapping for byte values + offset
127
+ vocab = {}
128
+
129
+ # Add special tokens (with defensive checks)
130
+ if hasattr(self, 'bos_token'):
131
+ vocab[str(self.bos_token)] = self.bos_id
132
+ if hasattr(self, 'eos_token'):
133
+ vocab[str(self.eos_token)] = self.eos_id
134
+ if hasattr(self, 'pad_token'):
135
+ vocab[str(self.pad_token)] = self.pad_id
136
+ if hasattr(self, 'boe_token'):
137
+ vocab[str(self.boe_token)] = self.boe_id
138
+
139
+ # Add byte tokens as string representations of byte values
140
+ vocab_size_unit_1 = getattr(self, 'vocab_size_unit_1', BYTE_UNITS)
141
+ offsetting_special_char = getattr(self, 'offsetting_special_char', OFFSET)
142
+ for i in range(vocab_size_unit_1):
143
+ vocab[str(i)] = i + offsetting_special_char
144
+
145
+ # Add any additional tokens if available
146
+ if hasattr(self, 'added_tokens_encoder'):
147
+ vocab.update(self.added_tokens_encoder)
148
+ return vocab
149
+
150
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
151
+ """
152
+ Converts a string to a list of tokens. For BLT, we work directly with byte values.
153
+ Returns a list of strings that represent the byte values.
154
+ """
155
+ # Convert text to UTF-8 bytes, just like the original
156
+ try:
157
+ bytes_data = text.encode("utf-8", errors="ignore")
158
+ except UnicodeEncodeError:
159
+ bytes_data = text.encode("utf-8", errors="ignore")
160
+
161
+ # Return string representations of byte values for the tokenizer framework
162
+ return [str(byte_val) for byte_val in bytes_data]
163
+
164
+ def _convert_token_to_id(self, token: str) -> int:
165
+ """Converts a token (str) to an id using the vocab."""
166
+ # Handle special tokens
167
+ if token == str(self.bos_token):
168
+ return self.bos_id
169
+ elif token == str(self.eos_token):
170
+ return self.eos_id
171
+ elif token == str(self.pad_token):
172
+ return self.pad_id
173
+ elif token == str(self.boe_token):
174
+ return self.boe_id
175
+ else:
176
+ try:
177
+ # Convert byte value string to int and add offset (like original)
178
+ byte_val = int(token)
179
+ if 0 <= byte_val <= 255:
180
+ return byte_val + self.offsetting_special_char
181
+ except ValueError:
182
+ pass
183
+
184
+ # Check if it's in added tokens
185
+ return self.added_tokens_encoder.get(token, self.unk_token_id)
186
+
187
+ def _convert_id_to_token(self, index: int) -> str:
188
+ """Converts an index (integer) to a token (str) using the vocab."""
189
+ # Handle special tokens
190
+ if index == self.bos_id:
191
+ return str(self.bos_token)
192
+ elif index == self.eos_id:
193
+ return str(self.eos_token)
194
+ elif index == self.pad_id:
195
+ return str(self.pad_token)
196
+ elif index == self.boe_id:
197
+ return str(self.boe_token)
198
+ elif index >= self.offsetting_special_char and index < self.vocab_size:
199
+ # Convert back to byte value (like original)
200
+ byte_val = index - self.offsetting_special_char
201
+ return str(byte_val)
202
+ else:
203
+ # Check added tokens
204
+ for token, token_id in self.added_tokens_encoder.items():
205
+ if token_id == index:
206
+ return token
207
+ return str(self.unk_token)
208
+
209
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
210
+ """Converts a sequence of tokens to a single string."""
211
+ byte_values = []
212
+
213
+ for token in tokens:
214
+ # Skip special tokens
215
+ if token in [str(self.bos_token), str(self.eos_token), str(self.pad_token), str(self.boe_token)]:
216
+ continue
217
+
218
+ try:
219
+ # Convert token back to byte value (like original decode method)
220
+ byte_val = int(token)
221
+ if 0 <= byte_val <= 255:
222
+ byte_values.append(byte_val)
223
+ except ValueError:
224
+ continue
225
+
226
+ # Convert byte values back to string (exactly like original)
227
+ try:
228
+ return bytes(byte_values).decode("utf-8", errors="ignore")
229
+ except (UnicodeDecodeError, ValueError):
230
+ return ""
231
+
232
+ def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None):
233
+ """
234
+ Encode text exactly like the original BLT tokenizer.
235
+ """
236
+ if add_bos is None:
237
+ add_bos = self.add_bos_token
238
+ if add_eos is None:
239
+ add_eos = self.add_eos_token
240
+
241
+ # Since bpe_delim=False, we use the simple byte encoding
242
+ tokens = bytes(text, encoding="utf-8", errors="ignore")
243
+
244
+ # Offsetting (exactly like original)
245
+ tokens = [int(unit) + self.offsetting_special_char for unit in tokens]
246
+
247
+ if add_bos:
248
+ tokens.insert(0, self.bos_id)
249
+ if add_eos:
250
+ tokens.append(self.eos_id)
251
+
252
+ return tokens
253
+
254
+ def decode(self, tokens: list[int], cut_at_eos: bool = False):
255
+ """
256
+ Decode tokens exactly like the original BLT tokenizer.
257
+ """
258
+ if cut_at_eos:
259
+ for k, t in enumerate(tokens):
260
+ if t == self.eos_id:
261
+ tokens = tokens[: k + 1]
262
+ break
263
+ return bytes(
264
+ [tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0]
265
+ ).decode("utf-8", errors="ignore")
266
+
267
+ def get_vocab_size(self) -> int:
268
+ """Get vocab size like the original tokenizer."""
269
+ return self.vocab_size_unit_1 + self.offsetting_special_char
270
+
271
+ #__all__ = ["BLTTokenizer"]
backup_blt_wip_backup/__pycache__/blt_args.cpython-312.pyc ADDED
Binary file (7.05 kB). View file
 
backup_blt_wip_backup/__pycache__/blt_one_file.cpython-312.pyc ADDED
Binary file (96.4 kB). View file
 
backup_blt_wip_backup/__pycache__/configuration_blt.cpython-312.pyc ADDED
Binary file (22.1 kB). View file
 
backup_blt_wip_backup/__pycache__/modeling_blt_wip.cpython-312.pyc ADDED
Binary file (78.7 kB). View file
 
backup_blt_wip_backup/__pycache__/modeling_blt_wip_backup.cpython-312.pyc ADDED
Binary file (89.6 kB). View file
 
backup_blt_wip_backup/__pycache__/tokenization_blt.cpython-312.pyc ADDED
Binary file (11.6 kB). View file
 
backup_blt_wip_backup/blt_args.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Any
3
+
4
+ from pydantic import BaseModel, ConfigDict, model_validator
5
+ from typing_extensions import Self
6
+
7
+
8
+ EOS_ID: int = 2
9
+
10
+
11
+ class InitStdFactor(str, Enum):
12
+ DISABLED = "disabled" # Init std is divided by 1.0
13
+ GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
14
+ CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
15
+ DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
16
+
17
+
18
+ class PatchingModeEnum(str, Enum):
19
+ entropy = "entropy"
20
+ bpe = "bpe"
21
+ bpe_patcher = "bpe_patcher"
22
+ space = "space"
23
+ static = "static"
24
+ byte = "byte"
25
+
26
+
27
+ class LMTransformerArgs(BaseModel):
28
+ """Arguments for the Language Model Transformer (used as entropy model for patching)"""
29
+
30
+ model_config = ConfigDict()
31
+
32
+ # Basic architecture
33
+ dim: int = 512
34
+ n_layers: int = 8
35
+ head_dim: int | None = None
36
+ n_heads: int | None = None
37
+ n_kv_heads: int | None = None
38
+
39
+ # Transformer configuration
40
+ max_seqlen: int = 1024
41
+ norm_eps: float = 1e-5
42
+ dropout: float = 0
43
+ vocab_size: int = -1
44
+ sliding_window: int | None = None
45
+
46
+ # Feedforward
47
+ ffn_dim_multiplier: float | None = None
48
+ multiple_of: int = 256
49
+
50
+ # Positional encoding
51
+ rope_theta: float = 10000.0
52
+ rope_use_fp32_in_outer_product: bool = False
53
+
54
+ # Attention
55
+ attn_impl: str = "sdpa"
56
+ attn_bias_type: str = "causal"
57
+
58
+ # Initialization
59
+ init_base_std: float | None = None
60
+ init_std_factor: InitStdFactor = InitStdFactor.DISABLED
61
+
62
+ # Embedding dimensions
63
+ dim_token_emb: int | None = None
64
+
65
+ # Model behavior
66
+ weight_tying: bool = False
67
+ seed: int = 42
68
+
69
+ # Special token config
70
+ eos_id: int = EOS_ID
71
+
72
+
73
+ class ByteLatentTransformerArgs(BaseModel):
74
+ """Arguments for the Byte Latent Transformer (main BLT model)"""
75
+
76
+ model_config = ConfigDict()
77
+
78
+ # Basic model configuration
79
+ seed: int = 42
80
+ vocab_size: int = -1
81
+
82
+ # Main architecture dimensions (these will be used for creating transformer args)
83
+ dim: int = 512
84
+ n_layers: int = 8
85
+ head_dim: int | None = None
86
+ n_heads: int | None = None
87
+ n_kv_heads: int | None = None
88
+
89
+ # Component-specific dimensions
90
+ dim_global: int = 512
91
+ dim_local_decoder: int = 512
92
+ dim_local_encoder: int = 512
93
+ n_layers_global: int = 8
94
+ n_layers_local_decoder: int = 8
95
+ n_layers_local_encoder: int = 8
96
+ n_heads_global: int = 8
97
+ n_heads_local_decoder: int = 8
98
+ n_heads_local_encoder: int = 8
99
+ n_kv_heads_global: int | None = None
100
+
101
+ # Transformer configuration (needed by transformer components)
102
+ max_seqlen: int = 1024
103
+ norm_eps: float = 1e-5
104
+ dropout: float = 0
105
+
106
+ # Feedforward (needed by transformer components)
107
+ ffn_dim_multiplier: float = 1.0
108
+ multiple_of: int = 256
109
+
110
+ # Positional encoding (needed by transformer components)
111
+ rope_theta: float = 10000.0
112
+ rope_use_fp32_in_outer_product: bool = False
113
+
114
+ # Attention (needed by transformer components)
115
+ attn_impl: str = "sdpa"
116
+ attn_bias_type: str = "causal"
117
+
118
+ # Initialization (needed by transformer components)
119
+ init_base_std: float | None = None
120
+ init_std_factor: InitStdFactor = InitStdFactor.DISABLED
121
+
122
+ # Embedding dimensions (needed by transformer components)
123
+ dim_token_emb: int | None = None
124
+
125
+ # Patching configuration
126
+ patch_in_forward: bool = False
127
+ realtime_patching: bool = True
128
+ patch_size: float | None = None
129
+ patching_mode: str | None = None
130
+ patching_threshold: float | None = None
131
+ patching_threshold_add: float | None = None
132
+ monotonicity: bool = False
133
+ patching_batch_size: int = 1
134
+ patching_device: str = "cuda"
135
+ max_patch_length: int | None = None
136
+ entropy_model_checkpoint_dir: str | None = None
137
+
138
+ # Cross attention configurations
139
+ cross_attn_encoder: bool = False
140
+ cross_attn_decoder: bool = False
141
+ cross_attn_window_encoder: int | None = None
142
+ cross_attn_window_decoder: int | None = None
143
+ cross_attn_k: int | None = None
144
+ cross_attn_nheads: int | None = None
145
+ cross_attn_all_layers_decoder: bool = False
146
+ cross_attn_all_layers_encoder: bool = False
147
+ cross_attn_use_flex_attention: bool = True
148
+ cross_attn_init_by_pooling: bool = False
149
+
150
+ # Encoder configurations
151
+ use_local_encoder_transformer: bool = False
152
+ max_encoder_seq_length: int | None = None
153
+ encoder_hash_byte_group_size: Any | None = None
154
+ encoder_hash_byte_group_vocab: int = 30000
155
+ encoder_hash_byte_group_nb_functions: int = 3
156
+ encoder_enable_byte_ngrams: bool = False
157
+ encoder_ngram_to_size_str: str | None = None
158
+ downsampling_by_pooling: str | None = None
159
+
160
+ # Architecture and dimensions
161
+ dim_token: int | None = None
162
+ share_encoder_decoder_emb: bool = True
163
+ weight_tying: bool = False
164
+
165
+ # Attention configuration
166
+ local_attention_window_len: int | None = None
167
+ use_rope: bool = True
168
+
169
+ # Performance optimization
170
+ sequence_parallel: bool = False
171
+ loss_parallel: bool = False
172
+ fuse_sequence_parallel: bool = False
173
+ use_fsdp: bool = True
174
+
175
+ # Parameter mixing
176
+ pm_size: int = 0
177
+
178
+ # Special token config
179
+ eos_id: int = EOS_ID
180
+
181
+ @model_validator(mode="after")
182
+ def check_hash_byte_sizes(self) -> Self:
183
+ if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str:
184
+ self.encoder_hash_byte_group_size = [
185
+ int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0
186
+ ]
187
+ return self
backup_blt_wip_backup/configuration_blt.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """BLT (Byte Latent Transformer) model configuration"""
16
+
17
+ from enum import Enum
18
+
19
+ from ...configuration_utils import PretrainedConfig
20
+ from ...utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class InitStdFactor(str, Enum):
27
+ DISABLED = "disabled" # Init std is divided by 1.0
28
+ CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
29
+
30
+
31
+ class PatchingModeEnum(str, Enum):
32
+ entropy = "entropy"
33
+ bpe = "bpe"
34
+ bpe_patcher = "bpe_patcher"
35
+ space = "space"
36
+ static = "static"
37
+ byte = "byte"
38
+
39
+
40
+ class BLTConfig(PretrainedConfig):
41
+ r"""
42
+ This is the configuration class to store the configuration of a [`ByteLatentTransformer`]. It is used to instantiate a
43
+ BLT model according to the specified arguments, defining the model architecture.
44
+
45
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
46
+ documentation from [`PretrainedConfig`] for more information.
47
+
48
+ Args:
49
+ vocab_size (`int`, *optional*, defaults to 256):
50
+ Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented.
51
+ max_seqlen (`int`, *optional*, defaults to 1024):
52
+ The maximum sequence length that this model can handle.
53
+
54
+ # Main architecture dimensions
55
+ dim (`int`, *optional*, defaults to 512):
56
+ Main dimension of the model.
57
+ n_layers (`int`, *optional*, defaults to 8):
58
+ Number of layers in the main transformer.
59
+ n_heads (`int`, *optional*, defaults to 8):
60
+ Number of attention heads in the main transformer.
61
+ head_dim (`int`, *optional*):
62
+ Dimension of each attention head. If not specified, computed as dim // n_heads.
63
+ n_kv_heads (`int`, *optional*):
64
+ Number of key-value heads for grouped query attention. If not specified, defaults to n_heads.
65
+
66
+ # Component-specific dimensions
67
+ dim_global (`int`, *optional*, defaults to 512):
68
+ Dimension of the global transformer component.
69
+ dim_local_decoder (`int`, *optional*, defaults to 512):
70
+ Dimension of the local decoder component.
71
+ dim_local_encoder (`int`, *optional*, defaults to 512):
72
+ Dimension of the local encoder component.
73
+ n_layers_global (`int`, *optional*, defaults to 8):
74
+ Number of layers in the global transformer.
75
+ n_layers_local_decoder (`int`, *optional*, defaults to 8):
76
+ Number of layers in the local decoder.
77
+ n_layers_local_encoder (`int`, *optional*, defaults to 8):
78
+ Number of layers in the local encoder.
79
+ n_heads_global (`int`, *optional*, defaults to 8):
80
+ Number of attention heads in the global transformer.
81
+ n_heads_local_decoder (`int`, *optional*, defaults to 8):
82
+ Number of attention heads in the local decoder.
83
+ n_heads_local_encoder (`int`, *optional*, defaults to 8):
84
+ Number of attention heads in the local encoder.
85
+ n_kv_heads_global (`int`, *optional*):
86
+ Number of key-value heads in the global transformer.
87
+
88
+ # Transformer configuration
89
+ norm_eps (`float`, *optional*, defaults to 1e-5):
90
+ The epsilon used by the layer normalization layers.
91
+ dropout (`float`, *optional*, defaults to 0.0):
92
+ The dropout probability for all fully connected layers.
93
+ ffn_dim_multiplier (`float`, *optional*, defaults to 1.0):
94
+ Multiplier for the feedforward network dimension.
95
+ multiple_of (`int`, *optional*, defaults to 256):
96
+ Make feedforward network dimension multiple of this value.
97
+
98
+ # Positional encoding
99
+ rope_theta (`float`, *optional*, defaults to 10000.0):
100
+ The base period of the RoPE embeddings.
101
+ rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False):
102
+ Whether to use fp32 in RoPE outer product computation.
103
+
104
+ # Attention configuration
105
+ attn_impl (`str`, *optional*, defaults to "sdpa"):
106
+ Attention implementation to use ("sdpa" or "flex_attention").
107
+ attn_bias_type (`str`, *optional*, defaults to "causal"):
108
+ Type of attention bias to apply.
109
+ local_attention_window_len (`int`, *optional*):
110
+ Window length for local attention.
111
+ use_rope (`bool`, *optional*, defaults to True):
112
+ Whether to use rotary position embeddings.
113
+
114
+ # Initialization
115
+ init_base_std (`float`, *optional*):
116
+ Base standard deviation for weight initialization.
117
+ init_std_factor (`str`, *optional*, defaults to "disabled"):
118
+ Factor for adjusting initialization standard deviation.
119
+
120
+ # Embedding dimensions
121
+ dim_token_emb (`int`, *optional*):
122
+ Token embedding dimension.
123
+ dim_token (`int`, *optional*):
124
+ Token dimension.
125
+
126
+ # Patching configuration
127
+ patch_in_forward (`bool`, *optional*, defaults to False):
128
+ Whether to perform patching during forward pass.
129
+ realtime_patching (`bool`, *optional*, defaults to True):
130
+ Whether to use realtime patching.
131
+ patch_size (`float`, *optional*):
132
+ Size of patches for static patching.
133
+ patching_mode (`str`, *optional*):
134
+ Mode for patching ("entropy", "static", etc.).
135
+ patching_threshold (`float`, *optional*):
136
+ Threshold for entropy-based patching.
137
+ patching_threshold_add (`float`, *optional*):
138
+ Additional threshold parameter for patching.
139
+ monotonicity (`bool`, *optional*, defaults to False):
140
+ Whether to enforce monotonicity in patching.
141
+ patching_batch_size (`int`, *optional*, defaults to 1):
142
+ Batch size for patching operations.
143
+ patching_device (`str`, *optional*, defaults to "cuda"):
144
+ Device to use for patching operations.
145
+ max_patch_length (`int`, *optional*):
146
+ Maximum length of patches.
147
+ entropy_model_checkpoint_dir (`str`, *optional*):
148
+ Directory containing entropy model checkpoint.
149
+
150
+ # Cross attention configurations
151
+ cross_attn_encoder (`bool`, *optional*, defaults to False):
152
+ Whether to use cross attention in encoder.
153
+ cross_attn_decoder (`bool`, *optional*, defaults to False):
154
+ Whether to use cross attention in decoder.
155
+ cross_attn_window_encoder (`int`, *optional*):
156
+ Cross attention window for encoder.
157
+ cross_attn_window_decoder (`int`, *optional*):
158
+ Cross attention window for decoder.
159
+ cross_attn_k (`int`, *optional*):
160
+ Number of cross attention components.
161
+ cross_attn_nheads (`int`, *optional*):
162
+ Number of heads for cross attention.
163
+ cross_attn_all_layers_decoder (`bool`, *optional*, defaults to False):
164
+ Whether to apply cross attention to all decoder layers.
165
+ cross_attn_all_layers_encoder (`bool`, *optional*, defaults to False):
166
+ Whether to apply cross attention to all encoder layers.
167
+ cross_attn_use_flex_attention (`bool`, *optional*, defaults to True):
168
+ Whether to use flexible attention for cross attention.
169
+ cross_attn_init_by_pooling (`bool`, *optional*, defaults to False):
170
+ Whether to initialize cross attention by pooling.
171
+
172
+ # Encoder configurations
173
+ use_local_encoder_transformer (`bool`, *optional*, defaults to False):
174
+ Whether to use transformer in local encoder.
175
+ max_encoder_seq_length (`int`, *optional*):
176
+ Maximum sequence length for encoder.
177
+ encoder_hash_byte_group_size (`Any`, *optional*):
178
+ Hash byte group size for encoder.
179
+ encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 30000):
180
+ Vocabulary size for hash byte groups.
181
+ encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 3):
182
+ Number of hash functions for byte groups.
183
+ encoder_enable_byte_ngrams (`bool`, *optional*, defaults to False):
184
+ Whether to enable byte n-grams in encoder.
185
+ encoder_ngram_to_size_str (`str`, *optional*):
186
+ String defining n-gram sizes.
187
+ downsampling_by_pooling (`str`, *optional*):
188
+ Type of pooling for downsampling.
189
+
190
+ # Model behavior
191
+ share_encoder_decoder_emb (`bool`, *optional*, defaults to True):
192
+ Whether to share encoder and decoder embeddings.
193
+ weight_tying (`bool`, *optional*, defaults to False):
194
+ Whether to tie input and output embeddings.
195
+
196
+ # Performance optimization
197
+ sequence_parallel (`bool`, *optional*, defaults to False):
198
+ Whether to use sequence parallelism.
199
+ loss_parallel (`bool`, *optional*, defaults to False):
200
+ Whether to use loss parallelism.
201
+ fuse_sequence_parallel (`bool`, *optional*, defaults to False):
202
+ Whether to fuse sequence parallel operations.
203
+ use_fsdp (`bool`, *optional*, defaults to True):
204
+ Whether to use fully sharded data parallel.
205
+
206
+ # Parameter mixing
207
+ pm_size (`int`, *optional*, defaults to 0):
208
+ Parameter mixing size.
209
+
210
+ # Special tokens
211
+ bos_token_id (`int`, *optional*, defaults to 1):
212
+ The id of the "beginning-of-sequence" token.
213
+ eos_token_id (`int`, *optional*, defaults to 2):
214
+ The id of the "end-of-sequence" token.
215
+ pad_token_id (`int`, *optional*, defaults to -1):
216
+ The id of the padding token.
217
+
218
+ # Patcher/Entropy model configuration
219
+ patcher_vocab_size (`int`, *optional*, defaults to 256):
220
+ Vocabulary size for the entropy model used in patching.
221
+ patcher_dim (`int`, *optional*, defaults to 512):
222
+ Hidden dimension for the entropy model.
223
+ patcher_n_layers (`int`, *optional*, defaults to 8):
224
+ Number of layers in the entropy model.
225
+ patcher_n_heads (`int`, *optional*, defaults to 8):
226
+ Number of attention heads in the entropy model.
227
+ patcher_head_dim (`int`, *optional*):
228
+ Dimension of each attention head in the entropy model.
229
+ patcher_n_kv_heads (`int`, *optional*):
230
+ Number of key-value heads in the entropy model.
231
+ patcher_max_seqlen (`int`, *optional*, defaults to 1024):
232
+ Maximum sequence length for the entropy model.
233
+ patcher_norm_eps (`float`, *optional*, defaults to 1e-5):
234
+ Layer normalization epsilon for the entropy model.
235
+ patcher_dropout (`float`, *optional*, defaults to 0.0):
236
+ Dropout probability for the entropy model.
237
+ patcher_sliding_window (`int`, *optional*):
238
+ Sliding window size for the entropy model attention.
239
+ patcher_ffn_dim_multiplier (`float`, *optional*):
240
+ Feedforward dimension multiplier for the entropy model.
241
+ patcher_multiple_of (`int`, *optional*, defaults to 256):
242
+ Make feedforward dimension multiple of this for the entropy model.
243
+ patcher_rope_theta (`float`, *optional*, defaults to 10000.0):
244
+ RoPE theta parameter for the entropy model.
245
+ patcher_rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False):
246
+ Whether to use fp32 in RoPE outer product for the entropy model.
247
+ patcher_attn_impl (`str`, *optional*, defaults to "sdpa"):
248
+ Attention implementation for the entropy model.
249
+ patcher_attn_bias_type (`str`, *optional*, defaults to "causal"):
250
+ Attention bias type for the entropy model.
251
+ patcher_init_base_std (`float`, *optional*):
252
+ Base initialization standard deviation for the entropy model.
253
+ patcher_init_std_factor (`str`, *optional*, defaults to "disabled"):
254
+ Initialization std factor for the entropy model.
255
+ patcher_dim_token_emb (`int`, *optional*):
256
+ Token embedding dimension for the entropy model.
257
+ patcher_weight_tying (`bool`, *optional*, defaults to False):
258
+ Whether to tie embeddings in the entropy model.
259
+ patcher_bos_token_id (`int`, *optional*, defaults to 1):
260
+ Beginning of sequence token id for the entropy model.
261
+ patcher_eos_token_id (`int`, *optional*, defaults to 2):
262
+ End of sequence token id for the entropy model.
263
+
264
+ ```python
265
+ >>> from transformers import ByteLatentTransformer, BLTConfig
266
+
267
+ >>> # Initializing a BLT configuration
268
+ >>> configuration = BLTConfig()
269
+
270
+ >>> # Initializing a model from the configuration
271
+ >>> model = ByteLatentTransformer(configuration)
272
+
273
+ >>> # Accessing the model configuration
274
+ >>> configuration = model.config
275
+ ```"""
276
+
277
+ model_type = "blt"
278
+ keys_to_ignore_at_inference = ["past_key_values"]
279
+
280
+ def __init__(
281
+ self,
282
+ vocab_size=256,
283
+ max_seqlen=1024,
284
+ # Main architecture dimensions
285
+ dim=512,
286
+ n_layers=8,
287
+ n_heads=8,
288
+ head_dim=None,
289
+ n_kv_heads=None,
290
+ # Component-specific dimensions
291
+ dim_global=512,
292
+ dim_local_decoder=512,
293
+ dim_local_encoder=512,
294
+ n_layers_global=8,
295
+ n_layers_local_decoder=8,
296
+ n_layers_local_encoder=8,
297
+ n_heads_global=8,
298
+ n_heads_local_decoder=8,
299
+ n_heads_local_encoder=8,
300
+ n_kv_heads_global=None,
301
+ # Transformer configuration
302
+ norm_eps=1e-5,
303
+ dropout=0.0,
304
+ ffn_dim_multiplier=1.0,
305
+ multiple_of=256,
306
+ # Positional encoding
307
+ rope_theta=10000.0,
308
+ rope_use_fp32_in_outer_product=False,
309
+ # Attention configuration
310
+ attn_impl="sdpa",
311
+ attn_bias_type="causal",
312
+ local_attention_window_len=None,
313
+ use_rope=True,
314
+ # Initialization
315
+ init_base_std=None,
316
+ init_std_factor="disabled",
317
+ # Embedding dimensions
318
+ dim_token_emb=None,
319
+ dim_token=None,
320
+ # Patching configuration
321
+ patch_in_forward=False,
322
+ realtime_patching=True,
323
+ patch_size=None,
324
+ patching_mode=None,
325
+ patching_threshold=None,
326
+ patching_threshold_add=None,
327
+ monotonicity=False,
328
+ patching_batch_size=1,
329
+ patching_device="cuda",
330
+ max_patch_length=None,
331
+ entropy_model_checkpoint_dir=None,
332
+ # Cross attention configurations
333
+ cross_attn_encoder=False,
334
+ cross_attn_decoder=False,
335
+ cross_attn_window_encoder=None,
336
+ cross_attn_window_decoder=None,
337
+ cross_attn_k=None,
338
+ cross_attn_nheads=None,
339
+ cross_attn_all_layers_decoder=False,
340
+ cross_attn_all_layers_encoder=False,
341
+ cross_attn_use_flex_attention=True,
342
+ cross_attn_init_by_pooling=False,
343
+ # Encoder configurations
344
+ use_local_encoder_transformer=False,
345
+ max_encoder_seq_length=None,
346
+ encoder_hash_byte_group_size=None,
347
+ encoder_hash_byte_group_vocab=30000,
348
+ encoder_hash_byte_group_nb_functions=3,
349
+ encoder_enable_byte_ngrams=False,
350
+ encoder_ngram_to_size_str=None,
351
+ downsampling_by_pooling=None,
352
+ # Model behavior
353
+ share_encoder_decoder_emb=True,
354
+ weight_tying=False,
355
+ # Performance optimization
356
+ sequence_parallel=False,
357
+ loss_parallel=False,
358
+ fuse_sequence_parallel=False,
359
+ use_fsdp=True,
360
+ # Parameter mixing
361
+ pm_size=0,
362
+ # Special tokens
363
+ bos_token_id=1,
364
+ eos_token_id=2,
365
+ pad_token_id=-1,
366
+ # Patcher/Entropy model configuration
367
+ patcher_vocab_size=256,
368
+ patcher_dim=512,
369
+ patcher_n_layers=8,
370
+ patcher_n_heads=8,
371
+ patcher_head_dim=None,
372
+ patcher_n_kv_heads=None,
373
+ patcher_max_seqlen=1024,
374
+ patcher_norm_eps=1e-5,
375
+ patcher_dropout=0.0,
376
+ patcher_sliding_window=None,
377
+ patcher_ffn_dim_multiplier=None,
378
+ patcher_multiple_of=256,
379
+ patcher_rope_theta=10000.0,
380
+ patcher_rope_use_fp32_in_outer_product=False,
381
+ patcher_attn_impl="sdpa",
382
+ patcher_attn_bias_type="causal",
383
+ patcher_init_base_std=None,
384
+ patcher_init_std_factor="disabled",
385
+ patcher_dim_token_emb=None,
386
+ patcher_weight_tying=False,
387
+ patcher_bos_token_id=1,
388
+ patcher_eos_token_id=2,
389
+ # Inherited
390
+ **kwargs,
391
+ ):
392
+ # Basic model configuration
393
+ self.vocab_size = vocab_size
394
+ self.max_seqlen = max_seqlen
395
+
396
+ # Main architecture dimensions
397
+ self.dim = dim
398
+ self.n_layers = n_layers
399
+ self.n_heads = n_heads
400
+ self.head_dim = head_dim
401
+ self.n_kv_heads = n_kv_heads
402
+
403
+ # Component-specific dimensions
404
+ self.dim_global = dim_global
405
+ self.dim_local_decoder = dim_local_decoder
406
+ self.dim_local_encoder = dim_local_encoder
407
+ self.n_layers_global = n_layers_global
408
+ self.n_layers_local_decoder = n_layers_local_decoder
409
+ self.n_layers_local_encoder = n_layers_local_encoder
410
+ self.n_heads_global = n_heads_global
411
+ self.n_heads_local_decoder = n_heads_local_decoder
412
+ self.n_heads_local_encoder = n_heads_local_encoder
413
+ self.n_kv_heads_global = n_kv_heads_global
414
+
415
+ # Transformer configuration
416
+ self.norm_eps = norm_eps
417
+ self.dropout = dropout
418
+ self.ffn_dim_multiplier = ffn_dim_multiplier
419
+ self.multiple_of = multiple_of
420
+
421
+ # Positional encoding
422
+ self.rope_theta = rope_theta
423
+ self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
424
+
425
+ # Attention configuration
426
+ self.attn_impl = attn_impl
427
+ self.attn_bias_type = attn_bias_type
428
+ self.local_attention_window_len = local_attention_window_len
429
+ self.use_rope = use_rope
430
+
431
+ # Initialization
432
+ self.init_base_std = init_base_std
433
+ self.init_std_factor = InitStdFactor(init_std_factor)
434
+
435
+ # Embedding dimensions
436
+ self.dim_token_emb = dim_token_emb
437
+ self.dim_token = dim_token
438
+
439
+ # Patching configuration
440
+ self.patch_in_forward = patch_in_forward
441
+ self.realtime_patching = realtime_patching
442
+ self.patch_size = patch_size
443
+ self.patching_mode = patching_mode
444
+ self.patching_threshold = patching_threshold
445
+ self.patching_threshold_add = patching_threshold_add
446
+ self.monotonicity = monotonicity
447
+ self.patching_batch_size = patching_batch_size
448
+ self.patching_device = patching_device
449
+ self.max_patch_length = max_patch_length
450
+ self.entropy_model_checkpoint_dir = entropy_model_checkpoint_dir
451
+
452
+ # Cross attention configurations
453
+ self.cross_attn_encoder = cross_attn_encoder
454
+ self.cross_attn_decoder = cross_attn_decoder
455
+ self.cross_attn_window_encoder = cross_attn_window_encoder
456
+ self.cross_attn_window_decoder = cross_attn_window_decoder
457
+ self.cross_attn_k = cross_attn_k
458
+ self.cross_attn_nheads = cross_attn_nheads
459
+ self.cross_attn_all_layers_decoder = cross_attn_all_layers_decoder
460
+ self.cross_attn_all_layers_encoder = cross_attn_all_layers_encoder
461
+ self.cross_attn_use_flex_attention = cross_attn_use_flex_attention
462
+ self.cross_attn_init_by_pooling = cross_attn_init_by_pooling
463
+
464
+ # Encoder configurations
465
+ self.use_local_encoder_transformer = use_local_encoder_transformer
466
+ self.max_encoder_seq_length = max_encoder_seq_length
467
+ self.encoder_hash_byte_group_size = encoder_hash_byte_group_size
468
+ self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab
469
+ self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions
470
+ self.encoder_enable_byte_ngrams = encoder_enable_byte_ngrams
471
+ self.encoder_ngram_to_size_str = encoder_ngram_to_size_str
472
+ self.downsampling_by_pooling = downsampling_by_pooling
473
+
474
+ # Model behavior
475
+ self.share_encoder_decoder_emb = share_encoder_decoder_emb
476
+ self.weight_tying = weight_tying
477
+
478
+ # Performance optimization
479
+ self.sequence_parallel = sequence_parallel
480
+ self.loss_parallel = loss_parallel
481
+ self.fuse_sequence_parallel = fuse_sequence_parallel
482
+ self.use_fsdp = use_fsdp
483
+
484
+ # Parameter mixing
485
+ self.pm_size = pm_size
486
+
487
+ # Patcher/Entropy model configuration
488
+ self.patcher_vocab_size = patcher_vocab_size
489
+ self.patcher_dim = patcher_dim
490
+ self.patcher_n_layers = patcher_n_layers
491
+ self.patcher_n_heads = patcher_n_heads
492
+ self.patcher_head_dim = patcher_head_dim
493
+ self.patcher_n_kv_heads = patcher_n_kv_heads
494
+ self.patcher_max_seqlen = patcher_max_seqlen
495
+ self.patcher_norm_eps = patcher_norm_eps
496
+ self.patcher_dropout = patcher_dropout
497
+ self.patcher_sliding_window = patcher_sliding_window
498
+ self.patcher_ffn_dim_multiplier = patcher_ffn_dim_multiplier
499
+ self.patcher_multiple_of = patcher_multiple_of
500
+ self.patcher_rope_theta = patcher_rope_theta
501
+ self.patcher_rope_use_fp32_in_outer_product = patcher_rope_use_fp32_in_outer_product
502
+ self.patcher_attn_impl = patcher_attn_impl
503
+ self.patcher_attn_bias_type = patcher_attn_bias_type
504
+ self.patcher_init_base_std = patcher_init_base_std
505
+ self.patcher_init_std_factor = InitStdFactor(patcher_init_std_factor)
506
+ self.patcher_dim_token_emb = patcher_dim_token_emb
507
+ self.patcher_weight_tying = patcher_weight_tying
508
+ self.patcher_bos_token_id = patcher_bos_token_id
509
+ self.patcher_eos_token_id = patcher_eos_token_id
510
+
511
+ # Handle hash byte group size validation
512
+ if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str:
513
+ self.encoder_hash_byte_group_size = [
514
+ int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0
515
+ ]
516
+
517
+ super().__init__(
518
+ bos_token_id=bos_token_id,
519
+ eos_token_id=eos_token_id,
520
+ pad_token_id=pad_token_id,
521
+ **kwargs,
522
+ )
523
+
524
+ @property
525
+ def encoder_dim_token_emb(self):
526
+ """Compute encoder token embedding dimension."""
527
+ if self.dim_token is not None:
528
+ return self.dim_token
529
+ elif self.use_local_encoder_transformer:
530
+ return self.dim_local_encoder
531
+ else:
532
+ # Use default patch_size of 8 if not set
533
+ patch_size = self.patch_size if self.patch_size is not None else 8
534
+ return self.dim_global // patch_size
535
+
536
+ @property
537
+ def encoder_dim_patch_emb(self):
538
+ """Compute encoder patch embedding dimension."""
539
+ if self.cross_attn_encoder:
540
+ if self.cross_attn_init_by_pooling:
541
+ return self.dim_local_encoder
542
+ else:
543
+ return self.dim_global
544
+ return None
545
+
546
+ @property
547
+ def global_dim_patch_emb(self):
548
+ """Compute global patch embedding dimension."""
549
+ dim_token_emb = self.encoder_dim_token_emb
550
+ if self.cross_attn_encoder:
551
+ cross_attn_k = self.cross_attn_k if self.cross_attn_k is not None else 1
552
+ return dim_token_emb * cross_attn_k
553
+ elif (
554
+ self.downsampling_by_pooling is None
555
+ or not self.downsampling_by_pooling
556
+ or len(self.downsampling_by_pooling) == 0
557
+ ):
558
+ # Use default patch_size of 8 if not set
559
+ patch_size = self.patch_size if self.patch_size is not None else 8
560
+ return dim_token_emb * patch_size
561
+ else:
562
+ return dim_token_emb * sum([pooling in self.downsampling_by_pooling for pooling in ["avg", "min", "max"]])
563
+
564
+ @property
565
+ def decoder_dim_token_emb(self):
566
+ """Compute decoder token embedding dimension."""
567
+ if self.share_encoder_decoder_emb:
568
+ return self.encoder_dim_token_emb
569
+ elif self.dim_token is not None:
570
+ return self.dim_token
571
+ else:
572
+ return self.dim_local_decoder
573
+
574
+ def get_init_std_factor(self, depth: int) -> float:
575
+ """
576
+ Calculate the initialization standard deviation scaling factor for a given layer depth.
577
+
578
+ Args:
579
+ depth: Current layer depth (0-indexed)
580
+
581
+ Returns:
582
+ Scaling factor to divide the base initialization std by
583
+ """
584
+ if self.init_std_factor == InitStdFactor.CURRENT_DEPTH:
585
+ return (2 * (depth + 1)) ** 0.5
586
+ else: # DISABLED
587
+ return 1.0
588
+
589
+
590
+ __all__ = ["BLTConfig", "InitStdFactor", "PatchingModeEnum"]
backup_blt_wip_backup/convert_hf_blt_original_to_unified.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ from typing import Dict, Any, Optional
6
+
7
+ import torch
8
+ from huggingface_hub import hf_hub_download, snapshot_download
9
+ from safetensors.torch import load_file, save_file
10
+
11
+ from transformers.utils import logging as transformers_logging
12
+
13
+ logger = transformers_logging.get_logger(__name__)
14
+ transformers_logging.set_verbosity_info()
15
+
16
+ # For standalone execution, we'll skip the model validation to avoid import issues
17
+ # The script will create the unified config and weights files without testing model instantiation
18
+ ENABLE_MODEL_VALIDATION = False
19
+
20
+ import sys
21
+ import os
22
+
23
+ from transformers.models.blt_wip.modeling_blt_wip import BLTModel
24
+ from transformers.models.blt_wip.configuration_blt import BLTConfig
25
+
26
+
27
+ ENABLE_MODEL_VALIDATION = True
28
+
29
+ def download_model_files(model_id: str, cache_dir: Optional[str] = None) -> Dict[str, str]:
30
+ """
31
+ Download all necessary files from HuggingFace Hub.
32
+
33
+ Args:
34
+ model_id: HuggingFace model ID (e.g., "facebook/blt-1b")
35
+ cache_dir: Optional cache directory
36
+
37
+ Returns:
38
+ Dictionary with paths to downloaded files
39
+ """
40
+ logger.info(f"Downloading model files from {model_id}...")
41
+
42
+ try:
43
+ # Download main config
44
+ config_path = hf_hub_download(
45
+ repo_id=model_id,
46
+ filename="config.json",
47
+ cache_dir=cache_dir
48
+ )
49
+
50
+ # Download main model weights
51
+ weights_path = hf_hub_download(
52
+ repo_id=model_id,
53
+ filename="model.safetensors",
54
+ cache_dir=cache_dir
55
+ )
56
+
57
+ # Download entropy model params
58
+ entropy_params_path = hf_hub_download(
59
+ repo_id=model_id,
60
+ filename="entropy_model/params.json",
61
+ cache_dir=cache_dir
62
+ )
63
+
64
+ # Download entropy model weights
65
+ entropy_weights_path = hf_hub_download(
66
+ repo_id=model_id,
67
+ filename="entropy_model/consolidated.pth",
68
+ cache_dir=cache_dir
69
+ )
70
+
71
+ return {
72
+ "config": config_path,
73
+ "weights": weights_path,
74
+ "entropy_params": entropy_params_path,
75
+ "entropy_weights": entropy_weights_path
76
+ }
77
+
78
+ except Exception as e:
79
+ logger.error(f"Failed to download files from {model_id}: {e}")
80
+ raise
81
+
82
+
83
+ def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]:
84
+ """
85
+ Merge main configuration with entropy model parameters.
86
+
87
+ Args:
88
+ config_path: Path to main config.json
89
+ entropy_params_path: Path to entropy_model/params.json
90
+
91
+ Returns:
92
+ Merged configuration dictionary
93
+ """
94
+ logger.info("Merging configurations...")
95
+
96
+ # Load main configuration
97
+ with open(config_path, 'r') as f:
98
+ main_config = json.load(f)
99
+
100
+ # Load entropy model parameters
101
+ with open(entropy_params_path, 'r') as f:
102
+ entropy_data = json.load(f)
103
+
104
+ # Extract entropy model and patcher parameters
105
+ entropy_model_params = entropy_data.get("entropy_model", {})
106
+ patcher_args = entropy_data.get("data", {}).get("patcher_args", {})
107
+
108
+ # Create unified configuration
109
+ unified_config = main_config.copy()
110
+
111
+ # Ensure required main model parameters are present with correct types
112
+ # Sometimes the original config may have different key names
113
+ if "vocab_size" not in unified_config:
114
+ unified_config["vocab_size"] = int(main_config.get("vocab_size", 256))
115
+ if "dim" not in unified_config:
116
+ unified_config["dim"] = int(main_config.get("dim", main_config.get("hidden_size", main_config.get("d_model", 512))))
117
+ if "n_layers" not in unified_config:
118
+ unified_config["n_layers"] = int(main_config.get("n_layers", main_config.get("num_layers", main_config.get("num_hidden_layers", 8))))
119
+ if "n_heads" not in unified_config:
120
+ unified_config["n_heads"] = int(main_config.get("n_heads", main_config.get("num_attention_heads", main_config.get("num_heads", 8))))
121
+ if "max_seqlen" not in unified_config:
122
+ unified_config["max_seqlen"] = int(main_config.get("max_seqlen", main_config.get("max_position_embeddings", main_config.get("seq_length", 1024))))
123
+
124
+ # Ensure other integer parameters are properly typed
125
+ for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]:
126
+ if key in unified_config and not isinstance(unified_config[key], int):
127
+ unified_config[key] = int(unified_config[key])
128
+
129
+ # Convert all patch_size values to integers to avoid float/int type errors
130
+ patch_size = patcher_args.get("patch_size", 8)
131
+ if isinstance(patch_size, float):
132
+ patch_size = int(patch_size)
133
+
134
+ # Add patching configuration
135
+ unified_config.update({
136
+ "patch_in_forward": True,
137
+ "realtime_patching": True,
138
+ "patching_mode": "entropy",
139
+
140
+ # Patcher arguments
141
+ "patch_size": patch_size,
142
+ "patching_threshold": patcher_args.get("threshold", 0.5),
143
+ "patching_threshold_add": patcher_args.get("threshold_add", 0.0),
144
+ "max_patch_length": patcher_args.get("max_patch_length"),
145
+ "patching_batch_size": patcher_args.get("patching_batch_size", 1),
146
+ "patching_device": patcher_args.get("patching_device", "cuda"),
147
+ "monotonicity": patcher_args.get("monotonicity", False),
148
+
149
+ # Entropy model (patcher) architecture parameters
150
+ "patcher_vocab_size": int(entropy_model_params.get("vocab_size", 256)),
151
+ "patcher_dim": int(entropy_model_params.get("dim", 512)),
152
+ "patcher_n_layers": int(entropy_model_params.get("n_layers", 8)),
153
+ "patcher_n_heads": int(entropy_model_params.get("n_heads", 8)),
154
+ "patcher_head_dim": int(entropy_model_params.get("head_dim")) if entropy_model_params.get("head_dim") is not None else None,
155
+ "patcher_n_kv_heads": int(entropy_model_params.get("n_kv_heads")) if entropy_model_params.get("n_kv_heads") is not None else None,
156
+ "patcher_max_seqlen": int(entropy_model_params.get("max_seqlen", 1024)),
157
+ "patcher_norm_eps": entropy_model_params.get("norm_eps", 1e-5),
158
+ "patcher_dropout": entropy_model_params.get("dropout", 0.0),
159
+ "patcher_sliding_window": int(entropy_model_params.get("sliding_window", 512)) if entropy_model_params.get("sliding_window") is not None else None,
160
+ "patcher_ffn_dim_multiplier": entropy_model_params.get("ffn_dim_multiplier"),
161
+ "patcher_multiple_of": int(entropy_model_params.get("multiple_of", 256)),
162
+ "patcher_rope_theta": entropy_model_params.get("rope_theta", 10000.0),
163
+ "patcher_rope_use_fp32_in_outer_product": entropy_model_params.get("rope_use_fp32_in_outer_product", False),
164
+ "patcher_attn_impl": entropy_model_params.get("attn_impl", "sdpa"),
165
+ "patcher_attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"),
166
+ "patcher_init_base_std": entropy_model_params.get("init_base_std"),
167
+ "patcher_init_std_factor": entropy_model_params.get("init_std_factor", "disabled"),
168
+ "patcher_dim_token_emb": entropy_model_params.get("dim_token_emb"),
169
+ "patcher_weight_tying": entropy_model_params.get("weight_tying", False),
170
+ "patcher_bos_token_id": entropy_model_params.get("bos_token_id", 1),
171
+ "patcher_eos_token_id": entropy_model_params.get("eos_token_id", 2),
172
+ })
173
+
174
+ logger.info(f"Merged configuration with {len(unified_config)} parameters")
175
+ return unified_config
176
+
177
+
178
+ def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]:
179
+ """
180
+ Merge main model weights with entropy model weights.
181
+
182
+ Args:
183
+ weights_path: Path to main model.safetensors
184
+ entropy_weights_path: Path to entropy_model/consolidated.pth
185
+
186
+ Returns:
187
+ Merged state dictionary
188
+ """
189
+ logger.info("Merging model weights...")
190
+
191
+ # Load main model weights
192
+ main_weights = load_file(weights_path)
193
+ logger.info(f"Loaded main model weights: {len(main_weights)} tensors")
194
+
195
+ # Load entropy model weights
196
+ entropy_weights = torch.load(entropy_weights_path, map_location='cpu', weights_only=True)
197
+
198
+ # Handle nested entropy model structure
199
+ if 'model' in entropy_weights:
200
+ entropy_weights = entropy_weights['model']
201
+ elif 'state_dict' in entropy_weights:
202
+ entropy_weights = entropy_weights['state_dict']
203
+
204
+ logger.info(f"Loaded entropy model weights: {len(entropy_weights)} tensors")
205
+
206
+ # Create unified state dict
207
+ unified_weights = main_weights.copy()
208
+
209
+ # Add entropy model weights with "patcher." prefix
210
+ for key, tensor in entropy_weights.items():
211
+ patcher_key = f"patcher.{key}"
212
+ unified_weights[patcher_key] = tensor
213
+
214
+ logger.info(f"Merged weights: {len(unified_weights)} tensors total")
215
+ return unified_weights
216
+
217
+
218
+ def create_tokenizer_config(output_dir: str, config: Dict[str, Any]):
219
+ """
220
+ Create tokenizer configuration file.
221
+
222
+ Args:
223
+ output_dir: Output directory
224
+ config: Model configuration
225
+ """
226
+ logger.info("Creating tokenizer configuration...")
227
+
228
+ tokenizer_config = {
229
+ "tokenizer_class": "BltTokenizer",
230
+ "vocab_size": config.get("vocab_size", 256),
231
+ "model_max_length": config.get("max_seqlen", 1024),
232
+ "add_bos_token": True,
233
+ "add_eos_token": True,
234
+ "bos_token": "<s>",
235
+ "eos_token": "</s>",
236
+ "pad_token": "<pad>",
237
+ "unk_token": "<unk>",
238
+ }
239
+
240
+ tokenizer_path = os.path.join(output_dir, "tokenizer_config.json")
241
+ with open(tokenizer_path, 'w') as f:
242
+ json.dump(tokenizer_config, f, indent=2)
243
+
244
+ logger.info(f"Tokenizer config saved to {tokenizer_path}")
245
+
246
+
247
+ def validate_unified_model(config: Dict[str, Any], weights: Dict[str, torch.Tensor]):
248
+ """
249
+ Validate the unified model configuration and weights.
250
+
251
+ Args:
252
+ config: Unified configuration
253
+ weights: Unified weights
254
+ """
255
+ logger.info("Validating unified model...")
256
+
257
+ # Check required configuration keys
258
+ required_keys = [
259
+ "vocab_size", "dim", "n_layers", "n_heads",
260
+ "patch_in_forward", "patcher_vocab_size", "patcher_dim"
261
+ ]
262
+
263
+ missing_keys = [key for key in required_keys if key not in config]
264
+ if missing_keys:
265
+ logger.warning(f"Missing configuration keys: {missing_keys}")
266
+
267
+ # Check for patcher weights
268
+ patcher_weights = [key for key in weights.keys() if key.startswith("patcher.")]
269
+ if not patcher_weights:
270
+ logger.warning("No patcher weights found in unified weights")
271
+ else:
272
+ logger.info(f"Found {len(patcher_weights)} patcher weight tensors")
273
+
274
+ # Check for main model weights
275
+ main_weights = [key for key in weights.keys() if not key.startswith("patcher.")]
276
+ logger.info(f"Found {len(main_weights)} main model weight tensors")
277
+
278
+ # Try to create the model with the configuration (if imports are available)
279
+ if ENABLE_MODEL_VALIDATION and BLTConfig is not None and BLTModel is not None:
280
+ try:
281
+ logger.info("Testing model instantiation...")
282
+
283
+ # Debug: Print config keys to help diagnose issues
284
+ logger.debug(f"Config keys: {list(config.keys())}")
285
+ logger.debug(f"Config vocab_size: {config.get('vocab_size')} (type: {type(config.get('vocab_size'))})")
286
+ logger.debug(f"Config dim: {config.get('dim')} (type: {type(config.get('dim'))})")
287
+
288
+ blt_config = BLTConfig(**config)
289
+ model = BLTModel(blt_config)
290
+ logger.info("✓ Model instantiation successful")
291
+
292
+ # Try to load the weights
293
+ logger.info("Testing weight loading...")
294
+ try:
295
+ missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False)
296
+ if missing_keys:
297
+ logger.warning(f"Missing keys during weight loading: {missing_keys[:5]}...") # Show first 5
298
+ if unexpected_keys:
299
+ logger.warning(f"Unexpected keys during weight loading: {unexpected_keys[:5]}...") # Show first 5
300
+ logger.info("✓ Weight loading successful")
301
+ except Exception as weight_error:
302
+ logger.warning(f"Weight loading failed: {weight_error}")
303
+ logger.info("Model instantiation successful, but weight loading had issues")
304
+
305
+ except Exception as e:
306
+ logger.error(f"Model validation failed: {e}")
307
+ logger.debug(f"Full error details:", exc_info=True)
308
+ logger.warning("Model may not be compatible with modeling_blt_wip.py")
309
+ logger.info("You can still use the converted files and test manually")
310
+ else:
311
+ logger.info("Skipping model instantiation test (BLT classes not available)")
312
+ logger.info("You can test the model manually after conversion")
313
+
314
+ logger.info("Model validation completed")
315
+
316
+
317
+ def convert_hf_blt_to_unified(
318
+ model_id: str,
319
+ output_dir: str,
320
+ config_name: str = "config.json",
321
+ weights_name: str = "pytorch_model.bin",
322
+ safe_serialization: bool = True,
323
+ cache_dir: Optional[str] = None,
324
+ validate: bool = True,
325
+ ) -> None:
326
+ """
327
+ Convert BLT model from HuggingFace Hub format to unified format.
328
+
329
+ Args:
330
+ model_id: HuggingFace model ID (e.g., "facebook/blt-1b")
331
+ output_dir: Output directory for unified model
332
+ config_name: Name for unified config file
333
+ weights_name: Name for unified weights file
334
+ safe_serialization: Whether to use safetensors format
335
+ cache_dir: Cache directory for downloads
336
+ validate: Whether to validate the unified model
337
+ """
338
+ logger.info(f"Converting {model_id} to unified format...")
339
+
340
+ # Download model files
341
+ file_paths = download_model_files(model_id, cache_dir)
342
+
343
+ # Merge configurations
344
+ unified_config = merge_configurations(
345
+ file_paths["config"],
346
+ file_paths["entropy_params"]
347
+ )
348
+
349
+ # Merge weights
350
+ unified_weights = merge_weights(
351
+ file_paths["weights"],
352
+ file_paths["entropy_weights"]
353
+ )
354
+
355
+ # Validate if requested
356
+ if validate:
357
+ validate_unified_model(unified_config, unified_weights)
358
+
359
+ # Create output directory
360
+ os.makedirs(output_dir, exist_ok=True)
361
+
362
+ # Save unified configuration
363
+ config_path = os.path.join(output_dir, config_name)
364
+ with open(config_path, 'w') as f:
365
+ json.dump(unified_config, f, indent=2)
366
+ logger.info(f"Unified config saved to {config_path}")
367
+
368
+ # Save unified weights
369
+ if safe_serialization and weights_name.endswith('.bin'):
370
+ weights_name = weights_name.replace('.bin', '.safetensors')
371
+ elif not safe_serialization and weights_name.endswith('.safetensors'):
372
+ weights_name = weights_name.replace('.safetensors', '.bin')
373
+
374
+ weights_path = os.path.join(output_dir, weights_name)
375
+ if safe_serialization:
376
+ save_file(unified_weights, weights_path)
377
+ else:
378
+ torch.save(unified_weights, weights_path)
379
+ logger.info(f"Unified weights saved to {weights_path}")
380
+
381
+ # Create tokenizer config
382
+ create_tokenizer_config(output_dir, unified_config)
383
+
384
+ # Create README
385
+ readme_path = os.path.join(output_dir, "README.md")
386
+ with open(readme_path, 'w') as f:
387
+ f.write(f"""# Unified BLT Model
388
+
389
+ This model was converted from {model_id} to unified format compatible with modeling_blt_wip.py.
390
+
391
+ ## Files
392
+
393
+ - `{config_name}`: Unified configuration (main config + entropy model params)
394
+ - `{weights_name}`: Unified weights (main model + entropy model weights with "patcher." prefix)
395
+ - `tokenizer_config.json`: Tokenizer configuration
396
+
397
+ ## Usage
398
+
399
+ ```python
400
+ import torch
401
+ import json
402
+ from modeling_blt_wip import BLTModel, BLTConfig
403
+
404
+ # Load configuration
405
+ with open('{config_name}', 'r') as f:
406
+ config_dict = json.load(f)
407
+
408
+ config = BLTConfig(**config_dict)
409
+
410
+ # Load model
411
+ model = BLTModel(config)
412
+
413
+ # Load weights
414
+ if '{weights_name}'.endswith('.safetensors'):
415
+ from safetensors.torch import load_file
416
+ state_dict = load_file('{weights_name}')
417
+ else:
418
+ state_dict = torch.load('{weights_name}', map_location='cpu')
419
+
420
+ model.load_state_dict(state_dict, strict=False)
421
+ ```
422
+
423
+ ## Original Model
424
+
425
+ Converted from: {model_id}
426
+ """)
427
+
428
+ logger.info(f"Conversion completed! Unified model saved to: {output_dir}")
429
+
430
+
431
+ def main():
432
+ parser = argparse.ArgumentParser(
433
+ description="Convert BLT models from HuggingFace Hub format to unified format",
434
+ formatter_class=argparse.RawDescriptionHelpFormatter,
435
+ epilog="""
436
+ Examples:
437
+ # Convert facebook/blt-1b to unified format
438
+ python convert_hf_blt_to_unified.py \\
439
+ --model_id facebook/blt-1b \\
440
+ --output_dir ./unified_blt_1b
441
+
442
+ # Convert with custom file names
443
+ python convert_hf_blt_to_unified.py \\
444
+ --model_id facebook/blt-7b \\
445
+ --output_dir ./unified_blt_7b \\
446
+ --config_name unified_config.json \\
447
+ --weights_name unified_model.safetensors
448
+
449
+ # Convert without validation
450
+ python convert_hf_blt_to_unified.py \\
451
+ --model_id facebook/blt-1b \\
452
+ --output_dir ./my_blt \\
453
+ --no_validate
454
+ """
455
+ )
456
+
457
+ # Required arguments (with defaults for debugging)
458
+ parser.add_argument(
459
+ "--model_id",
460
+ type=str,
461
+ default="facebook/blt-1b",
462
+ help="HuggingFace model ID (e.g., facebook/blt-1b)"
463
+ )
464
+ parser.add_argument(
465
+ "--output_dir",
466
+ type=str,
467
+ default="./unified_blt_debug",
468
+ help="Output directory for unified model"
469
+ )
470
+
471
+ # Optional arguments
472
+ parser.add_argument(
473
+ "--config_name",
474
+ type=str,
475
+ default="config.json",
476
+ help="Name for unified config file (default: config.json)"
477
+ )
478
+ parser.add_argument(
479
+ "--weights_name",
480
+ type=str,
481
+ default="pytorch_model.bin",
482
+ help="Name for unified weights file (default: pytorch_model.bin)"
483
+ )
484
+ parser.add_argument(
485
+ "--safe_serialization",
486
+ action="store_true",
487
+ default=True,
488
+ help="Use safetensors format for weights (default: True)"
489
+ )
490
+ parser.add_argument(
491
+ "--no_safe_serialization",
492
+ dest="safe_serialization",
493
+ action="store_false",
494
+ help="Use .bin format instead of safetensors"
495
+ )
496
+ parser.add_argument(
497
+ "--cache_dir",
498
+ type=str,
499
+ default=None,
500
+ help="Cache directory for downloads"
501
+ )
502
+ parser.add_argument(
503
+ "--no_validate",
504
+ dest="validate",
505
+ action="store_false",
506
+ default=True,
507
+ help="Skip model validation"
508
+ )
509
+ parser.add_argument(
510
+ "--debug",
511
+ action="store_true",
512
+ default=True, # Enable debug by default for easier debugging
513
+ help="Enable debug logging"
514
+ )
515
+
516
+ args = parser.parse_args()
517
+
518
+ # Setup logging
519
+ if args.debug:
520
+ transformers_logging.set_verbosity_debug()
521
+ logging.basicConfig(level=logging.DEBUG)
522
+
523
+ # Run conversion
524
+ try:
525
+ convert_hf_blt_to_unified(
526
+ model_id=args.model_id,
527
+ output_dir=args.output_dir,
528
+ config_name=args.config_name,
529
+ weights_name=args.weights_name,
530
+ safe_serialization=args.safe_serialization,
531
+ cache_dir=args.cache_dir,
532
+ validate=args.validate,
533
+ )
534
+ except Exception as e:
535
+ logger.error(f"Conversion failed: {e}")
536
+ raise
537
+
538
+
539
+ if __name__ == "__main__":
540
+ main()
backup_blt_wip_backup/modeling_blt_wip.py ADDED
@@ -0,0 +1,1836 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import logging
4
+ import os
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.nn
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
12
+
13
+ from ...modeling_utils import PreTrainedModel
14
+ from .configuration_blt import (
15
+ BLTConfig,
16
+ PatchingModeEnum,
17
+ )
18
+
19
+
20
+ SEP = " "
21
+ BOS_ID: int = 1
22
+ EOS_ID: int = 2
23
+ PAD_ID: int = -1
24
+ BOE_ID: int = 0
25
+ BPE_ID: int = 3
26
+ OFFSET: int = 4
27
+
28
+ BYTE_UNITS: int = 256
29
+
30
+ RMSNorm = nn.RMSNorm
31
+
32
+ logger = logging.getLogger()
33
+
34
+ flex_attention_comp = flex_attention
35
+
36
+
37
+ def causal_mask(b, h, q_idx, kv_idx):
38
+ return q_idx >= kv_idx
39
+
40
+
41
+ def create_causal_mask(
42
+ seqlen,
43
+ attn_impl: str,
44
+ attn_bias_type: str | None,
45
+ *,
46
+ eos_id: int | None = None,
47
+ tokens: torch.Tensor | None = None,
48
+ sliding_window: int | None = None,
49
+ ):
50
+ if attn_impl == "sdpa":
51
+ BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0))
52
+
53
+ if attn_bias_type == "causal":
54
+ return "causal"
55
+
56
+ if BLT_SUPPRESS_ATTN_ERROR == 1:
57
+ return "causal"
58
+ else:
59
+ raise ValueError(
60
+ "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1"
61
+ )
62
+ elif attn_impl == "flex_attention":
63
+ return create_block_mask(causal_mask, None, None, seqlen, seqlen)
64
+ else:
65
+ raise NotImplementedError(f"Attention {attn_impl} with {sliding_window} sliding window not implemented")
66
+
67
+
68
+ def cross_entropy(pred, target, **kwargs):
69
+ return F.nll_loss(
70
+ F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
71
+ target.flatten(end_dim=-1),
72
+ **kwargs,
73
+ )
74
+
75
+
76
+ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
77
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
78
+ assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
79
+ bs, slen, n_kv_heads, head_dim = x.shape
80
+ if n_rep == 1:
81
+ return x
82
+ return (
83
+ x[:, :, :, None, :]
84
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
85
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
86
+ )
87
+
88
+
89
+ def precompute_freqs_cis(
90
+ dim: int,
91
+ end: int,
92
+ theta: float = 10000.0,
93
+ rope_use_fp32_in_outer_product: bool = False,
94
+ ):
95
+ """
96
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
97
+
98
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
99
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
100
+ The returned tensor contains complex values in complex64 data type.
101
+
102
+ Args:
103
+ dim (int): Dimension of the frequency tensor.
104
+ end (int): End index for precomputing frequencies.
105
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
106
+
107
+ Returns:
108
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
109
+ """
110
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
111
+ t = torch.arange(end, device=freqs.device)
112
+ if rope_use_fp32_in_outer_product:
113
+ t = t.to(torch.float32)
114
+
115
+ freqs = torch.outer(t, freqs).float()
116
+
117
+ cos, sin = freqs.cos(), freqs.sin()
118
+
119
+ return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
120
+
121
+
122
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
123
+ """
124
+ Reshape frequency tensor for broadcasting it with another tensor.
125
+
126
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
127
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
128
+
129
+ Args:
130
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
131
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
132
+ seq_dim (int): Sequence dimension index.
133
+
134
+ Returns:
135
+ torch.Tensor: Reshaped frequency tensor.
136
+ """
137
+ ndim = x.ndim
138
+ assert 0 <= seq_dim < ndim
139
+ assert freqs_cis.shape == (
140
+ x.shape[seq_dim],
141
+ x.shape[-3],
142
+ 2,
143
+ 2,
144
+ ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
145
+ shape = [d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])] + [2, 2]
146
+ return freqs_cis.view(*shape)
147
+
148
+
149
+ def apply_rotary_emb(
150
+ xq: torch.Tensor,
151
+ xk: torch.Tensor,
152
+ seq_dim: int,
153
+ freqs_cis: torch.Tensor,
154
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
155
+ xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
156
+ xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
157
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, seq_dim).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
158
+ xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
159
+ xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
160
+ return xq_out.type_as(xq), xk_out.type_as(xk)
161
+
162
+
163
+ # Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
164
+ class RotaryEmbedding(torch.nn.Module):
165
+ """
166
+ RotaryEmbedding Module
167
+ """
168
+
169
+ def __init__(
170
+ self,
171
+ theta: float,
172
+ head_dim: int,
173
+ max_seqlen: int = 1024,
174
+ rope_use_fp32_in_outer_product: bool = False,
175
+ ):
176
+ super().__init__()
177
+
178
+ self.theta = theta
179
+ self.head_dim = head_dim
180
+ self.max_seqlen = max_seqlen
181
+ self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
182
+
183
+ self.register_buffer(
184
+ "freqs_cis",
185
+ precompute_freqs_cis(
186
+ dim=head_dim,
187
+ end=max_seqlen,
188
+ theta=theta,
189
+ rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
190
+ ),
191
+ persistent=False,
192
+ )
193
+
194
+
195
+ def forward(self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None):
196
+ """
197
+ Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
198
+ Args:
199
+ seqlen (int): Contiguous sequence length
200
+ tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
201
+
202
+ Returns:
203
+ Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
204
+ """
205
+ test = (seqlen is not None) or (tok_idx is not None)
206
+ assert test, "Should provide atleast seqlen or tok_idx"
207
+ if tok_idx is not None:
208
+ return self.freqs_cis[tok_idx]
209
+ elif seqlen is not None:
210
+ return self.freqs_cis[0:seqlen]
211
+
212
+
213
+ class BLTAttention(nn.Module):
214
+ def __init__(
215
+ self,
216
+ dim: int,
217
+ head_dim: int,
218
+ n_heads: int,
219
+ n_kv_heads: int,
220
+ rope_theta: float,
221
+ ):
222
+ super().__init__()
223
+
224
+ self.dim = dim
225
+ self.head_dim = head_dim
226
+ self.rope_theta = rope_theta
227
+
228
+ self.n_heads = n_heads
229
+ self.n_kv_heads = n_kv_heads
230
+ self.heads_per_group = self.n_heads // self.n_kv_heads
231
+
232
+ self.wq = nn.Linear(
233
+ dim,
234
+ n_heads * head_dim,
235
+ bias=False,
236
+ )
237
+ self.wk = nn.Linear(
238
+ dim,
239
+ n_kv_heads * head_dim,
240
+ bias=False,
241
+ )
242
+ self.wv = nn.Linear(
243
+ dim,
244
+ n_kv_heads * head_dim,
245
+ bias=False,
246
+ )
247
+
248
+ self.wo = nn.Linear(
249
+ n_heads * head_dim,
250
+ dim,
251
+ bias=False,
252
+ )
253
+
254
+ def forward(
255
+ self,
256
+ x: torch.Tensor,
257
+ freq_cis: torch.Tensor,
258
+ tok_idx: Optional[torch.Tensor] = None,
259
+ mask: Optional[Union[BlockMask, str]] = None,
260
+ attn_impl: str = "sdpa",
261
+ ) -> torch.Tensor:
262
+ # B S D
263
+ bsz, seq_len, dim = x.shape
264
+ xq = self.wq(x.view_as(x))
265
+ xk = self.wk(x.view_as(x))
266
+ xv = self.wv(x.view_as(x))
267
+
268
+ output_shape = xq.shape
269
+ # B S D -> B S H D
270
+ xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
271
+ xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
272
+ xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
273
+
274
+ xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
275
+
276
+ # This condition helps us be easily compatible
277
+ # with inference by adding a pluggable KVCache
278
+ if hasattr(self, "kv_cache"):
279
+ xk, xv = self.kv_cache.update(xk, xv, tok_idx)
280
+
281
+ xk = repeat_kv(xk, self.heads_per_group, dim=2)
282
+ xv = repeat_kv(xv, self.heads_per_group, dim=2)
283
+
284
+ if attn_impl == "flex_attention":
285
+ assert mask is None or isinstance(mask, BlockMask)
286
+ xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv))
287
+ output = flex_attention_comp(xq, xk, xv, block_mask=mask)
288
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
289
+
290
+ elif attn_impl == "sdpa":
291
+ xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv))
292
+ assert mask is None or isinstance(mask, (str, torch.Tensor))
293
+ is_causal = (mask == "causal") if isinstance(mask, str) else False
294
+ mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None
295
+ output = F.scaled_dot_product_attention(
296
+ xq,
297
+ xk,
298
+ xv,
299
+ is_causal=is_causal,
300
+ attn_mask=mask,
301
+ )
302
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
303
+ else:
304
+ raise NotImplementedError(f"Attention implementation {attn_impl} not supported")
305
+
306
+ output_reshaped = output.reshape(output_shape)
307
+
308
+ output = self.wo(output_reshaped)
309
+
310
+ return output
311
+
312
+
313
+
314
+
315
+ class BLTMLP(nn.Module):
316
+ def __init__(
317
+ self,
318
+ dim: int,
319
+ hidden_dim: int,
320
+ multiple_of: int,
321
+ ffn_dim_multiplier: Optional[float],
322
+ mp_size: int = 1,
323
+ ):
324
+ super().__init__()
325
+
326
+ hidden_dim = int(2 * hidden_dim / 3)
327
+ if ffn_dim_multiplier is not None:
328
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
329
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
330
+ assert hidden_dim % mp_size == 0
331
+
332
+ self.dim = dim
333
+ self.hidden_dim = hidden_dim
334
+
335
+ self.w1 = nn.Linear(
336
+ dim,
337
+ hidden_dim,
338
+ bias=False,
339
+ )
340
+ self.w3 = nn.Linear(
341
+ dim,
342
+ hidden_dim,
343
+ bias=False,
344
+ )
345
+ self.w2 = nn.Linear(
346
+ hidden_dim,
347
+ dim,
348
+ bias=False,
349
+ )
350
+
351
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
352
+ # B S D
353
+ x1 = self.w1(x.view_as(x))
354
+ x3 = self.w3(x.view_as(x))
355
+ output = self.w2(F.silu(x1) * x3)
356
+ return output
357
+
358
+
359
+
360
+
361
+ class BLTTransformerLayer(nn.Module):
362
+ def __init__(self, args):
363
+ super().__init__()
364
+
365
+ # Extract parameters from dictionary
366
+ dim = args["dim"]
367
+ n_heads = args["n_heads"]
368
+ head_dim = args["head_dim"]
369
+ n_kv_heads = args["n_kv_heads"]
370
+ rope_theta = args["rope_theta"]
371
+ multiple_of = args["multiple_of"]
372
+ ffn_dim_multiplier = args["ffn_dim_multiplier"]
373
+ norm_eps = args["norm_eps"]
374
+
375
+ assert (head_dim is not None) or (n_heads is not None), "Should specify at least head_dim or n_heads"
376
+ self.head_dim = head_dim or dim // n_heads
377
+ self.n_heads = n_heads or dim // head_dim
378
+ self.n_kv_heads = n_kv_heads or self.n_heads
379
+
380
+ assert n_heads % self.n_kv_heads == 0
381
+ assert dim % n_heads == 0
382
+
383
+ self.attention = BLTAttention(
384
+ dim=dim,
385
+ head_dim=self.head_dim,
386
+ n_heads=self.n_heads,
387
+ n_kv_heads=self.n_kv_heads,
388
+ rope_theta=rope_theta,
389
+ )
390
+ self.feed_forward = BLTMLP(
391
+ dim=dim,
392
+ hidden_dim=4 * dim,
393
+ multiple_of=multiple_of,
394
+ ffn_dim_multiplier=ffn_dim_multiplier,
395
+ )
396
+ self.attention_norm = RMSNorm(dim, eps=norm_eps)
397
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
398
+
399
+ def forward(
400
+ self,
401
+ x: torch.Tensor,
402
+ freq_cis: torch.Tensor,
403
+ tok_idx: Optional[torch.Tensor] = None,
404
+ mask: Optional[Union[BlockMask, str]] = None,
405
+ attn_impl: str = "sdpa",
406
+ ) -> torch.Tensor:
407
+ norm_x = self.attention_norm(x)
408
+ attn_out = self.attention(
409
+ norm_x,
410
+ freq_cis,
411
+ tok_idx=tok_idx,
412
+ mask=mask,
413
+ attn_impl=attn_impl,
414
+ )
415
+ h = x + attn_out
416
+ h_norm = self.ffn_norm(h)
417
+ out = h + self.feed_forward(h_norm)
418
+ return out
419
+
420
+
421
+
422
+
423
+ def rightpad(seq, pad_id, max_len):
424
+ return seq + [pad_id] * (max_len - len(seq))
425
+
426
+
427
+ def check_non_zero_after_zero(tensor):
428
+ zero_mask = tensor == 0
429
+ shifted_mask = torch.cat(
430
+ [
431
+ torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
432
+ zero_mask[:, :-1],
433
+ ],
434
+ dim=1,
435
+ )
436
+ non_zero_after_zero = (tensor != 0) & shifted_mask
437
+ return non_zero_after_zero.any()
438
+
439
+
440
+ def fill_tokens(tokens, patch_size, fill_id):
441
+ batch_size, seq_len = tokens.shape
442
+ if seq_len % patch_size == 0:
443
+ return tokens
444
+ else:
445
+ remaining = patch_size - seq_len % patch_size
446
+ final_padding = tokens.new(batch_size, remaining).fill_(fill_id)
447
+ return torch.cat((tokens, final_padding), dim=1)
448
+
449
+
450
+ def rolling_polynomial_hash(t, hash_func_nb: int = 0):
451
+ primes = [
452
+ 1000000007,
453
+ 5915587277,
454
+ 1500450271,
455
+ 3267000013,
456
+ 5754853343,
457
+ 4093082899,
458
+ 9576890767,
459
+ 3628273133,
460
+ 2860486313,
461
+ 5463458053,
462
+ 3367900313,
463
+ ]
464
+ prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device)
465
+ prime_powers = torch.stack([prime**i for i in range(t.shape[-1])])
466
+ return torch.sum(t * prime_powers, dim=-1)
467
+
468
+
469
+ def byte_group_hash_function(x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000):
470
+ """
471
+ Returns a hash of the input x and maps it to a value in the range [0, max_hash].
472
+
473
+ expects: x of shape (batch_size, seq_len) with values as ids in the token vocab.
474
+ returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash].
475
+
476
+ Note: max hash can make a big difference on the number of collisions.
477
+ """
478
+ with torch.no_grad():
479
+ bs, seq_len = x.shape
480
+ prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device)
481
+ x = torch.cat([prefix, x], dim=1)
482
+ windows = x.unfold(1, group_size, 1)
483
+ # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
484
+ hashes = rolling_polynomial_hash(windows, hash_func_nb)
485
+ hash_values_range = hashes % max_hash
486
+ hash_values_range.requires_grad = False
487
+ return hash_values_range
488
+
489
+
490
+ def create_patch_mask_from_ids(patch_ids, num_patches, window=None, patches_as_queries=False):
491
+ """
492
+ Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k)
493
+ is True if the patch id at position (i, j) is less than or equal to k.
494
+ Args:
495
+ patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids.
496
+ num_patches (int): Total number of patches.
497
+ window (int): If not None, only considers patches within a window of size window.
498
+ patches_as_queries (bool): If True, the patches are used as queries
499
+ Returns:
500
+ torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask.
501
+ """
502
+ bs, seq_len = patch_ids.shape
503
+ if not patches_as_queries:
504
+ q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches)
505
+ kv_ids = (
506
+ torch.arange(num_patches, device=patch_ids.device)
507
+ .unsqueeze(0)
508
+ .unsqueeze(0)
509
+ .expand(bs, seq_len, num_patches)
510
+ )
511
+ else:
512
+ kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len)
513
+ q_ids = (
514
+ torch.arange(num_patches, device=patch_ids.device)
515
+ .unsqueeze(0)
516
+ .unsqueeze(-1)
517
+ .expand(bs, num_patches, seq_len)
518
+ )
519
+ if window is None:
520
+ mask = q_ids == kv_ids
521
+ else:
522
+ mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window)
523
+ return mask
524
+
525
+
526
+ def cross_attn_mask(
527
+ patch_ids,
528
+ patch_lengths,
529
+ N,
530
+ patches_as_queries=False,
531
+ cross_attn_k=1,
532
+ window=None,
533
+ block_mask=True,
534
+ ):
535
+ bs = patch_ids.shape[0]
536
+ with torch.no_grad():
537
+ # Create the patch mask
538
+ cross_mask = create_patch_mask_from_ids(
539
+ patch_ids,
540
+ patch_lengths.shape[1],
541
+ window=window,
542
+ patches_as_queries=patches_as_queries,
543
+ ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1)
544
+ q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N
545
+ kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k
546
+ assert cross_mask.shape == (
547
+ bs,
548
+ q_len,
549
+ kv_len,
550
+ ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}"
551
+ block_mask = None
552
+ if block_mask:
553
+
554
+ def patch_mask(b, h, q_idx, kv_idx):
555
+ return cross_mask[b, q_idx, kv_idx]
556
+
557
+ block_mask = create_block_mask(
558
+ patch_mask,
559
+ B=bs,
560
+ H=None,
561
+ Q_LEN=q_len,
562
+ KV_LEN=kv_len,
563
+ _compile=True,
564
+ )
565
+ return block_mask
566
+ else:
567
+ return torch.where(cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))).unsqueeze(
568
+ 1
569
+ ) # [bs, 1, q_len, kv_len]
570
+
571
+
572
+ def get_blt_input(
573
+ tokens: torch.Tensor,
574
+ enforce_patch_size_multiple: bool,
575
+ nb_boe: torch.Tensor,
576
+ patch_size: int,
577
+ boe_id: int,
578
+ ):
579
+ """
580
+ This function returns X_et, X_gt and X_dt, the encoder, global, and decoder
581
+ tokens respectively.
582
+
583
+ Consider the input and target sequences:
584
+ X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13]
585
+ Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14]
586
+ with patch_size=4
587
+
588
+ Note 1: that there will be no special tokens introduced at the patch level.
589
+ Note 2: X_e needs to be trimmed to be passed to Global
590
+
591
+ Current without boe:
592
+ X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]]
593
+ X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch
594
+ X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
595
+ Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
596
+
597
+ --> lag fix:
598
+ X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]]
599
+ X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]]
600
+ X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
601
+ Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
602
+
603
+ Dynamic (current):
604
+ X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos]
605
+ Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
606
+
607
+ entropy patching:
608
+ input: 7, bos, 9, 10
609
+ pred (high entropy): eos, 8, 10, eos
610
+
611
+ X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos]
612
+ X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]]
613
+ X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]]
614
+ Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
615
+
616
+ --> lag fix no boe (force single byte first patch):
617
+ X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
618
+ X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
619
+ X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
620
+ Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
621
+
622
+ input: 4, 7, bos, 9, 10
623
+ pred (high entropy): 5, eos, 8, 10, eos
624
+
625
+ X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
626
+ X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
627
+ X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
628
+ Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
629
+
630
+ Handle the last byte properly.
631
+ patch_lengths = [1, 1, 3, 2, 2 1 2 2 1]
632
+ X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
633
+ X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch
634
+ X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]]
635
+ Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]]
636
+
637
+
638
+ bpe delim
639
+ X_et = [[3,4,5,6,7,<d>,eos,bos,<d>,8,9,<d>,10,<d>,eos,bos,11,12]
640
+ X_g = [[3], [4,5,6,7,<d>], [eos,bos,<d>], ..
641
+ X_dt = [[3,4,5,6,7], [<d>,eos,bos], [<d>,bos,8], ..
642
+ Y = [4,5,6,7,<d>, eos,bos,<d> 8,9,<d>, ..
643
+
644
+
645
+ Note 1: that there will be no special tokens introduced at the patch level.
646
+ Note 2: X_e needs to be trimmed to be passed to Global
647
+ """
648
+ batch_size, seq_len = tokens.shape
649
+ local_encoder_tokens = tokens
650
+ local_decoder_tokens = tokens
651
+
652
+ if nb_boe > 0:
653
+ padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id)
654
+ local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1)
655
+ # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id)
656
+
657
+ # create global tokens, contains boe tokens and eos
658
+ # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
659
+ # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size)
660
+ # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:]
661
+ # global_tokens += global_tokens.eq(0).int() * boe_id
662
+ # TODO: fix this when we want to use block causal in the global.
663
+
664
+ if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0:
665
+ local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
666
+
667
+ return local_encoder_tokens, None, local_decoder_tokens
668
+
669
+
670
+ class LocalModelBase(nn.Module):
671
+ def __init__(self, config: BLTConfig, component_type: str = "encoder"):
672
+ super().__init__()
673
+
674
+ # Store config for later use
675
+ self.config = config
676
+
677
+ # Use component-specific dimensions
678
+ if component_type == "encoder":
679
+ self.dim = config.dim_local_encoder
680
+ self.n_layers = config.n_layers_local_encoder
681
+ self.n_heads = config.n_heads_local_encoder
682
+ self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen
683
+ self.attn_bias_type = "local_block_causal"
684
+ self.sliding_window = config.local_attention_window_len
685
+ elif component_type == "decoder":
686
+ self.dim = config.dim_local_decoder
687
+ self.n_layers = config.n_layers_local_decoder
688
+ self.n_heads = config.n_heads_local_decoder
689
+ self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen
690
+ self.attn_bias_type = "local_block_causal"
691
+ self.sliding_window = config.local_attention_window_len
692
+ else:
693
+ raise ValueError(f"Unknown component_type: {component_type}")
694
+
695
+ self.dropout = config.dropout
696
+ self.vocab_size = config.vocab_size + config.pm_size
697
+ self.patch_size = config.patch_size
698
+
699
+ self.attn_impl = config.attn_impl
700
+ self.use_rope = config.use_rope
701
+ self.init_std_factor = config.init_std_factor
702
+ self.init_base_std = config.init_base_std
703
+ self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None)
704
+ self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None)
705
+ self.cross_attn_k = getattr(config, "cross_attn_k", None)
706
+ self.eos_id = config.eos_token_id
707
+
708
+ self.boe_id = BOE_ID
709
+
710
+ # Initialize cross attention layers as None (will be set by subclasses if needed)
711
+ self.cross_attn_layers = None
712
+
713
+ # Create parameter dict for BLTTransformerLayers
714
+ layer_params = {
715
+ "dim": self.dim,
716
+ "n_heads": self.n_heads,
717
+ "head_dim": config.head_dim,
718
+ "n_kv_heads": getattr(config, "n_kv_heads", None),
719
+ "rope_theta": config.rope_theta,
720
+ "multiple_of": getattr(config, "multiple_of", 256),
721
+ "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None),
722
+ "norm_eps": config.norm_eps,
723
+ }
724
+
725
+ self.layers = nn.ModuleList([BLTTransformerLayer(layer_params) for _ in range(self.n_layers)])
726
+
727
+ if not self.use_rope:
728
+ self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length
729
+ else:
730
+ self.rope = RotaryEmbedding(
731
+ theta=config.rope_theta,
732
+ head_dim=config.head_dim or self.dim // self.n_heads,
733
+ max_seqlen=self.max_seqlen,
734
+ rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product,
735
+ )
736
+ self.pos_embeddings = None
737
+
738
+ # Set dimension-specific embedding dimensions
739
+ if component_type == "encoder":
740
+ self.dim_token_emb = config.encoder_dim_token_emb
741
+ self.dim_patch_emb = config.encoder_dim_patch_emb
742
+ elif component_type == "decoder":
743
+ self.dim_token_emb = config.decoder_dim_token_emb
744
+ self.dim_patch_emb = config.dim_global
745
+
746
+ self.token_embedding_projection = (
747
+ nn.Linear(self.dim_token_emb, self.dim, bias=False)
748
+ if self.dim_token_emb is not None and self.dim_token_emb != self.dim
749
+ else None
750
+ )
751
+
752
+ self.patch_embedding_projection = self._create_patch_projection(config)
753
+
754
+ def _should_create_patch_projection(self, config: BLTConfig):
755
+ dimension_mismatch = self.dim_patch_emb is not None and self.dim_patch_emb != self.dim
756
+
757
+ # Check cross attention conditions
758
+ cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or (
759
+ config.cross_attn_decoder and config.cross_attn_init_by_pooling
760
+ )
761
+
762
+ return dimension_mismatch or cross_attn_conditions
763
+
764
+ def _create_patch_projection(self, config):
765
+ if not self._should_create_patch_projection(config):
766
+ return None
767
+
768
+ output_dim = self.dim_token_emb * (self.cross_attn_k or 1)
769
+
770
+ return nn.Linear(
771
+ in_features=self.dim_patch_emb,
772
+ out_features=output_dim,
773
+ bias=False,
774
+ )
775
+
776
+ def apply_embedding(self, tokens, embeds):
777
+ if embeds is not None:
778
+ return embeds
779
+ else:
780
+ return self.tok_embeddings(tokens)
781
+
782
+
783
+
784
+
785
+ class LocalEncoder(LocalModelBase):
786
+ def __init__(self, config: BLTConfig):
787
+ super().__init__(config, component_type="encoder")
788
+
789
+ self.apply_transformer = config.use_local_encoder_transformer
790
+ self.downsampling_by_pooling = config.downsampling_by_pooling
791
+ self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None
792
+ self.cross_attn_encoder = config.cross_attn_encoder
793
+ self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder
794
+ self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling
795
+ self.cross_attn_nheads = config.cross_attn_nheads
796
+
797
+ self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim)
798
+
799
+ if self.cross_attn_encoder:
800
+ self.cross_attn_layers = torch.nn.ModuleList()
801
+ layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1
802
+ for _ in range(layers_to_add):
803
+ self.cross_attn_layers.append(
804
+ BLTCrossAttention(
805
+ dim=self.dim,
806
+ head_dim=self.dim // self.cross_attn_nheads,
807
+ n_heads=self.cross_attn_nheads,
808
+ n_kv_heads=self.cross_attn_nheads,
809
+ norm_eps=config.norm_eps,
810
+ )
811
+ )
812
+
813
+ def apply_embedding(self, tokens, embeds):
814
+ if embeds is not None:
815
+ assert self.expects_hash_embeddings, "Not expecting embeddings to be passed."
816
+ return embeds
817
+ else:
818
+ return self.tok_embeddings(tokens)
819
+
820
+ def forward(
821
+ self,
822
+ tokens: torch.Tensor,
823
+ embeds: Optional[torch.Tensor] = None,
824
+ patch_embeds: Optional[torch.Tensor] = None,
825
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
826
+ cross_mask: Optional[torch.Tensor] = None,
827
+ num_patches: Optional[int] = None,
828
+ patch_ids: Optional[torch.Tensor] = None,
829
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
830
+ ):
831
+ """ """
832
+ bs, seqlen = tokens.shape
833
+ if mask is None:
834
+ mask = create_causal_mask(
835
+ seqlen,
836
+ self.attn_impl,
837
+ "local_block_causal",
838
+ sliding_window=self.sliding_window,
839
+ tokens=tokens,
840
+ eos_id=self.eos_id,
841
+ )
842
+
843
+ h = self.apply_embedding(tokens, embeds)
844
+ freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
845
+
846
+ h = F.dropout(h, p=self.dropout, training=self.training)
847
+
848
+ for i, layer in enumerate(self.layers):
849
+ h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
850
+ # check if cross attention should be applied to either all layer or only the last layer
851
+ if self.cross_attn_encoder and (i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder):
852
+ # apply pooling and project
853
+ if self.cross_attn_init_by_pooling and patch_embeds is None:
854
+ patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids)
855
+ if self.patch_embedding_projection is not None:
856
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
857
+ patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim)
858
+
859
+ layer_idx = i if self.cross_attn_all_layers_encoder else 0
860
+ patch_embeds_cross = self.cross_attn_layers[layer_idx](
861
+ x=patch_embeds,
862
+ kv=h,
863
+ mask=cross_mask,
864
+ )
865
+ patch_embeds = patch_embeds + patch_embeds_cross
866
+
867
+ h_residual = patch_embeds if self.cross_attn_encoder else None
868
+ return (h, h_residual), cache
869
+
870
+ def patch_reduce(self, h, max_num_patches, reduction, patch_ids):
871
+ """
872
+ Reduce variable length patches to single embedding per patch
873
+ Note: this works with variable number of patches for different sequences in the batch
874
+ It handles variable length patches by assuming that patch_lengths will be 0 for any
875
+ extra patches on the *right*. Since there can be a variable number of patches
876
+ this function also return the number of patches for each sequence in the batch.
877
+ Any embeddings on the right that are not allocated to a patch
878
+ (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
879
+ will be sent to a dummy patch, which is trimmed before returning.
880
+ """
881
+ bs, seq_len, emb_dim = h.shape
882
+
883
+ patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
884
+
885
+ reduced_embs = torch.zeros((bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device)
886
+ reduced_embs = reduced_embs.scatter_reduce(
887
+ src=h,
888
+ dim=1,
889
+ index=patch_ids,
890
+ reduce=reduction,
891
+ include_self=False,
892
+ )
893
+ reduced_embs = reduced_embs[:, :max_num_patches, :]
894
+
895
+ return reduced_embs
896
+
897
+
898
+ class LocalDecoder(LocalModelBase):
899
+ def __init__(self, config: BLTConfig):
900
+ super().__init__(config, component_type="decoder")
901
+
902
+ # Model configuration flags
903
+ self.cross_attn_decoder = config.cross_attn_decoder
904
+ self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder
905
+ self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling
906
+ self.cross_attn_nheads = config.cross_attn_nheads
907
+
908
+ self.norm = RMSNorm(self.dim, eps=config.norm_eps)
909
+
910
+ if self.cross_attn_decoder:
911
+ self.cross_attn_layers = torch.nn.ModuleList()
912
+ layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1
913
+ for _ in range(layers_to_add):
914
+ self.cross_attn_layers.append(
915
+ BLTCrossAttention(
916
+ dim=self.dim,
917
+ head_dim=self.dim // self.cross_attn_nheads,
918
+ n_heads=self.cross_attn_nheads,
919
+ n_kv_heads=self.cross_attn_nheads,
920
+ norm_eps=config.norm_eps,
921
+ )
922
+ )
923
+
924
+ self.output = nn.Linear(
925
+ self.dim,
926
+ config.vocab_size,
927
+ bias=False,
928
+ )
929
+
930
+ def forward(
931
+ self,
932
+ tokens: torch.Tensor,
933
+ embeds: Optional[torch.Tensor],
934
+ patch_embeds: Optional[torch.Tensor] = None,
935
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
936
+ cross_mask: Optional[torch.Tensor] = None,
937
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
938
+ ):
939
+ bs, seqlen = tokens.shape
940
+ assert embeds is not None, "Embeddings must be provided"
941
+
942
+ if mask is None:
943
+ mask = create_causal_mask(
944
+ seqlen,
945
+ self.attn_impl,
946
+ "local_block_causal",
947
+ sliding_window=self.sliding_window,
948
+ tokens=tokens,
949
+ eos_id=self.eos_id,
950
+ )
951
+
952
+ h = embeds
953
+
954
+ if self.patch_embedding_projection is not None:
955
+ assert patch_embeds is not None, "Patch embeddings must be passed."
956
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
957
+ if self.cross_attn_k is not None:
958
+ patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim)
959
+
960
+ if patch_embeds is not None and not self.cross_attn_decoder:
961
+ h = h + patch_embeds
962
+
963
+ freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
964
+
965
+ h = F.dropout(h, p=self.dropout, training=self.training)
966
+ for i, layer in enumerate(self.layers):
967
+ if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder):
968
+ # Use cross attention to extract info from patch_embeds into h
969
+ h_cross = self.cross_attn_layers[i](
970
+ x=h,
971
+ kv=patch_embeds,
972
+ mask=cross_mask,
973
+ )
974
+ h = h + h_cross
975
+
976
+ h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
977
+
978
+ h_preds = self.norm(h)
979
+ h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
980
+ h_preds = self.output(h_preds)
981
+ h_preds = h_preds.float()
982
+ return h_preds, cache
983
+
984
+
985
+ class BLTCrossAttention(nn.Module):
986
+ """
987
+ BLTCrossAttention block to attend to the encoder states from the decoder.
988
+ Rope is not supported.
989
+ """
990
+
991
+ def __init__(
992
+ self,
993
+ dim: int,
994
+ head_dim: int,
995
+ n_heads: int,
996
+ n_kv_heads: int,
997
+ norm_eps: float,
998
+ ):
999
+ super().__init__()
1000
+
1001
+ self.dim = dim
1002
+ self.head_dim = head_dim
1003
+
1004
+ self.n_heads = n_heads
1005
+ self.n_kv_heads = n_kv_heads
1006
+ self.heads_per_group = self.n_heads // self.n_kv_heads
1007
+
1008
+ self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps)
1009
+ self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
1010
+
1011
+ self.wq = nn.Linear(
1012
+ dim,
1013
+ n_heads * head_dim,
1014
+ bias=False,
1015
+ )
1016
+ self.wk = nn.Linear(
1017
+ dim,
1018
+ n_kv_heads * head_dim,
1019
+ bias=False,
1020
+ )
1021
+ self.wv = nn.Linear(
1022
+ dim,
1023
+ n_kv_heads * head_dim,
1024
+ bias=False,
1025
+ )
1026
+
1027
+ self.wo = nn.Linear(
1028
+ n_heads * head_dim,
1029
+ dim,
1030
+ bias=False,
1031
+ )
1032
+
1033
+ def forward(
1034
+ self,
1035
+ x: torch.Tensor,
1036
+ kv: torch.Tensor,
1037
+ mask: Optional[Union[BlockMask, str]] = None,
1038
+ ) -> torch.Tensor:
1039
+ # B S D
1040
+ bsz, seq_len, _ = x.shape
1041
+ _, slen_kv, _ = kv.shape
1042
+ x_norm = self.cross_attn_norm_q(x)
1043
+ kv = self.cross_attn_norm_kv(kv)
1044
+
1045
+ xq = self.wq(x_norm)
1046
+ xk = self.wk(kv)
1047
+ xv = self.wv(kv)
1048
+
1049
+ output_shape = xq.shape
1050
+ # B S D -> B S H D
1051
+ xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
1052
+ xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
1053
+ xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
1054
+
1055
+ xk = repeat_kv(xk, self.heads_per_group, dim=2)
1056
+ xv = repeat_kv(xv, self.heads_per_group, dim=2)
1057
+
1058
+ # assert mask is None or isinstance(mask, BlockMask)
1059
+ xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv))
1060
+ # output = flex_attention_comp(xq, xk, xv, block_mask=mask)
1061
+ is_causal = (mask == "causal") if isinstance(mask, str) else False
1062
+ mask = mask if isinstance(mask, torch.Tensor) else None
1063
+ mask = mask.to(dtype=xq.dtype).to(xq.device)
1064
+ output = F.scaled_dot_product_attention(
1065
+ xq,
1066
+ xk,
1067
+ xv,
1068
+ is_causal=is_causal,
1069
+ attn_mask=mask,
1070
+ )
1071
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
1072
+
1073
+ output = self.wo(output.reshape(output_shape))
1074
+
1075
+ return x + output
1076
+
1077
+
1078
+
1079
+
1080
+ class GlobalTransformer(nn.Module):
1081
+ def __init__(self, config):
1082
+ super().__init__()
1083
+
1084
+ # Store config for later use
1085
+ self.config = config
1086
+
1087
+ self.dim = config.dim_global
1088
+ self.init_base_std = config.init_base_std
1089
+ self.attn_impl = config.attn_impl
1090
+ self.attn_bias_type = config.attn_bias_type
1091
+ self.init_std_factor = config.init_std_factor
1092
+ self.max_seqlen = config.max_seqlen
1093
+ self.rope_embeddings = RotaryEmbedding(
1094
+ theta=config.rope_theta,
1095
+ head_dim=config.head_dim or config.dim_global // config.n_heads_global,
1096
+ max_seqlen=config.max_seqlen,
1097
+ rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product,
1098
+ )
1099
+ # Handle both eos_id and eos_token_id for compatibility
1100
+ self.eos_id = getattr(config, "eos_id", getattr(config, "eos_token_id", 2))
1101
+
1102
+ # Create parameter dict for BLTTransformerLayers
1103
+ layer_params = {
1104
+ "dim": self.dim,
1105
+ "n_heads": config.n_heads_global,
1106
+ "head_dim": config.head_dim,
1107
+ "n_kv_heads": getattr(config, "n_kv_heads_global", None),
1108
+ "rope_theta": config.rope_theta,
1109
+ "multiple_of": getattr(config, "multiple_of", 256),
1110
+ "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None),
1111
+ "norm_eps": config.norm_eps,
1112
+ }
1113
+
1114
+ self.layers = nn.ModuleList()
1115
+ for _ in range(config.n_layers_global):
1116
+ self.layers.append(BLTTransformerLayer(layer_params))
1117
+
1118
+ # GlobalTransformer specific attributes
1119
+ self.dropout = config.dropout
1120
+ self.dim_token_emb = config.global_dim_patch_emb
1121
+
1122
+ self.token_embedding_projection = None
1123
+ if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim:
1124
+ self.token_embedding_projection = nn.Linear(
1125
+ config.global_dim_patch_emb,
1126
+ config.dim_global,
1127
+ bias=False,
1128
+ )
1129
+
1130
+ def forward(
1131
+ self,
1132
+ tokens: torch.Tensor,
1133
+ tok_idx: Optional[torch.Tensor] = None,
1134
+ embeds: Optional[torch.Tensor] = None,
1135
+ mask: Optional[Union[BlockMask, torch.Tensor, str]] = None,
1136
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
1137
+ ):
1138
+ bs, seqlen = tokens.shape
1139
+
1140
+ h = embeds
1141
+
1142
+ mask = (
1143
+ mask
1144
+ if mask is not None
1145
+ else create_causal_mask(
1146
+ seqlen,
1147
+ self.attn_impl,
1148
+ self.attn_bias_type,
1149
+ tokens=tokens,
1150
+ eos_id=self.eos_id,
1151
+ )
1152
+ )
1153
+
1154
+ if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
1155
+ h = self.token_embedding_projection(h)
1156
+
1157
+ h = F.dropout(h, p=self.dropout, training=self.training)
1158
+
1159
+ freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)
1160
+
1161
+ for i, layer in enumerate(self.layers):
1162
+ h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
1163
+
1164
+ return h, cache
1165
+
1166
+
1167
+
1168
+
1169
+ def compute_hash_embeddings(
1170
+ local_encoder_tokens: torch.Tensor,
1171
+ local_encoder,
1172
+ encoder_hash_tok_embedding: nn.ModuleList,
1173
+ encoder_hash_byte_group_nb_functions: int,
1174
+ encoder_hash_byte_group_size: list,
1175
+ encoder_hash_byte_group_vocab: int,
1176
+ ) -> torch.Tensor:
1177
+ """
1178
+ Compute embeddings using hash token embeddings.
1179
+
1180
+ Args:
1181
+ local_encoder_tokens: Input tokens tensor
1182
+ local_encoder: Encoder object with tok_embeddings method
1183
+ encoder_hash_tok_embedding: ModuleList of hash token embeddings
1184
+ encoder_hash_byte_group_nb_functions: Number of hash functions
1185
+ encoder_hash_byte_group_size: List of byte group sizes
1186
+ encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings
1187
+
1188
+ Returns:
1189
+ torch.Tensor: Combined embeddings
1190
+ """
1191
+ if encoder_hash_tok_embedding is None:
1192
+ return None
1193
+
1194
+ local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens)
1195
+
1196
+ i = 0
1197
+ for func_nb in range(encoder_hash_byte_group_nb_functions):
1198
+ for byte_group_size in encoder_hash_byte_group_size:
1199
+ hash_ids = byte_group_hash_function(
1200
+ local_encoder_tokens,
1201
+ byte_group_size,
1202
+ hash_func_nb=func_nb,
1203
+ max_hash=encoder_hash_byte_group_vocab,
1204
+ )
1205
+ hash_tok_embedding = encoder_hash_tok_embedding[i]
1206
+ local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids)
1207
+ i += 1
1208
+
1209
+ assert i == len(encoder_hash_tok_embedding)
1210
+ return local_encoder_embeds
1211
+
1212
+
1213
+ class BLTPreTrainedModel(PreTrainedModel):
1214
+ config_class = BLTConfig
1215
+ base_model_prefix = "model"
1216
+ supports_gradient_checkpointing = True
1217
+ _no_split_modules = ["BLTTransformerLayer", "LocalEncoder", "LocalDecoder", "GlobalTransformer"]
1218
+ _skip_keys_device_placement = ["past_key_values"]
1219
+ _supports_flash_attn_2 = False # BLT uses its own attention implementation
1220
+ _supports_sdpa = True
1221
+ _supports_cache_class = False
1222
+
1223
+ def _init_weights(self, module):
1224
+ if isinstance(module, nn.Linear):
1225
+ std = getattr(module, '_custom_std', module.in_features ** (-0.5))
1226
+
1227
+ nn.init.trunc_normal_(
1228
+ module.weight,
1229
+ mean=0.0,
1230
+ std=std,
1231
+ a=-3 * std,
1232
+ b=3 * std,
1233
+ )
1234
+ if module.bias is not None:
1235
+ nn.init.zeros_(module.bias)
1236
+
1237
+ elif isinstance(module, nn.Embedding):
1238
+ std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5))
1239
+
1240
+ nn.init.trunc_normal_(
1241
+ module.weight,
1242
+ mean=0.0,
1243
+ std=std,
1244
+ a=-3 * std,
1245
+ b=3 * std,
1246
+ )
1247
+
1248
+ elif isinstance(module, (nn.RMSNorm, nn.LayerNorm)):
1249
+ nn.init.ones_(module.weight)
1250
+ if module.bias is not None:
1251
+ nn.init.zeros_(module.bias)
1252
+
1253
+ elif isinstance(module, RotaryEmbedding):
1254
+ module.freqs_cis[...] = precompute_freqs_cis(
1255
+ dim=module.head_dim,
1256
+ end=module.max_seqlen,
1257
+ theta=module.theta,
1258
+ rope_use_fp32_in_outer_product=module.rope_use_fp32_in_outer_product,
1259
+ )
1260
+
1261
+ elif isinstance(module, BLTModel):
1262
+ if module.encoder_hash_tok_embedding is not None:
1263
+ emb_std = module.local_encoder.dim ** (-0.5)
1264
+ for emb in module.encoder_hash_tok_embedding:
1265
+ emb._custom_std = emb_std
1266
+
1267
+ elif isinstance(module, (LocalEncoder, LocalDecoder)):
1268
+ if module.token_embedding_projection is not None:
1269
+ module.token_embedding_projection._custom_std = module.dim ** (-0.5)
1270
+
1271
+ if module.patch_embedding_projection is not None:
1272
+ module.patch_embedding_projection._custom_std = module.dim_patch_emb ** (-0.5)
1273
+
1274
+ elif isinstance(module, GlobalTransformer):
1275
+ if module.token_embedding_projection is not None:
1276
+ module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5)
1277
+
1278
+ elif isinstance(module, BLTPatcher):
1279
+ emb_std = module.config.patcher_dim ** (-0.5)
1280
+ module.tok_embeddings._custom_std = emb_std
1281
+ module.output._custom_std = emb_std
1282
+
1283
+
1284
+ class BLTModel(BLTPreTrainedModel):
1285
+ def __init__(self, config: BLTConfig):
1286
+ super().__init__(config)
1287
+
1288
+ self.config = config
1289
+ self.local_encoder = LocalEncoder(config)
1290
+ self.global_transformer = GlobalTransformer(config)
1291
+ self.local_decoder = LocalDecoder(config)
1292
+
1293
+ self.encoder_hash_tok_embedding = init_hash_embeddings(
1294
+ config,
1295
+ local_encoder_dim=self.local_encoder.dim,
1296
+ encoder_hash_byte_group_size=config.encoder_hash_byte_group_size,
1297
+ )
1298
+
1299
+ if config.patch_in_forward:
1300
+ self.patcher = BLTPatcher(config)
1301
+ self.patcher.eval()
1302
+ for param in self.patcher.parameters():
1303
+ param.requires_grad = False
1304
+ else:
1305
+ self.patcher = None
1306
+
1307
+
1308
+
1309
+ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
1310
+ """
1311
+ Convert patch lengths to patch IDs for each token position.
1312
+
1313
+ For each token position in the sequence, determines which patch it belongs to.
1314
+
1315
+ Args:
1316
+ patch_lengths: [batch_size, num_patches] - length of each patch
1317
+ seq_len: total sequence length
1318
+
1319
+ Returns:
1320
+ patch_ids: [batch_size, seq_len] - patch index for each token position
1321
+
1322
+ Example:
1323
+ patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1
1324
+ seq_len = 10
1325
+ Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]]
1326
+ # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3
1327
+ """
1328
+ batch_size, num_patches = patch_lengths.shape
1329
+
1330
+ # Create patch start positions: [0, 3, 5, 9] for the example above
1331
+ patch_starts = torch.cat(
1332
+ [
1333
+ torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
1334
+ patch_lengths.cumsum(dim=-1)[:, :-1], # cumsum without the final total
1335
+ ],
1336
+ dim=-1,
1337
+ )
1338
+
1339
+ # For each token position, find which patch it belongs to
1340
+ # by finding the rightmost patch start that's <= the position
1341
+ token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1]
1342
+
1343
+ # Broadcasting: patch_starts[batch, patch] <= token_positions[position]
1344
+ # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t
1345
+ position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)
1346
+
1347
+ # Count how many patch starts are <= each position, then subtract 1 to get patch index
1348
+ patch_ids = position_ge_patch_start.sum(dim=-1) - 1
1349
+
1350
+ return patch_ids
1351
+
1352
+ def _decoder_patch_ids_from_lengths(self, patch_lengths: torch.Tensor, nb_boe: int, seq_len: int) -> torch.Tensor:
1353
+ """
1354
+ Create decoder patch IDs by skipping the first encoder patch.
1355
+
1356
+ The decoder starts after the first patch (which contains BOE tokens),
1357
+ so we need to map decoder positions to the remaining patches.
1358
+
1359
+ Args:
1360
+ patch_lengths: [batch_size, num_patches] from encoder
1361
+ nb_boe: number of beginning-of-example tokens in first patch
1362
+ seq_len: decoder sequence length
1363
+
1364
+ Returns:
1365
+ decoder_patch_ids: [batch_size, seq_len] mapping decoder positions to patch indices
1366
+ """
1367
+ # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens)
1368
+ decoder_patch_lengths = patch_lengths[:, 1:]
1369
+
1370
+ # Create patch IDs for the decoder sequence using the remaining patches
1371
+ return self._patch_ids_from_lengths(decoder_patch_lengths, seq_len)
1372
+
1373
+ def forward(
1374
+ self,
1375
+ tokens: torch.Tensor,
1376
+ patch_lengths: Optional[torch.Tensor] = None,
1377
+ ):
1378
+ # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings
1379
+ # are no longer used in the final BLT model
1380
+
1381
+ bs, N = tokens.shape # Batch size and sequence length
1382
+
1383
+ # Get megabyte inputs
1384
+ nb_boe = int(0 if self.config.patching_mode != "" else self.config.patch_size - 1)
1385
+ local_encoder_tokens, _, local_decoder_tokens = get_blt_input(
1386
+ tokens=tokens,
1387
+ enforce_patch_size_multiple=False,
1388
+ nb_boe=nb_boe,
1389
+ patch_size=self.config.patch_size,
1390
+ boe_id=BOE_ID,
1391
+ )
1392
+
1393
+ # Patching
1394
+ if patch_lengths is None:
1395
+ # assert (
1396
+ # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward
1397
+ # ), "Patch in forward not enabled and no patch_lengths passed."
1398
+
1399
+ # PATCHER MODEL DEFINED
1400
+ if self.config.patching_mode == PatchingModeEnum.entropy:
1401
+ _, patch_lengths, _ = self.patcher(
1402
+ local_encoder_tokens,
1403
+ patch_size=self.config.patch_size,
1404
+ include_next_token=True,
1405
+ threshold=self.config.patching_threshold,
1406
+ threshold_add=self.config.patching_threshold_add,
1407
+ monotonicity=self.config.monotonicity,
1408
+ max_patch_length=self.config.max_patch_length,
1409
+ patching_batch_size=self.config.patching_batch_size,
1410
+ device=self.config.patching_device,
1411
+ )
1412
+ else:
1413
+ # self.config.patching_mode == PatchingModeEnum.byte
1414
+ bs, seq_len = local_encoder_tokens.shape
1415
+ seq_len_next_tok = seq_len + 1 # include_next_token=True
1416
+ patch_lengths = torch.ones(
1417
+ (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device
1418
+ )
1419
+
1420
+ # Apply any processing to patch lengths
1421
+ if self.config.max_patch_length is not None:
1422
+ # TODO: avoid going back to a list here.
1423
+ patch_lengths = [
1424
+ BLTPatcher.split_large_numbers(pl, self.config.max_patch_length)
1425
+ for pl in patch_lengths.tolist()
1426
+ ]
1427
+ max_len = max([len(pl) for pl in patch_lengths])
1428
+ patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
1429
+ patch_lengths = torch.tensor(
1430
+ patch_lengths, dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device
1431
+ )
1432
+ assert not check_non_zero_after_zero(patch_lengths)
1433
+ # Find the last non-zero column index using argmax on a reversed version of the tensor
1434
+ last_non_zero_col_reversed = (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
1435
+ # Slice the tensor up to the last non-zero column
1436
+ patch_lengths = patch_lengths[:, : patch_lengths.shape[1] - last_non_zero_col_reversed]
1437
+ else:
1438
+ if nb_boe > 0:
1439
+ patch_lengths[:, 0] += nb_boe
1440
+
1441
+ assert torch.min(patch_lengths) >= 0
1442
+
1443
+ # Generate patch IDs from patch_lengths
1444
+ patch_ids = self._patch_ids_from_lengths(patch_lengths, local_encoder_tokens.shape[-1])
1445
+ assert torch.max(patch_ids) + 1 <= torch.max((patch_lengths != 0).sum(dim=-1)), (
1446
+ f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}"
1447
+ )
1448
+
1449
+ cross_attn_mask_enc = None
1450
+ # Cross-attention encoder
1451
+ if self.config.cross_attn_encoder:
1452
+ cross_attn_mask_enc = cross_attn_mask(
1453
+ patch_ids,
1454
+ patch_lengths,
1455
+ N,
1456
+ patches_as_queries=True,
1457
+ cross_attn_k=self.config.cross_attn_k,
1458
+ window=self.config.cross_attn_window_encoder,
1459
+ block_mask=self.config.cross_attn_use_flex_attention,
1460
+ )
1461
+
1462
+ # Hashing and embedding
1463
+ local_encoder_embeds = compute_hash_embeddings(
1464
+ local_encoder_tokens=local_encoder_tokens,
1465
+ local_encoder=self.local_encoder,
1466
+ encoder_hash_tok_embedding=self.encoder_hash_tok_embedding,
1467
+ encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions,
1468
+ encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size,
1469
+ encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab,
1470
+ )
1471
+
1472
+ # NOTE: Frequency-based n-gram embeddings removed as per paper
1473
+ # The final BLT model uses only hash-based n-gram embeddings
1474
+
1475
+ # Local encoder
1476
+ (h_encoder, h_cross), cache_encoder = self.local_encoder(
1477
+ tokens=local_encoder_tokens,
1478
+ embeds=local_encoder_embeds,
1479
+ patch_embeds=None,
1480
+ cross_mask=cross_attn_mask_enc,
1481
+ num_patches=patch_lengths.shape[1],
1482
+ patch_ids=patch_ids,
1483
+ )
1484
+
1485
+ # Downsampling
1486
+ h = h_cross.view(bs, patch_lengths.shape[1], -1)
1487
+
1488
+ # Global transformer
1489
+ global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(BOE_ID)
1490
+ rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id)
1491
+ eos_patch_ids = patch_ids[rows, cols]
1492
+ global_tokens[rows, eos_patch_ids] = self.config.eos_token_id
1493
+
1494
+ h, _ = self.global_transformer(
1495
+ embeds=h,
1496
+ tokens=global_tokens,
1497
+ )
1498
+
1499
+ # Unpatching
1500
+ dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :]
1501
+
1502
+ # Generate decoder patch IDs
1503
+ decoder_patch_ids = self._decoder_patch_ids_from_lengths(patch_lengths, nb_boe, local_decoder_tokens.shape[-1])
1504
+ assert torch.max(decoder_patch_ids) + 1 <= h.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}"
1505
+ assert decoder_patch_ids.shape[1] == dec_embeds.shape[1], (
1506
+ f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}"
1507
+ )
1508
+
1509
+ # Cross-attention decoder
1510
+ if not self.config.cross_attn_decoder:
1511
+ h = torch.gather(h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]))
1512
+ cross_attn_mask_dec = None
1513
+ assert local_decoder_tokens.shape == h.shape[:-1]
1514
+ else:
1515
+ cross_attn_mask_dec = cross_attn_mask(
1516
+ decoder_patch_ids,
1517
+ patch_lengths,
1518
+ N,
1519
+ patches_as_queries=False,
1520
+ cross_attn_k=self.config.cross_attn_k,
1521
+ window=self.config.cross_attn_window_decoder,
1522
+ block_mask=self.config.cross_attn_use_flex_attention,
1523
+ )
1524
+
1525
+ # Local decoder
1526
+ output, _ = self.local_decoder(
1527
+ embeds=dec_embeds,
1528
+ patch_embeds=h,
1529
+ tokens=local_decoder_tokens,
1530
+ cross_mask=cross_attn_mask_dec,
1531
+ )
1532
+ return output
1533
+
1534
+
1535
+ class BLTPatcher(BLTPreTrainedModel):
1536
+ def __init__(self, config):
1537
+ super().__init__(config)
1538
+
1539
+ self.rope_embeddings = RotaryEmbedding(
1540
+ theta=config.patcher_rope_theta,
1541
+ head_dim=config.patcher_head_dim or config.patcher_dim // config.patcher_n_heads,
1542
+ max_seqlen=config.patcher_max_seqlen,
1543
+ rope_use_fp32_in_outer_product=config.patcher_rope_use_fp32_in_outer_product,
1544
+ )
1545
+ # Handle both eos_id and eos_token_id for compatibility
1546
+ self.eos_id = config.patcher_eos_token_id
1547
+
1548
+ # Extract additional parameters for BLTTransformerLayer
1549
+ n_kv_heads = (
1550
+ getattr(config, "patcher_n_kv_heads", None)
1551
+ if hasattr(config, "patcher_dim")
1552
+ else getattr(config, "n_kv_heads", None)
1553
+ )
1554
+ multiple_of = (
1555
+ getattr(config, "patcher_multiple_of", 256)
1556
+ if hasattr(config, "patcher_dim")
1557
+ else getattr(config, "multiple_of", 256)
1558
+ )
1559
+ ffn_dim_multiplier = (
1560
+ getattr(config, "patcher_ffn_dim_multiplier", None)
1561
+ if hasattr(config, "patcher_dim")
1562
+ else getattr(config, "ffn_dim_multiplier", None)
1563
+ )
1564
+
1565
+ self.layers = nn.ModuleList()
1566
+ for _ in range(config.patcher_n_layers):
1567
+ self.layers.append(
1568
+ BLTTransformerLayer(
1569
+ {
1570
+ "dim": config.patcher_dim,
1571
+ "n_heads": config.patcher_n_heads,
1572
+ "head_dim": config.patcher_head_dim,
1573
+ "n_kv_heads": n_kv_heads,
1574
+ "rope_theta": config.patcher_rope_theta,
1575
+ "multiple_of": multiple_of,
1576
+ "ffn_dim_multiplier": ffn_dim_multiplier,
1577
+ "norm_eps": config.patcher_norm_eps,
1578
+ }
1579
+ )
1580
+ )
1581
+
1582
+ # LMTransformer specific attributes
1583
+ self.sliding_window = config.patcher_sliding_window
1584
+
1585
+ assert config.patcher_vocab_size > 0
1586
+
1587
+ self.tok_embeddings = torch.nn.Embedding(config.patcher_vocab_size, config.patcher_dim)
1588
+
1589
+ self.norm = RMSNorm(config.patcher_dim, eps=config.patcher_norm_eps)
1590
+
1591
+ self.output = nn.Linear(
1592
+ config.patcher_dim,
1593
+ config.patcher_vocab_size,
1594
+ bias=False,
1595
+ )
1596
+
1597
+ def forward(
1598
+ self,
1599
+ token_values: torch.Tensor,
1600
+ target: Optional[torch.Tensor] = None,
1601
+ tok_idx: Optional[torch.Tensor] = None,
1602
+ mask: Optional[Union[BlockMask, torch.Tensor, str]] = None,
1603
+ attn_impl: str | None = None,
1604
+ patch_size: Optional[int] = None,
1605
+ include_next_token: bool = True,
1606
+ threshold: Optional[float] = None,
1607
+ threshold_add: Optional[float] = None,
1608
+ monotonicity: bool = False,
1609
+ max_patch_length: Optional[int] = None,
1610
+ patching_batch_size: int = 1,
1611
+ device: Optional[str] = None,
1612
+ enable_grad: bool = False,
1613
+ ):
1614
+ attn_impl = self.config.patcher_attn_impl if attn_impl is None else attn_impl
1615
+
1616
+ # Handle chunked processing for entropy calculation
1617
+ entropies = []
1618
+ preds = []
1619
+ max_length = min(getattr(self, "max_length", 8192), self.config.patcher_max_seqlen)
1620
+ batch_numel = max_length * patching_batch_size
1621
+ splits = torch.split(token_values.flatten(), batch_numel)
1622
+
1623
+ for split in splits:
1624
+ pad_size = (max_length - (split.numel() % max_length)) % max_length
1625
+ pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False)
1626
+ split = torch.cat((split, pad), dim=0)
1627
+ split = split.reshape(-1, max_length)
1628
+ if device is not None:
1629
+ split = split.to(device)
1630
+
1631
+ # Process chunk: embeddings -> layers -> output
1632
+ bsz, seqlen = split.shape
1633
+ h = self.tok_embeddings(split)
1634
+ chunk_mask = create_causal_mask(
1635
+ seqlen,
1636
+ attn_impl,
1637
+ self.config.patcher_attn_bias_type,
1638
+ sliding_window=self.sliding_window,
1639
+ tokens=split,
1640
+ eos_id=self.eos_id,
1641
+ )
1642
+ freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None)
1643
+
1644
+ for i, layer in enumerate(self.layers):
1645
+ h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=attn_impl)
1646
+
1647
+ pred = self.output(self.norm(h))
1648
+ pred = pred.reshape(-1, pred.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab]
1649
+ preds.append(pred)
1650
+ pred_entropies = self.entropy(pred)
1651
+ entropies.append(pred_entropies)
1652
+
1653
+ concat_entropies = torch.cat(entropies, dim=0)
1654
+ concat_entropies = concat_entropies.reshape(token_values.shape)
1655
+ concat_preds = torch.cat(preds, dim=0)
1656
+ concat_preds = concat_preds.reshape(token_values.shape[0], -1)
1657
+
1658
+ # Always compute patch lengths from concatenated entropies
1659
+ bs, seq_len = token_values.shape
1660
+ seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
1661
+
1662
+ # Find patch start IDs based on entropy
1663
+ if patch_size is not None:
1664
+ patch_start_ids = self.find_entropy_patch_start_ids(
1665
+ concat_entropies,
1666
+ patch_size,
1667
+ include_next_token=include_next_token,
1668
+ threshold=threshold,
1669
+ threshold_add=threshold_add,
1670
+ monotonicity=monotonicity,
1671
+ )
1672
+ patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok)
1673
+ else:
1674
+ # Default to byte-level patching
1675
+ patch_lengths = torch.ones((bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device)
1676
+
1677
+ # Apply any processing to patch lengths
1678
+ if max_patch_length is not None:
1679
+ # TODO: avoid going back to a list here.
1680
+ patch_lengths = [self.split_large_numbers(pl, max_patch_length) for pl in patch_lengths.tolist()]
1681
+ max_len = max([len(pl) for pl in patch_lengths])
1682
+ patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
1683
+ patch_lengths = torch.tensor(patch_lengths, dtype=token_values.dtype, device=token_values.device)
1684
+ assert not check_non_zero_after_zero(patch_lengths)
1685
+ # Find the last non-zero column index using argmax on a reversed version of the tensor
1686
+ last_non_zero_col_reversed = (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
1687
+ # Slice the tensor up to the last non-zero column
1688
+ patch_lengths = patch_lengths[:, : patch_lengths.shape[1] - last_non_zero_col_reversed]
1689
+
1690
+ return concat_entropies, patch_lengths, concat_preds
1691
+
1692
+
1693
+
1694
+
1695
+
1696
+ @staticmethod
1697
+ def entropy(scores):
1698
+ """
1699
+ scores: [bs, seq_len, vocab]
1700
+ returns [bs, seq_len]
1701
+
1702
+ Computes the entropy for each token in the batch.
1703
+ Note: uses natural log.
1704
+ """
1705
+ log_probs = F.log_softmax(scores, dim=-1)
1706
+ probs = torch.exp(log_probs)
1707
+ p_log_p = log_probs * probs
1708
+ entropy = -p_log_p.sum(dim=-1)
1709
+ return entropy
1710
+
1711
+ @staticmethod
1712
+ def patch_start_ids_from_patch_start_mask(patch_start_mask):
1713
+ bs, trunc_seq_len = patch_start_mask.shape
1714
+ max_patches = patch_start_mask.sum(dim=1).max()
1715
+ if max_patches == 0:
1716
+ patch_start_ids = torch.full(
1717
+ (bs, trunc_seq_len),
1718
+ trunc_seq_len,
1719
+ dtype=torch.long,
1720
+ device=patch_start_mask.device,
1721
+ )
1722
+ else:
1723
+ patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(bs, 1)
1724
+ extra_patch_ids = torch.full(
1725
+ (bs, trunc_seq_len),
1726
+ trunc_seq_len,
1727
+ dtype=torch.long,
1728
+ device=patch_start_mask.device,
1729
+ )
1730
+ all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
1731
+ patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1)
1732
+ patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, trunc_seq_len)[:, :max_patches]
1733
+ return patch_start_ids
1734
+
1735
+ @staticmethod
1736
+ def patch_lengths_from_start_ids(patch_start_ids, seq_len):
1737
+ """
1738
+ Calculate patch lengths from start ids.
1739
+ start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then
1740
+ the rest are filled to the seq len.
1741
+ seq_len: ex: 7 length of the sequence
1742
+
1743
+ returns the patch lengths:
1744
+ [1, 6] for the above example.
1745
+ """
1746
+ last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1)
1747
+ patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1)
1748
+ patch_lengths = patch_end_ids - patch_start_ids + 1
1749
+ assert torch.all(patch_lengths >= 0), f"{patch_lengths}"
1750
+ assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}"
1751
+ return patch_lengths
1752
+
1753
+ @staticmethod
1754
+ def find_entropy_patch_start_ids(
1755
+ entropies,
1756
+ patch_size=None,
1757
+ threshold=None,
1758
+ threshold_add=None,
1759
+ monotonicity=False,
1760
+ include_next_token=True,
1761
+ ):
1762
+ """
1763
+ Use entropies to find the start ids of each patch.
1764
+ Use patch_size or threshold to figure out the total number of patches to allocate.
1765
+
1766
+ When threshold is not None the number of patches is not constant between
1767
+ different sequences, but patches can be identified incrementally rather than
1768
+ decided globally using the entire sequence.
1769
+ """
1770
+ bs, seq_len = entropies.shape[:2]
1771
+
1772
+ first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(bs, 1)
1773
+ preds_truncation_len = first_ids.shape[1] # remove the first preds because they will be start of patches.
1774
+ entropies = entropies[:, 1:]
1775
+ if threshold is None:
1776
+ num_patches = seq_len // patch_size
1777
+ patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices
1778
+ patch_start_ids = patch_start_ids.sort(dim=1).values
1779
+ else:
1780
+ patch_start_mask = entropies > threshold
1781
+ if not include_next_token:
1782
+ patch_start_mask = patch_start_mask[:, :-1]
1783
+ # patch_start_mask[1:] |= tokens[:-1] < OFFSET
1784
+ patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask)
1785
+
1786
+ patch_start_ids = torch.cat((first_ids, patch_start_ids + preds_truncation_len), dim=1)
1787
+ return patch_start_ids
1788
+
1789
+ @staticmethod
1790
+ def split_large_numbers(lst, m):
1791
+ new_lst = []
1792
+ for i in lst:
1793
+ if i > m:
1794
+ while i > m:
1795
+ new_lst.append(m)
1796
+ i -= m
1797
+ new_lst.append(i)
1798
+ else:
1799
+ new_lst.append(i)
1800
+ assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}"
1801
+ return new_lst
1802
+
1803
+
1804
+ def init_hash_embeddings(
1805
+ config,
1806
+ local_encoder_dim: int,
1807
+ encoder_hash_byte_group_size: list,
1808
+ ):
1809
+ """Initialize hash-based token embeddings for the BLT encoder."""
1810
+ if config.encoder_hash_byte_group_size is None:
1811
+ return None
1812
+
1813
+ embeddings = []
1814
+ emb_dim = local_encoder_dim
1815
+ encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab
1816
+
1817
+ for _ in range(config.encoder_hash_byte_group_nb_functions):
1818
+ for _ in encoder_hash_byte_group_size:
1819
+ embeddings.append(
1820
+ nn.Embedding(
1821
+ encoder_hash_byte_group_vocab,
1822
+ emb_dim,
1823
+ )
1824
+ )
1825
+
1826
+ return nn.ModuleList(embeddings)
1827
+
1828
+
1829
+ __all__ = [
1830
+ "BLTPreTrainedModel",
1831
+ "BLTModel",
1832
+ "BLTPatcher",
1833
+ "LocalEncoder",
1834
+ "LocalDecoder",
1835
+ "GlobalTransformer",
1836
+ ]
backup_blt_wip_backup/modeling_blt_wip_backup.py ADDED
@@ -0,0 +1,2166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from enum import Enum
4
+ from typing import Any, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from pydantic import model_validator
8
+ from torch import nn
9
+ from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention
10
+ import json
11
+ import logging
12
+
13
+ import torch
14
+ import torch.nn
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+
18
+ import os
19
+ from contextlib import nullcontext
20
+
21
+ SEP = " "
22
+ BOS_ID: int = 1
23
+ EOS_ID: int = 2
24
+ PAD_ID: int = -1
25
+ BOE_ID: int = 0
26
+ BPE_ID: int = 3
27
+ OFFSET: int = 4
28
+
29
+ BYTE_UNITS: int = 256
30
+
31
+ RMSNorm = nn.RMSNorm
32
+
33
+ logger = logging.getLogger()
34
+
35
+ from .configuration_blt import (
36
+ BLTConfig,
37
+ PatchingModeEnum,
38
+ InitStdFactor,
39
+ )
40
+
41
+ from ...modeling_utils import PreTrainedModel
42
+ from ...utils import logging as transformers_logging
43
+
44
+ flex_attention_comp = flex_attention
45
+
46
+
47
+ def causal_mask(b, h, q_idx, kv_idx):
48
+ return q_idx >= kv_idx
49
+
50
+
51
+ def create_causal_mask(
52
+ seqlen,
53
+ attn_impl: str,
54
+ attn_bias_type: str | None,
55
+ *,
56
+ eos_id: int | None = None,
57
+ tokens: torch.Tensor | None = None,
58
+ sliding_window: int | None = None,
59
+ ):
60
+ if attn_impl == "sdpa":
61
+ BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0))
62
+
63
+ if attn_bias_type == "causal":
64
+ return "causal"
65
+
66
+ if BLT_SUPPRESS_ATTN_ERROR == 1:
67
+ return "causal"
68
+ else:
69
+ raise ValueError(
70
+ "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1"
71
+ )
72
+ elif attn_impl == "flex_attention":
73
+ return create_block_mask(causal_mask, None, None, seqlen, seqlen)
74
+ else:
75
+ raise NotImplementedError(
76
+ f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
77
+ )
78
+
79
+ def cross_entropy(pred, target, **kwargs):
80
+ return F.nll_loss(
81
+ F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
82
+ target.flatten(end_dim=-1),
83
+ **kwargs,
84
+ )
85
+
86
+
87
+ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
88
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
89
+ assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
90
+ bs, slen, n_kv_heads, head_dim = x.shape
91
+ if n_rep == 1:
92
+ return x
93
+ return (
94
+ x[:, :, :, None, :]
95
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
96
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
97
+ )
98
+
99
+
100
+ def precompute_freqs_cis(
101
+ dim: int,
102
+ end: int,
103
+ theta: float = 10000.0,
104
+ rope_use_fp32_in_outer_product: bool = False,
105
+ ):
106
+ """
107
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
108
+
109
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
110
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
111
+ The returned tensor contains complex values in complex64 data type.
112
+
113
+ Args:
114
+ dim (int): Dimension of the frequency tensor.
115
+ end (int): End index for precomputing frequencies.
116
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
117
+
118
+ Returns:
119
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
120
+ """
121
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
122
+ t = torch.arange(end, device=freqs.device)
123
+ if rope_use_fp32_in_outer_product:
124
+ t = t.to(torch.float32)
125
+
126
+ freqs = torch.outer(t, freqs).float()
127
+
128
+ cos, sin = freqs.cos(), freqs.sin()
129
+
130
+ return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
131
+
132
+
133
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
134
+ """
135
+ Reshape frequency tensor for broadcasting it with another tensor.
136
+
137
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
138
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
139
+
140
+ Args:
141
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
142
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
143
+ seq_dim (int): Sequence dimension index.
144
+
145
+ Returns:
146
+ torch.Tensor: Reshaped frequency tensor.
147
+ """
148
+ ndim = x.ndim
149
+ assert 0 <= seq_dim < ndim
150
+ assert freqs_cis.shape == (
151
+ x.shape[seq_dim],
152
+ x.shape[-3],
153
+ 2,
154
+ 2,
155
+ ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
156
+ shape = [
157
+ d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
158
+ ] + [2, 2]
159
+ return freqs_cis.view(*shape)
160
+
161
+
162
+ def apply_rotary_emb(
163
+ xq: torch.Tensor,
164
+ xk: torch.Tensor,
165
+ seq_dim: int,
166
+ freqs_cis: torch.Tensor,
167
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
168
+ xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
169
+ xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
170
+ freqs_cis = reshape_for_broadcast(
171
+ freqs_cis, xq_, seq_dim
172
+ ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
173
+ xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
174
+ xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
175
+ return xq_out.type_as(xq), xk_out.type_as(xk)
176
+
177
+
178
+ # Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
179
+ class RotaryEmbedding(torch.nn.Module):
180
+ """
181
+ RotaryEmbedding Module
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ theta: float,
187
+ head_dim: int,
188
+ max_seqlen: int = 1024,
189
+ rope_use_fp32_in_outer_product: bool = False,
190
+ ):
191
+ super().__init__()
192
+
193
+ self.theta = theta
194
+ self.head_dim = head_dim
195
+ self.max_seqlen = max_seqlen
196
+ self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
197
+
198
+ self.register_buffer(
199
+ "freqs_cis",
200
+ precompute_freqs_cis(
201
+ dim=head_dim,
202
+ end=max_seqlen,
203
+ theta=theta,
204
+ rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
205
+ ),
206
+ persistent=False,
207
+ )
208
+
209
+ def reset_parameters(self):
210
+ self.freqs_cis[...] = precompute_freqs_cis(
211
+ dim=self.head_dim,
212
+ end=self.max_seqlen,
213
+ theta=self.theta,
214
+ rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
215
+ )
216
+
217
+ def forward(
218
+ self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None
219
+ ):
220
+ """
221
+ Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
222
+ Args:
223
+ seqlen (int): Contiguous sequence length
224
+ tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
225
+
226
+ Returns:
227
+ Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
228
+ """
229
+ test = (seqlen is not None) or (tok_idx is not None)
230
+ assert test, "Should provide atleast seqlen or tok_idx"
231
+ if tok_idx is not None:
232
+ return self.freqs_cis[tok_idx]
233
+ elif seqlen is not None:
234
+ return self.freqs_cis[0:seqlen]
235
+
236
+
237
+ class BLTAttention(nn.Module):
238
+ def __init__(
239
+ self,
240
+ dim: int,
241
+ head_dim: int,
242
+ n_heads: int,
243
+ n_kv_heads: int,
244
+ rope_theta: float,
245
+ ):
246
+ super().__init__()
247
+
248
+ self.dim = dim
249
+ self.head_dim = head_dim
250
+ self.rope_theta = rope_theta
251
+
252
+ self.n_heads = n_heads
253
+ self.n_kv_heads = n_kv_heads
254
+ self.heads_per_group = self.n_heads // self.n_kv_heads
255
+
256
+ self.wq = nn.Linear(
257
+ dim,
258
+ n_heads * head_dim,
259
+ bias=False,
260
+ )
261
+ self.wk = nn.Linear(
262
+ dim,
263
+ n_kv_heads * head_dim,
264
+ bias=False,
265
+ )
266
+ self.wv = nn.Linear(
267
+ dim,
268
+ n_kv_heads * head_dim,
269
+ bias=False,
270
+ )
271
+
272
+ self.wo = nn.Linear(
273
+ n_heads * head_dim,
274
+ dim,
275
+ bias=False,
276
+ )
277
+
278
+ def forward(
279
+ self,
280
+ x: torch.Tensor,
281
+ freq_cis: torch.Tensor,
282
+ tok_idx: Optional[torch.Tensor] = None,
283
+ mask: Optional[Union[BlockMask, str]] = None,
284
+ attn_impl: str = "sdpa",
285
+ ) -> torch.Tensor:
286
+ # B S D
287
+ bsz, seq_len, dim = x.shape
288
+ xq = self.wq(x.view_as(x))
289
+ xk = self.wk(x.view_as(x))
290
+ xv = self.wv(x.view_as(x))
291
+
292
+ output_shape = xq.shape
293
+ # B S D -> B S H D
294
+ xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
295
+ xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
296
+ xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
297
+
298
+ xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
299
+
300
+ # This condition helps us be easily compatible
301
+ # with inference by adding a pluggable KVCache
302
+ if hasattr(self, "kv_cache"):
303
+ xk, xv = self.kv_cache.update(xk, xv, tok_idx)
304
+
305
+ xk = repeat_kv(xk, self.heads_per_group, dim=2)
306
+ xv = repeat_kv(xv, self.heads_per_group, dim=2)
307
+
308
+ if attn_impl == "flex_attention":
309
+ assert mask is None or isinstance(mask, BlockMask)
310
+ xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
311
+ output = flex_attention_comp(xq, xk, xv, block_mask=mask)
312
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
313
+
314
+ elif attn_impl == "sdpa":
315
+ xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
316
+ assert mask is None or isinstance(mask, (str, torch.Tensor))
317
+ is_causal = (mask == "causal") if isinstance(mask, str) else False
318
+ mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None
319
+ output = F.scaled_dot_product_attention(
320
+ xq,
321
+ xk,
322
+ xv,
323
+ is_causal=is_causal,
324
+ attn_mask=mask,
325
+ )
326
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
327
+ else:
328
+ raise NotImplementedError(
329
+ f"Attention implementation {attn_impl} not supported"
330
+ )
331
+
332
+ output_reshaped = output.reshape(output_shape)
333
+
334
+ output = self.wo(output_reshaped)
335
+
336
+ return output
337
+
338
+ def reset_parameters(self, init_std=None, factor=1.0):
339
+ init_std = init_std or (self.dim ** (-0.5)) / factor
340
+
341
+ for w in [self.wq, self.wk, self.wv]:
342
+ nn.init.trunc_normal_(
343
+ w.weight,
344
+ mean=0.0,
345
+ std=init_std,
346
+ a=-3 * init_std,
347
+ b=3 * init_std,
348
+ )
349
+
350
+ nn.init.trunc_normal_(
351
+ self.wo.weight,
352
+ mean=0.0,
353
+ std=init_std,
354
+ a=-3 * init_std,
355
+ b=3 * init_std,
356
+ )
357
+
358
+
359
+ class BLTMLP(nn.Module):
360
+ def __init__(
361
+ self,
362
+ dim: int,
363
+ hidden_dim: int,
364
+ multiple_of: int,
365
+ ffn_dim_multiplier: Optional[float],
366
+ mp_size: int = 1,
367
+ ):
368
+ super().__init__()
369
+
370
+ hidden_dim = int(2 * hidden_dim / 3)
371
+ if ffn_dim_multiplier is not None:
372
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
373
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
374
+ assert hidden_dim % mp_size == 0
375
+
376
+ self.dim = dim
377
+ self.hidden_dim = hidden_dim
378
+
379
+ self.w1 = nn.Linear(
380
+ dim,
381
+ hidden_dim,
382
+ bias=False,
383
+ )
384
+ self.w3 = nn.Linear(
385
+ dim,
386
+ hidden_dim,
387
+ bias=False,
388
+ )
389
+ self.w2 = nn.Linear(
390
+ hidden_dim,
391
+ dim,
392
+ bias=False,
393
+ )
394
+
395
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
396
+ # B S D
397
+ x1 = self.w1(x.view_as(x))
398
+ x3 = self.w3(x.view_as(x))
399
+ output = self.w2(F.silu(x1) * x3)
400
+ return output
401
+
402
+ def reset_parameters(self, init_std=None, factor=1.0):
403
+ in_init_std = init_std or (self.dim ** (-0.5)) / factor
404
+ out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor
405
+
406
+ nn.init.trunc_normal_(
407
+ self.w1.weight,
408
+ mean=0.0,
409
+ std=in_init_std,
410
+ a=-3 * in_init_std,
411
+ b=3 * in_init_std,
412
+ )
413
+ nn.init.trunc_normal_(
414
+ self.w2.weight,
415
+ mean=0.0,
416
+ std=out_init_std,
417
+ a=-3 * out_init_std,
418
+ b=3 * out_init_std,
419
+ )
420
+ nn.init.trunc_normal_(
421
+ self.w3.weight,
422
+ mean=0.0,
423
+ std=in_init_std,
424
+ a=-3 * in_init_std,
425
+ b=3 * in_init_std,
426
+ )
427
+
428
+
429
+ class BLTTransformerLayer(nn.Module):
430
+ def __init__(self, args):
431
+ super().__init__()
432
+
433
+ # Extract parameters from dictionary
434
+ dim = args['dim']
435
+ n_heads = args['n_heads']
436
+ head_dim = args['head_dim']
437
+ n_kv_heads = args['n_kv_heads']
438
+ rope_theta = args['rope_theta']
439
+ multiple_of = args['multiple_of']
440
+ ffn_dim_multiplier = args['ffn_dim_multiplier']
441
+ norm_eps = args['norm_eps']
442
+
443
+ assert (head_dim is not None) or (
444
+ n_heads is not None
445
+ ), "Should specify at least head_dim or n_heads"
446
+ self.head_dim = head_dim or dim // n_heads
447
+ self.n_heads = n_heads or dim // head_dim
448
+ self.n_kv_heads = n_kv_heads or self.n_heads
449
+
450
+ assert n_heads % self.n_kv_heads == 0
451
+ assert dim % n_heads == 0
452
+
453
+ self.attention = BLTAttention(
454
+ dim=dim,
455
+ head_dim=self.head_dim,
456
+ n_heads=self.n_heads,
457
+ n_kv_heads=self.n_kv_heads,
458
+ rope_theta=rope_theta,
459
+ )
460
+ self.feed_forward = BLTMLP(
461
+ dim=dim,
462
+ hidden_dim=4 * dim,
463
+ multiple_of=multiple_of,
464
+ ffn_dim_multiplier=ffn_dim_multiplier,
465
+ )
466
+ self.attention_norm = RMSNorm(dim, eps=norm_eps)
467
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
468
+
469
+ def forward(
470
+ self,
471
+ x: torch.Tensor,
472
+ freq_cis: torch.Tensor,
473
+ tok_idx: Optional[torch.Tensor] = None,
474
+ mask: Optional[Union[BlockMask, str]] = None,
475
+ attn_impl: str = "sdpa",
476
+ ) -> torch.Tensor:
477
+ norm_x = self.attention_norm(x)
478
+ attn_out = self.attention(
479
+ norm_x,
480
+ freq_cis,
481
+ tok_idx=tok_idx,
482
+ mask=mask,
483
+ attn_impl=attn_impl,
484
+ )
485
+ h = x + attn_out
486
+ h_norm = self.ffn_norm(h)
487
+ out = h + self.feed_forward(h_norm)
488
+ return out
489
+
490
+ def init_weights(self, init_std=None, factor=1.0):
491
+ self.attention.reset_parameters(init_std, factor)
492
+ self.attention_norm.reset_parameters()
493
+
494
+ self.feed_forward.reset_parameters(init_std, factor)
495
+ self.ffn_norm.reset_parameters()
496
+
497
+
498
+ def rightpad(seq, pad_id, max_len):
499
+ return seq + [pad_id] * (max_len - len(seq))
500
+
501
+
502
+ def check_non_zero_after_zero(tensor):
503
+ zero_mask = tensor == 0
504
+ shifted_mask = torch.cat(
505
+ [
506
+ torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
507
+ zero_mask[:, :-1],
508
+ ],
509
+ dim=1,
510
+ )
511
+ non_zero_after_zero = (tensor != 0) & shifted_mask
512
+ return non_zero_after_zero.any()
513
+
514
+
515
+ def fill_tokens(tokens, patch_size, fill_id):
516
+ batch_size, seq_len = tokens.shape
517
+ if seq_len % patch_size == 0:
518
+ return tokens
519
+ else:
520
+ remaining = patch_size - seq_len % patch_size
521
+ final_padding = tokens.new(batch_size, remaining).fill_(fill_id)
522
+ return torch.cat((tokens, final_padding), dim=1)
523
+
524
+
525
+ def rolling_polynomial_hash(t, hash_func_nb: int = 0):
526
+ primes = [
527
+ 1000000007,
528
+ 5915587277,
529
+ 1500450271,
530
+ 3267000013,
531
+ 5754853343,
532
+ 4093082899,
533
+ 9576890767,
534
+ 3628273133,
535
+ 2860486313,
536
+ 5463458053,
537
+ 3367900313,
538
+ ]
539
+ prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device)
540
+ prime_powers = torch.stack([prime**i for i in range(t.shape[-1])])
541
+ return torch.sum(t * prime_powers, dim=-1)
542
+
543
+ def byte_group_hash_function(
544
+ x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000
545
+ ):
546
+ """
547
+ Returns a hash of the input x and maps it to a value in the range [0, max_hash].
548
+
549
+ expects: x of shape (batch_size, seq_len) with values as ids in the token vocab.
550
+ returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash].
551
+
552
+ Note: max hash can make a big difference on the number of collisions.
553
+ """
554
+ with torch.no_grad():
555
+ bs, seq_len = x.shape
556
+ prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device)
557
+ x = torch.cat([prefix, x], dim=1)
558
+ windows = x.unfold(1, group_size, 1)
559
+ # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
560
+ hashes = rolling_polynomial_hash(windows, hash_func_nb)
561
+ hash_values_range = hashes % max_hash
562
+ hash_values_range.requires_grad = False
563
+ return hash_values_range
564
+
565
+
566
+ def create_patch_mask_from_ids(
567
+ patch_ids, num_patches, window=None, patches_as_queries=False
568
+ ):
569
+ """
570
+ Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k)
571
+ is True if the patch id at position (i, j) is less than or equal to k.
572
+ Args:
573
+ patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids.
574
+ num_patches (int): Total number of patches.
575
+ window (int): If not None, only considers patches within a window of size window.
576
+ patches_as_queries (bool): If True, the patches are used as queries
577
+ Returns:
578
+ torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask.
579
+ """
580
+ bs, seq_len = patch_ids.shape
581
+ if not patches_as_queries:
582
+ q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches)
583
+ kv_ids = (
584
+ torch.arange(num_patches, device=patch_ids.device)
585
+ .unsqueeze(0)
586
+ .unsqueeze(0)
587
+ .expand(bs, seq_len, num_patches)
588
+ )
589
+ else:
590
+ kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len)
591
+ q_ids = (
592
+ torch.arange(num_patches, device=patch_ids.device)
593
+ .unsqueeze(0)
594
+ .unsqueeze(-1)
595
+ .expand(bs, num_patches, seq_len)
596
+ )
597
+ if window is None:
598
+ mask = q_ids == kv_ids
599
+ else:
600
+ mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window)
601
+ return mask
602
+
603
+
604
+ def cross_attn_mask(
605
+ patch_ids,
606
+ patch_lengths,
607
+ N,
608
+ patches_as_queries=False,
609
+ cross_attn_k=1,
610
+ window=None,
611
+ block_mask=True,
612
+ ):
613
+ bs = patch_ids.shape[0]
614
+ with torch.no_grad():
615
+ # Create the patch mask
616
+ cross_mask = create_patch_mask_from_ids(
617
+ patch_ids,
618
+ patch_lengths.shape[1],
619
+ window=window,
620
+ patches_as_queries=patches_as_queries,
621
+ ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1)
622
+ q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N
623
+ kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k
624
+ assert cross_mask.shape == (
625
+ bs,
626
+ q_len,
627
+ kv_len,
628
+ ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}"
629
+ block_mask = None
630
+ if block_mask:
631
+
632
+ def patch_mask(b, h, q_idx, kv_idx):
633
+ return cross_mask[b, q_idx, kv_idx]
634
+
635
+ block_mask = create_block_mask(
636
+ patch_mask,
637
+ B=bs,
638
+ H=None,
639
+ Q_LEN=q_len,
640
+ KV_LEN=kv_len,
641
+ _compile=True,
642
+ )
643
+ return block_mask
644
+ else:
645
+ return torch.where(
646
+ cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))
647
+ ).unsqueeze(
648
+ 1
649
+ ) # [bs, 1, q_len, kv_len]
650
+
651
+
652
+ def get_blt_input(
653
+ tokens: torch.Tensor,
654
+ enforce_patch_size_multiple: bool,
655
+ nb_boe: torch.Tensor,
656
+ patch_size: int,
657
+ boe_id: int,
658
+ ):
659
+ """
660
+ This function returns X_et, X_gt and X_dt, the encoder, global, and decoder
661
+ tokens respectively.
662
+
663
+ Consider the input and target sequences:
664
+ X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13]
665
+ Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14]
666
+ with patch_size=4
667
+
668
+ Note 1: that there will be no special tokens introduced at the patch level.
669
+ Note 2: X_e needs to be trimmed to be passed to Global
670
+
671
+ Current without boe:
672
+ X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]]
673
+ X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch
674
+ X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
675
+ Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
676
+
677
+ --> lag fix:
678
+ X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]]
679
+ X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]]
680
+ X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
681
+ Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
682
+
683
+ Dynamic (current):
684
+ X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos]
685
+ Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
686
+
687
+ entropy patching:
688
+ input: 7, bos, 9, 10
689
+ pred (high entropy): eos, 8, 10, eos
690
+
691
+ X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos]
692
+ X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]]
693
+ X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]]
694
+ Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
695
+
696
+ --> lag fix no boe (force single byte first patch):
697
+ X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
698
+ X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
699
+ X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
700
+ Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
701
+
702
+ input: 4, 7, bos, 9, 10
703
+ pred (high entropy): 5, eos, 8, 10, eos
704
+
705
+ X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
706
+ X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
707
+ X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
708
+ Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
709
+
710
+ Handle the last byte properly.
711
+ patch_lengths = [1, 1, 3, 2, 2 1 2 2 1]
712
+ X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
713
+ X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch
714
+ X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]]
715
+ Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]]
716
+
717
+
718
+ bpe delim
719
+ X_et = [[3,4,5,6,7,<d>,eos,bos,<d>,8,9,<d>,10,<d>,eos,bos,11,12]
720
+ X_g = [[3], [4,5,6,7,<d>], [eos,bos,<d>], ..
721
+ X_dt = [[3,4,5,6,7], [<d>,eos,bos], [<d>,bos,8], ..
722
+ Y = [4,5,6,7,<d>, eos,bos,<d> 8,9,<d>, ..
723
+
724
+
725
+ Note 1: that there will be no special tokens introduced at the patch level.
726
+ Note 2: X_e needs to be trimmed to be passed to Global
727
+ """
728
+ batch_size, seq_len = tokens.shape
729
+ local_encoder_tokens = tokens
730
+ local_decoder_tokens = tokens
731
+
732
+ if nb_boe > 0:
733
+ padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id)
734
+ local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1)
735
+ # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id)
736
+
737
+ # create global tokens, contains boe tokens and eos
738
+ # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
739
+ # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size)
740
+ # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:]
741
+ # global_tokens += global_tokens.eq(0).int() * boe_id
742
+ # TODO: fix this when we want to use block causal in the global.
743
+
744
+ if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0:
745
+ local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
746
+
747
+ return local_encoder_tokens, None, local_decoder_tokens
748
+
749
+
750
+ class LocalModelBase(nn.Module):
751
+ def __init__(self, config: BLTConfig, component_type: str = "encoder"):
752
+ super().__init__()
753
+
754
+ # Store config for later use
755
+ self.config = config
756
+
757
+ # Use component-specific dimensions
758
+ if component_type == "encoder":
759
+ self.dim = config.dim_local_encoder
760
+ self.n_layers = config.n_layers_local_encoder
761
+ self.n_heads = config.n_heads_local_encoder
762
+ self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen
763
+ self.attn_bias_type = "local_block_causal"
764
+ self.sliding_window = config.local_attention_window_len
765
+ elif component_type == "decoder":
766
+ self.dim = config.dim_local_decoder
767
+ self.n_layers = config.n_layers_local_decoder
768
+ self.n_heads = config.n_heads_local_decoder
769
+ self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen
770
+ self.attn_bias_type = "local_block_causal"
771
+ self.sliding_window = config.local_attention_window_len
772
+ else:
773
+ raise ValueError(f"Unknown component_type: {component_type}")
774
+
775
+ self.dropout = config.dropout
776
+ self.vocab_size = config.vocab_size + config.pm_size
777
+ self.patch_size = config.patch_size
778
+
779
+ self.attn_impl = config.attn_impl
780
+ self.use_rope = config.use_rope
781
+ self.init_std_factor = config.init_std_factor
782
+ self.init_base_std = config.init_base_std
783
+ self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None)
784
+ self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None)
785
+ self.cross_attn_k = getattr(config, "cross_attn_k", None)
786
+ self.eos_id = config.eos_token_id
787
+
788
+ self.boe_id = BOE_ID
789
+
790
+ # Initialize cross attention layers as None (will be set by subclasses if needed)
791
+ self.cross_attn_layers = None
792
+
793
+ # Create parameter dict for BLTTransformerLayers
794
+ layer_params = {
795
+ 'dim': self.dim,
796
+ 'n_heads': self.n_heads,
797
+ 'head_dim': config.head_dim,
798
+ 'n_kv_heads': getattr(config, 'n_kv_heads', None),
799
+ 'rope_theta': config.rope_theta,
800
+ 'multiple_of': getattr(config, 'multiple_of', 256),
801
+ 'ffn_dim_multiplier': getattr(config, 'ffn_dim_multiplier', None),
802
+ 'norm_eps': config.norm_eps,
803
+ }
804
+
805
+ self.layers = nn.ModuleList(
806
+ [BLTTransformerLayer(layer_params) for _ in range(self.n_layers)]
807
+ )
808
+
809
+ if not self.use_rope:
810
+ self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length
811
+ else:
812
+ self.rope = RotaryEmbedding(
813
+ theta=config.rope_theta,
814
+ head_dim=config.head_dim or self.dim // self.n_heads,
815
+ max_seqlen=self.max_seqlen,
816
+ rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product,
817
+ )
818
+ self.pos_embeddings = None
819
+
820
+ # Set dimension-specific embedding dimensions
821
+ if component_type == "encoder":
822
+ self.dim_token_emb = config.encoder_dim_token_emb
823
+ self.dim_patch_emb = config.encoder_dim_patch_emb
824
+ elif component_type == "decoder":
825
+ self.dim_token_emb = config.decoder_dim_token_emb
826
+ self.dim_patch_emb = config.dim_global
827
+
828
+ self.token_embedding_projection = (
829
+ nn.Linear(self.dim_token_emb, self.dim, bias=False)
830
+ if self.dim_token_emb is not None and self.dim_token_emb != self.dim
831
+ else None
832
+ )
833
+
834
+ self.patch_embedding_projection = self._create_patch_projection(config)
835
+
836
+ def _should_create_patch_projection(self, config: BLTConfig):
837
+ dimension_mismatch = (
838
+ self.dim_patch_emb is not None and self.dim_patch_emb != self.dim
839
+ )
840
+
841
+ # Check cross attention conditions
842
+ cross_attn_conditions = (
843
+ config.cross_attn_encoder and config.cross_attn_init_by_pooling
844
+ ) or (config.cross_attn_decoder and config.cross_attn_init_by_pooling)
845
+
846
+ return dimension_mismatch or cross_attn_conditions
847
+
848
+ def _create_patch_projection(self, config):
849
+ if not self._should_create_patch_projection(config):
850
+ return None
851
+
852
+ output_dim = self.dim_token_emb * (self.cross_attn_k or 1)
853
+
854
+ return nn.Linear(
855
+ in_features=self.dim_patch_emb,
856
+ out_features=output_dim,
857
+ bias=False,
858
+ )
859
+
860
+ def apply_embedding(self, tokens, embeds):
861
+ if embeds is not None:
862
+ return embeds
863
+ else:
864
+ return self.tok_embeddings(tokens)
865
+
866
+ def init_weights(self, init_std=None):
867
+ self.rope.reset_parameters()
868
+ if hasattr(self, "norm"):
869
+ self.norm.reset_parameters()
870
+
871
+ init_std = init_std or (self.dim ** (-0.5))
872
+ if hasattr(self, "tok_embeddings"):
873
+ nn.init.trunc_normal_(
874
+ self.tok_embeddings.weight,
875
+ mean=0.0,
876
+ std=init_std,
877
+ a=-3 * init_std,
878
+ b=3 * init_std,
879
+ )
880
+ if self.pos_embeddings is not None:
881
+ nn.init.trunc_normal_(
882
+ self.pos_embeddings.weight,
883
+ mean=0.0,
884
+ std=init_std,
885
+ a=-3 * init_std,
886
+ b=3 * init_std,
887
+ )
888
+
889
+ for depth, layer in enumerate(self.layers):
890
+ factor = self.config.get_init_std_factor(depth)
891
+ layer.init_weights(self.init_base_std, factor)
892
+
893
+ if hasattr(self, "output"):
894
+ nn.init.trunc_normal_(
895
+ self.output.weight,
896
+ mean=0.0,
897
+ std=init_std,
898
+ a=-3 * init_std,
899
+ b=3 * init_std,
900
+ )
901
+
902
+ if self.token_embedding_projection is not None:
903
+ nn.init.trunc_normal_(
904
+ self.token_embedding_projection.weight,
905
+ mean=0.0,
906
+ std=init_std,
907
+ a=-3 * init_std,
908
+ b=3 * init_std,
909
+ )
910
+
911
+ if self.patch_embedding_projection is not None:
912
+ patch_emb_std = self.dim_patch_emb ** (-0.5)
913
+ nn.init.trunc_normal_(
914
+ self.patch_embedding_projection.weight,
915
+ mean=0.0,
916
+ std=patch_emb_std,
917
+ a=-3 * patch_emb_std,
918
+ b=3 * patch_emb_std,
919
+ )
920
+
921
+ if self.cross_attn_layers is not None:
922
+ for depth, layer in enumerate(self.cross_attn_layers):
923
+ factor = self.config.get_init_std_factor(depth)
924
+ layer.init_weights(None, factor)
925
+
926
+
927
+ class LocalEncoder(LocalModelBase):
928
+ def __init__(self, config: BLTConfig):
929
+ super().__init__(config, component_type="encoder")
930
+
931
+ self.apply_transformer = config.use_local_encoder_transformer
932
+ self.downsampling_by_pooling = config.downsampling_by_pooling
933
+ self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None
934
+ self.cross_attn_encoder = config.cross_attn_encoder
935
+ self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder
936
+ self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling
937
+ self.cross_attn_nheads = config.cross_attn_nheads
938
+
939
+ self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim)
940
+
941
+ if self.cross_attn_encoder:
942
+ self.cross_attn_layers = torch.nn.ModuleList()
943
+ layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1
944
+ for _ in range(layers_to_add):
945
+ self.cross_attn_layers.append(
946
+ BLTCrossAttention(
947
+ dim=self.dim,
948
+ head_dim=self.dim // self.cross_attn_nheads,
949
+ n_heads=self.cross_attn_nheads,
950
+ n_kv_heads=self.cross_attn_nheads,
951
+ norm_eps=config.norm_eps,
952
+ )
953
+ )
954
+
955
+ def apply_embedding(self, tokens, embeds):
956
+ if embeds is not None:
957
+ assert (
958
+ self.expects_hash_embeddings
959
+ ), "Not expecting embeddings to be passed."
960
+ return embeds
961
+ else:
962
+ return self.tok_embeddings(tokens)
963
+
964
+ def forward(
965
+ self,
966
+ tokens: torch.Tensor,
967
+ embeds: Optional[torch.Tensor] = None,
968
+ patch_embeds: Optional[torch.Tensor] = None,
969
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
970
+ cross_mask: Optional[torch.Tensor] = None,
971
+ num_patches: Optional[int] = None,
972
+ patch_ids: Optional[torch.Tensor] = None,
973
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
974
+ ):
975
+ """ """
976
+ bs, seqlen = tokens.shape
977
+ if mask is None:
978
+ mask = create_causal_mask(
979
+ seqlen,
980
+ self.attn_impl,
981
+ "local_block_causal",
982
+ sliding_window=self.sliding_window,
983
+ tokens=tokens,
984
+ eos_id=self.eos_id,
985
+ )
986
+
987
+ h = self.apply_embedding(tokens, embeds)
988
+ freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
989
+
990
+ h = F.dropout(h, p=self.dropout, training=self.training)
991
+
992
+ for i, layer in enumerate(self.layers):
993
+ h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
994
+ # check if cross attention should be applied to either all layer or only the last layer
995
+ if self.cross_attn_encoder and (
996
+ i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
997
+ ):
998
+ # apply pooling and project
999
+ if self.cross_attn_init_by_pooling and patch_embeds is None:
1000
+ patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids)
1001
+ if self.patch_embedding_projection is not None:
1002
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
1003
+ patch_embeds = patch_embeds.reshape(
1004
+ bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
1005
+ )
1006
+
1007
+ layer_idx = i if self.cross_attn_all_layers_encoder else 0
1008
+ patch_embeds_cross = self.cross_attn_layers[layer_idx](
1009
+ x=patch_embeds,
1010
+ kv=h,
1011
+ mask=cross_mask,
1012
+ )
1013
+ patch_embeds = patch_embeds + patch_embeds_cross
1014
+
1015
+ h_residual = patch_embeds if self.cross_attn_encoder else None
1016
+ return (h, h_residual), cache
1017
+
1018
+
1019
+
1020
+ def patch_reduce(self, h, max_num_patches, reduction, patch_ids):
1021
+ """
1022
+ Reduce variable length patches to single embedding per patch
1023
+ Note: this works with variable number of patches for different sequences in the batch
1024
+ It handles variable length patches by assuming that patch_lengths will be 0 for any
1025
+ extra patches on the *right*. Since there can be a variable number of patches
1026
+ this function also return the number of patches for each sequence in the batch.
1027
+ Any embeddings on the right that are not allocated to a patch
1028
+ (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
1029
+ will be sent to a dummy patch, which is trimmed before returning.
1030
+ """
1031
+ bs, seq_len, emb_dim = h.shape
1032
+
1033
+ patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
1034
+
1035
+ reduced_embs = torch.zeros(
1036
+ (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device
1037
+ )
1038
+ reduced_embs = reduced_embs.scatter_reduce(
1039
+ src=h,
1040
+ dim=1,
1041
+ index=patch_ids,
1042
+ reduce=reduction,
1043
+ include_self=False,
1044
+ )
1045
+ reduced_embs = reduced_embs[:, :max_num_patches, :]
1046
+
1047
+ return reduced_embs
1048
+
1049
+
1050
+ class LocalDecoder(LocalModelBase):
1051
+ def __init__(self, config: BLTConfig):
1052
+ super().__init__(config, component_type="decoder")
1053
+
1054
+ # Model configuration flags
1055
+ self.cross_attn_decoder = config.cross_attn_decoder
1056
+ self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder
1057
+ self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling
1058
+ self.cross_attn_nheads = config.cross_attn_nheads
1059
+
1060
+ self.norm = RMSNorm(self.dim, eps=config.norm_eps)
1061
+
1062
+ if self.cross_attn_decoder:
1063
+ self.cross_attn_layers = torch.nn.ModuleList()
1064
+ layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1
1065
+ for _ in range(layers_to_add):
1066
+ self.cross_attn_layers.append(
1067
+ BLTCrossAttention(
1068
+ dim=self.dim,
1069
+ head_dim=self.dim // self.cross_attn_nheads,
1070
+ n_heads=self.cross_attn_nheads,
1071
+ n_kv_heads=self.cross_attn_nheads,
1072
+ norm_eps=config.norm_eps,
1073
+ )
1074
+ )
1075
+
1076
+ self.output = nn.Linear(
1077
+ self.dim,
1078
+ config.vocab_size,
1079
+ bias=False,
1080
+ )
1081
+
1082
+ def forward(
1083
+ self,
1084
+ tokens: torch.Tensor,
1085
+ embeds: Optional[torch.Tensor],
1086
+ patch_embeds: Optional[torch.Tensor] = None,
1087
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
1088
+ cross_mask: Optional[torch.Tensor] = None,
1089
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
1090
+ ):
1091
+ bs, seqlen = tokens.shape
1092
+ assert embeds is not None, "Embeddings must be provided"
1093
+
1094
+ if mask is None:
1095
+ mask = create_causal_mask(
1096
+ seqlen,
1097
+ self.attn_impl,
1098
+ "local_block_causal",
1099
+ sliding_window=self.sliding_window,
1100
+ tokens=tokens,
1101
+ eos_id=self.eos_id,
1102
+ )
1103
+
1104
+ h = embeds
1105
+
1106
+ if self.patch_embedding_projection is not None:
1107
+ assert patch_embeds is not None, "Patch embeddings must be passed."
1108
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
1109
+ if self.cross_attn_k is not None:
1110
+ patch_embeds = patch_embeds.reshape(
1111
+ bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
1112
+ )
1113
+
1114
+ if patch_embeds is not None and not self.cross_attn_decoder:
1115
+ h = h + patch_embeds
1116
+
1117
+ freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
1118
+
1119
+ h = F.dropout(h, p=self.dropout, training=self.training)
1120
+ for i, layer in enumerate(self.layers):
1121
+ if self.cross_attn_decoder and (
1122
+ i == 0 or self.cross_attn_all_layers_decoder
1123
+ ):
1124
+ # Use cross attention to extract info from patch_embeds into h
1125
+ h_cross = self.cross_attn_layers[i](
1126
+ x=h,
1127
+ kv=patch_embeds,
1128
+ mask=cross_mask,
1129
+ )
1130
+ h = h + h_cross
1131
+
1132
+ h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
1133
+
1134
+ h_preds = self.norm(h)
1135
+ h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
1136
+ h_preds = self.output(h_preds)
1137
+ h_preds = h_preds.float()
1138
+ return h_preds, cache
1139
+
1140
+
1141
+ class BLTCrossAttention(nn.Module):
1142
+ """
1143
+ BLTCrossAttention block to attend to the encoder states from the decoder.
1144
+ Rope is not supported.
1145
+ """
1146
+
1147
+ def __init__(
1148
+ self,
1149
+ dim: int,
1150
+ head_dim: int,
1151
+ n_heads: int,
1152
+ n_kv_heads: int,
1153
+ norm_eps: float,
1154
+ ):
1155
+ super().__init__()
1156
+
1157
+ self.dim = dim
1158
+ self.head_dim = head_dim
1159
+
1160
+ self.n_heads = n_heads
1161
+ self.n_kv_heads = n_kv_heads
1162
+ self.heads_per_group = self.n_heads // self.n_kv_heads
1163
+
1164
+ self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps)
1165
+ self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
1166
+
1167
+ self.wq = nn.Linear(
1168
+ dim,
1169
+ n_heads * head_dim,
1170
+ bias=False,
1171
+ )
1172
+ self.wk = nn.Linear(
1173
+ dim,
1174
+ n_kv_heads * head_dim,
1175
+ bias=False,
1176
+ )
1177
+ self.wv = nn.Linear(
1178
+ dim,
1179
+ n_kv_heads * head_dim,
1180
+ bias=False,
1181
+ )
1182
+
1183
+ self.wo = nn.Linear(
1184
+ n_heads * head_dim,
1185
+ dim,
1186
+ bias=False,
1187
+ )
1188
+
1189
+ def forward(
1190
+ self,
1191
+ x: torch.Tensor,
1192
+ kv: torch.Tensor,
1193
+ mask: Optional[Union[BlockMask, str]] = None,
1194
+ ) -> torch.Tensor:
1195
+ # B S D
1196
+ bsz, seq_len, _ = x.shape
1197
+ _, slen_kv, _ = kv.shape
1198
+ x_norm = self.cross_attn_norm_q(x)
1199
+ kv = self.cross_attn_norm_kv(kv)
1200
+
1201
+ xq = self.wq(x_norm)
1202
+ xk = self.wk(kv)
1203
+ xv = self.wv(kv)
1204
+
1205
+ output_shape = xq.shape
1206
+ # B S D -> B S H D
1207
+ xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
1208
+ xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
1209
+ xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
1210
+
1211
+ xk = repeat_kv(xk, self.heads_per_group, dim=2)
1212
+ xv = repeat_kv(xv, self.heads_per_group, dim=2)
1213
+
1214
+ # assert mask is None or isinstance(mask, BlockMask)
1215
+ xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
1216
+ #output = flex_attention_comp(xq, xk, xv, block_mask=mask)
1217
+ is_causal = (mask == "causal") if isinstance(mask, str) else False
1218
+ mask = mask if isinstance(mask, torch.Tensor) else None
1219
+ mask = mask.to(dtype=xq.dtype).to(xq.device)
1220
+ output = F.scaled_dot_product_attention(
1221
+ xq,
1222
+ xk,
1223
+ xv,
1224
+ is_causal=is_causal,
1225
+ attn_mask=mask,
1226
+ )
1227
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
1228
+
1229
+ output = self.wo(output.reshape(output_shape))
1230
+
1231
+ return x + output
1232
+
1233
+ def init_weights(self, base_std: float, factor: float = 1.0):
1234
+ std = base_std or (self.dim ** (-0.5)) / factor
1235
+
1236
+ nn.init.trunc_normal_(
1237
+ self.wq.weight,
1238
+ mean=0.0,
1239
+ std=std,
1240
+ a=-3 * std,
1241
+ b=3 * std,
1242
+ )
1243
+
1244
+ nn.init.trunc_normal_(
1245
+ self.wk.weight,
1246
+ mean=0.0,
1247
+ std=std,
1248
+ a=-3 * std,
1249
+ b=3 * std,
1250
+ )
1251
+
1252
+ nn.init.trunc_normal_(
1253
+ self.wv.weight,
1254
+ mean=0.0,
1255
+ std=std,
1256
+ a=-3 * std,
1257
+ b=3 * std,
1258
+ )
1259
+
1260
+ nn.init.trunc_normal_(
1261
+ self.wo.weight,
1262
+ mean=0.0,
1263
+ std=std,
1264
+ a=-3 * std,
1265
+ b=3 * std,
1266
+ )
1267
+ self.cross_attn_norm_q.reset_parameters()
1268
+ self.cross_attn_norm_kv.reset_parameters()
1269
+
1270
+
1271
+ class GlobalTransformer(nn.Module):
1272
+ def __init__(self, config):
1273
+ super().__init__()
1274
+
1275
+ # Store config for later use
1276
+ self.config = config
1277
+
1278
+ self.dim = config.dim
1279
+ self.init_base_std = config.init_base_std
1280
+ self.attn_impl = config.attn_impl
1281
+ self.attn_bias_type = config.attn_bias_type
1282
+ self.init_std_factor = config.init_std_factor
1283
+ self.max_seqlen = config.max_seqlen
1284
+ self.rope_embeddings = RotaryEmbedding(
1285
+ theta=config.rope_theta,
1286
+ head_dim=config.head_dim or config.dim // config.n_heads,
1287
+ max_seqlen=config.max_seqlen,
1288
+ rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product,
1289
+ )
1290
+ # Handle both eos_id and eos_token_id for compatibility
1291
+ self.eos_id = getattr(config, 'eos_id', getattr(config, 'eos_token_id', 2))
1292
+
1293
+ # Create parameter dict for BLTTransformerLayers
1294
+ layer_params = {
1295
+ 'dim': self.dim,
1296
+ 'n_heads': config.n_heads,
1297
+ 'head_dim': config.head_dim,
1298
+ 'n_kv_heads': getattr(config, 'n_kv_heads', None),
1299
+ 'rope_theta': config.rope_theta,
1300
+ 'multiple_of': getattr(config, 'multiple_of', 256),
1301
+ 'ffn_dim_multiplier': getattr(config, 'ffn_dim_multiplier', None),
1302
+ 'norm_eps': config.norm_eps,
1303
+ }
1304
+
1305
+ self.layers = nn.ModuleList()
1306
+ for _ in range(config.n_layers):
1307
+ self.layers.append(BLTTransformerLayer(layer_params))
1308
+
1309
+ # GlobalTransformer specific attributes
1310
+ self.dropout = config.dropout
1311
+ self.dim_token_emb = config.dim_token_emb
1312
+
1313
+ self.token_embedding_projection = None
1314
+ if config.dim_token_emb is not None and config.dim_token_emb != self.dim:
1315
+ self.token_embedding_projection = nn.Linear(
1316
+ config.dim_token_emb,
1317
+ config.dim,
1318
+ bias=False,
1319
+ )
1320
+
1321
+ def forward(
1322
+ self,
1323
+ tokens: torch.Tensor,
1324
+ tok_idx: Optional[torch.Tensor] = None,
1325
+ embeds: Optional[torch.Tensor] = None,
1326
+ mask: Optional[Union[BlockMask, torch.Tensor, str]] = None,
1327
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
1328
+ ):
1329
+ bs, seqlen = tokens.shape
1330
+
1331
+ h = embeds
1332
+
1333
+ mask = (
1334
+ mask
1335
+ if mask is not None
1336
+ else create_causal_mask(
1337
+ seqlen,
1338
+ self.attn_impl,
1339
+ self.attn_bias_type,
1340
+ tokens=tokens,
1341
+ eos_id=self.eos_id,
1342
+ )
1343
+ )
1344
+
1345
+ if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
1346
+ h = self.token_embedding_projection(h)
1347
+
1348
+ h = F.dropout(h, p=self.dropout, training=self.training)
1349
+
1350
+ freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)
1351
+
1352
+ for i, layer in enumerate(self.layers):
1353
+ h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
1354
+
1355
+ return h, cache
1356
+
1357
+ def init_weights(self):
1358
+ self.rope_embeddings.reset_parameters()
1359
+ for depth, layer in enumerate(self.layers):
1360
+ factor = self.config.get_init_std_factor(depth)
1361
+ layer.init_weights(self.init_base_std, factor)
1362
+
1363
+ # GlobalTransformer specific initialization
1364
+ std = self.dim_token_emb ** (-0.5)
1365
+ if self.token_embedding_projection is not None:
1366
+ nn.init.trunc_normal_(
1367
+ self.token_embedding_projection.weight,
1368
+ mean=0.0,
1369
+ std=std,
1370
+ a=-3 * std,
1371
+ b=3 * std,
1372
+ )
1373
+
1374
+ def compute_hash_embeddings(
1375
+ local_encoder_tokens: torch.Tensor,
1376
+ local_encoder,
1377
+ encoder_hash_tok_embedding: nn.ModuleList,
1378
+ encoder_hash_byte_group_nb_functions: int,
1379
+ encoder_hash_byte_group_size: list,
1380
+ encoder_hash_byte_group_vocab: int,
1381
+ ) -> torch.Tensor:
1382
+ """
1383
+ Compute embeddings using hash token embeddings.
1384
+
1385
+ Args:
1386
+ local_encoder_tokens: Input tokens tensor
1387
+ local_encoder: Encoder object with tok_embeddings method
1388
+ encoder_hash_tok_embedding: ModuleList of hash token embeddings
1389
+ encoder_hash_byte_group_nb_functions: Number of hash functions
1390
+ encoder_hash_byte_group_size: List of byte group sizes
1391
+ encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings
1392
+
1393
+ Returns:
1394
+ torch.Tensor: Combined embeddings
1395
+ """
1396
+ if encoder_hash_tok_embedding is None:
1397
+ return None
1398
+
1399
+ local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens)
1400
+
1401
+ i = 0
1402
+ for func_nb in range(encoder_hash_byte_group_nb_functions):
1403
+ for byte_group_size in encoder_hash_byte_group_size:
1404
+ hash_ids = byte_group_hash_function(
1405
+ local_encoder_tokens,
1406
+ byte_group_size,
1407
+ hash_func_nb=func_nb,
1408
+ max_hash=encoder_hash_byte_group_vocab,
1409
+ )
1410
+ hash_tok_embedding = encoder_hash_tok_embedding[i]
1411
+ local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids)
1412
+ i += 1
1413
+
1414
+ assert i == len(encoder_hash_tok_embedding)
1415
+ return local_encoder_embeds
1416
+
1417
+
1418
+ class BLTPreTrainedModel(PreTrainedModel):
1419
+ """
1420
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1421
+ BLT models.
1422
+
1423
+ This class provides the interface for model loading, saving, and weight initialization for all BLT model variants.
1424
+ It inherits from [`PreTrainedModel`] which provides the core functionality for working with HuggingFace models.
1425
+
1426
+ Args:
1427
+ config ([`BLTConfig`]): Model configuration class with all the parameters of the model.
1428
+ Initializing with a config file does not load the weights associated with the model, only the
1429
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1430
+ """
1431
+
1432
+ config_class = BLTConfig
1433
+ base_model_prefix = "model"
1434
+ supports_gradient_checkpointing = True
1435
+ _no_split_modules = ["BLTTransformerLayer", "LocalEncoder", "LocalDecoder", "GlobalTransformer"]
1436
+ _skip_keys_device_placement = ["past_key_values"]
1437
+ _supports_flash_attn_2 = False # BLT uses its own attention implementation
1438
+ _supports_sdpa = True
1439
+ _supports_cache_class = False
1440
+
1441
+ def _init_weights(self, module):
1442
+ """Initialize the weights - this is called by PreTrainedModel but we delegate to our custom init"""
1443
+ # Don't do anything here - we use the custom init_weights method instead
1444
+ pass
1445
+
1446
+
1447
+ class BLTModel(BLTPreTrainedModel):
1448
+ """
1449
+ The BLTModel (BLT) is a byte-level language model architecture that processes byte sequences
1450
+ by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
1451
+ and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for
1452
+ improved performance and inference efficiency.
1453
+ """
1454
+
1455
+ def __init__(self, config: BLTConfig):
1456
+ super().__init__(config)
1457
+
1458
+ # Store config reference
1459
+ self.config = config
1460
+
1461
+ # Create main components - they will read their parameters from config
1462
+ self.local_encoder = LocalEncoder(config)
1463
+
1464
+ # Create global-specific config by copying config and overriding dimensions
1465
+ global_config = type(config)(**config.to_dict())
1466
+ global_config.dim = config.dim_global
1467
+ global_config.n_layers = config.n_layers_global
1468
+ global_config.n_heads = config.n_heads_global
1469
+ global_config.n_kv_heads = config.n_kv_heads_global
1470
+ global_config.dim_token_emb = config.global_dim_patch_emb
1471
+
1472
+ self.global_transformer = GlobalTransformer(global_config)
1473
+ self.local_decoder = LocalDecoder(config)
1474
+
1475
+ # Initialize hash embeddings
1476
+ self.encoder_hash_tok_embedding = init_hash_embeddings(
1477
+ config,
1478
+ local_encoder_dim=self.local_encoder.dim,
1479
+ encoder_hash_byte_group_size=config.encoder_hash_byte_group_size,
1480
+ )
1481
+
1482
+ # Initialize patcher if needed
1483
+ if config.patch_in_forward:
1484
+ if config.realtime_patching and config.entropy_model_checkpoint_dir is not None:
1485
+ # Load entropy model directly
1486
+ entropy_model_checkpoint_dir = config.entropy_model_checkpoint_dir
1487
+
1488
+ if not os.path.exists(entropy_model_checkpoint_dir):
1489
+ raise FileNotFoundError(f"Entropy model checkpoint directory not found: {entropy_model_checkpoint_dir}")
1490
+
1491
+ # Load entropy model parameters
1492
+ params_path = os.path.join(entropy_model_checkpoint_dir, "params.json")
1493
+ if not os.path.exists(params_path):
1494
+ raise FileNotFoundError(f"params.json not found in: {entropy_model_checkpoint_dir}")
1495
+
1496
+ with open(params_path) as fr:
1497
+ reloaded = json.loads(fr.read())
1498
+
1499
+ torch.set_default_dtype(torch.bfloat16)
1500
+ model_params = reloaded["entropy_model"]
1501
+ logger.warning(
1502
+ "Update checkpoint to load attn and sliding window args from checkpoint"
1503
+ )
1504
+
1505
+ # Override patcher configuration with actual entropy model parameters from checkpoint
1506
+ config.patcher_dim = model_params["dim"]
1507
+ config.patcher_n_layers = model_params["n_layers"]
1508
+ config.patcher_n_heads = model_params["n_heads"]
1509
+ config.patcher_max_seqlen = model_params["max_seqlen"]
1510
+ config.patcher_ffn_dim_multiplier = model_params["ffn_dim_multiplier"]
1511
+ config.patcher_vocab_size = model_params["vocab_size"]
1512
+ # Use sensible defaults for parameters not in checkpoint
1513
+ config.patcher_attn_bias_type = "local_block_causal"
1514
+ config.patcher_attn_impl = "sdpa" # originally xformers
1515
+ config.patcher_sliding_window = 512
1516
+
1517
+ # BLTPatcher will extract patcher_ parameters from config directly
1518
+ self.patcher = BLTPatcher(config)
1519
+
1520
+ state_path = os.path.join(
1521
+ entropy_model_checkpoint_dir, "consolidated.pth"
1522
+ )
1523
+
1524
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1525
+ self.patcher.load_state_dict(
1526
+ torch.load(state_path, map_location=device)["model"], strict=False
1527
+ )
1528
+ self.patcher.to(device)
1529
+ self.patcher = self.patcher.eval()
1530
+ # no grads for the model:
1531
+ for param in self.patcher.parameters():
1532
+ param.requires_grad = False
1533
+ else:
1534
+ self.patcher = None
1535
+
1536
+ # Initialize weights and apply final processing
1537
+ self.post_init()
1538
+
1539
+ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
1540
+ """
1541
+ Convert patch lengths to patch IDs for each token position.
1542
+
1543
+ For each token position in the sequence, determines which patch it belongs to.
1544
+
1545
+ Args:
1546
+ patch_lengths: [batch_size, num_patches] - length of each patch
1547
+ seq_len: total sequence length
1548
+
1549
+ Returns:
1550
+ patch_ids: [batch_size, seq_len] - patch index for each token position
1551
+
1552
+ Example:
1553
+ patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1
1554
+ seq_len = 10
1555
+ Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]]
1556
+ # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3
1557
+ """
1558
+ batch_size, num_patches = patch_lengths.shape
1559
+
1560
+ # Create patch start positions: [0, 3, 5, 9] for the example above
1561
+ patch_starts = torch.cat([
1562
+ torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
1563
+ patch_lengths.cumsum(dim=-1)[:, :-1] # cumsum without the final total
1564
+ ], dim=-1)
1565
+
1566
+ # For each token position, find which patch it belongs to
1567
+ # by finding the rightmost patch start that's <= the position
1568
+ token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1]
1569
+
1570
+ # Broadcasting: patch_starts[batch, patch] <= token_positions[position]
1571
+ # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t
1572
+ position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)
1573
+
1574
+ # Count how many patch starts are <= each position, then subtract 1 to get patch index
1575
+ patch_ids = position_ge_patch_start.sum(dim=-1) - 1
1576
+
1577
+ return patch_ids
1578
+
1579
+ def _decoder_patch_ids_from_lengths(self, patch_lengths: torch.Tensor, nb_boe: int, seq_len: int) -> torch.Tensor:
1580
+ """
1581
+ Create decoder patch IDs by skipping the first encoder patch.
1582
+
1583
+ The decoder starts after the first patch (which contains BOE tokens),
1584
+ so we need to map decoder positions to the remaining patches.
1585
+
1586
+ Args:
1587
+ patch_lengths: [batch_size, num_patches] from encoder
1588
+ nb_boe: number of beginning-of-example tokens in first patch
1589
+ seq_len: decoder sequence length
1590
+
1591
+ Returns:
1592
+ decoder_patch_ids: [batch_size, seq_len] mapping decoder positions to patch indices
1593
+ """
1594
+ # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens)
1595
+ decoder_patch_lengths = patch_lengths[:, 1:]
1596
+
1597
+ # Create patch IDs for the decoder sequence using the remaining patches
1598
+ return self._patch_ids_from_lengths(decoder_patch_lengths, seq_len)
1599
+
1600
+
1601
+
1602
+ def forward(
1603
+ self,
1604
+ tokens: torch.Tensor,
1605
+ patch_lengths: Optional[torch.Tensor] = None,
1606
+ ):
1607
+ # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings
1608
+ # are no longer used in the final BLT model
1609
+
1610
+ bs, N = tokens.shape # Batch size and sequence length
1611
+
1612
+ # Get megabyte inputs
1613
+ nb_boe = int(0 if self.config.patching_mode != "" else self.config.patch_size - 1)
1614
+ local_encoder_tokens, _, local_decoder_tokens = get_blt_input(
1615
+ tokens=tokens,
1616
+ enforce_patch_size_multiple=False,
1617
+ nb_boe=nb_boe,
1618
+ patch_size=self.config.patch_size,
1619
+ boe_id=BOE_ID,
1620
+ )
1621
+
1622
+ # Patching
1623
+ if patch_lengths is None:
1624
+ # assert (
1625
+ # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward
1626
+ # ), "Patch in forward not enabled and no patch_lengths passed."
1627
+
1628
+ # PATCHER MODEL DEFINED
1629
+ if self.config.patching_mode == PatchingModeEnum.entropy:
1630
+ _, patch_lengths, _ = self.patcher(
1631
+ local_encoder_tokens,
1632
+ patch_size=self.config.patch_size,
1633
+ include_next_token=True,
1634
+ threshold=self.config.patching_threshold,
1635
+ threshold_add=self.config.patching_threshold_add,
1636
+ monotonicity=self.config.monotonicity,
1637
+ max_patch_length=self.config.max_patch_length,
1638
+ patching_batch_size=self.config.patching_batch_size,
1639
+ device=self.config.patching_device,
1640
+ )
1641
+ else:
1642
+ # self.config.patching_mode == PatchingModeEnum.byte
1643
+ bs, seq_len = local_encoder_tokens.shape
1644
+ seq_len_next_tok = seq_len + 1 # include_next_token=True
1645
+ patch_lengths = torch.ones(
1646
+ (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device
1647
+ )
1648
+
1649
+ # Apply any processing to patch lengths
1650
+ if self.config.max_patch_length is not None:
1651
+ # TODO: avoid going back to a list here.
1652
+ patch_lengths = [
1653
+ BLTPatcher.split_large_numbers(pl, self.config.max_patch_length)
1654
+ for pl in patch_lengths.tolist()
1655
+ ]
1656
+ max_len = max([len(pl) for pl in patch_lengths])
1657
+ patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
1658
+ patch_lengths = torch.tensor(
1659
+ patch_lengths, dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device
1660
+ )
1661
+ assert not check_non_zero_after_zero(patch_lengths)
1662
+ # Find the last non-zero column index using argmax on a reversed version of the tensor
1663
+ last_non_zero_col_reversed = (
1664
+ (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
1665
+ )
1666
+ # Slice the tensor up to the last non-zero column
1667
+ patch_lengths = patch_lengths[
1668
+ :, : patch_lengths.shape[1] - last_non_zero_col_reversed
1669
+ ]
1670
+ else:
1671
+ if nb_boe > 0:
1672
+ patch_lengths[:, 0] += nb_boe
1673
+
1674
+ assert torch.min(patch_lengths) >= 0
1675
+
1676
+ # Generate patch IDs from patch_lengths
1677
+ patch_ids = self._patch_ids_from_lengths(
1678
+ patch_lengths, local_encoder_tokens.shape[-1]
1679
+ )
1680
+ assert torch.max(patch_ids) + 1 <= torch.max(
1681
+ (patch_lengths != 0).sum(dim=-1)
1682
+ ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}"
1683
+
1684
+ cross_attn_mask_enc = None
1685
+ # Cross-attention encoder
1686
+ if self.config.cross_attn_encoder:
1687
+ cross_attn_mask_enc = cross_attn_mask(
1688
+ patch_ids,
1689
+ patch_lengths,
1690
+ N,
1691
+ patches_as_queries=True,
1692
+ cross_attn_k=self.config.cross_attn_k,
1693
+ window=self.config.cross_attn_window_encoder,
1694
+ block_mask=self.config.cross_attn_use_flex_attention,
1695
+ )
1696
+
1697
+ # Hashing and embedding
1698
+ local_encoder_embeds = compute_hash_embeddings(
1699
+ local_encoder_tokens=local_encoder_tokens,
1700
+ local_encoder=self.local_encoder,
1701
+ encoder_hash_tok_embedding=self.encoder_hash_tok_embedding,
1702
+ encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions,
1703
+ encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size,
1704
+ encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab,
1705
+ )
1706
+
1707
+ # NOTE: Frequency-based n-gram embeddings removed as per paper
1708
+ # The final BLT model uses only hash-based n-gram embeddings
1709
+
1710
+ # Local encoder
1711
+ (h_encoder, h_cross), cache_encoder = self.local_encoder(
1712
+ tokens=local_encoder_tokens,
1713
+ embeds=local_encoder_embeds,
1714
+ patch_embeds=None,
1715
+ cross_mask=cross_attn_mask_enc,
1716
+ num_patches=patch_lengths.shape[1],
1717
+ patch_ids=patch_ids,
1718
+ )
1719
+
1720
+ # Downsampling
1721
+ h = h_cross.view(bs, patch_lengths.shape[1], -1)
1722
+
1723
+ # Global transformer
1724
+ global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(BOE_ID)
1725
+ rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id)
1726
+ eos_patch_ids = patch_ids[rows, cols]
1727
+ global_tokens[rows, eos_patch_ids] = self.config.eos_token_id
1728
+
1729
+ h, _ = self.global_transformer(
1730
+ embeds=h,
1731
+ tokens=global_tokens,
1732
+ )
1733
+
1734
+ # Unpatching
1735
+ dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :]
1736
+
1737
+ # Generate decoder patch IDs
1738
+ decoder_patch_ids = self._decoder_patch_ids_from_lengths(
1739
+ patch_lengths, nb_boe, local_decoder_tokens.shape[-1]
1740
+ )
1741
+ assert (
1742
+ torch.max(decoder_patch_ids) + 1 <= h.shape[1]
1743
+ ), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}"
1744
+ assert (
1745
+ decoder_patch_ids.shape[1] == dec_embeds.shape[1]
1746
+ ), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}"
1747
+
1748
+ # Cross-attention decoder
1749
+ if not self.config.cross_attn_decoder:
1750
+ h = torch.gather(
1751
+ h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
1752
+ )
1753
+ cross_attn_mask_dec = None
1754
+ assert local_decoder_tokens.shape == h.shape[:-1]
1755
+ else:
1756
+ cross_attn_mask_dec = cross_attn_mask(
1757
+ decoder_patch_ids,
1758
+ patch_lengths,
1759
+ N,
1760
+ patches_as_queries=False,
1761
+ cross_attn_k=self.config.cross_attn_k,
1762
+ window=self.config.cross_attn_window_decoder,
1763
+ block_mask=self.config.cross_attn_use_flex_attention,
1764
+ )
1765
+
1766
+ # Local decoder
1767
+ output, _ = self.local_decoder(
1768
+ embeds=dec_embeds,
1769
+ patch_embeds=h,
1770
+ tokens=local_decoder_tokens,
1771
+ cross_mask=cross_attn_mask_dec,
1772
+ )
1773
+ return output
1774
+
1775
+ def init_weights(self):
1776
+ self.local_encoder.init_weights()
1777
+ self.global_transformer.init_weights()
1778
+ self.local_decoder.init_weights()
1779
+
1780
+ if self.encoder_hash_tok_embedding is not None:
1781
+ emb_std = self.local_encoder.dim ** (-0.5)
1782
+ for emb in self.encoder_hash_tok_embedding:
1783
+ nn.init.trunc_normal_(
1784
+ emb.weight,
1785
+ mean=0.0,
1786
+ std=emb_std,
1787
+ a=-3 * emb_std,
1788
+ b=3 * emb_std,
1789
+ )
1790
+
1791
+
1792
+ class BLTPatcher(BLTPreTrainedModel):
1793
+ def __init__(self, config):
1794
+ super().__init__(config)
1795
+
1796
+ # Store config reference for later use
1797
+ self.config = config
1798
+
1799
+ # Extract patcher parameters from BLTConfig
1800
+ self.dim = config.patcher_dim
1801
+ self.init_base_std = config.patcher_init_base_std
1802
+ self.attn_impl = config.patcher_attn_impl
1803
+ self.attn_bias_type = config.patcher_attn_bias_type
1804
+ self.init_std_factor = config.patcher_init_std_factor
1805
+ self.max_seqlen = config.patcher_max_seqlen
1806
+ n_layers = config.patcher_n_layers
1807
+ n_heads = config.patcher_n_heads
1808
+ head_dim = config.patcher_head_dim
1809
+ rope_theta = config.patcher_rope_theta
1810
+ rope_use_fp32_in_outer_product = config.patcher_rope_use_fp32_in_outer_product
1811
+ norm_eps = config.patcher_norm_eps
1812
+ vocab_size = config.patcher_vocab_size
1813
+ weight_tying = config.patcher_weight_tying
1814
+ sliding_window = config.patcher_sliding_window
1815
+ eos_token_id = config.patcher_eos_token_id
1816
+
1817
+ self.rope_embeddings = RotaryEmbedding(
1818
+ theta=rope_theta,
1819
+ head_dim=head_dim or self.dim // n_heads,
1820
+ max_seqlen=self.max_seqlen,
1821
+ rope_use_fp32_in_outer_product=rope_use_fp32_in_outer_product,
1822
+ )
1823
+ # Handle both eos_id and eos_token_id for compatibility
1824
+ self.eos_id = eos_token_id
1825
+
1826
+ # Extract additional parameters for BLTTransformerLayer
1827
+ n_kv_heads = getattr(config, 'patcher_n_kv_heads', None) if hasattr(config, 'patcher_dim') else getattr(config, 'n_kv_heads', None)
1828
+ multiple_of = getattr(config, 'patcher_multiple_of', 256) if hasattr(config, 'patcher_dim') else getattr(config, 'multiple_of', 256)
1829
+ ffn_dim_multiplier = getattr(config, 'patcher_ffn_dim_multiplier', None) if hasattr(config, 'patcher_dim') else getattr(config, 'ffn_dim_multiplier', None)
1830
+
1831
+ # Create a simple parameter dict for BLTTransformerLayer
1832
+ layer_params = {
1833
+ 'dim': self.dim,
1834
+ 'n_heads': n_heads,
1835
+ 'head_dim': head_dim,
1836
+ 'n_kv_heads': n_kv_heads,
1837
+ 'rope_theta': rope_theta,
1838
+ 'multiple_of': multiple_of,
1839
+ 'ffn_dim_multiplier': ffn_dim_multiplier,
1840
+ 'norm_eps': norm_eps,
1841
+ }
1842
+
1843
+ self.layers = nn.ModuleList()
1844
+ for _ in range(n_layers):
1845
+ self.layers.append(BLTTransformerLayer(layer_params))
1846
+
1847
+ # LMTransformer specific attributes
1848
+ self.weight_tying = weight_tying
1849
+ self.sliding_window = sliding_window
1850
+
1851
+ assert vocab_size > 0
1852
+
1853
+ self.tok_embeddings = torch.nn.Embedding(vocab_size, self.dim)
1854
+
1855
+ self.norm = RMSNorm(self.dim, eps=norm_eps)
1856
+
1857
+ self.output = nn.Linear(
1858
+ self.dim,
1859
+ vocab_size,
1860
+ bias=False,
1861
+ )
1862
+
1863
+ if self.weight_tying:
1864
+ self.output.weight = self.tok_embeddings.weight
1865
+
1866
+ def forward(
1867
+ self,
1868
+ token_values: torch.Tensor,
1869
+ target: Optional[torch.Tensor] = None,
1870
+ tok_idx: Optional[torch.Tensor] = None,
1871
+ mask: Optional[Union[BlockMask, torch.Tensor, str]] = None,
1872
+ attn_impl: str | None = None,
1873
+ patch_size: Optional[int] = None,
1874
+ include_next_token: bool = True,
1875
+ threshold: Optional[float] = None,
1876
+ threshold_add: Optional[float] = None,
1877
+ monotonicity: bool = False,
1878
+ max_patch_length: Optional[int] = None,
1879
+ patching_batch_size: int = 1, # Changed from Optional[int] = None to int = 1
1880
+ device: Optional[str] = None,
1881
+ enable_grad: bool = False,
1882
+ ):
1883
+ attn_impl = self.attn_impl if attn_impl is None else attn_impl
1884
+
1885
+ # Handle chunked processing for entropy calculation
1886
+ # grad_context = nullcontext() if enable_grad else torch.no_grad()
1887
+ # with grad_context:
1888
+ entropies = []
1889
+ preds = []
1890
+ max_length = min(getattr(self, "max_length", 8192), self.max_seqlen)
1891
+ batch_numel = max_length * patching_batch_size
1892
+ splits = torch.split(token_values.flatten(), batch_numel)
1893
+
1894
+ for split in splits:
1895
+ pad_size = (max_length - (split.numel() % max_length)) % max_length
1896
+ pad = torch.zeros(
1897
+ pad_size, dtype=split.dtype, device=split.device, requires_grad=False
1898
+ )
1899
+ split = torch.cat((split, pad), dim=0)
1900
+ split = split.reshape(-1, max_length)
1901
+ if device is not None:
1902
+ split = split.to(device)
1903
+
1904
+ # Process chunk: embeddings -> layers -> output
1905
+ bsz, seqlen = split.shape
1906
+ h = self.tok_embeddings(split)
1907
+ chunk_mask = create_causal_mask(
1908
+ seqlen,
1909
+ attn_impl,
1910
+ self.attn_bias_type,
1911
+ sliding_window=self.sliding_window,
1912
+ tokens=split,
1913
+ eos_id=self.eos_id,
1914
+ )
1915
+ freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None)
1916
+
1917
+ for i, layer in enumerate(self.layers):
1918
+ h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=attn_impl)
1919
+
1920
+ pred = self.output(self.norm(h))
1921
+ pred = pred.reshape(-1, pred.shape[-1])[
1922
+ : split.numel() - pad_size, :
1923
+ ] # [batch_size * seq_len, vocab]
1924
+ preds.append(pred)
1925
+ pred_entropies = self.entropy(pred)
1926
+ entropies.append(pred_entropies)
1927
+
1928
+ concat_entropies = torch.cat(entropies, dim=0)
1929
+ concat_entropies = concat_entropies.reshape(token_values.shape)
1930
+ concat_preds = torch.cat(preds, dim=0)
1931
+ concat_preds = concat_preds.reshape(token_values.shape[0], -1)
1932
+
1933
+ # Always compute patch lengths from concatenated entropies
1934
+ bs, seq_len = token_values.shape
1935
+ seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
1936
+
1937
+ # Find patch start IDs based on entropy
1938
+ if patch_size is not None:
1939
+ patch_start_ids = self.find_entropy_patch_start_ids(
1940
+ concat_entropies,
1941
+ patch_size,
1942
+ include_next_token=include_next_token,
1943
+ threshold=threshold,
1944
+ threshold_add=threshold_add,
1945
+ monotonicity=monotonicity,
1946
+ )
1947
+ patch_lengths = self.patch_lengths_from_start_ids(
1948
+ patch_start_ids, seq_len_next_tok
1949
+ )
1950
+ else:
1951
+ # Default to byte-level patching
1952
+ patch_lengths = torch.ones(
1953
+ (bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device
1954
+ )
1955
+
1956
+ # Apply any processing to patch lengths
1957
+ if max_patch_length is not None:
1958
+ # TODO: avoid going back to a list here.
1959
+ patch_lengths = [
1960
+ self.split_large_numbers(pl, max_patch_length)
1961
+ for pl in patch_lengths.tolist()
1962
+ ]
1963
+ max_len = max([len(pl) for pl in patch_lengths])
1964
+ patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
1965
+ patch_lengths = torch.tensor(
1966
+ patch_lengths, dtype=token_values.dtype, device=token_values.device
1967
+ )
1968
+ assert not check_non_zero_after_zero(patch_lengths)
1969
+ # Find the last non-zero column index using argmax on a reversed version of the tensor
1970
+ last_non_zero_col_reversed = (
1971
+ (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
1972
+ )
1973
+ # Slice the tensor up to the last non-zero column
1974
+ patch_lengths = patch_lengths[
1975
+ :, : patch_lengths.shape[1] - last_non_zero_col_reversed
1976
+ ]
1977
+
1978
+ return concat_entropies, patch_lengths, concat_preds
1979
+
1980
+ def reset_parameters(self, init_std=None):
1981
+ self.norm.reset_parameters()
1982
+
1983
+ def init_weights(self):
1984
+ self.reset_parameters()
1985
+ init_std = self.dim ** (-0.5)
1986
+ nn.init.trunc_normal_(
1987
+ self.tok_embeddings.weight,
1988
+ mean=0.0,
1989
+ std=init_std,
1990
+ a=-3 * init_std,
1991
+ b=3 * init_std,
1992
+ )
1993
+
1994
+ self.rope_embeddings.reset_parameters()
1995
+ for depth, layer in enumerate(self.layers):
1996
+ factor = self.config.get_init_std_factor(depth)
1997
+ layer.init_weights(self.init_base_std, factor)
1998
+
1999
+ if not self.weight_tying:
2000
+ nn.init.trunc_normal_(
2001
+ self.output.weight,
2002
+ mean=0.0,
2003
+ std=init_std,
2004
+ a=-3 * init_std,
2005
+ b=3 * init_std,
2006
+ )
2007
+
2008
+ @staticmethod
2009
+ def entropy(scores):
2010
+ """
2011
+ scores: [bs, seq_len, vocab]
2012
+ returns [bs, seq_len]
2013
+
2014
+ Computes the entropy for each token in the batch.
2015
+ Note: uses natural log.
2016
+ """
2017
+ log_probs = F.log_softmax(scores, dim=-1)
2018
+ probs = torch.exp(log_probs)
2019
+ p_log_p = log_probs * probs
2020
+ entropy = -p_log_p.sum(dim=-1)
2021
+ return entropy
2022
+
2023
+
2024
+
2025
+ @staticmethod
2026
+ def patch_start_ids_from_patch_start_mask(patch_start_mask):
2027
+ bs, trunc_seq_len = patch_start_mask.shape
2028
+ max_patches = patch_start_mask.sum(dim=1).max()
2029
+ if max_patches == 0:
2030
+ patch_start_ids = torch.full(
2031
+ (bs, trunc_seq_len),
2032
+ trunc_seq_len,
2033
+ dtype=torch.long,
2034
+ device=patch_start_mask.device,
2035
+ )
2036
+ else:
2037
+ patch_ids = (
2038
+ torch.arange(trunc_seq_len, device=patch_start_mask.device)
2039
+ .unsqueeze(0)
2040
+ .repeat(bs, 1)
2041
+ )
2042
+ extra_patch_ids = torch.full(
2043
+ (bs, trunc_seq_len),
2044
+ trunc_seq_len,
2045
+ dtype=torch.long,
2046
+ device=patch_start_mask.device,
2047
+ )
2048
+ all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
2049
+ patch_start_mask_padded = torch.cat(
2050
+ (patch_start_mask, ~patch_start_mask), dim=1
2051
+ )
2052
+ patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(
2053
+ bs, trunc_seq_len
2054
+ )[:, :max_patches]
2055
+ return patch_start_ids
2056
+
2057
+ @staticmethod
2058
+ def patch_lengths_from_start_ids(patch_start_ids, seq_len):
2059
+ """
2060
+ Calculate patch lengths from start ids.
2061
+ start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then
2062
+ the rest are filled to the seq len.
2063
+ seq_len: ex: 7 length of the sequence
2064
+
2065
+ returns the patch lengths:
2066
+ [1, 6] for the above example.
2067
+ """
2068
+ last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1)
2069
+ patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1)
2070
+ patch_lengths = patch_end_ids - patch_start_ids + 1
2071
+ assert torch.all(patch_lengths >= 0), f"{patch_lengths}"
2072
+ assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}"
2073
+ return patch_lengths
2074
+
2075
+ @staticmethod
2076
+ def find_entropy_patch_start_ids(
2077
+ entropies,
2078
+ patch_size=None,
2079
+ threshold=None,
2080
+ threshold_add=None,
2081
+ monotonicity=False,
2082
+ include_next_token=True,
2083
+ ):
2084
+ """
2085
+ Use entropies to find the start ids of each patch.
2086
+ Use patch_size or threshold to figure out the total number of patches to allocate.
2087
+
2088
+ When threshold is not None the number of patches is not constant between
2089
+ different sequences, but patches can be identified incrementally rather than
2090
+ decided globally using the entire sequence.
2091
+ """
2092
+ bs, seq_len = entropies.shape[:2]
2093
+
2094
+ first_ids = (
2095
+ torch.tensor([0, 1], dtype=torch.long, device=entropies.device)
2096
+ .unsqueeze(0)
2097
+ .repeat(bs, 1)
2098
+ )
2099
+ preds_truncation_len = first_ids.shape[
2100
+ 1
2101
+ ] # remove the first preds because they will be start of patches.
2102
+ entropies = entropies[:, 1:]
2103
+ if threshold is None:
2104
+ num_patches = seq_len // patch_size
2105
+ patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices
2106
+ patch_start_ids = patch_start_ids.sort(dim=1).values
2107
+ else:
2108
+ patch_start_mask = entropies > threshold
2109
+ if not include_next_token:
2110
+ patch_start_mask = patch_start_mask[:, :-1]
2111
+ # patch_start_mask[1:] |= tokens[:-1] < OFFSET
2112
+ patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask)
2113
+
2114
+ patch_start_ids = torch.cat(
2115
+ (first_ids, patch_start_ids + preds_truncation_len), dim=1
2116
+ )
2117
+ return patch_start_ids
2118
+
2119
+ @staticmethod
2120
+ def split_large_numbers(lst, m):
2121
+ new_lst = []
2122
+ for i in lst:
2123
+ if i > m:
2124
+ while i > m:
2125
+ new_lst.append(m)
2126
+ i -= m
2127
+ new_lst.append(i)
2128
+ else:
2129
+ new_lst.append(i)
2130
+ assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}"
2131
+ return new_lst
2132
+
2133
+
2134
+ def init_hash_embeddings(
2135
+ config,
2136
+ local_encoder_dim: int,
2137
+ encoder_hash_byte_group_size: list,
2138
+ ):
2139
+ """Initialize hash-based token embeddings for the BLT encoder."""
2140
+ if config.encoder_hash_byte_group_size is None:
2141
+ return None
2142
+
2143
+ embeddings = []
2144
+ emb_dim = local_encoder_dim
2145
+ encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab
2146
+
2147
+ for _ in range(config.encoder_hash_byte_group_nb_functions):
2148
+ for _ in encoder_hash_byte_group_size:
2149
+ embeddings.append(
2150
+ nn.Embedding(
2151
+ encoder_hash_byte_group_vocab,
2152
+ emb_dim,
2153
+ )
2154
+ )
2155
+
2156
+ return nn.ModuleList(embeddings)
2157
+
2158
+
2159
+ __all__ = [
2160
+ "BLTPreTrainedModel",
2161
+ "BLTModel",
2162
+ "BLTPatcher",
2163
+ "LocalEncoder",
2164
+ "LocalDecoder",
2165
+ "GlobalTransformer",
2166
+ ]
backup_blt_wip_backup/tokenization_blt.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Tokenization classes for BLT."""
17
+
18
+ import os
19
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
20
+
21
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
22
+ from ...utils import logging
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from ...tokenization_utils_base import TextInput
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ # BLT tokenizer constants
31
+ SEP = " "
32
+ BOS_ID: int = 1
33
+ EOS_ID: int = 2
34
+ PAD_ID: int = -1
35
+ BOE_ID: int = 0
36
+ BPE_ID: int = 3
37
+ OFFSET: int = 4
38
+ BYTE_UNITS: int = 256
39
+
40
+ VOCAB_FILES_NAMES = {} # BLT doesn't require external vocab files
41
+
42
+
43
+ class BLTTokenizer(PreTrainedTokenizer):
44
+ """
45
+ Construct a BLT tokenizer. Based on byte-level tokenization where each byte is treated as a token.
46
+
47
+ This tokenizer converts text to UTF-8 bytes and then maps each byte to a token ID with an offset.
48
+ It supports special tokens for beginning of sequence (BOS), end of sequence (EOS),
49
+ beginning of example (BOE), and padding (PAD).
50
+
51
+ Args:
52
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
53
+ The beginning of sequence token.
54
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
55
+ The end of sequence token.
56
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`):
57
+ The padding token.
58
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
59
+ The unknown token. Not used in BLT but kept for compatibility.
60
+ boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<boe>"`):
61
+ The beginning of example token, specific to BLT.
62
+ add_bos_token (`bool`, *optional*, defaults to `True`):
63
+ Whether or not to add a `bos_token` at the start of sequences.
64
+ add_eos_token (`bool`, *optional*, defaults to `True`):
65
+ Whether or not to add an `eos_token` at the end of sequences.
66
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
67
+ Whether or not to cleanup spaces after decoding.
68
+ spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
69
+ Whether or not to add spaces between special tokens.
70
+ """
71
+
72
+ vocab_files_names = VOCAB_FILES_NAMES
73
+ model_input_names = ["input_ids", "attention_mask"]
74
+
75
+ def __init__(
76
+ self,
77
+ bos_token="<s>",
78
+ eos_token="</s>",
79
+ pad_token="<pad>",
80
+ unk_token="<unk>",
81
+ boe_token="<boe>",
82
+ add_bos_token=True,
83
+ add_eos_token=True,
84
+ clean_up_tokenization_spaces=False,
85
+ spaces_between_special_tokens=False,
86
+ **kwargs,
87
+ ):
88
+ # Store BLT-specific parameters first
89
+ self.add_bos_token = add_bos_token
90
+ self.add_eos_token = add_eos_token
91
+ self.vocab_size_unit_1 = BYTE_UNITS
92
+ self.offsetting_special_char = OFFSET
93
+
94
+ # BLT token IDs (exactly like original)
95
+ self.boe_id = BOE_ID
96
+ self.bos_id = BOS_ID
97
+ self.eos_id = EOS_ID
98
+ self.pad_id = PAD_ID
99
+ self.bpe_id = BPE_ID
100
+ self.n_words = self.vocab_size_unit_1 + self.offsetting_special_char
101
+
102
+ # Convert string tokens to AddedToken objects
103
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
104
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
105
+ pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
106
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
107
+ self.boe_token = AddedToken(boe_token, normalized=False, special=True) if isinstance(boe_token, str) else boe_token
108
+
109
+ super().__init__(
110
+ bos_token=bos_token,
111
+ eos_token=eos_token,
112
+ pad_token=pad_token,
113
+ unk_token=unk_token,
114
+ add_bos_token=add_bos_token,
115
+ add_eos_token=add_eos_token,
116
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
117
+ spaces_between_special_tokens=spaces_between_special_tokens,
118
+ **kwargs,
119
+ )
120
+
121
+ @property
122
+ def vocab_size(self):
123
+ """Returns vocab size"""
124
+ return self.vocab_size_unit_1 + self.offsetting_special_char
125
+
126
+ def get_vocab(self):
127
+ """Returns vocab as a dict"""
128
+ # Create a mapping for byte values + offset
129
+ vocab = {}
130
+
131
+ # Add special tokens (with defensive checks)
132
+ if hasattr(self, 'bos_token'):
133
+ vocab[str(self.bos_token)] = self.bos_id
134
+ if hasattr(self, 'eos_token'):
135
+ vocab[str(self.eos_token)] = self.eos_id
136
+ if hasattr(self, 'pad_token'):
137
+ vocab[str(self.pad_token)] = self.pad_id
138
+ if hasattr(self, 'boe_token'):
139
+ vocab[str(self.boe_token)] = self.boe_id
140
+
141
+ # Add byte tokens as string representations of byte values
142
+ vocab_size_unit_1 = getattr(self, 'vocab_size_unit_1', BYTE_UNITS)
143
+ offsetting_special_char = getattr(self, 'offsetting_special_char', OFFSET)
144
+ for i in range(vocab_size_unit_1):
145
+ vocab[str(i)] = i + offsetting_special_char
146
+
147
+ # Add any additional tokens if available
148
+ if hasattr(self, 'added_tokens_encoder'):
149
+ vocab.update(self.added_tokens_encoder)
150
+ return vocab
151
+
152
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
153
+ """
154
+ Converts a string to a list of tokens. For BLT, we work directly with byte values.
155
+ Returns a list of strings that represent the byte values.
156
+ """
157
+ # Convert text to UTF-8 bytes, just like the original
158
+ try:
159
+ bytes_data = text.encode("utf-8", errors="ignore")
160
+ except UnicodeEncodeError:
161
+ bytes_data = text.encode("utf-8", errors="ignore")
162
+
163
+ # Return string representations of byte values for the tokenizer framework
164
+ return [str(byte_val) for byte_val in bytes_data]
165
+
166
+ def _convert_token_to_id(self, token: str) -> int:
167
+ """Converts a token (str) to an id using the vocab."""
168
+ # Handle special tokens
169
+ if token == str(self.bos_token):
170
+ return self.bos_id
171
+ elif token == str(self.eos_token):
172
+ return self.eos_id
173
+ elif token == str(self.pad_token):
174
+ return self.pad_id
175
+ elif token == str(self.boe_token):
176
+ return self.boe_id
177
+ else:
178
+ try:
179
+ # Convert byte value string to int and add offset (like original)
180
+ byte_val = int(token)
181
+ if 0 <= byte_val <= 255:
182
+ return byte_val + self.offsetting_special_char
183
+ except ValueError:
184
+ pass
185
+
186
+ # Check if it's in added tokens
187
+ return self.added_tokens_encoder.get(token, self.unk_token_id)
188
+
189
+ def _convert_id_to_token(self, index: int) -> str:
190
+ """Converts an index (integer) to a token (str) using the vocab."""
191
+ # Handle special tokens
192
+ if index == self.bos_id:
193
+ return str(self.bos_token)
194
+ elif index == self.eos_id:
195
+ return str(self.eos_token)
196
+ elif index == self.pad_id:
197
+ return str(self.pad_token)
198
+ elif index == self.boe_id:
199
+ return str(self.boe_token)
200
+ elif index >= self.offsetting_special_char and index < self.vocab_size:
201
+ # Convert back to byte value (like original)
202
+ byte_val = index - self.offsetting_special_char
203
+ return str(byte_val)
204
+ else:
205
+ # Check added tokens
206
+ for token, token_id in self.added_tokens_encoder.items():
207
+ if token_id == index:
208
+ return token
209
+ return str(self.unk_token)
210
+
211
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
212
+ """Converts a sequence of tokens to a single string."""
213
+ byte_values = []
214
+
215
+ for token in tokens:
216
+ # Skip special tokens
217
+ if token in [str(self.bos_token), str(self.eos_token), str(self.pad_token), str(self.boe_token)]:
218
+ continue
219
+
220
+ try:
221
+ # Convert token back to byte value (like original decode method)
222
+ byte_val = int(token)
223
+ if 0 <= byte_val <= 255:
224
+ byte_values.append(byte_val)
225
+ except ValueError:
226
+ continue
227
+
228
+ # Convert byte values back to string (exactly like original)
229
+ try:
230
+ return bytes(byte_values).decode("utf-8", errors="ignore")
231
+ except (UnicodeDecodeError, ValueError):
232
+ return ""
233
+
234
+ def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None):
235
+ """
236
+ Encode text exactly like the original BLT tokenizer.
237
+ """
238
+ if add_bos is None:
239
+ add_bos = self.add_bos_token
240
+ if add_eos is None:
241
+ add_eos = self.add_eos_token
242
+
243
+ # Since bpe_delim=False, we use the simple byte encoding
244
+ tokens = bytes(text, encoding="utf-8", errors="ignore")
245
+
246
+ # Offsetting (exactly like original)
247
+ tokens = [int(unit) + self.offsetting_special_char for unit in tokens]
248
+
249
+ if add_bos:
250
+ tokens.insert(0, self.bos_id)
251
+ if add_eos:
252
+ tokens.append(self.eos_id)
253
+
254
+ return tokens
255
+
256
+ def decode(self, tokens: list[int], cut_at_eos: bool = False):
257
+ """
258
+ Decode tokens exactly like the original BLT tokenizer.
259
+ """
260
+ if cut_at_eos:
261
+ for k, t in enumerate(tokens):
262
+ if t == self.eos_id:
263
+ tokens = tokens[: k + 1]
264
+ break
265
+ return bytes(
266
+ [tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0]
267
+ ).decode("utf-8", errors="ignore")
268
+
269
+ def get_vocab_size(self) -> int:
270
+ """Get vocab size like the original tokenizer."""
271
+ return self.vocab_size_unit_1 + self.offsetting_special_char
272
+
273
+ __all__ = ["BLTTokenizer"]
backup_blt_wip_backup/tokenizers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (185 Bytes). View file
 
backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (149 Bytes). View file
 
backup_blt_wip_backup/tokenizers/__pycache__/abstract_tokenizer.cpython-312.pyc ADDED
Binary file (1.51 kB). View file
 
backup_blt_wip_backup/tokenizers/__pycache__/blt_tokenizer.cpython-312.pyc ADDED
Binary file (6.69 kB). View file
 
backup_blt_wip_backup/tokenizers/__pycache__/build_tokenizer.cpython-312.pyc ADDED
Binary file (3.08 kB). View file
 
backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-312.pyc ADDED
Binary file (491 Bytes). View file
 
backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-39.pyc ADDED
Binary file (411 Bytes). View file
 
backup_blt_wip_backup/tokenizers/__pycache__/sentence_piece_tokenizer.cpython-312.pyc ADDED
Binary file (3.78 kB). View file
 
backup_blt_wip_backup/tokenizers/__pycache__/tiktoken_tokenizer.cpython-312.pyc ADDED
Binary file (5.19 kB). View file
 
backup_blt_wip_backup/tokenizers/abstract_tokenizer.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import abc
3
+
4
+
5
+ class Tokenizer(abc.ABC):
6
+ @abc.abstractmethod
7
+ def encode(self, text: str, add_bos: bool, add_eos: bool):
8
+ pass
9
+
10
+ @abc.abstractmethod
11
+ def decode(self, tokens: list[int]):
12
+ pass
13
+
14
+ @abc.abstractmethod
15
+ def get_token_offsets(self, text: str, tokens: list[int] | None = None) -> tuple[list[str], list[int]]:
16
+ """Return the offsets of the tokens in the original text. Only used for evaluation."""
17
+ pass
18
+
19
+ @abc.abstractmethod
20
+ def get_vocab_size(self) -> int:
21
+ pass
backup_blt_wip_backup/tokenizers/blt_tokenizer.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import re
3
+
4
+ from .abstract_tokenizer import Tokenizer
5
+ from .sentence_piece_tokenizer import SentencePieceTokenizer
6
+
7
+
8
+ SEP = " "
9
+ BOS_ID: int = 1
10
+ EOS_ID: int = 2
11
+ PAD_ID: int = -1
12
+ BOE_ID: int = 0
13
+ BPE_ID: int = 3
14
+ OFFSET: int = 4
15
+
16
+ BYTE_UNITS: int = 256
17
+
18
+
19
+ def convert_to_bytes(s):
20
+ # check if the output is a bytes like object of the format <0x00>
21
+ if re.match(r"<0x[0-9a-fA-F]+>", s):
22
+ return bytes.fromhex(s[3:-1])
23
+ else:
24
+ return bytes(s, "utf-8", errors="ignore")
25
+
26
+
27
+ def text2bytes_bpe_delims(
28
+ text: str,
29
+ *,
30
+ bpe_tokenizer,
31
+ bpe_id: int,
32
+ offsetting_special_char: int,
33
+ add_bos: bool,
34
+ add_eos: bool,
35
+ ):
36
+ cur_bpe = bpe_tokenizer.encode(text, add_bos=add_bos, add_eos=add_eos)
37
+ # merge the leading space tokens
38
+ leading_space_tokens = []
39
+ other_bpe_tokens = []
40
+ leading = True
41
+ for token in cur_bpe:
42
+ bpe_str = bpe_tokenizer.sp_model.id_to_piece(token)
43
+ if leading and all(c == "▁" for c in bpe_str):
44
+ leading_space_tokens.append(bpe_str)
45
+ else:
46
+ leading = False
47
+ other_bpe_tokens.append(bpe_str)
48
+ cur_bpe_strs = ["".join(leading_space_tokens)] + other_bpe_tokens
49
+
50
+ # Remove the '▁' characters
51
+ bpe_strs = []
52
+ for i, bpe_str in enumerate(cur_bpe_strs):
53
+ if len(bpe_strs) <= 1 and all([c == " " for s in bpe_strs for c in s]) and not all(c == "▁" for c in bpe_str):
54
+ # Remove leading space for first non space token.
55
+ bpe_str = bpe_str.replace("▁", "")
56
+ elif i == 0 and all(c == "▁" for c in bpe_str):
57
+ bpe_str = " " * (len(text) - len(text.lstrip(" ")))
58
+ else:
59
+ bpe_str = bpe_str.replace("▁", " ")
60
+ if len(bpe_str) > 0:
61
+ bpe_strs.append(bpe_str)
62
+ ex_seq = []
63
+ # Convert bpe tokens to bytes
64
+ for s in bpe_strs:
65
+ byte_chunk = convert_to_bytes(s)
66
+ proc_chunk = [int(unit) for unit in byte_chunk]
67
+ ex_seq.extend([bpe_id - offsetting_special_char] + proc_chunk)
68
+
69
+ return ex_seq
70
+
71
+
72
+ class BltTokenizer(Tokenizer):
73
+ def __init__(
74
+ self,
75
+ *,
76
+ vocab_size_unit_1: int = BYTE_UNITS,
77
+ bpe_delim: bool = False,
78
+ bpe_tokenizer_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model",
79
+ add_bos: bool = True,
80
+ add_eos: bool = True,
81
+ ):
82
+ self.add_bos = add_bos
83
+ self.add_eos = add_eos
84
+ self.vocab_size_unit_1 = vocab_size_unit_1
85
+ self.boe_id = BOE_ID
86
+ self.bos_id = BOS_ID
87
+ self.eos_id = EOS_ID
88
+ self.pad_id = PAD_ID
89
+ self.bpe_id = BPE_ID
90
+ self.bpe_tokenizer_path = bpe_tokenizer_path
91
+ if bpe_delim:
92
+ self.bpe_tokenizer = SentencePieceTokenizer(model_path=self.bpe_tokenizer_path)
93
+ else:
94
+ self.bpe_tokenizer = None
95
+ self.bpe_delim = bpe_delim
96
+ self.offsetting_special_char = OFFSET
97
+ self.vocab_size_unit_1 = vocab_size_unit_1
98
+ self.n_words = vocab_size_unit_1 + self.offsetting_special_char
99
+
100
+ def get_vocab_size(self) -> int:
101
+ return self.n_words
102
+
103
+ def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None):
104
+ if add_bos is None:
105
+ add_bos = self.add_bos
106
+ if add_eos is None:
107
+ add_eos = self.add_eos
108
+
109
+ if self.bpe_delim:
110
+ tokens = text2bytes_bpe_delims(
111
+ text,
112
+ bpe_tokenizer=self.bpe_tokenizer,
113
+ bpe_id=self.bpe_id,
114
+ offsetting_special_char=self.offsetting_special_char,
115
+ add_bos=False,
116
+ add_eos=False,
117
+ )
118
+ else:
119
+ tokens = bytes(text, encoding="utf-8", errors="ignore")
120
+
121
+ # Offsetting
122
+ tokens = [int(unit) + self.offsetting_special_char for unit in tokens]
123
+
124
+ if add_bos:
125
+ tokens.insert(0, self.bos_id)
126
+ if add_eos:
127
+ tokens.append(self.eos_id)
128
+
129
+ return tokens
130
+
131
+ def decode(self, tokens: list[int], cut_at_eos: bool = False):
132
+ if cut_at_eos:
133
+ for k, t in enumerate(tokens):
134
+ if t == self.eos_id:
135
+ tokens = tokens[: k + 1]
136
+ break
137
+ return bytes(
138
+ [tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0]
139
+ ).decode("utf-8", errors="ignore")
140
+
141
+ def get_token_offsets(self, text: str, tokens: list[int] | None = None):
142
+ # TODO: Figure out what this does
143
+ raise NotImplementedError()