diff --git a/backup_blt_modellike/__init__.py b/backup_blt_modellike/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..703b81ecdd09dda47a97c641f7e440bcb5e81119
--- /dev/null
+++ b/backup_blt_modellike/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_blt import *
+ from .modeling_blt import *
+ from .tokenization_blt import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/backup_blt_modellike/__pycache__/__init__.cpython-312.pyc b/backup_blt_modellike/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1eead81c4f00feeba4ec4166a9cd4cad1d228c8
Binary files /dev/null and b/backup_blt_modellike/__pycache__/__init__.cpython-312.pyc differ
diff --git a/backup_blt_modellike/__pycache__/tokenization_blt.cpython-312.pyc b/backup_blt_modellike/__pycache__/tokenization_blt.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c0d2019b6d449a711c0fc06bdf6d36fc4f9caa47
Binary files /dev/null and b/backup_blt_modellike/__pycache__/tokenization_blt.cpython-312.pyc differ
diff --git a/backup_blt_modellike/configuration_blt.py b/backup_blt_modellike/configuration_blt.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ee59362d955060fbe5bff81b635da55dc012981
--- /dev/null
+++ b/backup_blt_modellike/configuration_blt.py
@@ -0,0 +1,225 @@
+# coding=utf-8
+# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""BLT model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+class BLTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`BLTModel`]. It is used to instantiate an BLT
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the BLT-7B.
+ e.g. [meta-blt/BLT-2-7b-hf](https://huggingface.co/meta-blt/BLT-2-7b-hf)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the BLT model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`BLTModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. BLT 1 supports up to 2048 tokens,
+ BLT 2 up to 4096, CodeBLT up to 16384.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'blt3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'blt3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'blt3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'blt3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+ head_dim (`int`, *optional*):
+ The attention head dimension. If None, it will default to hidden_size // num_attention_heads
+
+ ```python
+ >>> from transformers import BLTModel, BLTConfig
+
+ >>> # Initializing a BLT blt-7b style configuration
+ >>> configuration = BLTConfig()
+
+ >>> # Initializing a model from the blt-7b style configuration
+ >>> model = BLTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "blt"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `BLTModel`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ mlp_bias=False,
+ head_dim=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.mlp_bias = mlp_bias
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+__all__ = ["BLTConfig"]
diff --git a/backup_blt_modellike/convert_blt_weights_to_hf.py b/backup_blt_modellike/convert_blt_weights_to_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..26c05477a169d41213133c3dffcc00ac32cd2401
--- /dev/null
+++ b/backup_blt_modellike/convert_blt_weights_to_hf.py
@@ -0,0 +1,397 @@
+import argparse
+import json
+import logging
+import os
+from typing import Any, Dict, Optional
+
+import torch
+from huggingface_hub import hf_hub_download, upload_folder
+from safetensors.torch import load_file, save_file
+
+from transformers.models.blt_wip.configuration_blt import BLTConfig
+from transformers.models.blt_wip.modeling_blt import BLTModel
+from transformers.models.blt_wip.modeling_blt_dev import BLTForCausalLM
+from transformers.utils import logging as transformers_logging
+
+
+logger = transformers_logging.get_logger(__name__)
+transformers_logging.set_verbosity_info()
+
+
+def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]:
+ logger.info("Merging configurations")
+
+ with open(config_path, "r") as f:
+ main_config = json.load(f)
+
+ with open(entropy_params_path, "r") as f:
+ entropy_data = json.load(f)
+
+ entropy_model_params = entropy_data.get("entropy_model", {})
+ patcher_args = entropy_data.get("data", {}).get("patcher_args", {})
+
+ unified_config = main_config.copy()["args"]
+
+ for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]:
+ if key in unified_config and not isinstance(unified_config[key], int):
+ unified_config[key] = int(unified_config[key])
+
+ patch_size = patcher_args.get("patch_size", 8)
+ if isinstance(patch_size, float):
+ patch_size = int(patch_size)
+
+ # Create patcher config
+ patcher_hidden_size = int(entropy_model_params.get("dim", 512))
+ patcher_multiple_of = int(entropy_model_params.get("multiple_of", 256))
+ patcher_intermediate_size = patcher_multiple_of * ((int(8 * patcher_hidden_size / 3) + patcher_multiple_of - 1) // patcher_multiple_of)
+
+ patcher_config = {
+ "vocab_size": int(entropy_model_params.get("vocab_size", 256)),
+ "hidden_size": patcher_hidden_size,
+ "num_hidden_layers": int(entropy_model_params.get("n_layers", 8)),
+ "num_attention_heads": int(entropy_model_params.get("n_heads", 8)),
+ "num_key_value_heads": int(entropy_model_params.get("n_kv_heads"))
+ if entropy_model_params.get("n_kv_heads") is not None
+ else None,
+ "max_position_embeddings": int(entropy_model_params.get("max_seqlen", 1024)),
+ "norm_eps": entropy_model_params.get("norm_eps", 1e-5),
+ "dropout": entropy_model_params.get("dropout", 0.0),
+ "rope_theta": entropy_model_params.get("rope_theta", 10000.0),
+ "attn_impl": entropy_model_params.get("attn_impl", "sdpa"),
+ "attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"),
+ "intermediate_size": patcher_intermediate_size,
+ }
+
+ # Create encoder config
+ encoder_hidden_size = unified_config.get("dim_local_encoder", 1024)
+ encoder_multiple_of = unified_config.get("multiple_of", 256)
+ encoder_intermediate_size = encoder_multiple_of * ((int(8 * encoder_hidden_size / 3) + encoder_multiple_of - 1) // encoder_multiple_of)
+
+ encoder_config = {
+ "vocab_size": unified_config.get("vocab_size", 256),
+ "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_encoder", False),
+ "cross_attn_k": unified_config.get("cross_attn_k", 2),
+ "hidden_size_global": unified_config.get("hidden_size_global", 2048),
+ "pm_size": unified_config.get("pm_size", 0),
+ "hidden_size": encoder_hidden_size,
+ "num_attention_heads": unified_config.get("n_heads_local_encoder", 16),
+ "num_key_value_heads": unified_config.get("n_kv_heads"),
+ "num_hidden_layers": unified_config.get("n_layers_local_encoder", 1),
+ "norm_eps": unified_config.get("norm_eps", 1e-5),
+ "dropout": unified_config.get("dropout", 0.0),
+ "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024),
+ "rope_theta": unified_config.get("rope_theta", 10000.0),
+ "rope_scaling": {"rope_type": "default"},
+ "hidden_act": unified_config.get("hidden_act", "silu"),
+ "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"),
+ "intermediate_size": encoder_intermediate_size,
+ }
+
+ # Create decoder config
+ decoder_hidden_size = unified_config.get("dim_local_decoder", 1024)
+ decoder_multiple_of = unified_config.get("multiple_of", 256)
+ decoder_intermediate_size = decoder_multiple_of * ((int(8 * decoder_hidden_size / 3) + decoder_multiple_of - 1) // decoder_multiple_of)
+
+ decoder_config = {
+ "vocab_size": unified_config.get("vocab_size", 256),
+ "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_decoder", False),
+ "cross_attn_k": unified_config.get("cross_attn_k", 2),
+ "hidden_size_global": unified_config.get("hidden_size_global", 2048),
+ "hidden_size": decoder_hidden_size,
+ "num_attention_heads": unified_config.get("n_heads_local_decoder", 16),
+ "num_key_value_heads": unified_config.get("n_kv_heads"),
+ "num_hidden_layers": unified_config.get("n_layers_local_decoder", 9),
+ "norm_eps": unified_config.get("norm_eps", 1e-5),
+ "dropout": unified_config.get("dropout", 0.0),
+ "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024),
+ "rope_theta": unified_config.get("rope_theta", 10000.0),
+ "rope_scaling": {"rope_type": "default"},
+ "hidden_act": unified_config.get("hidden_act", "silu"),
+ "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"),
+ "intermediate_size": decoder_intermediate_size,
+ }
+
+ # Create global transformer config
+ global_hidden_size = unified_config.get("dim_global", 2048)
+ global_multiple_of = unified_config.get("multiple_of", 256)
+ global_intermediate_size = global_multiple_of * ((int(8 * global_hidden_size / 3) + global_multiple_of - 1) // global_multiple_of)
+
+ global_config = {
+ "hidden_size": global_hidden_size,
+ "num_attention_heads": unified_config.get("n_heads_global", 16),
+ "num_key_value_heads": unified_config.get("n_kv_heads_global"),
+ "num_hidden_layers": unified_config.get("n_layers_global", 25),
+ "norm_eps": unified_config.get("norm_eps", 1e-5),
+ "dropout": unified_config.get("dropout", 0.0),
+ "max_position_embeddings": unified_config.get("max_seqlen", 1024),
+ "rope_theta": unified_config.get("rope_theta", 10000.0),
+ "rope_scaling": {"rope_type": "default"},
+ "hidden_act": unified_config.get("hidden_act", "silu"),
+ "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"),
+ "intermediate_size": global_intermediate_size,
+ }
+
+ # Create main config with sub-configs
+ main_config_dict = {
+ "model_type": "blt",
+ "vocab_size": unified_config.get("vocab_size", 256),
+ "max_position_embeddings": unified_config.get("max_seqlen", 1024),
+ "patch_in_forward": True,
+ "realtime_patching": True,
+ "patching_mode": "entropy",
+ "patch_size": patch_size,
+ "patching_threshold": patcher_args.get("threshold", 0.5),
+ "patching_threshold_add": patcher_args.get("threshold_add", 0.0),
+ "max_patch_length": patcher_args.get("max_patch_length"),
+ "patching_batch_size": patcher_args.get("patching_batch_size", 1),
+ "patching_device": patcher_args.get("patching_device", "cuda"),
+ "monotonicity": patcher_args.get("monotonicity", False),
+ "cross_attn_k": unified_config.get("cross_attn_k", 2),
+ "encoder_hash_byte_group_size": unified_config.get("encoder_hash_byte_group_size"),
+ "encoder_hash_byte_group_vocab": unified_config.get("encoder_hash_byte_group_vocab", 30000),
+ "encoder_hash_byte_group_nb_functions": unified_config.get("encoder_hash_byte_group_nb_functions", 3),
+ "pm_size": unified_config.get("pm_size", 0),
+ "patcher_config": patcher_config,
+ "encoder_config": encoder_config,
+ "decoder_config": decoder_config,
+ "global_config": global_config,
+ }
+
+ main_config_dict["tie_word_embeddings"] = False
+
+ logger.info(f"Merged configuration with {len(main_config_dict)} parameters")
+ return main_config_dict
+
+
+def apply_weight_mapping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ component_mappings = {
+ ".attention.": ".self_attn.",
+ ".feed_forward.": ".mlp.",
+ ".attention_norm.": ".input_layernorm.",
+ ".ffn_norm.": ".post_attention_layernorm.",
+ ".tok_embeddings.": ".embed_tokens.",
+ ".cross_attn_norm_q.": ".q_norm.",
+ ".cross_attn_norm_kv.": ".k_norm.",
+ ".w1.": ".gate_proj.",
+ ".w2.": ".down_proj.",
+ ".w3.": ".up_proj.",
+ ".wq.": ".q_proj.",
+ ".wk.": ".k_proj.",
+ ".wv.": ".v_proj.",
+ ".wo.": ".o_proj.",
+ ".output.": ".lm_head.",
+ }
+
+ new_state_dict = {}
+
+ for old_key, tensor in state_dict.items():
+ new_key = old_key
+
+ for old_pattern, new_pattern in component_mappings.items():
+ if old_pattern in new_key:
+ new_key = new_key.replace(old_pattern, new_pattern)
+
+ new_state_dict[new_key] = tensor
+
+ return new_state_dict
+
+
+def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]:
+ main_weights = load_file(weights_path)
+
+ entropy_weights = torch.load(entropy_weights_path, map_location="cpu", weights_only=True)
+
+ if "model" in entropy_weights:
+ entropy_weights = entropy_weights["model"]
+ elif "state_dict" in entropy_weights:
+ entropy_weights = entropy_weights["state_dict"]
+
+ unified_weights = main_weights.copy()
+
+ for key, tensor in entropy_weights.items():
+ patcher_key = f"patcher.{key}"
+ unified_weights[patcher_key] = tensor
+
+ unified_weights = apply_weight_mapping(unified_weights)
+
+ decoder_lm_head_key = "local_decoder.lm_head.weight"
+ top_lm_head_key = "lm_head.weight"
+ unified_weights[top_lm_head_key] = unified_weights[decoder_lm_head_key]
+ del unified_weights[decoder_lm_head_key]
+
+ prefixed_weights = {}
+ for key, tensor in unified_weights.items():
+ if key == top_lm_head_key:
+ prefixed_weights[key] = tensor
+ elif not key.startswith("model."):
+ prefixed_weights[f"model.{key}"] = tensor
+ else:
+ prefixed_weights[key] = tensor
+
+ unified_weights = prefixed_weights
+
+ return unified_weights
+
+
+def create_tokenizer_config(output_dir: str, config: Dict[str, Any]):
+ tokenizer_config = {
+ "tokenizer_class": "BltTokenizer",
+ "vocab_size": config.get("vocab_size", 256),
+ "model_max_length": config.get("max_seqlen", 1024),
+ "add_bos_token": True,
+ "add_eos_token": True,
+ "bos_token": "",
+ "eos_token": "",
+ "pad_token": "",
+ "unk_token": "",
+ }
+
+ tokenizer_path = os.path.join(output_dir, "tokenizer_config.json")
+ with open(tokenizer_path, "w") as f:
+ json.dump(tokenizer_config, f, indent=2)
+
+
+def push_to_hub(
+ local_dir: str,
+ repo_id: str,
+ commit_message: str = "Upload converted BLT model",
+ private: bool = False,
+ token: Optional[str] = None,
+) -> None:
+ try:
+ upload_folder(
+ folder_path=local_dir,
+ repo_id=repo_id,
+ commit_message=commit_message,
+ repo_type="model",
+ token=token,
+ )
+ logger.info(f"Successfully pushed model to {repo_id}")
+
+ except Exception as e:
+ logger.error(f"Failed to push model to Hub: {e}")
+ raise
+
+
+def convert_hf_blt_to_unified(
+ model_id: str,
+ output_dir: str,
+ config_name: str = "config.json",
+ weights_name: str = "model.bin",
+ cache_dir: Optional[str] = None,
+ push_to_hub_repo: Optional[str] = None,
+ hub_private: bool = False,
+ hub_token: Optional[str] = None,
+) -> None:
+ # Download model files
+ config_path = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir)
+ weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", cache_dir=cache_dir)
+ entropy_params_path = hf_hub_download(repo_id=model_id, filename="entropy_model/params.json", cache_dir=cache_dir)
+ entropy_weights_path = hf_hub_download(
+ repo_id=model_id, filename="entropy_model/consolidated.pth", cache_dir=cache_dir
+ )
+
+ unified_config = merge_configurations(config_path, entropy_params_path)
+ unified_weights = merge_weights(weights_path, entropy_weights_path)
+
+ os.makedirs(output_dir, exist_ok=True)
+
+ config_path = os.path.join(output_dir, config_name)
+ with open(config_path, "w") as f:
+ json.dump(unified_config, f, indent=2)
+
+ if weights_name.endswith(".bin"):
+ weights_name = weights_name.replace(".bin", ".safetensors")
+
+ weights_path = os.path.join(output_dir, weights_name)
+ save_file(unified_weights, weights_path)
+
+ create_tokenizer_config(output_dir, unified_config)
+
+ logger.info(f"Conversion completed, model saved to: {output_dir}")
+
+ if push_to_hub_repo:
+ push_to_hub(
+ local_dir=output_dir,
+ repo_id=push_to_hub_repo,
+ commit_message="Upload BLT model converted",
+ private=hub_private,
+ token=hub_token,
+ )
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Convert BLT models from HuggingFace Hub format to unified format",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--model_id",
+ type=str,
+ default="facebook/blt-1b",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="./blt_converted",
+ )
+ parser.add_argument(
+ "--config_name",
+ type=str,
+ default="config.json",
+ )
+ parser.add_argument(
+ "--weights_name",
+ type=str,
+ default="model.bin",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ )
+ parser.add_argument(
+ "--debug",
+ action="store_true",
+ default=True,
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ type=str,
+ default=None,
+ )
+ parser.add_argument(
+ "--hub_private",
+ action="store_true",
+ default=False,
+ )
+ parser.add_argument(
+ "--hub_token",
+ type=str,
+ default="hf_token",
+ )
+
+ args = parser.parse_args()
+
+ transformers_logging.set_verbosity_debug()
+ logging.basicConfig(level=logging.DEBUG)
+
+ try:
+ convert_hf_blt_to_unified(
+ model_id=args.model_id,
+ output_dir=args.output_dir,
+ config_name=args.config_name,
+ weights_name=args.weights_name,
+ cache_dir=args.cache_dir,
+ push_to_hub_repo=args.push_to_hub,
+ hub_private=args.hub_private,
+ hub_token=args.hub_token,
+ )
+ except Exception as e:
+ logger.error(f"Conversion failed: {e}")
+ raise
+
+
+if __name__ == "__main__":
+ main()
diff --git a/backup_blt_modellike/modeling_blt.py b/backup_blt_modellike/modeling_blt.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdcaa9062d336430668cc040ae4ca080b5c69977
--- /dev/null
+++ b/backup_blt_modellike/modeling_blt.py
@@ -0,0 +1,971 @@
+# coding=utf-8
+# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import ALL_LAYERNORM_LAYERS
+from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
+from .configuration_blt import BLTConfig
+
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+from ...integrations import use_kernel_forward_from_hub
+
+
+logger = logging.get_logger(__name__)
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->BLT
+class BLTRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ BLTRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+ALL_LAYERNORM_LAYERS.append(BLTRMSNorm)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->BLT
+class BLTRotaryEmbedding(nn.Module):
+ def __init__(self, config: BLTConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->BLT
+class BLTMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->BLT
+class BLTAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: BLTConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->BLT
+class BLTDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: BLTConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = BLTAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = BLTMLP(config)
+ self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->BLT
+class BLTPreTrainedModel(PreTrainedModel):
+ config_class = BLTConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["BLTDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, BLTRMSNorm):
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->BLT
+class BLTModel(BLTPreTrainedModel):
+ def __init__(self, config: BLTConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [BLTDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = BLTRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> BaseModelOutputWithPast:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
+ if not isinstance(past_key_values, (type(None), Cache)):
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **flash_attn_kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, "BlockMask"],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+
+
+@auto_docstring
+# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->BLT,llama->blt
+class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = BLTModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[KwargsForCausalLM],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, BLTForCausalLM
+
+ >>> model = BLTForCausalLM.from_pretrained("meta-blt/BLT-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-blt/BLT-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
+
+ [`BLTForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
+# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->BLT
+class BLTForSequenceClassification(BLTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = BLTModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> SequenceClassifierOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ transformer_outputs: BaseModelOutputWithPast = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ hidden_states = transformer_outputs.last_hidden_state
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->BLT
+class BLTForQuestionAnswering(BLTPreTrainedModel):
+ base_model_prefix = "transformer"
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = BLTModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.transformer.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.transformer.embed_tokens = value
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ **kwargs,
+ ) -> QuestionAnsweringModelOutput:
+ outputs: BaseModelOutputWithPast = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ sequence_output = outputs.last_hidden_state
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ loss = None
+ if start_positions is not None and end_positions is not None:
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
+
+ return QuestionAnsweringModelOutput(
+ loss=loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->BLT
+class BLTForTokenClassification(BLTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = BLTModel(config)
+ if getattr(config, "classifier_dropout", None) is not None:
+ classifier_dropout = config.classifier_dropout
+ elif getattr(config, "hidden_dropout", None) is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> TokenClassifierOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ sequence_output = outputs.last_hidden_state
+ sequence_output = self.dropout(sequence_output)
+ logits = self.score(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.config)
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "BLTForCausalLM",
+ "BLTModel",
+ "BLTPreTrainedModel",
+ "BLTForSequenceClassification",
+ "BLTForQuestionAnswering",
+ "BLTForTokenClassification",
+]
diff --git a/backup_blt_modellike/tokenization_blt.py b/backup_blt_modellike/tokenization_blt.py
new file mode 100644
index 0000000000000000000000000000000000000000..145f257f0eeca91ca4ddc89470300d42f3372025
--- /dev/null
+++ b/backup_blt_modellike/tokenization_blt.py
@@ -0,0 +1,412 @@
+# coding=utf-8
+# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tokenization classes for BLT."""
+
+import os
+from shutil import copyfile
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+
+import sentencepiece as spm
+
+from ...convert_slow_tokenizer import import_protobuf
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+from ...utils.import_utils import requires
+
+
+if TYPE_CHECKING:
+ from ...tokenization_utils_base import TextInput
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
+
+SPIECE_UNDERLINE = "▁"
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<>\n", "\n<>\n\n"
+
+DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
+answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
+ that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
+correct. If you don't know the answer to a question, please don't share false information.""" # fmt: skip
+
+
+@requires(backends=("sentencepiece",))
+class BLTTokenizer(PreTrainedTokenizer):
+ """
+ Construct a BLT tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
+ no padding token in the original model.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The end of sequence token.
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*):
+ A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
+ attention mechanisms or loss computation.
+ sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+ to set:
+
+ - `enable_sampling`: Enable subword regularization.
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+ - `nbest_size = {0,1}`: No sampling is performed.
+ - `nbest_size > 1`: samples from the nbest_size results.
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+ using forward-filtering-and-backward-sampling algorithm.
+
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+ BPE-dropout.
+
+ add_bos_token (`bool`, *optional*, defaults to `True`):
+ Whether or not to add an `bos_token` at the start of sequences.
+ add_eos_token (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an `eos_token` at the end of sequences.
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+ Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
+ extra spaces.
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
+ Whether or not the default system prompt for BLT should be used.
+ spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to add spaces between special tokens.
+ legacy (`bool`, *optional*):
+ Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
+ and #25224 which includes fixes to properly handle tokens that appear after special tokens.
+ Make sure to also set `from_slow` to `True`.
+ A simple example:
+
+ - `legacy=True`:
+ ```python
+ >>> from transformers import BLTTokenizerFast
+
+ >>> tokenizer = BLTTokenizerFast.from_pretrained("huggyblt/blt-7b", legacy=True, from_slow=True)
+ >>> tokenizer.encode("Hello .") # 869 is '▁.'
+ [1, 15043, 29871, 1, 869]
+ ```
+ - `legacy=False`:
+ ```python
+ >>> from transformers import BLTTokenizerFast
+
+ >>> tokenizer = BLTTokenizerFast.from_pretrained("huggyblt/blt-7b", legacy=False, from_slow=True)
+ >>> tokenizer.encode("Hello .") # 29889 is '.'
+ [1, 15043, 29871, 1, 29889]
+ ```
+ Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
+ add_prefix_space (`bool`, *optional*, defaults to `True`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. Again, this should be set with `from_slow=True` to make sure it's taken into account.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ unk_token="",
+ bos_token="",
+ eos_token="",
+ pad_token=None,
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
+ add_bos_token=True,
+ add_eos_token=False,
+ clean_up_tokenization_spaces=False,
+ use_default_system_prompt=False,
+ spaces_between_special_tokens=False,
+ legacy=None,
+ add_prefix_space=True,
+ **kwargs,
+ ):
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
+ pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
+
+ if legacy is None:
+ logger.warning_once(
+ f"You are using the default legacy behaviour of the {self.__class__}. This is"
+ " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
+ " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
+ " means, and thoroughly read the reason why this was added as explained in"
+ " https://github.com/huggingface/transformers/pull/24565 - if you loaded a blt tokenizer from a GGUF file"
+ " you can ignore this message"
+ )
+ legacy = True
+
+ self.legacy = legacy
+ self.vocab_file = vocab_file
+ self.add_bos_token = add_bos_token
+ self.add_eos_token = add_eos_token
+ self.use_default_system_prompt = use_default_system_prompt
+ self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
+ self.add_prefix_space = add_prefix_space
+
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token,
+ sp_model_kwargs=self.sp_model_kwargs,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ use_default_system_prompt=use_default_system_prompt,
+ spaces_between_special_tokens=spaces_between_special_tokens,
+ legacy=legacy,
+ add_prefix_space=add_prefix_space,
+ **kwargs,
+ )
+
+ @property
+ def unk_token_length(self):
+ return len(self.sp_model.encode(str(self.unk_token)))
+
+ def get_spm_processor(self, from_slow=False):
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ if self.legacy or from_slow: # no dependency on protobuf
+ tokenizer.Load(self.vocab_file)
+ return tokenizer
+
+ with open(self.vocab_file, "rb") as f:
+ sp_model = f.read()
+ model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
+ model = model_pb2.ModelProto.FromString(sp_model)
+ normalizer_spec = model_pb2.NormalizerSpec()
+ normalizer_spec.add_dummy_prefix = False
+ model.normalizer_spec.MergeFrom(normalizer_spec)
+ sp_model = model.SerializeToString()
+ tokenizer.LoadFromSerializedProto(sp_model)
+ return tokenizer
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__.update(d)
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
+
+ @property
+ def vocab_size(self):
+ """Returns vocab size"""
+ return self.sp_model.get_piece_size()
+
+ def get_vocab(self):
+ """Returns vocab as a dict"""
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
+ """
+ Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
+ first token is special.
+ """
+ if self.legacy or len(text) == 0:
+ return super().tokenize(text, **kwargs)
+
+ text = text.replace(SPIECE_UNDERLINE, " ")
+ if self.add_prefix_space:
+ text = SPIECE_UNDERLINE + text
+
+ tokens = super().tokenize(text, **kwargs)
+
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
+ tokens = tokens[1:]
+ return tokens
+
+ def _tokenize(self, text, **kwargs):
+ """
+ Returns a tokenized string.
+
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
+ SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
+ `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
+ `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`.
+ `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`.
+ """
+ if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
+ return self.sp_model.encode(text, out_type=str)
+
+ # 1. Encode string + prefix ex: " Hey"
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
+ return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.sp_model.piece_to_id(token)
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ token = self.sp_model.IdToPiece(index)
+ return token
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ # since we manually add the prefix space, we have to remove it when decoding
+ if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space:
+ tokens[0] = tokens[0][1:]
+
+ current_sub_tokens = []
+ out_string = ""
+ prev_is_special = False
+ for i, token in enumerate(tokens):
+ # make sure that special tokens are not decoded using sentencepiece model
+ if token in self.all_special_tokens:
+ if not prev_is_special and i != 0 and self.legacy:
+ out_string += " "
+ out_string += self.sp_model.decode(current_sub_tokens) + token
+ prev_is_special = True
+ current_sub_tokens = []
+ else:
+ if prev_is_special and i == 1 and self.add_prefix_space and not token.startswith(SPIECE_UNDERLINE):
+ out_string += " "
+ current_sub_tokens.append(token)
+ prev_is_special = False
+ out_string += self.sp_model.decode(current_sub_tokens)
+ return out_string
+
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ """
+ Save the vocabulary and special tokens file to a directory.
+
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+ output = bos_token_id + token_ids_0 + eos_token_id
+
+ if token_ids_1 is not None:
+ output = output + bos_token_id + token_ids_1 + eos_token_id
+
+ return output
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ bos_token_id = [1] if self.add_bos_token else []
+ eos_token_id = [1] if self.add_eos_token else []
+
+ if token_ids_1 is None:
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
+ return (
+ bos_token_id
+ + ([0] * len(token_ids_0))
+ + eos_token_id
+ + bos_token_id
+ + ([0] * len(token_ids_1))
+ + eos_token_id
+ )
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
+ sequence pair mask has the following format:
+
+ ```
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of ids.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
+
+ if token_ids_1 is not None:
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
+
+ return output
+
+
+#__all__ = ["BLTTokenizer"]
diff --git a/backup_blt_wip copy/__init__.py b/backup_blt_wip copy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/backup_blt_wip copy/__pycache__/__init__.cpython-312.pyc b/backup_blt_wip copy/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aec13a4ab4fd4175a82f4cc9acbcae04b89ed6fd
Binary files /dev/null and b/backup_blt_wip copy/__pycache__/__init__.cpython-312.pyc differ
diff --git a/backup_blt_wip copy/__pycache__/blt_args.cpython-312.pyc b/backup_blt_wip copy/__pycache__/blt_args.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe1c3fac23fde31ac2fe95aa5c15d35368e766fc
Binary files /dev/null and b/backup_blt_wip copy/__pycache__/blt_args.cpython-312.pyc differ
diff --git a/backup_blt_wip copy/__pycache__/blt_one_file.cpython-312.pyc b/backup_blt_wip copy/__pycache__/blt_one_file.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad67ee5b9c82f98e084ea14790b83f2fbbca1539
Binary files /dev/null and b/backup_blt_wip copy/__pycache__/blt_one_file.cpython-312.pyc differ
diff --git a/backup_blt_wip copy/__pycache__/configuration_blt.cpython-312.pyc b/backup_blt_wip copy/__pycache__/configuration_blt.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6619d112a99fe0a3db54bb3649041aa9d3254a71
Binary files /dev/null and b/backup_blt_wip copy/__pycache__/configuration_blt.cpython-312.pyc differ
diff --git a/backup_blt_wip copy/__pycache__/configuration_blt_og.cpython-312.pyc b/backup_blt_wip copy/__pycache__/configuration_blt_og.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65e042fc413fbde011b9ac7c4982694313c9d1e9
Binary files /dev/null and b/backup_blt_wip copy/__pycache__/configuration_blt_og.cpython-312.pyc differ
diff --git a/backup_blt_wip copy/__pycache__/modeling_blt.cpython-312.pyc b/backup_blt_wip copy/__pycache__/modeling_blt.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d11c9fbc187865ec15972562ad1fc46629f6a986
Binary files /dev/null and b/backup_blt_wip copy/__pycache__/modeling_blt.cpython-312.pyc differ
diff --git a/backup_blt_wip copy/__pycache__/modeling_blt_dev.cpython-312.pyc b/backup_blt_wip copy/__pycache__/modeling_blt_dev.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df77d3f4ae267eada2d86e052112370cda866d91
Binary files /dev/null and b/backup_blt_wip copy/__pycache__/modeling_blt_dev.cpython-312.pyc differ
diff --git a/backup_blt_wip copy/__pycache__/modeling_blt_modellike.cpython-312.pyc b/backup_blt_wip copy/__pycache__/modeling_blt_modellike.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8dcf920f19d283f8290013fc57df016f94f08853
Binary files /dev/null and b/backup_blt_wip copy/__pycache__/modeling_blt_modellike.cpython-312.pyc differ
diff --git a/backup_blt_wip copy/__pycache__/modeling_blt_old.cpython-312.pyc b/backup_blt_wip copy/__pycache__/modeling_blt_old.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0066695b7e54ec3a4c8d924f4fb37dfa0c9f5ec
Binary files /dev/null and b/backup_blt_wip copy/__pycache__/modeling_blt_old.cpython-312.pyc differ
diff --git a/backup_blt_wip copy/__pycache__/modeling_blt_wip.cpython-312.pyc b/backup_blt_wip copy/__pycache__/modeling_blt_wip.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14471f3e8bde08cd94f2ba3b8ea1dc01597af449
Binary files /dev/null and b/backup_blt_wip copy/__pycache__/modeling_blt_wip.cpython-312.pyc differ
diff --git a/backup_blt_wip copy/__pycache__/modeling_blt_wip_backup.cpython-312.pyc b/backup_blt_wip copy/__pycache__/modeling_blt_wip_backup.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..01bb41a2f95cf657f1cb253e41a8eadbe9a4c5b7
Binary files /dev/null and b/backup_blt_wip copy/__pycache__/modeling_blt_wip_backup.cpython-312.pyc differ
diff --git a/backup_blt_wip copy/__pycache__/tokenization_blt.cpython-312.pyc b/backup_blt_wip copy/__pycache__/tokenization_blt.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..871072a6257040ba60bbb0473fb9be6b73117ccb
Binary files /dev/null and b/backup_blt_wip copy/__pycache__/tokenization_blt.cpython-312.pyc differ
diff --git a/backup_blt_wip copy/configuration_blt.py b/backup_blt_wip copy/configuration_blt.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9a245afebc3be28dc5bc10f7ff9ef6097ea8247
--- /dev/null
+++ b/backup_blt_wip copy/configuration_blt.py
@@ -0,0 +1,390 @@
+# coding=utf-8
+# Copyright 2024 Facebook Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""BLT model configuration"""
+
+from enum import Enum
+from typing import Union
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+class BLTLocalEncoderConfig(PretrainedConfig):
+ """
+ Configuration class for the BLT Local Encoder component.
+ """
+
+ model_type = "blt_local_encoder"
+
+ def __init__(
+ self,
+ vocab_size=256,
+ cross_attn_all_layers=True,
+ cross_attn_k=2,
+ hidden_size_global=2048,
+ hidden_size=512,
+ num_attention_heads=8,
+ num_key_value_heads=None,
+ num_hidden_layers=8,
+ norm_eps=1e-5,
+ dropout=0.0,
+ max_position_embeddings=1024,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ hidden_act="silu",
+ intermediate_size=None,
+ _attn_implementation="sdpa",
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.cross_attn_all_layers = cross_attn_all_layers
+ self.cross_attn_k = cross_attn_k
+ self.hidden_size_global = hidden_size_global
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads or num_attention_heads
+ self.head_dim = hidden_size // num_attention_heads
+ self.intermediate_size = intermediate_size or int(8 * hidden_size / 3)
+ self.num_hidden_layers = num_hidden_layers
+ self.norm_eps = norm_eps
+ self.dropout = dropout
+ self.max_position_embeddings = max_position_embeddings
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling or {"rope_type": "default"}
+ self.hidden_act = hidden_act
+ self._attn_implementation = _attn_implementation
+
+ super().__init__(**kwargs)
+
+class BLTLocalDecoderConfig(PretrainedConfig):
+ """
+ Configuration class for the BLT Local Decoder component.
+ """
+
+ model_type = "blt_local_decoder"
+
+ def __init__(
+ self,
+ vocab_size=256,
+ cross_attn_all_layers=True,
+ cross_attn_k=2,
+ hidden_size_global=2048,
+ hidden_size=512,
+ num_attention_heads=8,
+ num_key_value_heads=None,
+ num_hidden_layers=8,
+ norm_eps=1e-5,
+ dropout=0.0,
+ max_position_embeddings=1024,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ hidden_act="silu",
+ intermediate_size=None,
+ _attn_implementation="sdpa",
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.cross_attn_all_layers = cross_attn_all_layers
+ self.cross_attn_k = cross_attn_k
+ self.hidden_size_global = hidden_size_global
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads or num_attention_heads
+ self.head_dim = hidden_size // num_attention_heads
+ self.intermediate_size = intermediate_size or int(8 * hidden_size / 3)
+ self.num_hidden_layers = num_hidden_layers
+ self.norm_eps = norm_eps
+ self.dropout = dropout
+ self.max_position_embeddings = max_position_embeddings
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling or {"rope_type": "default"}
+ self.hidden_act = hidden_act
+ self._attn_implementation = _attn_implementation
+
+ super().__init__(**kwargs)
+
+
+class BLTGlobalTransformerConfig(PretrainedConfig):
+ """
+ Configuration class for the BLT Global Transformer component.
+ """
+
+ model_type = "blt_global_transformer"
+
+ def __init__(
+ self,
+ hidden_size=512,
+ num_attention_heads=8,
+ num_key_value_heads=None,
+ num_hidden_layers=8,
+ norm_eps=1e-5,
+ dropout=0.0,
+ max_position_embeddings=1024,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ hidden_act="silu",
+ intermediate_size=None,
+ _attn_implementation="sdpa",
+ **kwargs,
+ ):
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads or num_attention_heads
+ self.head_dim = hidden_size // num_attention_heads
+ self.intermediate_size = intermediate_size or int(8 * hidden_size / 3)
+ self.num_hidden_layers = num_hidden_layers
+ self.norm_eps = norm_eps
+ self.dropout = dropout
+ self.max_position_embeddings = max_position_embeddings
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling or {"rope_type": "default"}
+ self.hidden_act = hidden_act
+ self._attn_implementation = _attn_implementation
+
+ super().__init__(**kwargs)
+
+
+class BLTPatcherConfig(PretrainedConfig):
+ r"""
+ Configuration class for the BLT Patcher/Entropy model component.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 256):
+ Vocabulary size for the entropy model used in patching.
+ hidden_size (`int`, *optional*, defaults to 512):
+ Hidden dimension for the entropy model.
+ num_hidden_layers (`int`, *optional*, defaults to 8):
+ Number of layers in the entropy model.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads in the entropy model.
+ head_dim (`int`, *optional*):
+ Dimension of each attention head in the entropy model.
+ num_key_value_heads (`int`, *optional*):
+ Number of key-value heads in the entropy model.
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
+ Maximum sequence length for the entropy model.
+ norm_eps (`float`, *optional*, defaults to 1e-5):
+ Layer normalization epsilon for the entropy model.
+ dropout (`float`, *optional*, defaults to 0.0):
+ Dropout probability for the entropy model.
+ ffn_dim_multiplier (`float`, *optional*):
+ Feedforward dimension multiplier for the entropy model.
+ multiple_of (`int`, *optional*, defaults to 256):
+ Make feedforward dimension multiple of this for the entropy model.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ RoPE theta parameter for the entropy model.
+ attn_impl (`str`, *optional*, defaults to "sdpa"):
+ Attention implementation for the entropy model.
+ attn_bias_type (`str`, *optional*, defaults to "causal"):
+ Attention bias type for the entropy model.
+ """
+
+ model_type = "blt_patcher"
+
+ def __init__(
+ self,
+ vocab_size=256,
+ hidden_size=512,
+ num_hidden_layers=8,
+ num_attention_heads=8,
+ num_key_value_heads=None,
+ max_position_embeddings=1024,
+ norm_eps=1e-5,
+ dropout=0.0,
+ rope_theta=10000.0,
+ attn_impl="sdpa",
+ attn_bias_type="causal",
+ intermediate_size=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = hidden_size // num_attention_heads
+ self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.norm_eps = norm_eps
+ self.dropout = dropout
+ self.rope_theta = rope_theta
+ self.attn_impl = attn_impl
+ self.attn_bias_type = attn_bias_type
+ self.hidden_act = "silu" # BLT uses silu activation
+ self.intermediate_size = intermediate_size or int(8 * self.hidden_size / 3)
+ self.rope_scaling = {"rope_type": "default"}
+ super().__init__(**kwargs)
+
+
+class BLTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`BLTModel`]. It is used to instantiate a
+ BLT model according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 256):
+ Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented.
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model can handle.
+
+ # Patching configuration
+ patch_in_forward (`bool`, *optional*, defaults to False):
+ Whether to perform patching during forward pass.
+ patch_size (`float`, *optional*):
+ Size of patches for static patching.
+ patching_mode (`str`, *optional*):
+ Mode for patching ("entropy", "static", etc.).
+ patching_threshold (`float`, *optional*):
+ Threshold for entropy-based patching.
+ patching_batch_size (`int`, *optional*, defaults to 1):
+ Batch size for patching operations.
+ patching_device (`str`, *optional*, defaults to "cuda"):
+ Device to use for patching operations.
+ max_patch_length (`int`, *optional*):
+ Maximum length of patches.
+
+ # Cross attention configurations
+ cross_attn_k (`int`, *optional*):
+ Number of cross attention components.
+
+ # Encoder configurations
+ encoder_hash_byte_group_size (`Any`, *optional*):
+ Hash byte group size for encoder.
+ encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 30000):
+ Vocabulary size for hash byte groups.
+ encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 3):
+ Number of hash functions for byte groups.
+
+ # Component configurations
+ patcher_config (`Union[BLTPatcherConfig, dict]`, *optional*):
+ Configuration for the BLT patcher/entropy model component.
+ encoder_config (`Union[BLTLocalEncoderConfig, dict]`, *optional*):
+ Configuration for the BLT local encoder component.
+ decoder_config (`Union[BLTLocalDecoderConfig, dict]`, *optional*):
+ Configuration for the BLT local decoder component.
+ global_config (`Union[BLTGlobalTransformerConfig, dict]`, *optional*):
+ Configuration for the BLT global transformer component.
+
+ ```python
+ >>> from transformers import BLTModel, BLTConfig
+
+ >>> # Initializing a BLT configuration
+ >>> configuration = BLTConfig()
+
+ >>> # Initializing a model from the configuration
+ >>> model = BLTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "blt"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ sub_configs = {
+ "patcher_config": BLTPatcherConfig,
+ "encoder_config": BLTLocalEncoderConfig,
+ "decoder_config": BLTLocalDecoderConfig,
+ "global_config": BLTGlobalTransformerConfig
+ }
+
+ def __init__(
+ self,
+ vocab_size=256,
+ max_position_embeddings=1024,
+ patch_in_forward=False,
+ patch_size=None,
+ patching_mode=None,
+ patching_threshold=None,
+ patching_batch_size=1,
+ max_patch_length=None,
+ cross_attn_k=2,
+ encoder_hash_byte_group_size=None,
+ encoder_hash_byte_group_vocab=30000,
+ encoder_hash_byte_group_nb_functions=3,
+ patcher_config=None,
+ encoder_config=None,
+ decoder_config=None,
+ global_config=None,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+
+ # Basic model configuration
+ self.tie_word_embeddings = tie_word_embeddings
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+
+ # Patching configuration
+ self.patch_in_forward = patch_in_forward
+ self.patch_size = patch_size
+ self.patching_mode = patching_mode
+ self.patching_threshold = patching_threshold
+ self.patching_batch_size = patching_batch_size
+ self.max_patch_length = max_patch_length
+
+ # Cross attention configurations
+ self.cross_attn_k = cross_attn_k
+
+ # Encoder configurations
+ self.encoder_hash_byte_group_size = encoder_hash_byte_group_size or [2, 3, 4]
+ self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab
+ self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions
+
+ # Initialize component configurations
+ if patcher_config is None:
+ self.patcher_config = BLTPatcherConfig()
+ logger.info("patcher_config is None, using default BLT patcher config")
+ elif isinstance(patcher_config, dict):
+ self.patcher_config = BLTPatcherConfig(**patcher_config)
+ elif isinstance(patcher_config, BLTPatcherConfig):
+ self.patcher_config = patcher_config
+
+ if encoder_config is None:
+ self.encoder_config = BLTLocalEncoderConfig()
+ logger.info("encoder_config is None, using default BLT encoder config")
+ elif isinstance(encoder_config, dict):
+ self.encoder_config = BLTLocalEncoderConfig(**encoder_config)
+ elif isinstance(encoder_config, BLTLocalEncoderConfig):
+ self.encoder_config = encoder_config
+
+ if decoder_config is None:
+ self.decoder_config = BLTLocalDecoderConfig()
+ logger.info("decoder_config is None, using default BLT decoder config")
+ elif isinstance(decoder_config, dict):
+ self.decoder_config = BLTLocalDecoderConfig(**decoder_config)
+ elif isinstance(decoder_config, BLTLocalDecoderConfig):
+ self.decoder_config = decoder_config
+
+ if global_config is None:
+ self.global_config = BLTGlobalTransformerConfig()
+ logger.info("global_config is None, using default BLT global config")
+ elif isinstance(global_config, dict):
+ self.global_config = BLTGlobalTransformerConfig(**global_config)
+ elif isinstance(global_config, BLTGlobalTransformerConfig):
+ self.global_config = global_config
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+__all__ = [
+ "BLTConfig",
+ "BLTPatcherConfig",
+ "BLTLocalEncoderConfig",
+ "BLTLocalDecoderConfig",
+ "BLTGlobalTransformerConfig",
+]
diff --git a/backup_blt_wip copy/configuration_blt_og.py b/backup_blt_wip copy/configuration_blt_og.py
new file mode 100644
index 0000000000000000000000000000000000000000..60af723d89d72599eb0add3881c2b99e9e7390ce
--- /dev/null
+++ b/backup_blt_wip copy/configuration_blt_og.py
@@ -0,0 +1,608 @@
+# old config
+
+# coding=utf-8
+# Copyright 2024 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""BLT (Byte Latent Transformer) model configuration"""
+
+from enum import Enum
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class InitStdFactor(str, Enum):
+ DISABLED = "disabled" # Init std is divided by 1.0
+ CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
+
+
+class PatchingModeEnum(str, Enum):
+ entropy = "entropy"
+ bpe = "bpe"
+ bpe_patcher = "bpe_patcher"
+ space = "space"
+ static = "static"
+ byte = "byte"
+
+
+class BLTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ByteLatentTransformer`]. It is used to instantiate a
+ BLT model according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 256):
+ Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented.
+ max_seqlen (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model can handle.
+
+ # Main architecture dimensions
+ dim (`int`, *optional*, defaults to 512):
+ Main dimension of the model.
+ n_layers (`int`, *optional*, defaults to 8):
+ Number of layers in the main transformer.
+ n_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads in the main transformer.
+ head_dim (`int`, *optional*):
+ Dimension of each attention head. If not specified, computed as dim // n_heads.
+ n_kv_heads (`int`, *optional*):
+ Number of key-value heads for grouped query attention. If not specified, defaults to n_heads.
+
+ # Component-specific dimensions
+ dim_global (`int`, *optional*, defaults to 512):
+ Dimension of the global transformer component.
+ dim_local_decoder (`int`, *optional*, defaults to 512):
+ Dimension of the local decoder component.
+ dim_local_encoder (`int`, *optional*, defaults to 512):
+ Dimension of the local encoder component.
+ n_layers_global (`int`, *optional*, defaults to 8):
+ Number of layers in the global transformer.
+ n_layers_local_decoder (`int`, *optional*, defaults to 8):
+ Number of layers in the local decoder.
+ n_layers_local_encoder (`int`, *optional*, defaults to 8):
+ Number of layers in the local encoder.
+ n_heads_global (`int`, *optional*, defaults to 8):
+ Number of attention heads in the global transformer.
+ n_heads_local_decoder (`int`, *optional*, defaults to 8):
+ Number of attention heads in the local decoder.
+ n_heads_local_encoder (`int`, *optional*, defaults to 8):
+ Number of attention heads in the local encoder.
+ n_kv_heads_global (`int`, *optional*):
+ Number of key-value heads in the global transformer.
+
+ # Transformer configuration
+ norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers.
+ ffn_dim_multiplier (`float`, *optional*, defaults to 1.0):
+ Multiplier for the feedforward network dimension.
+ multiple_of (`int`, *optional*, defaults to 256):
+ Make feedforward network dimension multiple of this value.
+
+ # Positional encoding
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False):
+ Whether to use fp32 in RoPE outer product computation.
+
+ # Attention configuration
+ attn_impl (`str`, *optional*, defaults to "sdpa"):
+ Attention implementation to use ("sdpa" or "flex_attention").
+ attn_bias_type (`str`, *optional*, defaults to "causal"):
+ Type of attention bias to apply.
+ local_attention_window_len (`int`, *optional*):
+ Window length for local attention.
+ use_rope (`bool`, *optional*, defaults to True):
+ Whether to use rotary position embeddings.
+
+ # Initialization
+ init_base_std (`float`, *optional*):
+ Base standard deviation for weight initialization.
+ init_std_factor (`str`, *optional*, defaults to "disabled"):
+ Factor for adjusting initialization standard deviation.
+
+ # Embedding dimensions
+ dim_token_emb (`int`, *optional*):
+ Token embedding dimension.
+ dim_token (`int`, *optional*):
+ Token dimension.
+
+ # Patching configuration
+ patch_in_forward (`bool`, *optional*, defaults to False):
+ Whether to perform patching during forward pass.
+ realtime_patching (`bool`, *optional*, defaults to True):
+ Whether to use realtime patching.
+ patch_size (`float`, *optional*):
+ Size of patches for static patching.
+ patching_mode (`str`, *optional*):
+ Mode for patching ("entropy", "static", etc.).
+ patching_threshold (`float`, *optional*):
+ Threshold for entropy-based patching.
+ patching_threshold_add (`float`, *optional*):
+ Additional threshold parameter for patching.
+ monotonicity (`bool`, *optional*, defaults to False):
+ Whether to enforce monotonicity in patching.
+ patching_batch_size (`int`, *optional*, defaults to 1):
+ Batch size for patching operations.
+ patching_device (`str`, *optional*, defaults to "cuda"):
+ Device to use for patching operations.
+ max_patch_length (`int`, *optional*):
+ Maximum length of patches.
+ entropy_model_checkpoint_dir (`str`, *optional*):
+ Directory containing entropy model checkpoint.
+
+ # Cross attention configurations
+ cross_attn_encoder (`bool`, *optional*, defaults to False):
+ Whether to use cross attention in encoder.
+ cross_attn_decoder (`bool`, *optional*, defaults to False):
+ Whether to use cross attention in decoder.
+ cross_attn_window_encoder (`int`, *optional*):
+ Cross attention window for encoder.
+ cross_attn_window_decoder (`int`, *optional*):
+ Cross attention window for decoder.
+ cross_attn_k (`int`, *optional*):
+ Number of cross attention components.
+ cross_attn_nheads (`int`, *optional*):
+ Number of heads for cross attention.
+ cross_attn_all_layers_decoder (`bool`, *optional*, defaults to False):
+ Whether to apply cross attention to all decoder layers.
+ cross_attn_all_layers_encoder (`bool`, *optional*, defaults to False):
+ Whether to apply cross attention to all encoder layers.
+ cross_attn_use_flex_attention (`bool`, *optional*, defaults to True):
+ Whether to use flexible attention for cross attention.
+ cross_attn_init_by_pooling (`bool`, *optional*, defaults to False):
+ Whether to initialize cross attention by pooling.
+
+ # Encoder configurations
+ use_local_encoder_transformer (`bool`, *optional*, defaults to False):
+ Whether to use transformer in local encoder.
+ max_encoder_seq_length (`int`, *optional*):
+ Maximum sequence length for encoder.
+ encoder_hash_byte_group_size (`Any`, *optional*):
+ Hash byte group size for encoder.
+ encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 30000):
+ Vocabulary size for hash byte groups.
+ encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 3):
+ Number of hash functions for byte groups.
+ encoder_enable_byte_ngrams (`bool`, *optional*, defaults to False):
+ Whether to enable byte n-grams in encoder.
+ encoder_ngram_to_size_str (`str`, *optional*):
+ String defining n-gram sizes.
+ downsampling_by_pooling (`str`, *optional*):
+ Type of pooling for downsampling.
+
+ # Model behavior
+ share_encoder_decoder_emb (`bool`, *optional*, defaults to True):
+ Whether to share encoder and decoder embeddings.
+ weight_tying (`bool`, *optional*, defaults to False):
+ Whether to tie input and output embeddings.
+
+ # Performance optimization
+ sequence_parallel (`bool`, *optional*, defaults to False):
+ Whether to use sequence parallelism.
+ loss_parallel (`bool`, *optional*, defaults to False):
+ Whether to use loss parallelism.
+ fuse_sequence_parallel (`bool`, *optional*, defaults to False):
+ Whether to fuse sequence parallel operations.
+ use_fsdp (`bool`, *optional*, defaults to True):
+ Whether to use fully sharded data parallel.
+
+ # Parameter mixing
+ pm_size (`int`, *optional*, defaults to 0):
+ Parameter mixing size.
+
+ # Special tokens
+ bos_token_id (`int`, *optional*, defaults to 1):
+ The id of the "beginning-of-sequence" token.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the "end-of-sequence" token.
+ pad_token_id (`int`, *optional*, defaults to -1):
+ The id of the padding token.
+
+ # Patcher/Entropy model configuration
+ patcher_vocab_size (`int`, *optional*, defaults to 256):
+ Vocabulary size for the entropy model used in patching.
+ patcher_dim (`int`, *optional*, defaults to 512):
+ Hidden dimension for the entropy model.
+ patcher_n_layers (`int`, *optional*, defaults to 8):
+ Number of layers in the entropy model.
+ patcher_n_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads in the entropy model.
+ patcher_head_dim (`int`, *optional*):
+ Dimension of each attention head in the entropy model.
+ patcher_n_kv_heads (`int`, *optional*):
+ Number of key-value heads in the entropy model.
+ patcher_max_seqlen (`int`, *optional*, defaults to 1024):
+ Maximum sequence length for the entropy model.
+ patcher_norm_eps (`float`, *optional*, defaults to 1e-5):
+ Layer normalization epsilon for the entropy model.
+ patcher_dropout (`float`, *optional*, defaults to 0.0):
+ Dropout probability for the entropy model.
+ patcher_sliding_window (`int`, *optional*):
+ Sliding window size for the entropy model attention.
+ patcher_ffn_dim_multiplier (`float`, *optional*):
+ Feedforward dimension multiplier for the entropy model.
+ patcher_multiple_of (`int`, *optional*, defaults to 256):
+ Make feedforward dimension multiple of this for the entropy model.
+ patcher_rope_theta (`float`, *optional*, defaults to 10000.0):
+ RoPE theta parameter for the entropy model.
+ patcher_rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False):
+ Whether to use fp32 in RoPE outer product for the entropy model.
+ patcher_attn_impl (`str`, *optional*, defaults to "sdpa"):
+ Attention implementation for the entropy model.
+ patcher_attn_bias_type (`str`, *optional*, defaults to "causal"):
+ Attention bias type for the entropy model.
+ patcher_init_base_std (`float`, *optional*):
+ Base initialization standard deviation for the entropy model.
+ patcher_init_std_factor (`str`, *optional*, defaults to "disabled"):
+ Initialization std factor for the entropy model.
+ patcher_dim_token_emb (`int`, *optional*):
+ Token embedding dimension for the entropy model.
+ patcher_weight_tying (`bool`, *optional*, defaults to False):
+ Whether to tie embeddings in the entropy model.
+ patcher_bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of sequence token id for the entropy model.
+ patcher_eos_token_id (`int`, *optional*, defaults to 2):
+ End of sequence token id for the entropy model.
+
+ ```python
+ >>> from transformers import ByteLatentTransformer, BLTConfig
+
+ >>> # Initializing a BLT configuration
+ >>> configuration = BLTConfig()
+
+ >>> # Initializing a model from the configuration
+ >>> model = ByteLatentTransformer(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "blt"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=256,
+ max_seqlen=1024,
+ # Main architecture dimensions
+ dim=512,
+ n_layers=8,
+ n_heads=8,
+ head_dim=None,
+ n_kv_heads=None,
+ # Component-specific dimensions
+ dim_global=512,
+ dim_local_decoder=512,
+ dim_local_encoder=512,
+ n_layers_global=8,
+ n_layers_local_decoder=8,
+ n_layers_local_encoder=8,
+ n_heads_global=8,
+ n_heads_local_decoder=8,
+ n_heads_local_encoder=8,
+ n_kv_heads_global=None,
+ # Transformer configuration
+ norm_eps=1e-5,
+ dropout=0.0,
+ ffn_dim_multiplier=1.0,
+ multiple_of=256,
+ # Positional encoding
+ rope_theta=10000.0,
+ rope_use_fp32_in_outer_product=False,
+ # Attention configuration
+ attn_impl="sdpa",
+ attn_bias_type="causal",
+ local_attention_window_len=None,
+ use_rope=True,
+ # Initialization
+ init_base_std=None,
+ init_std_factor="disabled",
+ # Embedding dimensions
+ dim_token_emb=None,
+ dim_token=None,
+ # Patching configuration
+ patch_in_forward=False,
+ realtime_patching=True,
+ patch_size=None,
+ patching_mode=None,
+ patching_threshold=None,
+ patching_threshold_add=None,
+ monotonicity=False,
+ patching_batch_size=1,
+ patching_device="cuda",
+ max_patch_length=None,
+ entropy_model_checkpoint_dir=None,
+ # Cross attention configurations
+ cross_attn_encoder=False,
+ cross_attn_decoder=False,
+ cross_attn_window_encoder=None,
+ cross_attn_window_decoder=None,
+ cross_attn_k=None,
+ cross_attn_nheads=None,
+ cross_attn_all_layers_decoder=False,
+ cross_attn_all_layers_encoder=False,
+ cross_attn_use_flex_attention=True,
+ cross_attn_init_by_pooling=False,
+ # Encoder configurations
+ use_local_encoder_transformer=False,
+ max_encoder_seq_length=None,
+ encoder_hash_byte_group_size=None,
+ encoder_hash_byte_group_vocab=30000,
+ encoder_hash_byte_group_nb_functions=3,
+ encoder_enable_byte_ngrams=False,
+ encoder_ngram_to_size_str=None,
+ downsampling_by_pooling=None,
+ # Model behavior
+ share_encoder_decoder_emb=True,
+ weight_tying=False,
+ # Performance optimization
+ sequence_parallel=False,
+ loss_parallel=False,
+ fuse_sequence_parallel=False,
+ use_fsdp=True,
+ # Parameter mixing
+ pm_size=0,
+ # Special tokens
+ bos_token_id=1,
+ eos_token_id=2,
+ pad_token_id=-1,
+ # Patcher/Entropy model configuration
+ patcher_vocab_size=256,
+ patcher_dim=512,
+ patcher_n_layers=8,
+ patcher_n_heads=8,
+ patcher_head_dim=None,
+ patcher_n_kv_heads=None,
+ patcher_max_seqlen=1024,
+ patcher_norm_eps=1e-5,
+ patcher_dropout=0.0,
+ patcher_sliding_window=None,
+ patcher_ffn_dim_multiplier=None,
+ patcher_multiple_of=256,
+ patcher_rope_theta=10000.0,
+ patcher_rope_use_fp32_in_outer_product=False,
+ patcher_attn_impl="sdpa",
+ patcher_attn_bias_type="causal",
+ patcher_init_base_std=None,
+ patcher_init_std_factor="disabled",
+ patcher_dim_token_emb=None,
+ patcher_weight_tying=False,
+ patcher_bos_token_id=1,
+ patcher_eos_token_id=2,
+ # Inherited
+ **kwargs,
+ ):
+
+ self.sliding_window = None
+ # Basic model configuration
+ self.vocab_size = vocab_size
+ self.max_seqlen = max_seqlen
+
+ # Main architecture dimensions
+ self.dim = dim
+ self.n_layers = n_layers
+ self.n_heads = n_heads
+ self.head_dim = head_dim
+ self.n_kv_heads = n_kv_heads
+
+ # Component-specific dimensions
+ self.dim_global = dim_global
+ self.dim_local_decoder = dim_local_decoder
+ self.dim_local_encoder = dim_local_encoder
+ self.n_layers_global = n_layers_global
+ self.n_layers_local_decoder = n_layers_local_decoder
+ self.n_layers_local_encoder = n_layers_local_encoder
+ self.n_heads_global = n_heads_global
+ self.n_heads_local_decoder = n_heads_local_decoder
+ self.n_heads_local_encoder = n_heads_local_encoder
+ self.n_kv_heads_global = n_kv_heads_global
+
+ # Transformer configuration
+ self.norm_eps = norm_eps
+ self.dropout = dropout
+ self.ffn_dim_multiplier = ffn_dim_multiplier
+ self.multiple_of = multiple_of
+
+ # Positional encoding
+ self.rope_theta = rope_theta
+ self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
+
+ # Attention configuration
+ self.attn_impl = attn_impl
+ self.attn_bias_type = attn_bias_type
+ self.local_attention_window_len = local_attention_window_len
+ self.use_rope = use_rope
+
+ # Initialization
+ self.init_base_std = init_base_std
+ self.init_std_factor = InitStdFactor(init_std_factor)
+
+ # Embedding dimensions
+ self.dim_token_emb = dim_token_emb
+ self.dim_token = dim_token
+
+ # Patching configuration
+ self.patch_in_forward = patch_in_forward
+ self.realtime_patching = realtime_patching
+ self.patch_size = patch_size
+ self.patching_mode = patching_mode
+ self.patching_threshold = patching_threshold
+ self.patching_threshold_add = patching_threshold_add
+ self.monotonicity = monotonicity
+ self.patching_batch_size = patching_batch_size
+ self.patching_device = patching_device
+ self.max_patch_length = max_patch_length
+ self.entropy_model_checkpoint_dir = entropy_model_checkpoint_dir
+
+ # Cross attention configurations
+ self.cross_attn_encoder = cross_attn_encoder
+ self.cross_attn_decoder = cross_attn_decoder
+ self.cross_attn_window_encoder = cross_attn_window_encoder
+ self.cross_attn_window_decoder = cross_attn_window_decoder
+ self.cross_attn_k = cross_attn_k
+ self.cross_attn_nheads = cross_attn_nheads
+ self.cross_attn_all_layers_decoder = cross_attn_all_layers_decoder
+ self.cross_attn_all_layers_encoder = cross_attn_all_layers_encoder
+ self.cross_attn_use_flex_attention = cross_attn_use_flex_attention
+ self.cross_attn_init_by_pooling = cross_attn_init_by_pooling
+
+ # Encoder configurations
+ self.use_local_encoder_transformer = use_local_encoder_transformer
+ self.max_encoder_seq_length = max_encoder_seq_length
+ self.encoder_hash_byte_group_size = encoder_hash_byte_group_size
+ self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab
+ self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions
+ self.encoder_enable_byte_ngrams = encoder_enable_byte_ngrams
+ self.encoder_ngram_to_size_str = encoder_ngram_to_size_str
+ self.downsampling_by_pooling = downsampling_by_pooling
+
+ # Model behavior
+ self.share_encoder_decoder_emb = share_encoder_decoder_emb
+ self.weight_tying = weight_tying
+
+ # Performance optimization
+ self.sequence_parallel = sequence_parallel
+ self.loss_parallel = loss_parallel
+ self.fuse_sequence_parallel = fuse_sequence_parallel
+ self.use_fsdp = use_fsdp
+
+ # Parameter mixing
+ self.pm_size = pm_size
+
+ # Patcher/Entropy model configuration
+ self.patcher_vocab_size = patcher_vocab_size
+ self.patcher_dim = patcher_dim
+ self.patcher_n_layers = patcher_n_layers
+ self.patcher_n_heads = patcher_n_heads
+ self.patcher_head_dim = patcher_head_dim
+ self.patcher_n_kv_heads = patcher_n_kv_heads
+ self.patcher_max_seqlen = patcher_max_seqlen
+ self.patcher_norm_eps = patcher_norm_eps
+ self.patcher_dropout = patcher_dropout
+ self.patcher_sliding_window = patcher_sliding_window
+ self.patcher_ffn_dim_multiplier = patcher_ffn_dim_multiplier
+ self.patcher_multiple_of = patcher_multiple_of
+ self.patcher_rope_theta = patcher_rope_theta
+ self.patcher_rope_use_fp32_in_outer_product = patcher_rope_use_fp32_in_outer_product
+ self.patcher_attn_impl = patcher_attn_impl
+ self.patcher_attn_bias_type = patcher_attn_bias_type
+ self.patcher_init_base_std = patcher_init_base_std
+ self.patcher_init_std_factor = InitStdFactor(patcher_init_std_factor)
+ self.patcher_dim_token_emb = patcher_dim_token_emb
+ self.patcher_weight_tying = patcher_weight_tying
+ self.patcher_bos_token_id = patcher_bos_token_id
+ self.patcher_eos_token_id = patcher_eos_token_id
+
+ # Handle hash byte group size validation
+ if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str:
+ self.encoder_hash_byte_group_size = [
+ int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0
+ ]
+
+ # Rope
+ self.rope_scaling={
+ "type": "dynamic",
+ "factor": 2.0,
+ "rope_type": "dynamic"
+ }
+
+ self.num_key_value_heads=n_heads_local_encoder
+ self.max_position_embeddings=max_seqlen
+ self.hidden_size=dim_local_encoder
+ self.num_attention_heads=n_heads_local_encoder
+ # self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+
+ super().__init__(
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ **kwargs,
+ )
+
+ @property
+ def encoder_dim_token_emb(self):
+ """Compute encoder token embedding dimension."""
+ if self.dim_token is not None:
+ return self.dim_token
+ elif self.use_local_encoder_transformer:
+ return self.dim_local_encoder
+ else:
+ # Use default patch_size of 8 if not set
+ patch_size = self.patch_size if self.patch_size is not None else 8
+ return self.dim_global // patch_size
+
+ @property
+ def encoder_dim_patch_emb(self):
+ """Compute encoder patch embedding dimension."""
+ if self.cross_attn_encoder:
+ if self.cross_attn_init_by_pooling:
+ return self.dim_local_encoder
+ else:
+ return self.dim_global
+ return None
+
+ @property
+ def global_dim_patch_emb(self):
+ """Compute global patch embedding dimension."""
+ dim_token_emb = self.encoder_dim_token_emb
+ if self.cross_attn_encoder:
+ cross_attn_k = self.cross_attn_k if self.cross_attn_k is not None else 1
+ return dim_token_emb * cross_attn_k
+ elif (
+ self.downsampling_by_pooling is None
+ or not self.downsampling_by_pooling
+ or len(self.downsampling_by_pooling) == 0
+ ):
+ # Use default patch_size of 8 if not set
+ patch_size = self.patch_size if self.patch_size is not None else 8
+ return dim_token_emb * patch_size
+ else:
+ return dim_token_emb * sum([pooling in self.downsampling_by_pooling for pooling in ["avg", "min", "max"]])
+
+ @property
+ def decoder_dim_token_emb(self):
+ """Compute decoder token embedding dimension."""
+ if self.share_encoder_decoder_emb:
+ return self.encoder_dim_token_emb
+ elif self.dim_token is not None:
+ return self.dim_token
+ else:
+ return self.dim_local_decoder
+
+ def get_init_std_factor(self, depth: int) -> float:
+ """
+ Calculate the initialization standard deviation scaling factor for a given layer depth.
+
+ Args:
+ depth: Current layer depth (0-indexed)
+
+ Returns:
+ Scaling factor to divide the base initialization std by
+ """
+ if self.init_std_factor == InitStdFactor.CURRENT_DEPTH:
+ return (2 * (depth + 1)) ** 0.5
+ else: # DISABLED
+ return 1.0
+
+
+__all__ = ["BLTConfig", "InitStdFactor", "PatchingModeEnum"]
+
diff --git a/backup_blt_wip copy/modeling_blt.py b/backup_blt_wip copy/modeling_blt.py
new file mode 100644
index 0000000000000000000000000000000000000000..84d874daa76ef209652702d5b9a432fd721f5679
--- /dev/null
+++ b/backup_blt_wip copy/modeling_blt.py
@@ -0,0 +1,1287 @@
+# coding=utf-8
+# Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""BLT model."""
+
+from ...utils import is_torch_flex_attn_available, logging
+from typing import Callable, List, Optional, Tuple, Union
+
+from enum import Enum
+
+from ...cache_utils import Cache
+from ...activations import ACT2FN
+
+import torch
+import torch.distributions
+import torch.nn
+import torch.nn as nn
+from torch.nn import functional as F
+
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from .configuration_blt import (
+ BLTConfig,
+ BLTLocalEncoderConfig,
+ BLTLocalDecoderConfig,
+ BLTGlobalTransformerConfig,
+ BLTPatcherConfig,
+)
+
+from ...generation.utils import GenerationMixin
+from ...modeling_outputs import CausalLMOutputWithPast
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+class PatchingModeEnum(str, Enum):
+ entropy = "entropy"
+ bpe = "bpe"
+ bpe_patcher = "bpe_patcher"
+ space = "space"
+ static = "static"
+ byte = "byte"
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = F.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ # TODO: not exactly equivalent to other transformers implementations,, need feedback
+ # Extract first head_dim//2 elements which correspond to the unique frequencies
+ # This matches the original BLT approach which uses head_dim//2 frequency pairs
+ head_dim = q.shape[-1]
+ cos_freqs = cos[..., :head_dim//2] # [B, S, D/2]
+ sin_freqs = sin[..., :head_dim//2] # [B, S, D/2]
+
+ # Expand cos/sin to match query/key tensor format [B, H, S, D/2]
+ cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2]
+ sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2]
+
+ # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ...
+ q_pairs = q.view(*q.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2]
+ k_pairs = k.view(*k.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2]
+
+ # Extract real and i parts
+ q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2]
+ k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2]
+
+ # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag]
+ q_real_rot = cos_freqs * q_real - sin_freqs * q_imag
+ q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag
+ k_real_rot = cos_freqs * k_real - sin_freqs * k_imag
+ k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag
+
+ # Recombine pairs and reshape back to original format
+ q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D]
+ k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D]
+
+ return q_rot.type_as(q), k_rot.type_as(k)
+
+
+class BLTMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class BLTRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ BLTRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class BLTTransformerLayer(nn.Module):
+ def __init__(self, config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.layer_idx = layer_idx
+
+ self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx)
+ self.mlp = BLTMLP(config)
+ self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ position_ids (`torch.LongTensor`, *optional*):
+ Position indices of tokens in the sequence for RoPE computation.
+ past_key_value (`Cache`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class BLTSelfAttention(nn.Module):
+ def __init__(self, config, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.num_heads = config.num_attention_heads
+ self.dropout = config.dropout
+ self.hidden_size = config.hidden_size
+ self.num_key_value_heads = config.num_key_value_heads
+ self.head_dim = config.hidden_size // self.num_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.scaling = None
+ self.rope_theta = config.rope_theta
+ self.layer_idx = layer_idx
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_embeddings: torch.Tensor,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ past_key_value=None,
+ cache_position=None,
+ **kwargs,
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ output_attentions = False
+ self.config._attn_implementation = "sdpa"
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and output_attentions:
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0):
+ primes = [
+ 1000000007, 5915587277, 1500450271, 3267000013, 5754853343,
+ 4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313,
+ ]
+ prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device)
+ powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device)
+ prime_powers = prime ** powers
+ return torch.sum(token_tensor * prime_powers, dim=-1)
+
+
+def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000):
+ """Hash token groups and map to range [0, max_hash]."""
+ with torch.no_grad():
+ batch_size, seq_len = token_ids.shape
+ # Add padding for sliding window
+ padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device)
+ padded_tokens = torch.cat([padding, token_ids], dim=1)
+
+ # Create sliding windows and compute hashes
+ windows = padded_tokens.unfold(1, group_size, 1)
+ hashes = rolling_polynomial_hash(windows, hash_func_nb)
+ hash_values = hashes % max_hash
+
+ hash_values.requires_grad = False
+ return hash_values
+
+
+def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list):
+ """Initialize hash-based token embeddings for the BLT encoder."""
+ num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size)
+ embeddings = [
+ nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim)
+ for _ in range(num_embeddings)
+ ]
+ return nn.ModuleList(embeddings)
+
+
+def compute_hash_embeddings(
+ local_encoder_tokens: torch.Tensor,
+ local_encoder,
+ encoder_hash_tok_embedding: nn.ModuleList,
+ encoder_hash_byte_group_nb_functions: int,
+ encoder_hash_byte_group_size: list,
+ encoder_hash_byte_group_vocab: int,
+) -> torch.Tensor:
+ """Compute token embeddings enhanced with hash-based embeddings."""
+ embeddings = local_encoder.embed_tokens(local_encoder_tokens)
+ embedding_idx = 0
+ for func_nb in range(encoder_hash_byte_group_nb_functions):
+ for group_size in encoder_hash_byte_group_size:
+ hash_ids = byte_group_hash_function(
+ local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab
+ )
+ embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids)
+ embedding_idx += 1
+
+ return embeddings
+
+
+def _prepare_patch_cross_attention_mask(
+ patch_ids: torch.Tensor,
+ num_patches: int,
+ sequence_length: int,
+ patches_as_queries: bool = False,
+ cross_attn_k: int = 1,
+ dtype: torch.dtype = torch.float32,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Prepare cross-attention mask for patch-based attention, following mllama's robust approach.
+
+ This function creates masks that control which patches can attend to which other patches,
+ with support for query/key role swapping and cross-attention multipliers.
+
+ Args:
+ patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids.
+ num_patches (int): Total number of patches.
+ sequence_length (int): Length of the sequence.
+ patches_as_queries (bool): If True, patches are used as queries, otherwise as keys.
+ cross_attn_k (int): Cross-attention multiplier for repeating patches.
+ dtype (torch.dtype): Data type for the output mask.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]:
+ - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len]
+ - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows
+ """
+ batch_size, seq_len = patch_ids.shape
+ device = patch_ids.device
+
+ # Determine query and key lengths based on configuration
+ if patches_as_queries:
+ q_len = num_patches * cross_attn_k
+ kv_len = sequence_length
+ # Create patch-to-sequence mapping
+ q_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(-1).expand(
+ batch_size, num_patches, seq_len
+ )
+ kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len)
+ else:
+ q_len = sequence_length
+ kv_len = num_patches * cross_attn_k
+ # Create sequence-to-patch mapping
+ q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches)
+ kv_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(
+ batch_size, seq_len, num_patches
+ )
+
+ # Create base attention mask - boolean mask where True means "should attend"
+ # Exact patch matching
+ cross_attention_mask = q_patch_ids == kv_patch_ids
+
+ # Handle cross_attn_k multiplier by repeating along appropriate dimension
+ repeat_dim = 1 if patches_as_queries else -1
+ cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim)
+
+ # Validate dimensions
+ expected_shape = (batch_size, q_len, kv_len)
+ if cross_attention_mask.shape != expected_shape:
+ raise ValueError(f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}")
+
+ # Reshape so it can be used by attn module - add head dimension
+ cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len]
+
+ # Invert the mask (following mllama pattern exactly)
+ # True -> 0.0 (attend), False -> 1.0 (will become -inf)
+ inverted_cross_attn_mask = (1.0 - cross_attention_mask.to(dtype))
+ cross_attention_mask = inverted_cross_attn_mask.masked_fill(
+ inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
+ )
+
+ # Apply full-row bias (following mllama pattern exactly)
+ # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's
+ # last dimension contains negative infinity values, otherwise it's 1
+ negative_inf_value = torch.finfo(dtype).min
+ full_text_row_masked_out_mask = (
+ (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
+ )
+ cross_attention_mask *= full_text_row_masked_out_mask
+
+ return cross_attention_mask, full_text_row_masked_out_mask
+
+
+def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor:
+ """
+ Splits patch lengths into smaller segments if they exceed `max_patch_length`.
+ Pads the result to uniform length across the batch.
+
+ Args:
+ patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths.
+ max_patch_length (int, optional): Maximum allowed length per patch.
+
+ Returns:
+ torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths.
+ """
+ if max_patch_length is None:
+ return patch_lengths
+
+ batch_size = patch_lengths.size(0)
+ processed = []
+
+ for seq in patch_lengths:
+ splits = []
+ for length in seq[seq > 0]:
+ length = length.item()
+ full_chunks, remainder = divmod(length, max_patch_length)
+ splits.extend([max_patch_length] * full_chunks)
+ if remainder:
+ splits.append(remainder)
+ processed.append(splits)
+
+ # Find max length to pad to
+ max_len = max(len(splits) for splits in processed)
+ padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
+
+ for i, splits in enumerate(processed):
+ if splits:
+ padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
+
+ # Trim zero columns
+ if (padded != 0).any(dim=0).sum() < padded.shape[1]:
+ last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1
+ padded = padded[:, :last_nonzero]
+
+ return padded
+
+
+class BLTRotaryEmbedding(nn.Module):
+ def __init__(self, config, device=None):
+ super().__init__()
+ self.rope_type = config.rope_scaling["rope_type"]
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class BLTLocalEncoder(nn.Module):
+ def __init__(self, config: BLTLocalEncoderConfig):
+ super().__init__()
+
+ self.config = config
+
+ self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
+
+ self.rotary_emb = BLTRotaryEmbedding(config=config)
+
+ self.patch_embedding_projection = nn.Linear(
+ in_features=config.hidden_size,
+ out_features=config.hidden_size * config.cross_attn_k,
+ bias=False,
+ )
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+
+ self.cross_attn_layers = torch.nn.ModuleList()
+ layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
+ for layer_idx in range(layers_to_add):
+ self.cross_attn_layers.append(
+ BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ input_embeds: Optional[torch.Tensor] = None,
+ patch_embeds: Optional[torch.Tensor] = None,
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
+ cross_mask: Optional[torch.Tensor] = None,
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ num_patches: Optional[int] = None,
+ patch_ids: Optional[torch.Tensor] = None,
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+ ):
+ """ """
+ if input_embeds is None:
+ input_embeds = self.embed_tokens(input_ids)
+
+ batch_size, _, _ = input_embeds.shape
+
+ hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training)
+
+ position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1)
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
+
+ for idx, layer in enumerate(self.layers):
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None)
+ hidden_states = layer_outputs[0]
+
+ if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers:
+ patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids)
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
+ patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size)
+
+ layer_idx = idx if self.config.cross_attn_all_layers else 0
+ cross_attention_output, _, _ = self.cross_attn_layers[layer_idx](
+ hidden_states=patch_embeds,
+ cross_attention_states=hidden_states,
+ attention_mask=cross_mask,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
+ output_attentions=False,
+ use_cache=False,
+ cache_position=None,
+ )
+ patch_embeds = patch_embeds + cross_attention_output
+
+ encoder_cross_states = patch_embeds
+ return hidden_states, encoder_cross_states
+
+ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids):
+ """
+ Reduce variable length patches to single embedding per patch
+ Note: this works with variable number of patches for different sequences in the batch
+ It handles variable length patches by assuming that patch_lengths will be 0 for any
+ extra patches on the *right*. Since there can be a variable number of patches
+ this function also return the number of patches for each sequence in the batch.
+ Any embeddings on the right that are not allocated to a patch
+ (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
+ will be sent to a dummy patch, which is trimmed before returning.
+ """
+ batch_size, _, embedding_dim = hidden_states.shape
+
+ patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])
+
+ reduced_embeddings = torch.zeros((batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device)
+ reduced_embeddings = reduced_embeddings.scatter_reduce(
+ src=hidden_states,
+ dim=1,
+ index=patch_ids,
+ reduce=reduction,
+ include_self=False,
+ )
+ reduced_embeddings = reduced_embeddings[:, :max_num_patches, :]
+
+ return reduced_embeddings
+
+
+class BLTLocalDecoder(nn.Module):
+ def __init__(self, config: BLTLocalDecoderConfig):
+ super().__init__()
+
+ # Extract config values to instance attributes
+ self.config = config
+ self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove
+
+ self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
+
+ self.rotary_emb = BLTRotaryEmbedding(config=config)
+
+ self.patch_embedding_projection = nn.Linear(
+ in_features=config.hidden_size_global,
+ out_features=config.hidden_size * config.cross_attn_k,
+ bias=False,
+ )
+
+ self.norm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps)
+
+ self.cross_attn_layers = torch.nn.ModuleList()
+ layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
+ for layer_idx in range(layers_to_add):
+ self.cross_attn_layers.append(
+ BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
+ )
+
+ # self.lm_head = nn.Linear(
+ # config.hidden_size,
+ # config.vocab_size,
+ # bias=False,
+ # )
+
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ embeds: Optional[torch.Tensor],
+ patch_embeds: Optional[torch.Tensor] = None,
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
+ cross_mask: Optional[torch.Tensor] = None,
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+ ):
+ batch_size, _, _ = embeds.shape
+
+ hidden_states = embeds
+
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
+ patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size)
+
+ if patch_embeds is not None and not self.cross_attn_decoder:
+ hidden_states = hidden_states + patch_embeds
+
+ position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1)
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
+ for i, layer in enumerate(self.layers):
+ if i == 0 or self.config.cross_attn_all_layers:
+ # Use cross attention to extract info from patch_embeds into hidden_states
+ cross_attention_output, _, _ = self.cross_attn_layers[i](
+ hidden_states=hidden_states,
+ cross_attention_states=patch_embeds,
+ attention_mask=cross_mask,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
+ output_attentions=False,
+ use_cache=False,
+ cache_position=None,
+ )
+ hidden_states = hidden_states + cross_attention_output
+
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None)
+ hidden_states = layer_outputs[0]
+
+ logits = self.norm(hidden_states)
+ # logits = self.lm_head(logits)
+ return logits, cache
+
+
+class BLTCrossAttention(nn.Module):
+ """Cross-attention module for BLT, following transformers style"""
+
+ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ # Use provided hidden_size or fallback to encoder dimension
+ self.hidden_size = hidden_size or config.encoder_config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.scaling = None #self.head_dim ** -0.5
+ self.dropout = config.dropout
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ self.q_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps)
+ self.k_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_norm(hidden_states) # BLT normalizes first
+ query_states = self.q_proj(query_states)
+
+ if cross_attention_states is not None:
+ cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first
+ key_states = self.k_proj(cross_attention_states)
+ value_states = self.v_proj(cross_attention_states)
+ if past_key_value is not None:
+ # if we have a new cross attention states + new tokens, we only computed key_states on that new cross attention states
+ # we still update the cross key states, past_cross_states, new_cross_states. And use it!
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ elif cache_position is not None and cache_position[0] != 0:
+ key_states, value_states = (
+ past_key_value.key_cache[self.layer_idx],
+ past_key_value.value_cache[self.layer_idx],
+ )
+ else:
+ if cross_attention_states is None:
+ raise ValueError(
+ "Cross attention layer can't find neither `cross_attention_states` nor cached values for key/values!"
+ )
+
+ attention_interface: Callable = eager_attention_forward
+
+ self.config._attn_implementation = "sdpa"
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and output_attentions:
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if full_text_row_masked_out_mask is not None:
+ attn_output = full_text_row_masked_out_mask[:, 0] * attn_output
+
+ attn_output = attn_output + hidden_states
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class BLTGlobalTransformer(nn.Module):
+ def __init__(self, config: BLTGlobalTransformerConfig):
+ super().__init__()
+
+ self.config = config
+
+ self.layers = nn.ModuleList()
+ for layer_idx in range(config.num_hidden_layers):
+ self.layers.append(BLTTransformerLayer(config, layer_idx))
+
+ self.rotary_emb = BLTRotaryEmbedding(config=config)
+
+
+ def forward(
+ self,
+ input_embeds: torch.Tensor,
+ mask: Optional[Union[BlockMask, torch.Tensor, str]] = None,
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+ ):
+ batch_size, seq_len, _ = input_embeds.shape
+
+ hidden_states = input_embeds
+
+ hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
+
+ position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1)
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for i, layer in enumerate(self.layers):
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None)
+ hidden_states = layer_outputs[0]
+
+ return hidden_states, cache
+
+
+
+
+class BLTPreTrainedModel(PreTrainedModel):
+ config_class = BLTConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = False # BLT uses its own attention implementation
+ _supports_sdpa = True
+ _supports_cache_class = False
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Linear):
+ std = getattr(module, '_custom_std', module.in_features ** (-0.5))
+ nn.init.trunc_normal_(
+ module.weight,
+ mean=0.0,
+ std=std,
+ a=-3 * std,
+ b=3 * std,
+ )
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ elif isinstance(module, nn.Embedding):
+ std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5))
+ nn.init.trunc_normal_(
+ module.weight,
+ mean=0.0,
+ std=std,
+ a=-3 * std,
+ b=3 * std,
+ )
+
+ elif isinstance(module, BLTModel):
+ if module.encoder_hash_tok_embedding is not None:
+ emb_std = module.config.encoder_config.hidden_size ** (-0.5)
+ for emb in module.encoder_hash_tok_embedding:
+ emb._custom_std = emb_std
+
+ elif isinstance(module, BLTLocalEncoder):
+ if module.patch_embedding_projection is not None:
+ module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5)
+
+ elif isinstance(module, BLTLocalDecoder):
+ if module.patch_embedding_projection is not None:
+ module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5)
+
+ elif isinstance(module, BLTPatcher):
+ emb_std = module.config.hidden_size ** (-0.5)
+ module.embed_tokens._custom_std = emb_std
+ module.lm_head._custom_std = emb_std
+
+ elif isinstance(module, BLTForCausalLM):
+ if module.lm_head is not None:
+ module.lm_head._custom_std = module.config.decoder_config.hidden_size ** (-0.5)
+
+
+class BLTModel(BLTPreTrainedModel):
+ def __init__(self, config: BLTConfig):
+ super().__init__(config)
+ self.config = config
+ self.local_encoder = BLTLocalEncoder(config.encoder_config)
+ self.global_transformer = BLTGlobalTransformer(config.global_config)
+ self.local_decoder = BLTLocalDecoder(config.decoder_config)
+ self.encoder_hash_tok_embedding = init_hash_embeddings(
+ config,
+ local_encoder_dim=config.encoder_config.hidden_size,
+ encoder_hash_byte_group_size=config.encoder_hash_byte_group_size,
+ )
+ if self.config.patch_in_forward:
+ self.patcher = BLTPatcher(config.patcher_config)
+ self.patcher.eval()
+ for param in self.patcher.parameters():
+ param.requires_grad = False
+ else:
+ self.patcher = None
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ patch_lengths: Optional[torch.Tensor] = None,
+ attention_mask=None,
+ position_ids=None,
+ past_key_values=None,
+ inputs_embeds=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ cache_position=None,
+ **kwargs,
+ ):
+ """
+ Args:
+ tokens (torch.Tensor): Input token ids.
+ patch_lengths (Optional[torch.Tensor]): Patch lengths for patching.
+ attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Ignored, for compatibility.
+ Returns:
+ torch.Tensor: Final hidden states (as before).
+ """
+ batch_size, sequence_length = tokens.shape
+ # Handle patching
+ if patch_lengths is None:
+ if self.config.patching_mode == PatchingModeEnum.entropy:
+ _, patch_lengths, _ = self.patcher(
+ tokens,
+ patch_size=self.config.patch_size,
+ threshold=self.config.patching_threshold,
+ max_patch_length=self.config.max_patch_length,
+ patching_batch_size=self.config.patching_batch_size,
+ device=tokens.device,
+ )
+ else:
+ patch_lengths = process_patch_lengths(
+ torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device),
+ self.config.max_patch_length
+ )
+ patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length)
+ cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask(
+ patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, torch.float32
+ )
+ encoder_embeds = compute_hash_embeddings(
+ tokens, self.local_encoder, self.encoder_hash_tok_embedding,
+ self.config.encoder_hash_byte_group_nb_functions,
+ self.config.encoder_hash_byte_group_size,
+ self.config.encoder_hash_byte_group_vocab,
+ )
+ encoder_hidden_states, encoder_cross_states = self.local_encoder(
+ input_ids=tokens,
+ input_embeds=encoder_embeds,
+ patch_embeds=None,
+ cross_mask=cross_attn_mask_enc,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc,
+ num_patches=patch_lengths.shape[1],
+ patch_ids=patch_ids,
+ )
+ global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
+ global_hidden_states, _ = self.global_transformer(
+ input_embeds=global_hidden_states,
+ )
+ decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length)
+ cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask(
+ decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, torch.float32
+ )
+ output, _ = self.local_decoder(
+ tokens=tokens,
+ embeds=encoder_hidden_states,
+ patch_embeds=global_hidden_states,
+ mask=None,
+ cross_mask=cross_attn_mask_dec,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec,
+ )
+ if output_hidden_states or output_attentions:
+ if return_dict:
+ return {"last_hidden_state": output, "hidden_states": None, "attentions": None}
+ else:
+ return (output, None, None)
+ return output
+
+ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
+ """Convert patch lengths to patch IDs for each token position."""
+ batch_size = patch_lengths.shape[0]
+ patch_starts = torch.cat([
+ torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
+ patch_lengths.cumsum(dim=-1)[:, :-1]
+ ], dim=-1)
+
+ token_positions = torch.arange(seq_len, device=patch_lengths.device)
+ return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1
+
+
+class BLTPatcher(BLTPreTrainedModel):
+ def __init__(self, config: BLTPatcherConfig):
+ super().__init__(config)
+
+ self.rotary_emb = BLTRotaryEmbedding(config=self.config)
+
+ self.layers = nn.ModuleList()
+
+ for layer_idx in range(self.config.num_hidden_layers):
+ self.layers.append(BLTTransformerLayer(self.config, layer_idx))
+
+
+ self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
+
+ self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps)
+
+ self.lm_head = nn.Linear(
+ self.config.hidden_size,
+ self.config.vocab_size,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ token_values: torch.Tensor,
+ patch_size: Optional[int] = None,
+ threshold: Optional[float] = None,
+ max_patch_length: Optional[int] = None,
+ patching_batch_size: int = 1,
+ device: Optional[str] = None,
+ ):
+
+ # Handle chunked processing for entropy calculation
+ entropies = []
+ predictions = []
+ max_length = self.config.max_position_embeddings
+ batch_numel = max_length * patching_batch_size
+ splits = torch.split(token_values.flatten(), batch_numel)
+
+ for split in splits:
+ pad_size = (max_length - (split.numel() % max_length)) % max_length
+ pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False)
+ split = torch.cat((split, pad), dim=0)
+ split = split.reshape(-1, max_length)
+ if device is not None:
+ split = split.to(device)
+
+ # Process chunk: embeddings -> layers -> output
+ batch_size, sequence_length = split.shape
+ input_embeds = self.embed_tokens(split)
+
+ hidden_states = input_embeds
+
+ batch_size, _, _ = input_embeds.shape
+
+ position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1)
+
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for i, layer in enumerate(self.layers):
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None)
+ hidden_states = layer_outputs[0]
+
+ logits = self.lm_head(self.norm(hidden_states))
+ logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :]
+ predictions.append(logits)
+ prediction_entropies = torch.distributions.Categorical(logits=logits).entropy()
+ entropies.append(prediction_entropies)
+
+ concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape)
+ concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1)
+
+ # Always compute patch lengths from concatenated entropies
+ batch_size, sequence_length = token_values.shape
+
+ # Find patch start IDs based on entropy
+ if patch_size is not None:
+ patch_lengths = self.patch_lengths_from_entropies(
+ entropies=concat_entropies,
+ sequence_length=sequence_length,
+ patch_size=patch_size,
+ threshold=threshold,
+ )
+ else:
+ # Default to byte-level patching
+ patch_lengths = torch.ones((batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device)
+ patch_lengths = process_patch_lengths(patch_lengths, max_patch_length)
+ return concat_entropies, patch_lengths, concat_predictions
+
+ @staticmethod
+ def patch_lengths_from_entropies(
+ entropies,
+ sequence_length,
+ patch_size=None,
+ threshold=None,
+ ):
+ """
+ Computes patch lengths from token entropies.
+
+ Depending on whether a threshold is provided, the function uses either:
+ - Top-k selection based on entropy (when `threshold` is None), or
+ - Thresholding the entropy values (when `threshold` is set).
+ """
+
+ batch_size = entropies.shape[0]
+
+ # Always include token 0 and 1 as starting tokens
+ init_tokens = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1)
+ offset = init_tokens.shape[1]
+
+ # Ignore first token entropy (BOS)
+ entropies = entropies[:, 1:]
+
+ if threshold is None:
+ # Use top-k entropy values to define patch start points
+ num_patches = sequence_length // patch_size
+ topk_indices = entropies.topk(num_patches - 2, dim=1).indices
+ patch_starts = topk_indices.sort(dim=1).values
+ else:
+ # Threshold the entropy values to define patch start points
+ patch_mask = entropies > threshold
+
+ seq_len = patch_mask.shape[1]
+
+ # Create patch IDs (token indices), and add a sentinel to ensure alignment
+ token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1)
+ sentinel = torch.full_like(token_indices, seq_len)
+ padded_indices = torch.cat([token_indices, sentinel], dim=1)
+
+ # Pad mask with inverse to align sentinel correctly
+ padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1)
+
+ # Select indices where mask is True
+ patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len)
+ max_valid_patches = patch_mask.sum(dim=1).max()
+ patch_starts = patch_starts[:, :max_valid_patches]
+
+ # Offset patch starts to account for the two initial tokens
+ patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1)
+
+ # Compute patch end positions by shifting start positions
+ last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1)
+ patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1)
+
+ patch_lengths = patch_ends - patch_start_ids + 1
+
+ return patch_lengths
+
+
+class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin):
+ config_class = BLTConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = BLTModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.local_encoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.local_encoder.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ past_key_values=None,
+ inputs_embeds=None,
+ labels=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ cache_position=None,
+ **kwargs,
+ ):
+ """
+ Args:
+ input_ids (torch.LongTensor): Input token ids.
+ attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Standard transformers arguments.
+ labels (torch.LongTensor, optional): Labels for language modeling loss.
+ Returns:
+ CausalLMOutputWithPast or tuple: Standard transformers output.
+ """
+ # Route only input_ids to BLTModel (as tokens)
+ hidden_states = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ if isinstance(hidden_states, dict):
+ sequence_output = hidden_states["last_hidden_state"]
+ elif isinstance(hidden_states, tuple):
+ sequence_output = hidden_states[0]
+ else:
+ sequence_output = hidden_states
+ logits = self.lm_head(sequence_output)
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ loss_fct = torch.nn.CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+ if not return_dict:
+ output = (logits,)
+ if loss is not None:
+ output = (loss,) + output
+ return output
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=None,
+ hidden_states=None,
+ attentions=None,
+ )
+
+__all__ = [
+ "BLTPreTrainedModel",
+ "BLTModel",
+ "BLTPatcher",
+ "BLTLocalEncoder",
+ "BLTLocalDecoder",
+ "BLTGlobalTransformer",
+ "BLTTransformerLayer",
+ "BLTForCausalLM",
+]
\ No newline at end of file
diff --git a/backup_blt_wip copy/modeling_blt_old.py b/backup_blt_wip copy/modeling_blt_old.py
new file mode 100644
index 0000000000000000000000000000000000000000..7005a4f0fbba1f6bf389c175d608281c068903cd
--- /dev/null
+++ b/backup_blt_wip copy/modeling_blt_old.py
@@ -0,0 +1,1602 @@
+#blt old
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import logging
+import os
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn
+import torch.nn as nn
+from torch.nn import functional as F
+from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
+
+from ...modeling_utils import PreTrainedModel
+from .configuration_blt_og import (
+ BLTConfig,
+ PatchingModeEnum,
+)
+
+RMSNorm = nn.RMSNorm
+
+logger = logging.getLogger()
+
+flex_attention_comp = flex_attention
+
+
+def causal_mask(b, h, q_idx, kv_idx):
+ return q_idx >= kv_idx
+
+
+def create_causal_mask(
+ seqlen,
+ attn_impl: str,
+ attn_bias_type: str | None,
+ *,
+ eos_id: int | None = None,
+ tokens: torch.Tensor | None = None,
+ sliding_window: int | None = None,
+):
+ if attn_impl == "sdpa":
+ BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0))
+
+ if attn_bias_type == "causal":
+ return "causal"
+
+ if BLT_SUPPRESS_ATTN_ERROR == 1:
+ return "causal"
+ else:
+ raise ValueError(
+ "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1"
+ )
+ elif attn_impl == "flex_attention":
+ return create_block_mask(causal_mask, None, None, seqlen, seqlen)
+ else:
+ raise NotImplementedError(f"Attention {attn_impl} with {sliding_window} sliding window not implemented")
+
+
+def cross_entropy(pred, target, **kwargs):
+ return F.nll_loss(
+ F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
+ target.flatten(end_dim=-1),
+ **kwargs,
+ )
+
+
+def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
+ assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
+ bs, slen, n_kv_heads, head_dim = x.shape
+ if n_rep == 1:
+ return x
+ return (
+ x[:, :, :, None, :]
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
+ )
+
+
+def precompute_freqs_cis(
+ dim: int,
+ end: int,
+ theta: float = 10000.0,
+ rope_use_fp32_in_outer_product: bool = False,
+):
+ """
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
+ The returned tensor contains complex values in complex64 data type.
+
+ Args:
+ dim (int): Dimension of the frequency tensor.
+ end (int): End index for precomputing frequencies.
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
+
+ Returns:
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
+ """
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ t = torch.arange(end, device=freqs.device)
+ if rope_use_fp32_in_outer_product:
+ t = t.to(torch.float32)
+
+ freqs = torch.outer(t, freqs).float()
+
+ cos, sin = freqs.cos(), freqs.sin()
+
+ return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
+ """
+ Reshape frequency tensor for broadcasting it with another tensor.
+
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
+
+ Args:
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
+ seq_dim (int): Sequence dimension index.
+
+ Returns:
+ torch.Tensor: Reshaped frequency tensor.
+ """
+ ndim = x.ndim
+ assert 0 <= seq_dim < ndim
+ assert freqs_cis.shape == (
+ x.shape[seq_dim],
+ x.shape[-3],
+ 2,
+ 2,
+ ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
+ shape = [d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])] + [2, 2]
+ return freqs_cis.view(*shape)
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ seq_dim: int,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
+ xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, seq_dim).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
+ xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
+ xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
+class RotaryEmbedding(torch.nn.Module):
+ """
+ RotaryEmbedding Module
+ """
+
+ def __init__(
+ self,
+ theta: float,
+ head_dim: int,
+ max_seqlen: int = 1024,
+ rope_use_fp32_in_outer_product: bool = False,
+ ):
+ super().__init__()
+
+ self.theta = theta
+ self.head_dim = head_dim
+ self.max_seqlen = max_seqlen
+ self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
+
+ self.register_buffer(
+ "freqs_cis",
+ precompute_freqs_cis(
+ dim=head_dim,
+ end=max_seqlen,
+ theta=theta,
+ rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
+ ),
+ persistent=False,
+ )
+
+
+ def forward(self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None):
+ """
+ Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
+ Args:
+ seqlen (int): Contiguous sequence length
+ tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
+
+ Returns:
+ Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
+ """
+ test = (seqlen is not None) or (tok_idx is not None)
+ assert test, "Should provide atleast seqlen or tok_idx"
+ if tok_idx is not None:
+ return self.freqs_cis[tok_idx]
+ elif seqlen is not None:
+ return self.freqs_cis[0:seqlen]
+
+
+class BLTSelfAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ head_dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ rope_theta: float,
+ ):
+ super().__init__()
+
+ self.dim = dim
+ self.head_dim = head_dim
+ self.rope_theta = rope_theta
+
+ self.n_heads = n_heads
+ self.n_kv_heads = n_kv_heads
+ self.heads_per_group = self.n_heads // self.n_kv_heads
+
+ self.wq = nn.Linear(
+ dim,
+ n_heads * head_dim,
+ bias=False,
+ )
+ self.wk = nn.Linear(
+ dim,
+ n_kv_heads * head_dim,
+ bias=False,
+ )
+ self.wv = nn.Linear(
+ dim,
+ n_kv_heads * head_dim,
+ bias=False,
+ )
+
+ self.wo = nn.Linear(
+ n_heads * head_dim,
+ dim,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ freq_cis: torch.Tensor,
+ tok_idx: Optional[torch.Tensor] = None,
+ mask: Optional[Union[BlockMask, str]] = None,
+ attn_impl: str = "sdpa",
+ ) -> torch.Tensor:
+ # B S D
+ bsz, seq_len, dim = x.shape
+
+ xq = self.wq(x.view_as(x))
+ xk = self.wk(x.view_as(x))
+ xv = self.wv(x.view_as(x))
+
+ output_shape = xq.shape
+ # B S D -> B S H D
+ xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
+ xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
+ xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
+
+ xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
+
+ # This condition helps us be easily compatible
+ # with inference by adding a pluggable KVCache
+ if hasattr(self, "kv_cache"):
+ xk, xv = self.kv_cache.update(xk, xv, tok_idx)
+
+ xk = repeat_kv(xk, self.heads_per_group, dim=2)
+ xv = repeat_kv(xv, self.heads_per_group, dim=2)
+
+ if attn_impl == "flex_attention":
+ assert mask is None or isinstance(mask, BlockMask)
+ xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv))
+ output = flex_attention_comp(xq, xk, xv, block_mask=mask)
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
+
+ elif attn_impl == "sdpa":
+ xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv))
+ assert mask is None or isinstance(mask, (str, torch.Tensor))
+ is_causal = (mask == "causal") if isinstance(mask, str) else False
+ mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None
+ output = F.scaled_dot_product_attention(
+ xq,
+ xk,
+ xv,
+ is_causal=is_causal,
+ attn_mask=mask,
+ )
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
+ else:
+ raise NotImplementedError(f"Attention implementation {attn_impl} not supported")
+
+ output_reshaped = output.reshape(output_shape)
+
+ output = self.wo(output_reshaped)
+
+ return output
+
+
+class BLTMLP(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ multiple_of: int,
+ ffn_dim_multiplier: Optional[float],
+ mp_size: int = 1,
+ ):
+ super().__init__()
+
+ hidden_dim = int(2 * hidden_dim / 3)
+ if ffn_dim_multiplier is not None:
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+ assert hidden_dim % mp_size == 0
+
+ self.dim = dim
+ self.hidden_dim = hidden_dim
+
+ self.w1 = nn.Linear(
+ dim,
+ hidden_dim,
+ bias=False,
+ )
+ self.w3 = nn.Linear(
+ dim,
+ hidden_dim,
+ bias=False,
+ )
+ self.w2 = nn.Linear(
+ hidden_dim,
+ dim,
+ bias=False,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # B S D
+ x1 = self.w1(x.view_as(x))
+ x3 = self.w3(x.view_as(x))
+ output = self.w2(F.silu(x1) * x3)
+ return output
+
+
+
+
+class BLTTransformerLayer(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+
+ # Extract parameters from dictionary
+ dim = args["dim"]
+ n_heads = args["n_heads"]
+ head_dim = args["head_dim"]
+ n_kv_heads = args["n_kv_heads"]
+ rope_theta = args["rope_theta"]
+ multiple_of = args["multiple_of"]
+ ffn_dim_multiplier = args["ffn_dim_multiplier"]
+ norm_eps = args["norm_eps"]
+
+ assert (head_dim is not None) or (n_heads is not None), "Should specify at least head_dim or n_heads"
+ self.head_dim = head_dim or dim // n_heads
+ self.n_heads = n_heads or dim // head_dim
+ self.n_kv_heads = n_kv_heads or self.n_heads
+
+ assert n_heads % self.n_kv_heads == 0
+ assert dim % n_heads == 0
+
+ self.attention = BLTSelfAttention(
+ dim=dim,
+ head_dim=self.head_dim,
+ n_heads=self.n_heads,
+ n_kv_heads=self.n_kv_heads,
+ rope_theta=rope_theta,
+ )
+ self.feed_forward = BLTMLP(
+ dim=dim,
+ hidden_dim=4 * dim,
+ multiple_of=multiple_of,
+ ffn_dim_multiplier=ffn_dim_multiplier,
+ )
+ self.attention_norm = RMSNorm(dim, eps=norm_eps)
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ freq_cis: torch.Tensor,
+ tok_idx: Optional[torch.Tensor] = None,
+ mask: Optional[Union[BlockMask, str]] = None,
+ attn_impl: str = "sdpa",
+ ) -> torch.Tensor:
+ norm_x = self.attention_norm(x)
+ attn_out = self.attention(
+ norm_x,
+ freq_cis,
+ tok_idx=tok_idx,
+ mask=mask,
+ attn_impl=attn_impl,
+ )
+ h = x + attn_out
+ h_norm = self.ffn_norm(h)
+ out = h + self.feed_forward(h_norm)
+ return out
+
+def check_non_zero_after_zero(tensor):
+ zero_mask = tensor == 0
+ shifted_mask = torch.cat(
+ [
+ torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
+ zero_mask[:, :-1],
+ ],
+ dim=1,
+ )
+ non_zero_after_zero = (tensor != 0) & shifted_mask
+ return non_zero_after_zero.any()
+
+def rolling_polynomial_hash(t, hash_func_nb: int = 0):
+ primes = [
+ 1000000007,
+ 5915587277,
+ 1500450271,
+ 3267000013,
+ 5754853343,
+ 4093082899,
+ 9576890767,
+ 3628273133,
+ 2860486313,
+ 5463458053,
+ 3367900313,
+ ]
+ prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device)
+ prime_powers = torch.stack([prime**i for i in range(t.shape[-1])])
+ return torch.sum(t * prime_powers, dim=-1)
+
+
+def byte_group_hash_function(x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000):
+ """
+ Returns a hash of the input x and maps it to a value in the range [0, max_hash].
+
+ expects: x of shape (batch_size, seq_len) with values as ids in the token vocab.
+ returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash].
+
+ Note: max hash can make a big difference on the number of collisions.
+ """
+ with torch.no_grad():
+ bs, seq_len = x.shape
+ prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device)
+ x = torch.cat([prefix, x], dim=1)
+ windows = x.unfold(1, group_size, 1)
+ # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
+ hashes = rolling_polynomial_hash(windows, hash_func_nb)
+ hash_values_range = hashes % max_hash
+ hash_values_range.requires_grad = False
+ return hash_values_range
+
+
+def create_patch_mask_from_ids(patch_ids, num_patches, window=None, patches_as_queries=False):
+ """
+ Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k)
+ is True if the patch id at position (i, j) is less than or equal to k.
+ Args:
+ patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids.
+ num_patches (int): Total number of patches.
+ window (int): If not None, only considers patches within a window of size window.
+ patches_as_queries (bool): If True, the patches are used as queries
+ Returns:
+ torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask.
+ """
+ bs, seq_len = patch_ids.shape
+ if not patches_as_queries:
+ q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches)
+ kv_ids = (
+ torch.arange(num_patches, device=patch_ids.device)
+ .unsqueeze(0)
+ .unsqueeze(0)
+ .expand(bs, seq_len, num_patches)
+ )
+ else:
+ kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len)
+ q_ids = (
+ torch.arange(num_patches, device=patch_ids.device)
+ .unsqueeze(0)
+ .unsqueeze(-1)
+ .expand(bs, num_patches, seq_len)
+ )
+ if window is None:
+ mask = q_ids == kv_ids
+ else:
+ mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window)
+ return mask
+
+
+def cross_attn_mask(
+ patch_ids,
+ patch_lengths,
+ N,
+ patches_as_queries=False,
+ cross_attn_k=1,
+ window=None,
+ block_mask=True,
+):
+ bs = patch_ids.shape[0]
+ with torch.no_grad():
+ # Create the patch mask
+ cross_mask = create_patch_mask_from_ids(
+ patch_ids,
+ patch_lengths.shape[1],
+ window=window,
+ patches_as_queries=patches_as_queries,
+ ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1)
+ q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N
+ kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k
+ assert cross_mask.shape == (
+ bs,
+ q_len,
+ kv_len,
+ ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}"
+ # block_mask = None
+ if block_mask:
+
+ def patch_mask(b, h, q_idx, kv_idx):
+ return cross_mask[b, q_idx, kv_idx]
+
+ block_mask = create_block_mask(
+ patch_mask,
+ B=bs,
+ H=None,
+ Q_LEN=q_len,
+ KV_LEN=kv_len,
+ _compile=True,
+ )
+ return block_mask
+ else:
+ return torch.where(cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))).unsqueeze(
+ 1
+ ) # [bs, 1, q_len, kv_len]
+
+
+def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor:
+ if max_patch_length is None:
+ return patch_lengths
+
+ batch_size = patch_lengths.size(0)
+ split_all = []
+ max_len = 0
+
+ for seq in patch_lengths:
+ splits = []
+ for length in seq[seq > 0]:
+ # Split long patches into max_patch_length chunks
+ full, rem = divmod(length.item(), max_patch_length)
+ splits.extend([max_patch_length] * full + ([rem] if rem else []))
+ split_all.append(splits)
+ max_len = max(max_len, len(splits))
+
+ # Pad sequences to the maximum length
+ padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
+ for i, splits in enumerate(split_all):
+ if splits:
+ padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
+
+ # Trim trailing columns that are all zeros
+ last_non_zero = (padded != 0).flip(1).int().argmax(1).min()
+ if last_non_zero < padded.shape[1]:
+ padded = padded[:, :padded.shape[1] - last_non_zero]
+
+ return padded
+
+class BLTLocalModelBase(nn.Module):
+ def __init__(self, config: BLTConfig, component_type: str = "encoder"):
+ super().__init__()
+
+ self.config = config
+
+ if component_type == "encoder":
+ self.dim = config.dim_local_encoder
+ self.n_layers = config.n_layers_local_encoder
+ self.n_heads = config.n_heads_local_encoder
+ self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen
+ self.attn_bias_type = "local_block_causal"
+ self.sliding_window = config.local_attention_window_len
+ elif component_type == "decoder":
+ self.dim = config.dim_local_decoder
+ self.n_layers = config.n_layers_local_decoder
+ self.n_heads = config.n_heads_local_decoder
+ self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen
+ self.attn_bias_type = "local_block_causal"
+ self.sliding_window = config.local_attention_window_len
+ else:
+ raise ValueError(f"Unknown component_type: {component_type}")
+
+ self.dropout = config.dropout
+ self.vocab_size = config.vocab_size + config.pm_size
+ self.patch_size = config.patch_size
+
+ self.attn_impl = config.attn_impl
+ self.use_rope = config.use_rope
+ self.init_std_factor = config.init_std_factor
+ self.init_base_std = config.init_base_std
+ self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None)
+ self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None)
+ self.cross_attn_k = getattr(config, "cross_attn_k", None)
+ self.eos_id = config.eos_token_id
+
+ self.boe_id = config.boe_id
+
+ # Initialize cross attention layers as None (will be set by subclasses if needed)
+ self.cross_attn_layers = None
+
+ # Create parameter dict for BLTTransformerLayers
+ layer_params = {
+ "dim": self.dim,
+ "n_heads": self.n_heads,
+ "head_dim": config.head_dim,
+ "n_kv_heads": getattr(config, "n_kv_heads", None),
+ "rope_theta": config.rope_theta,
+ "multiple_of": getattr(config, "multiple_of", 256),
+ "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None),
+ "norm_eps": config.norm_eps,
+ }
+
+ self.layers = nn.ModuleList([BLTTransformerLayer(layer_params) for _ in range(self.n_layers)])
+
+ if not self.use_rope:
+ self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length
+ else:
+ self.rope = RotaryEmbedding(
+ theta=config.rope_theta,
+ head_dim=config.head_dim or self.dim // self.n_heads,
+ max_seqlen=self.max_seqlen,
+ rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product,
+ )
+ self.pos_embeddings = None
+
+ # Set dimension-specific embedding dimensions
+ if component_type == "encoder":
+ self.dim_token_emb = config.encoder_dim_token_emb
+ self.dim_patch_emb = config.encoder_dim_patch_emb
+ elif component_type == "decoder":
+ self.dim_token_emb = config.decoder_dim_token_emb
+ self.dim_patch_emb = config.dim_global
+
+ self.token_embedding_projection = (
+ nn.Linear(self.dim_token_emb, self.dim, bias=False)
+ if self.dim_token_emb is not None and self.dim_token_emb != self.dim
+ else None
+ )
+
+ self.patch_embedding_projection = self._create_patch_projection(config)
+
+ def _should_create_patch_projection(self, config: BLTConfig):
+ dimension_mismatch = self.dim_patch_emb is not None and self.dim_patch_emb != self.dim
+
+ # Check cross attention conditions
+ cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or (
+ config.cross_attn_decoder and config.cross_attn_init_by_pooling
+ )
+
+ return dimension_mismatch or cross_attn_conditions
+
+ def _create_patch_projection(self, config):
+ if not self._should_create_patch_projection(config):
+ return None
+
+ output_dim = self.dim_token_emb * (self.cross_attn_k or 1)
+
+ return nn.Linear(
+ in_features=self.dim_patch_emb,
+ out_features=output_dim,
+ bias=False,
+ )
+
+ def apply_embedding(self, tokens, embeds):
+ if embeds is not None:
+ return embeds
+ else:
+ return self.tok_embeddings(tokens)
+
+
+class BLTLocalEncoder(BLTLocalModelBase):
+ def __init__(self, config: BLTConfig):
+ super().__init__(config, component_type="encoder")
+
+ self.apply_transformer = config.use_local_encoder_transformer
+ self.downsampling_by_pooling = config.downsampling_by_pooling
+ self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None
+ self.cross_attn_encoder = config.cross_attn_encoder
+ self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder
+ self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling
+ self.cross_attn_nheads = config.cross_attn_nheads
+
+ self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim)
+
+ if self.cross_attn_encoder:
+ self.cross_attn_layers = torch.nn.ModuleList()
+ layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1
+ for _ in range(layers_to_add):
+ self.cross_attn_layers.append(
+ BLTCrossAttention(
+ dim=self.dim,
+ head_dim=self.dim // self.cross_attn_nheads,
+ n_heads=self.cross_attn_nheads,
+ n_kv_heads=self.cross_attn_nheads,
+ norm_eps=config.norm_eps,
+ )
+ )
+
+ def apply_embedding(self, tokens, embeds):
+ if embeds is not None:
+ assert self.expects_hash_embeddings, "Not expecting embeddings to be passed."
+ return embeds
+ else:
+ return self.tok_embeddings(tokens)
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ embeds: Optional[torch.Tensor] = None,
+ patch_embeds: Optional[torch.Tensor] = None,
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
+ cross_mask: Optional[torch.Tensor] = None,
+ num_patches: Optional[int] = None,
+ patch_ids: Optional[torch.Tensor] = None,
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+ ):
+ """ """
+ bs, seqlen = tokens.shape
+ if mask is None:
+ mask = create_causal_mask(
+ seqlen,
+ self.attn_impl,
+ "local_block_causal",
+ sliding_window=self.sliding_window,
+ tokens=tokens,
+ eos_id=self.eos_id,
+ )
+
+ h = self.apply_embedding(tokens, embeds)
+
+
+ freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
+
+
+ h = F.dropout(h, p=self.dropout, training=self.training)
+
+ for i, layer in enumerate(self.layers):
+ h = layer(h, freqs_cis, tok_idx=None, mask=mask, attn_impl=self.attn_impl)
+ # check if cross attention should be applied to either all layer or only the last layer
+ if self.cross_attn_encoder and (i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder):
+ # apply pooling and project
+ if self.cross_attn_init_by_pooling and patch_embeds is None:
+ patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids)
+ if self.patch_embedding_projection is not None:
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
+ patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim)
+
+ layer_idx = i if self.cross_attn_all_layers_encoder else 0
+ patch_embeds_cross = self.cross_attn_layers[layer_idx](
+ x=patch_embeds,
+ kv=h,
+ mask=cross_mask,
+ )
+ patch_embeds = patch_embeds + patch_embeds_cross
+
+ h_residual = patch_embeds if self.cross_attn_encoder else None
+ return (h, h_residual), cache
+
+ def patch_reduce(self, h, max_num_patches, reduction, patch_ids):
+ """
+ Reduce variable length patches to single embedding per patch
+ Note: this works with variable number of patches for different sequences in the batch
+ It handles variable length patches by assuming that patch_lengths will be 0 for any
+ extra patches on the *right*. Since there can be a variable number of patches
+ this function also return the number of patches for each sequence in the batch.
+ Any embeddings on the right that are not allocated to a patch
+ (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
+ will be sent to a dummy patch, which is trimmed before returning.
+ """
+ bs, seq_len, emb_dim = h.shape
+
+ patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
+
+ reduced_embs = torch.zeros((bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device)
+ reduced_embs = reduced_embs.scatter_reduce(
+ src=h,
+ dim=1,
+ index=patch_ids,
+ reduce=reduction,
+ include_self=False,
+ )
+ reduced_embs = reduced_embs[:, :max_num_patches, :]
+
+ return reduced_embs
+
+
+class BLTLocalDecoder(BLTLocalModelBase):
+ def __init__(self, config: BLTConfig):
+ super().__init__(config, component_type="decoder")
+
+ # Model configuration flags
+ self.cross_attn_decoder = config.cross_attn_decoder
+ self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder
+ self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling
+ self.cross_attn_nheads = config.cross_attn_nheads
+
+ self.norm = RMSNorm(self.dim, eps=config.norm_eps)
+
+ if self.cross_attn_decoder:
+ self.cross_attn_layers = torch.nn.ModuleList()
+ layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1
+ for _ in range(layers_to_add):
+ self.cross_attn_layers.append(
+ BLTCrossAttention(
+ dim=self.dim,
+ head_dim=self.dim // self.cross_attn_nheads,
+ n_heads=self.cross_attn_nheads,
+ n_kv_heads=self.cross_attn_nheads,
+ norm_eps=config.norm_eps,
+ )
+ )
+
+ self.output = nn.Linear(
+ self.dim,
+ config.vocab_size,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ embeds: Optional[torch.Tensor],
+ patch_embeds: Optional[torch.Tensor] = None,
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
+ cross_mask: Optional[torch.Tensor] = None,
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+ ):
+ bs, seqlen = tokens.shape
+ assert embeds is not None, "Embeddings must be provided"
+
+ if mask is None:
+ mask = create_causal_mask(
+ seqlen,
+ self.attn_impl,
+ "local_block_causal",
+ sliding_window=self.sliding_window,
+ tokens=tokens,
+ eos_id=self.eos_id,
+ )
+
+ h = embeds
+
+ if self.patch_embedding_projection is not None:
+ assert patch_embeds is not None, "Patch embeddings must be passed."
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
+ if self.cross_attn_k is not None:
+ patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim)
+
+ if patch_embeds is not None and not self.cross_attn_decoder:
+ h = h + patch_embeds
+
+ freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
+
+ h = F.dropout(h, p=self.dropout, training=self.training)
+ for i, layer in enumerate(self.layers):
+ if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder):
+ # Use cross attention to extract info from patch_embeds into h
+ h_cross = self.cross_attn_layers[i](
+ x=h,
+ kv=patch_embeds,
+ mask=cross_mask,
+ )
+ h = h + h_cross
+
+ h = layer(h, freqs_cis, tok_idx=None, mask=mask, attn_impl=self.attn_impl)
+
+ h_preds = self.norm(h)
+ h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
+ h_preds = self.output(h_preds)
+ h_preds = h_preds.float()
+ return h_preds, cache
+
+
+class BLTCrossAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ head_dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ norm_eps: float,
+ ):
+ super().__init__()
+
+ self.dim = dim
+ self.head_dim = head_dim
+
+ self.n_heads = n_heads
+ self.n_kv_heads = n_kv_heads
+ self.heads_per_group = self.n_heads // self.n_kv_heads
+
+ self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps)
+ self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
+
+ self.wq = nn.Linear(
+ dim,
+ n_heads * head_dim,
+ bias=False,
+ )
+ self.wk = nn.Linear(
+ dim,
+ n_kv_heads * head_dim,
+ bias=False,
+ )
+ self.wv = nn.Linear(
+ dim,
+ n_kv_heads * head_dim,
+ bias=False,
+ )
+
+ self.wo = nn.Linear(
+ n_heads * head_dim,
+ dim,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ kv: torch.Tensor,
+ mask: Optional[Union[BlockMask, str]] = None,
+ ) -> torch.Tensor:
+ # B S D
+ bsz, seq_len, _ = x.shape
+ _, slen_kv, _ = kv.shape
+ x_norm = self.cross_attn_norm_q(x)
+ kv = self.cross_attn_norm_kv(kv)
+
+ xq = self.wq(x_norm)
+ xk = self.wk(kv)
+ xv = self.wv(kv)
+
+ output_shape = xq.shape
+ # B S D -> B S H D
+ xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
+ xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
+ xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
+
+ xk = repeat_kv(xk, self.heads_per_group, dim=2)
+ xv = repeat_kv(xv, self.heads_per_group, dim=2)
+
+ # assert mask is None or isinstance(mask, BlockMask)
+ xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv))
+ # output = flex_attention_comp(xq, xk, xv, block_mask=mask)
+ is_causal = (mask == "causal") if isinstance(mask, str) else False
+ mask = mask if isinstance(mask, torch.Tensor) else None
+ mask = mask.to(dtype=xq.dtype).to(xq.device)
+ output = F.scaled_dot_product_attention(
+ xq,
+ xk,
+ xv,
+ is_causal=is_causal,
+ attn_mask=mask,
+ )
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
+
+ output = self.wo(output.reshape(output_shape))
+
+ return x + output
+
+
+class BLTGlobalTransformer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+
+ self.dim = config.dim_global
+ self.rope_embeddings = RotaryEmbedding(
+ theta=config.rope_theta,
+ head_dim=config.head_dim or self.config.dim_global // config.n_heads_global,
+ max_seqlen=config.max_seqlen,
+ rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product,
+ )
+ # Handle both eos_id and eos_token_id for compatibility
+ self.eos_id = getattr(config, "eos_id", getattr(config, "eos_token_id", 2))
+
+ # Create parameter dict for BLTTransformerLayers
+ layer_params = {
+ "dim": self.dim,
+ "n_heads": config.n_heads_global,
+ "head_dim": config.head_dim,
+ "n_kv_heads": getattr(config, "n_kv_heads_global", None),
+ "rope_theta": config.rope_theta,
+ "multiple_of": getattr(config, "multiple_of", 256),
+ "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None),
+ "norm_eps": config.norm_eps,
+ }
+
+ self.layers = nn.ModuleList()
+ for _ in range(config.n_layers_global):
+ self.layers.append(BLTTransformerLayer(layer_params))
+
+ self.token_embedding_projection = None
+ if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim:
+ self.token_embedding_projection = nn.Linear(
+ config.global_dim_patch_emb,
+ config.dim_global,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ tok_idx: Optional[torch.Tensor] = None,
+ embeds: Optional[torch.Tensor] = None,
+ mask: Optional[Union[BlockMask, torch.Tensor, str]] = None,
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+ ):
+ bs, seqlen = tokens.shape
+
+ h = embeds
+
+ mask = (
+ mask
+ if mask is not None
+ else create_causal_mask(
+ seqlen,
+ self.config.attn_impl,
+ self.config.attn_bias_type,
+ tokens=tokens,
+ eos_id=self.eos_id,
+ )
+ )
+
+ if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
+ h = self.token_embedding_projection(h)
+
+ h = F.dropout(h, p=self.config.dropout, training=self.training)
+ freq_cis = self.rope_embeddings(seqlen=self.config.max_seqlen, tok_idx=tok_idx)
+
+ for i, layer in enumerate(self.layers):
+ h = layer(h, freq_cis, tok_idx=None, mask=mask, attn_impl=self.config.attn_impl)
+
+ return h, cache
+
+
+def compute_hash_embeddings(
+ local_encoder_tokens: torch.Tensor,
+ local_encoder,
+ encoder_hash_tok_embedding: nn.ModuleList,
+ encoder_hash_byte_group_nb_functions: int,
+ encoder_hash_byte_group_size: list,
+ encoder_hash_byte_group_vocab: int,
+) -> torch.Tensor:
+ """
+ Compute embeddings using hash token embeddings.
+
+ Args:
+ local_encoder_tokens: Input tokens tensor
+ local_encoder: Encoder object with tok_embeddings method
+ encoder_hash_tok_embedding: ModuleList of hash token embeddings
+ encoder_hash_byte_group_nb_functions: Number of hash functions
+ encoder_hash_byte_group_size: List of byte group sizes
+ encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings
+
+ Returns:
+ torch.Tensor: Combined embeddings
+ """
+ if encoder_hash_tok_embedding is None:
+ return None
+
+ local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens)
+
+ i = 0
+ for func_nb in range(encoder_hash_byte_group_nb_functions):
+ for byte_group_size in encoder_hash_byte_group_size:
+ hash_ids = byte_group_hash_function(
+ local_encoder_tokens,
+ byte_group_size,
+ hash_func_nb=func_nb,
+ max_hash=encoder_hash_byte_group_vocab,
+ )
+ hash_tok_embedding = encoder_hash_tok_embedding[i]
+ local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids)
+ i += 1
+
+ assert i == len(encoder_hash_tok_embedding)
+ return local_encoder_embeds
+
+
+class BLTPreTrainedModel(PreTrainedModel):
+ config_class = BLTConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = False # BLT uses its own attention implementation
+ _supports_sdpa = True
+ _supports_cache_class = False
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Linear):
+ std = getattr(module, '_custom_std', module.in_features ** (-0.5))
+
+ nn.init.trunc_normal_(
+ module.weight,
+ mean=0.0,
+ std=std,
+ a=-3 * std,
+ b=3 * std,
+ )
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ elif isinstance(module, nn.Embedding):
+ std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5))
+
+ nn.init.trunc_normal_(
+ module.weight,
+ mean=0.0,
+ std=std,
+ a=-3 * std,
+ b=3 * std,
+ )
+
+ elif isinstance(module, (nn.RMSNorm, nn.LayerNorm)):
+ nn.init.ones_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ elif isinstance(module, RotaryEmbedding):
+ module.freqs_cis[...] = precompute_freqs_cis(
+ dim=module.head_dim,
+ end=module.max_seqlen,
+ theta=module.theta,
+ rope_use_fp32_in_outer_product=module.rope_use_fp32_in_outer_product,
+ )
+
+ elif isinstance(module, BLTModel):
+ if module.encoder_hash_tok_embedding is not None:
+ emb_std = module.local_encoder.dim ** (-0.5)
+ for emb in module.encoder_hash_tok_embedding:
+ emb._custom_std = emb_std
+
+ elif isinstance(module, (BLTLocalEncoder, BLTLocalDecoder)):
+ if module.token_embedding_projection is not None:
+ module.token_embedding_projection._custom_std = module.dim ** (-0.5)
+
+ if module.patch_embedding_projection is not None:
+ module.patch_embedding_projection._custom_std = module.dim_patch_emb ** (-0.5)
+
+ elif isinstance(module, BLTGlobalTransformer):
+ if module.token_embedding_projection is not None:
+ module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5)
+
+ elif isinstance(module, BLTPatcher):
+ emb_std = module.config.patcher_dim ** (-0.5)
+ module.tok_embeddings._custom_std = emb_std
+ module.output._custom_std = emb_std
+
+
+class BLTModel(BLTPreTrainedModel):
+ def __init__(self, config: BLTConfig):
+ super().__init__(config)
+
+ self.config = config
+ self.local_encoder = BLTLocalEncoder(config)
+ self.global_transformer = BLTGlobalTransformer(config)
+ self.local_decoder = BLTLocalDecoder(config)
+
+ self.encoder_hash_tok_embedding = init_hash_embeddings(
+ config,
+ local_encoder_dim=self.local_encoder.dim,
+ encoder_hash_byte_group_size=config.encoder_hash_byte_group_size,
+ )
+
+ if config.patch_in_forward:
+ self.patcher = BLTPatcher(config)
+ self.patcher.eval()
+ for param in self.patcher.parameters():
+ param.requires_grad = False
+ else:
+ self.patcher = None
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ patch_lengths: Optional[torch.Tensor] = None,
+ ):
+ # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings
+ # are no longer used in the final BLT model
+
+ bs, N = tokens.shape # Batch size and sequence length
+
+ local_encoder_tokens, local_decoder_tokens = tokens, tokens
+
+ # Patching
+ if patch_lengths is None:
+ # assert (
+ # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward
+ # ), "Patch in forward not enabled and no patch_lengths passed."
+
+ # PATCHER MODEL DEFINED
+ if self.config.patching_mode == PatchingModeEnum.entropy:
+ _, patch_lengths, _ = self.patcher(
+ local_encoder_tokens,
+ patch_size=self.config.patch_size,
+ include_next_token=True,
+ threshold=self.config.patching_threshold,
+ max_patch_length=self.config.max_patch_length,
+ patching_batch_size=self.config.patching_batch_size,
+ device=self.config.patching_device,
+ )
+ else:
+ # self.config.patching_mode == PatchingModeEnum.byte
+ bs, seq_len = local_encoder_tokens.shape
+ seq_len_next_tok = seq_len + 1 # include_next_token=True
+ patch_lengths = torch.ones(
+ (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device
+ )
+
+ patch_lengths = process_patch_lengths(patch_lengths, self.config.max_patch_length)
+
+ #assert torch.min(patch_lengths) >= 0
+ # Generate patch IDs from patch_lengths
+ patch_ids = self._patch_ids_from_lengths(patch_lengths, local_encoder_tokens.shape[-1])
+ # assert torch.max(patch_ids) + 1 <= torch.max((patch_lengths != 0).sum(dim=-1)), (
+ # f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}"
+ # )
+
+ cross_attn_mask_enc = None
+ # Cross-attention encoder
+ if self.config.cross_attn_encoder:
+ cross_attn_mask_enc = cross_attn_mask(
+ patch_ids,
+ patch_lengths,
+ N,
+ patches_as_queries=True,
+ cross_attn_k=self.config.cross_attn_k,
+ window=self.config.cross_attn_window_encoder,
+ block_mask=self.config.cross_attn_use_flex_attention,
+ )
+
+ # Hashing and embedding
+ local_encoder_embeds = compute_hash_embeddings(
+ local_encoder_tokens=local_encoder_tokens,
+ local_encoder=self.local_encoder,
+ encoder_hash_tok_embedding=self.encoder_hash_tok_embedding,
+ encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions,
+ encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size,
+ encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab,
+ )
+
+ # NOTE: Frequency-based n-gram embeddings removed as per paper
+ # The final BLT model uses only hash-based n-gram embeddings
+
+ # Local encoder
+ (h_encoder, h_cross), cache_encoder = self.local_encoder(
+ tokens=local_encoder_tokens,
+ embeds=local_encoder_embeds,
+ patch_embeds=None,
+ cross_mask=cross_attn_mask_enc,
+ num_patches=patch_lengths.shape[1],
+ patch_ids=patch_ids,
+ )
+
+ # Downsampling
+ h = h_cross.view(bs, patch_lengths.shape[1], -1)
+
+ # Global transformer
+ global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.config.boe_id)
+ rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id)
+ eos_patch_ids = patch_ids[rows, cols]
+ global_tokens[rows, eos_patch_ids] = self.config.eos_token_id
+
+ h, _ = self.global_transformer(
+ embeds=h,
+ tokens=global_tokens,
+ )
+
+ # Unpatching
+
+ dec_embeds = h_encoder
+
+ # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens), so we need to map decoder positions to the remaining patches.
+ decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], local_decoder_tokens.shape[-1])
+ # assert torch.max(decoder_patch_ids) + 1 <= h.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}"
+ # assert decoder_patch_ids.shape[1] == dec_embeds.shape[1], (
+ # f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}"
+ # )
+
+ # Cross-attention decoder
+ if not self.config.cross_attn_decoder:
+ h = torch.gather(h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]))
+ cross_attn_mask_dec = None
+ # assert local_decoder_tokens.shape == h.shape[:-1]
+ else:
+ cross_attn_mask_dec = cross_attn_mask(
+ decoder_patch_ids,
+ patch_lengths,
+ N,
+ patches_as_queries=False,
+ cross_attn_k=self.config.cross_attn_k,
+ window=self.config.cross_attn_window_decoder,
+ block_mask=self.config.cross_attn_use_flex_attention,
+ )
+
+ # Local decoder
+ output, _ = self.local_decoder(
+ embeds=dec_embeds,
+ patch_embeds=h,
+ tokens=local_decoder_tokens,
+ cross_mask=cross_attn_mask_dec,
+ )
+ return output
+
+ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
+ """
+ Convert patch lengths to patch IDs for each token position.
+ For each token position in the sequence, determines which patch it belongs to.
+
+ Args:
+ patch_lengths: [batch_size, num_patches] - length of each patch
+ seq_len: total sequence length
+
+ Returns:
+ patch_ids: [batch_size, seq_len] - patch index for each token position
+
+ Example:
+ patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1
+ seq_len = 10
+ Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]]
+ # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3
+ """
+ batch_size, num_patches = patch_lengths.shape
+
+ # Create patch start positions: [0, 3, 5, 9] for the example above
+ patch_starts = torch.cat(
+ [
+ torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
+ patch_lengths.cumsum(dim=-1)[:, :-1], # cumsum without the final total
+ ],
+ dim=-1,
+ )
+
+ # For each token position, find which patch it belongs to
+ # by finding the rightmost patch start that's <= the position
+ token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1]
+
+ # Broadcasting: patch_starts[batch, patch] <= token_positions[position]
+ # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t
+ position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)
+
+ # Count how many patch starts are <= each position, then subtract 1 to get patch index
+ patch_ids = position_ge_patch_start.sum(dim=-1) - 1
+
+ return patch_ids
+
+
+class BLTPatcher(BLTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.rope_embeddings = RotaryEmbedding(
+ theta=config.patcher_rope_theta,
+ head_dim=config.patcher_head_dim or config.patcher_dim // config.patcher_n_heads,
+ max_seqlen=config.patcher_max_seqlen,
+ rope_use_fp32_in_outer_product=config.patcher_rope_use_fp32_in_outer_product,
+ )
+
+ self.layers = nn.ModuleList()
+ for _ in range(config.patcher_n_layers):
+ self.layers.append(
+ BLTTransformerLayer(
+ {
+ "dim": config.patcher_dim,
+ "n_heads": config.patcher_n_heads,
+ "head_dim": config.patcher_head_dim,
+ "n_kv_heads": config.patcher_n_kv_heads,
+ "rope_theta": config.patcher_rope_theta,
+ "multiple_of": config.patcher_multiple_of,
+ "ffn_dim_multiplier": config.patcher_ffn_dim_multiplier,
+ "norm_eps": config.patcher_norm_eps,
+ }
+ )
+ )
+
+ #assert config.patcher_vocab_size > 0
+
+ self.tok_embeddings = torch.nn.Embedding(config.patcher_vocab_size, config.patcher_dim)
+
+ self.norm = RMSNorm(config.patcher_dim, eps=config.patcher_norm_eps)
+
+ self.output = nn.Linear(
+ config.patcher_dim,
+ config.patcher_vocab_size,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ token_values: torch.Tensor,
+ patch_size: Optional[int] = None,
+ include_next_token: bool = True,
+ threshold: Optional[float] = None,
+ max_patch_length: Optional[int] = None,
+ patching_batch_size: int = 1,
+ device: Optional[str] = None,
+ ):
+
+ # Handle chunked processing for entropy calculation
+ entropies = []
+ preds = []
+ max_length = self.config.patcher_max_seqlen
+ batch_numel = max_length * patching_batch_size
+ splits = torch.split(token_values.flatten(), batch_numel)
+
+ for split in splits:
+ pad_size = (max_length - (split.numel() % max_length)) % max_length
+ pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False)
+ split = torch.cat((split, pad), dim=0)
+ split = split.reshape(-1, max_length)
+ if device is not None:
+ split = split.to(device)
+
+ # Process chunk: embeddings -> layers -> output
+ bsz, seqlen = split.shape
+ h = self.tok_embeddings(split)
+ chunk_mask = create_causal_mask(
+ seqlen,
+ self.config.patcher_attn_impl ,
+ self.config.patcher_attn_bias_type,
+ sliding_window=self.config.patcher_sliding_window,
+ tokens=split,
+ eos_id=self.config.eos_id,
+ )
+
+ freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None)
+
+ for i, layer in enumerate(self.layers):
+ h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=self.config.patcher_attn_impl)
+
+ pred = self.output(self.norm(h))
+ pred = pred.reshape(-1, pred.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab]
+ preds.append(pred)
+ pred_entropies = self.entropy(pred)
+ entropies.append(pred_entropies)
+
+ concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape)
+ concat_preds = torch.cat(preds, dim=0).reshape(token_values.shape[0], -1)
+
+ # Always compute patch lengths from concatenated entropies
+ bs, seq_len = token_values.shape
+ seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
+
+ # Find patch start IDs based on entropy
+ if patch_size is not None:
+ patch_start_ids = self.find_entropy_patch_start_ids(
+ concat_entropies,
+ patch_size,
+ include_next_token=include_next_token,
+ threshold=threshold
+ )
+ patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok)
+ else:
+ # Default to byte-level patching
+ patch_lengths = torch.ones((bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device)
+
+ patch_lengths = process_patch_lengths(patch_lengths, max_patch_length)
+ return concat_entropies, patch_lengths, concat_preds
+
+
+ @staticmethod
+ def entropy(scores):
+ """
+ scores: [bs, seq_len, vocab]
+ returns [bs, seq_len]
+
+ Computes the entropy for each token in the batch.
+ Note: uses natural log.
+ """
+ log_probs = F.log_softmax(scores, dim=-1)
+ probs = torch.exp(log_probs)
+ p_log_p = log_probs * probs
+ entropy = -p_log_p.sum(dim=-1)
+ return entropy
+
+ @staticmethod
+ def patch_start_ids_from_patch_start_mask(patch_start_mask):
+ bs, trunc_seq_len = patch_start_mask.shape
+ max_patches = patch_start_mask.sum(dim=1).max()
+ if max_patches == 0:
+ patch_start_ids = torch.full(
+ (bs, trunc_seq_len),
+ trunc_seq_len,
+ dtype=torch.long,
+ device=patch_start_mask.device,
+ )
+ else:
+ patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(bs, 1)
+ extra_patch_ids = torch.full(
+ (bs, trunc_seq_len),
+ trunc_seq_len,
+ dtype=torch.long,
+ device=patch_start_mask.device,
+ )
+ all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
+ patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1)
+ patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, trunc_seq_len)[:, :max_patches]
+ return patch_start_ids
+
+ @staticmethod
+ def patch_lengths_from_start_ids(patch_start_ids, seq_len):
+ """
+ Calculate patch lengths from start ids.
+ start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then
+ the rest are filled to the seq len.
+ seq_len: ex: 7 length of the sequence
+
+ returns the patch lengths:
+ [1, 6] for the above example.
+ """
+ last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1)
+ patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1)
+ patch_lengths = patch_end_ids - patch_start_ids + 1
+ assert torch.all(patch_lengths >= 0), f"{patch_lengths}"
+ assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}"
+ return patch_lengths
+
+ @staticmethod
+ def find_entropy_patch_start_ids(
+ entropies,
+ patch_size=None,
+ threshold=None,
+ include_next_token=True,
+ ):
+ """
+ Use entropies to find the start ids of each patch.
+ Use patch_size or threshold to figure out the total number of patches to allocate.
+
+ When threshold is not None the number of patches is not constant between
+ different sequences, but patches can be identified incrementally rather than
+ decided globally using the entire sequence.
+ """
+ bs, seq_len = entropies.shape[:2]
+
+ first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(bs, 1)
+ preds_truncation_len = first_ids.shape[1] # remove the first preds because they will be start of patches.
+ entropies = entropies[:, 1:]
+ if threshold is None:
+ num_patches = seq_len // patch_size
+ patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices
+ patch_start_ids = patch_start_ids.sort(dim=1).values
+ else:
+ patch_start_mask = entropies > threshold
+ if not include_next_token:
+ patch_start_mask = patch_start_mask[:, :-1]
+ # patch_start_mask[1:] |= tokens[:-1] < OFFSET
+ patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask)
+
+ patch_start_ids = torch.cat((first_ids, patch_start_ids + preds_truncation_len), dim=1)
+ return patch_start_ids
+
+def init_hash_embeddings(
+ config,
+ local_encoder_dim: int,
+ encoder_hash_byte_group_size: list,
+):
+ """Initialize hash-based token embeddings for the BLT encoder."""
+ if config.encoder_hash_byte_group_size is None:
+ return None
+
+ embeddings = []
+ emb_dim = local_encoder_dim
+ encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab
+
+ for _ in range(config.encoder_hash_byte_group_nb_functions):
+ for _ in encoder_hash_byte_group_size:
+ embeddings.append(
+ nn.Embedding(
+ encoder_hash_byte_group_vocab,
+ emb_dim,
+ )
+ )
+
+ return nn.ModuleList(embeddings)
+
+
+__all__ = [
+ "BLTPreTrainedModel",
+ "BLTModel",
+ "BLTPatcher",
+ "BLTLocalEncoder",
+ "BLTLocalDecoder",
+ "BLTGlobalTransformer",
+]
diff --git a/backup_blt_wip copy/modular_blt.py b/backup_blt_wip copy/modular_blt.py
new file mode 100644
index 0000000000000000000000000000000000000000..f433e2c8b799f87ea4e86d0a2221bd18db4acd09
--- /dev/null
+++ b/backup_blt_wip copy/modular_blt.py
@@ -0,0 +1,1180 @@
+# coding=utf-8
+# Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""BLT model."""
+
+from ...utils import is_torch_flex_attn_available, logging
+from typing import Callable, List, Optional, Tuple, Union
+
+from ...cache_utils import Cache
+from ...activations import ACT2FN
+
+import torch
+import torch.distributions
+import torch.nn
+import torch.nn as nn
+from torch.nn import functional as F
+
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from .configuration_blt import (
+ BLTConfig,
+ BLTLocalEncoderConfig,
+ BLTLocalDecoderConfig,
+ BLTGlobalTransformerConfig,
+ BLTPatcherConfig,
+ PatchingModeEnum,
+)
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+from ..mllama.modeling_mllama import repeat_kv, eager_attention_forward, MllamaRotaryEmbedding, MllamaTextRMSNorm, MllamaCrossAttentionDecoderLayer, MllamaTextCrossAttention, MllamaTextSelfAttention
+
+logger = logging.get_logger(__name__)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ # TODO: not exactly equivalent to other transformers implementations,, need feedback
+ # Extract first head_dim//2 elements which correspond to the unique frequencies
+ # This matches the original BLT approach which uses head_dim//2 frequency pairs
+ head_dim = q.shape[-1]
+ cos_freqs = cos[..., :head_dim//2] # [B, S, D/2]
+ sin_freqs = sin[..., :head_dim//2] # [B, S, D/2]
+
+ # Expand cos/sin to match query/key tensor format [B, H, S, D/2]
+ cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2]
+ sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2]
+
+ # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ...
+ q_pairs = q.view(*q.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2]
+ k_pairs = k.view(*k.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2]
+
+ # Extract real and i parts
+ q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2]
+ k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2]
+
+ # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag]
+ q_real_rot = cos_freqs * q_real - sin_freqs * q_imag
+ q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag
+ k_real_rot = cos_freqs * k_real - sin_freqs * k_imag
+ k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag
+
+ # Recombine pairs and reshape back to original format
+ q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D]
+ k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D]
+
+ return q_rot.type_as(q), k_rot.type_as(k)
+
+
+class BLTMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class BLTRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ BLTRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class BLTTransformerLayer(nn.Module):
+ def __init__(self, config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.layer_idx = layer_idx
+
+ self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx)
+ self.mlp = BLTMLP(config)
+ self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ position_ids (`torch.LongTensor`, *optional*):
+ Position indices of tokens in the sequence for RoPE computation.
+ past_key_value (`Cache`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class BLTSelfAttention(nn.Module):
+ def __init__(self, config, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.num_heads = config.num_attention_heads
+ self.dropout = config.dropout
+ self.hidden_size = config.hidden_size
+ self.num_key_value_heads = config.num_key_value_heads
+ self.head_dim = config.hidden_size // self.num_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.scaling = None
+ self.rope_theta = config.rope_theta
+ self.layer_idx = layer_idx
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_embeddings: torch.Tensor,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ past_key_value=None,
+ cache_position=None,
+ **kwargs,
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ output_attentions = False
+ # self.config._attn_implementation = "sdpa"
+ # self.scaling = None
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and output_attentions:
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0):
+ primes = [
+ 1000000007, 5915587277, 1500450271, 3267000013, 5754853343,
+ 4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313,
+ ]
+ prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device)
+ powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device)
+ prime_powers = prime ** powers
+ return torch.sum(token_tensor * prime_powers, dim=-1)
+
+
+def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000):
+ """Hash token groups and map to range [0, max_hash]."""
+ with torch.no_grad():
+ batch_size, seq_len = token_ids.shape
+ # Add padding for sliding window
+ padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device)
+ padded_tokens = torch.cat([padding, token_ids], dim=1)
+
+ # Create sliding windows and compute hashes
+ windows = padded_tokens.unfold(1, group_size, 1)
+ hashes = rolling_polynomial_hash(windows, hash_func_nb)
+ hash_values = hashes % max_hash
+
+ hash_values.requires_grad = False
+ return hash_values
+
+
+def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list):
+ """Initialize hash-based token embeddings for the BLT encoder."""
+ num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size)
+ embeddings = [
+ nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim)
+ for _ in range(num_embeddings)
+ ]
+ return nn.ModuleList(embeddings)
+
+
+def compute_hash_embeddings(
+ local_encoder_tokens: torch.Tensor,
+ local_encoder,
+ encoder_hash_tok_embedding: nn.ModuleList,
+ encoder_hash_byte_group_nb_functions: int,
+ encoder_hash_byte_group_size: list,
+ encoder_hash_byte_group_vocab: int,
+) -> torch.Tensor:
+ """Compute token embeddings enhanced with hash-based embeddings."""
+ embeddings = local_encoder.embed_tokens(local_encoder_tokens)
+ embedding_idx = 0
+ for func_nb in range(encoder_hash_byte_group_nb_functions):
+ for group_size in encoder_hash_byte_group_size:
+ hash_ids = byte_group_hash_function(
+ local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab
+ )
+ embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids)
+ embedding_idx += 1
+
+ return embeddings
+
+
+def _prepare_patch_cross_attention_mask(
+ patch_ids: torch.Tensor,
+ num_patches: int,
+ sequence_length: int,
+ patches_as_queries: bool = False,
+ cross_attn_k: int = 1,
+ dtype: torch.dtype = torch.float32,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Prepare cross-attention mask for patch-based attention, following mllama's robust approach.
+
+ This function creates masks that control which patches can attend to which other patches,
+ with support for query/key role swapping and cross-attention multipliers.
+
+ Args:
+ patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids.
+ num_patches (int): Total number of patches.
+ sequence_length (int): Length of the sequence.
+ patches_as_queries (bool): If True, patches are used as queries, otherwise as keys.
+ cross_attn_k (int): Cross-attention multiplier for repeating patches.
+ dtype (torch.dtype): Data type for the output mask.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]:
+ - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len]
+ - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows
+ """
+ batch_size, seq_len = patch_ids.shape
+ device = patch_ids.device
+
+ # Determine query and key lengths based on configuration
+ if patches_as_queries:
+ q_len = num_patches * cross_attn_k
+ kv_len = sequence_length
+ # Create patch-to-sequence mapping
+ q_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(-1).expand(
+ batch_size, num_patches, seq_len
+ )
+ kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len)
+ else:
+ q_len = sequence_length
+ kv_len = num_patches * cross_attn_k
+ # Create sequence-to-patch mapping
+ q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches)
+ kv_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(
+ batch_size, seq_len, num_patches
+ )
+
+ # Create base attention mask - boolean mask where True means "should attend"
+ # Exact patch matching
+ cross_attention_mask = q_patch_ids == kv_patch_ids
+
+ # Handle cross_attn_k multiplier by repeating along appropriate dimension
+ repeat_dim = 1 if patches_as_queries else -1
+ cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim)
+
+ # Validate dimensions
+ expected_shape = (batch_size, q_len, kv_len)
+ if cross_attention_mask.shape != expected_shape:
+ raise ValueError(f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}")
+
+ # Reshape so it can be used by attn module - add head dimension
+ cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len]
+
+ # Invert the mask (following mllama pattern exactly)
+ # True -> 0.0 (attend), False -> 1.0 (will become -inf)
+ inverted_cross_attn_mask = (1.0 - cross_attention_mask.to(dtype))
+ cross_attention_mask = inverted_cross_attn_mask.masked_fill(
+ inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
+ )
+
+ # Apply full-row bias (following mllama pattern exactly)
+ # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's
+ # last dimension contains negative infinity values, otherwise it's 1
+ negative_inf_value = torch.finfo(dtype).min
+ full_text_row_masked_out_mask = (
+ (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
+ )
+ cross_attention_mask *= full_text_row_masked_out_mask
+
+ return cross_attention_mask, full_text_row_masked_out_mask
+
+
+def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor:
+ """
+ Splits patch lengths into smaller segments if they exceed `max_patch_length`.
+ Pads the result to uniform length across the batch.
+
+ Args:
+ patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths.
+ max_patch_length (int, optional): Maximum allowed length per patch.
+
+ Returns:
+ torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths.
+ """
+ if max_patch_length is None:
+ return patch_lengths
+
+ batch_size = patch_lengths.size(0)
+ processed = []
+
+ for seq in patch_lengths:
+ splits = []
+ for length in seq[seq > 0]:
+ length = length.item()
+ full_chunks, remainder = divmod(length, max_patch_length)
+ splits.extend([max_patch_length] * full_chunks)
+ if remainder:
+ splits.append(remainder)
+ processed.append(splits)
+
+ # Find max length to pad to
+ max_len = max(len(splits) for splits in processed)
+ padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
+
+ for i, splits in enumerate(processed):
+ if splits:
+ padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
+
+ # Trim zero columns
+ if (padded != 0).any(dim=0).sum() < padded.shape[1]:
+ last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1
+ padded = padded[:, :last_nonzero]
+
+ return padded
+
+
+class BLTRotaryEmbedding(nn.Module):
+ def __init__(self, config, device=None):
+ super().__init__()
+ self.rope_type = config.rope_scaling["rope_type"]
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class BLTLocalEncoder(nn.Module):
+ def __init__(self, config: BLTLocalEncoderConfig):
+ super().__init__()
+
+ self.hidden_size = config.hidden_size
+ self.vocab_size=config.vocab_size
+ self.num_hidden_layers = config.num_hidden_layers
+ self.dropout = config.dropout
+ self.cross_attn_all_layers = config.cross_attn_all_layers
+ self.cross_attn_k = config.cross_attn_k
+
+ self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(self.num_hidden_layers)])
+
+ self.rotary_emb = BLTRotaryEmbedding(config=config)
+
+ self.patch_embedding_projection = nn.Linear(
+ in_features=config.encoder_dim_patch_emb,
+ out_features=config.encoder_dim_token_emb * config.cross_attn_k,
+ bias=False,
+ )
+
+ self.embed_tokens = nn.Embedding(self.vocab_size + config.pm_size, self.hidden_size)
+
+ self.cross_attn_layers = torch.nn.ModuleList()
+ layers_to_add = self.num_hidden_layers if self.cross_attn_all_layers else 1
+ for layer_idx in range(layers_to_add):
+ self.cross_attn_layers.append(
+ BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size)
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ input_embeds: Optional[torch.Tensor] = None,
+ patch_embeds: Optional[torch.Tensor] = None,
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
+ cross_mask: Optional[torch.Tensor] = None,
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ num_patches: Optional[int] = None,
+ patch_ids: Optional[torch.Tensor] = None,
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+ ):
+ """ """
+ if input_embeds is None:
+ input_embeds = self.embed_tokens(input_ids)
+
+ batch_size, _, _ = input_embeds.shape
+
+ hidden_states = nn.functional.dropout(input_embeds, p=self.dropout, training=self.training)
+
+ position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1)
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ for idx, layer in enumerate(self.layers):
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None)
+ hidden_states = layer_outputs[0]
+
+ if idx == len(self.layers) - 1 or self.cross_attn_all_layers:
+ patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids)
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
+ patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.hidden_size)
+
+ layer_idx = idx if self.cross_attn_all_layers else 0
+ cross_attention_output, _, _ = self.cross_attn_layers[layer_idx](
+ hidden_states=patch_embeds,
+ cross_attention_states=hidden_states,
+ attention_mask=cross_mask,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
+ output_attentions=False,
+ use_cache=False,
+ cache_position=None,
+ )
+ patch_embeds = patch_embeds + cross_attention_output
+
+ encoder_cross_states = patch_embeds
+ return hidden_states, encoder_cross_states
+
+ def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids):
+ """
+ Reduce variable length patches to single embedding per patch
+ Note: this works with variable number of patches for different sequences in the batch
+ It handles variable length patches by assuming that patch_lengths will be 0 for any
+ extra patches on the *right*. Since there can be a variable number of patches
+ this function also return the number of patches for each sequence in the batch.
+ Any embeddings on the right that are not allocated to a patch
+ (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
+ will be sent to a dummy patch, which is trimmed before returning.
+ """
+ batch_size, _, embedding_dim = hidden_states.shape
+
+ patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])
+
+ reduced_embeddings = torch.zeros((batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device)
+ reduced_embeddings = reduced_embeddings.scatter_reduce(
+ src=hidden_states,
+ dim=1,
+ index=patch_ids,
+ reduce=reduction,
+ include_self=False,
+ )
+ reduced_embeddings = reduced_embeddings[:, :max_num_patches, :]
+
+ return reduced_embeddings
+
+
+class BLTLocalDecoder(nn.Module):
+ def __init__(self, config: BLTLocalDecoderConfig):
+ super().__init__()
+
+ # Extract config values to instance attributes
+ self.hidden_size = config.hidden_size
+ self.vocab_size=config.vocab_size
+ self.num_hidden_layers = config.num_hidden_layers
+ self.dropout = config.dropout
+ self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove
+ self.cross_attn_all_layers = config.cross_attn_all_layers
+ self.cross_attn_k = config.cross_attn_k
+
+ self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(self.num_hidden_layers)])
+
+ self.rotary_emb = BLTRotaryEmbedding(config=config)
+
+ self.patch_embedding_projection = nn.Linear(
+ in_features=config.hidden_size_global,
+ out_features=config.decoder_dim_token_emb * config.cross_attn_k,
+ bias=False,
+ )
+
+ self.norm = BLTRMSNorm(self.hidden_size, eps=config.norm_eps)
+
+ self.cross_attn_layers = torch.nn.ModuleList()
+ layers_to_add = self.num_hidden_layers if self.cross_attn_all_layers else 1
+ for layer_idx in range(layers_to_add):
+ self.cross_attn_layers.append(
+ BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size)
+ )
+
+ self.lm_head = nn.Linear(
+ self.hidden_size,
+ self.vocab_size,
+ bias=False,
+ )
+
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ embeds: Optional[torch.Tensor],
+ patch_embeds: Optional[torch.Tensor] = None,
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
+ cross_mask: Optional[torch.Tensor] = None,
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+ ):
+ batch_size, _, _ = embeds.shape
+
+ hidden_states = embeds
+
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
+ patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.hidden_size)
+
+ if patch_embeds is not None and not self.cross_attn_decoder:
+ hidden_states = hidden_states + patch_embeds
+
+ position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1)
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
+ for i, layer in enumerate(self.layers):
+ if i == 0 or self.cross_attn_all_layers:
+ # Use cross attention to extract info from patch_embeds into hidden_states
+ cross_attention_output, _, _ = self.cross_attn_layers[i](
+ hidden_states=hidden_states,
+ cross_attention_states=patch_embeds,
+ attention_mask=cross_mask,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
+ output_attentions=False,
+ use_cache=False,
+ cache_position=None,
+ )
+ hidden_states = hidden_states + cross_attention_output
+
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None)
+ hidden_states = layer_outputs[0]
+
+ logits = self.lm_head(self.norm(hidden_states))
+ return logits, cache
+
+
+class BLTCrossAttention(nn.Module):
+ """Cross-attention module for BLT, following transformers style"""
+
+ def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ # Use provided hidden_size or fallback to encoder dimension
+ self.hidden_size = hidden_size or config.hidden_size_local_encoder
+ self.num_heads = config.num_attention_heads
+ self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.scaling = None
+ self.dropout = config.dropout
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ self.q_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps)
+ self.k_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_norm(hidden_states) # BLT normalizes first
+ query_states = self.q_proj(query_states)
+
+ if cross_attention_states is not None:
+ cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first
+ key_states = self.k_proj(cross_attention_states)
+ value_states = self.v_proj(cross_attention_states)
+ if past_key_value is not None:
+ # if we have a new cross attention states + new tokens, we only computed key_states on that new cross attention states
+ # we still update the cross key states, past_cross_states, new_cross_states. And use it!
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ elif cache_position is not None and cache_position[0] != 0:
+ key_states, value_states = (
+ past_key_value.key_cache[self.layer_idx],
+ past_key_value.value_cache[self.layer_idx],
+ )
+ else:
+ if cross_attention_states is None:
+ raise ValueError(
+ "Cross attention layer can't find neither `cross_attention_states` nor cached values for key/values!"
+ )
+
+ attention_interface: Callable = eager_attention_forward
+
+ # self.config._attn_implementation = "sdpa"
+ # attn = "sdpa"
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and output_attentions:
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0, #if not self.training else self.dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ # Apply full row masking if provided (following mllama pattern)
+ if full_text_row_masked_out_mask is not None:
+ attn_output = full_text_row_masked_out_mask[:, 0] * attn_output
+
+ attn_output = attn_output + hidden_states
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class BLTGlobalTransformer(nn.Module):
+ def __init__(self, config: BLTGlobalTransformerConfig):
+ super().__init__()
+
+ self.hidden_size = config.hidden_size
+ self.num_hidden_layers = config.num_hidden_layers
+ self.dropout = config.dropout
+
+ self.layers = nn.ModuleList()
+ for layer_idx in range(self.num_hidden_layers):
+ self.layers.append(BLTTransformerLayer(config, layer_idx))
+
+ self.rotary_emb = BLTRotaryEmbedding(config=config)
+
+
+ def forward(
+ self,
+ input_embeds: torch.Tensor,
+ mask: Optional[Union[BlockMask, torch.Tensor, str]] = None,
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+ ):
+ batch_size, seq_len, _ = input_embeds.shape
+
+ hidden_states = input_embeds
+
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1)
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ for i, layer in enumerate(self.layers):
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None)
+ hidden_states = layer_outputs[0]
+
+ return hidden_states, cache
+
+
+
+
+class BLTPreTrainedModel(PreTrainedModel):
+ config_class = BLTConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = False # BLT uses its own attention implementation
+ _supports_sdpa = True
+ _supports_cache_class = False
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Linear):
+ std = getattr(module, '_custom_std', module.in_features ** (-0.5))
+ nn.init.trunc_normal_(
+ module.weight,
+ mean=0.0,
+ std=std,
+ a=-3 * std,
+ b=3 * std,
+ )
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ elif isinstance(module, nn.Embedding):
+ std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5))
+ nn.init.trunc_normal_(
+ module.weight,
+ mean=0.0,
+ std=std,
+ a=-3 * std,
+ b=3 * std,
+ )
+
+ elif isinstance(module, BLTModel):
+ if module.encoder_hash_tok_embedding is not None:
+ emb_std = module.config.hidden_size_local_encoder ** (-0.5)
+ for emb in module.encoder_hash_tok_embedding:
+ emb._custom_std = emb_std
+
+ elif isinstance(module, BLTLocalEncoder):
+ if module.patch_embedding_projection is not None:
+ module.patch_embedding_projection._custom_std = module.config.encoder_dim_patch_emb ** (-0.5)
+
+ elif isinstance(module, BLTLocalDecoder):
+ if module.patch_embedding_projection is not None:
+ module.patch_embedding_projection._custom_std = module.config.hidden_size_global ** (-0.5)
+
+ elif isinstance(module, BLTPatcher):
+ emb_std = module.config.hidden_size ** (-0.5)
+ module.embed_tokens._custom_std = emb_std
+ module.lm_head._custom_std = emb_std
+
+
+class BLTModel(BLTPreTrainedModel):
+ def __init__(self, config: BLTConfig):
+ super().__init__(config)
+
+ self.config = config
+
+ self.local_encoder = BLTLocalEncoder(config.encoder_config)
+ self.global_transformer = BLTGlobalTransformer(config.global_config)
+ self.local_decoder = BLTLocalDecoder(config.decoder_config)
+
+ self.encoder_hash_tok_embedding = init_hash_embeddings(
+ config,
+ local_encoder_dim=config.hidden_size_local_encoder,
+ encoder_hash_byte_group_size=config.encoder_hash_byte_group_size,
+ )
+
+ if self.config.patch_in_forward:
+ self.patcher = BLTPatcher(config.patcher_config)
+ self.patcher.eval()
+ for param in self.patcher.parameters():
+ param.requires_grad = False
+ else:
+ self.patcher = None
+
+ def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = None):
+ batch_size, sequence_length = tokens.shape
+
+ # Handle patching
+ if patch_lengths is None:
+ if self.config.patching_mode == PatchingModeEnum.entropy:
+ _, patch_lengths, _ = self.patcher(
+ tokens,
+ patch_size=self.config.patch_size,
+ threshold=self.config.patching_threshold,
+ max_patch_length=self.config.max_patch_length,
+ patching_batch_size=self.config.patching_batch_size,
+ device=self.config.patching_device,
+ )
+ else:
+ # Default to byte-level patching
+ patch_lengths = process_patch_lengths(
+ torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device),
+ self.config.max_patch_length
+ )
+
+ patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length)
+ cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask(
+ patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, torch.float32
+ )
+
+ encoder_embeds = compute_hash_embeddings(
+ tokens, self.local_encoder, self.encoder_hash_tok_embedding,
+ self.config.encoder_hash_byte_group_nb_functions,
+ self.config.encoder_hash_byte_group_size,
+ self.config.encoder_hash_byte_group_vocab,
+ )
+
+ encoder_hidden_states, encoder_cross_states = self.local_encoder(
+ input_ids=tokens,
+ input_embeds=encoder_embeds,
+ patch_embeds=None,
+ cross_mask=cross_attn_mask_enc,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc,
+ num_patches=patch_lengths.shape[1],
+ patch_ids=patch_ids,
+ )
+
+ global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
+
+ global_hidden_states, _ = self.global_transformer(
+ input_embeds=global_hidden_states,
+ )
+
+ decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length)
+ cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask(
+ decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, torch.float32
+ )
+
+ output, _ = self.local_decoder(
+ tokens=tokens,
+ embeds=encoder_hidden_states,
+ patch_embeds=global_hidden_states,
+ mask=None,
+ cross_mask=cross_attn_mask_dec,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec,
+ )
+
+ return output
+
+ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
+ """Convert patch lengths to patch IDs for each token position."""
+ batch_size = patch_lengths.shape[0]
+ patch_starts = torch.cat([
+ torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
+ patch_lengths.cumsum(dim=-1)[:, :-1]
+ ], dim=-1)
+
+ token_positions = torch.arange(seq_len, device=patch_lengths.device)
+ return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1
+
+
+class BLTPatcher(BLTPreTrainedModel):
+ def __init__(self, config: BLTPatcherConfig):
+ super().__init__(config)
+
+ self.rotary_emb = BLTRotaryEmbedding(config=self.config)
+
+ self.layers = nn.ModuleList()
+ # Create transformer layers using the patcher config
+ for layer_idx in range(self.config.num_hidden_layers):
+ self.layers.append(BLTTransformerLayer(self.config, layer_idx))
+
+
+ self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.hidden_size)
+
+ self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps)
+
+ self.lm_head = nn.Linear(
+ self.config.hidden_size,
+ self.config.vocab_size,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ token_values: torch.Tensor,
+ patch_size: Optional[int] = None,
+ threshold: Optional[float] = None,
+ max_patch_length: Optional[int] = None,
+ patching_batch_size: int = 1,
+ device: Optional[str] = None,
+ ):
+
+ # Handle chunked processing for entropy calculation
+ entropies = []
+ predictions = []
+ max_length = self.config.max_position_embeddings
+ batch_numel = max_length * patching_batch_size
+ splits = torch.split(token_values.flatten(), batch_numel)
+
+ for split in splits:
+ pad_size = (max_length - (split.numel() % max_length)) % max_length
+ pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False)
+ split = torch.cat((split, pad), dim=0)
+ split = split.reshape(-1, max_length)
+ if device is not None:
+ split = split.to(device)
+
+ # Process chunk: embeddings -> layers -> output
+ batch_size, sequence_length = split.shape
+ input_embeds = self.embed_tokens(split)
+
+ hidden_states = input_embeds
+
+ batch_size, _, _ = input_embeds.shape
+
+ position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1)
+
+ position_embeddings = self.rotary_emb(hidden_states, position_ids) # = BLT self.rope
+
+ for i, layer in enumerate(self.layers):
+ layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) #, attn_impl=self.config.patcher_attn_impl )
+ hidden_states = layer_outputs[0]
+
+ logits = self.lm_head(self.norm(hidden_states))
+ logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab]
+ predictions.append(logits)
+ prediction_entropies = torch.distributions.Categorical(logits=logits).entropy()
+ entropies.append(prediction_entropies)
+
+ concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape)
+ concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1)
+
+ # Always compute patch lengths from concatenated entropies
+ batch_size, sequence_length = token_values.shape
+
+ # Find patch start IDs based on entropy
+ if patch_size is not None:
+ patch_lengths = self.patch_lengths_from_entropies(
+ entropies=concat_entropies,
+ sequence_length=sequence_length,
+ patch_size=patch_size,
+ threshold=threshold,
+ )
+ else:
+ # Default to byte-level patching
+ patch_lengths = torch.ones((batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device)
+ patch_lengths = process_patch_lengths(patch_lengths, max_patch_length)
+ return concat_entropies, patch_lengths, concat_predictions
+
+ @staticmethod
+ def patch_lengths_from_entropies(
+ entropies,
+ sequence_length,
+ patch_size=None,
+ threshold=None,
+ ):
+ """
+ Computes patch lengths from token entropies.
+
+ Depending on whether a threshold is provided, the function uses either:
+ - Top-k selection based on entropy (when `threshold` is None), or
+ - Thresholding the entropy values (when `threshold` is set).
+ """
+
+ batch_size = entropies.shape[0]
+
+ # Always include token 0 and 1 as starting tokens
+ init_tokens = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1)
+ offset = init_tokens.shape[1]
+
+ # Ignore first token entropy (BOS)
+ entropies = entropies[:, 1:]
+
+ if threshold is None:
+ # Use top-k entropy values to define patch start points
+ num_patches = sequence_length // patch_size
+ topk_indices = entropies.topk(num_patches - 2, dim=1).indices
+ patch_starts = topk_indices.sort(dim=1).values
+ else:
+ # Threshold the entropy values to define patch start points
+ patch_mask = entropies > threshold
+
+ seq_len = patch_mask.shape[1]
+
+ # Create patch IDs (token indices), and add a sentinel to ensure alignment
+ token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1)
+ sentinel = torch.full_like(token_indices, seq_len)
+ padded_indices = torch.cat([token_indices, sentinel], dim=1)
+
+ # Pad mask with inverse to align sentinel correctly
+ padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1)
+
+ # Select indices where mask is True
+ patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len)
+ max_valid_patches = patch_mask.sum(dim=1).max()
+ patch_starts = patch_starts[:, :max_valid_patches]
+
+ # Offset patch starts to account for the two initial tokens
+ patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1)
+
+ # Compute patch end positions by shifting start positions
+ last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1)
+ patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1)
+
+ patch_lengths = patch_ends - patch_start_ids + 1
+
+ return patch_lengths
+
+__all__ = [
+ "BLTPreTrainedModel",
+ "BLTModel",
+ "BLTPatcher",
+ "BLTLocalEncoder",
+ "BLTLocalDecoder",
+ "BLTGlobalTransformer",
+ "BLTTransformerLayer",
+]
\ No newline at end of file
diff --git a/backup_blt_wip copy/tokenization_blt.py b/backup_blt_wip copy/tokenization_blt.py
new file mode 100644
index 0000000000000000000000000000000000000000..d116fae5188fcd12c741ddfcdb6dfb5973585b28
--- /dev/null
+++ b/backup_blt_wip copy/tokenization_blt.py
@@ -0,0 +1,271 @@
+# coding=utf-8
+# Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for BLT."""
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+
+
+if TYPE_CHECKING:
+ from ...tokenization_utils_base import TextInput
+
+logger = logging.get_logger(__name__)
+
+# BLT tokenizer constants
+SEP = " "
+BOS_ID: int = 1
+EOS_ID: int = 2
+PAD_ID: int = -1
+BOE_ID: int = 0
+BPE_ID: int = 3
+OFFSET: int = 4
+BYTE_UNITS: int = 256
+
+VOCAB_FILES_NAMES = {} # BLT doesn't require external vocab files
+
+
+class BLTTokenizer(PreTrainedTokenizer):
+ """
+ Construct a BLT tokenizer. Based on byte-level tokenization where each byte is treated as a token.
+
+ This tokenizer converts text to UTF-8 bytes and then maps each byte to a token ID with an offset.
+ It supports special tokens for beginning of sequence (BOS), end of sequence (EOS),
+ beginning of example (BOE), and padding (PAD).
+
+ Args:
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The beginning of sequence token.
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The end of sequence token.
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The padding token.
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The unknown token. Not used in BLT but kept for compatibility.
+ boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The beginning of example token, specific to BLT.
+ add_bos_token (`bool`, *optional*, defaults to `True`):
+ Whether or not to add a `bos_token` at the start of sequences.
+ add_eos_token (`bool`, *optional*, defaults to `True`):
+ Whether or not to add an `eos_token` at the end of sequences.
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+ Whether or not to cleanup spaces after decoding.
+ spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to add spaces between special tokens.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ bos_token="",
+ eos_token="",
+ pad_token="",
+ unk_token="",
+ boe_token="",
+ add_bos_token=True,
+ add_eos_token=True,
+ clean_up_tokenization_spaces=False,
+ spaces_between_special_tokens=False,
+ **kwargs,
+ ):
+ # Store BLT-specific parameters first
+ self.add_bos_token = add_bos_token
+ self.add_eos_token = add_eos_token
+ self.vocab_size_unit_1 = BYTE_UNITS
+ self.offsetting_special_char = OFFSET
+
+ # BLT token IDs (exactly like original)
+ self.boe_id = BOE_ID
+ self.bos_id = BOS_ID
+ self.eos_id = EOS_ID
+ self.pad_id = PAD_ID
+ self.bpe_id = BPE_ID
+ self.n_words = self.vocab_size_unit_1 + self.offsetting_special_char
+
+ # Convert string tokens to AddedToken objects
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
+ pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
+ self.boe_token = AddedToken(boe_token, normalized=False, special=True) if isinstance(boe_token, str) else boe_token
+
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ unk_token=unk_token,
+ add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ spaces_between_special_tokens=spaces_between_special_tokens,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self):
+ """Returns vocab size"""
+ return self.vocab_size_unit_1 + self.offsetting_special_char
+
+ def get_vocab(self):
+ """Returns vocab as a dict"""
+ # Create a mapping for byte values + offset
+ vocab = {}
+
+ # Add special tokens (with defensive checks)
+ if hasattr(self, 'bos_token'):
+ vocab[str(self.bos_token)] = self.bos_id
+ if hasattr(self, 'eos_token'):
+ vocab[str(self.eos_token)] = self.eos_id
+ if hasattr(self, 'pad_token'):
+ vocab[str(self.pad_token)] = self.pad_id
+ if hasattr(self, 'boe_token'):
+ vocab[str(self.boe_token)] = self.boe_id
+
+ # Add byte tokens as string representations of byte values
+ vocab_size_unit_1 = getattr(self, 'vocab_size_unit_1', BYTE_UNITS)
+ offsetting_special_char = getattr(self, 'offsetting_special_char', OFFSET)
+ for i in range(vocab_size_unit_1):
+ vocab[str(i)] = i + offsetting_special_char
+
+ # Add any additional tokens if available
+ if hasattr(self, 'added_tokens_encoder'):
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
+ """
+ Converts a string to a list of tokens. For BLT, we work directly with byte values.
+ Returns a list of strings that represent the byte values.
+ """
+ # Convert text to UTF-8 bytes, just like the original
+ try:
+ bytes_data = text.encode("utf-8", errors="ignore")
+ except UnicodeEncodeError:
+ bytes_data = text.encode("utf-8", errors="ignore")
+
+ # Return string representations of byte values for the tokenizer framework
+ return [str(byte_val) for byte_val in bytes_data]
+
+ def _convert_token_to_id(self, token: str) -> int:
+ """Converts a token (str) to an id using the vocab."""
+ # Handle special tokens
+ if token == str(self.bos_token):
+ return self.bos_id
+ elif token == str(self.eos_token):
+ return self.eos_id
+ elif token == str(self.pad_token):
+ return self.pad_id
+ elif token == str(self.boe_token):
+ return self.boe_id
+ else:
+ try:
+ # Convert byte value string to int and add offset (like original)
+ byte_val = int(token)
+ if 0 <= byte_val <= 255:
+ return byte_val + self.offsetting_special_char
+ except ValueError:
+ pass
+
+ # Check if it's in added tokens
+ return self.added_tokens_encoder.get(token, self.unk_token_id)
+
+ def _convert_id_to_token(self, index: int) -> str:
+ """Converts an index (integer) to a token (str) using the vocab."""
+ # Handle special tokens
+ if index == self.bos_id:
+ return str(self.bos_token)
+ elif index == self.eos_id:
+ return str(self.eos_token)
+ elif index == self.pad_id:
+ return str(self.pad_token)
+ elif index == self.boe_id:
+ return str(self.boe_token)
+ elif index >= self.offsetting_special_char and index < self.vocab_size:
+ # Convert back to byte value (like original)
+ byte_val = index - self.offsetting_special_char
+ return str(byte_val)
+ else:
+ # Check added tokens
+ for token, token_id in self.added_tokens_encoder.items():
+ if token_id == index:
+ return token
+ return str(self.unk_token)
+
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
+ """Converts a sequence of tokens to a single string."""
+ byte_values = []
+
+ for token in tokens:
+ # Skip special tokens
+ if token in [str(self.bos_token), str(self.eos_token), str(self.pad_token), str(self.boe_token)]:
+ continue
+
+ try:
+ # Convert token back to byte value (like original decode method)
+ byte_val = int(token)
+ if 0 <= byte_val <= 255:
+ byte_values.append(byte_val)
+ except ValueError:
+ continue
+
+ # Convert byte values back to string (exactly like original)
+ try:
+ return bytes(byte_values).decode("utf-8", errors="ignore")
+ except (UnicodeDecodeError, ValueError):
+ return ""
+
+ def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None):
+ """
+ Encode text exactly like the original BLT tokenizer.
+ """
+ if add_bos is None:
+ add_bos = self.add_bos_token
+ if add_eos is None:
+ add_eos = self.add_eos_token
+
+ # Since bpe_delim=False, we use the simple byte encoding
+ tokens = bytes(text, encoding="utf-8", errors="ignore")
+
+ # Offsetting (exactly like original)
+ tokens = [int(unit) + self.offsetting_special_char for unit in tokens]
+
+ if add_bos:
+ tokens.insert(0, self.bos_id)
+ if add_eos:
+ tokens.append(self.eos_id)
+
+ return tokens
+
+ def decode(self, tokens: list[int], cut_at_eos: bool = False):
+ """
+ Decode tokens exactly like the original BLT tokenizer.
+ """
+ if cut_at_eos:
+ for k, t in enumerate(tokens):
+ if t == self.eos_id:
+ tokens = tokens[: k + 1]
+ break
+ return bytes(
+ [tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0]
+ ).decode("utf-8", errors="ignore")
+
+ def get_vocab_size(self) -> int:
+ """Get vocab size like the original tokenizer."""
+ return self.vocab_size_unit_1 + self.offsetting_special_char
+
+#__all__ = ["BLTTokenizer"]
\ No newline at end of file
diff --git a/backup_blt_wip_backup/__pycache__/blt_args.cpython-312.pyc b/backup_blt_wip_backup/__pycache__/blt_args.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe1c3fac23fde31ac2fe95aa5c15d35368e766fc
Binary files /dev/null and b/backup_blt_wip_backup/__pycache__/blt_args.cpython-312.pyc differ
diff --git a/backup_blt_wip_backup/__pycache__/blt_one_file.cpython-312.pyc b/backup_blt_wip_backup/__pycache__/blt_one_file.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad67ee5b9c82f98e084ea14790b83f2fbbca1539
Binary files /dev/null and b/backup_blt_wip_backup/__pycache__/blt_one_file.cpython-312.pyc differ
diff --git a/backup_blt_wip_backup/__pycache__/configuration_blt.cpython-312.pyc b/backup_blt_wip_backup/__pycache__/configuration_blt.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29c901cfa128a76934a45a17664faae1d2ff226c
Binary files /dev/null and b/backup_blt_wip_backup/__pycache__/configuration_blt.cpython-312.pyc differ
diff --git a/backup_blt_wip_backup/__pycache__/modeling_blt_wip.cpython-312.pyc b/backup_blt_wip_backup/__pycache__/modeling_blt_wip.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14471f3e8bde08cd94f2ba3b8ea1dc01597af449
Binary files /dev/null and b/backup_blt_wip_backup/__pycache__/modeling_blt_wip.cpython-312.pyc differ
diff --git a/backup_blt_wip_backup/__pycache__/modeling_blt_wip_backup.cpython-312.pyc b/backup_blt_wip_backup/__pycache__/modeling_blt_wip_backup.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..01bb41a2f95cf657f1cb253e41a8eadbe9a4c5b7
Binary files /dev/null and b/backup_blt_wip_backup/__pycache__/modeling_blt_wip_backup.cpython-312.pyc differ
diff --git a/backup_blt_wip_backup/__pycache__/tokenization_blt.cpython-312.pyc b/backup_blt_wip_backup/__pycache__/tokenization_blt.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f38c87100166c7b926f6683567c5df1366316c48
Binary files /dev/null and b/backup_blt_wip_backup/__pycache__/tokenization_blt.cpython-312.pyc differ
diff --git a/backup_blt_wip_backup/blt_args.py b/backup_blt_wip_backup/blt_args.py
new file mode 100644
index 0000000000000000000000000000000000000000..e043d1dd20a85564ed998ec608481313b30f5406
--- /dev/null
+++ b/backup_blt_wip_backup/blt_args.py
@@ -0,0 +1,187 @@
+from enum import Enum
+from typing import Any
+
+from pydantic import BaseModel, ConfigDict, model_validator
+from typing_extensions import Self
+
+
+EOS_ID: int = 2
+
+
+class InitStdFactor(str, Enum):
+ DISABLED = "disabled" # Init std is divided by 1.0
+ GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
+ CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
+ DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
+
+
+class PatchingModeEnum(str, Enum):
+ entropy = "entropy"
+ bpe = "bpe"
+ bpe_patcher = "bpe_patcher"
+ space = "space"
+ static = "static"
+ byte = "byte"
+
+
+class LMTransformerArgs(BaseModel):
+ """Arguments for the Language Model Transformer (used as entropy model for patching)"""
+
+ model_config = ConfigDict()
+
+ # Basic architecture
+ dim: int = 512
+ n_layers: int = 8
+ head_dim: int | None = None
+ n_heads: int | None = None
+ n_kv_heads: int | None = None
+
+ # Transformer configuration
+ max_seqlen: int = 1024
+ norm_eps: float = 1e-5
+ dropout: float = 0
+ vocab_size: int = -1
+ sliding_window: int | None = None
+
+ # Feedforward
+ ffn_dim_multiplier: float | None = None
+ multiple_of: int = 256
+
+ # Positional encoding
+ rope_theta: float = 10000.0
+ rope_use_fp32_in_outer_product: bool = False
+
+ # Attention
+ attn_impl: str = "sdpa"
+ attn_bias_type: str = "causal"
+
+ # Initialization
+ init_base_std: float | None = None
+ init_std_factor: InitStdFactor = InitStdFactor.DISABLED
+
+ # Embedding dimensions
+ dim_token_emb: int | None = None
+
+ # Model behavior
+ weight_tying: bool = False
+ seed: int = 42
+
+ # Special token config
+ eos_id: int = EOS_ID
+
+
+class ByteLatentTransformerArgs(BaseModel):
+ """Arguments for the Byte Latent Transformer (main BLT model)"""
+
+ model_config = ConfigDict()
+
+ # Basic model configuration
+ seed: int = 42
+ vocab_size: int = -1
+
+ # Main architecture dimensions (these will be used for creating transformer args)
+ dim: int = 512
+ n_layers: int = 8
+ head_dim: int | None = None
+ n_heads: int | None = None
+ n_kv_heads: int | None = None
+
+ # Component-specific dimensions
+ dim_global: int = 512
+ dim_local_decoder: int = 512
+ dim_local_encoder: int = 512
+ n_layers_global: int = 8
+ n_layers_local_decoder: int = 8
+ n_layers_local_encoder: int = 8
+ n_heads_global: int = 8
+ n_heads_local_decoder: int = 8
+ n_heads_local_encoder: int = 8
+ n_kv_heads_global: int | None = None
+
+ # Transformer configuration (needed by transformer components)
+ max_seqlen: int = 1024
+ norm_eps: float = 1e-5
+ dropout: float = 0
+
+ # Feedforward (needed by transformer components)
+ ffn_dim_multiplier: float = 1.0
+ multiple_of: int = 256
+
+ # Positional encoding (needed by transformer components)
+ rope_theta: float = 10000.0
+ rope_use_fp32_in_outer_product: bool = False
+
+ # Attention (needed by transformer components)
+ attn_impl: str = "sdpa"
+ attn_bias_type: str = "causal"
+
+ # Initialization (needed by transformer components)
+ init_base_std: float | None = None
+ init_std_factor: InitStdFactor = InitStdFactor.DISABLED
+
+ # Embedding dimensions (needed by transformer components)
+ dim_token_emb: int | None = None
+
+ # Patching configuration
+ patch_in_forward: bool = False
+ realtime_patching: bool = True
+ patch_size: float | None = None
+ patching_mode: str | None = None
+ patching_threshold: float | None = None
+ patching_threshold_add: float | None = None
+ monotonicity: bool = False
+ patching_batch_size: int = 1
+ patching_device: str = "cuda"
+ max_patch_length: int | None = None
+ entropy_model_checkpoint_dir: str | None = None
+
+ # Cross attention configurations
+ cross_attn_encoder: bool = False
+ cross_attn_decoder: bool = False
+ cross_attn_window_encoder: int | None = None
+ cross_attn_window_decoder: int | None = None
+ cross_attn_k: int | None = None
+ cross_attn_nheads: int | None = None
+ cross_attn_all_layers_decoder: bool = False
+ cross_attn_all_layers_encoder: bool = False
+ cross_attn_use_flex_attention: bool = True
+ cross_attn_init_by_pooling: bool = False
+
+ # Encoder configurations
+ use_local_encoder_transformer: bool = False
+ max_encoder_seq_length: int | None = None
+ encoder_hash_byte_group_size: Any | None = None
+ encoder_hash_byte_group_vocab: int = 30000
+ encoder_hash_byte_group_nb_functions: int = 3
+ encoder_enable_byte_ngrams: bool = False
+ encoder_ngram_to_size_str: str | None = None
+ downsampling_by_pooling: str | None = None
+
+ # Architecture and dimensions
+ dim_token: int | None = None
+ share_encoder_decoder_emb: bool = True
+ weight_tying: bool = False
+
+ # Attention configuration
+ local_attention_window_len: int | None = None
+ use_rope: bool = True
+
+ # Performance optimization
+ sequence_parallel: bool = False
+ loss_parallel: bool = False
+ fuse_sequence_parallel: bool = False
+ use_fsdp: bool = True
+
+ # Parameter mixing
+ pm_size: int = 0
+
+ # Special token config
+ eos_id: int = EOS_ID
+
+ @model_validator(mode="after")
+ def check_hash_byte_sizes(self) -> Self:
+ if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str:
+ self.encoder_hash_byte_group_size = [
+ int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0
+ ]
+ return self
diff --git a/backup_blt_wip_backup/configuration_blt.py b/backup_blt_wip_backup/configuration_blt.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c645e8b18f63496c4f3419243250465eac941e8
--- /dev/null
+++ b/backup_blt_wip_backup/configuration_blt.py
@@ -0,0 +1,590 @@
+# coding=utf-8
+# Copyright 2024 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""BLT (Byte Latent Transformer) model configuration"""
+
+from enum import Enum
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class InitStdFactor(str, Enum):
+ DISABLED = "disabled" # Init std is divided by 1.0
+ CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
+
+
+class PatchingModeEnum(str, Enum):
+ entropy = "entropy"
+ bpe = "bpe"
+ bpe_patcher = "bpe_patcher"
+ space = "space"
+ static = "static"
+ byte = "byte"
+
+
+class BLTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ByteLatentTransformer`]. It is used to instantiate a
+ BLT model according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 256):
+ Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented.
+ max_seqlen (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model can handle.
+
+ # Main architecture dimensions
+ dim (`int`, *optional*, defaults to 512):
+ Main dimension of the model.
+ n_layers (`int`, *optional*, defaults to 8):
+ Number of layers in the main transformer.
+ n_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads in the main transformer.
+ head_dim (`int`, *optional*):
+ Dimension of each attention head. If not specified, computed as dim // n_heads.
+ n_kv_heads (`int`, *optional*):
+ Number of key-value heads for grouped query attention. If not specified, defaults to n_heads.
+
+ # Component-specific dimensions
+ dim_global (`int`, *optional*, defaults to 512):
+ Dimension of the global transformer component.
+ dim_local_decoder (`int`, *optional*, defaults to 512):
+ Dimension of the local decoder component.
+ dim_local_encoder (`int`, *optional*, defaults to 512):
+ Dimension of the local encoder component.
+ n_layers_global (`int`, *optional*, defaults to 8):
+ Number of layers in the global transformer.
+ n_layers_local_decoder (`int`, *optional*, defaults to 8):
+ Number of layers in the local decoder.
+ n_layers_local_encoder (`int`, *optional*, defaults to 8):
+ Number of layers in the local encoder.
+ n_heads_global (`int`, *optional*, defaults to 8):
+ Number of attention heads in the global transformer.
+ n_heads_local_decoder (`int`, *optional*, defaults to 8):
+ Number of attention heads in the local decoder.
+ n_heads_local_encoder (`int`, *optional*, defaults to 8):
+ Number of attention heads in the local encoder.
+ n_kv_heads_global (`int`, *optional*):
+ Number of key-value heads in the global transformer.
+
+ # Transformer configuration
+ norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers.
+ ffn_dim_multiplier (`float`, *optional*, defaults to 1.0):
+ Multiplier for the feedforward network dimension.
+ multiple_of (`int`, *optional*, defaults to 256):
+ Make feedforward network dimension multiple of this value.
+
+ # Positional encoding
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False):
+ Whether to use fp32 in RoPE outer product computation.
+
+ # Attention configuration
+ attn_impl (`str`, *optional*, defaults to "sdpa"):
+ Attention implementation to use ("sdpa" or "flex_attention").
+ attn_bias_type (`str`, *optional*, defaults to "causal"):
+ Type of attention bias to apply.
+ local_attention_window_len (`int`, *optional*):
+ Window length for local attention.
+ use_rope (`bool`, *optional*, defaults to True):
+ Whether to use rotary position embeddings.
+
+ # Initialization
+ init_base_std (`float`, *optional*):
+ Base standard deviation for weight initialization.
+ init_std_factor (`str`, *optional*, defaults to "disabled"):
+ Factor for adjusting initialization standard deviation.
+
+ # Embedding dimensions
+ dim_token_emb (`int`, *optional*):
+ Token embedding dimension.
+ dim_token (`int`, *optional*):
+ Token dimension.
+
+ # Patching configuration
+ patch_in_forward (`bool`, *optional*, defaults to False):
+ Whether to perform patching during forward pass.
+ realtime_patching (`bool`, *optional*, defaults to True):
+ Whether to use realtime patching.
+ patch_size (`float`, *optional*):
+ Size of patches for static patching.
+ patching_mode (`str`, *optional*):
+ Mode for patching ("entropy", "static", etc.).
+ patching_threshold (`float`, *optional*):
+ Threshold for entropy-based patching.
+ patching_threshold_add (`float`, *optional*):
+ Additional threshold parameter for patching.
+ monotonicity (`bool`, *optional*, defaults to False):
+ Whether to enforce monotonicity in patching.
+ patching_batch_size (`int`, *optional*, defaults to 1):
+ Batch size for patching operations.
+ patching_device (`str`, *optional*, defaults to "cuda"):
+ Device to use for patching operations.
+ max_patch_length (`int`, *optional*):
+ Maximum length of patches.
+ entropy_model_checkpoint_dir (`str`, *optional*):
+ Directory containing entropy model checkpoint.
+
+ # Cross attention configurations
+ cross_attn_encoder (`bool`, *optional*, defaults to False):
+ Whether to use cross attention in encoder.
+ cross_attn_decoder (`bool`, *optional*, defaults to False):
+ Whether to use cross attention in decoder.
+ cross_attn_window_encoder (`int`, *optional*):
+ Cross attention window for encoder.
+ cross_attn_window_decoder (`int`, *optional*):
+ Cross attention window for decoder.
+ cross_attn_k (`int`, *optional*):
+ Number of cross attention components.
+ cross_attn_nheads (`int`, *optional*):
+ Number of heads for cross attention.
+ cross_attn_all_layers_decoder (`bool`, *optional*, defaults to False):
+ Whether to apply cross attention to all decoder layers.
+ cross_attn_all_layers_encoder (`bool`, *optional*, defaults to False):
+ Whether to apply cross attention to all encoder layers.
+ cross_attn_use_flex_attention (`bool`, *optional*, defaults to True):
+ Whether to use flexible attention for cross attention.
+ cross_attn_init_by_pooling (`bool`, *optional*, defaults to False):
+ Whether to initialize cross attention by pooling.
+
+ # Encoder configurations
+ use_local_encoder_transformer (`bool`, *optional*, defaults to False):
+ Whether to use transformer in local encoder.
+ max_encoder_seq_length (`int`, *optional*):
+ Maximum sequence length for encoder.
+ encoder_hash_byte_group_size (`Any`, *optional*):
+ Hash byte group size for encoder.
+ encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 30000):
+ Vocabulary size for hash byte groups.
+ encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 3):
+ Number of hash functions for byte groups.
+ encoder_enable_byte_ngrams (`bool`, *optional*, defaults to False):
+ Whether to enable byte n-grams in encoder.
+ encoder_ngram_to_size_str (`str`, *optional*):
+ String defining n-gram sizes.
+ downsampling_by_pooling (`str`, *optional*):
+ Type of pooling for downsampling.
+
+ # Model behavior
+ share_encoder_decoder_emb (`bool`, *optional*, defaults to True):
+ Whether to share encoder and decoder embeddings.
+ weight_tying (`bool`, *optional*, defaults to False):
+ Whether to tie input and output embeddings.
+
+ # Performance optimization
+ sequence_parallel (`bool`, *optional*, defaults to False):
+ Whether to use sequence parallelism.
+ loss_parallel (`bool`, *optional*, defaults to False):
+ Whether to use loss parallelism.
+ fuse_sequence_parallel (`bool`, *optional*, defaults to False):
+ Whether to fuse sequence parallel operations.
+ use_fsdp (`bool`, *optional*, defaults to True):
+ Whether to use fully sharded data parallel.
+
+ # Parameter mixing
+ pm_size (`int`, *optional*, defaults to 0):
+ Parameter mixing size.
+
+ # Special tokens
+ bos_token_id (`int`, *optional*, defaults to 1):
+ The id of the "beginning-of-sequence" token.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the "end-of-sequence" token.
+ pad_token_id (`int`, *optional*, defaults to -1):
+ The id of the padding token.
+
+ # Patcher/Entropy model configuration
+ patcher_vocab_size (`int`, *optional*, defaults to 256):
+ Vocabulary size for the entropy model used in patching.
+ patcher_dim (`int`, *optional*, defaults to 512):
+ Hidden dimension for the entropy model.
+ patcher_n_layers (`int`, *optional*, defaults to 8):
+ Number of layers in the entropy model.
+ patcher_n_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads in the entropy model.
+ patcher_head_dim (`int`, *optional*):
+ Dimension of each attention head in the entropy model.
+ patcher_n_kv_heads (`int`, *optional*):
+ Number of key-value heads in the entropy model.
+ patcher_max_seqlen (`int`, *optional*, defaults to 1024):
+ Maximum sequence length for the entropy model.
+ patcher_norm_eps (`float`, *optional*, defaults to 1e-5):
+ Layer normalization epsilon for the entropy model.
+ patcher_dropout (`float`, *optional*, defaults to 0.0):
+ Dropout probability for the entropy model.
+ patcher_sliding_window (`int`, *optional*):
+ Sliding window size for the entropy model attention.
+ patcher_ffn_dim_multiplier (`float`, *optional*):
+ Feedforward dimension multiplier for the entropy model.
+ patcher_multiple_of (`int`, *optional*, defaults to 256):
+ Make feedforward dimension multiple of this for the entropy model.
+ patcher_rope_theta (`float`, *optional*, defaults to 10000.0):
+ RoPE theta parameter for the entropy model.
+ patcher_rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False):
+ Whether to use fp32 in RoPE outer product for the entropy model.
+ patcher_attn_impl (`str`, *optional*, defaults to "sdpa"):
+ Attention implementation for the entropy model.
+ patcher_attn_bias_type (`str`, *optional*, defaults to "causal"):
+ Attention bias type for the entropy model.
+ patcher_init_base_std (`float`, *optional*):
+ Base initialization standard deviation for the entropy model.
+ patcher_init_std_factor (`str`, *optional*, defaults to "disabled"):
+ Initialization std factor for the entropy model.
+ patcher_dim_token_emb (`int`, *optional*):
+ Token embedding dimension for the entropy model.
+ patcher_weight_tying (`bool`, *optional*, defaults to False):
+ Whether to tie embeddings in the entropy model.
+ patcher_bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of sequence token id for the entropy model.
+ patcher_eos_token_id (`int`, *optional*, defaults to 2):
+ End of sequence token id for the entropy model.
+
+ ```python
+ >>> from transformers import ByteLatentTransformer, BLTConfig
+
+ >>> # Initializing a BLT configuration
+ >>> configuration = BLTConfig()
+
+ >>> # Initializing a model from the configuration
+ >>> model = ByteLatentTransformer(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "blt"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=256,
+ max_seqlen=1024,
+ # Main architecture dimensions
+ dim=512,
+ n_layers=8,
+ n_heads=8,
+ head_dim=None,
+ n_kv_heads=None,
+ # Component-specific dimensions
+ dim_global=512,
+ dim_local_decoder=512,
+ dim_local_encoder=512,
+ n_layers_global=8,
+ n_layers_local_decoder=8,
+ n_layers_local_encoder=8,
+ n_heads_global=8,
+ n_heads_local_decoder=8,
+ n_heads_local_encoder=8,
+ n_kv_heads_global=None,
+ # Transformer configuration
+ norm_eps=1e-5,
+ dropout=0.0,
+ ffn_dim_multiplier=1.0,
+ multiple_of=256,
+ # Positional encoding
+ rope_theta=10000.0,
+ rope_use_fp32_in_outer_product=False,
+ # Attention configuration
+ attn_impl="sdpa",
+ attn_bias_type="causal",
+ local_attention_window_len=None,
+ use_rope=True,
+ # Initialization
+ init_base_std=None,
+ init_std_factor="disabled",
+ # Embedding dimensions
+ dim_token_emb=None,
+ dim_token=None,
+ # Patching configuration
+ patch_in_forward=False,
+ realtime_patching=True,
+ patch_size=None,
+ patching_mode=None,
+ patching_threshold=None,
+ patching_threshold_add=None,
+ monotonicity=False,
+ patching_batch_size=1,
+ patching_device="cuda",
+ max_patch_length=None,
+ entropy_model_checkpoint_dir=None,
+ # Cross attention configurations
+ cross_attn_encoder=False,
+ cross_attn_decoder=False,
+ cross_attn_window_encoder=None,
+ cross_attn_window_decoder=None,
+ cross_attn_k=None,
+ cross_attn_nheads=None,
+ cross_attn_all_layers_decoder=False,
+ cross_attn_all_layers_encoder=False,
+ cross_attn_use_flex_attention=True,
+ cross_attn_init_by_pooling=False,
+ # Encoder configurations
+ use_local_encoder_transformer=False,
+ max_encoder_seq_length=None,
+ encoder_hash_byte_group_size=None,
+ encoder_hash_byte_group_vocab=30000,
+ encoder_hash_byte_group_nb_functions=3,
+ encoder_enable_byte_ngrams=False,
+ encoder_ngram_to_size_str=None,
+ downsampling_by_pooling=None,
+ # Model behavior
+ share_encoder_decoder_emb=True,
+ weight_tying=False,
+ # Performance optimization
+ sequence_parallel=False,
+ loss_parallel=False,
+ fuse_sequence_parallel=False,
+ use_fsdp=True,
+ # Parameter mixing
+ pm_size=0,
+ # Special tokens
+ bos_token_id=1,
+ eos_token_id=2,
+ pad_token_id=-1,
+ # Patcher/Entropy model configuration
+ patcher_vocab_size=256,
+ patcher_dim=512,
+ patcher_n_layers=8,
+ patcher_n_heads=8,
+ patcher_head_dim=None,
+ patcher_n_kv_heads=None,
+ patcher_max_seqlen=1024,
+ patcher_norm_eps=1e-5,
+ patcher_dropout=0.0,
+ patcher_sliding_window=None,
+ patcher_ffn_dim_multiplier=None,
+ patcher_multiple_of=256,
+ patcher_rope_theta=10000.0,
+ patcher_rope_use_fp32_in_outer_product=False,
+ patcher_attn_impl="sdpa",
+ patcher_attn_bias_type="causal",
+ patcher_init_base_std=None,
+ patcher_init_std_factor="disabled",
+ patcher_dim_token_emb=None,
+ patcher_weight_tying=False,
+ patcher_bos_token_id=1,
+ patcher_eos_token_id=2,
+ # Inherited
+ **kwargs,
+ ):
+ # Basic model configuration
+ self.vocab_size = vocab_size
+ self.max_seqlen = max_seqlen
+
+ # Main architecture dimensions
+ self.dim = dim
+ self.n_layers = n_layers
+ self.n_heads = n_heads
+ self.head_dim = head_dim
+ self.n_kv_heads = n_kv_heads
+
+ # Component-specific dimensions
+ self.dim_global = dim_global
+ self.dim_local_decoder = dim_local_decoder
+ self.dim_local_encoder = dim_local_encoder
+ self.n_layers_global = n_layers_global
+ self.n_layers_local_decoder = n_layers_local_decoder
+ self.n_layers_local_encoder = n_layers_local_encoder
+ self.n_heads_global = n_heads_global
+ self.n_heads_local_decoder = n_heads_local_decoder
+ self.n_heads_local_encoder = n_heads_local_encoder
+ self.n_kv_heads_global = n_kv_heads_global
+
+ # Transformer configuration
+ self.norm_eps = norm_eps
+ self.dropout = dropout
+ self.ffn_dim_multiplier = ffn_dim_multiplier
+ self.multiple_of = multiple_of
+
+ # Positional encoding
+ self.rope_theta = rope_theta
+ self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
+
+ # Attention configuration
+ self.attn_impl = attn_impl
+ self.attn_bias_type = attn_bias_type
+ self.local_attention_window_len = local_attention_window_len
+ self.use_rope = use_rope
+
+ # Initialization
+ self.init_base_std = init_base_std
+ self.init_std_factor = InitStdFactor(init_std_factor)
+
+ # Embedding dimensions
+ self.dim_token_emb = dim_token_emb
+ self.dim_token = dim_token
+
+ # Patching configuration
+ self.patch_in_forward = patch_in_forward
+ self.realtime_patching = realtime_patching
+ self.patch_size = patch_size
+ self.patching_mode = patching_mode
+ self.patching_threshold = patching_threshold
+ self.patching_threshold_add = patching_threshold_add
+ self.monotonicity = monotonicity
+ self.patching_batch_size = patching_batch_size
+ self.patching_device = patching_device
+ self.max_patch_length = max_patch_length
+ self.entropy_model_checkpoint_dir = entropy_model_checkpoint_dir
+
+ # Cross attention configurations
+ self.cross_attn_encoder = cross_attn_encoder
+ self.cross_attn_decoder = cross_attn_decoder
+ self.cross_attn_window_encoder = cross_attn_window_encoder
+ self.cross_attn_window_decoder = cross_attn_window_decoder
+ self.cross_attn_k = cross_attn_k
+ self.cross_attn_nheads = cross_attn_nheads
+ self.cross_attn_all_layers_decoder = cross_attn_all_layers_decoder
+ self.cross_attn_all_layers_encoder = cross_attn_all_layers_encoder
+ self.cross_attn_use_flex_attention = cross_attn_use_flex_attention
+ self.cross_attn_init_by_pooling = cross_attn_init_by_pooling
+
+ # Encoder configurations
+ self.use_local_encoder_transformer = use_local_encoder_transformer
+ self.max_encoder_seq_length = max_encoder_seq_length
+ self.encoder_hash_byte_group_size = encoder_hash_byte_group_size
+ self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab
+ self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions
+ self.encoder_enable_byte_ngrams = encoder_enable_byte_ngrams
+ self.encoder_ngram_to_size_str = encoder_ngram_to_size_str
+ self.downsampling_by_pooling = downsampling_by_pooling
+
+ # Model behavior
+ self.share_encoder_decoder_emb = share_encoder_decoder_emb
+ self.weight_tying = weight_tying
+
+ # Performance optimization
+ self.sequence_parallel = sequence_parallel
+ self.loss_parallel = loss_parallel
+ self.fuse_sequence_parallel = fuse_sequence_parallel
+ self.use_fsdp = use_fsdp
+
+ # Parameter mixing
+ self.pm_size = pm_size
+
+ # Patcher/Entropy model configuration
+ self.patcher_vocab_size = patcher_vocab_size
+ self.patcher_dim = patcher_dim
+ self.patcher_n_layers = patcher_n_layers
+ self.patcher_n_heads = patcher_n_heads
+ self.patcher_head_dim = patcher_head_dim
+ self.patcher_n_kv_heads = patcher_n_kv_heads
+ self.patcher_max_seqlen = patcher_max_seqlen
+ self.patcher_norm_eps = patcher_norm_eps
+ self.patcher_dropout = patcher_dropout
+ self.patcher_sliding_window = patcher_sliding_window
+ self.patcher_ffn_dim_multiplier = patcher_ffn_dim_multiplier
+ self.patcher_multiple_of = patcher_multiple_of
+ self.patcher_rope_theta = patcher_rope_theta
+ self.patcher_rope_use_fp32_in_outer_product = patcher_rope_use_fp32_in_outer_product
+ self.patcher_attn_impl = patcher_attn_impl
+ self.patcher_attn_bias_type = patcher_attn_bias_type
+ self.patcher_init_base_std = patcher_init_base_std
+ self.patcher_init_std_factor = InitStdFactor(patcher_init_std_factor)
+ self.patcher_dim_token_emb = patcher_dim_token_emb
+ self.patcher_weight_tying = patcher_weight_tying
+ self.patcher_bos_token_id = patcher_bos_token_id
+ self.patcher_eos_token_id = patcher_eos_token_id
+
+ # Handle hash byte group size validation
+ if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str:
+ self.encoder_hash_byte_group_size = [
+ int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0
+ ]
+
+ super().__init__(
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ **kwargs,
+ )
+
+ @property
+ def encoder_dim_token_emb(self):
+ """Compute encoder token embedding dimension."""
+ if self.dim_token is not None:
+ return self.dim_token
+ elif self.use_local_encoder_transformer:
+ return self.dim_local_encoder
+ else:
+ # Use default patch_size of 8 if not set
+ patch_size = self.patch_size if self.patch_size is not None else 8
+ return self.dim_global // patch_size
+
+ @property
+ def encoder_dim_patch_emb(self):
+ """Compute encoder patch embedding dimension."""
+ if self.cross_attn_encoder:
+ if self.cross_attn_init_by_pooling:
+ return self.dim_local_encoder
+ else:
+ return self.dim_global
+ return None
+
+ @property
+ def global_dim_patch_emb(self):
+ """Compute global patch embedding dimension."""
+ dim_token_emb = self.encoder_dim_token_emb
+ if self.cross_attn_encoder:
+ cross_attn_k = self.cross_attn_k if self.cross_attn_k is not None else 1
+ return dim_token_emb * cross_attn_k
+ elif (
+ self.downsampling_by_pooling is None
+ or not self.downsampling_by_pooling
+ or len(self.downsampling_by_pooling) == 0
+ ):
+ # Use default patch_size of 8 if not set
+ patch_size = self.patch_size if self.patch_size is not None else 8
+ return dim_token_emb * patch_size
+ else:
+ return dim_token_emb * sum([pooling in self.downsampling_by_pooling for pooling in ["avg", "min", "max"]])
+
+ @property
+ def decoder_dim_token_emb(self):
+ """Compute decoder token embedding dimension."""
+ if self.share_encoder_decoder_emb:
+ return self.encoder_dim_token_emb
+ elif self.dim_token is not None:
+ return self.dim_token
+ else:
+ return self.dim_local_decoder
+
+ def get_init_std_factor(self, depth: int) -> float:
+ """
+ Calculate the initialization standard deviation scaling factor for a given layer depth.
+
+ Args:
+ depth: Current layer depth (0-indexed)
+
+ Returns:
+ Scaling factor to divide the base initialization std by
+ """
+ if self.init_std_factor == InitStdFactor.CURRENT_DEPTH:
+ return (2 * (depth + 1)) ** 0.5
+ else: # DISABLED
+ return 1.0
+
+
+__all__ = ["BLTConfig", "InitStdFactor", "PatchingModeEnum"]
diff --git a/backup_blt_wip_backup/convert_hf_blt_original_to_unified.py b/backup_blt_wip_backup/convert_hf_blt_original_to_unified.py
new file mode 100644
index 0000000000000000000000000000000000000000..dad247b19c62d983d71526ecdc8ff6c13ca9a5c8
--- /dev/null
+++ b/backup_blt_wip_backup/convert_hf_blt_original_to_unified.py
@@ -0,0 +1,540 @@
+import argparse
+import json
+import logging
+import os
+from typing import Dict, Any, Optional
+
+import torch
+from huggingface_hub import hf_hub_download, snapshot_download
+from safetensors.torch import load_file, save_file
+
+from transformers.utils import logging as transformers_logging
+
+logger = transformers_logging.get_logger(__name__)
+transformers_logging.set_verbosity_info()
+
+# For standalone execution, we'll skip the model validation to avoid import issues
+# The script will create the unified config and weights files without testing model instantiation
+ENABLE_MODEL_VALIDATION = False
+
+import sys
+import os
+
+from transformers.models.blt_wip.modeling_blt_wip import BLTModel
+from transformers.models.blt_wip.configuration_blt import BLTConfig
+
+
+ENABLE_MODEL_VALIDATION = True
+
+def download_model_files(model_id: str, cache_dir: Optional[str] = None) -> Dict[str, str]:
+ """
+ Download all necessary files from HuggingFace Hub.
+
+ Args:
+ model_id: HuggingFace model ID (e.g., "facebook/blt-1b")
+ cache_dir: Optional cache directory
+
+ Returns:
+ Dictionary with paths to downloaded files
+ """
+ logger.info(f"Downloading model files from {model_id}...")
+
+ try:
+ # Download main config
+ config_path = hf_hub_download(
+ repo_id=model_id,
+ filename="config.json",
+ cache_dir=cache_dir
+ )
+
+ # Download main model weights
+ weights_path = hf_hub_download(
+ repo_id=model_id,
+ filename="model.safetensors",
+ cache_dir=cache_dir
+ )
+
+ # Download entropy model params
+ entropy_params_path = hf_hub_download(
+ repo_id=model_id,
+ filename="entropy_model/params.json",
+ cache_dir=cache_dir
+ )
+
+ # Download entropy model weights
+ entropy_weights_path = hf_hub_download(
+ repo_id=model_id,
+ filename="entropy_model/consolidated.pth",
+ cache_dir=cache_dir
+ )
+
+ return {
+ "config": config_path,
+ "weights": weights_path,
+ "entropy_params": entropy_params_path,
+ "entropy_weights": entropy_weights_path
+ }
+
+ except Exception as e:
+ logger.error(f"Failed to download files from {model_id}: {e}")
+ raise
+
+
+def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]:
+ """
+ Merge main configuration with entropy model parameters.
+
+ Args:
+ config_path: Path to main config.json
+ entropy_params_path: Path to entropy_model/params.json
+
+ Returns:
+ Merged configuration dictionary
+ """
+ logger.info("Merging configurations...")
+
+ # Load main configuration
+ with open(config_path, 'r') as f:
+ main_config = json.load(f)
+
+ # Load entropy model parameters
+ with open(entropy_params_path, 'r') as f:
+ entropy_data = json.load(f)
+
+ # Extract entropy model and patcher parameters
+ entropy_model_params = entropy_data.get("entropy_model", {})
+ patcher_args = entropy_data.get("data", {}).get("patcher_args", {})
+
+ # Create unified configuration
+ unified_config = main_config.copy()
+
+ # Ensure required main model parameters are present with correct types
+ # Sometimes the original config may have different key names
+ if "vocab_size" not in unified_config:
+ unified_config["vocab_size"] = int(main_config.get("vocab_size", 256))
+ if "dim" not in unified_config:
+ unified_config["dim"] = int(main_config.get("dim", main_config.get("hidden_size", main_config.get("d_model", 512))))
+ if "n_layers" not in unified_config:
+ unified_config["n_layers"] = int(main_config.get("n_layers", main_config.get("num_layers", main_config.get("num_hidden_layers", 8))))
+ if "n_heads" not in unified_config:
+ unified_config["n_heads"] = int(main_config.get("n_heads", main_config.get("num_attention_heads", main_config.get("num_heads", 8))))
+ if "max_seqlen" not in unified_config:
+ unified_config["max_seqlen"] = int(main_config.get("max_seqlen", main_config.get("max_position_embeddings", main_config.get("seq_length", 1024))))
+
+ # Ensure other integer parameters are properly typed
+ for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]:
+ if key in unified_config and not isinstance(unified_config[key], int):
+ unified_config[key] = int(unified_config[key])
+
+ # Convert all patch_size values to integers to avoid float/int type errors
+ patch_size = patcher_args.get("patch_size", 8)
+ if isinstance(patch_size, float):
+ patch_size = int(patch_size)
+
+ # Add patching configuration
+ unified_config.update({
+ "patch_in_forward": True,
+ "realtime_patching": True,
+ "patching_mode": "entropy",
+
+ # Patcher arguments
+ "patch_size": patch_size,
+ "patching_threshold": patcher_args.get("threshold", 0.5),
+ "patching_threshold_add": patcher_args.get("threshold_add", 0.0),
+ "max_patch_length": patcher_args.get("max_patch_length"),
+ "patching_batch_size": patcher_args.get("patching_batch_size", 1),
+ "patching_device": patcher_args.get("patching_device", "cuda"),
+ "monotonicity": patcher_args.get("monotonicity", False),
+
+ # Entropy model (patcher) architecture parameters
+ "patcher_vocab_size": int(entropy_model_params.get("vocab_size", 256)),
+ "patcher_dim": int(entropy_model_params.get("dim", 512)),
+ "patcher_n_layers": int(entropy_model_params.get("n_layers", 8)),
+ "patcher_n_heads": int(entropy_model_params.get("n_heads", 8)),
+ "patcher_head_dim": int(entropy_model_params.get("head_dim")) if entropy_model_params.get("head_dim") is not None else None,
+ "patcher_n_kv_heads": int(entropy_model_params.get("n_kv_heads")) if entropy_model_params.get("n_kv_heads") is not None else None,
+ "patcher_max_seqlen": int(entropy_model_params.get("max_seqlen", 1024)),
+ "patcher_norm_eps": entropy_model_params.get("norm_eps", 1e-5),
+ "patcher_dropout": entropy_model_params.get("dropout", 0.0),
+ "patcher_sliding_window": int(entropy_model_params.get("sliding_window", 512)) if entropy_model_params.get("sliding_window") is not None else None,
+ "patcher_ffn_dim_multiplier": entropy_model_params.get("ffn_dim_multiplier"),
+ "patcher_multiple_of": int(entropy_model_params.get("multiple_of", 256)),
+ "patcher_rope_theta": entropy_model_params.get("rope_theta", 10000.0),
+ "patcher_rope_use_fp32_in_outer_product": entropy_model_params.get("rope_use_fp32_in_outer_product", False),
+ "patcher_attn_impl": entropy_model_params.get("attn_impl", "sdpa"),
+ "patcher_attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"),
+ "patcher_init_base_std": entropy_model_params.get("init_base_std"),
+ "patcher_init_std_factor": entropy_model_params.get("init_std_factor", "disabled"),
+ "patcher_dim_token_emb": entropy_model_params.get("dim_token_emb"),
+ "patcher_weight_tying": entropy_model_params.get("weight_tying", False),
+ "patcher_bos_token_id": entropy_model_params.get("bos_token_id", 1),
+ "patcher_eos_token_id": entropy_model_params.get("eos_token_id", 2),
+ })
+
+ logger.info(f"Merged configuration with {len(unified_config)} parameters")
+ return unified_config
+
+
+def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]:
+ """
+ Merge main model weights with entropy model weights.
+
+ Args:
+ weights_path: Path to main model.safetensors
+ entropy_weights_path: Path to entropy_model/consolidated.pth
+
+ Returns:
+ Merged state dictionary
+ """
+ logger.info("Merging model weights...")
+
+ # Load main model weights
+ main_weights = load_file(weights_path)
+ logger.info(f"Loaded main model weights: {len(main_weights)} tensors")
+
+ # Load entropy model weights
+ entropy_weights = torch.load(entropy_weights_path, map_location='cpu', weights_only=True)
+
+ # Handle nested entropy model structure
+ if 'model' in entropy_weights:
+ entropy_weights = entropy_weights['model']
+ elif 'state_dict' in entropy_weights:
+ entropy_weights = entropy_weights['state_dict']
+
+ logger.info(f"Loaded entropy model weights: {len(entropy_weights)} tensors")
+
+ # Create unified state dict
+ unified_weights = main_weights.copy()
+
+ # Add entropy model weights with "patcher." prefix
+ for key, tensor in entropy_weights.items():
+ patcher_key = f"patcher.{key}"
+ unified_weights[patcher_key] = tensor
+
+ logger.info(f"Merged weights: {len(unified_weights)} tensors total")
+ return unified_weights
+
+
+def create_tokenizer_config(output_dir: str, config: Dict[str, Any]):
+ """
+ Create tokenizer configuration file.
+
+ Args:
+ output_dir: Output directory
+ config: Model configuration
+ """
+ logger.info("Creating tokenizer configuration...")
+
+ tokenizer_config = {
+ "tokenizer_class": "BltTokenizer",
+ "vocab_size": config.get("vocab_size", 256),
+ "model_max_length": config.get("max_seqlen", 1024),
+ "add_bos_token": True,
+ "add_eos_token": True,
+ "bos_token": "",
+ "eos_token": "",
+ "pad_token": "",
+ "unk_token": "",
+ }
+
+ tokenizer_path = os.path.join(output_dir, "tokenizer_config.json")
+ with open(tokenizer_path, 'w') as f:
+ json.dump(tokenizer_config, f, indent=2)
+
+ logger.info(f"Tokenizer config saved to {tokenizer_path}")
+
+
+def validate_unified_model(config: Dict[str, Any], weights: Dict[str, torch.Tensor]):
+ """
+ Validate the unified model configuration and weights.
+
+ Args:
+ config: Unified configuration
+ weights: Unified weights
+ """
+ logger.info("Validating unified model...")
+
+ # Check required configuration keys
+ required_keys = [
+ "vocab_size", "dim", "n_layers", "n_heads",
+ "patch_in_forward", "patcher_vocab_size", "patcher_dim"
+ ]
+
+ missing_keys = [key for key in required_keys if key not in config]
+ if missing_keys:
+ logger.warning(f"Missing configuration keys: {missing_keys}")
+
+ # Check for patcher weights
+ patcher_weights = [key for key in weights.keys() if key.startswith("patcher.")]
+ if not patcher_weights:
+ logger.warning("No patcher weights found in unified weights")
+ else:
+ logger.info(f"Found {len(patcher_weights)} patcher weight tensors")
+
+ # Check for main model weights
+ main_weights = [key for key in weights.keys() if not key.startswith("patcher.")]
+ logger.info(f"Found {len(main_weights)} main model weight tensors")
+
+ # Try to create the model with the configuration (if imports are available)
+ if ENABLE_MODEL_VALIDATION and BLTConfig is not None and BLTModel is not None:
+ try:
+ logger.info("Testing model instantiation...")
+
+ # Debug: Print config keys to help diagnose issues
+ logger.debug(f"Config keys: {list(config.keys())}")
+ logger.debug(f"Config vocab_size: {config.get('vocab_size')} (type: {type(config.get('vocab_size'))})")
+ logger.debug(f"Config dim: {config.get('dim')} (type: {type(config.get('dim'))})")
+
+ blt_config = BLTConfig(**config)
+ model = BLTModel(blt_config)
+ logger.info("✓ Model instantiation successful")
+
+ # Try to load the weights
+ logger.info("Testing weight loading...")
+ try:
+ missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False)
+ if missing_keys:
+ logger.warning(f"Missing keys during weight loading: {missing_keys[:5]}...") # Show first 5
+ if unexpected_keys:
+ logger.warning(f"Unexpected keys during weight loading: {unexpected_keys[:5]}...") # Show first 5
+ logger.info("✓ Weight loading successful")
+ except Exception as weight_error:
+ logger.warning(f"Weight loading failed: {weight_error}")
+ logger.info("Model instantiation successful, but weight loading had issues")
+
+ except Exception as e:
+ logger.error(f"Model validation failed: {e}")
+ logger.debug(f"Full error details:", exc_info=True)
+ logger.warning("Model may not be compatible with modeling_blt_wip.py")
+ logger.info("You can still use the converted files and test manually")
+ else:
+ logger.info("Skipping model instantiation test (BLT classes not available)")
+ logger.info("You can test the model manually after conversion")
+
+ logger.info("Model validation completed")
+
+
+def convert_hf_blt_to_unified(
+ model_id: str,
+ output_dir: str,
+ config_name: str = "config.json",
+ weights_name: str = "pytorch_model.bin",
+ safe_serialization: bool = True,
+ cache_dir: Optional[str] = None,
+ validate: bool = True,
+) -> None:
+ """
+ Convert BLT model from HuggingFace Hub format to unified format.
+
+ Args:
+ model_id: HuggingFace model ID (e.g., "facebook/blt-1b")
+ output_dir: Output directory for unified model
+ config_name: Name for unified config file
+ weights_name: Name for unified weights file
+ safe_serialization: Whether to use safetensors format
+ cache_dir: Cache directory for downloads
+ validate: Whether to validate the unified model
+ """
+ logger.info(f"Converting {model_id} to unified format...")
+
+ # Download model files
+ file_paths = download_model_files(model_id, cache_dir)
+
+ # Merge configurations
+ unified_config = merge_configurations(
+ file_paths["config"],
+ file_paths["entropy_params"]
+ )
+
+ # Merge weights
+ unified_weights = merge_weights(
+ file_paths["weights"],
+ file_paths["entropy_weights"]
+ )
+
+ # Validate if requested
+ if validate:
+ validate_unified_model(unified_config, unified_weights)
+
+ # Create output directory
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Save unified configuration
+ config_path = os.path.join(output_dir, config_name)
+ with open(config_path, 'w') as f:
+ json.dump(unified_config, f, indent=2)
+ logger.info(f"Unified config saved to {config_path}")
+
+ # Save unified weights
+ if safe_serialization and weights_name.endswith('.bin'):
+ weights_name = weights_name.replace('.bin', '.safetensors')
+ elif not safe_serialization and weights_name.endswith('.safetensors'):
+ weights_name = weights_name.replace('.safetensors', '.bin')
+
+ weights_path = os.path.join(output_dir, weights_name)
+ if safe_serialization:
+ save_file(unified_weights, weights_path)
+ else:
+ torch.save(unified_weights, weights_path)
+ logger.info(f"Unified weights saved to {weights_path}")
+
+ # Create tokenizer config
+ create_tokenizer_config(output_dir, unified_config)
+
+ # Create README
+ readme_path = os.path.join(output_dir, "README.md")
+ with open(readme_path, 'w') as f:
+ f.write(f"""# Unified BLT Model
+
+This model was converted from {model_id} to unified format compatible with modeling_blt_wip.py.
+
+## Files
+
+- `{config_name}`: Unified configuration (main config + entropy model params)
+- `{weights_name}`: Unified weights (main model + entropy model weights with "patcher." prefix)
+- `tokenizer_config.json`: Tokenizer configuration
+
+## Usage
+
+```python
+import torch
+import json
+from modeling_blt_wip import BLTModel, BLTConfig
+
+# Load configuration
+with open('{config_name}', 'r') as f:
+ config_dict = json.load(f)
+
+config = BLTConfig(**config_dict)
+
+# Load model
+model = BLTModel(config)
+
+# Load weights
+if '{weights_name}'.endswith('.safetensors'):
+ from safetensors.torch import load_file
+ state_dict = load_file('{weights_name}')
+else:
+ state_dict = torch.load('{weights_name}', map_location='cpu')
+
+model.load_state_dict(state_dict, strict=False)
+```
+
+## Original Model
+
+Converted from: {model_id}
+""")
+
+ logger.info(f"Conversion completed! Unified model saved to: {output_dir}")
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Convert BLT models from HuggingFace Hub format to unified format",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Convert facebook/blt-1b to unified format
+ python convert_hf_blt_to_unified.py \\
+ --model_id facebook/blt-1b \\
+ --output_dir ./unified_blt_1b
+
+ # Convert with custom file names
+ python convert_hf_blt_to_unified.py \\
+ --model_id facebook/blt-7b \\
+ --output_dir ./unified_blt_7b \\
+ --config_name unified_config.json \\
+ --weights_name unified_model.safetensors
+
+ # Convert without validation
+ python convert_hf_blt_to_unified.py \\
+ --model_id facebook/blt-1b \\
+ --output_dir ./my_blt \\
+ --no_validate
+ """
+ )
+
+ # Required arguments (with defaults for debugging)
+ parser.add_argument(
+ "--model_id",
+ type=str,
+ default="facebook/blt-1b",
+ help="HuggingFace model ID (e.g., facebook/blt-1b)"
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="./unified_blt_debug",
+ help="Output directory for unified model"
+ )
+
+ # Optional arguments
+ parser.add_argument(
+ "--config_name",
+ type=str,
+ default="config.json",
+ help="Name for unified config file (default: config.json)"
+ )
+ parser.add_argument(
+ "--weights_name",
+ type=str,
+ default="pytorch_model.bin",
+ help="Name for unified weights file (default: pytorch_model.bin)"
+ )
+ parser.add_argument(
+ "--safe_serialization",
+ action="store_true",
+ default=True,
+ help="Use safetensors format for weights (default: True)"
+ )
+ parser.add_argument(
+ "--no_safe_serialization",
+ dest="safe_serialization",
+ action="store_false",
+ help="Use .bin format instead of safetensors"
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="Cache directory for downloads"
+ )
+ parser.add_argument(
+ "--no_validate",
+ dest="validate",
+ action="store_false",
+ default=True,
+ help="Skip model validation"
+ )
+ parser.add_argument(
+ "--debug",
+ action="store_true",
+ default=True, # Enable debug by default for easier debugging
+ help="Enable debug logging"
+ )
+
+ args = parser.parse_args()
+
+ # Setup logging
+ if args.debug:
+ transformers_logging.set_verbosity_debug()
+ logging.basicConfig(level=logging.DEBUG)
+
+ # Run conversion
+ try:
+ convert_hf_blt_to_unified(
+ model_id=args.model_id,
+ output_dir=args.output_dir,
+ config_name=args.config_name,
+ weights_name=args.weights_name,
+ safe_serialization=args.safe_serialization,
+ cache_dir=args.cache_dir,
+ validate=args.validate,
+ )
+ except Exception as e:
+ logger.error(f"Conversion failed: {e}")
+ raise
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/backup_blt_wip_backup/modeling_blt_wip.py b/backup_blt_wip_backup/modeling_blt_wip.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe846fc0ab9f7f5e794e32fd3809aea26451676c
--- /dev/null
+++ b/backup_blt_wip_backup/modeling_blt_wip.py
@@ -0,0 +1,1836 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import logging
+import os
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn
+import torch.nn as nn
+from torch.nn import functional as F
+from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
+
+from ...modeling_utils import PreTrainedModel
+from .configuration_blt import (
+ BLTConfig,
+ PatchingModeEnum,
+)
+
+
+SEP = " "
+BOS_ID: int = 1
+EOS_ID: int = 2
+PAD_ID: int = -1
+BOE_ID: int = 0
+BPE_ID: int = 3
+OFFSET: int = 4
+
+BYTE_UNITS: int = 256
+
+RMSNorm = nn.RMSNorm
+
+logger = logging.getLogger()
+
+flex_attention_comp = flex_attention
+
+
+def causal_mask(b, h, q_idx, kv_idx):
+ return q_idx >= kv_idx
+
+
+def create_causal_mask(
+ seqlen,
+ attn_impl: str,
+ attn_bias_type: str | None,
+ *,
+ eos_id: int | None = None,
+ tokens: torch.Tensor | None = None,
+ sliding_window: int | None = None,
+):
+ if attn_impl == "sdpa":
+ BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0))
+
+ if attn_bias_type == "causal":
+ return "causal"
+
+ if BLT_SUPPRESS_ATTN_ERROR == 1:
+ return "causal"
+ else:
+ raise ValueError(
+ "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1"
+ )
+ elif attn_impl == "flex_attention":
+ return create_block_mask(causal_mask, None, None, seqlen, seqlen)
+ else:
+ raise NotImplementedError(f"Attention {attn_impl} with {sliding_window} sliding window not implemented")
+
+
+def cross_entropy(pred, target, **kwargs):
+ return F.nll_loss(
+ F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
+ target.flatten(end_dim=-1),
+ **kwargs,
+ )
+
+
+def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
+ assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
+ bs, slen, n_kv_heads, head_dim = x.shape
+ if n_rep == 1:
+ return x
+ return (
+ x[:, :, :, None, :]
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
+ )
+
+
+def precompute_freqs_cis(
+ dim: int,
+ end: int,
+ theta: float = 10000.0,
+ rope_use_fp32_in_outer_product: bool = False,
+):
+ """
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
+ The returned tensor contains complex values in complex64 data type.
+
+ Args:
+ dim (int): Dimension of the frequency tensor.
+ end (int): End index for precomputing frequencies.
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
+
+ Returns:
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
+ """
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ t = torch.arange(end, device=freqs.device)
+ if rope_use_fp32_in_outer_product:
+ t = t.to(torch.float32)
+
+ freqs = torch.outer(t, freqs).float()
+
+ cos, sin = freqs.cos(), freqs.sin()
+
+ return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
+ """
+ Reshape frequency tensor for broadcasting it with another tensor.
+
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
+
+ Args:
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
+ seq_dim (int): Sequence dimension index.
+
+ Returns:
+ torch.Tensor: Reshaped frequency tensor.
+ """
+ ndim = x.ndim
+ assert 0 <= seq_dim < ndim
+ assert freqs_cis.shape == (
+ x.shape[seq_dim],
+ x.shape[-3],
+ 2,
+ 2,
+ ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
+ shape = [d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])] + [2, 2]
+ return freqs_cis.view(*shape)
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ seq_dim: int,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
+ xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, seq_dim).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
+ xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
+ xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
+class RotaryEmbedding(torch.nn.Module):
+ """
+ RotaryEmbedding Module
+ """
+
+ def __init__(
+ self,
+ theta: float,
+ head_dim: int,
+ max_seqlen: int = 1024,
+ rope_use_fp32_in_outer_product: bool = False,
+ ):
+ super().__init__()
+
+ self.theta = theta
+ self.head_dim = head_dim
+ self.max_seqlen = max_seqlen
+ self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
+
+ self.register_buffer(
+ "freqs_cis",
+ precompute_freqs_cis(
+ dim=head_dim,
+ end=max_seqlen,
+ theta=theta,
+ rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
+ ),
+ persistent=False,
+ )
+
+
+ def forward(self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None):
+ """
+ Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
+ Args:
+ seqlen (int): Contiguous sequence length
+ tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
+
+ Returns:
+ Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
+ """
+ test = (seqlen is not None) or (tok_idx is not None)
+ assert test, "Should provide atleast seqlen or tok_idx"
+ if tok_idx is not None:
+ return self.freqs_cis[tok_idx]
+ elif seqlen is not None:
+ return self.freqs_cis[0:seqlen]
+
+
+class BLTAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ head_dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ rope_theta: float,
+ ):
+ super().__init__()
+
+ self.dim = dim
+ self.head_dim = head_dim
+ self.rope_theta = rope_theta
+
+ self.n_heads = n_heads
+ self.n_kv_heads = n_kv_heads
+ self.heads_per_group = self.n_heads // self.n_kv_heads
+
+ self.wq = nn.Linear(
+ dim,
+ n_heads * head_dim,
+ bias=False,
+ )
+ self.wk = nn.Linear(
+ dim,
+ n_kv_heads * head_dim,
+ bias=False,
+ )
+ self.wv = nn.Linear(
+ dim,
+ n_kv_heads * head_dim,
+ bias=False,
+ )
+
+ self.wo = nn.Linear(
+ n_heads * head_dim,
+ dim,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ freq_cis: torch.Tensor,
+ tok_idx: Optional[torch.Tensor] = None,
+ mask: Optional[Union[BlockMask, str]] = None,
+ attn_impl: str = "sdpa",
+ ) -> torch.Tensor:
+ # B S D
+ bsz, seq_len, dim = x.shape
+ xq = self.wq(x.view_as(x))
+ xk = self.wk(x.view_as(x))
+ xv = self.wv(x.view_as(x))
+
+ output_shape = xq.shape
+ # B S D -> B S H D
+ xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
+ xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
+ xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
+
+ xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
+
+ # This condition helps us be easily compatible
+ # with inference by adding a pluggable KVCache
+ if hasattr(self, "kv_cache"):
+ xk, xv = self.kv_cache.update(xk, xv, tok_idx)
+
+ xk = repeat_kv(xk, self.heads_per_group, dim=2)
+ xv = repeat_kv(xv, self.heads_per_group, dim=2)
+
+ if attn_impl == "flex_attention":
+ assert mask is None or isinstance(mask, BlockMask)
+ xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv))
+ output = flex_attention_comp(xq, xk, xv, block_mask=mask)
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
+
+ elif attn_impl == "sdpa":
+ xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv))
+ assert mask is None or isinstance(mask, (str, torch.Tensor))
+ is_causal = (mask == "causal") if isinstance(mask, str) else False
+ mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None
+ output = F.scaled_dot_product_attention(
+ xq,
+ xk,
+ xv,
+ is_causal=is_causal,
+ attn_mask=mask,
+ )
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
+ else:
+ raise NotImplementedError(f"Attention implementation {attn_impl} not supported")
+
+ output_reshaped = output.reshape(output_shape)
+
+ output = self.wo(output_reshaped)
+
+ return output
+
+
+
+
+class BLTMLP(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ multiple_of: int,
+ ffn_dim_multiplier: Optional[float],
+ mp_size: int = 1,
+ ):
+ super().__init__()
+
+ hidden_dim = int(2 * hidden_dim / 3)
+ if ffn_dim_multiplier is not None:
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+ assert hidden_dim % mp_size == 0
+
+ self.dim = dim
+ self.hidden_dim = hidden_dim
+
+ self.w1 = nn.Linear(
+ dim,
+ hidden_dim,
+ bias=False,
+ )
+ self.w3 = nn.Linear(
+ dim,
+ hidden_dim,
+ bias=False,
+ )
+ self.w2 = nn.Linear(
+ hidden_dim,
+ dim,
+ bias=False,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # B S D
+ x1 = self.w1(x.view_as(x))
+ x3 = self.w3(x.view_as(x))
+ output = self.w2(F.silu(x1) * x3)
+ return output
+
+
+
+
+class BLTTransformerLayer(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+
+ # Extract parameters from dictionary
+ dim = args["dim"]
+ n_heads = args["n_heads"]
+ head_dim = args["head_dim"]
+ n_kv_heads = args["n_kv_heads"]
+ rope_theta = args["rope_theta"]
+ multiple_of = args["multiple_of"]
+ ffn_dim_multiplier = args["ffn_dim_multiplier"]
+ norm_eps = args["norm_eps"]
+
+ assert (head_dim is not None) or (n_heads is not None), "Should specify at least head_dim or n_heads"
+ self.head_dim = head_dim or dim // n_heads
+ self.n_heads = n_heads or dim // head_dim
+ self.n_kv_heads = n_kv_heads or self.n_heads
+
+ assert n_heads % self.n_kv_heads == 0
+ assert dim % n_heads == 0
+
+ self.attention = BLTAttention(
+ dim=dim,
+ head_dim=self.head_dim,
+ n_heads=self.n_heads,
+ n_kv_heads=self.n_kv_heads,
+ rope_theta=rope_theta,
+ )
+ self.feed_forward = BLTMLP(
+ dim=dim,
+ hidden_dim=4 * dim,
+ multiple_of=multiple_of,
+ ffn_dim_multiplier=ffn_dim_multiplier,
+ )
+ self.attention_norm = RMSNorm(dim, eps=norm_eps)
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ freq_cis: torch.Tensor,
+ tok_idx: Optional[torch.Tensor] = None,
+ mask: Optional[Union[BlockMask, str]] = None,
+ attn_impl: str = "sdpa",
+ ) -> torch.Tensor:
+ norm_x = self.attention_norm(x)
+ attn_out = self.attention(
+ norm_x,
+ freq_cis,
+ tok_idx=tok_idx,
+ mask=mask,
+ attn_impl=attn_impl,
+ )
+ h = x + attn_out
+ h_norm = self.ffn_norm(h)
+ out = h + self.feed_forward(h_norm)
+ return out
+
+
+
+
+def rightpad(seq, pad_id, max_len):
+ return seq + [pad_id] * (max_len - len(seq))
+
+
+def check_non_zero_after_zero(tensor):
+ zero_mask = tensor == 0
+ shifted_mask = torch.cat(
+ [
+ torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
+ zero_mask[:, :-1],
+ ],
+ dim=1,
+ )
+ non_zero_after_zero = (tensor != 0) & shifted_mask
+ return non_zero_after_zero.any()
+
+
+def fill_tokens(tokens, patch_size, fill_id):
+ batch_size, seq_len = tokens.shape
+ if seq_len % patch_size == 0:
+ return tokens
+ else:
+ remaining = patch_size - seq_len % patch_size
+ final_padding = tokens.new(batch_size, remaining).fill_(fill_id)
+ return torch.cat((tokens, final_padding), dim=1)
+
+
+def rolling_polynomial_hash(t, hash_func_nb: int = 0):
+ primes = [
+ 1000000007,
+ 5915587277,
+ 1500450271,
+ 3267000013,
+ 5754853343,
+ 4093082899,
+ 9576890767,
+ 3628273133,
+ 2860486313,
+ 5463458053,
+ 3367900313,
+ ]
+ prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device)
+ prime_powers = torch.stack([prime**i for i in range(t.shape[-1])])
+ return torch.sum(t * prime_powers, dim=-1)
+
+
+def byte_group_hash_function(x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000):
+ """
+ Returns a hash of the input x and maps it to a value in the range [0, max_hash].
+
+ expects: x of shape (batch_size, seq_len) with values as ids in the token vocab.
+ returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash].
+
+ Note: max hash can make a big difference on the number of collisions.
+ """
+ with torch.no_grad():
+ bs, seq_len = x.shape
+ prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device)
+ x = torch.cat([prefix, x], dim=1)
+ windows = x.unfold(1, group_size, 1)
+ # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
+ hashes = rolling_polynomial_hash(windows, hash_func_nb)
+ hash_values_range = hashes % max_hash
+ hash_values_range.requires_grad = False
+ return hash_values_range
+
+
+def create_patch_mask_from_ids(patch_ids, num_patches, window=None, patches_as_queries=False):
+ """
+ Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k)
+ is True if the patch id at position (i, j) is less than or equal to k.
+ Args:
+ patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids.
+ num_patches (int): Total number of patches.
+ window (int): If not None, only considers patches within a window of size window.
+ patches_as_queries (bool): If True, the patches are used as queries
+ Returns:
+ torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask.
+ """
+ bs, seq_len = patch_ids.shape
+ if not patches_as_queries:
+ q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches)
+ kv_ids = (
+ torch.arange(num_patches, device=patch_ids.device)
+ .unsqueeze(0)
+ .unsqueeze(0)
+ .expand(bs, seq_len, num_patches)
+ )
+ else:
+ kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len)
+ q_ids = (
+ torch.arange(num_patches, device=patch_ids.device)
+ .unsqueeze(0)
+ .unsqueeze(-1)
+ .expand(bs, num_patches, seq_len)
+ )
+ if window is None:
+ mask = q_ids == kv_ids
+ else:
+ mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window)
+ return mask
+
+
+def cross_attn_mask(
+ patch_ids,
+ patch_lengths,
+ N,
+ patches_as_queries=False,
+ cross_attn_k=1,
+ window=None,
+ block_mask=True,
+):
+ bs = patch_ids.shape[0]
+ with torch.no_grad():
+ # Create the patch mask
+ cross_mask = create_patch_mask_from_ids(
+ patch_ids,
+ patch_lengths.shape[1],
+ window=window,
+ patches_as_queries=patches_as_queries,
+ ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1)
+ q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N
+ kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k
+ assert cross_mask.shape == (
+ bs,
+ q_len,
+ kv_len,
+ ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}"
+ block_mask = None
+ if block_mask:
+
+ def patch_mask(b, h, q_idx, kv_idx):
+ return cross_mask[b, q_idx, kv_idx]
+
+ block_mask = create_block_mask(
+ patch_mask,
+ B=bs,
+ H=None,
+ Q_LEN=q_len,
+ KV_LEN=kv_len,
+ _compile=True,
+ )
+ return block_mask
+ else:
+ return torch.where(cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))).unsqueeze(
+ 1
+ ) # [bs, 1, q_len, kv_len]
+
+
+def get_blt_input(
+ tokens: torch.Tensor,
+ enforce_patch_size_multiple: bool,
+ nb_boe: torch.Tensor,
+ patch_size: int,
+ boe_id: int,
+):
+ """
+ This function returns X_et, X_gt and X_dt, the encoder, global, and decoder
+ tokens respectively.
+
+ Consider the input and target sequences:
+ X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13]
+ Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14]
+ with patch_size=4
+
+ Note 1: that there will be no special tokens introduced at the patch level.
+ Note 2: X_e needs to be trimmed to be passed to Global
+
+ Current without boe:
+ X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]]
+ X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch
+ X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
+ Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
+
+ --> lag fix:
+ X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]]
+ X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]]
+ X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
+ Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
+
+ Dynamic (current):
+ X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos]
+ Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
+
+ entropy patching:
+ input: 7, bos, 9, 10
+ pred (high entropy): eos, 8, 10, eos
+
+ X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos]
+ X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]]
+ X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]]
+ Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
+
+ --> lag fix no boe (force single byte first patch):
+ X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
+ X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
+ X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
+ Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
+
+ input: 4, 7, bos, 9, 10
+ pred (high entropy): 5, eos, 8, 10, eos
+
+ X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
+ X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
+ X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
+ Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
+
+ Handle the last byte properly.
+ patch_lengths = [1, 1, 3, 2, 2 1 2 2 1]
+ X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
+ X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch
+ X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]]
+ Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]]
+
+
+ bpe delim
+ X_et = [[3,4,5,6,7,,eos,bos,,8,9,,10,,eos,bos,11,12]
+ X_g = [[3], [4,5,6,7,], [eos,bos,], ..
+ X_dt = [[3,4,5,6,7], [,eos,bos], [,bos,8], ..
+ Y = [4,5,6,7,, eos,bos, 8,9,, ..
+
+
+ Note 1: that there will be no special tokens introduced at the patch level.
+ Note 2: X_e needs to be trimmed to be passed to Global
+ """
+ batch_size, seq_len = tokens.shape
+ local_encoder_tokens = tokens
+ local_decoder_tokens = tokens
+
+ if nb_boe > 0:
+ padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id)
+ local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1)
+ # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id)
+
+ # create global tokens, contains boe tokens and eos
+ # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
+ # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size)
+ # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:]
+ # global_tokens += global_tokens.eq(0).int() * boe_id
+ # TODO: fix this when we want to use block causal in the global.
+
+ if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0:
+ local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
+
+ return local_encoder_tokens, None, local_decoder_tokens
+
+
+class LocalModelBase(nn.Module):
+ def __init__(self, config: BLTConfig, component_type: str = "encoder"):
+ super().__init__()
+
+ # Store config for later use
+ self.config = config
+
+ # Use component-specific dimensions
+ if component_type == "encoder":
+ self.dim = config.dim_local_encoder
+ self.n_layers = config.n_layers_local_encoder
+ self.n_heads = config.n_heads_local_encoder
+ self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen
+ self.attn_bias_type = "local_block_causal"
+ self.sliding_window = config.local_attention_window_len
+ elif component_type == "decoder":
+ self.dim = config.dim_local_decoder
+ self.n_layers = config.n_layers_local_decoder
+ self.n_heads = config.n_heads_local_decoder
+ self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen
+ self.attn_bias_type = "local_block_causal"
+ self.sliding_window = config.local_attention_window_len
+ else:
+ raise ValueError(f"Unknown component_type: {component_type}")
+
+ self.dropout = config.dropout
+ self.vocab_size = config.vocab_size + config.pm_size
+ self.patch_size = config.patch_size
+
+ self.attn_impl = config.attn_impl
+ self.use_rope = config.use_rope
+ self.init_std_factor = config.init_std_factor
+ self.init_base_std = config.init_base_std
+ self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None)
+ self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None)
+ self.cross_attn_k = getattr(config, "cross_attn_k", None)
+ self.eos_id = config.eos_token_id
+
+ self.boe_id = BOE_ID
+
+ # Initialize cross attention layers as None (will be set by subclasses if needed)
+ self.cross_attn_layers = None
+
+ # Create parameter dict for BLTTransformerLayers
+ layer_params = {
+ "dim": self.dim,
+ "n_heads": self.n_heads,
+ "head_dim": config.head_dim,
+ "n_kv_heads": getattr(config, "n_kv_heads", None),
+ "rope_theta": config.rope_theta,
+ "multiple_of": getattr(config, "multiple_of", 256),
+ "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None),
+ "norm_eps": config.norm_eps,
+ }
+
+ self.layers = nn.ModuleList([BLTTransformerLayer(layer_params) for _ in range(self.n_layers)])
+
+ if not self.use_rope:
+ self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length
+ else:
+ self.rope = RotaryEmbedding(
+ theta=config.rope_theta,
+ head_dim=config.head_dim or self.dim // self.n_heads,
+ max_seqlen=self.max_seqlen,
+ rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product,
+ )
+ self.pos_embeddings = None
+
+ # Set dimension-specific embedding dimensions
+ if component_type == "encoder":
+ self.dim_token_emb = config.encoder_dim_token_emb
+ self.dim_patch_emb = config.encoder_dim_patch_emb
+ elif component_type == "decoder":
+ self.dim_token_emb = config.decoder_dim_token_emb
+ self.dim_patch_emb = config.dim_global
+
+ self.token_embedding_projection = (
+ nn.Linear(self.dim_token_emb, self.dim, bias=False)
+ if self.dim_token_emb is not None and self.dim_token_emb != self.dim
+ else None
+ )
+
+ self.patch_embedding_projection = self._create_patch_projection(config)
+
+ def _should_create_patch_projection(self, config: BLTConfig):
+ dimension_mismatch = self.dim_patch_emb is not None and self.dim_patch_emb != self.dim
+
+ # Check cross attention conditions
+ cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or (
+ config.cross_attn_decoder and config.cross_attn_init_by_pooling
+ )
+
+ return dimension_mismatch or cross_attn_conditions
+
+ def _create_patch_projection(self, config):
+ if not self._should_create_patch_projection(config):
+ return None
+
+ output_dim = self.dim_token_emb * (self.cross_attn_k or 1)
+
+ return nn.Linear(
+ in_features=self.dim_patch_emb,
+ out_features=output_dim,
+ bias=False,
+ )
+
+ def apply_embedding(self, tokens, embeds):
+ if embeds is not None:
+ return embeds
+ else:
+ return self.tok_embeddings(tokens)
+
+
+
+
+class LocalEncoder(LocalModelBase):
+ def __init__(self, config: BLTConfig):
+ super().__init__(config, component_type="encoder")
+
+ self.apply_transformer = config.use_local_encoder_transformer
+ self.downsampling_by_pooling = config.downsampling_by_pooling
+ self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None
+ self.cross_attn_encoder = config.cross_attn_encoder
+ self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder
+ self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling
+ self.cross_attn_nheads = config.cross_attn_nheads
+
+ self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim)
+
+ if self.cross_attn_encoder:
+ self.cross_attn_layers = torch.nn.ModuleList()
+ layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1
+ for _ in range(layers_to_add):
+ self.cross_attn_layers.append(
+ BLTCrossAttention(
+ dim=self.dim,
+ head_dim=self.dim // self.cross_attn_nheads,
+ n_heads=self.cross_attn_nheads,
+ n_kv_heads=self.cross_attn_nheads,
+ norm_eps=config.norm_eps,
+ )
+ )
+
+ def apply_embedding(self, tokens, embeds):
+ if embeds is not None:
+ assert self.expects_hash_embeddings, "Not expecting embeddings to be passed."
+ return embeds
+ else:
+ return self.tok_embeddings(tokens)
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ embeds: Optional[torch.Tensor] = None,
+ patch_embeds: Optional[torch.Tensor] = None,
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
+ cross_mask: Optional[torch.Tensor] = None,
+ num_patches: Optional[int] = None,
+ patch_ids: Optional[torch.Tensor] = None,
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+ ):
+ """ """
+ bs, seqlen = tokens.shape
+ if mask is None:
+ mask = create_causal_mask(
+ seqlen,
+ self.attn_impl,
+ "local_block_causal",
+ sliding_window=self.sliding_window,
+ tokens=tokens,
+ eos_id=self.eos_id,
+ )
+
+ h = self.apply_embedding(tokens, embeds)
+ freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
+
+ h = F.dropout(h, p=self.dropout, training=self.training)
+
+ for i, layer in enumerate(self.layers):
+ h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
+ # check if cross attention should be applied to either all layer or only the last layer
+ if self.cross_attn_encoder and (i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder):
+ # apply pooling and project
+ if self.cross_attn_init_by_pooling and patch_embeds is None:
+ patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids)
+ if self.patch_embedding_projection is not None:
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
+ patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim)
+
+ layer_idx = i if self.cross_attn_all_layers_encoder else 0
+ patch_embeds_cross = self.cross_attn_layers[layer_idx](
+ x=patch_embeds,
+ kv=h,
+ mask=cross_mask,
+ )
+ patch_embeds = patch_embeds + patch_embeds_cross
+
+ h_residual = patch_embeds if self.cross_attn_encoder else None
+ return (h, h_residual), cache
+
+ def patch_reduce(self, h, max_num_patches, reduction, patch_ids):
+ """
+ Reduce variable length patches to single embedding per patch
+ Note: this works with variable number of patches for different sequences in the batch
+ It handles variable length patches by assuming that patch_lengths will be 0 for any
+ extra patches on the *right*. Since there can be a variable number of patches
+ this function also return the number of patches for each sequence in the batch.
+ Any embeddings on the right that are not allocated to a patch
+ (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
+ will be sent to a dummy patch, which is trimmed before returning.
+ """
+ bs, seq_len, emb_dim = h.shape
+
+ patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
+
+ reduced_embs = torch.zeros((bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device)
+ reduced_embs = reduced_embs.scatter_reduce(
+ src=h,
+ dim=1,
+ index=patch_ids,
+ reduce=reduction,
+ include_self=False,
+ )
+ reduced_embs = reduced_embs[:, :max_num_patches, :]
+
+ return reduced_embs
+
+
+class LocalDecoder(LocalModelBase):
+ def __init__(self, config: BLTConfig):
+ super().__init__(config, component_type="decoder")
+
+ # Model configuration flags
+ self.cross_attn_decoder = config.cross_attn_decoder
+ self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder
+ self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling
+ self.cross_attn_nheads = config.cross_attn_nheads
+
+ self.norm = RMSNorm(self.dim, eps=config.norm_eps)
+
+ if self.cross_attn_decoder:
+ self.cross_attn_layers = torch.nn.ModuleList()
+ layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1
+ for _ in range(layers_to_add):
+ self.cross_attn_layers.append(
+ BLTCrossAttention(
+ dim=self.dim,
+ head_dim=self.dim // self.cross_attn_nheads,
+ n_heads=self.cross_attn_nheads,
+ n_kv_heads=self.cross_attn_nheads,
+ norm_eps=config.norm_eps,
+ )
+ )
+
+ self.output = nn.Linear(
+ self.dim,
+ config.vocab_size,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ embeds: Optional[torch.Tensor],
+ patch_embeds: Optional[torch.Tensor] = None,
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
+ cross_mask: Optional[torch.Tensor] = None,
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+ ):
+ bs, seqlen = tokens.shape
+ assert embeds is not None, "Embeddings must be provided"
+
+ if mask is None:
+ mask = create_causal_mask(
+ seqlen,
+ self.attn_impl,
+ "local_block_causal",
+ sliding_window=self.sliding_window,
+ tokens=tokens,
+ eos_id=self.eos_id,
+ )
+
+ h = embeds
+
+ if self.patch_embedding_projection is not None:
+ assert patch_embeds is not None, "Patch embeddings must be passed."
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
+ if self.cross_attn_k is not None:
+ patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim)
+
+ if patch_embeds is not None and not self.cross_attn_decoder:
+ h = h + patch_embeds
+
+ freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
+
+ h = F.dropout(h, p=self.dropout, training=self.training)
+ for i, layer in enumerate(self.layers):
+ if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder):
+ # Use cross attention to extract info from patch_embeds into h
+ h_cross = self.cross_attn_layers[i](
+ x=h,
+ kv=patch_embeds,
+ mask=cross_mask,
+ )
+ h = h + h_cross
+
+ h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
+
+ h_preds = self.norm(h)
+ h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
+ h_preds = self.output(h_preds)
+ h_preds = h_preds.float()
+ return h_preds, cache
+
+
+class BLTCrossAttention(nn.Module):
+ """
+ BLTCrossAttention block to attend to the encoder states from the decoder.
+ Rope is not supported.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ head_dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ norm_eps: float,
+ ):
+ super().__init__()
+
+ self.dim = dim
+ self.head_dim = head_dim
+
+ self.n_heads = n_heads
+ self.n_kv_heads = n_kv_heads
+ self.heads_per_group = self.n_heads // self.n_kv_heads
+
+ self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps)
+ self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
+
+ self.wq = nn.Linear(
+ dim,
+ n_heads * head_dim,
+ bias=False,
+ )
+ self.wk = nn.Linear(
+ dim,
+ n_kv_heads * head_dim,
+ bias=False,
+ )
+ self.wv = nn.Linear(
+ dim,
+ n_kv_heads * head_dim,
+ bias=False,
+ )
+
+ self.wo = nn.Linear(
+ n_heads * head_dim,
+ dim,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ kv: torch.Tensor,
+ mask: Optional[Union[BlockMask, str]] = None,
+ ) -> torch.Tensor:
+ # B S D
+ bsz, seq_len, _ = x.shape
+ _, slen_kv, _ = kv.shape
+ x_norm = self.cross_attn_norm_q(x)
+ kv = self.cross_attn_norm_kv(kv)
+
+ xq = self.wq(x_norm)
+ xk = self.wk(kv)
+ xv = self.wv(kv)
+
+ output_shape = xq.shape
+ # B S D -> B S H D
+ xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
+ xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
+ xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
+
+ xk = repeat_kv(xk, self.heads_per_group, dim=2)
+ xv = repeat_kv(xv, self.heads_per_group, dim=2)
+
+ # assert mask is None or isinstance(mask, BlockMask)
+ xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv))
+ # output = flex_attention_comp(xq, xk, xv, block_mask=mask)
+ is_causal = (mask == "causal") if isinstance(mask, str) else False
+ mask = mask if isinstance(mask, torch.Tensor) else None
+ mask = mask.to(dtype=xq.dtype).to(xq.device)
+ output = F.scaled_dot_product_attention(
+ xq,
+ xk,
+ xv,
+ is_causal=is_causal,
+ attn_mask=mask,
+ )
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
+
+ output = self.wo(output.reshape(output_shape))
+
+ return x + output
+
+
+
+
+class GlobalTransformer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ # Store config for later use
+ self.config = config
+
+ self.dim = config.dim_global
+ self.init_base_std = config.init_base_std
+ self.attn_impl = config.attn_impl
+ self.attn_bias_type = config.attn_bias_type
+ self.init_std_factor = config.init_std_factor
+ self.max_seqlen = config.max_seqlen
+ self.rope_embeddings = RotaryEmbedding(
+ theta=config.rope_theta,
+ head_dim=config.head_dim or config.dim_global // config.n_heads_global,
+ max_seqlen=config.max_seqlen,
+ rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product,
+ )
+ # Handle both eos_id and eos_token_id for compatibility
+ self.eos_id = getattr(config, "eos_id", getattr(config, "eos_token_id", 2))
+
+ # Create parameter dict for BLTTransformerLayers
+ layer_params = {
+ "dim": self.dim,
+ "n_heads": config.n_heads_global,
+ "head_dim": config.head_dim,
+ "n_kv_heads": getattr(config, "n_kv_heads_global", None),
+ "rope_theta": config.rope_theta,
+ "multiple_of": getattr(config, "multiple_of", 256),
+ "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None),
+ "norm_eps": config.norm_eps,
+ }
+
+ self.layers = nn.ModuleList()
+ for _ in range(config.n_layers_global):
+ self.layers.append(BLTTransformerLayer(layer_params))
+
+ # GlobalTransformer specific attributes
+ self.dropout = config.dropout
+ self.dim_token_emb = config.global_dim_patch_emb
+
+ self.token_embedding_projection = None
+ if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim:
+ self.token_embedding_projection = nn.Linear(
+ config.global_dim_patch_emb,
+ config.dim_global,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ tok_idx: Optional[torch.Tensor] = None,
+ embeds: Optional[torch.Tensor] = None,
+ mask: Optional[Union[BlockMask, torch.Tensor, str]] = None,
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+ ):
+ bs, seqlen = tokens.shape
+
+ h = embeds
+
+ mask = (
+ mask
+ if mask is not None
+ else create_causal_mask(
+ seqlen,
+ self.attn_impl,
+ self.attn_bias_type,
+ tokens=tokens,
+ eos_id=self.eos_id,
+ )
+ )
+
+ if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
+ h = self.token_embedding_projection(h)
+
+ h = F.dropout(h, p=self.dropout, training=self.training)
+
+ freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)
+
+ for i, layer in enumerate(self.layers):
+ h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
+
+ return h, cache
+
+
+
+
+def compute_hash_embeddings(
+ local_encoder_tokens: torch.Tensor,
+ local_encoder,
+ encoder_hash_tok_embedding: nn.ModuleList,
+ encoder_hash_byte_group_nb_functions: int,
+ encoder_hash_byte_group_size: list,
+ encoder_hash_byte_group_vocab: int,
+) -> torch.Tensor:
+ """
+ Compute embeddings using hash token embeddings.
+
+ Args:
+ local_encoder_tokens: Input tokens tensor
+ local_encoder: Encoder object with tok_embeddings method
+ encoder_hash_tok_embedding: ModuleList of hash token embeddings
+ encoder_hash_byte_group_nb_functions: Number of hash functions
+ encoder_hash_byte_group_size: List of byte group sizes
+ encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings
+
+ Returns:
+ torch.Tensor: Combined embeddings
+ """
+ if encoder_hash_tok_embedding is None:
+ return None
+
+ local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens)
+
+ i = 0
+ for func_nb in range(encoder_hash_byte_group_nb_functions):
+ for byte_group_size in encoder_hash_byte_group_size:
+ hash_ids = byte_group_hash_function(
+ local_encoder_tokens,
+ byte_group_size,
+ hash_func_nb=func_nb,
+ max_hash=encoder_hash_byte_group_vocab,
+ )
+ hash_tok_embedding = encoder_hash_tok_embedding[i]
+ local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids)
+ i += 1
+
+ assert i == len(encoder_hash_tok_embedding)
+ return local_encoder_embeds
+
+
+class BLTPreTrainedModel(PreTrainedModel):
+ config_class = BLTConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["BLTTransformerLayer", "LocalEncoder", "LocalDecoder", "GlobalTransformer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = False # BLT uses its own attention implementation
+ _supports_sdpa = True
+ _supports_cache_class = False
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Linear):
+ std = getattr(module, '_custom_std', module.in_features ** (-0.5))
+
+ nn.init.trunc_normal_(
+ module.weight,
+ mean=0.0,
+ std=std,
+ a=-3 * std,
+ b=3 * std,
+ )
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ elif isinstance(module, nn.Embedding):
+ std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5))
+
+ nn.init.trunc_normal_(
+ module.weight,
+ mean=0.0,
+ std=std,
+ a=-3 * std,
+ b=3 * std,
+ )
+
+ elif isinstance(module, (nn.RMSNorm, nn.LayerNorm)):
+ nn.init.ones_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ elif isinstance(module, RotaryEmbedding):
+ module.freqs_cis[...] = precompute_freqs_cis(
+ dim=module.head_dim,
+ end=module.max_seqlen,
+ theta=module.theta,
+ rope_use_fp32_in_outer_product=module.rope_use_fp32_in_outer_product,
+ )
+
+ elif isinstance(module, BLTModel):
+ if module.encoder_hash_tok_embedding is not None:
+ emb_std = module.local_encoder.dim ** (-0.5)
+ for emb in module.encoder_hash_tok_embedding:
+ emb._custom_std = emb_std
+
+ elif isinstance(module, (LocalEncoder, LocalDecoder)):
+ if module.token_embedding_projection is not None:
+ module.token_embedding_projection._custom_std = module.dim ** (-0.5)
+
+ if module.patch_embedding_projection is not None:
+ module.patch_embedding_projection._custom_std = module.dim_patch_emb ** (-0.5)
+
+ elif isinstance(module, GlobalTransformer):
+ if module.token_embedding_projection is not None:
+ module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5)
+
+ elif isinstance(module, BLTPatcher):
+ emb_std = module.config.patcher_dim ** (-0.5)
+ module.tok_embeddings._custom_std = emb_std
+ module.output._custom_std = emb_std
+
+
+class BLTModel(BLTPreTrainedModel):
+ def __init__(self, config: BLTConfig):
+ super().__init__(config)
+
+ self.config = config
+ self.local_encoder = LocalEncoder(config)
+ self.global_transformer = GlobalTransformer(config)
+ self.local_decoder = LocalDecoder(config)
+
+ self.encoder_hash_tok_embedding = init_hash_embeddings(
+ config,
+ local_encoder_dim=self.local_encoder.dim,
+ encoder_hash_byte_group_size=config.encoder_hash_byte_group_size,
+ )
+
+ if config.patch_in_forward:
+ self.patcher = BLTPatcher(config)
+ self.patcher.eval()
+ for param in self.patcher.parameters():
+ param.requires_grad = False
+ else:
+ self.patcher = None
+
+
+
+ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
+ """
+ Convert patch lengths to patch IDs for each token position.
+
+ For each token position in the sequence, determines which patch it belongs to.
+
+ Args:
+ patch_lengths: [batch_size, num_patches] - length of each patch
+ seq_len: total sequence length
+
+ Returns:
+ patch_ids: [batch_size, seq_len] - patch index for each token position
+
+ Example:
+ patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1
+ seq_len = 10
+ Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]]
+ # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3
+ """
+ batch_size, num_patches = patch_lengths.shape
+
+ # Create patch start positions: [0, 3, 5, 9] for the example above
+ patch_starts = torch.cat(
+ [
+ torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
+ patch_lengths.cumsum(dim=-1)[:, :-1], # cumsum without the final total
+ ],
+ dim=-1,
+ )
+
+ # For each token position, find which patch it belongs to
+ # by finding the rightmost patch start that's <= the position
+ token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1]
+
+ # Broadcasting: patch_starts[batch, patch] <= token_positions[position]
+ # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t
+ position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)
+
+ # Count how many patch starts are <= each position, then subtract 1 to get patch index
+ patch_ids = position_ge_patch_start.sum(dim=-1) - 1
+
+ return patch_ids
+
+ def _decoder_patch_ids_from_lengths(self, patch_lengths: torch.Tensor, nb_boe: int, seq_len: int) -> torch.Tensor:
+ """
+ Create decoder patch IDs by skipping the first encoder patch.
+
+ The decoder starts after the first patch (which contains BOE tokens),
+ so we need to map decoder positions to the remaining patches.
+
+ Args:
+ patch_lengths: [batch_size, num_patches] from encoder
+ nb_boe: number of beginning-of-example tokens in first patch
+ seq_len: decoder sequence length
+
+ Returns:
+ decoder_patch_ids: [batch_size, seq_len] mapping decoder positions to patch indices
+ """
+ # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens)
+ decoder_patch_lengths = patch_lengths[:, 1:]
+
+ # Create patch IDs for the decoder sequence using the remaining patches
+ return self._patch_ids_from_lengths(decoder_patch_lengths, seq_len)
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ patch_lengths: Optional[torch.Tensor] = None,
+ ):
+ # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings
+ # are no longer used in the final BLT model
+
+ bs, N = tokens.shape # Batch size and sequence length
+
+ # Get megabyte inputs
+ nb_boe = int(0 if self.config.patching_mode != "" else self.config.patch_size - 1)
+ local_encoder_tokens, _, local_decoder_tokens = get_blt_input(
+ tokens=tokens,
+ enforce_patch_size_multiple=False,
+ nb_boe=nb_boe,
+ patch_size=self.config.patch_size,
+ boe_id=BOE_ID,
+ )
+
+ # Patching
+ if patch_lengths is None:
+ # assert (
+ # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward
+ # ), "Patch in forward not enabled and no patch_lengths passed."
+
+ # PATCHER MODEL DEFINED
+ if self.config.patching_mode == PatchingModeEnum.entropy:
+ _, patch_lengths, _ = self.patcher(
+ local_encoder_tokens,
+ patch_size=self.config.patch_size,
+ include_next_token=True,
+ threshold=self.config.patching_threshold,
+ threshold_add=self.config.patching_threshold_add,
+ monotonicity=self.config.monotonicity,
+ max_patch_length=self.config.max_patch_length,
+ patching_batch_size=self.config.patching_batch_size,
+ device=self.config.patching_device,
+ )
+ else:
+ # self.config.patching_mode == PatchingModeEnum.byte
+ bs, seq_len = local_encoder_tokens.shape
+ seq_len_next_tok = seq_len + 1 # include_next_token=True
+ patch_lengths = torch.ones(
+ (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device
+ )
+
+ # Apply any processing to patch lengths
+ if self.config.max_patch_length is not None:
+ # TODO: avoid going back to a list here.
+ patch_lengths = [
+ BLTPatcher.split_large_numbers(pl, self.config.max_patch_length)
+ for pl in patch_lengths.tolist()
+ ]
+ max_len = max([len(pl) for pl in patch_lengths])
+ patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
+ patch_lengths = torch.tensor(
+ patch_lengths, dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device
+ )
+ assert not check_non_zero_after_zero(patch_lengths)
+ # Find the last non-zero column index using argmax on a reversed version of the tensor
+ last_non_zero_col_reversed = (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
+ # Slice the tensor up to the last non-zero column
+ patch_lengths = patch_lengths[:, : patch_lengths.shape[1] - last_non_zero_col_reversed]
+ else:
+ if nb_boe > 0:
+ patch_lengths[:, 0] += nb_boe
+
+ assert torch.min(patch_lengths) >= 0
+
+ # Generate patch IDs from patch_lengths
+ patch_ids = self._patch_ids_from_lengths(patch_lengths, local_encoder_tokens.shape[-1])
+ assert torch.max(patch_ids) + 1 <= torch.max((patch_lengths != 0).sum(dim=-1)), (
+ f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}"
+ )
+
+ cross_attn_mask_enc = None
+ # Cross-attention encoder
+ if self.config.cross_attn_encoder:
+ cross_attn_mask_enc = cross_attn_mask(
+ patch_ids,
+ patch_lengths,
+ N,
+ patches_as_queries=True,
+ cross_attn_k=self.config.cross_attn_k,
+ window=self.config.cross_attn_window_encoder,
+ block_mask=self.config.cross_attn_use_flex_attention,
+ )
+
+ # Hashing and embedding
+ local_encoder_embeds = compute_hash_embeddings(
+ local_encoder_tokens=local_encoder_tokens,
+ local_encoder=self.local_encoder,
+ encoder_hash_tok_embedding=self.encoder_hash_tok_embedding,
+ encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions,
+ encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size,
+ encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab,
+ )
+
+ # NOTE: Frequency-based n-gram embeddings removed as per paper
+ # The final BLT model uses only hash-based n-gram embeddings
+
+ # Local encoder
+ (h_encoder, h_cross), cache_encoder = self.local_encoder(
+ tokens=local_encoder_tokens,
+ embeds=local_encoder_embeds,
+ patch_embeds=None,
+ cross_mask=cross_attn_mask_enc,
+ num_patches=patch_lengths.shape[1],
+ patch_ids=patch_ids,
+ )
+
+ # Downsampling
+ h = h_cross.view(bs, patch_lengths.shape[1], -1)
+
+ # Global transformer
+ global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(BOE_ID)
+ rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id)
+ eos_patch_ids = patch_ids[rows, cols]
+ global_tokens[rows, eos_patch_ids] = self.config.eos_token_id
+
+ h, _ = self.global_transformer(
+ embeds=h,
+ tokens=global_tokens,
+ )
+
+ # Unpatching
+ dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :]
+
+ # Generate decoder patch IDs
+ decoder_patch_ids = self._decoder_patch_ids_from_lengths(patch_lengths, nb_boe, local_decoder_tokens.shape[-1])
+ assert torch.max(decoder_patch_ids) + 1 <= h.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}"
+ assert decoder_patch_ids.shape[1] == dec_embeds.shape[1], (
+ f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}"
+ )
+
+ # Cross-attention decoder
+ if not self.config.cross_attn_decoder:
+ h = torch.gather(h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]))
+ cross_attn_mask_dec = None
+ assert local_decoder_tokens.shape == h.shape[:-1]
+ else:
+ cross_attn_mask_dec = cross_attn_mask(
+ decoder_patch_ids,
+ patch_lengths,
+ N,
+ patches_as_queries=False,
+ cross_attn_k=self.config.cross_attn_k,
+ window=self.config.cross_attn_window_decoder,
+ block_mask=self.config.cross_attn_use_flex_attention,
+ )
+
+ # Local decoder
+ output, _ = self.local_decoder(
+ embeds=dec_embeds,
+ patch_embeds=h,
+ tokens=local_decoder_tokens,
+ cross_mask=cross_attn_mask_dec,
+ )
+ return output
+
+
+class BLTPatcher(BLTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.rope_embeddings = RotaryEmbedding(
+ theta=config.patcher_rope_theta,
+ head_dim=config.patcher_head_dim or config.patcher_dim // config.patcher_n_heads,
+ max_seqlen=config.patcher_max_seqlen,
+ rope_use_fp32_in_outer_product=config.patcher_rope_use_fp32_in_outer_product,
+ )
+ # Handle both eos_id and eos_token_id for compatibility
+ self.eos_id = config.patcher_eos_token_id
+
+ # Extract additional parameters for BLTTransformerLayer
+ n_kv_heads = (
+ getattr(config, "patcher_n_kv_heads", None)
+ if hasattr(config, "patcher_dim")
+ else getattr(config, "n_kv_heads", None)
+ )
+ multiple_of = (
+ getattr(config, "patcher_multiple_of", 256)
+ if hasattr(config, "patcher_dim")
+ else getattr(config, "multiple_of", 256)
+ )
+ ffn_dim_multiplier = (
+ getattr(config, "patcher_ffn_dim_multiplier", None)
+ if hasattr(config, "patcher_dim")
+ else getattr(config, "ffn_dim_multiplier", None)
+ )
+
+ self.layers = nn.ModuleList()
+ for _ in range(config.patcher_n_layers):
+ self.layers.append(
+ BLTTransformerLayer(
+ {
+ "dim": config.patcher_dim,
+ "n_heads": config.patcher_n_heads,
+ "head_dim": config.patcher_head_dim,
+ "n_kv_heads": n_kv_heads,
+ "rope_theta": config.patcher_rope_theta,
+ "multiple_of": multiple_of,
+ "ffn_dim_multiplier": ffn_dim_multiplier,
+ "norm_eps": config.patcher_norm_eps,
+ }
+ )
+ )
+
+ # LMTransformer specific attributes
+ self.sliding_window = config.patcher_sliding_window
+
+ assert config.patcher_vocab_size > 0
+
+ self.tok_embeddings = torch.nn.Embedding(config.patcher_vocab_size, config.patcher_dim)
+
+ self.norm = RMSNorm(config.patcher_dim, eps=config.patcher_norm_eps)
+
+ self.output = nn.Linear(
+ config.patcher_dim,
+ config.patcher_vocab_size,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ token_values: torch.Tensor,
+ target: Optional[torch.Tensor] = None,
+ tok_idx: Optional[torch.Tensor] = None,
+ mask: Optional[Union[BlockMask, torch.Tensor, str]] = None,
+ attn_impl: str | None = None,
+ patch_size: Optional[int] = None,
+ include_next_token: bool = True,
+ threshold: Optional[float] = None,
+ threshold_add: Optional[float] = None,
+ monotonicity: bool = False,
+ max_patch_length: Optional[int] = None,
+ patching_batch_size: int = 1,
+ device: Optional[str] = None,
+ enable_grad: bool = False,
+ ):
+ attn_impl = self.config.patcher_attn_impl if attn_impl is None else attn_impl
+
+ # Handle chunked processing for entropy calculation
+ entropies = []
+ preds = []
+ max_length = min(getattr(self, "max_length", 8192), self.config.patcher_max_seqlen)
+ batch_numel = max_length * patching_batch_size
+ splits = torch.split(token_values.flatten(), batch_numel)
+
+ for split in splits:
+ pad_size = (max_length - (split.numel() % max_length)) % max_length
+ pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False)
+ split = torch.cat((split, pad), dim=0)
+ split = split.reshape(-1, max_length)
+ if device is not None:
+ split = split.to(device)
+
+ # Process chunk: embeddings -> layers -> output
+ bsz, seqlen = split.shape
+ h = self.tok_embeddings(split)
+ chunk_mask = create_causal_mask(
+ seqlen,
+ attn_impl,
+ self.config.patcher_attn_bias_type,
+ sliding_window=self.sliding_window,
+ tokens=split,
+ eos_id=self.eos_id,
+ )
+ freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None)
+
+ for i, layer in enumerate(self.layers):
+ h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=attn_impl)
+
+ pred = self.output(self.norm(h))
+ pred = pred.reshape(-1, pred.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab]
+ preds.append(pred)
+ pred_entropies = self.entropy(pred)
+ entropies.append(pred_entropies)
+
+ concat_entropies = torch.cat(entropies, dim=0)
+ concat_entropies = concat_entropies.reshape(token_values.shape)
+ concat_preds = torch.cat(preds, dim=0)
+ concat_preds = concat_preds.reshape(token_values.shape[0], -1)
+
+ # Always compute patch lengths from concatenated entropies
+ bs, seq_len = token_values.shape
+ seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
+
+ # Find patch start IDs based on entropy
+ if patch_size is not None:
+ patch_start_ids = self.find_entropy_patch_start_ids(
+ concat_entropies,
+ patch_size,
+ include_next_token=include_next_token,
+ threshold=threshold,
+ threshold_add=threshold_add,
+ monotonicity=monotonicity,
+ )
+ patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok)
+ else:
+ # Default to byte-level patching
+ patch_lengths = torch.ones((bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device)
+
+ # Apply any processing to patch lengths
+ if max_patch_length is not None:
+ # TODO: avoid going back to a list here.
+ patch_lengths = [self.split_large_numbers(pl, max_patch_length) for pl in patch_lengths.tolist()]
+ max_len = max([len(pl) for pl in patch_lengths])
+ patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
+ patch_lengths = torch.tensor(patch_lengths, dtype=token_values.dtype, device=token_values.device)
+ assert not check_non_zero_after_zero(patch_lengths)
+ # Find the last non-zero column index using argmax on a reversed version of the tensor
+ last_non_zero_col_reversed = (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
+ # Slice the tensor up to the last non-zero column
+ patch_lengths = patch_lengths[:, : patch_lengths.shape[1] - last_non_zero_col_reversed]
+
+ return concat_entropies, patch_lengths, concat_preds
+
+
+
+
+
+ @staticmethod
+ def entropy(scores):
+ """
+ scores: [bs, seq_len, vocab]
+ returns [bs, seq_len]
+
+ Computes the entropy for each token in the batch.
+ Note: uses natural log.
+ """
+ log_probs = F.log_softmax(scores, dim=-1)
+ probs = torch.exp(log_probs)
+ p_log_p = log_probs * probs
+ entropy = -p_log_p.sum(dim=-1)
+ return entropy
+
+ @staticmethod
+ def patch_start_ids_from_patch_start_mask(patch_start_mask):
+ bs, trunc_seq_len = patch_start_mask.shape
+ max_patches = patch_start_mask.sum(dim=1).max()
+ if max_patches == 0:
+ patch_start_ids = torch.full(
+ (bs, trunc_seq_len),
+ trunc_seq_len,
+ dtype=torch.long,
+ device=patch_start_mask.device,
+ )
+ else:
+ patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(bs, 1)
+ extra_patch_ids = torch.full(
+ (bs, trunc_seq_len),
+ trunc_seq_len,
+ dtype=torch.long,
+ device=patch_start_mask.device,
+ )
+ all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
+ patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1)
+ patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, trunc_seq_len)[:, :max_patches]
+ return patch_start_ids
+
+ @staticmethod
+ def patch_lengths_from_start_ids(patch_start_ids, seq_len):
+ """
+ Calculate patch lengths from start ids.
+ start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then
+ the rest are filled to the seq len.
+ seq_len: ex: 7 length of the sequence
+
+ returns the patch lengths:
+ [1, 6] for the above example.
+ """
+ last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1)
+ patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1)
+ patch_lengths = patch_end_ids - patch_start_ids + 1
+ assert torch.all(patch_lengths >= 0), f"{patch_lengths}"
+ assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}"
+ return patch_lengths
+
+ @staticmethod
+ def find_entropy_patch_start_ids(
+ entropies,
+ patch_size=None,
+ threshold=None,
+ threshold_add=None,
+ monotonicity=False,
+ include_next_token=True,
+ ):
+ """
+ Use entropies to find the start ids of each patch.
+ Use patch_size or threshold to figure out the total number of patches to allocate.
+
+ When threshold is not None the number of patches is not constant between
+ different sequences, but patches can be identified incrementally rather than
+ decided globally using the entire sequence.
+ """
+ bs, seq_len = entropies.shape[:2]
+
+ first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(bs, 1)
+ preds_truncation_len = first_ids.shape[1] # remove the first preds because they will be start of patches.
+ entropies = entropies[:, 1:]
+ if threshold is None:
+ num_patches = seq_len // patch_size
+ patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices
+ patch_start_ids = patch_start_ids.sort(dim=1).values
+ else:
+ patch_start_mask = entropies > threshold
+ if not include_next_token:
+ patch_start_mask = patch_start_mask[:, :-1]
+ # patch_start_mask[1:] |= tokens[:-1] < OFFSET
+ patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask)
+
+ patch_start_ids = torch.cat((first_ids, patch_start_ids + preds_truncation_len), dim=1)
+ return patch_start_ids
+
+ @staticmethod
+ def split_large_numbers(lst, m):
+ new_lst = []
+ for i in lst:
+ if i > m:
+ while i > m:
+ new_lst.append(m)
+ i -= m
+ new_lst.append(i)
+ else:
+ new_lst.append(i)
+ assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}"
+ return new_lst
+
+
+def init_hash_embeddings(
+ config,
+ local_encoder_dim: int,
+ encoder_hash_byte_group_size: list,
+):
+ """Initialize hash-based token embeddings for the BLT encoder."""
+ if config.encoder_hash_byte_group_size is None:
+ return None
+
+ embeddings = []
+ emb_dim = local_encoder_dim
+ encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab
+
+ for _ in range(config.encoder_hash_byte_group_nb_functions):
+ for _ in encoder_hash_byte_group_size:
+ embeddings.append(
+ nn.Embedding(
+ encoder_hash_byte_group_vocab,
+ emb_dim,
+ )
+ )
+
+ return nn.ModuleList(embeddings)
+
+
+__all__ = [
+ "BLTPreTrainedModel",
+ "BLTModel",
+ "BLTPatcher",
+ "LocalEncoder",
+ "LocalDecoder",
+ "GlobalTransformer",
+]
\ No newline at end of file
diff --git a/backup_blt_wip_backup/modeling_blt_wip_backup.py b/backup_blt_wip_backup/modeling_blt_wip_backup.py
new file mode 100644
index 0000000000000000000000000000000000000000..adc4104dcbebb38e3a72866ade92eba569f3df3c
--- /dev/null
+++ b/backup_blt_wip_backup/modeling_blt_wip_backup.py
@@ -0,0 +1,2166 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from enum import Enum
+from typing import Any, List, Optional, Tuple, Union
+
+import torch
+from pydantic import model_validator
+from torch import nn
+from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention
+import json
+import logging
+
+import torch
+import torch.nn
+import torch.nn as nn
+from torch.nn import functional as F
+
+import os
+from contextlib import nullcontext
+
+SEP = " "
+BOS_ID: int = 1
+EOS_ID: int = 2
+PAD_ID: int = -1
+BOE_ID: int = 0
+BPE_ID: int = 3
+OFFSET: int = 4
+
+BYTE_UNITS: int = 256
+
+RMSNorm = nn.RMSNorm
+
+logger = logging.getLogger()
+
+from .configuration_blt import (
+ BLTConfig,
+ PatchingModeEnum,
+ InitStdFactor,
+)
+
+from ...modeling_utils import PreTrainedModel
+from ...utils import logging as transformers_logging
+
+flex_attention_comp = flex_attention
+
+
+def causal_mask(b, h, q_idx, kv_idx):
+ return q_idx >= kv_idx
+
+
+def create_causal_mask(
+ seqlen,
+ attn_impl: str,
+ attn_bias_type: str | None,
+ *,
+ eos_id: int | None = None,
+ tokens: torch.Tensor | None = None,
+ sliding_window: int | None = None,
+):
+ if attn_impl == "sdpa":
+ BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0))
+
+ if attn_bias_type == "causal":
+ return "causal"
+
+ if BLT_SUPPRESS_ATTN_ERROR == 1:
+ return "causal"
+ else:
+ raise ValueError(
+ "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1"
+ )
+ elif attn_impl == "flex_attention":
+ return create_block_mask(causal_mask, None, None, seqlen, seqlen)
+ else:
+ raise NotImplementedError(
+ f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
+ )
+
+def cross_entropy(pred, target, **kwargs):
+ return F.nll_loss(
+ F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
+ target.flatten(end_dim=-1),
+ **kwargs,
+ )
+
+
+def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
+ assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
+ bs, slen, n_kv_heads, head_dim = x.shape
+ if n_rep == 1:
+ return x
+ return (
+ x[:, :, :, None, :]
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
+ )
+
+
+def precompute_freqs_cis(
+ dim: int,
+ end: int,
+ theta: float = 10000.0,
+ rope_use_fp32_in_outer_product: bool = False,
+):
+ """
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
+ The returned tensor contains complex values in complex64 data type.
+
+ Args:
+ dim (int): Dimension of the frequency tensor.
+ end (int): End index for precomputing frequencies.
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
+
+ Returns:
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
+ """
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ t = torch.arange(end, device=freqs.device)
+ if rope_use_fp32_in_outer_product:
+ t = t.to(torch.float32)
+
+ freqs = torch.outer(t, freqs).float()
+
+ cos, sin = freqs.cos(), freqs.sin()
+
+ return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
+ """
+ Reshape frequency tensor for broadcasting it with another tensor.
+
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
+
+ Args:
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
+ seq_dim (int): Sequence dimension index.
+
+ Returns:
+ torch.Tensor: Reshaped frequency tensor.
+ """
+ ndim = x.ndim
+ assert 0 <= seq_dim < ndim
+ assert freqs_cis.shape == (
+ x.shape[seq_dim],
+ x.shape[-3],
+ 2,
+ 2,
+ ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
+ shape = [
+ d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
+ ] + [2, 2]
+ return freqs_cis.view(*shape)
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ seq_dim: int,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
+ xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
+ freqs_cis = reshape_for_broadcast(
+ freqs_cis, xq_, seq_dim
+ ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
+ xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
+ xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
+class RotaryEmbedding(torch.nn.Module):
+ """
+ RotaryEmbedding Module
+ """
+
+ def __init__(
+ self,
+ theta: float,
+ head_dim: int,
+ max_seqlen: int = 1024,
+ rope_use_fp32_in_outer_product: bool = False,
+ ):
+ super().__init__()
+
+ self.theta = theta
+ self.head_dim = head_dim
+ self.max_seqlen = max_seqlen
+ self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
+
+ self.register_buffer(
+ "freqs_cis",
+ precompute_freqs_cis(
+ dim=head_dim,
+ end=max_seqlen,
+ theta=theta,
+ rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
+ ),
+ persistent=False,
+ )
+
+ def reset_parameters(self):
+ self.freqs_cis[...] = precompute_freqs_cis(
+ dim=self.head_dim,
+ end=self.max_seqlen,
+ theta=self.theta,
+ rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
+ )
+
+ def forward(
+ self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None
+ ):
+ """
+ Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
+ Args:
+ seqlen (int): Contiguous sequence length
+ tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
+
+ Returns:
+ Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
+ """
+ test = (seqlen is not None) or (tok_idx is not None)
+ assert test, "Should provide atleast seqlen or tok_idx"
+ if tok_idx is not None:
+ return self.freqs_cis[tok_idx]
+ elif seqlen is not None:
+ return self.freqs_cis[0:seqlen]
+
+
+class BLTAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ head_dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ rope_theta: float,
+ ):
+ super().__init__()
+
+ self.dim = dim
+ self.head_dim = head_dim
+ self.rope_theta = rope_theta
+
+ self.n_heads = n_heads
+ self.n_kv_heads = n_kv_heads
+ self.heads_per_group = self.n_heads // self.n_kv_heads
+
+ self.wq = nn.Linear(
+ dim,
+ n_heads * head_dim,
+ bias=False,
+ )
+ self.wk = nn.Linear(
+ dim,
+ n_kv_heads * head_dim,
+ bias=False,
+ )
+ self.wv = nn.Linear(
+ dim,
+ n_kv_heads * head_dim,
+ bias=False,
+ )
+
+ self.wo = nn.Linear(
+ n_heads * head_dim,
+ dim,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ freq_cis: torch.Tensor,
+ tok_idx: Optional[torch.Tensor] = None,
+ mask: Optional[Union[BlockMask, str]] = None,
+ attn_impl: str = "sdpa",
+ ) -> torch.Tensor:
+ # B S D
+ bsz, seq_len, dim = x.shape
+ xq = self.wq(x.view_as(x))
+ xk = self.wk(x.view_as(x))
+ xv = self.wv(x.view_as(x))
+
+ output_shape = xq.shape
+ # B S D -> B S H D
+ xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
+ xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
+ xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
+
+ xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
+
+ # This condition helps us be easily compatible
+ # with inference by adding a pluggable KVCache
+ if hasattr(self, "kv_cache"):
+ xk, xv = self.kv_cache.update(xk, xv, tok_idx)
+
+ xk = repeat_kv(xk, self.heads_per_group, dim=2)
+ xv = repeat_kv(xv, self.heads_per_group, dim=2)
+
+ if attn_impl == "flex_attention":
+ assert mask is None or isinstance(mask, BlockMask)
+ xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
+ output = flex_attention_comp(xq, xk, xv, block_mask=mask)
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
+
+ elif attn_impl == "sdpa":
+ xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
+ assert mask is None or isinstance(mask, (str, torch.Tensor))
+ is_causal = (mask == "causal") if isinstance(mask, str) else False
+ mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None
+ output = F.scaled_dot_product_attention(
+ xq,
+ xk,
+ xv,
+ is_causal=is_causal,
+ attn_mask=mask,
+ )
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
+ else:
+ raise NotImplementedError(
+ f"Attention implementation {attn_impl} not supported"
+ )
+
+ output_reshaped = output.reshape(output_shape)
+
+ output = self.wo(output_reshaped)
+
+ return output
+
+ def reset_parameters(self, init_std=None, factor=1.0):
+ init_std = init_std or (self.dim ** (-0.5)) / factor
+
+ for w in [self.wq, self.wk, self.wv]:
+ nn.init.trunc_normal_(
+ w.weight,
+ mean=0.0,
+ std=init_std,
+ a=-3 * init_std,
+ b=3 * init_std,
+ )
+
+ nn.init.trunc_normal_(
+ self.wo.weight,
+ mean=0.0,
+ std=init_std,
+ a=-3 * init_std,
+ b=3 * init_std,
+ )
+
+
+class BLTMLP(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ multiple_of: int,
+ ffn_dim_multiplier: Optional[float],
+ mp_size: int = 1,
+ ):
+ super().__init__()
+
+ hidden_dim = int(2 * hidden_dim / 3)
+ if ffn_dim_multiplier is not None:
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+ assert hidden_dim % mp_size == 0
+
+ self.dim = dim
+ self.hidden_dim = hidden_dim
+
+ self.w1 = nn.Linear(
+ dim,
+ hidden_dim,
+ bias=False,
+ )
+ self.w3 = nn.Linear(
+ dim,
+ hidden_dim,
+ bias=False,
+ )
+ self.w2 = nn.Linear(
+ hidden_dim,
+ dim,
+ bias=False,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # B S D
+ x1 = self.w1(x.view_as(x))
+ x3 = self.w3(x.view_as(x))
+ output = self.w2(F.silu(x1) * x3)
+ return output
+
+ def reset_parameters(self, init_std=None, factor=1.0):
+ in_init_std = init_std or (self.dim ** (-0.5)) / factor
+ out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor
+
+ nn.init.trunc_normal_(
+ self.w1.weight,
+ mean=0.0,
+ std=in_init_std,
+ a=-3 * in_init_std,
+ b=3 * in_init_std,
+ )
+ nn.init.trunc_normal_(
+ self.w2.weight,
+ mean=0.0,
+ std=out_init_std,
+ a=-3 * out_init_std,
+ b=3 * out_init_std,
+ )
+ nn.init.trunc_normal_(
+ self.w3.weight,
+ mean=0.0,
+ std=in_init_std,
+ a=-3 * in_init_std,
+ b=3 * in_init_std,
+ )
+
+
+class BLTTransformerLayer(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+
+ # Extract parameters from dictionary
+ dim = args['dim']
+ n_heads = args['n_heads']
+ head_dim = args['head_dim']
+ n_kv_heads = args['n_kv_heads']
+ rope_theta = args['rope_theta']
+ multiple_of = args['multiple_of']
+ ffn_dim_multiplier = args['ffn_dim_multiplier']
+ norm_eps = args['norm_eps']
+
+ assert (head_dim is not None) or (
+ n_heads is not None
+ ), "Should specify at least head_dim or n_heads"
+ self.head_dim = head_dim or dim // n_heads
+ self.n_heads = n_heads or dim // head_dim
+ self.n_kv_heads = n_kv_heads or self.n_heads
+
+ assert n_heads % self.n_kv_heads == 0
+ assert dim % n_heads == 0
+
+ self.attention = BLTAttention(
+ dim=dim,
+ head_dim=self.head_dim,
+ n_heads=self.n_heads,
+ n_kv_heads=self.n_kv_heads,
+ rope_theta=rope_theta,
+ )
+ self.feed_forward = BLTMLP(
+ dim=dim,
+ hidden_dim=4 * dim,
+ multiple_of=multiple_of,
+ ffn_dim_multiplier=ffn_dim_multiplier,
+ )
+ self.attention_norm = RMSNorm(dim, eps=norm_eps)
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ freq_cis: torch.Tensor,
+ tok_idx: Optional[torch.Tensor] = None,
+ mask: Optional[Union[BlockMask, str]] = None,
+ attn_impl: str = "sdpa",
+ ) -> torch.Tensor:
+ norm_x = self.attention_norm(x)
+ attn_out = self.attention(
+ norm_x,
+ freq_cis,
+ tok_idx=tok_idx,
+ mask=mask,
+ attn_impl=attn_impl,
+ )
+ h = x + attn_out
+ h_norm = self.ffn_norm(h)
+ out = h + self.feed_forward(h_norm)
+ return out
+
+ def init_weights(self, init_std=None, factor=1.0):
+ self.attention.reset_parameters(init_std, factor)
+ self.attention_norm.reset_parameters()
+
+ self.feed_forward.reset_parameters(init_std, factor)
+ self.ffn_norm.reset_parameters()
+
+
+def rightpad(seq, pad_id, max_len):
+ return seq + [pad_id] * (max_len - len(seq))
+
+
+def check_non_zero_after_zero(tensor):
+ zero_mask = tensor == 0
+ shifted_mask = torch.cat(
+ [
+ torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
+ zero_mask[:, :-1],
+ ],
+ dim=1,
+ )
+ non_zero_after_zero = (tensor != 0) & shifted_mask
+ return non_zero_after_zero.any()
+
+
+def fill_tokens(tokens, patch_size, fill_id):
+ batch_size, seq_len = tokens.shape
+ if seq_len % patch_size == 0:
+ return tokens
+ else:
+ remaining = patch_size - seq_len % patch_size
+ final_padding = tokens.new(batch_size, remaining).fill_(fill_id)
+ return torch.cat((tokens, final_padding), dim=1)
+
+
+def rolling_polynomial_hash(t, hash_func_nb: int = 0):
+ primes = [
+ 1000000007,
+ 5915587277,
+ 1500450271,
+ 3267000013,
+ 5754853343,
+ 4093082899,
+ 9576890767,
+ 3628273133,
+ 2860486313,
+ 5463458053,
+ 3367900313,
+ ]
+ prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device)
+ prime_powers = torch.stack([prime**i for i in range(t.shape[-1])])
+ return torch.sum(t * prime_powers, dim=-1)
+
+def byte_group_hash_function(
+ x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000
+):
+ """
+ Returns a hash of the input x and maps it to a value in the range [0, max_hash].
+
+ expects: x of shape (batch_size, seq_len) with values as ids in the token vocab.
+ returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash].
+
+ Note: max hash can make a big difference on the number of collisions.
+ """
+ with torch.no_grad():
+ bs, seq_len = x.shape
+ prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device)
+ x = torch.cat([prefix, x], dim=1)
+ windows = x.unfold(1, group_size, 1)
+ # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
+ hashes = rolling_polynomial_hash(windows, hash_func_nb)
+ hash_values_range = hashes % max_hash
+ hash_values_range.requires_grad = False
+ return hash_values_range
+
+
+def create_patch_mask_from_ids(
+ patch_ids, num_patches, window=None, patches_as_queries=False
+):
+ """
+ Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k)
+ is True if the patch id at position (i, j) is less than or equal to k.
+ Args:
+ patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids.
+ num_patches (int): Total number of patches.
+ window (int): If not None, only considers patches within a window of size window.
+ patches_as_queries (bool): If True, the patches are used as queries
+ Returns:
+ torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask.
+ """
+ bs, seq_len = patch_ids.shape
+ if not patches_as_queries:
+ q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches)
+ kv_ids = (
+ torch.arange(num_patches, device=patch_ids.device)
+ .unsqueeze(0)
+ .unsqueeze(0)
+ .expand(bs, seq_len, num_patches)
+ )
+ else:
+ kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len)
+ q_ids = (
+ torch.arange(num_patches, device=patch_ids.device)
+ .unsqueeze(0)
+ .unsqueeze(-1)
+ .expand(bs, num_patches, seq_len)
+ )
+ if window is None:
+ mask = q_ids == kv_ids
+ else:
+ mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window)
+ return mask
+
+
+def cross_attn_mask(
+ patch_ids,
+ patch_lengths,
+ N,
+ patches_as_queries=False,
+ cross_attn_k=1,
+ window=None,
+ block_mask=True,
+):
+ bs = patch_ids.shape[0]
+ with torch.no_grad():
+ # Create the patch mask
+ cross_mask = create_patch_mask_from_ids(
+ patch_ids,
+ patch_lengths.shape[1],
+ window=window,
+ patches_as_queries=patches_as_queries,
+ ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1)
+ q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N
+ kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k
+ assert cross_mask.shape == (
+ bs,
+ q_len,
+ kv_len,
+ ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}"
+ block_mask = None
+ if block_mask:
+
+ def patch_mask(b, h, q_idx, kv_idx):
+ return cross_mask[b, q_idx, kv_idx]
+
+ block_mask = create_block_mask(
+ patch_mask,
+ B=bs,
+ H=None,
+ Q_LEN=q_len,
+ KV_LEN=kv_len,
+ _compile=True,
+ )
+ return block_mask
+ else:
+ return torch.where(
+ cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))
+ ).unsqueeze(
+ 1
+ ) # [bs, 1, q_len, kv_len]
+
+
+def get_blt_input(
+ tokens: torch.Tensor,
+ enforce_patch_size_multiple: bool,
+ nb_boe: torch.Tensor,
+ patch_size: int,
+ boe_id: int,
+):
+ """
+ This function returns X_et, X_gt and X_dt, the encoder, global, and decoder
+ tokens respectively.
+
+ Consider the input and target sequences:
+ X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13]
+ Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14]
+ with patch_size=4
+
+ Note 1: that there will be no special tokens introduced at the patch level.
+ Note 2: X_e needs to be trimmed to be passed to Global
+
+ Current without boe:
+ X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]]
+ X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch
+ X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
+ Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
+
+ --> lag fix:
+ X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]]
+ X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]]
+ X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
+ Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
+
+ Dynamic (current):
+ X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos]
+ Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
+
+ entropy patching:
+ input: 7, bos, 9, 10
+ pred (high entropy): eos, 8, 10, eos
+
+ X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos]
+ X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]]
+ X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]]
+ Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
+
+ --> lag fix no boe (force single byte first patch):
+ X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
+ X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
+ X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
+ Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
+
+ input: 4, 7, bos, 9, 10
+ pred (high entropy): 5, eos, 8, 10, eos
+
+ X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
+ X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
+ X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
+ Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
+
+ Handle the last byte properly.
+ patch_lengths = [1, 1, 3, 2, 2 1 2 2 1]
+ X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
+ X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch
+ X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]]
+ Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]]
+
+
+ bpe delim
+ X_et = [[3,4,5,6,7,,eos,bos,,8,9,,10,,eos,bos,11,12]
+ X_g = [[3], [4,5,6,7,], [eos,bos,], ..
+ X_dt = [[3,4,5,6,7], [,eos,bos], [,bos,8], ..
+ Y = [4,5,6,7,, eos,bos, 8,9,, ..
+
+
+ Note 1: that there will be no special tokens introduced at the patch level.
+ Note 2: X_e needs to be trimmed to be passed to Global
+ """
+ batch_size, seq_len = tokens.shape
+ local_encoder_tokens = tokens
+ local_decoder_tokens = tokens
+
+ if nb_boe > 0:
+ padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id)
+ local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1)
+ # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id)
+
+ # create global tokens, contains boe tokens and eos
+ # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
+ # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size)
+ # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:]
+ # global_tokens += global_tokens.eq(0).int() * boe_id
+ # TODO: fix this when we want to use block causal in the global.
+
+ if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0:
+ local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
+
+ return local_encoder_tokens, None, local_decoder_tokens
+
+
+class LocalModelBase(nn.Module):
+ def __init__(self, config: BLTConfig, component_type: str = "encoder"):
+ super().__init__()
+
+ # Store config for later use
+ self.config = config
+
+ # Use component-specific dimensions
+ if component_type == "encoder":
+ self.dim = config.dim_local_encoder
+ self.n_layers = config.n_layers_local_encoder
+ self.n_heads = config.n_heads_local_encoder
+ self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen
+ self.attn_bias_type = "local_block_causal"
+ self.sliding_window = config.local_attention_window_len
+ elif component_type == "decoder":
+ self.dim = config.dim_local_decoder
+ self.n_layers = config.n_layers_local_decoder
+ self.n_heads = config.n_heads_local_decoder
+ self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen
+ self.attn_bias_type = "local_block_causal"
+ self.sliding_window = config.local_attention_window_len
+ else:
+ raise ValueError(f"Unknown component_type: {component_type}")
+
+ self.dropout = config.dropout
+ self.vocab_size = config.vocab_size + config.pm_size
+ self.patch_size = config.patch_size
+
+ self.attn_impl = config.attn_impl
+ self.use_rope = config.use_rope
+ self.init_std_factor = config.init_std_factor
+ self.init_base_std = config.init_base_std
+ self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None)
+ self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None)
+ self.cross_attn_k = getattr(config, "cross_attn_k", None)
+ self.eos_id = config.eos_token_id
+
+ self.boe_id = BOE_ID
+
+ # Initialize cross attention layers as None (will be set by subclasses if needed)
+ self.cross_attn_layers = None
+
+ # Create parameter dict for BLTTransformerLayers
+ layer_params = {
+ 'dim': self.dim,
+ 'n_heads': self.n_heads,
+ 'head_dim': config.head_dim,
+ 'n_kv_heads': getattr(config, 'n_kv_heads', None),
+ 'rope_theta': config.rope_theta,
+ 'multiple_of': getattr(config, 'multiple_of', 256),
+ 'ffn_dim_multiplier': getattr(config, 'ffn_dim_multiplier', None),
+ 'norm_eps': config.norm_eps,
+ }
+
+ self.layers = nn.ModuleList(
+ [BLTTransformerLayer(layer_params) for _ in range(self.n_layers)]
+ )
+
+ if not self.use_rope:
+ self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length
+ else:
+ self.rope = RotaryEmbedding(
+ theta=config.rope_theta,
+ head_dim=config.head_dim or self.dim // self.n_heads,
+ max_seqlen=self.max_seqlen,
+ rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product,
+ )
+ self.pos_embeddings = None
+
+ # Set dimension-specific embedding dimensions
+ if component_type == "encoder":
+ self.dim_token_emb = config.encoder_dim_token_emb
+ self.dim_patch_emb = config.encoder_dim_patch_emb
+ elif component_type == "decoder":
+ self.dim_token_emb = config.decoder_dim_token_emb
+ self.dim_patch_emb = config.dim_global
+
+ self.token_embedding_projection = (
+ nn.Linear(self.dim_token_emb, self.dim, bias=False)
+ if self.dim_token_emb is not None and self.dim_token_emb != self.dim
+ else None
+ )
+
+ self.patch_embedding_projection = self._create_patch_projection(config)
+
+ def _should_create_patch_projection(self, config: BLTConfig):
+ dimension_mismatch = (
+ self.dim_patch_emb is not None and self.dim_patch_emb != self.dim
+ )
+
+ # Check cross attention conditions
+ cross_attn_conditions = (
+ config.cross_attn_encoder and config.cross_attn_init_by_pooling
+ ) or (config.cross_attn_decoder and config.cross_attn_init_by_pooling)
+
+ return dimension_mismatch or cross_attn_conditions
+
+ def _create_patch_projection(self, config):
+ if not self._should_create_patch_projection(config):
+ return None
+
+ output_dim = self.dim_token_emb * (self.cross_attn_k or 1)
+
+ return nn.Linear(
+ in_features=self.dim_patch_emb,
+ out_features=output_dim,
+ bias=False,
+ )
+
+ def apply_embedding(self, tokens, embeds):
+ if embeds is not None:
+ return embeds
+ else:
+ return self.tok_embeddings(tokens)
+
+ def init_weights(self, init_std=None):
+ self.rope.reset_parameters()
+ if hasattr(self, "norm"):
+ self.norm.reset_parameters()
+
+ init_std = init_std or (self.dim ** (-0.5))
+ if hasattr(self, "tok_embeddings"):
+ nn.init.trunc_normal_(
+ self.tok_embeddings.weight,
+ mean=0.0,
+ std=init_std,
+ a=-3 * init_std,
+ b=3 * init_std,
+ )
+ if self.pos_embeddings is not None:
+ nn.init.trunc_normal_(
+ self.pos_embeddings.weight,
+ mean=0.0,
+ std=init_std,
+ a=-3 * init_std,
+ b=3 * init_std,
+ )
+
+ for depth, layer in enumerate(self.layers):
+ factor = self.config.get_init_std_factor(depth)
+ layer.init_weights(self.init_base_std, factor)
+
+ if hasattr(self, "output"):
+ nn.init.trunc_normal_(
+ self.output.weight,
+ mean=0.0,
+ std=init_std,
+ a=-3 * init_std,
+ b=3 * init_std,
+ )
+
+ if self.token_embedding_projection is not None:
+ nn.init.trunc_normal_(
+ self.token_embedding_projection.weight,
+ mean=0.0,
+ std=init_std,
+ a=-3 * init_std,
+ b=3 * init_std,
+ )
+
+ if self.patch_embedding_projection is not None:
+ patch_emb_std = self.dim_patch_emb ** (-0.5)
+ nn.init.trunc_normal_(
+ self.patch_embedding_projection.weight,
+ mean=0.0,
+ std=patch_emb_std,
+ a=-3 * patch_emb_std,
+ b=3 * patch_emb_std,
+ )
+
+ if self.cross_attn_layers is not None:
+ for depth, layer in enumerate(self.cross_attn_layers):
+ factor = self.config.get_init_std_factor(depth)
+ layer.init_weights(None, factor)
+
+
+class LocalEncoder(LocalModelBase):
+ def __init__(self, config: BLTConfig):
+ super().__init__(config, component_type="encoder")
+
+ self.apply_transformer = config.use_local_encoder_transformer
+ self.downsampling_by_pooling = config.downsampling_by_pooling
+ self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None
+ self.cross_attn_encoder = config.cross_attn_encoder
+ self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder
+ self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling
+ self.cross_attn_nheads = config.cross_attn_nheads
+
+ self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim)
+
+ if self.cross_attn_encoder:
+ self.cross_attn_layers = torch.nn.ModuleList()
+ layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1
+ for _ in range(layers_to_add):
+ self.cross_attn_layers.append(
+ BLTCrossAttention(
+ dim=self.dim,
+ head_dim=self.dim // self.cross_attn_nheads,
+ n_heads=self.cross_attn_nheads,
+ n_kv_heads=self.cross_attn_nheads,
+ norm_eps=config.norm_eps,
+ )
+ )
+
+ def apply_embedding(self, tokens, embeds):
+ if embeds is not None:
+ assert (
+ self.expects_hash_embeddings
+ ), "Not expecting embeddings to be passed."
+ return embeds
+ else:
+ return self.tok_embeddings(tokens)
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ embeds: Optional[torch.Tensor] = None,
+ patch_embeds: Optional[torch.Tensor] = None,
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
+ cross_mask: Optional[torch.Tensor] = None,
+ num_patches: Optional[int] = None,
+ patch_ids: Optional[torch.Tensor] = None,
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+ ):
+ """ """
+ bs, seqlen = tokens.shape
+ if mask is None:
+ mask = create_causal_mask(
+ seqlen,
+ self.attn_impl,
+ "local_block_causal",
+ sliding_window=self.sliding_window,
+ tokens=tokens,
+ eos_id=self.eos_id,
+ )
+
+ h = self.apply_embedding(tokens, embeds)
+ freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
+
+ h = F.dropout(h, p=self.dropout, training=self.training)
+
+ for i, layer in enumerate(self.layers):
+ h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
+ # check if cross attention should be applied to either all layer or only the last layer
+ if self.cross_attn_encoder and (
+ i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
+ ):
+ # apply pooling and project
+ if self.cross_attn_init_by_pooling and patch_embeds is None:
+ patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids)
+ if self.patch_embedding_projection is not None:
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
+ patch_embeds = patch_embeds.reshape(
+ bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
+ )
+
+ layer_idx = i if self.cross_attn_all_layers_encoder else 0
+ patch_embeds_cross = self.cross_attn_layers[layer_idx](
+ x=patch_embeds,
+ kv=h,
+ mask=cross_mask,
+ )
+ patch_embeds = patch_embeds + patch_embeds_cross
+
+ h_residual = patch_embeds if self.cross_attn_encoder else None
+ return (h, h_residual), cache
+
+
+
+ def patch_reduce(self, h, max_num_patches, reduction, patch_ids):
+ """
+ Reduce variable length patches to single embedding per patch
+ Note: this works with variable number of patches for different sequences in the batch
+ It handles variable length patches by assuming that patch_lengths will be 0 for any
+ extra patches on the *right*. Since there can be a variable number of patches
+ this function also return the number of patches for each sequence in the batch.
+ Any embeddings on the right that are not allocated to a patch
+ (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
+ will be sent to a dummy patch, which is trimmed before returning.
+ """
+ bs, seq_len, emb_dim = h.shape
+
+ patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
+
+ reduced_embs = torch.zeros(
+ (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device
+ )
+ reduced_embs = reduced_embs.scatter_reduce(
+ src=h,
+ dim=1,
+ index=patch_ids,
+ reduce=reduction,
+ include_self=False,
+ )
+ reduced_embs = reduced_embs[:, :max_num_patches, :]
+
+ return reduced_embs
+
+
+class LocalDecoder(LocalModelBase):
+ def __init__(self, config: BLTConfig):
+ super().__init__(config, component_type="decoder")
+
+ # Model configuration flags
+ self.cross_attn_decoder = config.cross_attn_decoder
+ self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder
+ self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling
+ self.cross_attn_nheads = config.cross_attn_nheads
+
+ self.norm = RMSNorm(self.dim, eps=config.norm_eps)
+
+ if self.cross_attn_decoder:
+ self.cross_attn_layers = torch.nn.ModuleList()
+ layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1
+ for _ in range(layers_to_add):
+ self.cross_attn_layers.append(
+ BLTCrossAttention(
+ dim=self.dim,
+ head_dim=self.dim // self.cross_attn_nheads,
+ n_heads=self.cross_attn_nheads,
+ n_kv_heads=self.cross_attn_nheads,
+ norm_eps=config.norm_eps,
+ )
+ )
+
+ self.output = nn.Linear(
+ self.dim,
+ config.vocab_size,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ embeds: Optional[torch.Tensor],
+ patch_embeds: Optional[torch.Tensor] = None,
+ mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
+ cross_mask: Optional[torch.Tensor] = None,
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+ ):
+ bs, seqlen = tokens.shape
+ assert embeds is not None, "Embeddings must be provided"
+
+ if mask is None:
+ mask = create_causal_mask(
+ seqlen,
+ self.attn_impl,
+ "local_block_causal",
+ sliding_window=self.sliding_window,
+ tokens=tokens,
+ eos_id=self.eos_id,
+ )
+
+ h = embeds
+
+ if self.patch_embedding_projection is not None:
+ assert patch_embeds is not None, "Patch embeddings must be passed."
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
+ if self.cross_attn_k is not None:
+ patch_embeds = patch_embeds.reshape(
+ bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
+ )
+
+ if patch_embeds is not None and not self.cross_attn_decoder:
+ h = h + patch_embeds
+
+ freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
+
+ h = F.dropout(h, p=self.dropout, training=self.training)
+ for i, layer in enumerate(self.layers):
+ if self.cross_attn_decoder and (
+ i == 0 or self.cross_attn_all_layers_decoder
+ ):
+ # Use cross attention to extract info from patch_embeds into h
+ h_cross = self.cross_attn_layers[i](
+ x=h,
+ kv=patch_embeds,
+ mask=cross_mask,
+ )
+ h = h + h_cross
+
+ h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
+
+ h_preds = self.norm(h)
+ h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
+ h_preds = self.output(h_preds)
+ h_preds = h_preds.float()
+ return h_preds, cache
+
+
+class BLTCrossAttention(nn.Module):
+ """
+ BLTCrossAttention block to attend to the encoder states from the decoder.
+ Rope is not supported.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ head_dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ norm_eps: float,
+ ):
+ super().__init__()
+
+ self.dim = dim
+ self.head_dim = head_dim
+
+ self.n_heads = n_heads
+ self.n_kv_heads = n_kv_heads
+ self.heads_per_group = self.n_heads // self.n_kv_heads
+
+ self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps)
+ self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
+
+ self.wq = nn.Linear(
+ dim,
+ n_heads * head_dim,
+ bias=False,
+ )
+ self.wk = nn.Linear(
+ dim,
+ n_kv_heads * head_dim,
+ bias=False,
+ )
+ self.wv = nn.Linear(
+ dim,
+ n_kv_heads * head_dim,
+ bias=False,
+ )
+
+ self.wo = nn.Linear(
+ n_heads * head_dim,
+ dim,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ kv: torch.Tensor,
+ mask: Optional[Union[BlockMask, str]] = None,
+ ) -> torch.Tensor:
+ # B S D
+ bsz, seq_len, _ = x.shape
+ _, slen_kv, _ = kv.shape
+ x_norm = self.cross_attn_norm_q(x)
+ kv = self.cross_attn_norm_kv(kv)
+
+ xq = self.wq(x_norm)
+ xk = self.wk(kv)
+ xv = self.wv(kv)
+
+ output_shape = xq.shape
+ # B S D -> B S H D
+ xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
+ xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
+ xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
+
+ xk = repeat_kv(xk, self.heads_per_group, dim=2)
+ xv = repeat_kv(xv, self.heads_per_group, dim=2)
+
+ # assert mask is None or isinstance(mask, BlockMask)
+ xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
+ #output = flex_attention_comp(xq, xk, xv, block_mask=mask)
+ is_causal = (mask == "causal") if isinstance(mask, str) else False
+ mask = mask if isinstance(mask, torch.Tensor) else None
+ mask = mask.to(dtype=xq.dtype).to(xq.device)
+ output = F.scaled_dot_product_attention(
+ xq,
+ xk,
+ xv,
+ is_causal=is_causal,
+ attn_mask=mask,
+ )
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
+
+ output = self.wo(output.reshape(output_shape))
+
+ return x + output
+
+ def init_weights(self, base_std: float, factor: float = 1.0):
+ std = base_std or (self.dim ** (-0.5)) / factor
+
+ nn.init.trunc_normal_(
+ self.wq.weight,
+ mean=0.0,
+ std=std,
+ a=-3 * std,
+ b=3 * std,
+ )
+
+ nn.init.trunc_normal_(
+ self.wk.weight,
+ mean=0.0,
+ std=std,
+ a=-3 * std,
+ b=3 * std,
+ )
+
+ nn.init.trunc_normal_(
+ self.wv.weight,
+ mean=0.0,
+ std=std,
+ a=-3 * std,
+ b=3 * std,
+ )
+
+ nn.init.trunc_normal_(
+ self.wo.weight,
+ mean=0.0,
+ std=std,
+ a=-3 * std,
+ b=3 * std,
+ )
+ self.cross_attn_norm_q.reset_parameters()
+ self.cross_attn_norm_kv.reset_parameters()
+
+
+class GlobalTransformer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ # Store config for later use
+ self.config = config
+
+ self.dim = config.dim
+ self.init_base_std = config.init_base_std
+ self.attn_impl = config.attn_impl
+ self.attn_bias_type = config.attn_bias_type
+ self.init_std_factor = config.init_std_factor
+ self.max_seqlen = config.max_seqlen
+ self.rope_embeddings = RotaryEmbedding(
+ theta=config.rope_theta,
+ head_dim=config.head_dim or config.dim // config.n_heads,
+ max_seqlen=config.max_seqlen,
+ rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product,
+ )
+ # Handle both eos_id and eos_token_id for compatibility
+ self.eos_id = getattr(config, 'eos_id', getattr(config, 'eos_token_id', 2))
+
+ # Create parameter dict for BLTTransformerLayers
+ layer_params = {
+ 'dim': self.dim,
+ 'n_heads': config.n_heads,
+ 'head_dim': config.head_dim,
+ 'n_kv_heads': getattr(config, 'n_kv_heads', None),
+ 'rope_theta': config.rope_theta,
+ 'multiple_of': getattr(config, 'multiple_of', 256),
+ 'ffn_dim_multiplier': getattr(config, 'ffn_dim_multiplier', None),
+ 'norm_eps': config.norm_eps,
+ }
+
+ self.layers = nn.ModuleList()
+ for _ in range(config.n_layers):
+ self.layers.append(BLTTransformerLayer(layer_params))
+
+ # GlobalTransformer specific attributes
+ self.dropout = config.dropout
+ self.dim_token_emb = config.dim_token_emb
+
+ self.token_embedding_projection = None
+ if config.dim_token_emb is not None and config.dim_token_emb != self.dim:
+ self.token_embedding_projection = nn.Linear(
+ config.dim_token_emb,
+ config.dim,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ tok_idx: Optional[torch.Tensor] = None,
+ embeds: Optional[torch.Tensor] = None,
+ mask: Optional[Union[BlockMask, torch.Tensor, str]] = None,
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
+ ):
+ bs, seqlen = tokens.shape
+
+ h = embeds
+
+ mask = (
+ mask
+ if mask is not None
+ else create_causal_mask(
+ seqlen,
+ self.attn_impl,
+ self.attn_bias_type,
+ tokens=tokens,
+ eos_id=self.eos_id,
+ )
+ )
+
+ if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
+ h = self.token_embedding_projection(h)
+
+ h = F.dropout(h, p=self.dropout, training=self.training)
+
+ freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)
+
+ for i, layer in enumerate(self.layers):
+ h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
+
+ return h, cache
+
+ def init_weights(self):
+ self.rope_embeddings.reset_parameters()
+ for depth, layer in enumerate(self.layers):
+ factor = self.config.get_init_std_factor(depth)
+ layer.init_weights(self.init_base_std, factor)
+
+ # GlobalTransformer specific initialization
+ std = self.dim_token_emb ** (-0.5)
+ if self.token_embedding_projection is not None:
+ nn.init.trunc_normal_(
+ self.token_embedding_projection.weight,
+ mean=0.0,
+ std=std,
+ a=-3 * std,
+ b=3 * std,
+ )
+
+def compute_hash_embeddings(
+ local_encoder_tokens: torch.Tensor,
+ local_encoder,
+ encoder_hash_tok_embedding: nn.ModuleList,
+ encoder_hash_byte_group_nb_functions: int,
+ encoder_hash_byte_group_size: list,
+ encoder_hash_byte_group_vocab: int,
+) -> torch.Tensor:
+ """
+ Compute embeddings using hash token embeddings.
+
+ Args:
+ local_encoder_tokens: Input tokens tensor
+ local_encoder: Encoder object with tok_embeddings method
+ encoder_hash_tok_embedding: ModuleList of hash token embeddings
+ encoder_hash_byte_group_nb_functions: Number of hash functions
+ encoder_hash_byte_group_size: List of byte group sizes
+ encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings
+
+ Returns:
+ torch.Tensor: Combined embeddings
+ """
+ if encoder_hash_tok_embedding is None:
+ return None
+
+ local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens)
+
+ i = 0
+ for func_nb in range(encoder_hash_byte_group_nb_functions):
+ for byte_group_size in encoder_hash_byte_group_size:
+ hash_ids = byte_group_hash_function(
+ local_encoder_tokens,
+ byte_group_size,
+ hash_func_nb=func_nb,
+ max_hash=encoder_hash_byte_group_vocab,
+ )
+ hash_tok_embedding = encoder_hash_tok_embedding[i]
+ local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids)
+ i += 1
+
+ assert i == len(encoder_hash_tok_embedding)
+ return local_encoder_embeds
+
+
+class BLTPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ BLT models.
+
+ This class provides the interface for model loading, saving, and weight initialization for all BLT model variants.
+ It inherits from [`PreTrainedModel`] which provides the core functionality for working with HuggingFace models.
+
+ Args:
+ config ([`BLTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+ """
+
+ config_class = BLTConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["BLTTransformerLayer", "LocalEncoder", "LocalDecoder", "GlobalTransformer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = False # BLT uses its own attention implementation
+ _supports_sdpa = True
+ _supports_cache_class = False
+
+ def _init_weights(self, module):
+ """Initialize the weights - this is called by PreTrainedModel but we delegate to our custom init"""
+ # Don't do anything here - we use the custom init_weights method instead
+ pass
+
+
+class BLTModel(BLTPreTrainedModel):
+ """
+ The BLTModel (BLT) is a byte-level language model architecture that processes byte sequences
+ by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
+ and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for
+ improved performance and inference efficiency.
+ """
+
+ def __init__(self, config: BLTConfig):
+ super().__init__(config)
+
+ # Store config reference
+ self.config = config
+
+ # Create main components - they will read their parameters from config
+ self.local_encoder = LocalEncoder(config)
+
+ # Create global-specific config by copying config and overriding dimensions
+ global_config = type(config)(**config.to_dict())
+ global_config.dim = config.dim_global
+ global_config.n_layers = config.n_layers_global
+ global_config.n_heads = config.n_heads_global
+ global_config.n_kv_heads = config.n_kv_heads_global
+ global_config.dim_token_emb = config.global_dim_patch_emb
+
+ self.global_transformer = GlobalTransformer(global_config)
+ self.local_decoder = LocalDecoder(config)
+
+ # Initialize hash embeddings
+ self.encoder_hash_tok_embedding = init_hash_embeddings(
+ config,
+ local_encoder_dim=self.local_encoder.dim,
+ encoder_hash_byte_group_size=config.encoder_hash_byte_group_size,
+ )
+
+ # Initialize patcher if needed
+ if config.patch_in_forward:
+ if config.realtime_patching and config.entropy_model_checkpoint_dir is not None:
+ # Load entropy model directly
+ entropy_model_checkpoint_dir = config.entropy_model_checkpoint_dir
+
+ if not os.path.exists(entropy_model_checkpoint_dir):
+ raise FileNotFoundError(f"Entropy model checkpoint directory not found: {entropy_model_checkpoint_dir}")
+
+ # Load entropy model parameters
+ params_path = os.path.join(entropy_model_checkpoint_dir, "params.json")
+ if not os.path.exists(params_path):
+ raise FileNotFoundError(f"params.json not found in: {entropy_model_checkpoint_dir}")
+
+ with open(params_path) as fr:
+ reloaded = json.loads(fr.read())
+
+ torch.set_default_dtype(torch.bfloat16)
+ model_params = reloaded["entropy_model"]
+ logger.warning(
+ "Update checkpoint to load attn and sliding window args from checkpoint"
+ )
+
+ # Override patcher configuration with actual entropy model parameters from checkpoint
+ config.patcher_dim = model_params["dim"]
+ config.patcher_n_layers = model_params["n_layers"]
+ config.patcher_n_heads = model_params["n_heads"]
+ config.patcher_max_seqlen = model_params["max_seqlen"]
+ config.patcher_ffn_dim_multiplier = model_params["ffn_dim_multiplier"]
+ config.patcher_vocab_size = model_params["vocab_size"]
+ # Use sensible defaults for parameters not in checkpoint
+ config.patcher_attn_bias_type = "local_block_causal"
+ config.patcher_attn_impl = "sdpa" # originally xformers
+ config.patcher_sliding_window = 512
+
+ # BLTPatcher will extract patcher_ parameters from config directly
+ self.patcher = BLTPatcher(config)
+
+ state_path = os.path.join(
+ entropy_model_checkpoint_dir, "consolidated.pth"
+ )
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.patcher.load_state_dict(
+ torch.load(state_path, map_location=device)["model"], strict=False
+ )
+ self.patcher.to(device)
+ self.patcher = self.patcher.eval()
+ # no grads for the model:
+ for param in self.patcher.parameters():
+ param.requires_grad = False
+ else:
+ self.patcher = None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
+ """
+ Convert patch lengths to patch IDs for each token position.
+
+ For each token position in the sequence, determines which patch it belongs to.
+
+ Args:
+ patch_lengths: [batch_size, num_patches] - length of each patch
+ seq_len: total sequence length
+
+ Returns:
+ patch_ids: [batch_size, seq_len] - patch index for each token position
+
+ Example:
+ patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1
+ seq_len = 10
+ Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]]
+ # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3
+ """
+ batch_size, num_patches = patch_lengths.shape
+
+ # Create patch start positions: [0, 3, 5, 9] for the example above
+ patch_starts = torch.cat([
+ torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
+ patch_lengths.cumsum(dim=-1)[:, :-1] # cumsum without the final total
+ ], dim=-1)
+
+ # For each token position, find which patch it belongs to
+ # by finding the rightmost patch start that's <= the position
+ token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1]
+
+ # Broadcasting: patch_starts[batch, patch] <= token_positions[position]
+ # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t
+ position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)
+
+ # Count how many patch starts are <= each position, then subtract 1 to get patch index
+ patch_ids = position_ge_patch_start.sum(dim=-1) - 1
+
+ return patch_ids
+
+ def _decoder_patch_ids_from_lengths(self, patch_lengths: torch.Tensor, nb_boe: int, seq_len: int) -> torch.Tensor:
+ """
+ Create decoder patch IDs by skipping the first encoder patch.
+
+ The decoder starts after the first patch (which contains BOE tokens),
+ so we need to map decoder positions to the remaining patches.
+
+ Args:
+ patch_lengths: [batch_size, num_patches] from encoder
+ nb_boe: number of beginning-of-example tokens in first patch
+ seq_len: decoder sequence length
+
+ Returns:
+ decoder_patch_ids: [batch_size, seq_len] mapping decoder positions to patch indices
+ """
+ # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens)
+ decoder_patch_lengths = patch_lengths[:, 1:]
+
+ # Create patch IDs for the decoder sequence using the remaining patches
+ return self._patch_ids_from_lengths(decoder_patch_lengths, seq_len)
+
+
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ patch_lengths: Optional[torch.Tensor] = None,
+ ):
+ # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings
+ # are no longer used in the final BLT model
+
+ bs, N = tokens.shape # Batch size and sequence length
+
+ # Get megabyte inputs
+ nb_boe = int(0 if self.config.patching_mode != "" else self.config.patch_size - 1)
+ local_encoder_tokens, _, local_decoder_tokens = get_blt_input(
+ tokens=tokens,
+ enforce_patch_size_multiple=False,
+ nb_boe=nb_boe,
+ patch_size=self.config.patch_size,
+ boe_id=BOE_ID,
+ )
+
+ # Patching
+ if patch_lengths is None:
+ # assert (
+ # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward
+ # ), "Patch in forward not enabled and no patch_lengths passed."
+
+ # PATCHER MODEL DEFINED
+ if self.config.patching_mode == PatchingModeEnum.entropy:
+ _, patch_lengths, _ = self.patcher(
+ local_encoder_tokens,
+ patch_size=self.config.patch_size,
+ include_next_token=True,
+ threshold=self.config.patching_threshold,
+ threshold_add=self.config.patching_threshold_add,
+ monotonicity=self.config.monotonicity,
+ max_patch_length=self.config.max_patch_length,
+ patching_batch_size=self.config.patching_batch_size,
+ device=self.config.patching_device,
+ )
+ else:
+ # self.config.patching_mode == PatchingModeEnum.byte
+ bs, seq_len = local_encoder_tokens.shape
+ seq_len_next_tok = seq_len + 1 # include_next_token=True
+ patch_lengths = torch.ones(
+ (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device
+ )
+
+ # Apply any processing to patch lengths
+ if self.config.max_patch_length is not None:
+ # TODO: avoid going back to a list here.
+ patch_lengths = [
+ BLTPatcher.split_large_numbers(pl, self.config.max_patch_length)
+ for pl in patch_lengths.tolist()
+ ]
+ max_len = max([len(pl) for pl in patch_lengths])
+ patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
+ patch_lengths = torch.tensor(
+ patch_lengths, dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device
+ )
+ assert not check_non_zero_after_zero(patch_lengths)
+ # Find the last non-zero column index using argmax on a reversed version of the tensor
+ last_non_zero_col_reversed = (
+ (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
+ )
+ # Slice the tensor up to the last non-zero column
+ patch_lengths = patch_lengths[
+ :, : patch_lengths.shape[1] - last_non_zero_col_reversed
+ ]
+ else:
+ if nb_boe > 0:
+ patch_lengths[:, 0] += nb_boe
+
+ assert torch.min(patch_lengths) >= 0
+
+ # Generate patch IDs from patch_lengths
+ patch_ids = self._patch_ids_from_lengths(
+ patch_lengths, local_encoder_tokens.shape[-1]
+ )
+ assert torch.max(patch_ids) + 1 <= torch.max(
+ (patch_lengths != 0).sum(dim=-1)
+ ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}"
+
+ cross_attn_mask_enc = None
+ # Cross-attention encoder
+ if self.config.cross_attn_encoder:
+ cross_attn_mask_enc = cross_attn_mask(
+ patch_ids,
+ patch_lengths,
+ N,
+ patches_as_queries=True,
+ cross_attn_k=self.config.cross_attn_k,
+ window=self.config.cross_attn_window_encoder,
+ block_mask=self.config.cross_attn_use_flex_attention,
+ )
+
+ # Hashing and embedding
+ local_encoder_embeds = compute_hash_embeddings(
+ local_encoder_tokens=local_encoder_tokens,
+ local_encoder=self.local_encoder,
+ encoder_hash_tok_embedding=self.encoder_hash_tok_embedding,
+ encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions,
+ encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size,
+ encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab,
+ )
+
+ # NOTE: Frequency-based n-gram embeddings removed as per paper
+ # The final BLT model uses only hash-based n-gram embeddings
+
+ # Local encoder
+ (h_encoder, h_cross), cache_encoder = self.local_encoder(
+ tokens=local_encoder_tokens,
+ embeds=local_encoder_embeds,
+ patch_embeds=None,
+ cross_mask=cross_attn_mask_enc,
+ num_patches=patch_lengths.shape[1],
+ patch_ids=patch_ids,
+ )
+
+ # Downsampling
+ h = h_cross.view(bs, patch_lengths.shape[1], -1)
+
+ # Global transformer
+ global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(BOE_ID)
+ rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id)
+ eos_patch_ids = patch_ids[rows, cols]
+ global_tokens[rows, eos_patch_ids] = self.config.eos_token_id
+
+ h, _ = self.global_transformer(
+ embeds=h,
+ tokens=global_tokens,
+ )
+
+ # Unpatching
+ dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :]
+
+ # Generate decoder patch IDs
+ decoder_patch_ids = self._decoder_patch_ids_from_lengths(
+ patch_lengths, nb_boe, local_decoder_tokens.shape[-1]
+ )
+ assert (
+ torch.max(decoder_patch_ids) + 1 <= h.shape[1]
+ ), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}"
+ assert (
+ decoder_patch_ids.shape[1] == dec_embeds.shape[1]
+ ), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}"
+
+ # Cross-attention decoder
+ if not self.config.cross_attn_decoder:
+ h = torch.gather(
+ h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
+ )
+ cross_attn_mask_dec = None
+ assert local_decoder_tokens.shape == h.shape[:-1]
+ else:
+ cross_attn_mask_dec = cross_attn_mask(
+ decoder_patch_ids,
+ patch_lengths,
+ N,
+ patches_as_queries=False,
+ cross_attn_k=self.config.cross_attn_k,
+ window=self.config.cross_attn_window_decoder,
+ block_mask=self.config.cross_attn_use_flex_attention,
+ )
+
+ # Local decoder
+ output, _ = self.local_decoder(
+ embeds=dec_embeds,
+ patch_embeds=h,
+ tokens=local_decoder_tokens,
+ cross_mask=cross_attn_mask_dec,
+ )
+ return output
+
+ def init_weights(self):
+ self.local_encoder.init_weights()
+ self.global_transformer.init_weights()
+ self.local_decoder.init_weights()
+
+ if self.encoder_hash_tok_embedding is not None:
+ emb_std = self.local_encoder.dim ** (-0.5)
+ for emb in self.encoder_hash_tok_embedding:
+ nn.init.trunc_normal_(
+ emb.weight,
+ mean=0.0,
+ std=emb_std,
+ a=-3 * emb_std,
+ b=3 * emb_std,
+ )
+
+
+class BLTPatcher(BLTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ # Store config reference for later use
+ self.config = config
+
+ # Extract patcher parameters from BLTConfig
+ self.dim = config.patcher_dim
+ self.init_base_std = config.patcher_init_base_std
+ self.attn_impl = config.patcher_attn_impl
+ self.attn_bias_type = config.patcher_attn_bias_type
+ self.init_std_factor = config.patcher_init_std_factor
+ self.max_seqlen = config.patcher_max_seqlen
+ n_layers = config.patcher_n_layers
+ n_heads = config.patcher_n_heads
+ head_dim = config.patcher_head_dim
+ rope_theta = config.patcher_rope_theta
+ rope_use_fp32_in_outer_product = config.patcher_rope_use_fp32_in_outer_product
+ norm_eps = config.patcher_norm_eps
+ vocab_size = config.patcher_vocab_size
+ weight_tying = config.patcher_weight_tying
+ sliding_window = config.patcher_sliding_window
+ eos_token_id = config.patcher_eos_token_id
+
+ self.rope_embeddings = RotaryEmbedding(
+ theta=rope_theta,
+ head_dim=head_dim or self.dim // n_heads,
+ max_seqlen=self.max_seqlen,
+ rope_use_fp32_in_outer_product=rope_use_fp32_in_outer_product,
+ )
+ # Handle both eos_id and eos_token_id for compatibility
+ self.eos_id = eos_token_id
+
+ # Extract additional parameters for BLTTransformerLayer
+ n_kv_heads = getattr(config, 'patcher_n_kv_heads', None) if hasattr(config, 'patcher_dim') else getattr(config, 'n_kv_heads', None)
+ multiple_of = getattr(config, 'patcher_multiple_of', 256) if hasattr(config, 'patcher_dim') else getattr(config, 'multiple_of', 256)
+ ffn_dim_multiplier = getattr(config, 'patcher_ffn_dim_multiplier', None) if hasattr(config, 'patcher_dim') else getattr(config, 'ffn_dim_multiplier', None)
+
+ # Create a simple parameter dict for BLTTransformerLayer
+ layer_params = {
+ 'dim': self.dim,
+ 'n_heads': n_heads,
+ 'head_dim': head_dim,
+ 'n_kv_heads': n_kv_heads,
+ 'rope_theta': rope_theta,
+ 'multiple_of': multiple_of,
+ 'ffn_dim_multiplier': ffn_dim_multiplier,
+ 'norm_eps': norm_eps,
+ }
+
+ self.layers = nn.ModuleList()
+ for _ in range(n_layers):
+ self.layers.append(BLTTransformerLayer(layer_params))
+
+ # LMTransformer specific attributes
+ self.weight_tying = weight_tying
+ self.sliding_window = sliding_window
+
+ assert vocab_size > 0
+
+ self.tok_embeddings = torch.nn.Embedding(vocab_size, self.dim)
+
+ self.norm = RMSNorm(self.dim, eps=norm_eps)
+
+ self.output = nn.Linear(
+ self.dim,
+ vocab_size,
+ bias=False,
+ )
+
+ if self.weight_tying:
+ self.output.weight = self.tok_embeddings.weight
+
+ def forward(
+ self,
+ token_values: torch.Tensor,
+ target: Optional[torch.Tensor] = None,
+ tok_idx: Optional[torch.Tensor] = None,
+ mask: Optional[Union[BlockMask, torch.Tensor, str]] = None,
+ attn_impl: str | None = None,
+ patch_size: Optional[int] = None,
+ include_next_token: bool = True,
+ threshold: Optional[float] = None,
+ threshold_add: Optional[float] = None,
+ monotonicity: bool = False,
+ max_patch_length: Optional[int] = None,
+ patching_batch_size: int = 1, # Changed from Optional[int] = None to int = 1
+ device: Optional[str] = None,
+ enable_grad: bool = False,
+ ):
+ attn_impl = self.attn_impl if attn_impl is None else attn_impl
+
+ # Handle chunked processing for entropy calculation
+ # grad_context = nullcontext() if enable_grad else torch.no_grad()
+ # with grad_context:
+ entropies = []
+ preds = []
+ max_length = min(getattr(self, "max_length", 8192), self.max_seqlen)
+ batch_numel = max_length * patching_batch_size
+ splits = torch.split(token_values.flatten(), batch_numel)
+
+ for split in splits:
+ pad_size = (max_length - (split.numel() % max_length)) % max_length
+ pad = torch.zeros(
+ pad_size, dtype=split.dtype, device=split.device, requires_grad=False
+ )
+ split = torch.cat((split, pad), dim=0)
+ split = split.reshape(-1, max_length)
+ if device is not None:
+ split = split.to(device)
+
+ # Process chunk: embeddings -> layers -> output
+ bsz, seqlen = split.shape
+ h = self.tok_embeddings(split)
+ chunk_mask = create_causal_mask(
+ seqlen,
+ attn_impl,
+ self.attn_bias_type,
+ sliding_window=self.sliding_window,
+ tokens=split,
+ eos_id=self.eos_id,
+ )
+ freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None)
+
+ for i, layer in enumerate(self.layers):
+ h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=attn_impl)
+
+ pred = self.output(self.norm(h))
+ pred = pred.reshape(-1, pred.shape[-1])[
+ : split.numel() - pad_size, :
+ ] # [batch_size * seq_len, vocab]
+ preds.append(pred)
+ pred_entropies = self.entropy(pred)
+ entropies.append(pred_entropies)
+
+ concat_entropies = torch.cat(entropies, dim=0)
+ concat_entropies = concat_entropies.reshape(token_values.shape)
+ concat_preds = torch.cat(preds, dim=0)
+ concat_preds = concat_preds.reshape(token_values.shape[0], -1)
+
+ # Always compute patch lengths from concatenated entropies
+ bs, seq_len = token_values.shape
+ seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
+
+ # Find patch start IDs based on entropy
+ if patch_size is not None:
+ patch_start_ids = self.find_entropy_patch_start_ids(
+ concat_entropies,
+ patch_size,
+ include_next_token=include_next_token,
+ threshold=threshold,
+ threshold_add=threshold_add,
+ monotonicity=monotonicity,
+ )
+ patch_lengths = self.patch_lengths_from_start_ids(
+ patch_start_ids, seq_len_next_tok
+ )
+ else:
+ # Default to byte-level patching
+ patch_lengths = torch.ones(
+ (bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device
+ )
+
+ # Apply any processing to patch lengths
+ if max_patch_length is not None:
+ # TODO: avoid going back to a list here.
+ patch_lengths = [
+ self.split_large_numbers(pl, max_patch_length)
+ for pl in patch_lengths.tolist()
+ ]
+ max_len = max([len(pl) for pl in patch_lengths])
+ patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
+ patch_lengths = torch.tensor(
+ patch_lengths, dtype=token_values.dtype, device=token_values.device
+ )
+ assert not check_non_zero_after_zero(patch_lengths)
+ # Find the last non-zero column index using argmax on a reversed version of the tensor
+ last_non_zero_col_reversed = (
+ (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
+ )
+ # Slice the tensor up to the last non-zero column
+ patch_lengths = patch_lengths[
+ :, : patch_lengths.shape[1] - last_non_zero_col_reversed
+ ]
+
+ return concat_entropies, patch_lengths, concat_preds
+
+ def reset_parameters(self, init_std=None):
+ self.norm.reset_parameters()
+
+ def init_weights(self):
+ self.reset_parameters()
+ init_std = self.dim ** (-0.5)
+ nn.init.trunc_normal_(
+ self.tok_embeddings.weight,
+ mean=0.0,
+ std=init_std,
+ a=-3 * init_std,
+ b=3 * init_std,
+ )
+
+ self.rope_embeddings.reset_parameters()
+ for depth, layer in enumerate(self.layers):
+ factor = self.config.get_init_std_factor(depth)
+ layer.init_weights(self.init_base_std, factor)
+
+ if not self.weight_tying:
+ nn.init.trunc_normal_(
+ self.output.weight,
+ mean=0.0,
+ std=init_std,
+ a=-3 * init_std,
+ b=3 * init_std,
+ )
+
+ @staticmethod
+ def entropy(scores):
+ """
+ scores: [bs, seq_len, vocab]
+ returns [bs, seq_len]
+
+ Computes the entropy for each token in the batch.
+ Note: uses natural log.
+ """
+ log_probs = F.log_softmax(scores, dim=-1)
+ probs = torch.exp(log_probs)
+ p_log_p = log_probs * probs
+ entropy = -p_log_p.sum(dim=-1)
+ return entropy
+
+
+
+ @staticmethod
+ def patch_start_ids_from_patch_start_mask(patch_start_mask):
+ bs, trunc_seq_len = patch_start_mask.shape
+ max_patches = patch_start_mask.sum(dim=1).max()
+ if max_patches == 0:
+ patch_start_ids = torch.full(
+ (bs, trunc_seq_len),
+ trunc_seq_len,
+ dtype=torch.long,
+ device=patch_start_mask.device,
+ )
+ else:
+ patch_ids = (
+ torch.arange(trunc_seq_len, device=patch_start_mask.device)
+ .unsqueeze(0)
+ .repeat(bs, 1)
+ )
+ extra_patch_ids = torch.full(
+ (bs, trunc_seq_len),
+ trunc_seq_len,
+ dtype=torch.long,
+ device=patch_start_mask.device,
+ )
+ all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
+ patch_start_mask_padded = torch.cat(
+ (patch_start_mask, ~patch_start_mask), dim=1
+ )
+ patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(
+ bs, trunc_seq_len
+ )[:, :max_patches]
+ return patch_start_ids
+
+ @staticmethod
+ def patch_lengths_from_start_ids(patch_start_ids, seq_len):
+ """
+ Calculate patch lengths from start ids.
+ start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then
+ the rest are filled to the seq len.
+ seq_len: ex: 7 length of the sequence
+
+ returns the patch lengths:
+ [1, 6] for the above example.
+ """
+ last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1)
+ patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1)
+ patch_lengths = patch_end_ids - patch_start_ids + 1
+ assert torch.all(patch_lengths >= 0), f"{patch_lengths}"
+ assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}"
+ return patch_lengths
+
+ @staticmethod
+ def find_entropy_patch_start_ids(
+ entropies,
+ patch_size=None,
+ threshold=None,
+ threshold_add=None,
+ monotonicity=False,
+ include_next_token=True,
+ ):
+ """
+ Use entropies to find the start ids of each patch.
+ Use patch_size or threshold to figure out the total number of patches to allocate.
+
+ When threshold is not None the number of patches is not constant between
+ different sequences, but patches can be identified incrementally rather than
+ decided globally using the entire sequence.
+ """
+ bs, seq_len = entropies.shape[:2]
+
+ first_ids = (
+ torch.tensor([0, 1], dtype=torch.long, device=entropies.device)
+ .unsqueeze(0)
+ .repeat(bs, 1)
+ )
+ preds_truncation_len = first_ids.shape[
+ 1
+ ] # remove the first preds because they will be start of patches.
+ entropies = entropies[:, 1:]
+ if threshold is None:
+ num_patches = seq_len // patch_size
+ patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices
+ patch_start_ids = patch_start_ids.sort(dim=1).values
+ else:
+ patch_start_mask = entropies > threshold
+ if not include_next_token:
+ patch_start_mask = patch_start_mask[:, :-1]
+ # patch_start_mask[1:] |= tokens[:-1] < OFFSET
+ patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask)
+
+ patch_start_ids = torch.cat(
+ (first_ids, patch_start_ids + preds_truncation_len), dim=1
+ )
+ return patch_start_ids
+
+ @staticmethod
+ def split_large_numbers(lst, m):
+ new_lst = []
+ for i in lst:
+ if i > m:
+ while i > m:
+ new_lst.append(m)
+ i -= m
+ new_lst.append(i)
+ else:
+ new_lst.append(i)
+ assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}"
+ return new_lst
+
+
+def init_hash_embeddings(
+ config,
+ local_encoder_dim: int,
+ encoder_hash_byte_group_size: list,
+):
+ """Initialize hash-based token embeddings for the BLT encoder."""
+ if config.encoder_hash_byte_group_size is None:
+ return None
+
+ embeddings = []
+ emb_dim = local_encoder_dim
+ encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab
+
+ for _ in range(config.encoder_hash_byte_group_nb_functions):
+ for _ in encoder_hash_byte_group_size:
+ embeddings.append(
+ nn.Embedding(
+ encoder_hash_byte_group_vocab,
+ emb_dim,
+ )
+ )
+
+ return nn.ModuleList(embeddings)
+
+
+__all__ = [
+ "BLTPreTrainedModel",
+ "BLTModel",
+ "BLTPatcher",
+ "LocalEncoder",
+ "LocalDecoder",
+ "GlobalTransformer",
+]
\ No newline at end of file
diff --git a/backup_blt_wip_backup/tokenization_blt.py b/backup_blt_wip_backup/tokenization_blt.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf57143de5dd7d594c60b05660f03548dd60689b
--- /dev/null
+++ b/backup_blt_wip_backup/tokenization_blt.py
@@ -0,0 +1,273 @@
+# coding=utf-8
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tokenization classes for BLT."""
+
+import os
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+
+
+if TYPE_CHECKING:
+ from ...tokenization_utils_base import TextInput
+
+logger = logging.get_logger(__name__)
+
+# BLT tokenizer constants
+SEP = " "
+BOS_ID: int = 1
+EOS_ID: int = 2
+PAD_ID: int = -1
+BOE_ID: int = 0
+BPE_ID: int = 3
+OFFSET: int = 4
+BYTE_UNITS: int = 256
+
+VOCAB_FILES_NAMES = {} # BLT doesn't require external vocab files
+
+
+class BLTTokenizer(PreTrainedTokenizer):
+ """
+ Construct a BLT tokenizer. Based on byte-level tokenization where each byte is treated as a token.
+
+ This tokenizer converts text to UTF-8 bytes and then maps each byte to a token ID with an offset.
+ It supports special tokens for beginning of sequence (BOS), end of sequence (EOS),
+ beginning of example (BOE), and padding (PAD).
+
+ Args:
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The beginning of sequence token.
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The end of sequence token.
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The padding token.
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The unknown token. Not used in BLT but kept for compatibility.
+ boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The beginning of example token, specific to BLT.
+ add_bos_token (`bool`, *optional*, defaults to `True`):
+ Whether or not to add a `bos_token` at the start of sequences.
+ add_eos_token (`bool`, *optional*, defaults to `True`):
+ Whether or not to add an `eos_token` at the end of sequences.
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+ Whether or not to cleanup spaces after decoding.
+ spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not to add spaces between special tokens.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ bos_token="",
+ eos_token="",
+ pad_token="",
+ unk_token="",
+ boe_token="",
+ add_bos_token=True,
+ add_eos_token=True,
+ clean_up_tokenization_spaces=False,
+ spaces_between_special_tokens=False,
+ **kwargs,
+ ):
+ # Store BLT-specific parameters first
+ self.add_bos_token = add_bos_token
+ self.add_eos_token = add_eos_token
+ self.vocab_size_unit_1 = BYTE_UNITS
+ self.offsetting_special_char = OFFSET
+
+ # BLT token IDs (exactly like original)
+ self.boe_id = BOE_ID
+ self.bos_id = BOS_ID
+ self.eos_id = EOS_ID
+ self.pad_id = PAD_ID
+ self.bpe_id = BPE_ID
+ self.n_words = self.vocab_size_unit_1 + self.offsetting_special_char
+
+ # Convert string tokens to AddedToken objects
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
+ pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
+ self.boe_token = AddedToken(boe_token, normalized=False, special=True) if isinstance(boe_token, str) else boe_token
+
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ unk_token=unk_token,
+ add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ spaces_between_special_tokens=spaces_between_special_tokens,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self):
+ """Returns vocab size"""
+ return self.vocab_size_unit_1 + self.offsetting_special_char
+
+ def get_vocab(self):
+ """Returns vocab as a dict"""
+ # Create a mapping for byte values + offset
+ vocab = {}
+
+ # Add special tokens (with defensive checks)
+ if hasattr(self, 'bos_token'):
+ vocab[str(self.bos_token)] = self.bos_id
+ if hasattr(self, 'eos_token'):
+ vocab[str(self.eos_token)] = self.eos_id
+ if hasattr(self, 'pad_token'):
+ vocab[str(self.pad_token)] = self.pad_id
+ if hasattr(self, 'boe_token'):
+ vocab[str(self.boe_token)] = self.boe_id
+
+ # Add byte tokens as string representations of byte values
+ vocab_size_unit_1 = getattr(self, 'vocab_size_unit_1', BYTE_UNITS)
+ offsetting_special_char = getattr(self, 'offsetting_special_char', OFFSET)
+ for i in range(vocab_size_unit_1):
+ vocab[str(i)] = i + offsetting_special_char
+
+ # Add any additional tokens if available
+ if hasattr(self, 'added_tokens_encoder'):
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
+ """
+ Converts a string to a list of tokens. For BLT, we work directly with byte values.
+ Returns a list of strings that represent the byte values.
+ """
+ # Convert text to UTF-8 bytes, just like the original
+ try:
+ bytes_data = text.encode("utf-8", errors="ignore")
+ except UnicodeEncodeError:
+ bytes_data = text.encode("utf-8", errors="ignore")
+
+ # Return string representations of byte values for the tokenizer framework
+ return [str(byte_val) for byte_val in bytes_data]
+
+ def _convert_token_to_id(self, token: str) -> int:
+ """Converts a token (str) to an id using the vocab."""
+ # Handle special tokens
+ if token == str(self.bos_token):
+ return self.bos_id
+ elif token == str(self.eos_token):
+ return self.eos_id
+ elif token == str(self.pad_token):
+ return self.pad_id
+ elif token == str(self.boe_token):
+ return self.boe_id
+ else:
+ try:
+ # Convert byte value string to int and add offset (like original)
+ byte_val = int(token)
+ if 0 <= byte_val <= 255:
+ return byte_val + self.offsetting_special_char
+ except ValueError:
+ pass
+
+ # Check if it's in added tokens
+ return self.added_tokens_encoder.get(token, self.unk_token_id)
+
+ def _convert_id_to_token(self, index: int) -> str:
+ """Converts an index (integer) to a token (str) using the vocab."""
+ # Handle special tokens
+ if index == self.bos_id:
+ return str(self.bos_token)
+ elif index == self.eos_id:
+ return str(self.eos_token)
+ elif index == self.pad_id:
+ return str(self.pad_token)
+ elif index == self.boe_id:
+ return str(self.boe_token)
+ elif index >= self.offsetting_special_char and index < self.vocab_size:
+ # Convert back to byte value (like original)
+ byte_val = index - self.offsetting_special_char
+ return str(byte_val)
+ else:
+ # Check added tokens
+ for token, token_id in self.added_tokens_encoder.items():
+ if token_id == index:
+ return token
+ return str(self.unk_token)
+
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
+ """Converts a sequence of tokens to a single string."""
+ byte_values = []
+
+ for token in tokens:
+ # Skip special tokens
+ if token in [str(self.bos_token), str(self.eos_token), str(self.pad_token), str(self.boe_token)]:
+ continue
+
+ try:
+ # Convert token back to byte value (like original decode method)
+ byte_val = int(token)
+ if 0 <= byte_val <= 255:
+ byte_values.append(byte_val)
+ except ValueError:
+ continue
+
+ # Convert byte values back to string (exactly like original)
+ try:
+ return bytes(byte_values).decode("utf-8", errors="ignore")
+ except (UnicodeDecodeError, ValueError):
+ return ""
+
+ def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None):
+ """
+ Encode text exactly like the original BLT tokenizer.
+ """
+ if add_bos is None:
+ add_bos = self.add_bos_token
+ if add_eos is None:
+ add_eos = self.add_eos_token
+
+ # Since bpe_delim=False, we use the simple byte encoding
+ tokens = bytes(text, encoding="utf-8", errors="ignore")
+
+ # Offsetting (exactly like original)
+ tokens = [int(unit) + self.offsetting_special_char for unit in tokens]
+
+ if add_bos:
+ tokens.insert(0, self.bos_id)
+ if add_eos:
+ tokens.append(self.eos_id)
+
+ return tokens
+
+ def decode(self, tokens: list[int], cut_at_eos: bool = False):
+ """
+ Decode tokens exactly like the original BLT tokenizer.
+ """
+ if cut_at_eos:
+ for k, t in enumerate(tokens):
+ if t == self.eos_id:
+ tokens = tokens[: k + 1]
+ break
+ return bytes(
+ [tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0]
+ ).decode("utf-8", errors="ignore")
+
+ def get_vocab_size(self) -> int:
+ """Get vocab size like the original tokenizer."""
+ return self.vocab_size_unit_1 + self.offsetting_special_char
+
+__all__ = ["BLTTokenizer"]
\ No newline at end of file
diff --git a/backup_blt_wip_backup/tokenizers/__init__.py b/backup_blt_wip_backup/tokenizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7
--- /dev/null
+++ b/backup_blt_wip_backup/tokenizers/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-312.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e437756c33aec5b5372596b16fd60b8732b57de5
Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-312.pyc differ
diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-39.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ca0dcc4ad6d75cac3868dee0b9bf1183c2183a8
Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-39.pyc differ
diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/abstract_tokenizer.cpython-312.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/abstract_tokenizer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..029146f29c2b8c58f2e50bf95159751d1cf21006
Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/abstract_tokenizer.cpython-312.pyc differ
diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/blt_tokenizer.cpython-312.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/blt_tokenizer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a1c2398abd9aa0a8bef2ac61f3ab44161d6cf055
Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/blt_tokenizer.cpython-312.pyc differ
diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/build_tokenizer.cpython-312.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/build_tokenizer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f344f971978a50d4cb123127a0cb65e1c8327472
Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/build_tokenizer.cpython-312.pyc differ
diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-312.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c380523b8c7191a900bd6bce095df4c63ff30d4d
Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-312.pyc differ
diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-39.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b2e7cfd7e1cc6e85e4d11e731235b383230c2149
Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-39.pyc differ
diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/sentence_piece_tokenizer.cpython-312.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/sentence_piece_tokenizer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4a60bb2cf0afaecb335c7b70118047ded4e34ab
Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/sentence_piece_tokenizer.cpython-312.pyc differ
diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/tiktoken_tokenizer.cpython-312.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/tiktoken_tokenizer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7548c986e56ce17009c2d81c63ce2545c5cc1d5
Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/tiktoken_tokenizer.cpython-312.pyc differ
diff --git a/backup_blt_wip_backup/tokenizers/abstract_tokenizer.py b/backup_blt_wip_backup/tokenizers/abstract_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff31d655ae3461ca1d39230e40d72e0c24246fa3
--- /dev/null
+++ b/backup_blt_wip_backup/tokenizers/abstract_tokenizer.py
@@ -0,0 +1,21 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import abc
+
+
+class Tokenizer(abc.ABC):
+ @abc.abstractmethod
+ def encode(self, text: str, add_bos: bool, add_eos: bool):
+ pass
+
+ @abc.abstractmethod
+ def decode(self, tokens: list[int]):
+ pass
+
+ @abc.abstractmethod
+ def get_token_offsets(self, text: str, tokens: list[int] | None = None) -> tuple[list[str], list[int]]:
+ """Return the offsets of the tokens in the original text. Only used for evaluation."""
+ pass
+
+ @abc.abstractmethod
+ def get_vocab_size(self) -> int:
+ pass
diff --git a/backup_blt_wip_backup/tokenizers/blt_tokenizer.py b/backup_blt_wip_backup/tokenizers/blt_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d018ff90ead466a284e59a18ee5839d9f1a3155
--- /dev/null
+++ b/backup_blt_wip_backup/tokenizers/blt_tokenizer.py
@@ -0,0 +1,143 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import re
+
+from .abstract_tokenizer import Tokenizer
+from .sentence_piece_tokenizer import SentencePieceTokenizer
+
+
+SEP = " "
+BOS_ID: int = 1
+EOS_ID: int = 2
+PAD_ID: int = -1
+BOE_ID: int = 0
+BPE_ID: int = 3
+OFFSET: int = 4
+
+BYTE_UNITS: int = 256
+
+
+def convert_to_bytes(s):
+ # check if the output is a bytes like object of the format <0x00>
+ if re.match(r"<0x[0-9a-fA-F]+>", s):
+ return bytes.fromhex(s[3:-1])
+ else:
+ return bytes(s, "utf-8", errors="ignore")
+
+
+def text2bytes_bpe_delims(
+ text: str,
+ *,
+ bpe_tokenizer,
+ bpe_id: int,
+ offsetting_special_char: int,
+ add_bos: bool,
+ add_eos: bool,
+):
+ cur_bpe = bpe_tokenizer.encode(text, add_bos=add_bos, add_eos=add_eos)
+ # merge the leading space tokens
+ leading_space_tokens = []
+ other_bpe_tokens = []
+ leading = True
+ for token in cur_bpe:
+ bpe_str = bpe_tokenizer.sp_model.id_to_piece(token)
+ if leading and all(c == "▁" for c in bpe_str):
+ leading_space_tokens.append(bpe_str)
+ else:
+ leading = False
+ other_bpe_tokens.append(bpe_str)
+ cur_bpe_strs = ["".join(leading_space_tokens)] + other_bpe_tokens
+
+ # Remove the '▁' characters
+ bpe_strs = []
+ for i, bpe_str in enumerate(cur_bpe_strs):
+ if len(bpe_strs) <= 1 and all([c == " " for s in bpe_strs for c in s]) and not all(c == "▁" for c in bpe_str):
+ # Remove leading space for first non space token.
+ bpe_str = bpe_str.replace("▁", "")
+ elif i == 0 and all(c == "▁" for c in bpe_str):
+ bpe_str = " " * (len(text) - len(text.lstrip(" ")))
+ else:
+ bpe_str = bpe_str.replace("▁", " ")
+ if len(bpe_str) > 0:
+ bpe_strs.append(bpe_str)
+ ex_seq = []
+ # Convert bpe tokens to bytes
+ for s in bpe_strs:
+ byte_chunk = convert_to_bytes(s)
+ proc_chunk = [int(unit) for unit in byte_chunk]
+ ex_seq.extend([bpe_id - offsetting_special_char] + proc_chunk)
+
+ return ex_seq
+
+
+class BltTokenizer(Tokenizer):
+ def __init__(
+ self,
+ *,
+ vocab_size_unit_1: int = BYTE_UNITS,
+ bpe_delim: bool = False,
+ bpe_tokenizer_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model",
+ add_bos: bool = True,
+ add_eos: bool = True,
+ ):
+ self.add_bos = add_bos
+ self.add_eos = add_eos
+ self.vocab_size_unit_1 = vocab_size_unit_1
+ self.boe_id = BOE_ID
+ self.bos_id = BOS_ID
+ self.eos_id = EOS_ID
+ self.pad_id = PAD_ID
+ self.bpe_id = BPE_ID
+ self.bpe_tokenizer_path = bpe_tokenizer_path
+ if bpe_delim:
+ self.bpe_tokenizer = SentencePieceTokenizer(model_path=self.bpe_tokenizer_path)
+ else:
+ self.bpe_tokenizer = None
+ self.bpe_delim = bpe_delim
+ self.offsetting_special_char = OFFSET
+ self.vocab_size_unit_1 = vocab_size_unit_1
+ self.n_words = vocab_size_unit_1 + self.offsetting_special_char
+
+ def get_vocab_size(self) -> int:
+ return self.n_words
+
+ def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None):
+ if add_bos is None:
+ add_bos = self.add_bos
+ if add_eos is None:
+ add_eos = self.add_eos
+
+ if self.bpe_delim:
+ tokens = text2bytes_bpe_delims(
+ text,
+ bpe_tokenizer=self.bpe_tokenizer,
+ bpe_id=self.bpe_id,
+ offsetting_special_char=self.offsetting_special_char,
+ add_bos=False,
+ add_eos=False,
+ )
+ else:
+ tokens = bytes(text, encoding="utf-8", errors="ignore")
+
+ # Offsetting
+ tokens = [int(unit) + self.offsetting_special_char for unit in tokens]
+
+ if add_bos:
+ tokens.insert(0, self.bos_id)
+ if add_eos:
+ tokens.append(self.eos_id)
+
+ return tokens
+
+ def decode(self, tokens: list[int], cut_at_eos: bool = False):
+ if cut_at_eos:
+ for k, t in enumerate(tokens):
+ if t == self.eos_id:
+ tokens = tokens[: k + 1]
+ break
+ return bytes(
+ [tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0]
+ ).decode("utf-8", errors="ignore")
+
+ def get_token_offsets(self, text: str, tokens: list[int] | None = None):
+ # TODO: Figure out what this does
+ raise NotImplementedError()
diff --git a/backup_blt_wip_backup/tokenizers/sentence_piece_tokenizer.py b/backup_blt_wip_backup/tokenizers/sentence_piece_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fece8cbce7a85d6aabb3e3545cd256d17b947237
--- /dev/null
+++ b/backup_blt_wip_backup/tokenizers/sentence_piece_tokenizer.py
@@ -0,0 +1,56 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+import logging
+import os
+
+
+try:
+ from sentencepiece import SentencePieceProcessor
+
+ has_sp = True
+except ImportError:
+ has_sp = False
+
+from .abstract_tokenizer import Tokenizer
+
+
+logger = logging.getLogger(__name__)
+
+
+class SentencePieceTokenizer(Tokenizer):
+ def __init__(self, model_path: str, add_bos: bool = True, add_eos: bool = True) -> None:
+ assert os.path.isfile(model_path), model_path
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
+
+ logger.info(f"Reloaded SentencePiece model from {model_path}")
+
+ # BOS / EOS token IDs
+ self.n_words: int = self.sp_model.vocab_size()
+ self.bos_id: int = self.sp_model.bos_id()
+ self.eos_id: int = self.sp_model.eos_id()
+ self.pad_id: int = self.sp_model.pad_id()
+ self.add_bos = add_bos
+ self.add_eos = add_eos
+ logger.info(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}")
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
+
+ def get_vocab_size(self) -> int:
+ return self.n_words
+
+ def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None):
+ if add_bos is None:
+ add_bos = self.add_bos
+
+ if add_eos is None:
+ add_eos = self.add_eos
+ assert type(s) is str
+ tokens = [self.bos_id] * add_bos + self.sp_model.encode(s) + [self.eos_id] * add_eos
+ return tokens
+
+ def decode(self, tokens: list[int]):
+ return self.sp_model.decode(tokens)
+
+ def get_token_offsets(self, text: str, tokens: list[int] | None = None) -> tuple[list[str], list[int]]:
+ pieces = self.sp_model.encode_as_immutable_proto(text).pieces
+ substrs = [p.surface for p in pieces]
+ offsets = [p.begin for p in pieces]
+ return substrs, offsets
diff --git a/backup_blt_wip_backup/unified_blt_debug/config.json b/backup_blt_wip_backup/unified_blt_debug/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..67ab54a27541379cc9caae6512155029d2adf52c
--- /dev/null
+++ b/backup_blt_wip_backup/unified_blt_debug/config.json
@@ -0,0 +1,144 @@
+{
+ "args": {
+ "alpha_depth": "disabled",
+ "architecture": "vanilla",
+ "attn_bias_type": "block_causal",
+ "attn_impl": "xformers",
+ "attn_to_keep": "all",
+ "conv_kernel_size": null,
+ "cross_attn_all_layers_decoder": true,
+ "cross_attn_all_layers_encoder": false,
+ "cross_attn_decoder": true,
+ "cross_attn_encoder": true,
+ "cross_attn_init_by_pooling": true,
+ "cross_attn_k": 2,
+ "cross_attn_nheads": 16,
+ "cross_attn_use_flex_attention": true,
+ "cross_attn_window_decoder": null,
+ "cross_attn_window_encoder": null,
+ "custom_bwd": false,
+ "dim": 512,
+ "dim_global": 2048,
+ "dim_local_decoder": 1024,
+ "dim_local_encoder": 1024,
+ "dim_patch_emb": null,
+ "dim_token": null,
+ "dim_token_emb": null,
+ "downsampling_by_pooling": "max",
+ "dropout": 0.0,
+ "encoder_enable_byte_group_hash": false,
+ "encoder_enable_byte_ngrams": false,
+ "encoder_hash_byte_group_nb_functions": 1,
+ "encoder_hash_byte_group_size": [
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8
+ ],
+ "encoder_hash_byte_group_vocab": 500002,
+ "encoder_lm_loss": false,
+ "encoder_ngram_table_dir": null,
+ "encoder_ngram_to_size_str": null,
+ "encoder_preds_low_entropy_toks": null,
+ "encoder_preds_random_toks": null,
+ "entropy_model_checkpoint_dir": null,
+ "entropy_model_is_ngram_model": false,
+ "eos_id": 2,
+ "ffn_dim_multiplier": 1.0,
+ "full_logging_n_layers": 4,
+ "fuse_sequence_parallel": false,
+ "global_local_decoder_residual_layer": null,
+ "head_dim": null,
+ "init_base_std": null,
+ "init_std_factor": "current_depth",
+ "init_use_depth": "current",
+ "init_use_gaussian": true,
+ "layer_ckpt": "none",
+ "local_attention_window_len": 512,
+ "log_patch_lengths": false,
+ "loss_parallel": false,
+ "max_encoder_seq_length": 24576,
+ "max_length": 256,
+ "max_patch_length": null,
+ "max_seqlen": 4096,
+ "monotonicity": false,
+ "multiple_of": 256,
+ "n_heads": 8,
+ "n_heads_global": 16,
+ "n_heads_local_decoder": 16,
+ "n_heads_local_encoder": 16,
+ "n_kv_heads": null,
+ "n_kv_heads_global": null,
+ "n_layers": 8,
+ "n_layers_global": 25,
+ "n_layers_local_decoder": 9,
+ "n_layers_local_encoder": 1,
+ "ngram_vocab_sizes": null,
+ "non_linearity": "swiglu",
+ "norm_affine": true,
+ "norm_eps": 1e-05,
+ "norm_type": "rmsnorm",
+ "output_size": -1,
+ "pad_to_max_length": true,
+ "patch_in_forward": true,
+ "patch_size": 4.5,
+ "patching_batch_size": 1,
+ "patching_device": "cuda",
+ "patching_mode": "entropy",
+ "patching_threshold": 1.335442066192627,
+ "patching_threshold_add": null,
+ "patching_thresholds_str": null,
+ "pm_size": 0,
+ "pre_norm": true,
+ "recompute_attn": false,
+ "recompute_fc1_out": false,
+ "recompute_fc3_out": false,
+ "rope_theta": 500000.0,
+ "rope_use_fp32_in_outer_product": true,
+ "seed": 42,
+ "sequence_parallel": false,
+ "share_encoder_decoder_emb": true,
+ "tie_local_encoder_decoder": false,
+ "tie_local_encoder_decoder_logits": false,
+ "tokenize_with_bpe_delimiter": false,
+ "use_fsdp": true,
+ "use_local_encoder_transformer": true,
+ "use_rope": true,
+ "vocab_size": 260,
+ "weight_tying": false
+ },
+ "patch_in_forward": true,
+ "realtime_patching": true,
+ "patching_mode": "entropy",
+ "patch_size": 4.5,
+ "patching_threshold": 1.335442066192627,
+ "patching_threshold_add": null,
+ "max_patch_length": null,
+ "patching_batch_size": 1,
+ "patching_device": "cuda",
+ "monotonicity": false,
+ "patcher_vocab_size": 260,
+ "patcher_dim": 768,
+ "patcher_n_layers": 14,
+ "patcher_n_heads": 12,
+ "patcher_head_dim": null,
+ "patcher_n_kv_heads": null,
+ "patcher_max_seqlen": 8192,
+ "patcher_norm_eps": 1e-05,
+ "patcher_dropout": 0.0,
+ "patcher_sliding_window": 512,
+ "patcher_ffn_dim_multiplier": 1.0,
+ "patcher_multiple_of": 256,
+ "patcher_rope_theta": 10000.0,
+ "patcher_rope_use_fp32_in_outer_product": false,
+ "patcher_attn_impl": "xformers",
+ "patcher_attn_bias_type": "local_block_causal",
+ "patcher_init_base_std": null,
+ "patcher_init_std_factor": "current_depth",
+ "patcher_dim_token_emb": null,
+ "patcher_weight_tying": false,
+ "patcher_bos_token_id": 1,
+ "patcher_eos_token_id": 2
+}
\ No newline at end of file
diff --git a/config.json b/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..60d10db33244fffe880b0acb16c5b5d7938002c0
--- /dev/null
+++ b/config.json
@@ -0,0 +1,100 @@
+{
+ "model_type": "blt",
+ "vocab_size": 260,
+ "max_position_embeddings": 4096,
+ "patch_in_forward": true,
+ "realtime_patching": true,
+ "patching_mode": "entropy",
+ "patch_size": 4,
+ "patching_threshold": 1.335442066192627,
+ "patching_threshold_add": null,
+ "max_patch_length": null,
+ "patching_batch_size": 1,
+ "patching_device": "cuda",
+ "monotonicity": false,
+ "cross_attn_k": 2,
+ "encoder_hash_byte_group_size": [
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8
+ ],
+ "encoder_hash_byte_group_vocab": 500002,
+ "encoder_hash_byte_group_nb_functions": 1,
+ "pm_size": 0,
+ "tie_word_embeddings": false,
+ "initializer_range": 0.02,
+ "rope_theta": 500000.0,
+ "rope_scaling": {
+ "type": "default"
+ },
+ "patcher_config": {
+ "vocab_size": 260,
+ "hidden_size": 768,
+ "num_hidden_layers": 14,
+ "num_attention_heads": 12,
+ "num_key_value_heads": null,
+ "max_position_embeddings": 8192,
+ "norm_eps": 1e-05,
+ "dropout": 0.0,
+ "rope_theta": 10000.0,
+ "attn_bias_type": "local_block_causal",
+ "intermediate_size": 2048
+ },
+ "encoder_config": {
+ "vocab_size": 260,
+ "cross_attn_all_layers": false,
+ "cross_attn_k": 2,
+ "hidden_size_global": 2048,
+ "pm_size": 0,
+ "hidden_size": 1024,
+ "num_attention_heads": 16,
+ "num_key_value_heads": null,
+ "num_hidden_layers": 1,
+ "norm_eps": 1e-05,
+ "dropout": 0.0,
+ "max_position_embeddings": 24576,
+ "rope_theta": 500000.0,
+ "rope_scaling": {
+ "type": "default"
+ },
+ "hidden_act": "silu",
+ "intermediate_size": 2816
+ },
+ "decoder_config": {
+ "vocab_size": 260,
+ "cross_attn_all_layers": true,
+ "cross_attn_k": 2,
+ "hidden_size_global": 2048,
+ "hidden_size": 1024,
+ "num_attention_heads": 16,
+ "num_key_value_heads": null,
+ "num_hidden_layers": 9,
+ "norm_eps": 1e-05,
+ "dropout": 0.0,
+ "max_position_embeddings": 24576,
+ "rope_theta": 500000.0,
+ "rope_scaling": {
+ "type": "default"
+ },
+ "hidden_act": "silu",
+ "intermediate_size": 2816
+ },
+ "global_config": {
+ "hidden_size": 2048,
+ "num_attention_heads": 16,
+ "num_key_value_heads": null,
+ "num_hidden_layers": 25,
+ "norm_eps": 1e-05,
+ "dropout": 0.0,
+ "max_position_embeddings": 4096,
+ "rope_theta": 500000.0,
+ "rope_scaling": {
+ "type": "default"
+ },
+ "hidden_act": "silu",
+ "intermediate_size": 5632
+ }
+}
\ No newline at end of file
diff --git a/model.safetensors b/model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..5e4e87591312fda1029d125dfeb22f37ea1e0935
--- /dev/null
+++ b/model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b42ece52607eacbb4e538c695b137a53d38ea68dcc4a03dd825a9656f476162d
+size 9266850624
diff --git a/tokenizer_config.json b/tokenizer_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..3b3a3cd875566cd2014531f023e025f7a69e7539
--- /dev/null
+++ b/tokenizer_config.json
@@ -0,0 +1,11 @@
+{
+ "tokenizer_class": "BLTTokenizer",
+ "vocab_size": 260,
+ "model_max_length": 4096,
+ "add_bos_token": true,
+ "add_eos_token": true,
+ "bos_token": "",
+ "eos_token": "",
+ "pad_token": "",
+ "unk_token": ""
+}
\ No newline at end of file