diff --git a/backup_blt_modellike/__init__.py b/backup_blt_modellike/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..703b81ecdd09dda47a97c641f7e440bcb5e81119 --- /dev/null +++ b/backup_blt_modellike/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_blt import * + from .modeling_blt import * + from .tokenization_blt import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/backup_blt_modellike/__pycache__/__init__.cpython-312.pyc b/backup_blt_modellike/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1eead81c4f00feeba4ec4166a9cd4cad1d228c8 Binary files /dev/null and b/backup_blt_modellike/__pycache__/__init__.cpython-312.pyc differ diff --git a/backup_blt_modellike/__pycache__/tokenization_blt.cpython-312.pyc b/backup_blt_modellike/__pycache__/tokenization_blt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0d2019b6d449a711c0fc06bdf6d36fc4f9caa47 Binary files /dev/null and b/backup_blt_modellike/__pycache__/tokenization_blt.cpython-312.pyc differ diff --git a/backup_blt_modellike/configuration_blt.py b/backup_blt_modellike/configuration_blt.py new file mode 100644 index 0000000000000000000000000000000000000000..9ee59362d955060fbe5bff81b635da55dc012981 --- /dev/null +++ b/backup_blt_modellike/configuration_blt.py @@ -0,0 +1,225 @@ +# coding=utf-8 +# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BLT model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class BLTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BLTModel`]. It is used to instantiate an BLT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the BLT-7B. + e.g. [meta-blt/BLT-2-7b-hf](https://huggingface.co/meta-blt/BLT-2-7b-hf) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the BLT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BLTModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. BLT 1 supports up to 2048 tokens, + BLT 2 up to 4096, CodeBLT up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'blt3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'blt3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'blt3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'blt3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_attention_heads + + ```python + >>> from transformers import BLTModel, BLTConfig + + >>> # Initializing a BLT blt-7b style configuration + >>> configuration = BLTConfig() + + >>> # Initializing a model from the blt-7b style configuration + >>> model = BLTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "blt" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `BLTModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + head_dim=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["BLTConfig"] diff --git a/backup_blt_modellike/convert_blt_weights_to_hf.py b/backup_blt_modellike/convert_blt_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..26c05477a169d41213133c3dffcc00ac32cd2401 --- /dev/null +++ b/backup_blt_modellike/convert_blt_weights_to_hf.py @@ -0,0 +1,397 @@ +import argparse +import json +import logging +import os +from typing import Any, Dict, Optional + +import torch +from huggingface_hub import hf_hub_download, upload_folder +from safetensors.torch import load_file, save_file + +from transformers.models.blt_wip.configuration_blt import BLTConfig +from transformers.models.blt_wip.modeling_blt import BLTModel +from transformers.models.blt_wip.modeling_blt_dev import BLTForCausalLM +from transformers.utils import logging as transformers_logging + + +logger = transformers_logging.get_logger(__name__) +transformers_logging.set_verbosity_info() + + +def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]: + logger.info("Merging configurations") + + with open(config_path, "r") as f: + main_config = json.load(f) + + with open(entropy_params_path, "r") as f: + entropy_data = json.load(f) + + entropy_model_params = entropy_data.get("entropy_model", {}) + patcher_args = entropy_data.get("data", {}).get("patcher_args", {}) + + unified_config = main_config.copy()["args"] + + for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]: + if key in unified_config and not isinstance(unified_config[key], int): + unified_config[key] = int(unified_config[key]) + + patch_size = patcher_args.get("patch_size", 8) + if isinstance(patch_size, float): + patch_size = int(patch_size) + + # Create patcher config + patcher_hidden_size = int(entropy_model_params.get("dim", 512)) + patcher_multiple_of = int(entropy_model_params.get("multiple_of", 256)) + patcher_intermediate_size = patcher_multiple_of * ((int(8 * patcher_hidden_size / 3) + patcher_multiple_of - 1) // patcher_multiple_of) + + patcher_config = { + "vocab_size": int(entropy_model_params.get("vocab_size", 256)), + "hidden_size": patcher_hidden_size, + "num_hidden_layers": int(entropy_model_params.get("n_layers", 8)), + "num_attention_heads": int(entropy_model_params.get("n_heads", 8)), + "num_key_value_heads": int(entropy_model_params.get("n_kv_heads")) + if entropy_model_params.get("n_kv_heads") is not None + else None, + "max_position_embeddings": int(entropy_model_params.get("max_seqlen", 1024)), + "norm_eps": entropy_model_params.get("norm_eps", 1e-5), + "dropout": entropy_model_params.get("dropout", 0.0), + "rope_theta": entropy_model_params.get("rope_theta", 10000.0), + "attn_impl": entropy_model_params.get("attn_impl", "sdpa"), + "attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"), + "intermediate_size": patcher_intermediate_size, + } + + # Create encoder config + encoder_hidden_size = unified_config.get("dim_local_encoder", 1024) + encoder_multiple_of = unified_config.get("multiple_of", 256) + encoder_intermediate_size = encoder_multiple_of * ((int(8 * encoder_hidden_size / 3) + encoder_multiple_of - 1) // encoder_multiple_of) + + encoder_config = { + "vocab_size": unified_config.get("vocab_size", 256), + "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_encoder", False), + "cross_attn_k": unified_config.get("cross_attn_k", 2), + "hidden_size_global": unified_config.get("hidden_size_global", 2048), + "pm_size": unified_config.get("pm_size", 0), + "hidden_size": encoder_hidden_size, + "num_attention_heads": unified_config.get("n_heads_local_encoder", 16), + "num_key_value_heads": unified_config.get("n_kv_heads"), + "num_hidden_layers": unified_config.get("n_layers_local_encoder", 1), + "norm_eps": unified_config.get("norm_eps", 1e-5), + "dropout": unified_config.get("dropout", 0.0), + "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024), + "rope_theta": unified_config.get("rope_theta", 10000.0), + "rope_scaling": {"rope_type": "default"}, + "hidden_act": unified_config.get("hidden_act", "silu"), + "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"), + "intermediate_size": encoder_intermediate_size, + } + + # Create decoder config + decoder_hidden_size = unified_config.get("dim_local_decoder", 1024) + decoder_multiple_of = unified_config.get("multiple_of", 256) + decoder_intermediate_size = decoder_multiple_of * ((int(8 * decoder_hidden_size / 3) + decoder_multiple_of - 1) // decoder_multiple_of) + + decoder_config = { + "vocab_size": unified_config.get("vocab_size", 256), + "cross_attn_all_layers": unified_config.get("cross_attn_all_layers_decoder", False), + "cross_attn_k": unified_config.get("cross_attn_k", 2), + "hidden_size_global": unified_config.get("hidden_size_global", 2048), + "hidden_size": decoder_hidden_size, + "num_attention_heads": unified_config.get("n_heads_local_decoder", 16), + "num_key_value_heads": unified_config.get("n_kv_heads"), + "num_hidden_layers": unified_config.get("n_layers_local_decoder", 9), + "norm_eps": unified_config.get("norm_eps", 1e-5), + "dropout": unified_config.get("dropout", 0.0), + "max_position_embeddings": unified_config.get("max_encoder_seq_length") or unified_config.get("max_seqlen", 1024), + "rope_theta": unified_config.get("rope_theta", 10000.0), + "rope_scaling": {"rope_type": "default"}, + "hidden_act": unified_config.get("hidden_act", "silu"), + "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"), + "intermediate_size": decoder_intermediate_size, + } + + # Create global transformer config + global_hidden_size = unified_config.get("dim_global", 2048) + global_multiple_of = unified_config.get("multiple_of", 256) + global_intermediate_size = global_multiple_of * ((int(8 * global_hidden_size / 3) + global_multiple_of - 1) // global_multiple_of) + + global_config = { + "hidden_size": global_hidden_size, + "num_attention_heads": unified_config.get("n_heads_global", 16), + "num_key_value_heads": unified_config.get("n_kv_heads_global"), + "num_hidden_layers": unified_config.get("n_layers_global", 25), + "norm_eps": unified_config.get("norm_eps", 1e-5), + "dropout": unified_config.get("dropout", 0.0), + "max_position_embeddings": unified_config.get("max_seqlen", 1024), + "rope_theta": unified_config.get("rope_theta", 10000.0), + "rope_scaling": {"rope_type": "default"}, + "hidden_act": unified_config.get("hidden_act", "silu"), + "_attn_implementation": unified_config.get("_attn_implementation", "sdpa"), + "intermediate_size": global_intermediate_size, + } + + # Create main config with sub-configs + main_config_dict = { + "model_type": "blt", + "vocab_size": unified_config.get("vocab_size", 256), + "max_position_embeddings": unified_config.get("max_seqlen", 1024), + "patch_in_forward": True, + "realtime_patching": True, + "patching_mode": "entropy", + "patch_size": patch_size, + "patching_threshold": patcher_args.get("threshold", 0.5), + "patching_threshold_add": patcher_args.get("threshold_add", 0.0), + "max_patch_length": patcher_args.get("max_patch_length"), + "patching_batch_size": patcher_args.get("patching_batch_size", 1), + "patching_device": patcher_args.get("patching_device", "cuda"), + "monotonicity": patcher_args.get("monotonicity", False), + "cross_attn_k": unified_config.get("cross_attn_k", 2), + "encoder_hash_byte_group_size": unified_config.get("encoder_hash_byte_group_size"), + "encoder_hash_byte_group_vocab": unified_config.get("encoder_hash_byte_group_vocab", 30000), + "encoder_hash_byte_group_nb_functions": unified_config.get("encoder_hash_byte_group_nb_functions", 3), + "pm_size": unified_config.get("pm_size", 0), + "patcher_config": patcher_config, + "encoder_config": encoder_config, + "decoder_config": decoder_config, + "global_config": global_config, + } + + main_config_dict["tie_word_embeddings"] = False + + logger.info(f"Merged configuration with {len(main_config_dict)} parameters") + return main_config_dict + + +def apply_weight_mapping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + component_mappings = { + ".attention.": ".self_attn.", + ".feed_forward.": ".mlp.", + ".attention_norm.": ".input_layernorm.", + ".ffn_norm.": ".post_attention_layernorm.", + ".tok_embeddings.": ".embed_tokens.", + ".cross_attn_norm_q.": ".q_norm.", + ".cross_attn_norm_kv.": ".k_norm.", + ".w1.": ".gate_proj.", + ".w2.": ".down_proj.", + ".w3.": ".up_proj.", + ".wq.": ".q_proj.", + ".wk.": ".k_proj.", + ".wv.": ".v_proj.", + ".wo.": ".o_proj.", + ".output.": ".lm_head.", + } + + new_state_dict = {} + + for old_key, tensor in state_dict.items(): + new_key = old_key + + for old_pattern, new_pattern in component_mappings.items(): + if old_pattern in new_key: + new_key = new_key.replace(old_pattern, new_pattern) + + new_state_dict[new_key] = tensor + + return new_state_dict + + +def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]: + main_weights = load_file(weights_path) + + entropy_weights = torch.load(entropy_weights_path, map_location="cpu", weights_only=True) + + if "model" in entropy_weights: + entropy_weights = entropy_weights["model"] + elif "state_dict" in entropy_weights: + entropy_weights = entropy_weights["state_dict"] + + unified_weights = main_weights.copy() + + for key, tensor in entropy_weights.items(): + patcher_key = f"patcher.{key}" + unified_weights[patcher_key] = tensor + + unified_weights = apply_weight_mapping(unified_weights) + + decoder_lm_head_key = "local_decoder.lm_head.weight" + top_lm_head_key = "lm_head.weight" + unified_weights[top_lm_head_key] = unified_weights[decoder_lm_head_key] + del unified_weights[decoder_lm_head_key] + + prefixed_weights = {} + for key, tensor in unified_weights.items(): + if key == top_lm_head_key: + prefixed_weights[key] = tensor + elif not key.startswith("model."): + prefixed_weights[f"model.{key}"] = tensor + else: + prefixed_weights[key] = tensor + + unified_weights = prefixed_weights + + return unified_weights + + +def create_tokenizer_config(output_dir: str, config: Dict[str, Any]): + tokenizer_config = { + "tokenizer_class": "BltTokenizer", + "vocab_size": config.get("vocab_size", 256), + "model_max_length": config.get("max_seqlen", 1024), + "add_bos_token": True, + "add_eos_token": True, + "bos_token": "", + "eos_token": "", + "pad_token": "", + "unk_token": "", + } + + tokenizer_path = os.path.join(output_dir, "tokenizer_config.json") + with open(tokenizer_path, "w") as f: + json.dump(tokenizer_config, f, indent=2) + + +def push_to_hub( + local_dir: str, + repo_id: str, + commit_message: str = "Upload converted BLT model", + private: bool = False, + token: Optional[str] = None, +) -> None: + try: + upload_folder( + folder_path=local_dir, + repo_id=repo_id, + commit_message=commit_message, + repo_type="model", + token=token, + ) + logger.info(f"Successfully pushed model to {repo_id}") + + except Exception as e: + logger.error(f"Failed to push model to Hub: {e}") + raise + + +def convert_hf_blt_to_unified( + model_id: str, + output_dir: str, + config_name: str = "config.json", + weights_name: str = "model.bin", + cache_dir: Optional[str] = None, + push_to_hub_repo: Optional[str] = None, + hub_private: bool = False, + hub_token: Optional[str] = None, +) -> None: + # Download model files + config_path = hf_hub_download(repo_id=model_id, filename="config.json", cache_dir=cache_dir) + weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", cache_dir=cache_dir) + entropy_params_path = hf_hub_download(repo_id=model_id, filename="entropy_model/params.json", cache_dir=cache_dir) + entropy_weights_path = hf_hub_download( + repo_id=model_id, filename="entropy_model/consolidated.pth", cache_dir=cache_dir + ) + + unified_config = merge_configurations(config_path, entropy_params_path) + unified_weights = merge_weights(weights_path, entropy_weights_path) + + os.makedirs(output_dir, exist_ok=True) + + config_path = os.path.join(output_dir, config_name) + with open(config_path, "w") as f: + json.dump(unified_config, f, indent=2) + + if weights_name.endswith(".bin"): + weights_name = weights_name.replace(".bin", ".safetensors") + + weights_path = os.path.join(output_dir, weights_name) + save_file(unified_weights, weights_path) + + create_tokenizer_config(output_dir, unified_config) + + logger.info(f"Conversion completed, model saved to: {output_dir}") + + if push_to_hub_repo: + push_to_hub( + local_dir=output_dir, + repo_id=push_to_hub_repo, + commit_message="Upload BLT model converted", + private=hub_private, + token=hub_token, + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Convert BLT models from HuggingFace Hub format to unified format", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--model_id", + type=str, + default="facebook/blt-1b", + ) + parser.add_argument( + "--output_dir", + type=str, + default="./blt_converted", + ) + parser.add_argument( + "--config_name", + type=str, + default="config.json", + ) + parser.add_argument( + "--weights_name", + type=str, + default="model.bin", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + ) + parser.add_argument( + "--debug", + action="store_true", + default=True, + ) + parser.add_argument( + "--push_to_hub", + type=str, + default=None, + ) + parser.add_argument( + "--hub_private", + action="store_true", + default=False, + ) + parser.add_argument( + "--hub_token", + type=str, + default="hf_token", + ) + + args = parser.parse_args() + + transformers_logging.set_verbosity_debug() + logging.basicConfig(level=logging.DEBUG) + + try: + convert_hf_blt_to_unified( + model_id=args.model_id, + output_dir=args.output_dir, + config_name=args.config_name, + weights_name=args.weights_name, + cache_dir=args.cache_dir, + push_to_hub_repo=args.push_to_hub, + hub_private=args.hub_private, + hub_token=args.hub_token, + ) + except Exception as e: + logger.error(f"Conversion failed: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/backup_blt_modellike/modeling_blt.py b/backup_blt_modellike/modeling_blt.py new file mode 100644 index 0000000000000000000000000000000000000000..fdcaa9062d336430668cc040ae4ca080b5c69977 --- /dev/null +++ b/backup_blt_modellike/modeling_blt.py @@ -0,0 +1,971 @@ +# coding=utf-8 +# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from .configuration_blt import BLTConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + +from ...integrations import use_kernel_forward_from_hub + + +logger = logging.get_logger(__name__) + + +@use_kernel_forward_from_hub("RMSNorm") +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->BLT +class BLTRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + BLTRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(BLTRMSNorm) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->BLT +class BLTRotaryEmbedding(nn.Module): + def __init__(self, config: BLTConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->BLT +class BLTMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->BLT +class BLTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: BLTConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->BLT +class BLTDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: BLTConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = BLTAttention(config=config, layer_idx=layer_idx) + + self.mlp = BLTMLP(config) + self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->BLT +class BLTPreTrainedModel(PreTrainedModel): + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, BLTRMSNorm): + module.weight.data.fill_(1.0) + + +@auto_docstring +# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->BLT +class BLTModel(BLTPreTrainedModel): + def __init__(self, config: BLTConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [BLTDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = BLTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = BLTRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +@auto_docstring +# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->BLT,llama->blt +class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = BLTModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, BLTForCausalLM + + >>> model = BLTForCausalLM.from_pretrained("meta-blt/BLT-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-blt/BLT-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`BLTForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """ +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->BLT +class BLTForSequenceClassification(BLTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = BLTModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring +# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->BLT +class BLTForQuestionAnswering(BLTPreTrainedModel): + base_model_prefix = "transformer" + + def __init__(self, config): + super().__init__(config) + self.transformer = BLTModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> QuestionAnsweringModelOutput: + outputs: BaseModelOutputWithPast = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs.last_hidden_state + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->BLT +class BLTForTokenClassification(BLTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = BLTModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "BLTForCausalLM", + "BLTModel", + "BLTPreTrainedModel", + "BLTForSequenceClassification", + "BLTForQuestionAnswering", + "BLTForTokenClassification", +] diff --git a/backup_blt_modellike/tokenization_blt.py b/backup_blt_modellike/tokenization_blt.py new file mode 100644 index 0000000000000000000000000000000000000000..145f257f0eeca91ca4ddc89470300d42f3372025 --- /dev/null +++ b/backup_blt_modellike/tokenization_blt.py @@ -0,0 +1,412 @@ +# coding=utf-8 +# Copyright 2025 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for BLT.""" + +import os +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from ...convert_slow_tokenizer import import_protobuf +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +if TYPE_CHECKING: + from ...tokenization_utils_base import TextInput + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +SPIECE_UNDERLINE = "▁" + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" # fmt: skip + + +@requires(backends=("sentencepiece",)) +class BLTTokenizer(PreTrainedTokenizer): + """ + Construct a BLT tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is + no padding token in the original model. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + use_default_system_prompt (`bool`, *optional*, defaults to `False`): + Whether or not the default system prompt for BLT should be used. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. + legacy (`bool`, *optional*): + Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622 + and #25224 which includes fixes to properly handle tokens that appear after special tokens. + Make sure to also set `from_slow` to `True`. + A simple example: + + - `legacy=True`: + ```python + >>> from transformers import BLTTokenizerFast + + >>> tokenizer = BLTTokenizerFast.from_pretrained("huggyblt/blt-7b", legacy=True, from_slow=True) + >>> tokenizer.encode("Hello .") # 869 is '▁.' + [1, 15043, 29871, 1, 869] + ``` + - `legacy=False`: + ```python + >>> from transformers import BLTTokenizerFast + + >>> tokenizer = BLTTokenizerFast.from_pretrained("huggyblt/blt-7b", legacy=False, from_slow=True) + >>> tokenizer.encode("Hello .") # 29889 is '.' + [1, 15043, 29871, 1, 29889] + ``` + Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. + add_prefix_space (`bool`, *optional*, defaults to `True`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. Again, this should be set with `from_slow=True` to make sure it's taken into account. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + use_default_system_prompt=False, + spaces_between_special_tokens=False, + legacy=None, + add_prefix_space=True, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + + if legacy is None: + logger.warning_once( + f"You are using the default legacy behaviour of the {self.__class__}. This is" + " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." + " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" + " means, and thoroughly read the reason why this was added as explained in" + " https://github.com/huggingface/transformers/pull/24565 - if you loaded a blt tokenizer from a GGUF file" + " you can ignore this message" + ) + legacy = True + + self.legacy = legacy + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.use_default_system_prompt = use_default_system_prompt + self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) + self.add_prefix_space = add_prefix_space + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + use_default_system_prompt=use_default_system_prompt, + spaces_between_special_tokens=spaces_between_special_tokens, + legacy=legacy, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + def get_spm_processor(self, from_slow=False): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + if self.legacy or from_slow: # no dependency on protobuf + tokenizer.Load(self.vocab_file) + return tokenizer + + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__.update(d) + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def tokenize(self, text: "TextInput", **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + text = text.replace(SPIECE_UNDERLINE, " ") + if self.add_prefix_space: + text = SPIECE_UNDERLINE + text + + tokens = super().tokenize(text, **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return self.sp_model.encode(text, out_type=str) + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: + tokens[0] = tokens[0][1:] + + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for i, token in enumerate(tokens): + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special and i != 0 and self.legacy: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + if prev_is_special and i == 1 and self.add_prefix_space and not token.startswith(SPIECE_UNDERLINE): + out_string += " " + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return ( + bos_token_id + + ([0] * len(token_ids_0)) + + eos_token_id + + bos_token_id + + ([0] * len(token_ids_1)) + + eos_token_id + ) + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output + + +#__all__ = ["BLTTokenizer"] diff --git a/backup_blt_wip copy/__init__.py b/backup_blt_wip copy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backup_blt_wip copy/__pycache__/__init__.cpython-312.pyc b/backup_blt_wip copy/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aec13a4ab4fd4175a82f4cc9acbcae04b89ed6fd Binary files /dev/null and b/backup_blt_wip copy/__pycache__/__init__.cpython-312.pyc differ diff --git a/backup_blt_wip copy/__pycache__/blt_args.cpython-312.pyc b/backup_blt_wip copy/__pycache__/blt_args.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe1c3fac23fde31ac2fe95aa5c15d35368e766fc Binary files /dev/null and b/backup_blt_wip copy/__pycache__/blt_args.cpython-312.pyc differ diff --git a/backup_blt_wip copy/__pycache__/blt_one_file.cpython-312.pyc b/backup_blt_wip copy/__pycache__/blt_one_file.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad67ee5b9c82f98e084ea14790b83f2fbbca1539 Binary files /dev/null and b/backup_blt_wip copy/__pycache__/blt_one_file.cpython-312.pyc differ diff --git a/backup_blt_wip copy/__pycache__/configuration_blt.cpython-312.pyc b/backup_blt_wip copy/__pycache__/configuration_blt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6619d112a99fe0a3db54bb3649041aa9d3254a71 Binary files /dev/null and b/backup_blt_wip copy/__pycache__/configuration_blt.cpython-312.pyc differ diff --git a/backup_blt_wip copy/__pycache__/configuration_blt_og.cpython-312.pyc b/backup_blt_wip copy/__pycache__/configuration_blt_og.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65e042fc413fbde011b9ac7c4982694313c9d1e9 Binary files /dev/null and b/backup_blt_wip copy/__pycache__/configuration_blt_og.cpython-312.pyc differ diff --git a/backup_blt_wip copy/__pycache__/modeling_blt.cpython-312.pyc b/backup_blt_wip copy/__pycache__/modeling_blt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d11c9fbc187865ec15972562ad1fc46629f6a986 Binary files /dev/null and b/backup_blt_wip copy/__pycache__/modeling_blt.cpython-312.pyc differ diff --git a/backup_blt_wip copy/__pycache__/modeling_blt_dev.cpython-312.pyc b/backup_blt_wip copy/__pycache__/modeling_blt_dev.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df77d3f4ae267eada2d86e052112370cda866d91 Binary files /dev/null and b/backup_blt_wip copy/__pycache__/modeling_blt_dev.cpython-312.pyc differ diff --git a/backup_blt_wip copy/__pycache__/modeling_blt_modellike.cpython-312.pyc b/backup_blt_wip copy/__pycache__/modeling_blt_modellike.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dcf920f19d283f8290013fc57df016f94f08853 Binary files /dev/null and b/backup_blt_wip copy/__pycache__/modeling_blt_modellike.cpython-312.pyc differ diff --git a/backup_blt_wip copy/__pycache__/modeling_blt_old.cpython-312.pyc b/backup_blt_wip copy/__pycache__/modeling_blt_old.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0066695b7e54ec3a4c8d924f4fb37dfa0c9f5ec Binary files /dev/null and b/backup_blt_wip copy/__pycache__/modeling_blt_old.cpython-312.pyc differ diff --git a/backup_blt_wip copy/__pycache__/modeling_blt_wip.cpython-312.pyc b/backup_blt_wip copy/__pycache__/modeling_blt_wip.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14471f3e8bde08cd94f2ba3b8ea1dc01597af449 Binary files /dev/null and b/backup_blt_wip copy/__pycache__/modeling_blt_wip.cpython-312.pyc differ diff --git a/backup_blt_wip copy/__pycache__/modeling_blt_wip_backup.cpython-312.pyc b/backup_blt_wip copy/__pycache__/modeling_blt_wip_backup.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01bb41a2f95cf657f1cb253e41a8eadbe9a4c5b7 Binary files /dev/null and b/backup_blt_wip copy/__pycache__/modeling_blt_wip_backup.cpython-312.pyc differ diff --git a/backup_blt_wip copy/__pycache__/tokenization_blt.cpython-312.pyc b/backup_blt_wip copy/__pycache__/tokenization_blt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..871072a6257040ba60bbb0473fb9be6b73117ccb Binary files /dev/null and b/backup_blt_wip copy/__pycache__/tokenization_blt.cpython-312.pyc differ diff --git a/backup_blt_wip copy/configuration_blt.py b/backup_blt_wip copy/configuration_blt.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a245afebc3be28dc5bc10f7ff9ef6097ea8247 --- /dev/null +++ b/backup_blt_wip copy/configuration_blt.py @@ -0,0 +1,390 @@ +# coding=utf-8 +# Copyright 2024 Facebook Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BLT model configuration""" + +from enum import Enum +from typing import Union + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +class BLTLocalEncoderConfig(PretrainedConfig): + """ + Configuration class for the BLT Local Encoder component. + """ + + model_type = "blt_local_encoder" + + def __init__( + self, + vocab_size=256, + cross_attn_all_layers=True, + cross_attn_k=2, + hidden_size_global=2048, + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=None, + num_hidden_layers=8, + norm_eps=1e-5, + dropout=0.0, + max_position_embeddings=1024, + rope_theta=10000.0, + rope_scaling=None, + hidden_act="silu", + intermediate_size=None, + _attn_implementation="sdpa", + **kwargs, + ): + self.vocab_size = vocab_size + self.cross_attn_all_layers = cross_attn_all_layers + self.cross_attn_k = cross_attn_k + self.hidden_size_global = hidden_size_global + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads or num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.intermediate_size = intermediate_size or int(8 * hidden_size / 3) + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.dropout = dropout + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling or {"rope_type": "default"} + self.hidden_act = hidden_act + self._attn_implementation = _attn_implementation + + super().__init__(**kwargs) + +class BLTLocalDecoderConfig(PretrainedConfig): + """ + Configuration class for the BLT Local Decoder component. + """ + + model_type = "blt_local_decoder" + + def __init__( + self, + vocab_size=256, + cross_attn_all_layers=True, + cross_attn_k=2, + hidden_size_global=2048, + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=None, + num_hidden_layers=8, + norm_eps=1e-5, + dropout=0.0, + max_position_embeddings=1024, + rope_theta=10000.0, + rope_scaling=None, + hidden_act="silu", + intermediate_size=None, + _attn_implementation="sdpa", + **kwargs, + ): + self.vocab_size = vocab_size + self.cross_attn_all_layers = cross_attn_all_layers + self.cross_attn_k = cross_attn_k + self.hidden_size_global = hidden_size_global + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads or num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.intermediate_size = intermediate_size or int(8 * hidden_size / 3) + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.dropout = dropout + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling or {"rope_type": "default"} + self.hidden_act = hidden_act + self._attn_implementation = _attn_implementation + + super().__init__(**kwargs) + + +class BLTGlobalTransformerConfig(PretrainedConfig): + """ + Configuration class for the BLT Global Transformer component. + """ + + model_type = "blt_global_transformer" + + def __init__( + self, + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=None, + num_hidden_layers=8, + norm_eps=1e-5, + dropout=0.0, + max_position_embeddings=1024, + rope_theta=10000.0, + rope_scaling=None, + hidden_act="silu", + intermediate_size=None, + _attn_implementation="sdpa", + **kwargs, + ): + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads or num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.intermediate_size = intermediate_size or int(8 * hidden_size / 3) + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.dropout = dropout + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling or {"rope_type": "default"} + self.hidden_act = hidden_act + self._attn_implementation = _attn_implementation + + super().__init__(**kwargs) + + +class BLTPatcherConfig(PretrainedConfig): + r""" + Configuration class for the BLT Patcher/Entropy model component. + + Args: + vocab_size (`int`, *optional*, defaults to 256): + Vocabulary size for the entropy model used in patching. + hidden_size (`int`, *optional*, defaults to 512): + Hidden dimension for the entropy model. + num_hidden_layers (`int`, *optional*, defaults to 8): + Number of layers in the entropy model. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads in the entropy model. + head_dim (`int`, *optional*): + Dimension of each attention head in the entropy model. + num_key_value_heads (`int`, *optional*): + Number of key-value heads in the entropy model. + max_position_embeddings (`int`, *optional*, defaults to 1024): + Maximum sequence length for the entropy model. + norm_eps (`float`, *optional*, defaults to 1e-5): + Layer normalization epsilon for the entropy model. + dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for the entropy model. + ffn_dim_multiplier (`float`, *optional*): + Feedforward dimension multiplier for the entropy model. + multiple_of (`int`, *optional*, defaults to 256): + Make feedforward dimension multiple of this for the entropy model. + rope_theta (`float`, *optional*, defaults to 10000.0): + RoPE theta parameter for the entropy model. + attn_impl (`str`, *optional*, defaults to "sdpa"): + Attention implementation for the entropy model. + attn_bias_type (`str`, *optional*, defaults to "causal"): + Attention bias type for the entropy model. + """ + + model_type = "blt_patcher" + + def __init__( + self, + vocab_size=256, + hidden_size=512, + num_hidden_layers=8, + num_attention_heads=8, + num_key_value_heads=None, + max_position_embeddings=1024, + norm_eps=1e-5, + dropout=0.0, + rope_theta=10000.0, + attn_impl="sdpa", + attn_bias_type="causal", + intermediate_size=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.norm_eps = norm_eps + self.dropout = dropout + self.rope_theta = rope_theta + self.attn_impl = attn_impl + self.attn_bias_type = attn_bias_type + self.hidden_act = "silu" # BLT uses silu activation + self.intermediate_size = intermediate_size or int(8 * self.hidden_size / 3) + self.rope_scaling = {"rope_type": "default"} + super().__init__(**kwargs) + + +class BLTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BLTModel`]. It is used to instantiate a + BLT model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 256): + Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model can handle. + + # Patching configuration + patch_in_forward (`bool`, *optional*, defaults to False): + Whether to perform patching during forward pass. + patch_size (`float`, *optional*): + Size of patches for static patching. + patching_mode (`str`, *optional*): + Mode for patching ("entropy", "static", etc.). + patching_threshold (`float`, *optional*): + Threshold for entropy-based patching. + patching_batch_size (`int`, *optional*, defaults to 1): + Batch size for patching operations. + patching_device (`str`, *optional*, defaults to "cuda"): + Device to use for patching operations. + max_patch_length (`int`, *optional*): + Maximum length of patches. + + # Cross attention configurations + cross_attn_k (`int`, *optional*): + Number of cross attention components. + + # Encoder configurations + encoder_hash_byte_group_size (`Any`, *optional*): + Hash byte group size for encoder. + encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 30000): + Vocabulary size for hash byte groups. + encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 3): + Number of hash functions for byte groups. + + # Component configurations + patcher_config (`Union[BLTPatcherConfig, dict]`, *optional*): + Configuration for the BLT patcher/entropy model component. + encoder_config (`Union[BLTLocalEncoderConfig, dict]`, *optional*): + Configuration for the BLT local encoder component. + decoder_config (`Union[BLTLocalDecoderConfig, dict]`, *optional*): + Configuration for the BLT local decoder component. + global_config (`Union[BLTGlobalTransformerConfig, dict]`, *optional*): + Configuration for the BLT global transformer component. + + ```python + >>> from transformers import BLTModel, BLTConfig + + >>> # Initializing a BLT configuration + >>> configuration = BLTConfig() + + >>> # Initializing a model from the configuration + >>> model = BLTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "blt" + keys_to_ignore_at_inference = ["past_key_values"] + sub_configs = { + "patcher_config": BLTPatcherConfig, + "encoder_config": BLTLocalEncoderConfig, + "decoder_config": BLTLocalDecoderConfig, + "global_config": BLTGlobalTransformerConfig + } + + def __init__( + self, + vocab_size=256, + max_position_embeddings=1024, + patch_in_forward=False, + patch_size=None, + patching_mode=None, + patching_threshold=None, + patching_batch_size=1, + max_patch_length=None, + cross_attn_k=2, + encoder_hash_byte_group_size=None, + encoder_hash_byte_group_vocab=30000, + encoder_hash_byte_group_nb_functions=3, + patcher_config=None, + encoder_config=None, + decoder_config=None, + global_config=None, + tie_word_embeddings=False, + **kwargs, + ): + + # Basic model configuration + self.tie_word_embeddings = tie_word_embeddings + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + + # Patching configuration + self.patch_in_forward = patch_in_forward + self.patch_size = patch_size + self.patching_mode = patching_mode + self.patching_threshold = patching_threshold + self.patching_batch_size = patching_batch_size + self.max_patch_length = max_patch_length + + # Cross attention configurations + self.cross_attn_k = cross_attn_k + + # Encoder configurations + self.encoder_hash_byte_group_size = encoder_hash_byte_group_size or [2, 3, 4] + self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab + self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions + + # Initialize component configurations + if patcher_config is None: + self.patcher_config = BLTPatcherConfig() + logger.info("patcher_config is None, using default BLT patcher config") + elif isinstance(patcher_config, dict): + self.patcher_config = BLTPatcherConfig(**patcher_config) + elif isinstance(patcher_config, BLTPatcherConfig): + self.patcher_config = patcher_config + + if encoder_config is None: + self.encoder_config = BLTLocalEncoderConfig() + logger.info("encoder_config is None, using default BLT encoder config") + elif isinstance(encoder_config, dict): + self.encoder_config = BLTLocalEncoderConfig(**encoder_config) + elif isinstance(encoder_config, BLTLocalEncoderConfig): + self.encoder_config = encoder_config + + if decoder_config is None: + self.decoder_config = BLTLocalDecoderConfig() + logger.info("decoder_config is None, using default BLT decoder config") + elif isinstance(decoder_config, dict): + self.decoder_config = BLTLocalDecoderConfig(**decoder_config) + elif isinstance(decoder_config, BLTLocalDecoderConfig): + self.decoder_config = decoder_config + + if global_config is None: + self.global_config = BLTGlobalTransformerConfig() + logger.info("global_config is None, using default BLT global config") + elif isinstance(global_config, dict): + self.global_config = BLTGlobalTransformerConfig(**global_config) + elif isinstance(global_config, BLTGlobalTransformerConfig): + self.global_config = global_config + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + +__all__ = [ + "BLTConfig", + "BLTPatcherConfig", + "BLTLocalEncoderConfig", + "BLTLocalDecoderConfig", + "BLTGlobalTransformerConfig", +] diff --git a/backup_blt_wip copy/configuration_blt_og.py b/backup_blt_wip copy/configuration_blt_og.py new file mode 100644 index 0000000000000000000000000000000000000000..60af723d89d72599eb0add3881c2b99e9e7390ce --- /dev/null +++ b/backup_blt_wip copy/configuration_blt_og.py @@ -0,0 +1,608 @@ +# old config + +# coding=utf-8 +# Copyright 2024 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BLT (Byte Latent Transformer) model configuration""" + +from enum import Enum + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class InitStdFactor(str, Enum): + DISABLED = "disabled" # Init std is divided by 1.0 + CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) + + +class PatchingModeEnum(str, Enum): + entropy = "entropy" + bpe = "bpe" + bpe_patcher = "bpe_patcher" + space = "space" + static = "static" + byte = "byte" + + +class BLTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ByteLatentTransformer`]. It is used to instantiate a + BLT model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 256): + Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented. + max_seqlen (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model can handle. + + # Main architecture dimensions + dim (`int`, *optional*, defaults to 512): + Main dimension of the model. + n_layers (`int`, *optional*, defaults to 8): + Number of layers in the main transformer. + n_heads (`int`, *optional*, defaults to 8): + Number of attention heads in the main transformer. + head_dim (`int`, *optional*): + Dimension of each attention head. If not specified, computed as dim // n_heads. + n_kv_heads (`int`, *optional*): + Number of key-value heads for grouped query attention. If not specified, defaults to n_heads. + + # Component-specific dimensions + dim_global (`int`, *optional*, defaults to 512): + Dimension of the global transformer component. + dim_local_decoder (`int`, *optional*, defaults to 512): + Dimension of the local decoder component. + dim_local_encoder (`int`, *optional*, defaults to 512): + Dimension of the local encoder component. + n_layers_global (`int`, *optional*, defaults to 8): + Number of layers in the global transformer. + n_layers_local_decoder (`int`, *optional*, defaults to 8): + Number of layers in the local decoder. + n_layers_local_encoder (`int`, *optional*, defaults to 8): + Number of layers in the local encoder. + n_heads_global (`int`, *optional*, defaults to 8): + Number of attention heads in the global transformer. + n_heads_local_decoder (`int`, *optional*, defaults to 8): + Number of attention heads in the local decoder. + n_heads_local_encoder (`int`, *optional*, defaults to 8): + Number of attention heads in the local encoder. + n_kv_heads_global (`int`, *optional*): + Number of key-value heads in the global transformer. + + # Transformer configuration + norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers. + ffn_dim_multiplier (`float`, *optional*, defaults to 1.0): + Multiplier for the feedforward network dimension. + multiple_of (`int`, *optional*, defaults to 256): + Make feedforward network dimension multiple of this value. + + # Positional encoding + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False): + Whether to use fp32 in RoPE outer product computation. + + # Attention configuration + attn_impl (`str`, *optional*, defaults to "sdpa"): + Attention implementation to use ("sdpa" or "flex_attention"). + attn_bias_type (`str`, *optional*, defaults to "causal"): + Type of attention bias to apply. + local_attention_window_len (`int`, *optional*): + Window length for local attention. + use_rope (`bool`, *optional*, defaults to True): + Whether to use rotary position embeddings. + + # Initialization + init_base_std (`float`, *optional*): + Base standard deviation for weight initialization. + init_std_factor (`str`, *optional*, defaults to "disabled"): + Factor for adjusting initialization standard deviation. + + # Embedding dimensions + dim_token_emb (`int`, *optional*): + Token embedding dimension. + dim_token (`int`, *optional*): + Token dimension. + + # Patching configuration + patch_in_forward (`bool`, *optional*, defaults to False): + Whether to perform patching during forward pass. + realtime_patching (`bool`, *optional*, defaults to True): + Whether to use realtime patching. + patch_size (`float`, *optional*): + Size of patches for static patching. + patching_mode (`str`, *optional*): + Mode for patching ("entropy", "static", etc.). + patching_threshold (`float`, *optional*): + Threshold for entropy-based patching. + patching_threshold_add (`float`, *optional*): + Additional threshold parameter for patching. + monotonicity (`bool`, *optional*, defaults to False): + Whether to enforce monotonicity in patching. + patching_batch_size (`int`, *optional*, defaults to 1): + Batch size for patching operations. + patching_device (`str`, *optional*, defaults to "cuda"): + Device to use for patching operations. + max_patch_length (`int`, *optional*): + Maximum length of patches. + entropy_model_checkpoint_dir (`str`, *optional*): + Directory containing entropy model checkpoint. + + # Cross attention configurations + cross_attn_encoder (`bool`, *optional*, defaults to False): + Whether to use cross attention in encoder. + cross_attn_decoder (`bool`, *optional*, defaults to False): + Whether to use cross attention in decoder. + cross_attn_window_encoder (`int`, *optional*): + Cross attention window for encoder. + cross_attn_window_decoder (`int`, *optional*): + Cross attention window for decoder. + cross_attn_k (`int`, *optional*): + Number of cross attention components. + cross_attn_nheads (`int`, *optional*): + Number of heads for cross attention. + cross_attn_all_layers_decoder (`bool`, *optional*, defaults to False): + Whether to apply cross attention to all decoder layers. + cross_attn_all_layers_encoder (`bool`, *optional*, defaults to False): + Whether to apply cross attention to all encoder layers. + cross_attn_use_flex_attention (`bool`, *optional*, defaults to True): + Whether to use flexible attention for cross attention. + cross_attn_init_by_pooling (`bool`, *optional*, defaults to False): + Whether to initialize cross attention by pooling. + + # Encoder configurations + use_local_encoder_transformer (`bool`, *optional*, defaults to False): + Whether to use transformer in local encoder. + max_encoder_seq_length (`int`, *optional*): + Maximum sequence length for encoder. + encoder_hash_byte_group_size (`Any`, *optional*): + Hash byte group size for encoder. + encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 30000): + Vocabulary size for hash byte groups. + encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 3): + Number of hash functions for byte groups. + encoder_enable_byte_ngrams (`bool`, *optional*, defaults to False): + Whether to enable byte n-grams in encoder. + encoder_ngram_to_size_str (`str`, *optional*): + String defining n-gram sizes. + downsampling_by_pooling (`str`, *optional*): + Type of pooling for downsampling. + + # Model behavior + share_encoder_decoder_emb (`bool`, *optional*, defaults to True): + Whether to share encoder and decoder embeddings. + weight_tying (`bool`, *optional*, defaults to False): + Whether to tie input and output embeddings. + + # Performance optimization + sequence_parallel (`bool`, *optional*, defaults to False): + Whether to use sequence parallelism. + loss_parallel (`bool`, *optional*, defaults to False): + Whether to use loss parallelism. + fuse_sequence_parallel (`bool`, *optional*, defaults to False): + Whether to fuse sequence parallel operations. + use_fsdp (`bool`, *optional*, defaults to True): + Whether to use fully sharded data parallel. + + # Parameter mixing + pm_size (`int`, *optional*, defaults to 0): + Parameter mixing size. + + # Special tokens + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to -1): + The id of the padding token. + + # Patcher/Entropy model configuration + patcher_vocab_size (`int`, *optional*, defaults to 256): + Vocabulary size for the entropy model used in patching. + patcher_dim (`int`, *optional*, defaults to 512): + Hidden dimension for the entropy model. + patcher_n_layers (`int`, *optional*, defaults to 8): + Number of layers in the entropy model. + patcher_n_heads (`int`, *optional*, defaults to 8): + Number of attention heads in the entropy model. + patcher_head_dim (`int`, *optional*): + Dimension of each attention head in the entropy model. + patcher_n_kv_heads (`int`, *optional*): + Number of key-value heads in the entropy model. + patcher_max_seqlen (`int`, *optional*, defaults to 1024): + Maximum sequence length for the entropy model. + patcher_norm_eps (`float`, *optional*, defaults to 1e-5): + Layer normalization epsilon for the entropy model. + patcher_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for the entropy model. + patcher_sliding_window (`int`, *optional*): + Sliding window size for the entropy model attention. + patcher_ffn_dim_multiplier (`float`, *optional*): + Feedforward dimension multiplier for the entropy model. + patcher_multiple_of (`int`, *optional*, defaults to 256): + Make feedforward dimension multiple of this for the entropy model. + patcher_rope_theta (`float`, *optional*, defaults to 10000.0): + RoPE theta parameter for the entropy model. + patcher_rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False): + Whether to use fp32 in RoPE outer product for the entropy model. + patcher_attn_impl (`str`, *optional*, defaults to "sdpa"): + Attention implementation for the entropy model. + patcher_attn_bias_type (`str`, *optional*, defaults to "causal"): + Attention bias type for the entropy model. + patcher_init_base_std (`float`, *optional*): + Base initialization standard deviation for the entropy model. + patcher_init_std_factor (`str`, *optional*, defaults to "disabled"): + Initialization std factor for the entropy model. + patcher_dim_token_emb (`int`, *optional*): + Token embedding dimension for the entropy model. + patcher_weight_tying (`bool`, *optional*, defaults to False): + Whether to tie embeddings in the entropy model. + patcher_bos_token_id (`int`, *optional*, defaults to 1): + Beginning of sequence token id for the entropy model. + patcher_eos_token_id (`int`, *optional*, defaults to 2): + End of sequence token id for the entropy model. + + ```python + >>> from transformers import ByteLatentTransformer, BLTConfig + + >>> # Initializing a BLT configuration + >>> configuration = BLTConfig() + + >>> # Initializing a model from the configuration + >>> model = ByteLatentTransformer(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "blt" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=256, + max_seqlen=1024, + # Main architecture dimensions + dim=512, + n_layers=8, + n_heads=8, + head_dim=None, + n_kv_heads=None, + # Component-specific dimensions + dim_global=512, + dim_local_decoder=512, + dim_local_encoder=512, + n_layers_global=8, + n_layers_local_decoder=8, + n_layers_local_encoder=8, + n_heads_global=8, + n_heads_local_decoder=8, + n_heads_local_encoder=8, + n_kv_heads_global=None, + # Transformer configuration + norm_eps=1e-5, + dropout=0.0, + ffn_dim_multiplier=1.0, + multiple_of=256, + # Positional encoding + rope_theta=10000.0, + rope_use_fp32_in_outer_product=False, + # Attention configuration + attn_impl="sdpa", + attn_bias_type="causal", + local_attention_window_len=None, + use_rope=True, + # Initialization + init_base_std=None, + init_std_factor="disabled", + # Embedding dimensions + dim_token_emb=None, + dim_token=None, + # Patching configuration + patch_in_forward=False, + realtime_patching=True, + patch_size=None, + patching_mode=None, + patching_threshold=None, + patching_threshold_add=None, + monotonicity=False, + patching_batch_size=1, + patching_device="cuda", + max_patch_length=None, + entropy_model_checkpoint_dir=None, + # Cross attention configurations + cross_attn_encoder=False, + cross_attn_decoder=False, + cross_attn_window_encoder=None, + cross_attn_window_decoder=None, + cross_attn_k=None, + cross_attn_nheads=None, + cross_attn_all_layers_decoder=False, + cross_attn_all_layers_encoder=False, + cross_attn_use_flex_attention=True, + cross_attn_init_by_pooling=False, + # Encoder configurations + use_local_encoder_transformer=False, + max_encoder_seq_length=None, + encoder_hash_byte_group_size=None, + encoder_hash_byte_group_vocab=30000, + encoder_hash_byte_group_nb_functions=3, + encoder_enable_byte_ngrams=False, + encoder_ngram_to_size_str=None, + downsampling_by_pooling=None, + # Model behavior + share_encoder_decoder_emb=True, + weight_tying=False, + # Performance optimization + sequence_parallel=False, + loss_parallel=False, + fuse_sequence_parallel=False, + use_fsdp=True, + # Parameter mixing + pm_size=0, + # Special tokens + bos_token_id=1, + eos_token_id=2, + pad_token_id=-1, + # Patcher/Entropy model configuration + patcher_vocab_size=256, + patcher_dim=512, + patcher_n_layers=8, + patcher_n_heads=8, + patcher_head_dim=None, + patcher_n_kv_heads=None, + patcher_max_seqlen=1024, + patcher_norm_eps=1e-5, + patcher_dropout=0.0, + patcher_sliding_window=None, + patcher_ffn_dim_multiplier=None, + patcher_multiple_of=256, + patcher_rope_theta=10000.0, + patcher_rope_use_fp32_in_outer_product=False, + patcher_attn_impl="sdpa", + patcher_attn_bias_type="causal", + patcher_init_base_std=None, + patcher_init_std_factor="disabled", + patcher_dim_token_emb=None, + patcher_weight_tying=False, + patcher_bos_token_id=1, + patcher_eos_token_id=2, + # Inherited + **kwargs, + ): + + self.sliding_window = None + # Basic model configuration + self.vocab_size = vocab_size + self.max_seqlen = max_seqlen + + # Main architecture dimensions + self.dim = dim + self.n_layers = n_layers + self.n_heads = n_heads + self.head_dim = head_dim + self.n_kv_heads = n_kv_heads + + # Component-specific dimensions + self.dim_global = dim_global + self.dim_local_decoder = dim_local_decoder + self.dim_local_encoder = dim_local_encoder + self.n_layers_global = n_layers_global + self.n_layers_local_decoder = n_layers_local_decoder + self.n_layers_local_encoder = n_layers_local_encoder + self.n_heads_global = n_heads_global + self.n_heads_local_decoder = n_heads_local_decoder + self.n_heads_local_encoder = n_heads_local_encoder + self.n_kv_heads_global = n_kv_heads_global + + # Transformer configuration + self.norm_eps = norm_eps + self.dropout = dropout + self.ffn_dim_multiplier = ffn_dim_multiplier + self.multiple_of = multiple_of + + # Positional encoding + self.rope_theta = rope_theta + self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product + + # Attention configuration + self.attn_impl = attn_impl + self.attn_bias_type = attn_bias_type + self.local_attention_window_len = local_attention_window_len + self.use_rope = use_rope + + # Initialization + self.init_base_std = init_base_std + self.init_std_factor = InitStdFactor(init_std_factor) + + # Embedding dimensions + self.dim_token_emb = dim_token_emb + self.dim_token = dim_token + + # Patching configuration + self.patch_in_forward = patch_in_forward + self.realtime_patching = realtime_patching + self.patch_size = patch_size + self.patching_mode = patching_mode + self.patching_threshold = patching_threshold + self.patching_threshold_add = patching_threshold_add + self.monotonicity = monotonicity + self.patching_batch_size = patching_batch_size + self.patching_device = patching_device + self.max_patch_length = max_patch_length + self.entropy_model_checkpoint_dir = entropy_model_checkpoint_dir + + # Cross attention configurations + self.cross_attn_encoder = cross_attn_encoder + self.cross_attn_decoder = cross_attn_decoder + self.cross_attn_window_encoder = cross_attn_window_encoder + self.cross_attn_window_decoder = cross_attn_window_decoder + self.cross_attn_k = cross_attn_k + self.cross_attn_nheads = cross_attn_nheads + self.cross_attn_all_layers_decoder = cross_attn_all_layers_decoder + self.cross_attn_all_layers_encoder = cross_attn_all_layers_encoder + self.cross_attn_use_flex_attention = cross_attn_use_flex_attention + self.cross_attn_init_by_pooling = cross_attn_init_by_pooling + + # Encoder configurations + self.use_local_encoder_transformer = use_local_encoder_transformer + self.max_encoder_seq_length = max_encoder_seq_length + self.encoder_hash_byte_group_size = encoder_hash_byte_group_size + self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab + self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions + self.encoder_enable_byte_ngrams = encoder_enable_byte_ngrams + self.encoder_ngram_to_size_str = encoder_ngram_to_size_str + self.downsampling_by_pooling = downsampling_by_pooling + + # Model behavior + self.share_encoder_decoder_emb = share_encoder_decoder_emb + self.weight_tying = weight_tying + + # Performance optimization + self.sequence_parallel = sequence_parallel + self.loss_parallel = loss_parallel + self.fuse_sequence_parallel = fuse_sequence_parallel + self.use_fsdp = use_fsdp + + # Parameter mixing + self.pm_size = pm_size + + # Patcher/Entropy model configuration + self.patcher_vocab_size = patcher_vocab_size + self.patcher_dim = patcher_dim + self.patcher_n_layers = patcher_n_layers + self.patcher_n_heads = patcher_n_heads + self.patcher_head_dim = patcher_head_dim + self.patcher_n_kv_heads = patcher_n_kv_heads + self.patcher_max_seqlen = patcher_max_seqlen + self.patcher_norm_eps = patcher_norm_eps + self.patcher_dropout = patcher_dropout + self.patcher_sliding_window = patcher_sliding_window + self.patcher_ffn_dim_multiplier = patcher_ffn_dim_multiplier + self.patcher_multiple_of = patcher_multiple_of + self.patcher_rope_theta = patcher_rope_theta + self.patcher_rope_use_fp32_in_outer_product = patcher_rope_use_fp32_in_outer_product + self.patcher_attn_impl = patcher_attn_impl + self.patcher_attn_bias_type = patcher_attn_bias_type + self.patcher_init_base_std = patcher_init_base_std + self.patcher_init_std_factor = InitStdFactor(patcher_init_std_factor) + self.patcher_dim_token_emb = patcher_dim_token_emb + self.patcher_weight_tying = patcher_weight_tying + self.patcher_bos_token_id = patcher_bos_token_id + self.patcher_eos_token_id = patcher_eos_token_id + + # Handle hash byte group size validation + if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str: + self.encoder_hash_byte_group_size = [ + int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0 + ] + + # Rope + self.rope_scaling={ + "type": "dynamic", + "factor": 2.0, + "rope_type": "dynamic" + } + + self.num_key_value_heads=n_heads_local_encoder + self.max_position_embeddings=max_seqlen + self.hidden_size=dim_local_encoder + self.num_attention_heads=n_heads_local_encoder + # self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + **kwargs, + ) + + @property + def encoder_dim_token_emb(self): + """Compute encoder token embedding dimension.""" + if self.dim_token is not None: + return self.dim_token + elif self.use_local_encoder_transformer: + return self.dim_local_encoder + else: + # Use default patch_size of 8 if not set + patch_size = self.patch_size if self.patch_size is not None else 8 + return self.dim_global // patch_size + + @property + def encoder_dim_patch_emb(self): + """Compute encoder patch embedding dimension.""" + if self.cross_attn_encoder: + if self.cross_attn_init_by_pooling: + return self.dim_local_encoder + else: + return self.dim_global + return None + + @property + def global_dim_patch_emb(self): + """Compute global patch embedding dimension.""" + dim_token_emb = self.encoder_dim_token_emb + if self.cross_attn_encoder: + cross_attn_k = self.cross_attn_k if self.cross_attn_k is not None else 1 + return dim_token_emb * cross_attn_k + elif ( + self.downsampling_by_pooling is None + or not self.downsampling_by_pooling + or len(self.downsampling_by_pooling) == 0 + ): + # Use default patch_size of 8 if not set + patch_size = self.patch_size if self.patch_size is not None else 8 + return dim_token_emb * patch_size + else: + return dim_token_emb * sum([pooling in self.downsampling_by_pooling for pooling in ["avg", "min", "max"]]) + + @property + def decoder_dim_token_emb(self): + """Compute decoder token embedding dimension.""" + if self.share_encoder_decoder_emb: + return self.encoder_dim_token_emb + elif self.dim_token is not None: + return self.dim_token + else: + return self.dim_local_decoder + + def get_init_std_factor(self, depth: int) -> float: + """ + Calculate the initialization standard deviation scaling factor for a given layer depth. + + Args: + depth: Current layer depth (0-indexed) + + Returns: + Scaling factor to divide the base initialization std by + """ + if self.init_std_factor == InitStdFactor.CURRENT_DEPTH: + return (2 * (depth + 1)) ** 0.5 + else: # DISABLED + return 1.0 + + +__all__ = ["BLTConfig", "InitStdFactor", "PatchingModeEnum"] + diff --git a/backup_blt_wip copy/modeling_blt.py b/backup_blt_wip copy/modeling_blt.py new file mode 100644 index 0000000000000000000000000000000000000000..84d874daa76ef209652702d5b9a432fd721f5679 --- /dev/null +++ b/backup_blt_wip copy/modeling_blt.py @@ -0,0 +1,1287 @@ +# coding=utf-8 +# Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BLT model.""" + +from ...utils import is_torch_flex_attn_available, logging +from typing import Callable, List, Optional, Tuple, Union + +from enum import Enum + +from ...cache_utils import Cache +from ...activations import ACT2FN + +import torch +import torch.distributions +import torch.nn +import torch.nn as nn +from torch.nn import functional as F + +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update + +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from .configuration_blt import ( + BLTConfig, + BLTLocalEncoderConfig, + BLTLocalDecoderConfig, + BLTGlobalTransformerConfig, + BLTPatcherConfig, +) + +from ...generation.utils import GenerationMixin +from ...modeling_outputs import CausalLMOutputWithPast + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +class PatchingModeEnum(str, Enum): + entropy = "entropy" + bpe = "bpe" + bpe_patcher = "bpe_patcher" + space = "space" + static = "static" + byte = "byte" + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = F.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + # TODO: not exactly equivalent to other transformers implementations,, need feedback + # Extract first head_dim//2 elements which correspond to the unique frequencies + # This matches the original BLT approach which uses head_dim//2 frequency pairs + head_dim = q.shape[-1] + cos_freqs = cos[..., :head_dim//2] # [B, S, D/2] + sin_freqs = sin[..., :head_dim//2] # [B, S, D/2] + + # Expand cos/sin to match query/key tensor format [B, H, S, D/2] + cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + + # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... + q_pairs = q.view(*q.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] + k_pairs = k.view(*k.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] + + # Extract real and i parts + q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2] + k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2] + + # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag] + q_real_rot = cos_freqs * q_real - sin_freqs * q_imag + q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag + k_real_rot = cos_freqs * k_real - sin_freqs * k_imag + k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag + + # Recombine pairs and reshape back to original format + q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D] + k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D] + + return q_rot.type_as(q), k_rot.type_as(k) + + +class BLTMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class BLTRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + BLTRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class BLTTransformerLayer(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) + self.mlp = BLTMLP(config) + self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor`, *optional*): + Position indices of tokens in the sequence for RoPE computation. + past_key_value (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class BLTSelfAttention(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.config = config + self.num_heads = config.num_attention_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = None + self.rope_theta = config.rope_theta + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + output_attentions: bool = False, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + output_attentions = False + self.config._attn_implementation = "sdpa" + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): + primes = [ + 1000000007, 5915587277, 1500450271, 3267000013, 5754853343, + 4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313, + ] + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) + powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) + prime_powers = prime ** powers + return torch.sum(token_tensor * prime_powers, dim=-1) + + +def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): + """Hash token groups and map to range [0, max_hash].""" + with torch.no_grad(): + batch_size, seq_len = token_ids.shape + # Add padding for sliding window + padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) + padded_tokens = torch.cat([padding, token_ids], dim=1) + + # Create sliding windows and compute hashes + windows = padded_tokens.unfold(1, group_size, 1) + hashes = rolling_polynomial_hash(windows, hash_func_nb) + hash_values = hashes % max_hash + + hash_values.requires_grad = False + return hash_values + + +def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list): + """Initialize hash-based token embeddings for the BLT encoder.""" + num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size) + embeddings = [ + nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim) + for _ in range(num_embeddings) + ] + return nn.ModuleList(embeddings) + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """Compute token embeddings enhanced with hash-based embeddings.""" + embeddings = local_encoder.embed_tokens(local_encoder_tokens) + embedding_idx = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab + ) + embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids) + embedding_idx += 1 + + return embeddings + + +def _prepare_patch_cross_attention_mask( + patch_ids: torch.Tensor, + num_patches: int, + sequence_length: int, + patches_as_queries: bool = False, + cross_attn_k: int = 1, + dtype: torch.dtype = torch.float32, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Prepare cross-attention mask for patch-based attention, following mllama's robust approach. + + This function creates masks that control which patches can attend to which other patches, + with support for query/key role swapping and cross-attention multipliers. + + Args: + patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. + num_patches (int): Total number of patches. + sequence_length (int): Length of the sequence. + patches_as_queries (bool): If True, patches are used as queries, otherwise as keys. + cross_attn_k (int): Cross-attention multiplier for repeating patches. + dtype (torch.dtype): Data type for the output mask. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] + - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows + """ + batch_size, seq_len = patch_ids.shape + device = patch_ids.device + + # Determine query and key lengths based on configuration + if patches_as_queries: + q_len = num_patches * cross_attn_k + kv_len = sequence_length + # Create patch-to-sequence mapping + q_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(-1).expand( + batch_size, num_patches, seq_len + ) + kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) + else: + q_len = sequence_length + kv_len = num_patches * cross_attn_k + # Create sequence-to-patch mapping + q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) + kv_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand( + batch_size, seq_len, num_patches + ) + + # Create base attention mask - boolean mask where True means "should attend" + # Exact patch matching + cross_attention_mask = q_patch_ids == kv_patch_ids + + # Handle cross_attn_k multiplier by repeating along appropriate dimension + repeat_dim = 1 if patches_as_queries else -1 + cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim) + + # Validate dimensions + expected_shape = (batch_size, q_len, kv_len) + if cross_attention_mask.shape != expected_shape: + raise ValueError(f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}") + + # Reshape so it can be used by attn module - add head dimension + cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len] + + # Invert the mask (following mllama pattern exactly) + # True -> 0.0 (attend), False -> 1.0 (will become -inf) + inverted_cross_attn_mask = (1.0 - cross_attention_mask.to(dtype)) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # Apply full-row bias (following mllama pattern exactly) + # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + +def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: + """ + Splits patch lengths into smaller segments if they exceed `max_patch_length`. + Pads the result to uniform length across the batch. + + Args: + patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. + max_patch_length (int, optional): Maximum allowed length per patch. + + Returns: + torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. + """ + if max_patch_length is None: + return patch_lengths + + batch_size = patch_lengths.size(0) + processed = [] + + for seq in patch_lengths: + splits = [] + for length in seq[seq > 0]: + length = length.item() + full_chunks, remainder = divmod(length, max_patch_length) + splits.extend([max_patch_length] * full_chunks) + if remainder: + splits.append(remainder) + processed.append(splits) + + # Find max length to pad to + max_len = max(len(splits) for splits in processed) + padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) + + for i, splits in enumerate(processed): + if splits: + padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) + + # Trim zero columns + if (padded != 0).any(dim=0).sum() < padded.shape[1]: + last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 + padded = padded[:, :last_nonzero] + + return padded + + +class BLTRotaryEmbedding(nn.Module): + def __init__(self, config, device=None): + super().__init__() + self.rope_type = config.rope_scaling["rope_type"] + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class BLTLocalEncoder(nn.Module): + def __init__(self, config: BLTLocalEncoderConfig): + super().__init__() + + self.config = config + + self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + + self.rotary_emb = BLTRotaryEmbedding(config=config) + + self.patch_embedding_projection = nn.Linear( + in_features=config.hidden_size, + out_features=config.hidden_size * config.cross_attn_k, + bias=False, + ) + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) + ) + + def forward( + self, + input_ids: torch.Tensor, + input_embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + num_patches: Optional[int] = None, + patch_ids: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ """ + if input_embeds is None: + input_embeds = self.embed_tokens(input_ids) + + batch_size, _, _ = input_embeds.shape + + hidden_states = F.dropout(input_embeds, p=self.config.dropout, training=self.training) + + position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + + for idx, layer in enumerate(self.layers): + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] + + if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers: + patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size) + + layer_idx = idx if self.config.cross_attn_all_layers else 0 + cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( + hidden_states=patch_embeds, + cross_attention_states=hidden_states, + attention_mask=cross_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + output_attentions=False, + use_cache=False, + cache_position=None, + ) + patch_embeds = patch_embeds + cross_attention_output + + encoder_cross_states = patch_embeds + return hidden_states, encoder_cross_states + + def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): + """ + Reduce variable length patches to single embedding per patch + Note: this works with variable number of patches for different sequences in the batch + It handles variable length patches by assuming that patch_lengths will be 0 for any + extra patches on the *right*. Since there can be a variable number of patches + this function also return the number of patches for each sequence in the batch. + Any embeddings on the right that are not allocated to a patch + (i.e. if the sum(patch_lengths[i]) < seq_len for any i) + will be sent to a dummy patch, which is trimmed before returning. + """ + batch_size, _, embedding_dim = hidden_states.shape + + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) + + reduced_embeddings = torch.zeros((batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device) + reduced_embeddings = reduced_embeddings.scatter_reduce( + src=hidden_states, + dim=1, + index=patch_ids, + reduce=reduction, + include_self=False, + ) + reduced_embeddings = reduced_embeddings[:, :max_num_patches, :] + + return reduced_embeddings + + +class BLTLocalDecoder(nn.Module): + def __init__(self, config: BLTLocalDecoderConfig): + super().__init__() + + # Extract config values to instance attributes + self.config = config + self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove + + self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + + self.rotary_emb = BLTRotaryEmbedding(config=config) + + self.patch_embedding_projection = nn.Linear( + in_features=config.hidden_size_global, + out_features=config.hidden_size * config.cross_attn_k, + bias=False, + ) + + self.norm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) + ) + + # self.lm_head = nn.Linear( + # config.hidden_size, + # config.vocab_size, + # bias=False, + # ) + + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor], + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + batch_size, _, _ = embeds.shape + + hidden_states = embeds + + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size) + + if patch_embeds is not None and not self.cross_attn_decoder: + hidden_states = hidden_states + patch_embeds + + position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + for i, layer in enumerate(self.layers): + if i == 0 or self.config.cross_attn_all_layers: + # Use cross attention to extract info from patch_embeds into hidden_states + cross_attention_output, _, _ = self.cross_attn_layers[i]( + hidden_states=hidden_states, + cross_attention_states=patch_embeds, + attention_mask=cross_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + output_attentions=False, + use_cache=False, + cache_position=None, + ) + hidden_states = hidden_states + cross_attention_output + + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] + + logits = self.norm(hidden_states) + # logits = self.lm_head(logits) + return logits, cache + + +class BLTCrossAttention(nn.Module): + """Cross-attention module for BLT, following transformers style""" + + def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + # Use provided hidden_size or fallback to encoder dimension + self.hidden_size = hidden_size or config.encoder_config.hidden_size + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = None #self.head_dim ** -0.5 + self.dropout = config.dropout + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.q_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) + self.k_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(hidden_states) # BLT normalizes first + query_states = self.q_proj(query_states) + + if cross_attention_states is not None: + cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + if past_key_value is not None: + # if we have a new cross attention states + new tokens, we only computed key_states on that new cross attention states + # we still update the cross key states, past_cross_states, new_cross_states. And use it! + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + elif cache_position is not None and cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + if cross_attention_states is None: + raise ValueError( + "Cross attention layer can't find neither `cross_attention_states` nor cached values for key/values!" + ) + + attention_interface: Callable = eager_attention_forward + + self.config._attn_implementation = "sdpa" + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if full_text_row_masked_out_mask is not None: + attn_output = full_text_row_masked_out_mask[:, 0] * attn_output + + attn_output = attn_output + hidden_states + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class BLTGlobalTransformer(nn.Module): + def __init__(self, config: BLTGlobalTransformerConfig): + super().__init__() + + self.config = config + + self.layers = nn.ModuleList() + for layer_idx in range(config.num_hidden_layers): + self.layers.append(BLTTransformerLayer(config, layer_idx)) + + self.rotary_emb = BLTRotaryEmbedding(config=config) + + + def forward( + self, + input_embeds: torch.Tensor, + mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + batch_size, seq_len, _ = input_embeds.shape + + hidden_states = input_embeds + + hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) + + position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for i, layer in enumerate(self.layers): + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] + + return hidden_states, cache + + + + +class BLTPreTrainedModel(PreTrainedModel): + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = False # BLT uses its own attention implementation + _supports_sdpa = True + _supports_cache_class = False + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + std = getattr(module, '_custom_std', module.in_features ** (-0.5)) + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.Embedding): + std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5)) + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + elif isinstance(module, BLTModel): + if module.encoder_hash_tok_embedding is not None: + emb_std = module.config.encoder_config.hidden_size ** (-0.5) + for emb in module.encoder_hash_tok_embedding: + emb._custom_std = emb_std + + elif isinstance(module, BLTLocalEncoder): + if module.patch_embedding_projection is not None: + module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) + + elif isinstance(module, BLTLocalDecoder): + if module.patch_embedding_projection is not None: + module.patch_embedding_projection._custom_std = module.config.hidden_size ** (-0.5) + + elif isinstance(module, BLTPatcher): + emb_std = module.config.hidden_size ** (-0.5) + module.embed_tokens._custom_std = emb_std + module.lm_head._custom_std = emb_std + + elif isinstance(module, BLTForCausalLM): + if module.lm_head is not None: + module.lm_head._custom_std = module.config.decoder_config.hidden_size ** (-0.5) + + +class BLTModel(BLTPreTrainedModel): + def __init__(self, config: BLTConfig): + super().__init__(config) + self.config = config + self.local_encoder = BLTLocalEncoder(config.encoder_config) + self.global_transformer = BLTGlobalTransformer(config.global_config) + self.local_decoder = BLTLocalDecoder(config.decoder_config) + self.encoder_hash_tok_embedding = init_hash_embeddings( + config, + local_encoder_dim=config.encoder_config.hidden_size, + encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, + ) + if self.config.patch_in_forward: + self.patcher = BLTPatcher(config.patcher_config) + self.patcher.eval() + for param in self.patcher.parameters(): + param.requires_grad = False + else: + self.patcher = None + + def forward( + self, + tokens: torch.Tensor, + patch_lengths: Optional[torch.Tensor] = None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + **kwargs, + ): + """ + Args: + tokens (torch.Tensor): Input token ids. + patch_lengths (Optional[torch.Tensor]): Patch lengths for patching. + attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Ignored, for compatibility. + Returns: + torch.Tensor: Final hidden states (as before). + """ + batch_size, sequence_length = tokens.shape + # Handle patching + if patch_lengths is None: + if self.config.patching_mode == PatchingModeEnum.entropy: + _, patch_lengths, _ = self.patcher( + tokens, + patch_size=self.config.patch_size, + threshold=self.config.patching_threshold, + max_patch_length=self.config.max_patch_length, + patching_batch_size=self.config.patching_batch_size, + device=tokens.device, + ) + else: + patch_lengths = process_patch_lengths( + torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device), + self.config.max_patch_length + ) + patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) + cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( + patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, torch.float32 + ) + encoder_embeds = compute_hash_embeddings( + tokens, self.local_encoder, self.encoder_hash_tok_embedding, + self.config.encoder_hash_byte_group_nb_functions, + self.config.encoder_hash_byte_group_size, + self.config.encoder_hash_byte_group_vocab, + ) + encoder_hidden_states, encoder_cross_states = self.local_encoder( + input_ids=tokens, + input_embeds=encoder_embeds, + patch_embeds=None, + cross_mask=cross_attn_mask_enc, + full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) + global_hidden_states, _ = self.global_transformer( + input_embeds=global_hidden_states, + ) + decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) + cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( + decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, torch.float32 + ) + output, _ = self.local_decoder( + tokens=tokens, + embeds=encoder_hidden_states, + patch_embeds=global_hidden_states, + mask=None, + cross_mask=cross_attn_mask_dec, + full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, + ) + if output_hidden_states or output_attentions: + if return_dict: + return {"last_hidden_state": output, "hidden_states": None, "attentions": None} + else: + return (output, None, None) + return output + + def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: + """Convert patch lengths to patch IDs for each token position.""" + batch_size = patch_lengths.shape[0] + patch_starts = torch.cat([ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1] + ], dim=-1) + + token_positions = torch.arange(seq_len, device=patch_lengths.device) + return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 + + +class BLTPatcher(BLTPreTrainedModel): + def __init__(self, config: BLTPatcherConfig): + super().__init__(config) + + self.rotary_emb = BLTRotaryEmbedding(config=self.config) + + self.layers = nn.ModuleList() + + for layer_idx in range(self.config.num_hidden_layers): + self.layers.append(BLTTransformerLayer(self.config, layer_idx)) + + + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) + + self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps) + + self.lm_head = nn.Linear( + self.config.hidden_size, + self.config.vocab_size, + bias=False, + ) + + def forward( + self, + token_values: torch.Tensor, + patch_size: Optional[int] = None, + threshold: Optional[float] = None, + max_patch_length: Optional[int] = None, + patching_batch_size: int = 1, + device: Optional[str] = None, + ): + + # Handle chunked processing for entropy calculation + entropies = [] + predictions = [] + max_length = self.config.max_position_embeddings + batch_numel = max_length * patching_batch_size + splits = torch.split(token_values.flatten(), batch_numel) + + for split in splits: + pad_size = (max_length - (split.numel() % max_length)) % max_length + pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False) + split = torch.cat((split, pad), dim=0) + split = split.reshape(-1, max_length) + if device is not None: + split = split.to(device) + + # Process chunk: embeddings -> layers -> output + batch_size, sequence_length = split.shape + input_embeds = self.embed_tokens(split) + + hidden_states = input_embeds + + batch_size, _, _ = input_embeds.shape + + position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for i, layer in enumerate(self.layers): + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] + + logits = self.lm_head(self.norm(hidden_states)) + logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] + predictions.append(logits) + prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() + entropies.append(prediction_entropies) + + concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) + concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1) + + # Always compute patch lengths from concatenated entropies + batch_size, sequence_length = token_values.shape + + # Find patch start IDs based on entropy + if patch_size is not None: + patch_lengths = self.patch_lengths_from_entropies( + entropies=concat_entropies, + sequence_length=sequence_length, + patch_size=patch_size, + threshold=threshold, + ) + else: + # Default to byte-level patching + patch_lengths = torch.ones((batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device) + patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) + return concat_entropies, patch_lengths, concat_predictions + + @staticmethod + def patch_lengths_from_entropies( + entropies, + sequence_length, + patch_size=None, + threshold=None, + ): + """ + Computes patch lengths from token entropies. + + Depending on whether a threshold is provided, the function uses either: + - Top-k selection based on entropy (when `threshold` is None), or + - Thresholding the entropy values (when `threshold` is set). + """ + + batch_size = entropies.shape[0] + + # Always include token 0 and 1 as starting tokens + init_tokens = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + offset = init_tokens.shape[1] + + # Ignore first token entropy (BOS) + entropies = entropies[:, 1:] + + if threshold is None: + # Use top-k entropy values to define patch start points + num_patches = sequence_length // patch_size + topk_indices = entropies.topk(num_patches - 2, dim=1).indices + patch_starts = topk_indices.sort(dim=1).values + else: + # Threshold the entropy values to define patch start points + patch_mask = entropies > threshold + + seq_len = patch_mask.shape[1] + + # Create patch IDs (token indices), and add a sentinel to ensure alignment + token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) + sentinel = torch.full_like(token_indices, seq_len) + padded_indices = torch.cat([token_indices, sentinel], dim=1) + + # Pad mask with inverse to align sentinel correctly + padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) + + # Select indices where mask is True + patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) + max_valid_patches = patch_mask.sum(dim=1).max() + patch_starts = patch_starts[:, :max_valid_patches] + + # Offset patch starts to account for the two initial tokens + patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1) + + # Compute patch end positions by shifting start positions + last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1) + patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1) + + patch_lengths = patch_ends - patch_start_ids + 1 + + return patch_lengths + + +class BLTForCausalLM(BLTPreTrainedModel, GenerationMixin): + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] + + def __init__(self, config): + super().__init__(config) + self.model = BLTModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.local_encoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.local_encoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + **kwargs, + ): + """ + Args: + input_ids (torch.LongTensor): Input token ids. + attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs: Standard transformers arguments. + labels (torch.LongTensor, optional): Labels for language modeling loss. + Returns: + CausalLMOutputWithPast or tuple: Standard transformers output. + """ + # Route only input_ids to BLTModel (as tokens) + hidden_states = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + if isinstance(hidden_states, dict): + sequence_output = hidden_states["last_hidden_state"] + elif isinstance(hidden_states, tuple): + sequence_output = hidden_states[0] + else: + sequence_output = hidden_states + logits = self.lm_head(sequence_output) + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + if not return_dict: + output = (logits,) + if loss is not None: + output = (loss,) + output + return output + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + +__all__ = [ + "BLTPreTrainedModel", + "BLTModel", + "BLTPatcher", + "BLTLocalEncoder", + "BLTLocalDecoder", + "BLTGlobalTransformer", + "BLTTransformerLayer", + "BLTForCausalLM", +] \ No newline at end of file diff --git a/backup_blt_wip copy/modeling_blt_old.py b/backup_blt_wip copy/modeling_blt_old.py new file mode 100644 index 0000000000000000000000000000000000000000..7005a4f0fbba1f6bf389c175d608281c068903cd --- /dev/null +++ b/backup_blt_wip copy/modeling_blt_old.py @@ -0,0 +1,1602 @@ +#blt old + +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import logging +import os +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn +import torch.nn as nn +from torch.nn import functional as F +from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention + +from ...modeling_utils import PreTrainedModel +from .configuration_blt_og import ( + BLTConfig, + PatchingModeEnum, +) + +RMSNorm = nn.RMSNorm + +logger = logging.getLogger() + +flex_attention_comp = flex_attention + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "sdpa": + BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0)) + + if attn_bias_type == "causal": + return "causal" + + if BLT_SUPPRESS_ATTN_ERROR == 1: + return "causal" + else: + raise ValueError( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1" + ) + elif attn_impl == "flex_attention": + return create_block_mask(causal_mask, None, None, seqlen, seqlen) + else: + raise NotImplementedError(f"Attention {attn_impl} with {sliding_window} sliding window not implemented") + + +def cross_entropy(pred, target, **kwargs): + return F.nll_loss( + F.log_softmax(pred.flatten(end_dim=-2).float(), -1), + target.flatten(end_dim=-1), + **kwargs, + ) + + +def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims." + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + rope_use_fp32_in_outer_product: bool = False, +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + if rope_use_fp32_in_outer_product: + t = t.to(torch.float32) + + freqs = torch.outer(t, freqs).float() + + cos, sin = freqs.cos(), freqs.sin() + + return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + seq_dim (int): Sequence dimension index. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= seq_dim < ndim + assert freqs_cis.shape == ( + x.shape[seq_dim], + x.shape[-3], + 2, + 2, + ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" + shape = [d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])] + [2, 2] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + seq_dim: int, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, seq_dim).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 + xq_out = (xq_ * freqs_cis).sum(5).flatten(3) + xk_out = (xk_ * freqs_cis).sum(5).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed. +class RotaryEmbedding(torch.nn.Module): + """ + RotaryEmbedding Module + """ + + def __init__( + self, + theta: float, + head_dim: int, + max_seqlen: int = 1024, + rope_use_fp32_in_outer_product: bool = False, + ): + super().__init__() + + self.theta = theta + self.head_dim = head_dim + self.max_seqlen = max_seqlen + self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + dim=head_dim, + end=max_seqlen, + theta=theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, + ), + persistent=False, + ) + + + def forward(self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None): + """ + Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions + Args: + seqlen (int): Contiguous sequence length + tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen + + Returns: + Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis + """ + test = (seqlen is not None) or (tok_idx is not None) + assert test, "Should provide atleast seqlen or tok_idx" + if tok_idx is not None: + return self.freqs_cis[tok_idx] + elif seqlen is not None: + return self.freqs_cis[0:seqlen] + + +class BLTSelfAttention(nn.Module): + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + rope_theta: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + self.rope_theta = rope_theta + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + # B S D + bsz, seq_len, dim = x.shape + + xq = self.wq(x.view_as(x)) + xk = self.wk(x.view_as(x)) + xv = self.wv(x.view_as(x)) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len]) + + # This condition helps us be easily compatible + # with inference by adding a pluggable KVCache + if hasattr(self, "kv_cache"): + xk, xv = self.kv_cache.update(xk, xv, tok_idx) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + if attn_impl == "flex_attention": + assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) + output = flex_attention_comp(xq, xk, xv, block_mask=mask) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + elif attn_impl == "sdpa": + xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) + assert mask is None or isinstance(mask, (str, torch.Tensor)) + is_causal = (mask == "causal") if isinstance(mask, str) else False + mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None + output = F.scaled_dot_product_attention( + xq, + xk, + xv, + is_causal=is_causal, + attn_mask=mask, + ) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + else: + raise NotImplementedError(f"Attention implementation {attn_impl} not supported") + + output_reshaped = output.reshape(output_shape) + + output = self.wo(output_reshaped) + + return output + + +class BLTMLP(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + mp_size: int = 1, + ): + super().__init__() + + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + assert hidden_dim % mp_size == 0 + + self.dim = dim + self.hidden_dim = hidden_dim + + self.w1 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w3 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w2 = nn.Linear( + hidden_dim, + dim, + bias=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # B S D + x1 = self.w1(x.view_as(x)) + x3 = self.w3(x.view_as(x)) + output = self.w2(F.silu(x1) * x3) + return output + + + + +class BLTTransformerLayer(nn.Module): + def __init__(self, args): + super().__init__() + + # Extract parameters from dictionary + dim = args["dim"] + n_heads = args["n_heads"] + head_dim = args["head_dim"] + n_kv_heads = args["n_kv_heads"] + rope_theta = args["rope_theta"] + multiple_of = args["multiple_of"] + ffn_dim_multiplier = args["ffn_dim_multiplier"] + norm_eps = args["norm_eps"] + + assert (head_dim is not None) or (n_heads is not None), "Should specify at least head_dim or n_heads" + self.head_dim = head_dim or dim // n_heads + self.n_heads = n_heads or dim // head_dim + self.n_kv_heads = n_kv_heads or self.n_heads + + assert n_heads % self.n_kv_heads == 0 + assert dim % n_heads == 0 + + self.attention = BLTSelfAttention( + dim=dim, + head_dim=self.head_dim, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + rope_theta=rope_theta, + ) + self.feed_forward = BLTMLP( + dim=dim, + hidden_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + self.attention_norm = RMSNorm(dim, eps=norm_eps) + self.ffn_norm = RMSNorm(dim, eps=norm_eps) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + norm_x = self.attention_norm(x) + attn_out = self.attention( + norm_x, + freq_cis, + tok_idx=tok_idx, + mask=mask, + attn_impl=attn_impl, + ) + h = x + attn_out + h_norm = self.ffn_norm(h) + out = h + self.feed_forward(h_norm) + return out + +def check_non_zero_after_zero(tensor): + zero_mask = tensor == 0 + shifted_mask = torch.cat( + [ + torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), + zero_mask[:, :-1], + ], + dim=1, + ) + non_zero_after_zero = (tensor != 0) & shifted_mask + return non_zero_after_zero.any() + +def rolling_polynomial_hash(t, hash_func_nb: int = 0): + primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, + ] + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) + prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) + return torch.sum(t * prime_powers, dim=-1) + + +def byte_group_hash_function(x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): + """ + Returns a hash of the input x and maps it to a value in the range [0, max_hash]. + + expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. + returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. + + Note: max hash can make a big difference on the number of collisions. + """ + with torch.no_grad(): + bs, seq_len = x.shape + prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) + x = torch.cat([prefix, x], dim=1) + windows = x.unfold(1, group_size, 1) + # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) + hashes = rolling_polynomial_hash(windows, hash_func_nb) + hash_values_range = hashes % max_hash + hash_values_range.requires_grad = False + return hash_values_range + + +def create_patch_mask_from_ids(patch_ids, num_patches, window=None, patches_as_queries=False): + """ + Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) + is True if the patch id at position (i, j) is less than or equal to k. + Args: + patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids. + num_patches (int): Total number of patches. + window (int): If not None, only considers patches within a window of size window. + patches_as_queries (bool): If True, the patches are used as queries + Returns: + torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask. + """ + bs, seq_len = patch_ids.shape + if not patches_as_queries: + q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches) + kv_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(0) + .expand(bs, seq_len, num_patches) + ) + else: + kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len) + q_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(bs, num_patches, seq_len) + ) + if window is None: + mask = q_ids == kv_ids + else: + mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window) + return mask + + +def cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=1, + window=None, + block_mask=True, +): + bs = patch_ids.shape[0] + with torch.no_grad(): + # Create the patch mask + cross_mask = create_patch_mask_from_ids( + patch_ids, + patch_lengths.shape[1], + window=window, + patches_as_queries=patches_as_queries, + ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1) + q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N + kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k + assert cross_mask.shape == ( + bs, + q_len, + kv_len, + ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" + # block_mask = None + if block_mask: + + def patch_mask(b, h, q_idx, kv_idx): + return cross_mask[b, q_idx, kv_idx] + + block_mask = create_block_mask( + patch_mask, + B=bs, + H=None, + Q_LEN=q_len, + KV_LEN=kv_len, + _compile=True, + ) + return block_mask + else: + return torch.where(cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))).unsqueeze( + 1 + ) # [bs, 1, q_len, kv_len] + + +def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor: + if max_patch_length is None: + return patch_lengths + + batch_size = patch_lengths.size(0) + split_all = [] + max_len = 0 + + for seq in patch_lengths: + splits = [] + for length in seq[seq > 0]: + # Split long patches into max_patch_length chunks + full, rem = divmod(length.item(), max_patch_length) + splits.extend([max_patch_length] * full + ([rem] if rem else [])) + split_all.append(splits) + max_len = max(max_len, len(splits)) + + # Pad sequences to the maximum length + padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) + for i, splits in enumerate(split_all): + if splits: + padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) + + # Trim trailing columns that are all zeros + last_non_zero = (padded != 0).flip(1).int().argmax(1).min() + if last_non_zero < padded.shape[1]: + padded = padded[:, :padded.shape[1] - last_non_zero] + + return padded + +class BLTLocalModelBase(nn.Module): + def __init__(self, config: BLTConfig, component_type: str = "encoder"): + super().__init__() + + self.config = config + + if component_type == "encoder": + self.dim = config.dim_local_encoder + self.n_layers = config.n_layers_local_encoder + self.n_heads = config.n_heads_local_encoder + self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen + self.attn_bias_type = "local_block_causal" + self.sliding_window = config.local_attention_window_len + elif component_type == "decoder": + self.dim = config.dim_local_decoder + self.n_layers = config.n_layers_local_decoder + self.n_heads = config.n_heads_local_decoder + self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen + self.attn_bias_type = "local_block_causal" + self.sliding_window = config.local_attention_window_len + else: + raise ValueError(f"Unknown component_type: {component_type}") + + self.dropout = config.dropout + self.vocab_size = config.vocab_size + config.pm_size + self.patch_size = config.patch_size + + self.attn_impl = config.attn_impl + self.use_rope = config.use_rope + self.init_std_factor = config.init_std_factor + self.init_base_std = config.init_base_std + self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None) + self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None) + self.cross_attn_k = getattr(config, "cross_attn_k", None) + self.eos_id = config.eos_token_id + + self.boe_id = config.boe_id + + # Initialize cross attention layers as None (will be set by subclasses if needed) + self.cross_attn_layers = None + + # Create parameter dict for BLTTransformerLayers + layer_params = { + "dim": self.dim, + "n_heads": self.n_heads, + "head_dim": config.head_dim, + "n_kv_heads": getattr(config, "n_kv_heads", None), + "rope_theta": config.rope_theta, + "multiple_of": getattr(config, "multiple_of", 256), + "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None), + "norm_eps": config.norm_eps, + } + + self.layers = nn.ModuleList([BLTTransformerLayer(layer_params) for _ in range(self.n_layers)]) + + if not self.use_rope: + self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length + else: + self.rope = RotaryEmbedding( + theta=config.rope_theta, + head_dim=config.head_dim or self.dim // self.n_heads, + max_seqlen=self.max_seqlen, + rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, + ) + self.pos_embeddings = None + + # Set dimension-specific embedding dimensions + if component_type == "encoder": + self.dim_token_emb = config.encoder_dim_token_emb + self.dim_patch_emb = config.encoder_dim_patch_emb + elif component_type == "decoder": + self.dim_token_emb = config.decoder_dim_token_emb + self.dim_patch_emb = config.dim_global + + self.token_embedding_projection = ( + nn.Linear(self.dim_token_emb, self.dim, bias=False) + if self.dim_token_emb is not None and self.dim_token_emb != self.dim + else None + ) + + self.patch_embedding_projection = self._create_patch_projection(config) + + def _should_create_patch_projection(self, config: BLTConfig): + dimension_mismatch = self.dim_patch_emb is not None and self.dim_patch_emb != self.dim + + # Check cross attention conditions + cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( + config.cross_attn_decoder and config.cross_attn_init_by_pooling + ) + + return dimension_mismatch or cross_attn_conditions + + def _create_patch_projection(self, config): + if not self._should_create_patch_projection(config): + return None + + output_dim = self.dim_token_emb * (self.cross_attn_k or 1) + + return nn.Linear( + in_features=self.dim_patch_emb, + out_features=output_dim, + bias=False, + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + return embeds + else: + return self.tok_embeddings(tokens) + + +class BLTLocalEncoder(BLTLocalModelBase): + def __init__(self, config: BLTConfig): + super().__init__(config, component_type="encoder") + + self.apply_transformer = config.use_local_encoder_transformer + self.downsampling_by_pooling = config.downsampling_by_pooling + self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None + self.cross_attn_encoder = config.cross_attn_encoder + self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder + self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_nheads = config.cross_attn_nheads + + self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim) + + if self.cross_attn_encoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=config.norm_eps, + ) + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + assert self.expects_hash_embeddings, "Not expecting embeddings to be passed." + return embeds + else: + return self.tok_embeddings(tokens) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + num_patches: Optional[int] = None, + patch_ids: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ """ + bs, seqlen = tokens.shape + if mask is None: + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) + + h = self.apply_embedding(tokens, embeds) + + + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + + h = F.dropout(h, p=self.dropout, training=self.training) + + for i, layer in enumerate(self.layers): + h = layer(h, freqs_cis, tok_idx=None, mask=mask, attn_impl=self.attn_impl) + # check if cross attention should be applied to either all layer or only the last layer + if self.cross_attn_encoder and (i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder): + # apply pooling and project + if self.cross_attn_init_by_pooling and patch_embeds is None: + patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids) + if self.patch_embedding_projection is not None: + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim) + + layer_idx = i if self.cross_attn_all_layers_encoder else 0 + patch_embeds_cross = self.cross_attn_layers[layer_idx]( + x=patch_embeds, + kv=h, + mask=cross_mask, + ) + patch_embeds = patch_embeds + patch_embeds_cross + + h_residual = patch_embeds if self.cross_attn_encoder else None + return (h, h_residual), cache + + def patch_reduce(self, h, max_num_patches, reduction, patch_ids): + """ + Reduce variable length patches to single embedding per patch + Note: this works with variable number of patches for different sequences in the batch + It handles variable length patches by assuming that patch_lengths will be 0 for any + extra patches on the *right*. Since there can be a variable number of patches + this function also return the number of patches for each sequence in the batch. + Any embeddings on the right that are not allocated to a patch + (i.e. if the sum(patch_lengths[i]) < seq_len for any i) + will be sent to a dummy patch, which is trimmed before returning. + """ + bs, seq_len, emb_dim = h.shape + + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + + reduced_embs = torch.zeros((bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device) + reduced_embs = reduced_embs.scatter_reduce( + src=h, + dim=1, + index=patch_ids, + reduce=reduction, + include_self=False, + ) + reduced_embs = reduced_embs[:, :max_num_patches, :] + + return reduced_embs + + +class BLTLocalDecoder(BLTLocalModelBase): + def __init__(self, config: BLTConfig): + super().__init__(config, component_type="decoder") + + # Model configuration flags + self.cross_attn_decoder = config.cross_attn_decoder + self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder + self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_nheads = config.cross_attn_nheads + + self.norm = RMSNorm(self.dim, eps=config.norm_eps) + + if self.cross_attn_decoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=config.norm_eps, + ) + ) + + self.output = nn.Linear( + self.dim, + config.vocab_size, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor], + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + bs, seqlen = tokens.shape + assert embeds is not None, "Embeddings must be provided" + + if mask is None: + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) + + h = embeds + + if self.patch_embedding_projection is not None: + assert patch_embeds is not None, "Patch embeddings must be passed." + patch_embeds = self.patch_embedding_projection(patch_embeds) + if self.cross_attn_k is not None: + patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim) + + if patch_embeds is not None and not self.cross_attn_decoder: + h = h + patch_embeds + + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + h = F.dropout(h, p=self.dropout, training=self.training) + for i, layer in enumerate(self.layers): + if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder): + # Use cross attention to extract info from patch_embeds into h + h_cross = self.cross_attn_layers[i]( + x=h, + kv=patch_embeds, + mask=cross_mask, + ) + h = h + h_cross + + h = layer(h, freqs_cis, tok_idx=None, mask=mask, attn_impl=self.attn_impl) + + h_preds = self.norm(h) + h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) + h_preds = self.output(h_preds) + h_preds = h_preds.float() + return h_preds, cache + + +class BLTCrossAttention(nn.Module): + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) + self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + kv: torch.Tensor, + mask: Optional[Union[BlockMask, str]] = None, + ) -> torch.Tensor: + # B S D + bsz, seq_len, _ = x.shape + _, slen_kv, _ = kv.shape + x_norm = self.cross_attn_norm_q(x) + kv = self.cross_attn_norm_kv(kv) + + xq = self.wq(x_norm) + xk = self.wk(kv) + xv = self.wv(kv) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + # assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) + # output = flex_attention_comp(xq, xk, xv, block_mask=mask) + is_causal = (mask == "causal") if isinstance(mask, str) else False + mask = mask if isinstance(mask, torch.Tensor) else None + mask = mask.to(dtype=xq.dtype).to(xq.device) + output = F.scaled_dot_product_attention( + xq, + xk, + xv, + is_causal=is_causal, + attn_mask=mask, + ) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + output = self.wo(output.reshape(output_shape)) + + return x + output + + +class BLTGlobalTransformer(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + + self.dim = config.dim_global + self.rope_embeddings = RotaryEmbedding( + theta=config.rope_theta, + head_dim=config.head_dim or self.config.dim_global // config.n_heads_global, + max_seqlen=config.max_seqlen, + rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, + ) + # Handle both eos_id and eos_token_id for compatibility + self.eos_id = getattr(config, "eos_id", getattr(config, "eos_token_id", 2)) + + # Create parameter dict for BLTTransformerLayers + layer_params = { + "dim": self.dim, + "n_heads": config.n_heads_global, + "head_dim": config.head_dim, + "n_kv_heads": getattr(config, "n_kv_heads_global", None), + "rope_theta": config.rope_theta, + "multiple_of": getattr(config, "multiple_of", 256), + "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None), + "norm_eps": config.norm_eps, + } + + self.layers = nn.ModuleList() + for _ in range(config.n_layers_global): + self.layers.append(BLTTransformerLayer(layer_params)) + + self.token_embedding_projection = None + if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim: + self.token_embedding_projection = nn.Linear( + config.global_dim_patch_emb, + config.dim_global, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + embeds: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + bs, seqlen = tokens.shape + + h = embeds + + mask = ( + mask + if mask is not None + else create_causal_mask( + seqlen, + self.config.attn_impl, + self.config.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) + ) + + if self.token_embedding_projection is not None and h.shape[-1] != self.dim: + h = self.token_embedding_projection(h) + + h = F.dropout(h, p=self.config.dropout, training=self.training) + freq_cis = self.rope_embeddings(seqlen=self.config.max_seqlen, tok_idx=tok_idx) + + for i, layer in enumerate(self.layers): + h = layer(h, freq_cis, tok_idx=None, mask=mask, attn_impl=self.config.attn_impl) + + return h, cache + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """ + Compute embeddings using hash token embeddings. + + Args: + local_encoder_tokens: Input tokens tensor + local_encoder: Encoder object with tok_embeddings method + encoder_hash_tok_embedding: ModuleList of hash token embeddings + encoder_hash_byte_group_nb_functions: Number of hash functions + encoder_hash_byte_group_size: List of byte group sizes + encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings + + Returns: + torch.Tensor: Combined embeddings + """ + if encoder_hash_tok_embedding is None: + return None + + local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens) + + i = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for byte_group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, + byte_group_size, + hash_func_nb=func_nb, + max_hash=encoder_hash_byte_group_vocab, + ) + hash_tok_embedding = encoder_hash_tok_embedding[i] + local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) + i += 1 + + assert i == len(encoder_hash_tok_embedding) + return local_encoder_embeds + + +class BLTPreTrainedModel(PreTrainedModel): + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = False # BLT uses its own attention implementation + _supports_sdpa = True + _supports_cache_class = False + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + std = getattr(module, '_custom_std', module.in_features ** (-0.5)) + + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.Embedding): + std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5)) + + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + elif isinstance(module, (nn.RMSNorm, nn.LayerNorm)): + nn.init.ones_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, RotaryEmbedding): + module.freqs_cis[...] = precompute_freqs_cis( + dim=module.head_dim, + end=module.max_seqlen, + theta=module.theta, + rope_use_fp32_in_outer_product=module.rope_use_fp32_in_outer_product, + ) + + elif isinstance(module, BLTModel): + if module.encoder_hash_tok_embedding is not None: + emb_std = module.local_encoder.dim ** (-0.5) + for emb in module.encoder_hash_tok_embedding: + emb._custom_std = emb_std + + elif isinstance(module, (BLTLocalEncoder, BLTLocalDecoder)): + if module.token_embedding_projection is not None: + module.token_embedding_projection._custom_std = module.dim ** (-0.5) + + if module.patch_embedding_projection is not None: + module.patch_embedding_projection._custom_std = module.dim_patch_emb ** (-0.5) + + elif isinstance(module, BLTGlobalTransformer): + if module.token_embedding_projection is not None: + module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5) + + elif isinstance(module, BLTPatcher): + emb_std = module.config.patcher_dim ** (-0.5) + module.tok_embeddings._custom_std = emb_std + module.output._custom_std = emb_std + + +class BLTModel(BLTPreTrainedModel): + def __init__(self, config: BLTConfig): + super().__init__(config) + + self.config = config + self.local_encoder = BLTLocalEncoder(config) + self.global_transformer = BLTGlobalTransformer(config) + self.local_decoder = BLTLocalDecoder(config) + + self.encoder_hash_tok_embedding = init_hash_embeddings( + config, + local_encoder_dim=self.local_encoder.dim, + encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, + ) + + if config.patch_in_forward: + self.patcher = BLTPatcher(config) + self.patcher.eval() + for param in self.patcher.parameters(): + param.requires_grad = False + else: + self.patcher = None + + def forward( + self, + tokens: torch.Tensor, + patch_lengths: Optional[torch.Tensor] = None, + ): + # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings + # are no longer used in the final BLT model + + bs, N = tokens.shape # Batch size and sequence length + + local_encoder_tokens, local_decoder_tokens = tokens, tokens + + # Patching + if patch_lengths is None: + # assert ( + # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward + # ), "Patch in forward not enabled and no patch_lengths passed." + + # PATCHER MODEL DEFINED + if self.config.patching_mode == PatchingModeEnum.entropy: + _, patch_lengths, _ = self.patcher( + local_encoder_tokens, + patch_size=self.config.patch_size, + include_next_token=True, + threshold=self.config.patching_threshold, + max_patch_length=self.config.max_patch_length, + patching_batch_size=self.config.patching_batch_size, + device=self.config.patching_device, + ) + else: + # self.config.patching_mode == PatchingModeEnum.byte + bs, seq_len = local_encoder_tokens.shape + seq_len_next_tok = seq_len + 1 # include_next_token=True + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device + ) + + patch_lengths = process_patch_lengths(patch_lengths, self.config.max_patch_length) + + #assert torch.min(patch_lengths) >= 0 + # Generate patch IDs from patch_lengths + patch_ids = self._patch_ids_from_lengths(patch_lengths, local_encoder_tokens.shape[-1]) + # assert torch.max(patch_ids) + 1 <= torch.max((patch_lengths != 0).sum(dim=-1)), ( + # f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" + # ) + + cross_attn_mask_enc = None + # Cross-attention encoder + if self.config.cross_attn_encoder: + cross_attn_mask_enc = cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=True, + cross_attn_k=self.config.cross_attn_k, + window=self.config.cross_attn_window_encoder, + block_mask=self.config.cross_attn_use_flex_attention, + ) + + # Hashing and embedding + local_encoder_embeds = compute_hash_embeddings( + local_encoder_tokens=local_encoder_tokens, + local_encoder=self.local_encoder, + encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, + encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions, + encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size, + encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab, + ) + + # NOTE: Frequency-based n-gram embeddings removed as per paper + # The final BLT model uses only hash-based n-gram embeddings + + # Local encoder + (h_encoder, h_cross), cache_encoder = self.local_encoder( + tokens=local_encoder_tokens, + embeds=local_encoder_embeds, + patch_embeds=None, + cross_mask=cross_attn_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + + # Downsampling + h = h_cross.view(bs, patch_lengths.shape[1], -1) + + # Global transformer + global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.config.boe_id) + rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id) + eos_patch_ids = patch_ids[rows, cols] + global_tokens[rows, eos_patch_ids] = self.config.eos_token_id + + h, _ = self.global_transformer( + embeds=h, + tokens=global_tokens, + ) + + # Unpatching + + dec_embeds = h_encoder + + # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens), so we need to map decoder positions to the remaining patches. + decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], local_decoder_tokens.shape[-1]) + # assert torch.max(decoder_patch_ids) + 1 <= h.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" + # assert decoder_patch_ids.shape[1] == dec_embeds.shape[1], ( + # f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" + # ) + + # Cross-attention decoder + if not self.config.cross_attn_decoder: + h = torch.gather(h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])) + cross_attn_mask_dec = None + # assert local_decoder_tokens.shape == h.shape[:-1] + else: + cross_attn_mask_dec = cross_attn_mask( + decoder_patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=self.config.cross_attn_k, + window=self.config.cross_attn_window_decoder, + block_mask=self.config.cross_attn_use_flex_attention, + ) + + # Local decoder + output, _ = self.local_decoder( + embeds=dec_embeds, + patch_embeds=h, + tokens=local_decoder_tokens, + cross_mask=cross_attn_mask_dec, + ) + return output + + def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: + """ + Convert patch lengths to patch IDs for each token position. + For each token position in the sequence, determines which patch it belongs to. + + Args: + patch_lengths: [batch_size, num_patches] - length of each patch + seq_len: total sequence length + + Returns: + patch_ids: [batch_size, seq_len] - patch index for each token position + + Example: + patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1 + seq_len = 10 + Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]] + # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3 + """ + batch_size, num_patches = patch_lengths.shape + + # Create patch start positions: [0, 3, 5, 9] for the example above + patch_starts = torch.cat( + [ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1], # cumsum without the final total + ], + dim=-1, + ) + + # For each token position, find which patch it belongs to + # by finding the rightmost patch start that's <= the position + token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1] + + # Broadcasting: patch_starts[batch, patch] <= token_positions[position] + # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t + position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1) + + # Count how many patch starts are <= each position, then subtract 1 to get patch index + patch_ids = position_ge_patch_start.sum(dim=-1) - 1 + + return patch_ids + + +class BLTPatcher(BLTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.rope_embeddings = RotaryEmbedding( + theta=config.patcher_rope_theta, + head_dim=config.patcher_head_dim or config.patcher_dim // config.patcher_n_heads, + max_seqlen=config.patcher_max_seqlen, + rope_use_fp32_in_outer_product=config.patcher_rope_use_fp32_in_outer_product, + ) + + self.layers = nn.ModuleList() + for _ in range(config.patcher_n_layers): + self.layers.append( + BLTTransformerLayer( + { + "dim": config.patcher_dim, + "n_heads": config.patcher_n_heads, + "head_dim": config.patcher_head_dim, + "n_kv_heads": config.patcher_n_kv_heads, + "rope_theta": config.patcher_rope_theta, + "multiple_of": config.patcher_multiple_of, + "ffn_dim_multiplier": config.patcher_ffn_dim_multiplier, + "norm_eps": config.patcher_norm_eps, + } + ) + ) + + #assert config.patcher_vocab_size > 0 + + self.tok_embeddings = torch.nn.Embedding(config.patcher_vocab_size, config.patcher_dim) + + self.norm = RMSNorm(config.patcher_dim, eps=config.patcher_norm_eps) + + self.output = nn.Linear( + config.patcher_dim, + config.patcher_vocab_size, + bias=False, + ) + + def forward( + self, + token_values: torch.Tensor, + patch_size: Optional[int] = None, + include_next_token: bool = True, + threshold: Optional[float] = None, + max_patch_length: Optional[int] = None, + patching_batch_size: int = 1, + device: Optional[str] = None, + ): + + # Handle chunked processing for entropy calculation + entropies = [] + preds = [] + max_length = self.config.patcher_max_seqlen + batch_numel = max_length * patching_batch_size + splits = torch.split(token_values.flatten(), batch_numel) + + for split in splits: + pad_size = (max_length - (split.numel() % max_length)) % max_length + pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False) + split = torch.cat((split, pad), dim=0) + split = split.reshape(-1, max_length) + if device is not None: + split = split.to(device) + + # Process chunk: embeddings -> layers -> output + bsz, seqlen = split.shape + h = self.tok_embeddings(split) + chunk_mask = create_causal_mask( + seqlen, + self.config.patcher_attn_impl , + self.config.patcher_attn_bias_type, + sliding_window=self.config.patcher_sliding_window, + tokens=split, + eos_id=self.config.eos_id, + ) + + freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None) + + for i, layer in enumerate(self.layers): + h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=self.config.patcher_attn_impl) + + pred = self.output(self.norm(h)) + pred = pred.reshape(-1, pred.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] + preds.append(pred) + pred_entropies = self.entropy(pred) + entropies.append(pred_entropies) + + concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) + concat_preds = torch.cat(preds, dim=0).reshape(token_values.shape[0], -1) + + # Always compute patch lengths from concatenated entropies + bs, seq_len = token_values.shape + seq_len_next_tok = seq_len + 1 if include_next_token else seq_len + + # Find patch start IDs based on entropy + if patch_size is not None: + patch_start_ids = self.find_entropy_patch_start_ids( + concat_entropies, + patch_size, + include_next_token=include_next_token, + threshold=threshold + ) + patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok) + else: + # Default to byte-level patching + patch_lengths = torch.ones((bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device) + + patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) + return concat_entropies, patch_lengths, concat_preds + + + @staticmethod + def entropy(scores): + """ + scores: [bs, seq_len, vocab] + returns [bs, seq_len] + + Computes the entropy for each token in the batch. + Note: uses natural log. + """ + log_probs = F.log_softmax(scores, dim=-1) + probs = torch.exp(log_probs) + p_log_p = log_probs * probs + entropy = -p_log_p.sum(dim=-1) + return entropy + + @staticmethod + def patch_start_ids_from_patch_start_mask(patch_start_mask): + bs, trunc_seq_len = patch_start_mask.shape + max_patches = patch_start_mask.sum(dim=1).max() + if max_patches == 0: + patch_start_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + else: + patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(bs, 1) + extra_patch_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) + patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1) + patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, trunc_seq_len)[:, :max_patches] + return patch_start_ids + + @staticmethod + def patch_lengths_from_start_ids(patch_start_ids, seq_len): + """ + Calculate patch lengths from start ids. + start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then + the rest are filled to the seq len. + seq_len: ex: 7 length of the sequence + + returns the patch lengths: + [1, 6] for the above example. + """ + last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) + patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) + patch_lengths = patch_end_ids - patch_start_ids + 1 + assert torch.all(patch_lengths >= 0), f"{patch_lengths}" + assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}" + return patch_lengths + + @staticmethod + def find_entropy_patch_start_ids( + entropies, + patch_size=None, + threshold=None, + include_next_token=True, + ): + """ + Use entropies to find the start ids of each patch. + Use patch_size or threshold to figure out the total number of patches to allocate. + + When threshold is not None the number of patches is not constant between + different sequences, but patches can be identified incrementally rather than + decided globally using the entire sequence. + """ + bs, seq_len = entropies.shape[:2] + + first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(bs, 1) + preds_truncation_len = first_ids.shape[1] # remove the first preds because they will be start of patches. + entropies = entropies[:, 1:] + if threshold is None: + num_patches = seq_len // patch_size + patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices + patch_start_ids = patch_start_ids.sort(dim=1).values + else: + patch_start_mask = entropies > threshold + if not include_next_token: + patch_start_mask = patch_start_mask[:, :-1] + # patch_start_mask[1:] |= tokens[:-1] < OFFSET + patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask) + + patch_start_ids = torch.cat((first_ids, patch_start_ids + preds_truncation_len), dim=1) + return patch_start_ids + +def init_hash_embeddings( + config, + local_encoder_dim: int, + encoder_hash_byte_group_size: list, +): + """Initialize hash-based token embeddings for the BLT encoder.""" + if config.encoder_hash_byte_group_size is None: + return None + + embeddings = [] + emb_dim = local_encoder_dim + encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab + + for _ in range(config.encoder_hash_byte_group_nb_functions): + for _ in encoder_hash_byte_group_size: + embeddings.append( + nn.Embedding( + encoder_hash_byte_group_vocab, + emb_dim, + ) + ) + + return nn.ModuleList(embeddings) + + +__all__ = [ + "BLTPreTrainedModel", + "BLTModel", + "BLTPatcher", + "BLTLocalEncoder", + "BLTLocalDecoder", + "BLTGlobalTransformer", +] diff --git a/backup_blt_wip copy/modular_blt.py b/backup_blt_wip copy/modular_blt.py new file mode 100644 index 0000000000000000000000000000000000000000..f433e2c8b799f87ea4e86d0a2221bd18db4acd09 --- /dev/null +++ b/backup_blt_wip copy/modular_blt.py @@ -0,0 +1,1180 @@ +# coding=utf-8 +# Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BLT model.""" + +from ...utils import is_torch_flex_attn_available, logging +from typing import Callable, List, Optional, Tuple, Union + +from ...cache_utils import Cache +from ...activations import ACT2FN + +import torch +import torch.distributions +import torch.nn +import torch.nn as nn +from torch.nn import functional as F + +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update + +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from .configuration_blt import ( + BLTConfig, + BLTLocalEncoderConfig, + BLTLocalDecoderConfig, + BLTGlobalTransformerConfig, + BLTPatcherConfig, + PatchingModeEnum, +) + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + from ...integrations.flex_attention import make_flex_block_causal_mask + +from ..mllama.modeling_mllama import repeat_kv, eager_attention_forward, MllamaRotaryEmbedding, MllamaTextRMSNorm, MllamaCrossAttentionDecoderLayer, MllamaTextCrossAttention, MllamaTextSelfAttention + +logger = logging.get_logger(__name__) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + # TODO: not exactly equivalent to other transformers implementations,, need feedback + # Extract first head_dim//2 elements which correspond to the unique frequencies + # This matches the original BLT approach which uses head_dim//2 frequency pairs + head_dim = q.shape[-1] + cos_freqs = cos[..., :head_dim//2] # [B, S, D/2] + sin_freqs = sin[..., :head_dim//2] # [B, S, D/2] + + # Expand cos/sin to match query/key tensor format [B, H, S, D/2] + cos_freqs = cos_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + sin_freqs = sin_freqs.unsqueeze(1).expand(-1, q.shape[1], -1, -1) # [B, 1, S, D/2] -> [B, H, S, D/2] + + # Split q and k into pairs for rotation: (d0, d1), (d2, d3), ... + q_pairs = q.view(*q.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] + k_pairs = k.view(*k.shape[:-1], head_dim//2, 2) # [B, H, S, D/2, 2] + + # Extract real and i parts + q_real, q_imag = q_pairs[..., 0], q_pairs[..., 1] # [B, H, S, D/2] + k_real, k_imag = k_pairs[..., 0], k_pairs[..., 1] # [B, H, S, D/2] + + # Apply rotation: [real', imag'] = [cos*real - sin*imag, sin*real + cos*imag] + q_real_rot = cos_freqs * q_real - sin_freqs * q_imag + q_imag_rot = sin_freqs * q_real + cos_freqs * q_imag + k_real_rot = cos_freqs * k_real - sin_freqs * k_imag + k_imag_rot = sin_freqs * k_real + cos_freqs * k_imag + + # Recombine pairs and reshape back to original format + q_rot = torch.stack([q_real_rot, q_imag_rot], dim=-1).view(*q.shape) # [B, H, S, D] + k_rot = torch.stack([k_real_rot, k_imag_rot], dim=-1).view(*k.shape) # [B, H, S, D] + + return q_rot.type_as(q), k_rot.type_as(k) + + +class BLTMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class BLTRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + BLTRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class BLTTransformerLayer(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.self_attn = BLTSelfAttention(config=config, layer_idx=layer_idx) + self.mlp = BLTMLP(config) + self.input_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = BLTRMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor`, *optional*): + Position indices of tokens in the sequence for RoPE computation. + past_key_value (`Cache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class BLTSelfAttention(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.config = config + self.num_heads = config.num_attention_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = None + self.rope_theta = config.rope_theta + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + output_attentions: bool = False, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + output_attentions = False + # self.config._attn_implementation = "sdpa" + # self.scaling = None + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def rolling_polynomial_hash(token_tensor, hash_func_nb: int = 0): + primes = [ + 1000000007, 5915587277, 1500450271, 3267000013, 5754853343, + 4093082899, 9576890767, 3628273133, 2860486313, 5463458053, 3367900313, + ] + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=token_tensor.device) + powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device) + prime_powers = prime ** powers + return torch.sum(token_tensor * prime_powers, dim=-1) + + +def byte_group_hash_function(token_ids: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): + """Hash token groups and map to range [0, max_hash].""" + with torch.no_grad(): + batch_size, seq_len = token_ids.shape + # Add padding for sliding window + padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device) + padded_tokens = torch.cat([padding, token_ids], dim=1) + + # Create sliding windows and compute hashes + windows = padded_tokens.unfold(1, group_size, 1) + hashes = rolling_polynomial_hash(windows, hash_func_nb) + hash_values = hashes % max_hash + + hash_values.requires_grad = False + return hash_values + + +def init_hash_embeddings(config, local_encoder_dim: int, encoder_hash_byte_group_size: list): + """Initialize hash-based token embeddings for the BLT encoder.""" + num_embeddings = config.encoder_hash_byte_group_nb_functions * len(encoder_hash_byte_group_size) + embeddings = [ + nn.Embedding(config.encoder_hash_byte_group_vocab, local_encoder_dim) + for _ in range(num_embeddings) + ] + return nn.ModuleList(embeddings) + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """Compute token embeddings enhanced with hash-based embeddings.""" + embeddings = local_encoder.embed_tokens(local_encoder_tokens) + embedding_idx = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, group_size, func_nb, encoder_hash_byte_group_vocab + ) + embeddings += encoder_hash_tok_embedding[embedding_idx](hash_ids) + embedding_idx += 1 + + return embeddings + + +def _prepare_patch_cross_attention_mask( + patch_ids: torch.Tensor, + num_patches: int, + sequence_length: int, + patches_as_queries: bool = False, + cross_attn_k: int = 1, + dtype: torch.dtype = torch.float32, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Prepare cross-attention mask for patch-based attention, following mllama's robust approach. + + This function creates masks that control which patches can attend to which other patches, + with support for query/key role swapping and cross-attention multipliers. + + Args: + patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids. + num_patches (int): Total number of patches. + sequence_length (int): Length of the sequence. + patches_as_queries (bool): If True, patches are used as queries, otherwise as keys. + cross_attn_k (int): Cross-attention multiplier for repeating patches. + dtype (torch.dtype): Data type for the output mask. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len] + - full_text_row_masked_out_mask: 4D tensor indicating fully masked rows + """ + batch_size, seq_len = patch_ids.shape + device = patch_ids.device + + # Determine query and key lengths based on configuration + if patches_as_queries: + q_len = num_patches * cross_attn_k + kv_len = sequence_length + # Create patch-to-sequence mapping + q_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(-1).expand( + batch_size, num_patches, seq_len + ) + kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len) + else: + q_len = sequence_length + kv_len = num_patches * cross_attn_k + # Create sequence-to-patch mapping + q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches) + kv_patch_ids = torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand( + batch_size, seq_len, num_patches + ) + + # Create base attention mask - boolean mask where True means "should attend" + # Exact patch matching + cross_attention_mask = q_patch_ids == kv_patch_ids + + # Handle cross_attn_k multiplier by repeating along appropriate dimension + repeat_dim = 1 if patches_as_queries else -1 + cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim) + + # Validate dimensions + expected_shape = (batch_size, q_len, kv_len) + if cross_attention_mask.shape != expected_shape: + raise ValueError(f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}") + + # Reshape so it can be used by attn module - add head dimension + cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len] + + # Invert the mask (following mllama pattern exactly) + # True -> 0.0 (attend), False -> 1.0 (will become -inf) + inverted_cross_attn_mask = (1.0 - cross_attention_mask.to(dtype)) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # Apply full-row bias (following mllama pattern exactly) + # Return 4D tensor of shape [B, H, S1, 1] where value is 0 if a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + +def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor: + """ + Splits patch lengths into smaller segments if they exceed `max_patch_length`. + Pads the result to uniform length across the batch. + + Args: + patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths. + max_patch_length (int, optional): Maximum allowed length per patch. + + Returns: + torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths. + """ + if max_patch_length is None: + return patch_lengths + + batch_size = patch_lengths.size(0) + processed = [] + + for seq in patch_lengths: + splits = [] + for length in seq[seq > 0]: + length = length.item() + full_chunks, remainder = divmod(length, max_patch_length) + splits.extend([max_patch_length] * full_chunks) + if remainder: + splits.append(remainder) + processed.append(splits) + + # Find max length to pad to + max_len = max(len(splits) for splits in processed) + padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) + + for i, splits in enumerate(processed): + if splits: + padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) + + # Trim zero columns + if (padded != 0).any(dim=0).sum() < padded.shape[1]: + last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1 + padded = padded[:, :last_nonzero] + + return padded + + +class BLTRotaryEmbedding(nn.Module): + def __init__(self, config, device=None): + super().__init__() + self.rope_type = config.rope_scaling["rope_type"] + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class BLTLocalEncoder(nn.Module): + def __init__(self, config: BLTLocalEncoderConfig): + super().__init__() + + self.hidden_size = config.hidden_size + self.vocab_size=config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + self.dropout = config.dropout + self.cross_attn_all_layers = config.cross_attn_all_layers + self.cross_attn_k = config.cross_attn_k + + self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(self.num_hidden_layers)]) + + self.rotary_emb = BLTRotaryEmbedding(config=config) + + self.patch_embedding_projection = nn.Linear( + in_features=config.encoder_dim_patch_emb, + out_features=config.encoder_dim_token_emb * config.cross_attn_k, + bias=False, + ) + + self.embed_tokens = nn.Embedding(self.vocab_size + config.pm_size, self.hidden_size) + + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.num_hidden_layers if self.cross_attn_all_layers else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size) + ) + + def forward( + self, + input_ids: torch.Tensor, + input_embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + num_patches: Optional[int] = None, + patch_ids: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ """ + if input_embeds is None: + input_embeds = self.embed_tokens(input_ids) + + batch_size, _, _ = input_embeds.shape + + hidden_states = nn.functional.dropout(input_embeds, p=self.dropout, training=self.training) + + position_ids = torch.arange(input_ids.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + for idx, layer in enumerate(self.layers): + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] + + if idx == len(self.layers) - 1 or self.cross_attn_all_layers: + patch_embeds = self.patch_reduce(hidden_states, num_patches, "amax", patch_ids) + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.hidden_size) + + layer_idx = idx if self.cross_attn_all_layers else 0 + cross_attention_output, _, _ = self.cross_attn_layers[layer_idx]( + hidden_states=patch_embeds, + cross_attention_states=hidden_states, + attention_mask=cross_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + output_attentions=False, + use_cache=False, + cache_position=None, + ) + patch_embeds = patch_embeds + cross_attention_output + + encoder_cross_states = patch_embeds + return hidden_states, encoder_cross_states + + def patch_reduce(self, hidden_states, max_num_patches, reduction, patch_ids): + """ + Reduce variable length patches to single embedding per patch + Note: this works with variable number of patches for different sequences in the batch + It handles variable length patches by assuming that patch_lengths will be 0 for any + extra patches on the *right*. Since there can be a variable number of patches + this function also return the number of patches for each sequence in the batch. + Any embeddings on the right that are not allocated to a patch + (i.e. if the sum(patch_lengths[i]) < seq_len for any i) + will be sent to a dummy patch, which is trimmed before returning. + """ + batch_size, _, embedding_dim = hidden_states.shape + + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) + + reduced_embeddings = torch.zeros((batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device) + reduced_embeddings = reduced_embeddings.scatter_reduce( + src=hidden_states, + dim=1, + index=patch_ids, + reduce=reduction, + include_self=False, + ) + reduced_embeddings = reduced_embeddings[:, :max_num_patches, :] + + return reduced_embeddings + + +class BLTLocalDecoder(nn.Module): + def __init__(self, config: BLTLocalDecoderConfig): + super().__init__() + + # Extract config values to instance attributes + self.hidden_size = config.hidden_size + self.vocab_size=config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + self.dropout = config.dropout + self.cross_attn_decoder = True #config.cross_attn_decoder #TODO: maybe remove + self.cross_attn_all_layers = config.cross_attn_all_layers + self.cross_attn_k = config.cross_attn_k + + self.layers = nn.ModuleList([BLTTransformerLayer(config, layer_idx) for layer_idx in range(self.num_hidden_layers)]) + + self.rotary_emb = BLTRotaryEmbedding(config=config) + + self.patch_embedding_projection = nn.Linear( + in_features=config.hidden_size_global, + out_features=config.decoder_dim_token_emb * config.cross_attn_k, + bias=False, + ) + + self.norm = BLTRMSNorm(self.hidden_size, eps=config.norm_eps) + + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.num_hidden_layers if self.cross_attn_all_layers else 1 + for layer_idx in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention(config=config, layer_idx=layer_idx, hidden_size=self.hidden_size) + ) + + self.lm_head = nn.Linear( + self.hidden_size, + self.vocab_size, + bias=False, + ) + + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor], + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + batch_size, _, _ = embeds.shape + + hidden_states = embeds + + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape(batch_size, patch_embeds.shape[1] * self.cross_attn_k, self.hidden_size) + + if patch_embeds is not None and not self.cross_attn_decoder: + hidden_states = hidden_states + patch_embeds + + position_ids = torch.arange(tokens.shape[1], device=embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + for i, layer in enumerate(self.layers): + if i == 0 or self.cross_attn_all_layers: + # Use cross attention to extract info from patch_embeds into hidden_states + cross_attention_output, _, _ = self.cross_attn_layers[i]( + hidden_states=hidden_states, + cross_attention_states=patch_embeds, + attention_mask=cross_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + output_attentions=False, + use_cache=False, + cache_position=None, + ) + hidden_states = hidden_states + cross_attention_output + + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] + + logits = self.lm_head(self.norm(hidden_states)) + return logits, cache + + +class BLTCrossAttention(nn.Module): + """Cross-attention module for BLT, following transformers style""" + + def __init__(self, config: BLTConfig, layer_idx: int, hidden_size: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + # Use provided hidden_size or fallback to encoder dimension + self.hidden_size = hidden_size or config.hidden_size_local_encoder + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_attention_heads # Assuming same for cross attention + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = None + self.dropout = config.dropout + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.q_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) + self.k_norm = nn.RMSNorm(self.hidden_size, eps=config.norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(hidden_states) # BLT normalizes first + query_states = self.q_proj(query_states) + + if cross_attention_states is not None: + cross_attention_states = self.k_norm(cross_attention_states) # BLT normalizes first + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + if past_key_value is not None: + # if we have a new cross attention states + new tokens, we only computed key_states on that new cross attention states + # we still update the cross key states, past_cross_states, new_cross_states. And use it! + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + elif cache_position is not None and cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + if cross_attention_states is None: + raise ValueError( + "Cross attention layer can't find neither `cross_attention_states` nor cached values for key/values!" + ) + + attention_interface: Callable = eager_attention_forward + + # self.config._attn_implementation = "sdpa" + # attn = "sdpa" + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0, #if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + # Apply full row masking if provided (following mllama pattern) + if full_text_row_masked_out_mask is not None: + attn_output = full_text_row_masked_out_mask[:, 0] * attn_output + + attn_output = attn_output + hidden_states + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class BLTGlobalTransformer(nn.Module): + def __init__(self, config: BLTGlobalTransformerConfig): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.dropout = config.dropout + + self.layers = nn.ModuleList() + for layer_idx in range(self.num_hidden_layers): + self.layers.append(BLTTransformerLayer(config, layer_idx)) + + self.rotary_emb = BLTRotaryEmbedding(config=config) + + + def forward( + self, + input_embeds: torch.Tensor, + mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + batch_size, seq_len, _ = input_embeds.shape + + hidden_states = input_embeds + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + position_ids = torch.arange(seq_len, device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for i, layer in enumerate(self.layers): + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) + hidden_states = layer_outputs[0] + + return hidden_states, cache + + + + +class BLTPreTrainedModel(PreTrainedModel): + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = False # BLT uses its own attention implementation + _supports_sdpa = True + _supports_cache_class = False + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + std = getattr(module, '_custom_std', module.in_features ** (-0.5)) + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.Embedding): + std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5)) + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + elif isinstance(module, BLTModel): + if module.encoder_hash_tok_embedding is not None: + emb_std = module.config.hidden_size_local_encoder ** (-0.5) + for emb in module.encoder_hash_tok_embedding: + emb._custom_std = emb_std + + elif isinstance(module, BLTLocalEncoder): + if module.patch_embedding_projection is not None: + module.patch_embedding_projection._custom_std = module.config.encoder_dim_patch_emb ** (-0.5) + + elif isinstance(module, BLTLocalDecoder): + if module.patch_embedding_projection is not None: + module.patch_embedding_projection._custom_std = module.config.hidden_size_global ** (-0.5) + + elif isinstance(module, BLTPatcher): + emb_std = module.config.hidden_size ** (-0.5) + module.embed_tokens._custom_std = emb_std + module.lm_head._custom_std = emb_std + + +class BLTModel(BLTPreTrainedModel): + def __init__(self, config: BLTConfig): + super().__init__(config) + + self.config = config + + self.local_encoder = BLTLocalEncoder(config.encoder_config) + self.global_transformer = BLTGlobalTransformer(config.global_config) + self.local_decoder = BLTLocalDecoder(config.decoder_config) + + self.encoder_hash_tok_embedding = init_hash_embeddings( + config, + local_encoder_dim=config.hidden_size_local_encoder, + encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, + ) + + if self.config.patch_in_forward: + self.patcher = BLTPatcher(config.patcher_config) + self.patcher.eval() + for param in self.patcher.parameters(): + param.requires_grad = False + else: + self.patcher = None + + def forward(self, tokens: torch.Tensor, patch_lengths: Optional[torch.Tensor] = None): + batch_size, sequence_length = tokens.shape + + # Handle patching + if patch_lengths is None: + if self.config.patching_mode == PatchingModeEnum.entropy: + _, patch_lengths, _ = self.patcher( + tokens, + patch_size=self.config.patch_size, + threshold=self.config.patching_threshold, + max_patch_length=self.config.max_patch_length, + patching_batch_size=self.config.patching_batch_size, + device=self.config.patching_device, + ) + else: + # Default to byte-level patching + patch_lengths = process_patch_lengths( + torch.ones((batch_size, sequence_length + 1), dtype=tokens.dtype, device=tokens.device), + self.config.max_patch_length + ) + + patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length) + cross_attn_mask_enc, full_text_row_masked_out_mask_enc = _prepare_patch_cross_attention_mask( + patch_ids, patch_lengths.shape[1], sequence_length, True, self.config.cross_attn_k, torch.float32 + ) + + encoder_embeds = compute_hash_embeddings( + tokens, self.local_encoder, self.encoder_hash_tok_embedding, + self.config.encoder_hash_byte_group_nb_functions, + self.config.encoder_hash_byte_group_size, + self.config.encoder_hash_byte_group_vocab, + ) + + encoder_hidden_states, encoder_cross_states = self.local_encoder( + input_ids=tokens, + input_embeds=encoder_embeds, + patch_embeds=None, + cross_mask=cross_attn_mask_enc, + full_text_row_masked_out_mask=full_text_row_masked_out_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + + global_hidden_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1) + + global_hidden_states, _ = self.global_transformer( + input_embeds=global_hidden_states, + ) + + decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length) + cross_attn_mask_dec, full_text_row_masked_out_mask_dec = _prepare_patch_cross_attention_mask( + decoder_patch_ids, patch_lengths.shape[1], sequence_length, False, self.config.cross_attn_k, torch.float32 + ) + + output, _ = self.local_decoder( + tokens=tokens, + embeds=encoder_hidden_states, + patch_embeds=global_hidden_states, + mask=None, + cross_mask=cross_attn_mask_dec, + full_text_row_masked_out_mask=full_text_row_masked_out_mask_dec, + ) + + return output + + def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: + """Convert patch lengths to patch IDs for each token position.""" + batch_size = patch_lengths.shape[0] + patch_starts = torch.cat([ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1] + ], dim=-1) + + token_positions = torch.arange(seq_len, device=patch_lengths.device) + return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1 + + +class BLTPatcher(BLTPreTrainedModel): + def __init__(self, config: BLTPatcherConfig): + super().__init__(config) + + self.rotary_emb = BLTRotaryEmbedding(config=self.config) + + self.layers = nn.ModuleList() + # Create transformer layers using the patcher config + for layer_idx in range(self.config.num_hidden_layers): + self.layers.append(BLTTransformerLayer(self.config, layer_idx)) + + + self.embed_tokens = torch.nn.Embedding(self.config.vocab_size, self.config.hidden_size) + + self.norm = BLTRMSNorm(self.config.hidden_size, eps=self.config.norm_eps) + + self.lm_head = nn.Linear( + self.config.hidden_size, + self.config.vocab_size, + bias=False, + ) + + def forward( + self, + token_values: torch.Tensor, + patch_size: Optional[int] = None, + threshold: Optional[float] = None, + max_patch_length: Optional[int] = None, + patching_batch_size: int = 1, + device: Optional[str] = None, + ): + + # Handle chunked processing for entropy calculation + entropies = [] + predictions = [] + max_length = self.config.max_position_embeddings + batch_numel = max_length * patching_batch_size + splits = torch.split(token_values.flatten(), batch_numel) + + for split in splits: + pad_size = (max_length - (split.numel() % max_length)) % max_length + pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False) + split = torch.cat((split, pad), dim=0) + split = split.reshape(-1, max_length) + if device is not None: + split = split.to(device) + + # Process chunk: embeddings -> layers -> output + batch_size, sequence_length = split.shape + input_embeds = self.embed_tokens(split) + + hidden_states = input_embeds + + batch_size, _, _ = input_embeds.shape + + position_ids = torch.arange(split.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) # = BLT self.rope + + for i, layer in enumerate(self.layers): + layer_outputs = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=None) #, attn_impl=self.config.patcher_attn_impl ) + hidden_states = layer_outputs[0] + + logits = self.lm_head(self.norm(hidden_states)) + logits = logits.reshape(-1, logits.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] + predictions.append(logits) + prediction_entropies = torch.distributions.Categorical(logits=logits).entropy() + entropies.append(prediction_entropies) + + concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) + concat_predictions = torch.cat(predictions, dim=0).reshape(token_values.shape[0], -1) + + # Always compute patch lengths from concatenated entropies + batch_size, sequence_length = token_values.shape + + # Find patch start IDs based on entropy + if patch_size is not None: + patch_lengths = self.patch_lengths_from_entropies( + entropies=concat_entropies, + sequence_length=sequence_length, + patch_size=patch_size, + threshold=threshold, + ) + else: + # Default to byte-level patching + patch_lengths = torch.ones((batch_size, sequence_length), dtype=token_values.dtype, device=token_values.device) + patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) + return concat_entropies, patch_lengths, concat_predictions + + @staticmethod + def patch_lengths_from_entropies( + entropies, + sequence_length, + patch_size=None, + threshold=None, + ): + """ + Computes patch lengths from token entropies. + + Depending on whether a threshold is provided, the function uses either: + - Top-k selection based on entropy (when `threshold` is None), or + - Thresholding the entropy values (when `threshold` is set). + """ + + batch_size = entropies.shape[0] + + # Always include token 0 and 1 as starting tokens + init_tokens = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1) + offset = init_tokens.shape[1] + + # Ignore first token entropy (BOS) + entropies = entropies[:, 1:] + + if threshold is None: + # Use top-k entropy values to define patch start points + num_patches = sequence_length // patch_size + topk_indices = entropies.topk(num_patches - 2, dim=1).indices + patch_starts = topk_indices.sort(dim=1).values + else: + # Threshold the entropy values to define patch start points + patch_mask = entropies > threshold + + seq_len = patch_mask.shape[1] + + # Create patch IDs (token indices), and add a sentinel to ensure alignment + token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1) + sentinel = torch.full_like(token_indices, seq_len) + padded_indices = torch.cat([token_indices, sentinel], dim=1) + + # Pad mask with inverse to align sentinel correctly + padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1) + + # Select indices where mask is True + patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len) + max_valid_patches = patch_mask.sum(dim=1).max() + patch_starts = patch_starts[:, :max_valid_patches] + + # Offset patch starts to account for the two initial tokens + patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1) + + # Compute patch end positions by shifting start positions + last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1) + patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1) + + patch_lengths = patch_ends - patch_start_ids + 1 + + return patch_lengths + +__all__ = [ + "BLTPreTrainedModel", + "BLTModel", + "BLTPatcher", + "BLTLocalEncoder", + "BLTLocalDecoder", + "BLTGlobalTransformer", + "BLTTransformerLayer", +] \ No newline at end of file diff --git a/backup_blt_wip copy/tokenization_blt.py b/backup_blt_wip copy/tokenization_blt.py new file mode 100644 index 0000000000000000000000000000000000000000..d116fae5188fcd12c741ddfcdb6dfb5973585b28 --- /dev/null +++ b/backup_blt_wip copy/tokenization_blt.py @@ -0,0 +1,271 @@ +# coding=utf-8 +# Copyright 2025 the Facebook Research and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for BLT.""" + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +if TYPE_CHECKING: + from ...tokenization_utils_base import TextInput + +logger = logging.get_logger(__name__) + +# BLT tokenizer constants +SEP = " " +BOS_ID: int = 1 +EOS_ID: int = 2 +PAD_ID: int = -1 +BOE_ID: int = 0 +BPE_ID: int = 3 +OFFSET: int = 4 +BYTE_UNITS: int = 256 + +VOCAB_FILES_NAMES = {} # BLT doesn't require external vocab files + + +class BLTTokenizer(PreTrainedTokenizer): + """ + Construct a BLT tokenizer. Based on byte-level tokenization where each byte is treated as a token. + + This tokenizer converts text to UTF-8 bytes and then maps each byte to a token ID with an offset. + It supports special tokens for beginning of sequence (BOS), end of sequence (EOS), + beginning of example (BOE), and padding (PAD). + + Args: + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The padding token. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. Not used in BLT but kept for compatibility. + boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of example token, specific to BLT. + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add a `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + bos_token="", + eos_token="", + pad_token="", + unk_token="", + boe_token="", + add_bos_token=True, + add_eos_token=True, + clean_up_tokenization_spaces=False, + spaces_between_special_tokens=False, + **kwargs, + ): + # Store BLT-specific parameters first + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.vocab_size_unit_1 = BYTE_UNITS + self.offsetting_special_char = OFFSET + + # BLT token IDs (exactly like original) + self.boe_id = BOE_ID + self.bos_id = BOS_ID + self.eos_id = EOS_ID + self.pad_id = PAD_ID + self.bpe_id = BPE_ID + self.n_words = self.vocab_size_unit_1 + self.offsetting_special_char + + # Convert string tokens to AddedToken objects + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + self.boe_token = AddedToken(boe_token, normalized=False, special=True) if isinstance(boe_token, str) else boe_token + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + unk_token=unk_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + spaces_between_special_tokens=spaces_between_special_tokens, + **kwargs, + ) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.vocab_size_unit_1 + self.offsetting_special_char + + def get_vocab(self): + """Returns vocab as a dict""" + # Create a mapping for byte values + offset + vocab = {} + + # Add special tokens (with defensive checks) + if hasattr(self, 'bos_token'): + vocab[str(self.bos_token)] = self.bos_id + if hasattr(self, 'eos_token'): + vocab[str(self.eos_token)] = self.eos_id + if hasattr(self, 'pad_token'): + vocab[str(self.pad_token)] = self.pad_id + if hasattr(self, 'boe_token'): + vocab[str(self.boe_token)] = self.boe_id + + # Add byte tokens as string representations of byte values + vocab_size_unit_1 = getattr(self, 'vocab_size_unit_1', BYTE_UNITS) + offsetting_special_char = getattr(self, 'offsetting_special_char', OFFSET) + for i in range(vocab_size_unit_1): + vocab[str(i)] = i + offsetting_special_char + + # Add any additional tokens if available + if hasattr(self, 'added_tokens_encoder'): + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str, **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. For BLT, we work directly with byte values. + Returns a list of strings that represent the byte values. + """ + # Convert text to UTF-8 bytes, just like the original + try: + bytes_data = text.encode("utf-8", errors="ignore") + except UnicodeEncodeError: + bytes_data = text.encode("utf-8", errors="ignore") + + # Return string representations of byte values for the tokenizer framework + return [str(byte_val) for byte_val in bytes_data] + + def _convert_token_to_id(self, token: str) -> int: + """Converts a token (str) to an id using the vocab.""" + # Handle special tokens + if token == str(self.bos_token): + return self.bos_id + elif token == str(self.eos_token): + return self.eos_id + elif token == str(self.pad_token): + return self.pad_id + elif token == str(self.boe_token): + return self.boe_id + else: + try: + # Convert byte value string to int and add offset (like original) + byte_val = int(token) + if 0 <= byte_val <= 255: + return byte_val + self.offsetting_special_char + except ValueError: + pass + + # Check if it's in added tokens + return self.added_tokens_encoder.get(token, self.unk_token_id) + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) to a token (str) using the vocab.""" + # Handle special tokens + if index == self.bos_id: + return str(self.bos_token) + elif index == self.eos_id: + return str(self.eos_token) + elif index == self.pad_id: + return str(self.pad_token) + elif index == self.boe_id: + return str(self.boe_token) + elif index >= self.offsetting_special_char and index < self.vocab_size: + # Convert back to byte value (like original) + byte_val = index - self.offsetting_special_char + return str(byte_val) + else: + # Check added tokens + for token, token_id in self.added_tokens_encoder.items(): + if token_id == index: + return token + return str(self.unk_token) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """Converts a sequence of tokens to a single string.""" + byte_values = [] + + for token in tokens: + # Skip special tokens + if token in [str(self.bos_token), str(self.eos_token), str(self.pad_token), str(self.boe_token)]: + continue + + try: + # Convert token back to byte value (like original decode method) + byte_val = int(token) + if 0 <= byte_val <= 255: + byte_values.append(byte_val) + except ValueError: + continue + + # Convert byte values back to string (exactly like original) + try: + return bytes(byte_values).decode("utf-8", errors="ignore") + except (UnicodeDecodeError, ValueError): + return "" + + def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None): + """ + Encode text exactly like the original BLT tokenizer. + """ + if add_bos is None: + add_bos = self.add_bos_token + if add_eos is None: + add_eos = self.add_eos_token + + # Since bpe_delim=False, we use the simple byte encoding + tokens = bytes(text, encoding="utf-8", errors="ignore") + + # Offsetting (exactly like original) + tokens = [int(unit) + self.offsetting_special_char for unit in tokens] + + if add_bos: + tokens.insert(0, self.bos_id) + if add_eos: + tokens.append(self.eos_id) + + return tokens + + def decode(self, tokens: list[int], cut_at_eos: bool = False): + """ + Decode tokens exactly like the original BLT tokenizer. + """ + if cut_at_eos: + for k, t in enumerate(tokens): + if t == self.eos_id: + tokens = tokens[: k + 1] + break + return bytes( + [tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0] + ).decode("utf-8", errors="ignore") + + def get_vocab_size(self) -> int: + """Get vocab size like the original tokenizer.""" + return self.vocab_size_unit_1 + self.offsetting_special_char + +#__all__ = ["BLTTokenizer"] \ No newline at end of file diff --git a/backup_blt_wip_backup/__pycache__/blt_args.cpython-312.pyc b/backup_blt_wip_backup/__pycache__/blt_args.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe1c3fac23fde31ac2fe95aa5c15d35368e766fc Binary files /dev/null and b/backup_blt_wip_backup/__pycache__/blt_args.cpython-312.pyc differ diff --git a/backup_blt_wip_backup/__pycache__/blt_one_file.cpython-312.pyc b/backup_blt_wip_backup/__pycache__/blt_one_file.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad67ee5b9c82f98e084ea14790b83f2fbbca1539 Binary files /dev/null and b/backup_blt_wip_backup/__pycache__/blt_one_file.cpython-312.pyc differ diff --git a/backup_blt_wip_backup/__pycache__/configuration_blt.cpython-312.pyc b/backup_blt_wip_backup/__pycache__/configuration_blt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29c901cfa128a76934a45a17664faae1d2ff226c Binary files /dev/null and b/backup_blt_wip_backup/__pycache__/configuration_blt.cpython-312.pyc differ diff --git a/backup_blt_wip_backup/__pycache__/modeling_blt_wip.cpython-312.pyc b/backup_blt_wip_backup/__pycache__/modeling_blt_wip.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14471f3e8bde08cd94f2ba3b8ea1dc01597af449 Binary files /dev/null and b/backup_blt_wip_backup/__pycache__/modeling_blt_wip.cpython-312.pyc differ diff --git a/backup_blt_wip_backup/__pycache__/modeling_blt_wip_backup.cpython-312.pyc b/backup_blt_wip_backup/__pycache__/modeling_blt_wip_backup.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01bb41a2f95cf657f1cb253e41a8eadbe9a4c5b7 Binary files /dev/null and b/backup_blt_wip_backup/__pycache__/modeling_blt_wip_backup.cpython-312.pyc differ diff --git a/backup_blt_wip_backup/__pycache__/tokenization_blt.cpython-312.pyc b/backup_blt_wip_backup/__pycache__/tokenization_blt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f38c87100166c7b926f6683567c5df1366316c48 Binary files /dev/null and b/backup_blt_wip_backup/__pycache__/tokenization_blt.cpython-312.pyc differ diff --git a/backup_blt_wip_backup/blt_args.py b/backup_blt_wip_backup/blt_args.py new file mode 100644 index 0000000000000000000000000000000000000000..e043d1dd20a85564ed998ec608481313b30f5406 --- /dev/null +++ b/backup_blt_wip_backup/blt_args.py @@ -0,0 +1,187 @@ +from enum import Enum +from typing import Any + +from pydantic import BaseModel, ConfigDict, model_validator +from typing_extensions import Self + + +EOS_ID: int = 2 + + +class InitStdFactor(str, Enum): + DISABLED = "disabled" # Init std is divided by 1.0 + GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers) + CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) + DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 + + +class PatchingModeEnum(str, Enum): + entropy = "entropy" + bpe = "bpe" + bpe_patcher = "bpe_patcher" + space = "space" + static = "static" + byte = "byte" + + +class LMTransformerArgs(BaseModel): + """Arguments for the Language Model Transformer (used as entropy model for patching)""" + + model_config = ConfigDict() + + # Basic architecture + dim: int = 512 + n_layers: int = 8 + head_dim: int | None = None + n_heads: int | None = None + n_kv_heads: int | None = None + + # Transformer configuration + max_seqlen: int = 1024 + norm_eps: float = 1e-5 + dropout: float = 0 + vocab_size: int = -1 + sliding_window: int | None = None + + # Feedforward + ffn_dim_multiplier: float | None = None + multiple_of: int = 256 + + # Positional encoding + rope_theta: float = 10000.0 + rope_use_fp32_in_outer_product: bool = False + + # Attention + attn_impl: str = "sdpa" + attn_bias_type: str = "causal" + + # Initialization + init_base_std: float | None = None + init_std_factor: InitStdFactor = InitStdFactor.DISABLED + + # Embedding dimensions + dim_token_emb: int | None = None + + # Model behavior + weight_tying: bool = False + seed: int = 42 + + # Special token config + eos_id: int = EOS_ID + + +class ByteLatentTransformerArgs(BaseModel): + """Arguments for the Byte Latent Transformer (main BLT model)""" + + model_config = ConfigDict() + + # Basic model configuration + seed: int = 42 + vocab_size: int = -1 + + # Main architecture dimensions (these will be used for creating transformer args) + dim: int = 512 + n_layers: int = 8 + head_dim: int | None = None + n_heads: int | None = None + n_kv_heads: int | None = None + + # Component-specific dimensions + dim_global: int = 512 + dim_local_decoder: int = 512 + dim_local_encoder: int = 512 + n_layers_global: int = 8 + n_layers_local_decoder: int = 8 + n_layers_local_encoder: int = 8 + n_heads_global: int = 8 + n_heads_local_decoder: int = 8 + n_heads_local_encoder: int = 8 + n_kv_heads_global: int | None = None + + # Transformer configuration (needed by transformer components) + max_seqlen: int = 1024 + norm_eps: float = 1e-5 + dropout: float = 0 + + # Feedforward (needed by transformer components) + ffn_dim_multiplier: float = 1.0 + multiple_of: int = 256 + + # Positional encoding (needed by transformer components) + rope_theta: float = 10000.0 + rope_use_fp32_in_outer_product: bool = False + + # Attention (needed by transformer components) + attn_impl: str = "sdpa" + attn_bias_type: str = "causal" + + # Initialization (needed by transformer components) + init_base_std: float | None = None + init_std_factor: InitStdFactor = InitStdFactor.DISABLED + + # Embedding dimensions (needed by transformer components) + dim_token_emb: int | None = None + + # Patching configuration + patch_in_forward: bool = False + realtime_patching: bool = True + patch_size: float | None = None + patching_mode: str | None = None + patching_threshold: float | None = None + patching_threshold_add: float | None = None + monotonicity: bool = False + patching_batch_size: int = 1 + patching_device: str = "cuda" + max_patch_length: int | None = None + entropy_model_checkpoint_dir: str | None = None + + # Cross attention configurations + cross_attn_encoder: bool = False + cross_attn_decoder: bool = False + cross_attn_window_encoder: int | None = None + cross_attn_window_decoder: int | None = None + cross_attn_k: int | None = None + cross_attn_nheads: int | None = None + cross_attn_all_layers_decoder: bool = False + cross_attn_all_layers_encoder: bool = False + cross_attn_use_flex_attention: bool = True + cross_attn_init_by_pooling: bool = False + + # Encoder configurations + use_local_encoder_transformer: bool = False + max_encoder_seq_length: int | None = None + encoder_hash_byte_group_size: Any | None = None + encoder_hash_byte_group_vocab: int = 30000 + encoder_hash_byte_group_nb_functions: int = 3 + encoder_enable_byte_ngrams: bool = False + encoder_ngram_to_size_str: str | None = None + downsampling_by_pooling: str | None = None + + # Architecture and dimensions + dim_token: int | None = None + share_encoder_decoder_emb: bool = True + weight_tying: bool = False + + # Attention configuration + local_attention_window_len: int | None = None + use_rope: bool = True + + # Performance optimization + sequence_parallel: bool = False + loss_parallel: bool = False + fuse_sequence_parallel: bool = False + use_fsdp: bool = True + + # Parameter mixing + pm_size: int = 0 + + # Special token config + eos_id: int = EOS_ID + + @model_validator(mode="after") + def check_hash_byte_sizes(self) -> Self: + if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str: + self.encoder_hash_byte_group_size = [ + int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0 + ] + return self diff --git a/backup_blt_wip_backup/configuration_blt.py b/backup_blt_wip_backup/configuration_blt.py new file mode 100644 index 0000000000000000000000000000000000000000..7c645e8b18f63496c4f3419243250465eac941e8 --- /dev/null +++ b/backup_blt_wip_backup/configuration_blt.py @@ -0,0 +1,590 @@ +# coding=utf-8 +# Copyright 2024 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BLT (Byte Latent Transformer) model configuration""" + +from enum import Enum + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class InitStdFactor(str, Enum): + DISABLED = "disabled" # Init std is divided by 1.0 + CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) + + +class PatchingModeEnum(str, Enum): + entropy = "entropy" + bpe = "bpe" + bpe_patcher = "bpe_patcher" + space = "space" + static = "static" + byte = "byte" + + +class BLTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ByteLatentTransformer`]. It is used to instantiate a + BLT model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 256): + Vocabulary size of the BLT model. Defines the number of different tokens (bytes) that can be represented. + max_seqlen (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model can handle. + + # Main architecture dimensions + dim (`int`, *optional*, defaults to 512): + Main dimension of the model. + n_layers (`int`, *optional*, defaults to 8): + Number of layers in the main transformer. + n_heads (`int`, *optional*, defaults to 8): + Number of attention heads in the main transformer. + head_dim (`int`, *optional*): + Dimension of each attention head. If not specified, computed as dim // n_heads. + n_kv_heads (`int`, *optional*): + Number of key-value heads for grouped query attention. If not specified, defaults to n_heads. + + # Component-specific dimensions + dim_global (`int`, *optional*, defaults to 512): + Dimension of the global transformer component. + dim_local_decoder (`int`, *optional*, defaults to 512): + Dimension of the local decoder component. + dim_local_encoder (`int`, *optional*, defaults to 512): + Dimension of the local encoder component. + n_layers_global (`int`, *optional*, defaults to 8): + Number of layers in the global transformer. + n_layers_local_decoder (`int`, *optional*, defaults to 8): + Number of layers in the local decoder. + n_layers_local_encoder (`int`, *optional*, defaults to 8): + Number of layers in the local encoder. + n_heads_global (`int`, *optional*, defaults to 8): + Number of attention heads in the global transformer. + n_heads_local_decoder (`int`, *optional*, defaults to 8): + Number of attention heads in the local decoder. + n_heads_local_encoder (`int`, *optional*, defaults to 8): + Number of attention heads in the local encoder. + n_kv_heads_global (`int`, *optional*): + Number of key-value heads in the global transformer. + + # Transformer configuration + norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers. + ffn_dim_multiplier (`float`, *optional*, defaults to 1.0): + Multiplier for the feedforward network dimension. + multiple_of (`int`, *optional*, defaults to 256): + Make feedforward network dimension multiple of this value. + + # Positional encoding + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False): + Whether to use fp32 in RoPE outer product computation. + + # Attention configuration + attn_impl (`str`, *optional*, defaults to "sdpa"): + Attention implementation to use ("sdpa" or "flex_attention"). + attn_bias_type (`str`, *optional*, defaults to "causal"): + Type of attention bias to apply. + local_attention_window_len (`int`, *optional*): + Window length for local attention. + use_rope (`bool`, *optional*, defaults to True): + Whether to use rotary position embeddings. + + # Initialization + init_base_std (`float`, *optional*): + Base standard deviation for weight initialization. + init_std_factor (`str`, *optional*, defaults to "disabled"): + Factor for adjusting initialization standard deviation. + + # Embedding dimensions + dim_token_emb (`int`, *optional*): + Token embedding dimension. + dim_token (`int`, *optional*): + Token dimension. + + # Patching configuration + patch_in_forward (`bool`, *optional*, defaults to False): + Whether to perform patching during forward pass. + realtime_patching (`bool`, *optional*, defaults to True): + Whether to use realtime patching. + patch_size (`float`, *optional*): + Size of patches for static patching. + patching_mode (`str`, *optional*): + Mode for patching ("entropy", "static", etc.). + patching_threshold (`float`, *optional*): + Threshold for entropy-based patching. + patching_threshold_add (`float`, *optional*): + Additional threshold parameter for patching. + monotonicity (`bool`, *optional*, defaults to False): + Whether to enforce monotonicity in patching. + patching_batch_size (`int`, *optional*, defaults to 1): + Batch size for patching operations. + patching_device (`str`, *optional*, defaults to "cuda"): + Device to use for patching operations. + max_patch_length (`int`, *optional*): + Maximum length of patches. + entropy_model_checkpoint_dir (`str`, *optional*): + Directory containing entropy model checkpoint. + + # Cross attention configurations + cross_attn_encoder (`bool`, *optional*, defaults to False): + Whether to use cross attention in encoder. + cross_attn_decoder (`bool`, *optional*, defaults to False): + Whether to use cross attention in decoder. + cross_attn_window_encoder (`int`, *optional*): + Cross attention window for encoder. + cross_attn_window_decoder (`int`, *optional*): + Cross attention window for decoder. + cross_attn_k (`int`, *optional*): + Number of cross attention components. + cross_attn_nheads (`int`, *optional*): + Number of heads for cross attention. + cross_attn_all_layers_decoder (`bool`, *optional*, defaults to False): + Whether to apply cross attention to all decoder layers. + cross_attn_all_layers_encoder (`bool`, *optional*, defaults to False): + Whether to apply cross attention to all encoder layers. + cross_attn_use_flex_attention (`bool`, *optional*, defaults to True): + Whether to use flexible attention for cross attention. + cross_attn_init_by_pooling (`bool`, *optional*, defaults to False): + Whether to initialize cross attention by pooling. + + # Encoder configurations + use_local_encoder_transformer (`bool`, *optional*, defaults to False): + Whether to use transformer in local encoder. + max_encoder_seq_length (`int`, *optional*): + Maximum sequence length for encoder. + encoder_hash_byte_group_size (`Any`, *optional*): + Hash byte group size for encoder. + encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 30000): + Vocabulary size for hash byte groups. + encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 3): + Number of hash functions for byte groups. + encoder_enable_byte_ngrams (`bool`, *optional*, defaults to False): + Whether to enable byte n-grams in encoder. + encoder_ngram_to_size_str (`str`, *optional*): + String defining n-gram sizes. + downsampling_by_pooling (`str`, *optional*): + Type of pooling for downsampling. + + # Model behavior + share_encoder_decoder_emb (`bool`, *optional*, defaults to True): + Whether to share encoder and decoder embeddings. + weight_tying (`bool`, *optional*, defaults to False): + Whether to tie input and output embeddings. + + # Performance optimization + sequence_parallel (`bool`, *optional*, defaults to False): + Whether to use sequence parallelism. + loss_parallel (`bool`, *optional*, defaults to False): + Whether to use loss parallelism. + fuse_sequence_parallel (`bool`, *optional*, defaults to False): + Whether to fuse sequence parallel operations. + use_fsdp (`bool`, *optional*, defaults to True): + Whether to use fully sharded data parallel. + + # Parameter mixing + pm_size (`int`, *optional*, defaults to 0): + Parameter mixing size. + + # Special tokens + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to -1): + The id of the padding token. + + # Patcher/Entropy model configuration + patcher_vocab_size (`int`, *optional*, defaults to 256): + Vocabulary size for the entropy model used in patching. + patcher_dim (`int`, *optional*, defaults to 512): + Hidden dimension for the entropy model. + patcher_n_layers (`int`, *optional*, defaults to 8): + Number of layers in the entropy model. + patcher_n_heads (`int`, *optional*, defaults to 8): + Number of attention heads in the entropy model. + patcher_head_dim (`int`, *optional*): + Dimension of each attention head in the entropy model. + patcher_n_kv_heads (`int`, *optional*): + Number of key-value heads in the entropy model. + patcher_max_seqlen (`int`, *optional*, defaults to 1024): + Maximum sequence length for the entropy model. + patcher_norm_eps (`float`, *optional*, defaults to 1e-5): + Layer normalization epsilon for the entropy model. + patcher_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for the entropy model. + patcher_sliding_window (`int`, *optional*): + Sliding window size for the entropy model attention. + patcher_ffn_dim_multiplier (`float`, *optional*): + Feedforward dimension multiplier for the entropy model. + patcher_multiple_of (`int`, *optional*, defaults to 256): + Make feedforward dimension multiple of this for the entropy model. + patcher_rope_theta (`float`, *optional*, defaults to 10000.0): + RoPE theta parameter for the entropy model. + patcher_rope_use_fp32_in_outer_product (`bool`, *optional*, defaults to False): + Whether to use fp32 in RoPE outer product for the entropy model. + patcher_attn_impl (`str`, *optional*, defaults to "sdpa"): + Attention implementation for the entropy model. + patcher_attn_bias_type (`str`, *optional*, defaults to "causal"): + Attention bias type for the entropy model. + patcher_init_base_std (`float`, *optional*): + Base initialization standard deviation for the entropy model. + patcher_init_std_factor (`str`, *optional*, defaults to "disabled"): + Initialization std factor for the entropy model. + patcher_dim_token_emb (`int`, *optional*): + Token embedding dimension for the entropy model. + patcher_weight_tying (`bool`, *optional*, defaults to False): + Whether to tie embeddings in the entropy model. + patcher_bos_token_id (`int`, *optional*, defaults to 1): + Beginning of sequence token id for the entropy model. + patcher_eos_token_id (`int`, *optional*, defaults to 2): + End of sequence token id for the entropy model. + + ```python + >>> from transformers import ByteLatentTransformer, BLTConfig + + >>> # Initializing a BLT configuration + >>> configuration = BLTConfig() + + >>> # Initializing a model from the configuration + >>> model = ByteLatentTransformer(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "blt" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=256, + max_seqlen=1024, + # Main architecture dimensions + dim=512, + n_layers=8, + n_heads=8, + head_dim=None, + n_kv_heads=None, + # Component-specific dimensions + dim_global=512, + dim_local_decoder=512, + dim_local_encoder=512, + n_layers_global=8, + n_layers_local_decoder=8, + n_layers_local_encoder=8, + n_heads_global=8, + n_heads_local_decoder=8, + n_heads_local_encoder=8, + n_kv_heads_global=None, + # Transformer configuration + norm_eps=1e-5, + dropout=0.0, + ffn_dim_multiplier=1.0, + multiple_of=256, + # Positional encoding + rope_theta=10000.0, + rope_use_fp32_in_outer_product=False, + # Attention configuration + attn_impl="sdpa", + attn_bias_type="causal", + local_attention_window_len=None, + use_rope=True, + # Initialization + init_base_std=None, + init_std_factor="disabled", + # Embedding dimensions + dim_token_emb=None, + dim_token=None, + # Patching configuration + patch_in_forward=False, + realtime_patching=True, + patch_size=None, + patching_mode=None, + patching_threshold=None, + patching_threshold_add=None, + monotonicity=False, + patching_batch_size=1, + patching_device="cuda", + max_patch_length=None, + entropy_model_checkpoint_dir=None, + # Cross attention configurations + cross_attn_encoder=False, + cross_attn_decoder=False, + cross_attn_window_encoder=None, + cross_attn_window_decoder=None, + cross_attn_k=None, + cross_attn_nheads=None, + cross_attn_all_layers_decoder=False, + cross_attn_all_layers_encoder=False, + cross_attn_use_flex_attention=True, + cross_attn_init_by_pooling=False, + # Encoder configurations + use_local_encoder_transformer=False, + max_encoder_seq_length=None, + encoder_hash_byte_group_size=None, + encoder_hash_byte_group_vocab=30000, + encoder_hash_byte_group_nb_functions=3, + encoder_enable_byte_ngrams=False, + encoder_ngram_to_size_str=None, + downsampling_by_pooling=None, + # Model behavior + share_encoder_decoder_emb=True, + weight_tying=False, + # Performance optimization + sequence_parallel=False, + loss_parallel=False, + fuse_sequence_parallel=False, + use_fsdp=True, + # Parameter mixing + pm_size=0, + # Special tokens + bos_token_id=1, + eos_token_id=2, + pad_token_id=-1, + # Patcher/Entropy model configuration + patcher_vocab_size=256, + patcher_dim=512, + patcher_n_layers=8, + patcher_n_heads=8, + patcher_head_dim=None, + patcher_n_kv_heads=None, + patcher_max_seqlen=1024, + patcher_norm_eps=1e-5, + patcher_dropout=0.0, + patcher_sliding_window=None, + patcher_ffn_dim_multiplier=None, + patcher_multiple_of=256, + patcher_rope_theta=10000.0, + patcher_rope_use_fp32_in_outer_product=False, + patcher_attn_impl="sdpa", + patcher_attn_bias_type="causal", + patcher_init_base_std=None, + patcher_init_std_factor="disabled", + patcher_dim_token_emb=None, + patcher_weight_tying=False, + patcher_bos_token_id=1, + patcher_eos_token_id=2, + # Inherited + **kwargs, + ): + # Basic model configuration + self.vocab_size = vocab_size + self.max_seqlen = max_seqlen + + # Main architecture dimensions + self.dim = dim + self.n_layers = n_layers + self.n_heads = n_heads + self.head_dim = head_dim + self.n_kv_heads = n_kv_heads + + # Component-specific dimensions + self.dim_global = dim_global + self.dim_local_decoder = dim_local_decoder + self.dim_local_encoder = dim_local_encoder + self.n_layers_global = n_layers_global + self.n_layers_local_decoder = n_layers_local_decoder + self.n_layers_local_encoder = n_layers_local_encoder + self.n_heads_global = n_heads_global + self.n_heads_local_decoder = n_heads_local_decoder + self.n_heads_local_encoder = n_heads_local_encoder + self.n_kv_heads_global = n_kv_heads_global + + # Transformer configuration + self.norm_eps = norm_eps + self.dropout = dropout + self.ffn_dim_multiplier = ffn_dim_multiplier + self.multiple_of = multiple_of + + # Positional encoding + self.rope_theta = rope_theta + self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product + + # Attention configuration + self.attn_impl = attn_impl + self.attn_bias_type = attn_bias_type + self.local_attention_window_len = local_attention_window_len + self.use_rope = use_rope + + # Initialization + self.init_base_std = init_base_std + self.init_std_factor = InitStdFactor(init_std_factor) + + # Embedding dimensions + self.dim_token_emb = dim_token_emb + self.dim_token = dim_token + + # Patching configuration + self.patch_in_forward = patch_in_forward + self.realtime_patching = realtime_patching + self.patch_size = patch_size + self.patching_mode = patching_mode + self.patching_threshold = patching_threshold + self.patching_threshold_add = patching_threshold_add + self.monotonicity = monotonicity + self.patching_batch_size = patching_batch_size + self.patching_device = patching_device + self.max_patch_length = max_patch_length + self.entropy_model_checkpoint_dir = entropy_model_checkpoint_dir + + # Cross attention configurations + self.cross_attn_encoder = cross_attn_encoder + self.cross_attn_decoder = cross_attn_decoder + self.cross_attn_window_encoder = cross_attn_window_encoder + self.cross_attn_window_decoder = cross_attn_window_decoder + self.cross_attn_k = cross_attn_k + self.cross_attn_nheads = cross_attn_nheads + self.cross_attn_all_layers_decoder = cross_attn_all_layers_decoder + self.cross_attn_all_layers_encoder = cross_attn_all_layers_encoder + self.cross_attn_use_flex_attention = cross_attn_use_flex_attention + self.cross_attn_init_by_pooling = cross_attn_init_by_pooling + + # Encoder configurations + self.use_local_encoder_transformer = use_local_encoder_transformer + self.max_encoder_seq_length = max_encoder_seq_length + self.encoder_hash_byte_group_size = encoder_hash_byte_group_size + self.encoder_hash_byte_group_vocab = encoder_hash_byte_group_vocab + self.encoder_hash_byte_group_nb_functions = encoder_hash_byte_group_nb_functions + self.encoder_enable_byte_ngrams = encoder_enable_byte_ngrams + self.encoder_ngram_to_size_str = encoder_ngram_to_size_str + self.downsampling_by_pooling = downsampling_by_pooling + + # Model behavior + self.share_encoder_decoder_emb = share_encoder_decoder_emb + self.weight_tying = weight_tying + + # Performance optimization + self.sequence_parallel = sequence_parallel + self.loss_parallel = loss_parallel + self.fuse_sequence_parallel = fuse_sequence_parallel + self.use_fsdp = use_fsdp + + # Parameter mixing + self.pm_size = pm_size + + # Patcher/Entropy model configuration + self.patcher_vocab_size = patcher_vocab_size + self.patcher_dim = patcher_dim + self.patcher_n_layers = patcher_n_layers + self.patcher_n_heads = patcher_n_heads + self.patcher_head_dim = patcher_head_dim + self.patcher_n_kv_heads = patcher_n_kv_heads + self.patcher_max_seqlen = patcher_max_seqlen + self.patcher_norm_eps = patcher_norm_eps + self.patcher_dropout = patcher_dropout + self.patcher_sliding_window = patcher_sliding_window + self.patcher_ffn_dim_multiplier = patcher_ffn_dim_multiplier + self.patcher_multiple_of = patcher_multiple_of + self.patcher_rope_theta = patcher_rope_theta + self.patcher_rope_use_fp32_in_outer_product = patcher_rope_use_fp32_in_outer_product + self.patcher_attn_impl = patcher_attn_impl + self.patcher_attn_bias_type = patcher_attn_bias_type + self.patcher_init_base_std = patcher_init_base_std + self.patcher_init_std_factor = InitStdFactor(patcher_init_std_factor) + self.patcher_dim_token_emb = patcher_dim_token_emb + self.patcher_weight_tying = patcher_weight_tying + self.patcher_bos_token_id = patcher_bos_token_id + self.patcher_eos_token_id = patcher_eos_token_id + + # Handle hash byte group size validation + if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str: + self.encoder_hash_byte_group_size = [ + int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0 + ] + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + **kwargs, + ) + + @property + def encoder_dim_token_emb(self): + """Compute encoder token embedding dimension.""" + if self.dim_token is not None: + return self.dim_token + elif self.use_local_encoder_transformer: + return self.dim_local_encoder + else: + # Use default patch_size of 8 if not set + patch_size = self.patch_size if self.patch_size is not None else 8 + return self.dim_global // patch_size + + @property + def encoder_dim_patch_emb(self): + """Compute encoder patch embedding dimension.""" + if self.cross_attn_encoder: + if self.cross_attn_init_by_pooling: + return self.dim_local_encoder + else: + return self.dim_global + return None + + @property + def global_dim_patch_emb(self): + """Compute global patch embedding dimension.""" + dim_token_emb = self.encoder_dim_token_emb + if self.cross_attn_encoder: + cross_attn_k = self.cross_attn_k if self.cross_attn_k is not None else 1 + return dim_token_emb * cross_attn_k + elif ( + self.downsampling_by_pooling is None + or not self.downsampling_by_pooling + or len(self.downsampling_by_pooling) == 0 + ): + # Use default patch_size of 8 if not set + patch_size = self.patch_size if self.patch_size is not None else 8 + return dim_token_emb * patch_size + else: + return dim_token_emb * sum([pooling in self.downsampling_by_pooling for pooling in ["avg", "min", "max"]]) + + @property + def decoder_dim_token_emb(self): + """Compute decoder token embedding dimension.""" + if self.share_encoder_decoder_emb: + return self.encoder_dim_token_emb + elif self.dim_token is not None: + return self.dim_token + else: + return self.dim_local_decoder + + def get_init_std_factor(self, depth: int) -> float: + """ + Calculate the initialization standard deviation scaling factor for a given layer depth. + + Args: + depth: Current layer depth (0-indexed) + + Returns: + Scaling factor to divide the base initialization std by + """ + if self.init_std_factor == InitStdFactor.CURRENT_DEPTH: + return (2 * (depth + 1)) ** 0.5 + else: # DISABLED + return 1.0 + + +__all__ = ["BLTConfig", "InitStdFactor", "PatchingModeEnum"] diff --git a/backup_blt_wip_backup/convert_hf_blt_original_to_unified.py b/backup_blt_wip_backup/convert_hf_blt_original_to_unified.py new file mode 100644 index 0000000000000000000000000000000000000000..dad247b19c62d983d71526ecdc8ff6c13ca9a5c8 --- /dev/null +++ b/backup_blt_wip_backup/convert_hf_blt_original_to_unified.py @@ -0,0 +1,540 @@ +import argparse +import json +import logging +import os +from typing import Dict, Any, Optional + +import torch +from huggingface_hub import hf_hub_download, snapshot_download +from safetensors.torch import load_file, save_file + +from transformers.utils import logging as transformers_logging + +logger = transformers_logging.get_logger(__name__) +transformers_logging.set_verbosity_info() + +# For standalone execution, we'll skip the model validation to avoid import issues +# The script will create the unified config and weights files without testing model instantiation +ENABLE_MODEL_VALIDATION = False + +import sys +import os + +from transformers.models.blt_wip.modeling_blt_wip import BLTModel +from transformers.models.blt_wip.configuration_blt import BLTConfig + + +ENABLE_MODEL_VALIDATION = True + +def download_model_files(model_id: str, cache_dir: Optional[str] = None) -> Dict[str, str]: + """ + Download all necessary files from HuggingFace Hub. + + Args: + model_id: HuggingFace model ID (e.g., "facebook/blt-1b") + cache_dir: Optional cache directory + + Returns: + Dictionary with paths to downloaded files + """ + logger.info(f"Downloading model files from {model_id}...") + + try: + # Download main config + config_path = hf_hub_download( + repo_id=model_id, + filename="config.json", + cache_dir=cache_dir + ) + + # Download main model weights + weights_path = hf_hub_download( + repo_id=model_id, + filename="model.safetensors", + cache_dir=cache_dir + ) + + # Download entropy model params + entropy_params_path = hf_hub_download( + repo_id=model_id, + filename="entropy_model/params.json", + cache_dir=cache_dir + ) + + # Download entropy model weights + entropy_weights_path = hf_hub_download( + repo_id=model_id, + filename="entropy_model/consolidated.pth", + cache_dir=cache_dir + ) + + return { + "config": config_path, + "weights": weights_path, + "entropy_params": entropy_params_path, + "entropy_weights": entropy_weights_path + } + + except Exception as e: + logger.error(f"Failed to download files from {model_id}: {e}") + raise + + +def merge_configurations(config_path: str, entropy_params_path: str) -> Dict[str, Any]: + """ + Merge main configuration with entropy model parameters. + + Args: + config_path: Path to main config.json + entropy_params_path: Path to entropy_model/params.json + + Returns: + Merged configuration dictionary + """ + logger.info("Merging configurations...") + + # Load main configuration + with open(config_path, 'r') as f: + main_config = json.load(f) + + # Load entropy model parameters + with open(entropy_params_path, 'r') as f: + entropy_data = json.load(f) + + # Extract entropy model and patcher parameters + entropy_model_params = entropy_data.get("entropy_model", {}) + patcher_args = entropy_data.get("data", {}).get("patcher_args", {}) + + # Create unified configuration + unified_config = main_config.copy() + + # Ensure required main model parameters are present with correct types + # Sometimes the original config may have different key names + if "vocab_size" not in unified_config: + unified_config["vocab_size"] = int(main_config.get("vocab_size", 256)) + if "dim" not in unified_config: + unified_config["dim"] = int(main_config.get("dim", main_config.get("hidden_size", main_config.get("d_model", 512)))) + if "n_layers" not in unified_config: + unified_config["n_layers"] = int(main_config.get("n_layers", main_config.get("num_layers", main_config.get("num_hidden_layers", 8)))) + if "n_heads" not in unified_config: + unified_config["n_heads"] = int(main_config.get("n_heads", main_config.get("num_attention_heads", main_config.get("num_heads", 8)))) + if "max_seqlen" not in unified_config: + unified_config["max_seqlen"] = int(main_config.get("max_seqlen", main_config.get("max_position_embeddings", main_config.get("seq_length", 1024)))) + + # Ensure other integer parameters are properly typed + for key in ["vocab_size", "dim", "n_layers", "n_heads", "max_seqlen"]: + if key in unified_config and not isinstance(unified_config[key], int): + unified_config[key] = int(unified_config[key]) + + # Convert all patch_size values to integers to avoid float/int type errors + patch_size = patcher_args.get("patch_size", 8) + if isinstance(patch_size, float): + patch_size = int(patch_size) + + # Add patching configuration + unified_config.update({ + "patch_in_forward": True, + "realtime_patching": True, + "patching_mode": "entropy", + + # Patcher arguments + "patch_size": patch_size, + "patching_threshold": patcher_args.get("threshold", 0.5), + "patching_threshold_add": patcher_args.get("threshold_add", 0.0), + "max_patch_length": patcher_args.get("max_patch_length"), + "patching_batch_size": patcher_args.get("patching_batch_size", 1), + "patching_device": patcher_args.get("patching_device", "cuda"), + "monotonicity": patcher_args.get("monotonicity", False), + + # Entropy model (patcher) architecture parameters + "patcher_vocab_size": int(entropy_model_params.get("vocab_size", 256)), + "patcher_dim": int(entropy_model_params.get("dim", 512)), + "patcher_n_layers": int(entropy_model_params.get("n_layers", 8)), + "patcher_n_heads": int(entropy_model_params.get("n_heads", 8)), + "patcher_head_dim": int(entropy_model_params.get("head_dim")) if entropy_model_params.get("head_dim") is not None else None, + "patcher_n_kv_heads": int(entropy_model_params.get("n_kv_heads")) if entropy_model_params.get("n_kv_heads") is not None else None, + "patcher_max_seqlen": int(entropy_model_params.get("max_seqlen", 1024)), + "patcher_norm_eps": entropy_model_params.get("norm_eps", 1e-5), + "patcher_dropout": entropy_model_params.get("dropout", 0.0), + "patcher_sliding_window": int(entropy_model_params.get("sliding_window", 512)) if entropy_model_params.get("sliding_window") is not None else None, + "patcher_ffn_dim_multiplier": entropy_model_params.get("ffn_dim_multiplier"), + "patcher_multiple_of": int(entropy_model_params.get("multiple_of", 256)), + "patcher_rope_theta": entropy_model_params.get("rope_theta", 10000.0), + "patcher_rope_use_fp32_in_outer_product": entropy_model_params.get("rope_use_fp32_in_outer_product", False), + "patcher_attn_impl": entropy_model_params.get("attn_impl", "sdpa"), + "patcher_attn_bias_type": entropy_model_params.get("attn_bias_type", "causal"), + "patcher_init_base_std": entropy_model_params.get("init_base_std"), + "patcher_init_std_factor": entropy_model_params.get("init_std_factor", "disabled"), + "patcher_dim_token_emb": entropy_model_params.get("dim_token_emb"), + "patcher_weight_tying": entropy_model_params.get("weight_tying", False), + "patcher_bos_token_id": entropy_model_params.get("bos_token_id", 1), + "patcher_eos_token_id": entropy_model_params.get("eos_token_id", 2), + }) + + logger.info(f"Merged configuration with {len(unified_config)} parameters") + return unified_config + + +def merge_weights(weights_path: str, entropy_weights_path: str) -> Dict[str, torch.Tensor]: + """ + Merge main model weights with entropy model weights. + + Args: + weights_path: Path to main model.safetensors + entropy_weights_path: Path to entropy_model/consolidated.pth + + Returns: + Merged state dictionary + """ + logger.info("Merging model weights...") + + # Load main model weights + main_weights = load_file(weights_path) + logger.info(f"Loaded main model weights: {len(main_weights)} tensors") + + # Load entropy model weights + entropy_weights = torch.load(entropy_weights_path, map_location='cpu', weights_only=True) + + # Handle nested entropy model structure + if 'model' in entropy_weights: + entropy_weights = entropy_weights['model'] + elif 'state_dict' in entropy_weights: + entropy_weights = entropy_weights['state_dict'] + + logger.info(f"Loaded entropy model weights: {len(entropy_weights)} tensors") + + # Create unified state dict + unified_weights = main_weights.copy() + + # Add entropy model weights with "patcher." prefix + for key, tensor in entropy_weights.items(): + patcher_key = f"patcher.{key}" + unified_weights[patcher_key] = tensor + + logger.info(f"Merged weights: {len(unified_weights)} tensors total") + return unified_weights + + +def create_tokenizer_config(output_dir: str, config: Dict[str, Any]): + """ + Create tokenizer configuration file. + + Args: + output_dir: Output directory + config: Model configuration + """ + logger.info("Creating tokenizer configuration...") + + tokenizer_config = { + "tokenizer_class": "BltTokenizer", + "vocab_size": config.get("vocab_size", 256), + "model_max_length": config.get("max_seqlen", 1024), + "add_bos_token": True, + "add_eos_token": True, + "bos_token": "", + "eos_token": "", + "pad_token": "", + "unk_token": "", + } + + tokenizer_path = os.path.join(output_dir, "tokenizer_config.json") + with open(tokenizer_path, 'w') as f: + json.dump(tokenizer_config, f, indent=2) + + logger.info(f"Tokenizer config saved to {tokenizer_path}") + + +def validate_unified_model(config: Dict[str, Any], weights: Dict[str, torch.Tensor]): + """ + Validate the unified model configuration and weights. + + Args: + config: Unified configuration + weights: Unified weights + """ + logger.info("Validating unified model...") + + # Check required configuration keys + required_keys = [ + "vocab_size", "dim", "n_layers", "n_heads", + "patch_in_forward", "patcher_vocab_size", "patcher_dim" + ] + + missing_keys = [key for key in required_keys if key not in config] + if missing_keys: + logger.warning(f"Missing configuration keys: {missing_keys}") + + # Check for patcher weights + patcher_weights = [key for key in weights.keys() if key.startswith("patcher.")] + if not patcher_weights: + logger.warning("No patcher weights found in unified weights") + else: + logger.info(f"Found {len(patcher_weights)} patcher weight tensors") + + # Check for main model weights + main_weights = [key for key in weights.keys() if not key.startswith("patcher.")] + logger.info(f"Found {len(main_weights)} main model weight tensors") + + # Try to create the model with the configuration (if imports are available) + if ENABLE_MODEL_VALIDATION and BLTConfig is not None and BLTModel is not None: + try: + logger.info("Testing model instantiation...") + + # Debug: Print config keys to help diagnose issues + logger.debug(f"Config keys: {list(config.keys())}") + logger.debug(f"Config vocab_size: {config.get('vocab_size')} (type: {type(config.get('vocab_size'))})") + logger.debug(f"Config dim: {config.get('dim')} (type: {type(config.get('dim'))})") + + blt_config = BLTConfig(**config) + model = BLTModel(blt_config) + logger.info("✓ Model instantiation successful") + + # Try to load the weights + logger.info("Testing weight loading...") + try: + missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False) + if missing_keys: + logger.warning(f"Missing keys during weight loading: {missing_keys[:5]}...") # Show first 5 + if unexpected_keys: + logger.warning(f"Unexpected keys during weight loading: {unexpected_keys[:5]}...") # Show first 5 + logger.info("✓ Weight loading successful") + except Exception as weight_error: + logger.warning(f"Weight loading failed: {weight_error}") + logger.info("Model instantiation successful, but weight loading had issues") + + except Exception as e: + logger.error(f"Model validation failed: {e}") + logger.debug(f"Full error details:", exc_info=True) + logger.warning("Model may not be compatible with modeling_blt_wip.py") + logger.info("You can still use the converted files and test manually") + else: + logger.info("Skipping model instantiation test (BLT classes not available)") + logger.info("You can test the model manually after conversion") + + logger.info("Model validation completed") + + +def convert_hf_blt_to_unified( + model_id: str, + output_dir: str, + config_name: str = "config.json", + weights_name: str = "pytorch_model.bin", + safe_serialization: bool = True, + cache_dir: Optional[str] = None, + validate: bool = True, +) -> None: + """ + Convert BLT model from HuggingFace Hub format to unified format. + + Args: + model_id: HuggingFace model ID (e.g., "facebook/blt-1b") + output_dir: Output directory for unified model + config_name: Name for unified config file + weights_name: Name for unified weights file + safe_serialization: Whether to use safetensors format + cache_dir: Cache directory for downloads + validate: Whether to validate the unified model + """ + logger.info(f"Converting {model_id} to unified format...") + + # Download model files + file_paths = download_model_files(model_id, cache_dir) + + # Merge configurations + unified_config = merge_configurations( + file_paths["config"], + file_paths["entropy_params"] + ) + + # Merge weights + unified_weights = merge_weights( + file_paths["weights"], + file_paths["entropy_weights"] + ) + + # Validate if requested + if validate: + validate_unified_model(unified_config, unified_weights) + + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + # Save unified configuration + config_path = os.path.join(output_dir, config_name) + with open(config_path, 'w') as f: + json.dump(unified_config, f, indent=2) + logger.info(f"Unified config saved to {config_path}") + + # Save unified weights + if safe_serialization and weights_name.endswith('.bin'): + weights_name = weights_name.replace('.bin', '.safetensors') + elif not safe_serialization and weights_name.endswith('.safetensors'): + weights_name = weights_name.replace('.safetensors', '.bin') + + weights_path = os.path.join(output_dir, weights_name) + if safe_serialization: + save_file(unified_weights, weights_path) + else: + torch.save(unified_weights, weights_path) + logger.info(f"Unified weights saved to {weights_path}") + + # Create tokenizer config + create_tokenizer_config(output_dir, unified_config) + + # Create README + readme_path = os.path.join(output_dir, "README.md") + with open(readme_path, 'w') as f: + f.write(f"""# Unified BLT Model + +This model was converted from {model_id} to unified format compatible with modeling_blt_wip.py. + +## Files + +- `{config_name}`: Unified configuration (main config + entropy model params) +- `{weights_name}`: Unified weights (main model + entropy model weights with "patcher." prefix) +- `tokenizer_config.json`: Tokenizer configuration + +## Usage + +```python +import torch +import json +from modeling_blt_wip import BLTModel, BLTConfig + +# Load configuration +with open('{config_name}', 'r') as f: + config_dict = json.load(f) + +config = BLTConfig(**config_dict) + +# Load model +model = BLTModel(config) + +# Load weights +if '{weights_name}'.endswith('.safetensors'): + from safetensors.torch import load_file + state_dict = load_file('{weights_name}') +else: + state_dict = torch.load('{weights_name}', map_location='cpu') + +model.load_state_dict(state_dict, strict=False) +``` + +## Original Model + +Converted from: {model_id} +""") + + logger.info(f"Conversion completed! Unified model saved to: {output_dir}") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert BLT models from HuggingFace Hub format to unified format", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Convert facebook/blt-1b to unified format + python convert_hf_blt_to_unified.py \\ + --model_id facebook/blt-1b \\ + --output_dir ./unified_blt_1b + + # Convert with custom file names + python convert_hf_blt_to_unified.py \\ + --model_id facebook/blt-7b \\ + --output_dir ./unified_blt_7b \\ + --config_name unified_config.json \\ + --weights_name unified_model.safetensors + + # Convert without validation + python convert_hf_blt_to_unified.py \\ + --model_id facebook/blt-1b \\ + --output_dir ./my_blt \\ + --no_validate + """ + ) + + # Required arguments (with defaults for debugging) + parser.add_argument( + "--model_id", + type=str, + default="facebook/blt-1b", + help="HuggingFace model ID (e.g., facebook/blt-1b)" + ) + parser.add_argument( + "--output_dir", + type=str, + default="./unified_blt_debug", + help="Output directory for unified model" + ) + + # Optional arguments + parser.add_argument( + "--config_name", + type=str, + default="config.json", + help="Name for unified config file (default: config.json)" + ) + parser.add_argument( + "--weights_name", + type=str, + default="pytorch_model.bin", + help="Name for unified weights file (default: pytorch_model.bin)" + ) + parser.add_argument( + "--safe_serialization", + action="store_true", + default=True, + help="Use safetensors format for weights (default: True)" + ) + parser.add_argument( + "--no_safe_serialization", + dest="safe_serialization", + action="store_false", + help="Use .bin format instead of safetensors" + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="Cache directory for downloads" + ) + parser.add_argument( + "--no_validate", + dest="validate", + action="store_false", + default=True, + help="Skip model validation" + ) + parser.add_argument( + "--debug", + action="store_true", + default=True, # Enable debug by default for easier debugging + help="Enable debug logging" + ) + + args = parser.parse_args() + + # Setup logging + if args.debug: + transformers_logging.set_verbosity_debug() + logging.basicConfig(level=logging.DEBUG) + + # Run conversion + try: + convert_hf_blt_to_unified( + model_id=args.model_id, + output_dir=args.output_dir, + config_name=args.config_name, + weights_name=args.weights_name, + safe_serialization=args.safe_serialization, + cache_dir=args.cache_dir, + validate=args.validate, + ) + except Exception as e: + logger.error(f"Conversion failed: {e}") + raise + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/backup_blt_wip_backup/modeling_blt_wip.py b/backup_blt_wip_backup/modeling_blt_wip.py new file mode 100644 index 0000000000000000000000000000000000000000..fe846fc0ab9f7f5e794e32fd3809aea26451676c --- /dev/null +++ b/backup_blt_wip_backup/modeling_blt_wip.py @@ -0,0 +1,1836 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import logging +import os +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn +import torch.nn as nn +from torch.nn import functional as F +from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention + +from ...modeling_utils import PreTrainedModel +from .configuration_blt import ( + BLTConfig, + PatchingModeEnum, +) + + +SEP = " " +BOS_ID: int = 1 +EOS_ID: int = 2 +PAD_ID: int = -1 +BOE_ID: int = 0 +BPE_ID: int = 3 +OFFSET: int = 4 + +BYTE_UNITS: int = 256 + +RMSNorm = nn.RMSNorm + +logger = logging.getLogger() + +flex_attention_comp = flex_attention + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "sdpa": + BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0)) + + if attn_bias_type == "causal": + return "causal" + + if BLT_SUPPRESS_ATTN_ERROR == 1: + return "causal" + else: + raise ValueError( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1" + ) + elif attn_impl == "flex_attention": + return create_block_mask(causal_mask, None, None, seqlen, seqlen) + else: + raise NotImplementedError(f"Attention {attn_impl} with {sliding_window} sliding window not implemented") + + +def cross_entropy(pred, target, **kwargs): + return F.nll_loss( + F.log_softmax(pred.flatten(end_dim=-2).float(), -1), + target.flatten(end_dim=-1), + **kwargs, + ) + + +def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims." + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + rope_use_fp32_in_outer_product: bool = False, +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + if rope_use_fp32_in_outer_product: + t = t.to(torch.float32) + + freqs = torch.outer(t, freqs).float() + + cos, sin = freqs.cos(), freqs.sin() + + return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + seq_dim (int): Sequence dimension index. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= seq_dim < ndim + assert freqs_cis.shape == ( + x.shape[seq_dim], + x.shape[-3], + 2, + 2, + ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" + shape = [d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])] + [2, 2] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + seq_dim: int, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, seq_dim).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 + xq_out = (xq_ * freqs_cis).sum(5).flatten(3) + xk_out = (xk_ * freqs_cis).sum(5).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed. +class RotaryEmbedding(torch.nn.Module): + """ + RotaryEmbedding Module + """ + + def __init__( + self, + theta: float, + head_dim: int, + max_seqlen: int = 1024, + rope_use_fp32_in_outer_product: bool = False, + ): + super().__init__() + + self.theta = theta + self.head_dim = head_dim + self.max_seqlen = max_seqlen + self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + dim=head_dim, + end=max_seqlen, + theta=theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, + ), + persistent=False, + ) + + + def forward(self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None): + """ + Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions + Args: + seqlen (int): Contiguous sequence length + tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen + + Returns: + Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis + """ + test = (seqlen is not None) or (tok_idx is not None) + assert test, "Should provide atleast seqlen or tok_idx" + if tok_idx is not None: + return self.freqs_cis[tok_idx] + elif seqlen is not None: + return self.freqs_cis[0:seqlen] + + +class BLTAttention(nn.Module): + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + rope_theta: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + self.rope_theta = rope_theta + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + # B S D + bsz, seq_len, dim = x.shape + xq = self.wq(x.view_as(x)) + xk = self.wk(x.view_as(x)) + xv = self.wv(x.view_as(x)) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len]) + + # This condition helps us be easily compatible + # with inference by adding a pluggable KVCache + if hasattr(self, "kv_cache"): + xk, xv = self.kv_cache.update(xk, xv, tok_idx) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + if attn_impl == "flex_attention": + assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) + output = flex_attention_comp(xq, xk, xv, block_mask=mask) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + elif attn_impl == "sdpa": + xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) + assert mask is None or isinstance(mask, (str, torch.Tensor)) + is_causal = (mask == "causal") if isinstance(mask, str) else False + mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None + output = F.scaled_dot_product_attention( + xq, + xk, + xv, + is_causal=is_causal, + attn_mask=mask, + ) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + else: + raise NotImplementedError(f"Attention implementation {attn_impl} not supported") + + output_reshaped = output.reshape(output_shape) + + output = self.wo(output_reshaped) + + return output + + + + +class BLTMLP(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + mp_size: int = 1, + ): + super().__init__() + + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + assert hidden_dim % mp_size == 0 + + self.dim = dim + self.hidden_dim = hidden_dim + + self.w1 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w3 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w2 = nn.Linear( + hidden_dim, + dim, + bias=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # B S D + x1 = self.w1(x.view_as(x)) + x3 = self.w3(x.view_as(x)) + output = self.w2(F.silu(x1) * x3) + return output + + + + +class BLTTransformerLayer(nn.Module): + def __init__(self, args): + super().__init__() + + # Extract parameters from dictionary + dim = args["dim"] + n_heads = args["n_heads"] + head_dim = args["head_dim"] + n_kv_heads = args["n_kv_heads"] + rope_theta = args["rope_theta"] + multiple_of = args["multiple_of"] + ffn_dim_multiplier = args["ffn_dim_multiplier"] + norm_eps = args["norm_eps"] + + assert (head_dim is not None) or (n_heads is not None), "Should specify at least head_dim or n_heads" + self.head_dim = head_dim or dim // n_heads + self.n_heads = n_heads or dim // head_dim + self.n_kv_heads = n_kv_heads or self.n_heads + + assert n_heads % self.n_kv_heads == 0 + assert dim % n_heads == 0 + + self.attention = BLTAttention( + dim=dim, + head_dim=self.head_dim, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + rope_theta=rope_theta, + ) + self.feed_forward = BLTMLP( + dim=dim, + hidden_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + self.attention_norm = RMSNorm(dim, eps=norm_eps) + self.ffn_norm = RMSNorm(dim, eps=norm_eps) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + norm_x = self.attention_norm(x) + attn_out = self.attention( + norm_x, + freq_cis, + tok_idx=tok_idx, + mask=mask, + attn_impl=attn_impl, + ) + h = x + attn_out + h_norm = self.ffn_norm(h) + out = h + self.feed_forward(h_norm) + return out + + + + +def rightpad(seq, pad_id, max_len): + return seq + [pad_id] * (max_len - len(seq)) + + +def check_non_zero_after_zero(tensor): + zero_mask = tensor == 0 + shifted_mask = torch.cat( + [ + torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), + zero_mask[:, :-1], + ], + dim=1, + ) + non_zero_after_zero = (tensor != 0) & shifted_mask + return non_zero_after_zero.any() + + +def fill_tokens(tokens, patch_size, fill_id): + batch_size, seq_len = tokens.shape + if seq_len % patch_size == 0: + return tokens + else: + remaining = patch_size - seq_len % patch_size + final_padding = tokens.new(batch_size, remaining).fill_(fill_id) + return torch.cat((tokens, final_padding), dim=1) + + +def rolling_polynomial_hash(t, hash_func_nb: int = 0): + primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, + ] + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) + prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) + return torch.sum(t * prime_powers, dim=-1) + + +def byte_group_hash_function(x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): + """ + Returns a hash of the input x and maps it to a value in the range [0, max_hash]. + + expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. + returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. + + Note: max hash can make a big difference on the number of collisions. + """ + with torch.no_grad(): + bs, seq_len = x.shape + prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) + x = torch.cat([prefix, x], dim=1) + windows = x.unfold(1, group_size, 1) + # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) + hashes = rolling_polynomial_hash(windows, hash_func_nb) + hash_values_range = hashes % max_hash + hash_values_range.requires_grad = False + return hash_values_range + + +def create_patch_mask_from_ids(patch_ids, num_patches, window=None, patches_as_queries=False): + """ + Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) + is True if the patch id at position (i, j) is less than or equal to k. + Args: + patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids. + num_patches (int): Total number of patches. + window (int): If not None, only considers patches within a window of size window. + patches_as_queries (bool): If True, the patches are used as queries + Returns: + torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask. + """ + bs, seq_len = patch_ids.shape + if not patches_as_queries: + q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches) + kv_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(0) + .expand(bs, seq_len, num_patches) + ) + else: + kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len) + q_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(bs, num_patches, seq_len) + ) + if window is None: + mask = q_ids == kv_ids + else: + mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window) + return mask + + +def cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=1, + window=None, + block_mask=True, +): + bs = patch_ids.shape[0] + with torch.no_grad(): + # Create the patch mask + cross_mask = create_patch_mask_from_ids( + patch_ids, + patch_lengths.shape[1], + window=window, + patches_as_queries=patches_as_queries, + ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1) + q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N + kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k + assert cross_mask.shape == ( + bs, + q_len, + kv_len, + ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" + block_mask = None + if block_mask: + + def patch_mask(b, h, q_idx, kv_idx): + return cross_mask[b, q_idx, kv_idx] + + block_mask = create_block_mask( + patch_mask, + B=bs, + H=None, + Q_LEN=q_len, + KV_LEN=kv_len, + _compile=True, + ) + return block_mask + else: + return torch.where(cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))).unsqueeze( + 1 + ) # [bs, 1, q_len, kv_len] + + +def get_blt_input( + tokens: torch.Tensor, + enforce_patch_size_multiple: bool, + nb_boe: torch.Tensor, + patch_size: int, + boe_id: int, +): + """ + This function returns X_et, X_gt and X_dt, the encoder, global, and decoder + tokens respectively. + + Consider the input and target sequences: + X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13] + Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14] + with patch_size=4 + + Note 1: that there will be no special tokens introduced at the patch level. + Note 2: X_e needs to be trimmed to be passed to Global + + Current without boe: + X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] + X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch + X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] + Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] + + --> lag fix: + X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]] + X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]] + X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] + Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] + + Dynamic (current): + X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos] + Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] + + entropy patching: + input: 7, bos, 9, 10 + pred (high entropy): eos, 8, 10, eos + + X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos] + X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]] + X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]] + Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] + + --> lag fix no boe (force single byte first patch): + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch + X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] + Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] + + input: 4, 7, bos, 9, 10 + pred (high entropy): 5, eos, 8, 10, eos + + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch + X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] + Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] + + Handle the last byte properly. + patch_lengths = [1, 1, 3, 2, 2 1 2 2 1] + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch + X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]] + Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]] + + + bpe delim + X_et = [[3,4,5,6,7,,eos,bos,,8,9,,10,,eos,bos,11,12] + X_g = [[3], [4,5,6,7,], [eos,bos,], .. + X_dt = [[3,4,5,6,7], [,eos,bos], [,bos,8], .. + Y = [4,5,6,7,, eos,bos, 8,9,, .. + + + Note 1: that there will be no special tokens introduced at the patch level. + Note 2: X_e needs to be trimmed to be passed to Global + """ + batch_size, seq_len = tokens.shape + local_encoder_tokens = tokens + local_decoder_tokens = tokens + + if nb_boe > 0: + padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id) + local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1) + # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id) + + # create global tokens, contains boe tokens and eos + # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) + # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size) + # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:] + # global_tokens += global_tokens.eq(0).int() * boe_id + # TODO: fix this when we want to use block causal in the global. + + if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0: + local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) + + return local_encoder_tokens, None, local_decoder_tokens + + +class LocalModelBase(nn.Module): + def __init__(self, config: BLTConfig, component_type: str = "encoder"): + super().__init__() + + # Store config for later use + self.config = config + + # Use component-specific dimensions + if component_type == "encoder": + self.dim = config.dim_local_encoder + self.n_layers = config.n_layers_local_encoder + self.n_heads = config.n_heads_local_encoder + self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen + self.attn_bias_type = "local_block_causal" + self.sliding_window = config.local_attention_window_len + elif component_type == "decoder": + self.dim = config.dim_local_decoder + self.n_layers = config.n_layers_local_decoder + self.n_heads = config.n_heads_local_decoder + self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen + self.attn_bias_type = "local_block_causal" + self.sliding_window = config.local_attention_window_len + else: + raise ValueError(f"Unknown component_type: {component_type}") + + self.dropout = config.dropout + self.vocab_size = config.vocab_size + config.pm_size + self.patch_size = config.patch_size + + self.attn_impl = config.attn_impl + self.use_rope = config.use_rope + self.init_std_factor = config.init_std_factor + self.init_base_std = config.init_base_std + self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None) + self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None) + self.cross_attn_k = getattr(config, "cross_attn_k", None) + self.eos_id = config.eos_token_id + + self.boe_id = BOE_ID + + # Initialize cross attention layers as None (will be set by subclasses if needed) + self.cross_attn_layers = None + + # Create parameter dict for BLTTransformerLayers + layer_params = { + "dim": self.dim, + "n_heads": self.n_heads, + "head_dim": config.head_dim, + "n_kv_heads": getattr(config, "n_kv_heads", None), + "rope_theta": config.rope_theta, + "multiple_of": getattr(config, "multiple_of", 256), + "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None), + "norm_eps": config.norm_eps, + } + + self.layers = nn.ModuleList([BLTTransformerLayer(layer_params) for _ in range(self.n_layers)]) + + if not self.use_rope: + self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length + else: + self.rope = RotaryEmbedding( + theta=config.rope_theta, + head_dim=config.head_dim or self.dim // self.n_heads, + max_seqlen=self.max_seqlen, + rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, + ) + self.pos_embeddings = None + + # Set dimension-specific embedding dimensions + if component_type == "encoder": + self.dim_token_emb = config.encoder_dim_token_emb + self.dim_patch_emb = config.encoder_dim_patch_emb + elif component_type == "decoder": + self.dim_token_emb = config.decoder_dim_token_emb + self.dim_patch_emb = config.dim_global + + self.token_embedding_projection = ( + nn.Linear(self.dim_token_emb, self.dim, bias=False) + if self.dim_token_emb is not None and self.dim_token_emb != self.dim + else None + ) + + self.patch_embedding_projection = self._create_patch_projection(config) + + def _should_create_patch_projection(self, config: BLTConfig): + dimension_mismatch = self.dim_patch_emb is not None and self.dim_patch_emb != self.dim + + # Check cross attention conditions + cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( + config.cross_attn_decoder and config.cross_attn_init_by_pooling + ) + + return dimension_mismatch or cross_attn_conditions + + def _create_patch_projection(self, config): + if not self._should_create_patch_projection(config): + return None + + output_dim = self.dim_token_emb * (self.cross_attn_k or 1) + + return nn.Linear( + in_features=self.dim_patch_emb, + out_features=output_dim, + bias=False, + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + return embeds + else: + return self.tok_embeddings(tokens) + + + + +class LocalEncoder(LocalModelBase): + def __init__(self, config: BLTConfig): + super().__init__(config, component_type="encoder") + + self.apply_transformer = config.use_local_encoder_transformer + self.downsampling_by_pooling = config.downsampling_by_pooling + self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None + self.cross_attn_encoder = config.cross_attn_encoder + self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder + self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_nheads = config.cross_attn_nheads + + self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim) + + if self.cross_attn_encoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=config.norm_eps, + ) + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + assert self.expects_hash_embeddings, "Not expecting embeddings to be passed." + return embeds + else: + return self.tok_embeddings(tokens) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + num_patches: Optional[int] = None, + patch_ids: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ """ + bs, seqlen = tokens.shape + if mask is None: + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) + + h = self.apply_embedding(tokens, embeds) + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + h = F.dropout(h, p=self.dropout, training=self.training) + + for i, layer in enumerate(self.layers): + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) + # check if cross attention should be applied to either all layer or only the last layer + if self.cross_attn_encoder and (i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder): + # apply pooling and project + if self.cross_attn_init_by_pooling and patch_embeds is None: + patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids) + if self.patch_embedding_projection is not None: + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim) + + layer_idx = i if self.cross_attn_all_layers_encoder else 0 + patch_embeds_cross = self.cross_attn_layers[layer_idx]( + x=patch_embeds, + kv=h, + mask=cross_mask, + ) + patch_embeds = patch_embeds + patch_embeds_cross + + h_residual = patch_embeds if self.cross_attn_encoder else None + return (h, h_residual), cache + + def patch_reduce(self, h, max_num_patches, reduction, patch_ids): + """ + Reduce variable length patches to single embedding per patch + Note: this works with variable number of patches for different sequences in the batch + It handles variable length patches by assuming that patch_lengths will be 0 for any + extra patches on the *right*. Since there can be a variable number of patches + this function also return the number of patches for each sequence in the batch. + Any embeddings on the right that are not allocated to a patch + (i.e. if the sum(patch_lengths[i]) < seq_len for any i) + will be sent to a dummy patch, which is trimmed before returning. + """ + bs, seq_len, emb_dim = h.shape + + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + + reduced_embs = torch.zeros((bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device) + reduced_embs = reduced_embs.scatter_reduce( + src=h, + dim=1, + index=patch_ids, + reduce=reduction, + include_self=False, + ) + reduced_embs = reduced_embs[:, :max_num_patches, :] + + return reduced_embs + + +class LocalDecoder(LocalModelBase): + def __init__(self, config: BLTConfig): + super().__init__(config, component_type="decoder") + + # Model configuration flags + self.cross_attn_decoder = config.cross_attn_decoder + self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder + self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_nheads = config.cross_attn_nheads + + self.norm = RMSNorm(self.dim, eps=config.norm_eps) + + if self.cross_attn_decoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=config.norm_eps, + ) + ) + + self.output = nn.Linear( + self.dim, + config.vocab_size, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor], + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + bs, seqlen = tokens.shape + assert embeds is not None, "Embeddings must be provided" + + if mask is None: + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) + + h = embeds + + if self.patch_embedding_projection is not None: + assert patch_embeds is not None, "Patch embeddings must be passed." + patch_embeds = self.patch_embedding_projection(patch_embeds) + if self.cross_attn_k is not None: + patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim) + + if patch_embeds is not None and not self.cross_attn_decoder: + h = h + patch_embeds + + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + h = F.dropout(h, p=self.dropout, training=self.training) + for i, layer in enumerate(self.layers): + if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder): + # Use cross attention to extract info from patch_embeds into h + h_cross = self.cross_attn_layers[i]( + x=h, + kv=patch_embeds, + mask=cross_mask, + ) + h = h + h_cross + + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) + + h_preds = self.norm(h) + h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) + h_preds = self.output(h_preds) + h_preds = h_preds.float() + return h_preds, cache + + +class BLTCrossAttention(nn.Module): + """ + BLTCrossAttention block to attend to the encoder states from the decoder. + Rope is not supported. + """ + + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) + self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + kv: torch.Tensor, + mask: Optional[Union[BlockMask, str]] = None, + ) -> torch.Tensor: + # B S D + bsz, seq_len, _ = x.shape + _, slen_kv, _ = kv.shape + x_norm = self.cross_attn_norm_q(x) + kv = self.cross_attn_norm_kv(kv) + + xq = self.wq(x_norm) + xk = self.wk(kv) + xv = self.wv(kv) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + # assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) + # output = flex_attention_comp(xq, xk, xv, block_mask=mask) + is_causal = (mask == "causal") if isinstance(mask, str) else False + mask = mask if isinstance(mask, torch.Tensor) else None + mask = mask.to(dtype=xq.dtype).to(xq.device) + output = F.scaled_dot_product_attention( + xq, + xk, + xv, + is_causal=is_causal, + attn_mask=mask, + ) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + output = self.wo(output.reshape(output_shape)) + + return x + output + + + + +class GlobalTransformer(nn.Module): + def __init__(self, config): + super().__init__() + + # Store config for later use + self.config = config + + self.dim = config.dim_global + self.init_base_std = config.init_base_std + self.attn_impl = config.attn_impl + self.attn_bias_type = config.attn_bias_type + self.init_std_factor = config.init_std_factor + self.max_seqlen = config.max_seqlen + self.rope_embeddings = RotaryEmbedding( + theta=config.rope_theta, + head_dim=config.head_dim or config.dim_global // config.n_heads_global, + max_seqlen=config.max_seqlen, + rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, + ) + # Handle both eos_id and eos_token_id for compatibility + self.eos_id = getattr(config, "eos_id", getattr(config, "eos_token_id", 2)) + + # Create parameter dict for BLTTransformerLayers + layer_params = { + "dim": self.dim, + "n_heads": config.n_heads_global, + "head_dim": config.head_dim, + "n_kv_heads": getattr(config, "n_kv_heads_global", None), + "rope_theta": config.rope_theta, + "multiple_of": getattr(config, "multiple_of", 256), + "ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None), + "norm_eps": config.norm_eps, + } + + self.layers = nn.ModuleList() + for _ in range(config.n_layers_global): + self.layers.append(BLTTransformerLayer(layer_params)) + + # GlobalTransformer specific attributes + self.dropout = config.dropout + self.dim_token_emb = config.global_dim_patch_emb + + self.token_embedding_projection = None + if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim: + self.token_embedding_projection = nn.Linear( + config.global_dim_patch_emb, + config.dim_global, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + embeds: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + bs, seqlen = tokens.shape + + h = embeds + + mask = ( + mask + if mask is not None + else create_causal_mask( + seqlen, + self.attn_impl, + self.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) + ) + + if self.token_embedding_projection is not None and h.shape[-1] != self.dim: + h = self.token_embedding_projection(h) + + h = F.dropout(h, p=self.dropout, training=self.training) + + freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx) + + for i, layer in enumerate(self.layers): + h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) + + return h, cache + + + + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """ + Compute embeddings using hash token embeddings. + + Args: + local_encoder_tokens: Input tokens tensor + local_encoder: Encoder object with tok_embeddings method + encoder_hash_tok_embedding: ModuleList of hash token embeddings + encoder_hash_byte_group_nb_functions: Number of hash functions + encoder_hash_byte_group_size: List of byte group sizes + encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings + + Returns: + torch.Tensor: Combined embeddings + """ + if encoder_hash_tok_embedding is None: + return None + + local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens) + + i = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for byte_group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, + byte_group_size, + hash_func_nb=func_nb, + max_hash=encoder_hash_byte_group_vocab, + ) + hash_tok_embedding = encoder_hash_tok_embedding[i] + local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) + i += 1 + + assert i == len(encoder_hash_tok_embedding) + return local_encoder_embeds + + +class BLTPreTrainedModel(PreTrainedModel): + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTTransformerLayer", "LocalEncoder", "LocalDecoder", "GlobalTransformer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = False # BLT uses its own attention implementation + _supports_sdpa = True + _supports_cache_class = False + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + std = getattr(module, '_custom_std', module.in_features ** (-0.5)) + + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.Embedding): + std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5)) + + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + elif isinstance(module, (nn.RMSNorm, nn.LayerNorm)): + nn.init.ones_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + + elif isinstance(module, RotaryEmbedding): + module.freqs_cis[...] = precompute_freqs_cis( + dim=module.head_dim, + end=module.max_seqlen, + theta=module.theta, + rope_use_fp32_in_outer_product=module.rope_use_fp32_in_outer_product, + ) + + elif isinstance(module, BLTModel): + if module.encoder_hash_tok_embedding is not None: + emb_std = module.local_encoder.dim ** (-0.5) + for emb in module.encoder_hash_tok_embedding: + emb._custom_std = emb_std + + elif isinstance(module, (LocalEncoder, LocalDecoder)): + if module.token_embedding_projection is not None: + module.token_embedding_projection._custom_std = module.dim ** (-0.5) + + if module.patch_embedding_projection is not None: + module.patch_embedding_projection._custom_std = module.dim_patch_emb ** (-0.5) + + elif isinstance(module, GlobalTransformer): + if module.token_embedding_projection is not None: + module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5) + + elif isinstance(module, BLTPatcher): + emb_std = module.config.patcher_dim ** (-0.5) + module.tok_embeddings._custom_std = emb_std + module.output._custom_std = emb_std + + +class BLTModel(BLTPreTrainedModel): + def __init__(self, config: BLTConfig): + super().__init__(config) + + self.config = config + self.local_encoder = LocalEncoder(config) + self.global_transformer = GlobalTransformer(config) + self.local_decoder = LocalDecoder(config) + + self.encoder_hash_tok_embedding = init_hash_embeddings( + config, + local_encoder_dim=self.local_encoder.dim, + encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, + ) + + if config.patch_in_forward: + self.patcher = BLTPatcher(config) + self.patcher.eval() + for param in self.patcher.parameters(): + param.requires_grad = False + else: + self.patcher = None + + + + def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: + """ + Convert patch lengths to patch IDs for each token position. + + For each token position in the sequence, determines which patch it belongs to. + + Args: + patch_lengths: [batch_size, num_patches] - length of each patch + seq_len: total sequence length + + Returns: + patch_ids: [batch_size, seq_len] - patch index for each token position + + Example: + patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1 + seq_len = 10 + Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]] + # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3 + """ + batch_size, num_patches = patch_lengths.shape + + # Create patch start positions: [0, 3, 5, 9] for the example above + patch_starts = torch.cat( + [ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1], # cumsum without the final total + ], + dim=-1, + ) + + # For each token position, find which patch it belongs to + # by finding the rightmost patch start that's <= the position + token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1] + + # Broadcasting: patch_starts[batch, patch] <= token_positions[position] + # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t + position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1) + + # Count how many patch starts are <= each position, then subtract 1 to get patch index + patch_ids = position_ge_patch_start.sum(dim=-1) - 1 + + return patch_ids + + def _decoder_patch_ids_from_lengths(self, patch_lengths: torch.Tensor, nb_boe: int, seq_len: int) -> torch.Tensor: + """ + Create decoder patch IDs by skipping the first encoder patch. + + The decoder starts after the first patch (which contains BOE tokens), + so we need to map decoder positions to the remaining patches. + + Args: + patch_lengths: [batch_size, num_patches] from encoder + nb_boe: number of beginning-of-example tokens in first patch + seq_len: decoder sequence length + + Returns: + decoder_patch_ids: [batch_size, seq_len] mapping decoder positions to patch indices + """ + # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens) + decoder_patch_lengths = patch_lengths[:, 1:] + + # Create patch IDs for the decoder sequence using the remaining patches + return self._patch_ids_from_lengths(decoder_patch_lengths, seq_len) + + def forward( + self, + tokens: torch.Tensor, + patch_lengths: Optional[torch.Tensor] = None, + ): + # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings + # are no longer used in the final BLT model + + bs, N = tokens.shape # Batch size and sequence length + + # Get megabyte inputs + nb_boe = int(0 if self.config.patching_mode != "" else self.config.patch_size - 1) + local_encoder_tokens, _, local_decoder_tokens = get_blt_input( + tokens=tokens, + enforce_patch_size_multiple=False, + nb_boe=nb_boe, + patch_size=self.config.patch_size, + boe_id=BOE_ID, + ) + + # Patching + if patch_lengths is None: + # assert ( + # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward + # ), "Patch in forward not enabled and no patch_lengths passed." + + # PATCHER MODEL DEFINED + if self.config.patching_mode == PatchingModeEnum.entropy: + _, patch_lengths, _ = self.patcher( + local_encoder_tokens, + patch_size=self.config.patch_size, + include_next_token=True, + threshold=self.config.patching_threshold, + threshold_add=self.config.patching_threshold_add, + monotonicity=self.config.monotonicity, + max_patch_length=self.config.max_patch_length, + patching_batch_size=self.config.patching_batch_size, + device=self.config.patching_device, + ) + else: + # self.config.patching_mode == PatchingModeEnum.byte + bs, seq_len = local_encoder_tokens.shape + seq_len_next_tok = seq_len + 1 # include_next_token=True + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device + ) + + # Apply any processing to patch lengths + if self.config.max_patch_length is not None: + # TODO: avoid going back to a list here. + patch_lengths = [ + BLTPatcher.split_large_numbers(pl, self.config.max_patch_length) + for pl in patch_lengths.tolist() + ] + max_len = max([len(pl) for pl in patch_lengths]) + patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] + patch_lengths = torch.tensor( + patch_lengths, dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device + ) + assert not check_non_zero_after_zero(patch_lengths) + # Find the last non-zero column index using argmax on a reversed version of the tensor + last_non_zero_col_reversed = (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() + # Slice the tensor up to the last non-zero column + patch_lengths = patch_lengths[:, : patch_lengths.shape[1] - last_non_zero_col_reversed] + else: + if nb_boe > 0: + patch_lengths[:, 0] += nb_boe + + assert torch.min(patch_lengths) >= 0 + + # Generate patch IDs from patch_lengths + patch_ids = self._patch_ids_from_lengths(patch_lengths, local_encoder_tokens.shape[-1]) + assert torch.max(patch_ids) + 1 <= torch.max((patch_lengths != 0).sum(dim=-1)), ( + f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" + ) + + cross_attn_mask_enc = None + # Cross-attention encoder + if self.config.cross_attn_encoder: + cross_attn_mask_enc = cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=True, + cross_attn_k=self.config.cross_attn_k, + window=self.config.cross_attn_window_encoder, + block_mask=self.config.cross_attn_use_flex_attention, + ) + + # Hashing and embedding + local_encoder_embeds = compute_hash_embeddings( + local_encoder_tokens=local_encoder_tokens, + local_encoder=self.local_encoder, + encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, + encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions, + encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size, + encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab, + ) + + # NOTE: Frequency-based n-gram embeddings removed as per paper + # The final BLT model uses only hash-based n-gram embeddings + + # Local encoder + (h_encoder, h_cross), cache_encoder = self.local_encoder( + tokens=local_encoder_tokens, + embeds=local_encoder_embeds, + patch_embeds=None, + cross_mask=cross_attn_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + + # Downsampling + h = h_cross.view(bs, patch_lengths.shape[1], -1) + + # Global transformer + global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(BOE_ID) + rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id) + eos_patch_ids = patch_ids[rows, cols] + global_tokens[rows, eos_patch_ids] = self.config.eos_token_id + + h, _ = self.global_transformer( + embeds=h, + tokens=global_tokens, + ) + + # Unpatching + dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :] + + # Generate decoder patch IDs + decoder_patch_ids = self._decoder_patch_ids_from_lengths(patch_lengths, nb_boe, local_decoder_tokens.shape[-1]) + assert torch.max(decoder_patch_ids) + 1 <= h.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" + assert decoder_patch_ids.shape[1] == dec_embeds.shape[1], ( + f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" + ) + + # Cross-attention decoder + if not self.config.cross_attn_decoder: + h = torch.gather(h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])) + cross_attn_mask_dec = None + assert local_decoder_tokens.shape == h.shape[:-1] + else: + cross_attn_mask_dec = cross_attn_mask( + decoder_patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=self.config.cross_attn_k, + window=self.config.cross_attn_window_decoder, + block_mask=self.config.cross_attn_use_flex_attention, + ) + + # Local decoder + output, _ = self.local_decoder( + embeds=dec_embeds, + patch_embeds=h, + tokens=local_decoder_tokens, + cross_mask=cross_attn_mask_dec, + ) + return output + + +class BLTPatcher(BLTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.rope_embeddings = RotaryEmbedding( + theta=config.patcher_rope_theta, + head_dim=config.patcher_head_dim or config.patcher_dim // config.patcher_n_heads, + max_seqlen=config.patcher_max_seqlen, + rope_use_fp32_in_outer_product=config.patcher_rope_use_fp32_in_outer_product, + ) + # Handle both eos_id and eos_token_id for compatibility + self.eos_id = config.patcher_eos_token_id + + # Extract additional parameters for BLTTransformerLayer + n_kv_heads = ( + getattr(config, "patcher_n_kv_heads", None) + if hasattr(config, "patcher_dim") + else getattr(config, "n_kv_heads", None) + ) + multiple_of = ( + getattr(config, "patcher_multiple_of", 256) + if hasattr(config, "patcher_dim") + else getattr(config, "multiple_of", 256) + ) + ffn_dim_multiplier = ( + getattr(config, "patcher_ffn_dim_multiplier", None) + if hasattr(config, "patcher_dim") + else getattr(config, "ffn_dim_multiplier", None) + ) + + self.layers = nn.ModuleList() + for _ in range(config.patcher_n_layers): + self.layers.append( + BLTTransformerLayer( + { + "dim": config.patcher_dim, + "n_heads": config.patcher_n_heads, + "head_dim": config.patcher_head_dim, + "n_kv_heads": n_kv_heads, + "rope_theta": config.patcher_rope_theta, + "multiple_of": multiple_of, + "ffn_dim_multiplier": ffn_dim_multiplier, + "norm_eps": config.patcher_norm_eps, + } + ) + ) + + # LMTransformer specific attributes + self.sliding_window = config.patcher_sliding_window + + assert config.patcher_vocab_size > 0 + + self.tok_embeddings = torch.nn.Embedding(config.patcher_vocab_size, config.patcher_dim) + + self.norm = RMSNorm(config.patcher_dim, eps=config.patcher_norm_eps) + + self.output = nn.Linear( + config.patcher_dim, + config.patcher_vocab_size, + bias=False, + ) + + def forward( + self, + token_values: torch.Tensor, + target: Optional[torch.Tensor] = None, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, + attn_impl: str | None = None, + patch_size: Optional[int] = None, + include_next_token: bool = True, + threshold: Optional[float] = None, + threshold_add: Optional[float] = None, + monotonicity: bool = False, + max_patch_length: Optional[int] = None, + patching_batch_size: int = 1, + device: Optional[str] = None, + enable_grad: bool = False, + ): + attn_impl = self.config.patcher_attn_impl if attn_impl is None else attn_impl + + # Handle chunked processing for entropy calculation + entropies = [] + preds = [] + max_length = min(getattr(self, "max_length", 8192), self.config.patcher_max_seqlen) + batch_numel = max_length * patching_batch_size + splits = torch.split(token_values.flatten(), batch_numel) + + for split in splits: + pad_size = (max_length - (split.numel() % max_length)) % max_length + pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False) + split = torch.cat((split, pad), dim=0) + split = split.reshape(-1, max_length) + if device is not None: + split = split.to(device) + + # Process chunk: embeddings -> layers -> output + bsz, seqlen = split.shape + h = self.tok_embeddings(split) + chunk_mask = create_causal_mask( + seqlen, + attn_impl, + self.config.patcher_attn_bias_type, + sliding_window=self.sliding_window, + tokens=split, + eos_id=self.eos_id, + ) + freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None) + + for i, layer in enumerate(self.layers): + h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=attn_impl) + + pred = self.output(self.norm(h)) + pred = pred.reshape(-1, pred.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab] + preds.append(pred) + pred_entropies = self.entropy(pred) + entropies.append(pred_entropies) + + concat_entropies = torch.cat(entropies, dim=0) + concat_entropies = concat_entropies.reshape(token_values.shape) + concat_preds = torch.cat(preds, dim=0) + concat_preds = concat_preds.reshape(token_values.shape[0], -1) + + # Always compute patch lengths from concatenated entropies + bs, seq_len = token_values.shape + seq_len_next_tok = seq_len + 1 if include_next_token else seq_len + + # Find patch start IDs based on entropy + if patch_size is not None: + patch_start_ids = self.find_entropy_patch_start_ids( + concat_entropies, + patch_size, + include_next_token=include_next_token, + threshold=threshold, + threshold_add=threshold_add, + monotonicity=monotonicity, + ) + patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok) + else: + # Default to byte-level patching + patch_lengths = torch.ones((bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device) + + # Apply any processing to patch lengths + if max_patch_length is not None: + # TODO: avoid going back to a list here. + patch_lengths = [self.split_large_numbers(pl, max_patch_length) for pl in patch_lengths.tolist()] + max_len = max([len(pl) for pl in patch_lengths]) + patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] + patch_lengths = torch.tensor(patch_lengths, dtype=token_values.dtype, device=token_values.device) + assert not check_non_zero_after_zero(patch_lengths) + # Find the last non-zero column index using argmax on a reversed version of the tensor + last_non_zero_col_reversed = (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() + # Slice the tensor up to the last non-zero column + patch_lengths = patch_lengths[:, : patch_lengths.shape[1] - last_non_zero_col_reversed] + + return concat_entropies, patch_lengths, concat_preds + + + + + + @staticmethod + def entropy(scores): + """ + scores: [bs, seq_len, vocab] + returns [bs, seq_len] + + Computes the entropy for each token in the batch. + Note: uses natural log. + """ + log_probs = F.log_softmax(scores, dim=-1) + probs = torch.exp(log_probs) + p_log_p = log_probs * probs + entropy = -p_log_p.sum(dim=-1) + return entropy + + @staticmethod + def patch_start_ids_from_patch_start_mask(patch_start_mask): + bs, trunc_seq_len = patch_start_mask.shape + max_patches = patch_start_mask.sum(dim=1).max() + if max_patches == 0: + patch_start_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + else: + patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(bs, 1) + extra_patch_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) + patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1) + patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, trunc_seq_len)[:, :max_patches] + return patch_start_ids + + @staticmethod + def patch_lengths_from_start_ids(patch_start_ids, seq_len): + """ + Calculate patch lengths from start ids. + start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then + the rest are filled to the seq len. + seq_len: ex: 7 length of the sequence + + returns the patch lengths: + [1, 6] for the above example. + """ + last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) + patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) + patch_lengths = patch_end_ids - patch_start_ids + 1 + assert torch.all(patch_lengths >= 0), f"{patch_lengths}" + assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}" + return patch_lengths + + @staticmethod + def find_entropy_patch_start_ids( + entropies, + patch_size=None, + threshold=None, + threshold_add=None, + monotonicity=False, + include_next_token=True, + ): + """ + Use entropies to find the start ids of each patch. + Use patch_size or threshold to figure out the total number of patches to allocate. + + When threshold is not None the number of patches is not constant between + different sequences, but patches can be identified incrementally rather than + decided globally using the entire sequence. + """ + bs, seq_len = entropies.shape[:2] + + first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(bs, 1) + preds_truncation_len = first_ids.shape[1] # remove the first preds because they will be start of patches. + entropies = entropies[:, 1:] + if threshold is None: + num_patches = seq_len // patch_size + patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices + patch_start_ids = patch_start_ids.sort(dim=1).values + else: + patch_start_mask = entropies > threshold + if not include_next_token: + patch_start_mask = patch_start_mask[:, :-1] + # patch_start_mask[1:] |= tokens[:-1] < OFFSET + patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask) + + patch_start_ids = torch.cat((first_ids, patch_start_ids + preds_truncation_len), dim=1) + return patch_start_ids + + @staticmethod + def split_large_numbers(lst, m): + new_lst = [] + for i in lst: + if i > m: + while i > m: + new_lst.append(m) + i -= m + new_lst.append(i) + else: + new_lst.append(i) + assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" + return new_lst + + +def init_hash_embeddings( + config, + local_encoder_dim: int, + encoder_hash_byte_group_size: list, +): + """Initialize hash-based token embeddings for the BLT encoder.""" + if config.encoder_hash_byte_group_size is None: + return None + + embeddings = [] + emb_dim = local_encoder_dim + encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab + + for _ in range(config.encoder_hash_byte_group_nb_functions): + for _ in encoder_hash_byte_group_size: + embeddings.append( + nn.Embedding( + encoder_hash_byte_group_vocab, + emb_dim, + ) + ) + + return nn.ModuleList(embeddings) + + +__all__ = [ + "BLTPreTrainedModel", + "BLTModel", + "BLTPatcher", + "LocalEncoder", + "LocalDecoder", + "GlobalTransformer", +] \ No newline at end of file diff --git a/backup_blt_wip_backup/modeling_blt_wip_backup.py b/backup_blt_wip_backup/modeling_blt_wip_backup.py new file mode 100644 index 0000000000000000000000000000000000000000..adc4104dcbebb38e3a72866ade92eba569f3df3c --- /dev/null +++ b/backup_blt_wip_backup/modeling_blt_wip_backup.py @@ -0,0 +1,2166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from enum import Enum +from typing import Any, List, Optional, Tuple, Union + +import torch +from pydantic import model_validator +from torch import nn +from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention +import json +import logging + +import torch +import torch.nn +import torch.nn as nn +from torch.nn import functional as F + +import os +from contextlib import nullcontext + +SEP = " " +BOS_ID: int = 1 +EOS_ID: int = 2 +PAD_ID: int = -1 +BOE_ID: int = 0 +BPE_ID: int = 3 +OFFSET: int = 4 + +BYTE_UNITS: int = 256 + +RMSNorm = nn.RMSNorm + +logger = logging.getLogger() + +from .configuration_blt import ( + BLTConfig, + PatchingModeEnum, + InitStdFactor, +) + +from ...modeling_utils import PreTrainedModel +from ...utils import logging as transformers_logging + +flex_attention_comp = flex_attention + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +def create_causal_mask( + seqlen, + attn_impl: str, + attn_bias_type: str | None, + *, + eos_id: int | None = None, + tokens: torch.Tensor | None = None, + sliding_window: int | None = None, +): + if attn_impl == "sdpa": + BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0)) + + if attn_bias_type == "causal": + return "causal" + + if BLT_SUPPRESS_ATTN_ERROR == 1: + return "causal" + else: + raise ValueError( + "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1" + ) + elif attn_impl == "flex_attention": + return create_block_mask(causal_mask, None, None, seqlen, seqlen) + else: + raise NotImplementedError( + f"Attention {attn_impl} with {sliding_window} sliding window not implemented" + ) + +def cross_entropy(pred, target, **kwargs): + return F.nll_loss( + F.log_softmax(pred.flatten(end_dim=-2).float(), -1), + target.flatten(end_dim=-1), + **kwargs, + ) + + +def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims." + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + rope_use_fp32_in_outer_product: bool = False, +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + if rope_use_fp32_in_outer_product: + t = t.to(torch.float32) + + freqs = torch.outer(t, freqs).float() + + cos, sin = freqs.cos(), freqs.sin() + + return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + seq_dim (int): Sequence dimension index. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= seq_dim < ndim + assert freqs_cis.shape == ( + x.shape[seq_dim], + x.shape[-3], + 2, + 2, + ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" + shape = [ + d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2]) + ] + [2, 2] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + seq_dim: int, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 + freqs_cis = reshape_for_broadcast( + freqs_cis, xq_, seq_dim + ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 + xq_out = (xq_ * freqs_cis).sum(5).flatten(3) + xk_out = (xk_ * freqs_cis).sum(5).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed. +class RotaryEmbedding(torch.nn.Module): + """ + RotaryEmbedding Module + """ + + def __init__( + self, + theta: float, + head_dim: int, + max_seqlen: int = 1024, + rope_use_fp32_in_outer_product: bool = False, + ): + super().__init__() + + self.theta = theta + self.head_dim = head_dim + self.max_seqlen = max_seqlen + self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + dim=head_dim, + end=max_seqlen, + theta=theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, + ), + persistent=False, + ) + + def reset_parameters(self): + self.freqs_cis[...] = precompute_freqs_cis( + dim=self.head_dim, + end=self.max_seqlen, + theta=self.theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, + ) + + def forward( + self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None + ): + """ + Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions + Args: + seqlen (int): Contiguous sequence length + tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen + + Returns: + Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis + """ + test = (seqlen is not None) or (tok_idx is not None) + assert test, "Should provide atleast seqlen or tok_idx" + if tok_idx is not None: + return self.freqs_cis[tok_idx] + elif seqlen is not None: + return self.freqs_cis[0:seqlen] + + +class BLTAttention(nn.Module): + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + rope_theta: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + self.rope_theta = rope_theta + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + # B S D + bsz, seq_len, dim = x.shape + xq = self.wq(x.view_as(x)) + xk = self.wk(x.view_as(x)) + xv = self.wv(x.view_as(x)) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len]) + + # This condition helps us be easily compatible + # with inference by adding a pluggable KVCache + if hasattr(self, "kv_cache"): + xk, xv = self.kv_cache.update(xk, xv, tok_idx) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + if attn_impl == "flex_attention": + assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + output = flex_attention_comp(xq, xk, xv, block_mask=mask) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + elif attn_impl == "sdpa": + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + assert mask is None or isinstance(mask, (str, torch.Tensor)) + is_causal = (mask == "causal") if isinstance(mask, str) else False + mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None + output = F.scaled_dot_product_attention( + xq, + xk, + xv, + is_causal=is_causal, + attn_mask=mask, + ) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + else: + raise NotImplementedError( + f"Attention implementation {attn_impl} not supported" + ) + + output_reshaped = output.reshape(output_shape) + + output = self.wo(output_reshaped) + + return output + + def reset_parameters(self, init_std=None, factor=1.0): + init_std = init_std or (self.dim ** (-0.5)) / factor + + for w in [self.wq, self.wk, self.wv]: + nn.init.trunc_normal_( + w.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + nn.init.trunc_normal_( + self.wo.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + +class BLTMLP(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + mp_size: int = 1, + ): + super().__init__() + + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + assert hidden_dim % mp_size == 0 + + self.dim = dim + self.hidden_dim = hidden_dim + + self.w1 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w3 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + self.w2 = nn.Linear( + hidden_dim, + dim, + bias=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # B S D + x1 = self.w1(x.view_as(x)) + x3 = self.w3(x.view_as(x)) + output = self.w2(F.silu(x1) * x3) + return output + + def reset_parameters(self, init_std=None, factor=1.0): + in_init_std = init_std or (self.dim ** (-0.5)) / factor + out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor + + nn.init.trunc_normal_( + self.w1.weight, + mean=0.0, + std=in_init_std, + a=-3 * in_init_std, + b=3 * in_init_std, + ) + nn.init.trunc_normal_( + self.w2.weight, + mean=0.0, + std=out_init_std, + a=-3 * out_init_std, + b=3 * out_init_std, + ) + nn.init.trunc_normal_( + self.w3.weight, + mean=0.0, + std=in_init_std, + a=-3 * in_init_std, + b=3 * in_init_std, + ) + + +class BLTTransformerLayer(nn.Module): + def __init__(self, args): + super().__init__() + + # Extract parameters from dictionary + dim = args['dim'] + n_heads = args['n_heads'] + head_dim = args['head_dim'] + n_kv_heads = args['n_kv_heads'] + rope_theta = args['rope_theta'] + multiple_of = args['multiple_of'] + ffn_dim_multiplier = args['ffn_dim_multiplier'] + norm_eps = args['norm_eps'] + + assert (head_dim is not None) or ( + n_heads is not None + ), "Should specify at least head_dim or n_heads" + self.head_dim = head_dim or dim // n_heads + self.n_heads = n_heads or dim // head_dim + self.n_kv_heads = n_kv_heads or self.n_heads + + assert n_heads % self.n_kv_heads == 0 + assert dim % n_heads == 0 + + self.attention = BLTAttention( + dim=dim, + head_dim=self.head_dim, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + rope_theta=rope_theta, + ) + self.feed_forward = BLTMLP( + dim=dim, + hidden_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + self.attention_norm = RMSNorm(dim, eps=norm_eps) + self.ffn_norm = RMSNorm(dim, eps=norm_eps) + + def forward( + self, + x: torch.Tensor, + freq_cis: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, str]] = None, + attn_impl: str = "sdpa", + ) -> torch.Tensor: + norm_x = self.attention_norm(x) + attn_out = self.attention( + norm_x, + freq_cis, + tok_idx=tok_idx, + mask=mask, + attn_impl=attn_impl, + ) + h = x + attn_out + h_norm = self.ffn_norm(h) + out = h + self.feed_forward(h_norm) + return out + + def init_weights(self, init_std=None, factor=1.0): + self.attention.reset_parameters(init_std, factor) + self.attention_norm.reset_parameters() + + self.feed_forward.reset_parameters(init_std, factor) + self.ffn_norm.reset_parameters() + + +def rightpad(seq, pad_id, max_len): + return seq + [pad_id] * (max_len - len(seq)) + + +def check_non_zero_after_zero(tensor): + zero_mask = tensor == 0 + shifted_mask = torch.cat( + [ + torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), + zero_mask[:, :-1], + ], + dim=1, + ) + non_zero_after_zero = (tensor != 0) & shifted_mask + return non_zero_after_zero.any() + + +def fill_tokens(tokens, patch_size, fill_id): + batch_size, seq_len = tokens.shape + if seq_len % patch_size == 0: + return tokens + else: + remaining = patch_size - seq_len % patch_size + final_padding = tokens.new(batch_size, remaining).fill_(fill_id) + return torch.cat((tokens, final_padding), dim=1) + + +def rolling_polynomial_hash(t, hash_func_nb: int = 0): + primes = [ + 1000000007, + 5915587277, + 1500450271, + 3267000013, + 5754853343, + 4093082899, + 9576890767, + 3628273133, + 2860486313, + 5463458053, + 3367900313, + ] + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) + prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) + return torch.sum(t * prime_powers, dim=-1) + +def byte_group_hash_function( + x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000 +): + """ + Returns a hash of the input x and maps it to a value in the range [0, max_hash]. + + expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. + returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. + + Note: max hash can make a big difference on the number of collisions. + """ + with torch.no_grad(): + bs, seq_len = x.shape + prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) + x = torch.cat([prefix, x], dim=1) + windows = x.unfold(1, group_size, 1) + # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows) + hashes = rolling_polynomial_hash(windows, hash_func_nb) + hash_values_range = hashes % max_hash + hash_values_range.requires_grad = False + return hash_values_range + + +def create_patch_mask_from_ids( + patch_ids, num_patches, window=None, patches_as_queries=False +): + """ + Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) + is True if the patch id at position (i, j) is less than or equal to k. + Args: + patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids. + num_patches (int): Total number of patches. + window (int): If not None, only considers patches within a window of size window. + patches_as_queries (bool): If True, the patches are used as queries + Returns: + torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask. + """ + bs, seq_len = patch_ids.shape + if not patches_as_queries: + q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches) + kv_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(0) + .expand(bs, seq_len, num_patches) + ) + else: + kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len) + q_ids = ( + torch.arange(num_patches, device=patch_ids.device) + .unsqueeze(0) + .unsqueeze(-1) + .expand(bs, num_patches, seq_len) + ) + if window is None: + mask = q_ids == kv_ids + else: + mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window) + return mask + + +def cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=1, + window=None, + block_mask=True, +): + bs = patch_ids.shape[0] + with torch.no_grad(): + # Create the patch mask + cross_mask = create_patch_mask_from_ids( + patch_ids, + patch_lengths.shape[1], + window=window, + patches_as_queries=patches_as_queries, + ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1) + q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N + kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k + assert cross_mask.shape == ( + bs, + q_len, + kv_len, + ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" + block_mask = None + if block_mask: + + def patch_mask(b, h, q_idx, kv_idx): + return cross_mask[b, q_idx, kv_idx] + + block_mask = create_block_mask( + patch_mask, + B=bs, + H=None, + Q_LEN=q_len, + KV_LEN=kv_len, + _compile=True, + ) + return block_mask + else: + return torch.where( + cross_mask, torch.tensor(0.0), torch.tensor(float("-inf")) + ).unsqueeze( + 1 + ) # [bs, 1, q_len, kv_len] + + +def get_blt_input( + tokens: torch.Tensor, + enforce_patch_size_multiple: bool, + nb_boe: torch.Tensor, + patch_size: int, + boe_id: int, +): + """ + This function returns X_et, X_gt and X_dt, the encoder, global, and decoder + tokens respectively. + + Consider the input and target sequences: + X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13] + Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14] + with patch_size=4 + + Note 1: that there will be no special tokens introduced at the patch level. + Note 2: X_e needs to be trimmed to be passed to Global + + Current without boe: + X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] + X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch + X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] + Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] + + --> lag fix: + X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]] + X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]] + X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]] + Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]] + + Dynamic (current): + X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos] + Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] + + entropy patching: + input: 7, bos, 9, 10 + pred (high entropy): eos, 8, 10, eos + + X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos] + X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]] + X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]] + Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11] + + --> lag fix no boe (force single byte first patch): + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch + X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] + Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] + + input: 4, 7, bos, 9, 10 + pred (high entropy): 5, eos, 8, 10, eos + + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch + X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]] + Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13] + + Handle the last byte properly. + patch_lengths = [1, 1, 3, 2, 2 1 2 2 1] + X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12] + X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch + X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]] + Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]] + + + bpe delim + X_et = [[3,4,5,6,7,,eos,bos,,8,9,,10,,eos,bos,11,12] + X_g = [[3], [4,5,6,7,], [eos,bos,], .. + X_dt = [[3,4,5,6,7], [,eos,bos], [,bos,8], .. + Y = [4,5,6,7,, eos,bos, 8,9,, .. + + + Note 1: that there will be no special tokens introduced at the patch level. + Note 2: X_e needs to be trimmed to be passed to Global + """ + batch_size, seq_len = tokens.shape + local_encoder_tokens = tokens + local_decoder_tokens = tokens + + if nb_boe > 0: + padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id) + local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1) + # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id) + + # create global tokens, contains boe tokens and eos + # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) + # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size) + # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:] + # global_tokens += global_tokens.eq(0).int() * boe_id + # TODO: fix this when we want to use block causal in the global. + + if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0: + local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id) + + return local_encoder_tokens, None, local_decoder_tokens + + +class LocalModelBase(nn.Module): + def __init__(self, config: BLTConfig, component_type: str = "encoder"): + super().__init__() + + # Store config for later use + self.config = config + + # Use component-specific dimensions + if component_type == "encoder": + self.dim = config.dim_local_encoder + self.n_layers = config.n_layers_local_encoder + self.n_heads = config.n_heads_local_encoder + self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen + self.attn_bias_type = "local_block_causal" + self.sliding_window = config.local_attention_window_len + elif component_type == "decoder": + self.dim = config.dim_local_decoder + self.n_layers = config.n_layers_local_decoder + self.n_heads = config.n_heads_local_decoder + self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen + self.attn_bias_type = "local_block_causal" + self.sliding_window = config.local_attention_window_len + else: + raise ValueError(f"Unknown component_type: {component_type}") + + self.dropout = config.dropout + self.vocab_size = config.vocab_size + config.pm_size + self.patch_size = config.patch_size + + self.attn_impl = config.attn_impl + self.use_rope = config.use_rope + self.init_std_factor = config.init_std_factor + self.init_base_std = config.init_base_std + self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None) + self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None) + self.cross_attn_k = getattr(config, "cross_attn_k", None) + self.eos_id = config.eos_token_id + + self.boe_id = BOE_ID + + # Initialize cross attention layers as None (will be set by subclasses if needed) + self.cross_attn_layers = None + + # Create parameter dict for BLTTransformerLayers + layer_params = { + 'dim': self.dim, + 'n_heads': self.n_heads, + 'head_dim': config.head_dim, + 'n_kv_heads': getattr(config, 'n_kv_heads', None), + 'rope_theta': config.rope_theta, + 'multiple_of': getattr(config, 'multiple_of', 256), + 'ffn_dim_multiplier': getattr(config, 'ffn_dim_multiplier', None), + 'norm_eps': config.norm_eps, + } + + self.layers = nn.ModuleList( + [BLTTransformerLayer(layer_params) for _ in range(self.n_layers)] + ) + + if not self.use_rope: + self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length + else: + self.rope = RotaryEmbedding( + theta=config.rope_theta, + head_dim=config.head_dim or self.dim // self.n_heads, + max_seqlen=self.max_seqlen, + rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, + ) + self.pos_embeddings = None + + # Set dimension-specific embedding dimensions + if component_type == "encoder": + self.dim_token_emb = config.encoder_dim_token_emb + self.dim_patch_emb = config.encoder_dim_patch_emb + elif component_type == "decoder": + self.dim_token_emb = config.decoder_dim_token_emb + self.dim_patch_emb = config.dim_global + + self.token_embedding_projection = ( + nn.Linear(self.dim_token_emb, self.dim, bias=False) + if self.dim_token_emb is not None and self.dim_token_emb != self.dim + else None + ) + + self.patch_embedding_projection = self._create_patch_projection(config) + + def _should_create_patch_projection(self, config: BLTConfig): + dimension_mismatch = ( + self.dim_patch_emb is not None and self.dim_patch_emb != self.dim + ) + + # Check cross attention conditions + cross_attn_conditions = ( + config.cross_attn_encoder and config.cross_attn_init_by_pooling + ) or (config.cross_attn_decoder and config.cross_attn_init_by_pooling) + + return dimension_mismatch or cross_attn_conditions + + def _create_patch_projection(self, config): + if not self._should_create_patch_projection(config): + return None + + output_dim = self.dim_token_emb * (self.cross_attn_k or 1) + + return nn.Linear( + in_features=self.dim_patch_emb, + out_features=output_dim, + bias=False, + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + return embeds + else: + return self.tok_embeddings(tokens) + + def init_weights(self, init_std=None): + self.rope.reset_parameters() + if hasattr(self, "norm"): + self.norm.reset_parameters() + + init_std = init_std or (self.dim ** (-0.5)) + if hasattr(self, "tok_embeddings"): + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + if self.pos_embeddings is not None: + nn.init.trunc_normal_( + self.pos_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + for depth, layer in enumerate(self.layers): + factor = self.config.get_init_std_factor(depth) + layer.init_weights(self.init_base_std, factor) + + if hasattr(self, "output"): + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + if self.token_embedding_projection is not None: + nn.init.trunc_normal_( + self.token_embedding_projection.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + if self.patch_embedding_projection is not None: + patch_emb_std = self.dim_patch_emb ** (-0.5) + nn.init.trunc_normal_( + self.patch_embedding_projection.weight, + mean=0.0, + std=patch_emb_std, + a=-3 * patch_emb_std, + b=3 * patch_emb_std, + ) + + if self.cross_attn_layers is not None: + for depth, layer in enumerate(self.cross_attn_layers): + factor = self.config.get_init_std_factor(depth) + layer.init_weights(None, factor) + + +class LocalEncoder(LocalModelBase): + def __init__(self, config: BLTConfig): + super().__init__(config, component_type="encoder") + + self.apply_transformer = config.use_local_encoder_transformer + self.downsampling_by_pooling = config.downsampling_by_pooling + self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None + self.cross_attn_encoder = config.cross_attn_encoder + self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder + self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_nheads = config.cross_attn_nheads + + self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim) + + if self.cross_attn_encoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=config.norm_eps, + ) + ) + + def apply_embedding(self, tokens, embeds): + if embeds is not None: + assert ( + self.expects_hash_embeddings + ), "Not expecting embeddings to be passed." + return embeds + else: + return self.tok_embeddings(tokens) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor] = None, + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + num_patches: Optional[int] = None, + patch_ids: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + """ """ + bs, seqlen = tokens.shape + if mask is None: + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) + + h = self.apply_embedding(tokens, embeds) + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + h = F.dropout(h, p=self.dropout, training=self.training) + + for i, layer in enumerate(self.layers): + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) + # check if cross attention should be applied to either all layer or only the last layer + if self.cross_attn_encoder and ( + i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder + ): + # apply pooling and project + if self.cross_attn_init_by_pooling and patch_embeds is None: + patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids) + if self.patch_embedding_projection is not None: + patch_embeds = self.patch_embedding_projection(patch_embeds) + patch_embeds = patch_embeds.reshape( + bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim + ) + + layer_idx = i if self.cross_attn_all_layers_encoder else 0 + patch_embeds_cross = self.cross_attn_layers[layer_idx]( + x=patch_embeds, + kv=h, + mask=cross_mask, + ) + patch_embeds = patch_embeds + patch_embeds_cross + + h_residual = patch_embeds if self.cross_attn_encoder else None + return (h, h_residual), cache + + + + def patch_reduce(self, h, max_num_patches, reduction, patch_ids): + """ + Reduce variable length patches to single embedding per patch + Note: this works with variable number of patches for different sequences in the batch + It handles variable length patches by assuming that patch_lengths will be 0 for any + extra patches on the *right*. Since there can be a variable number of patches + this function also return the number of patches for each sequence in the batch. + Any embeddings on the right that are not allocated to a patch + (i.e. if the sum(patch_lengths[i]) < seq_len for any i) + will be sent to a dummy patch, which is trimmed before returning. + """ + bs, seq_len, emb_dim = h.shape + + patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + + reduced_embs = torch.zeros( + (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device + ) + reduced_embs = reduced_embs.scatter_reduce( + src=h, + dim=1, + index=patch_ids, + reduce=reduction, + include_self=False, + ) + reduced_embs = reduced_embs[:, :max_num_patches, :] + + return reduced_embs + + +class LocalDecoder(LocalModelBase): + def __init__(self, config: BLTConfig): + super().__init__(config, component_type="decoder") + + # Model configuration flags + self.cross_attn_decoder = config.cross_attn_decoder + self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder + self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling + self.cross_attn_nheads = config.cross_attn_nheads + + self.norm = RMSNorm(self.dim, eps=config.norm_eps) + + if self.cross_attn_decoder: + self.cross_attn_layers = torch.nn.ModuleList() + layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1 + for _ in range(layers_to_add): + self.cross_attn_layers.append( + BLTCrossAttention( + dim=self.dim, + head_dim=self.dim // self.cross_attn_nheads, + n_heads=self.cross_attn_nheads, + n_kv_heads=self.cross_attn_nheads, + norm_eps=config.norm_eps, + ) + ) + + self.output = nn.Linear( + self.dim, + config.vocab_size, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + embeds: Optional[torch.Tensor], + patch_embeds: Optional[torch.Tensor] = None, + mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, + cross_mask: Optional[torch.Tensor] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + bs, seqlen = tokens.shape + assert embeds is not None, "Embeddings must be provided" + + if mask is None: + mask = create_causal_mask( + seqlen, + self.attn_impl, + "local_block_causal", + sliding_window=self.sliding_window, + tokens=tokens, + eos_id=self.eos_id, + ) + + h = embeds + + if self.patch_embedding_projection is not None: + assert patch_embeds is not None, "Patch embeddings must be passed." + patch_embeds = self.patch_embedding_projection(patch_embeds) + if self.cross_attn_k is not None: + patch_embeds = patch_embeds.reshape( + bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim + ) + + if patch_embeds is not None and not self.cross_attn_decoder: + h = h + patch_embeds + + freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None + + h = F.dropout(h, p=self.dropout, training=self.training) + for i, layer in enumerate(self.layers): + if self.cross_attn_decoder and ( + i == 0 or self.cross_attn_all_layers_decoder + ): + # Use cross attention to extract info from patch_embeds into h + h_cross = self.cross_attn_layers[i]( + x=h, + kv=patch_embeds, + mask=cross_mask, + ) + h = h + h_cross + + h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl) + + h_preds = self.norm(h) + h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) + h_preds = self.output(h_preds) + h_preds = h_preds.float() + return h_preds, cache + + +class BLTCrossAttention(nn.Module): + """ + BLTCrossAttention block to attend to the encoder states from the decoder. + Rope is not supported. + """ + + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + ): + super().__init__() + + self.dim = dim + self.head_dim = head_dim + + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.heads_per_group = self.n_heads // self.n_kv_heads + + self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) + self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) + + self.wq = nn.Linear( + dim, + n_heads * head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + n_kv_heads * head_dim, + bias=False, + ) + + self.wo = nn.Linear( + n_heads * head_dim, + dim, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + kv: torch.Tensor, + mask: Optional[Union[BlockMask, str]] = None, + ) -> torch.Tensor: + # B S D + bsz, seq_len, _ = x.shape + _, slen_kv, _ = kv.shape + x_norm = self.cross_attn_norm_q(x) + kv = self.cross_attn_norm_kv(kv) + + xq = self.wq(x_norm) + xk = self.wk(kv) + xv = self.wv(kv) + + output_shape = xq.shape + # B S D -> B S H D + xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) + xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) + + xk = repeat_kv(xk, self.heads_per_group, dim=2) + xv = repeat_kv(xv, self.heads_per_group, dim=2) + + # assert mask is None or isinstance(mask, BlockMask) + xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv)) + #output = flex_attention_comp(xq, xk, xv, block_mask=mask) + is_causal = (mask == "causal") if isinstance(mask, str) else False + mask = mask if isinstance(mask, torch.Tensor) else None + mask = mask.to(dtype=xq.dtype).to(xq.device) + output = F.scaled_dot_product_attention( + xq, + xk, + xv, + is_causal=is_causal, + attn_mask=mask, + ) + output = output.transpose(1, 2).contiguous() # B H S D -> B S H D + + output = self.wo(output.reshape(output_shape)) + + return x + output + + def init_weights(self, base_std: float, factor: float = 1.0): + std = base_std or (self.dim ** (-0.5)) / factor + + nn.init.trunc_normal_( + self.wq.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wk.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wv.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + + nn.init.trunc_normal_( + self.wo.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + self.cross_attn_norm_q.reset_parameters() + self.cross_attn_norm_kv.reset_parameters() + + +class GlobalTransformer(nn.Module): + def __init__(self, config): + super().__init__() + + # Store config for later use + self.config = config + + self.dim = config.dim + self.init_base_std = config.init_base_std + self.attn_impl = config.attn_impl + self.attn_bias_type = config.attn_bias_type + self.init_std_factor = config.init_std_factor + self.max_seqlen = config.max_seqlen + self.rope_embeddings = RotaryEmbedding( + theta=config.rope_theta, + head_dim=config.head_dim or config.dim // config.n_heads, + max_seqlen=config.max_seqlen, + rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, + ) + # Handle both eos_id and eos_token_id for compatibility + self.eos_id = getattr(config, 'eos_id', getattr(config, 'eos_token_id', 2)) + + # Create parameter dict for BLTTransformerLayers + layer_params = { + 'dim': self.dim, + 'n_heads': config.n_heads, + 'head_dim': config.head_dim, + 'n_kv_heads': getattr(config, 'n_kv_heads', None), + 'rope_theta': config.rope_theta, + 'multiple_of': getattr(config, 'multiple_of', 256), + 'ffn_dim_multiplier': getattr(config, 'ffn_dim_multiplier', None), + 'norm_eps': config.norm_eps, + } + + self.layers = nn.ModuleList() + for _ in range(config.n_layers): + self.layers.append(BLTTransformerLayer(layer_params)) + + # GlobalTransformer specific attributes + self.dropout = config.dropout + self.dim_token_emb = config.dim_token_emb + + self.token_embedding_projection = None + if config.dim_token_emb is not None and config.dim_token_emb != self.dim: + self.token_embedding_projection = nn.Linear( + config.dim_token_emb, + config.dim, + bias=False, + ) + + def forward( + self, + tokens: torch.Tensor, + tok_idx: Optional[torch.Tensor] = None, + embeds: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, + cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, + ): + bs, seqlen = tokens.shape + + h = embeds + + mask = ( + mask + if mask is not None + else create_causal_mask( + seqlen, + self.attn_impl, + self.attn_bias_type, + tokens=tokens, + eos_id=self.eos_id, + ) + ) + + if self.token_embedding_projection is not None and h.shape[-1] != self.dim: + h = self.token_embedding_projection(h) + + h = F.dropout(h, p=self.dropout, training=self.training) + + freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx) + + for i, layer in enumerate(self.layers): + h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) + + return h, cache + + def init_weights(self): + self.rope_embeddings.reset_parameters() + for depth, layer in enumerate(self.layers): + factor = self.config.get_init_std_factor(depth) + layer.init_weights(self.init_base_std, factor) + + # GlobalTransformer specific initialization + std = self.dim_token_emb ** (-0.5) + if self.token_embedding_projection is not None: + nn.init.trunc_normal_( + self.token_embedding_projection.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + +def compute_hash_embeddings( + local_encoder_tokens: torch.Tensor, + local_encoder, + encoder_hash_tok_embedding: nn.ModuleList, + encoder_hash_byte_group_nb_functions: int, + encoder_hash_byte_group_size: list, + encoder_hash_byte_group_vocab: int, +) -> torch.Tensor: + """ + Compute embeddings using hash token embeddings. + + Args: + local_encoder_tokens: Input tokens tensor + local_encoder: Encoder object with tok_embeddings method + encoder_hash_tok_embedding: ModuleList of hash token embeddings + encoder_hash_byte_group_nb_functions: Number of hash functions + encoder_hash_byte_group_size: List of byte group sizes + encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings + + Returns: + torch.Tensor: Combined embeddings + """ + if encoder_hash_tok_embedding is None: + return None + + local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens) + + i = 0 + for func_nb in range(encoder_hash_byte_group_nb_functions): + for byte_group_size in encoder_hash_byte_group_size: + hash_ids = byte_group_hash_function( + local_encoder_tokens, + byte_group_size, + hash_func_nb=func_nb, + max_hash=encoder_hash_byte_group_vocab, + ) + hash_tok_embedding = encoder_hash_tok_embedding[i] + local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) + i += 1 + + assert i == len(encoder_hash_tok_embedding) + return local_encoder_embeds + + +class BLTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + BLT models. + + This class provides the interface for model loading, saving, and weight initialization for all BLT model variants. + It inherits from [`PreTrainedModel`] which provides the core functionality for working with HuggingFace models. + + Args: + config ([`BLTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + """ + + config_class = BLTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BLTTransformerLayer", "LocalEncoder", "LocalDecoder", "GlobalTransformer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = False # BLT uses its own attention implementation + _supports_sdpa = True + _supports_cache_class = False + + def _init_weights(self, module): + """Initialize the weights - this is called by PreTrainedModel but we delegate to our custom init""" + # Don't do anything here - we use the custom init_weights method instead + pass + + +class BLTModel(BLTPreTrainedModel): + """ + The BLTModel (BLT) is a byte-level language model architecture that processes byte sequences + by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers, + and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for + improved performance and inference efficiency. + """ + + def __init__(self, config: BLTConfig): + super().__init__(config) + + # Store config reference + self.config = config + + # Create main components - they will read their parameters from config + self.local_encoder = LocalEncoder(config) + + # Create global-specific config by copying config and overriding dimensions + global_config = type(config)(**config.to_dict()) + global_config.dim = config.dim_global + global_config.n_layers = config.n_layers_global + global_config.n_heads = config.n_heads_global + global_config.n_kv_heads = config.n_kv_heads_global + global_config.dim_token_emb = config.global_dim_patch_emb + + self.global_transformer = GlobalTransformer(global_config) + self.local_decoder = LocalDecoder(config) + + # Initialize hash embeddings + self.encoder_hash_tok_embedding = init_hash_embeddings( + config, + local_encoder_dim=self.local_encoder.dim, + encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, + ) + + # Initialize patcher if needed + if config.patch_in_forward: + if config.realtime_patching and config.entropy_model_checkpoint_dir is not None: + # Load entropy model directly + entropy_model_checkpoint_dir = config.entropy_model_checkpoint_dir + + if not os.path.exists(entropy_model_checkpoint_dir): + raise FileNotFoundError(f"Entropy model checkpoint directory not found: {entropy_model_checkpoint_dir}") + + # Load entropy model parameters + params_path = os.path.join(entropy_model_checkpoint_dir, "params.json") + if not os.path.exists(params_path): + raise FileNotFoundError(f"params.json not found in: {entropy_model_checkpoint_dir}") + + with open(params_path) as fr: + reloaded = json.loads(fr.read()) + + torch.set_default_dtype(torch.bfloat16) + model_params = reloaded["entropy_model"] + logger.warning( + "Update checkpoint to load attn and sliding window args from checkpoint" + ) + + # Override patcher configuration with actual entropy model parameters from checkpoint + config.patcher_dim = model_params["dim"] + config.patcher_n_layers = model_params["n_layers"] + config.patcher_n_heads = model_params["n_heads"] + config.patcher_max_seqlen = model_params["max_seqlen"] + config.patcher_ffn_dim_multiplier = model_params["ffn_dim_multiplier"] + config.patcher_vocab_size = model_params["vocab_size"] + # Use sensible defaults for parameters not in checkpoint + config.patcher_attn_bias_type = "local_block_causal" + config.patcher_attn_impl = "sdpa" # originally xformers + config.patcher_sliding_window = 512 + + # BLTPatcher will extract patcher_ parameters from config directly + self.patcher = BLTPatcher(config) + + state_path = os.path.join( + entropy_model_checkpoint_dir, "consolidated.pth" + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.patcher.load_state_dict( + torch.load(state_path, map_location=device)["model"], strict=False + ) + self.patcher.to(device) + self.patcher = self.patcher.eval() + # no grads for the model: + for param in self.patcher.parameters(): + param.requires_grad = False + else: + self.patcher = None + + # Initialize weights and apply final processing + self.post_init() + + def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: + """ + Convert patch lengths to patch IDs for each token position. + + For each token position in the sequence, determines which patch it belongs to. + + Args: + patch_lengths: [batch_size, num_patches] - length of each patch + seq_len: total sequence length + + Returns: + patch_ids: [batch_size, seq_len] - patch index for each token position + + Example: + patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1 + seq_len = 10 + Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]] + # pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3 + """ + batch_size, num_patches = patch_lengths.shape + + # Create patch start positions: [0, 3, 5, 9] for the example above + patch_starts = torch.cat([ + torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), + patch_lengths.cumsum(dim=-1)[:, :-1] # cumsum without the final total + ], dim=-1) + + # For each token position, find which patch it belongs to + # by finding the rightmost patch start that's <= the position + token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1] + + # Broadcasting: patch_starts[batch, patch] <= token_positions[position] + # Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t + position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1) + + # Count how many patch starts are <= each position, then subtract 1 to get patch index + patch_ids = position_ge_patch_start.sum(dim=-1) - 1 + + return patch_ids + + def _decoder_patch_ids_from_lengths(self, patch_lengths: torch.Tensor, nb_boe: int, seq_len: int) -> torch.Tensor: + """ + Create decoder patch IDs by skipping the first encoder patch. + + The decoder starts after the first patch (which contains BOE tokens), + so we need to map decoder positions to the remaining patches. + + Args: + patch_lengths: [batch_size, num_patches] from encoder + nb_boe: number of beginning-of-example tokens in first patch + seq_len: decoder sequence length + + Returns: + decoder_patch_ids: [batch_size, seq_len] mapping decoder positions to patch indices + """ + # Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens) + decoder_patch_lengths = patch_lengths[:, 1:] + + # Create patch IDs for the decoder sequence using the remaining patches + return self._patch_ids_from_lengths(decoder_patch_lengths, seq_len) + + + + def forward( + self, + tokens: torch.Tensor, + patch_lengths: Optional[torch.Tensor] = None, + ): + # NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings + # are no longer used in the final BLT model + + bs, N = tokens.shape # Batch size and sequence length + + # Get megabyte inputs + nb_boe = int(0 if self.config.patching_mode != "" else self.config.patch_size - 1) + local_encoder_tokens, _, local_decoder_tokens = get_blt_input( + tokens=tokens, + enforce_patch_size_multiple=False, + nb_boe=nb_boe, + patch_size=self.config.patch_size, + boe_id=BOE_ID, + ) + + # Patching + if patch_lengths is None: + # assert ( + # getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward + # ), "Patch in forward not enabled and no patch_lengths passed." + + # PATCHER MODEL DEFINED + if self.config.patching_mode == PatchingModeEnum.entropy: + _, patch_lengths, _ = self.patcher( + local_encoder_tokens, + patch_size=self.config.patch_size, + include_next_token=True, + threshold=self.config.patching_threshold, + threshold_add=self.config.patching_threshold_add, + monotonicity=self.config.monotonicity, + max_patch_length=self.config.max_patch_length, + patching_batch_size=self.config.patching_batch_size, + device=self.config.patching_device, + ) + else: + # self.config.patching_mode == PatchingModeEnum.byte + bs, seq_len = local_encoder_tokens.shape + seq_len_next_tok = seq_len + 1 # include_next_token=True + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device + ) + + # Apply any processing to patch lengths + if self.config.max_patch_length is not None: + # TODO: avoid going back to a list here. + patch_lengths = [ + BLTPatcher.split_large_numbers(pl, self.config.max_patch_length) + for pl in patch_lengths.tolist() + ] + max_len = max([len(pl) for pl in patch_lengths]) + patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] + patch_lengths = torch.tensor( + patch_lengths, dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device + ) + assert not check_non_zero_after_zero(patch_lengths) + # Find the last non-zero column index using argmax on a reversed version of the tensor + last_non_zero_col_reversed = ( + (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() + ) + # Slice the tensor up to the last non-zero column + patch_lengths = patch_lengths[ + :, : patch_lengths.shape[1] - last_non_zero_col_reversed + ] + else: + if nb_boe > 0: + patch_lengths[:, 0] += nb_boe + + assert torch.min(patch_lengths) >= 0 + + # Generate patch IDs from patch_lengths + patch_ids = self._patch_ids_from_lengths( + patch_lengths, local_encoder_tokens.shape[-1] + ) + assert torch.max(patch_ids) + 1 <= torch.max( + (patch_lengths != 0).sum(dim=-1) + ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}" + + cross_attn_mask_enc = None + # Cross-attention encoder + if self.config.cross_attn_encoder: + cross_attn_mask_enc = cross_attn_mask( + patch_ids, + patch_lengths, + N, + patches_as_queries=True, + cross_attn_k=self.config.cross_attn_k, + window=self.config.cross_attn_window_encoder, + block_mask=self.config.cross_attn_use_flex_attention, + ) + + # Hashing and embedding + local_encoder_embeds = compute_hash_embeddings( + local_encoder_tokens=local_encoder_tokens, + local_encoder=self.local_encoder, + encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, + encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions, + encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size, + encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab, + ) + + # NOTE: Frequency-based n-gram embeddings removed as per paper + # The final BLT model uses only hash-based n-gram embeddings + + # Local encoder + (h_encoder, h_cross), cache_encoder = self.local_encoder( + tokens=local_encoder_tokens, + embeds=local_encoder_embeds, + patch_embeds=None, + cross_mask=cross_attn_mask_enc, + num_patches=patch_lengths.shape[1], + patch_ids=patch_ids, + ) + + # Downsampling + h = h_cross.view(bs, patch_lengths.shape[1], -1) + + # Global transformer + global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(BOE_ID) + rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id) + eos_patch_ids = patch_ids[rows, cols] + global_tokens[rows, eos_patch_ids] = self.config.eos_token_id + + h, _ = self.global_transformer( + embeds=h, + tokens=global_tokens, + ) + + # Unpatching + dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :] + + # Generate decoder patch IDs + decoder_patch_ids = self._decoder_patch_ids_from_lengths( + patch_lengths, nb_boe, local_decoder_tokens.shape[-1] + ) + assert ( + torch.max(decoder_patch_ids) + 1 <= h.shape[1] + ), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}" + assert ( + decoder_patch_ids.shape[1] == dec_embeds.shape[1] + ), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}" + + # Cross-attention decoder + if not self.config.cross_attn_decoder: + h = torch.gather( + h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) + ) + cross_attn_mask_dec = None + assert local_decoder_tokens.shape == h.shape[:-1] + else: + cross_attn_mask_dec = cross_attn_mask( + decoder_patch_ids, + patch_lengths, + N, + patches_as_queries=False, + cross_attn_k=self.config.cross_attn_k, + window=self.config.cross_attn_window_decoder, + block_mask=self.config.cross_attn_use_flex_attention, + ) + + # Local decoder + output, _ = self.local_decoder( + embeds=dec_embeds, + patch_embeds=h, + tokens=local_decoder_tokens, + cross_mask=cross_attn_mask_dec, + ) + return output + + def init_weights(self): + self.local_encoder.init_weights() + self.global_transformer.init_weights() + self.local_decoder.init_weights() + + if self.encoder_hash_tok_embedding is not None: + emb_std = self.local_encoder.dim ** (-0.5) + for emb in self.encoder_hash_tok_embedding: + nn.init.trunc_normal_( + emb.weight, + mean=0.0, + std=emb_std, + a=-3 * emb_std, + b=3 * emb_std, + ) + + +class BLTPatcher(BLTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + # Store config reference for later use + self.config = config + + # Extract patcher parameters from BLTConfig + self.dim = config.patcher_dim + self.init_base_std = config.patcher_init_base_std + self.attn_impl = config.patcher_attn_impl + self.attn_bias_type = config.patcher_attn_bias_type + self.init_std_factor = config.patcher_init_std_factor + self.max_seqlen = config.patcher_max_seqlen + n_layers = config.patcher_n_layers + n_heads = config.patcher_n_heads + head_dim = config.patcher_head_dim + rope_theta = config.patcher_rope_theta + rope_use_fp32_in_outer_product = config.patcher_rope_use_fp32_in_outer_product + norm_eps = config.patcher_norm_eps + vocab_size = config.patcher_vocab_size + weight_tying = config.patcher_weight_tying + sliding_window = config.patcher_sliding_window + eos_token_id = config.patcher_eos_token_id + + self.rope_embeddings = RotaryEmbedding( + theta=rope_theta, + head_dim=head_dim or self.dim // n_heads, + max_seqlen=self.max_seqlen, + rope_use_fp32_in_outer_product=rope_use_fp32_in_outer_product, + ) + # Handle both eos_id and eos_token_id for compatibility + self.eos_id = eos_token_id + + # Extract additional parameters for BLTTransformerLayer + n_kv_heads = getattr(config, 'patcher_n_kv_heads', None) if hasattr(config, 'patcher_dim') else getattr(config, 'n_kv_heads', None) + multiple_of = getattr(config, 'patcher_multiple_of', 256) if hasattr(config, 'patcher_dim') else getattr(config, 'multiple_of', 256) + ffn_dim_multiplier = getattr(config, 'patcher_ffn_dim_multiplier', None) if hasattr(config, 'patcher_dim') else getattr(config, 'ffn_dim_multiplier', None) + + # Create a simple parameter dict for BLTTransformerLayer + layer_params = { + 'dim': self.dim, + 'n_heads': n_heads, + 'head_dim': head_dim, + 'n_kv_heads': n_kv_heads, + 'rope_theta': rope_theta, + 'multiple_of': multiple_of, + 'ffn_dim_multiplier': ffn_dim_multiplier, + 'norm_eps': norm_eps, + } + + self.layers = nn.ModuleList() + for _ in range(n_layers): + self.layers.append(BLTTransformerLayer(layer_params)) + + # LMTransformer specific attributes + self.weight_tying = weight_tying + self.sliding_window = sliding_window + + assert vocab_size > 0 + + self.tok_embeddings = torch.nn.Embedding(vocab_size, self.dim) + + self.norm = RMSNorm(self.dim, eps=norm_eps) + + self.output = nn.Linear( + self.dim, + vocab_size, + bias=False, + ) + + if self.weight_tying: + self.output.weight = self.tok_embeddings.weight + + def forward( + self, + token_values: torch.Tensor, + target: Optional[torch.Tensor] = None, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, + attn_impl: str | None = None, + patch_size: Optional[int] = None, + include_next_token: bool = True, + threshold: Optional[float] = None, + threshold_add: Optional[float] = None, + monotonicity: bool = False, + max_patch_length: Optional[int] = None, + patching_batch_size: int = 1, # Changed from Optional[int] = None to int = 1 + device: Optional[str] = None, + enable_grad: bool = False, + ): + attn_impl = self.attn_impl if attn_impl is None else attn_impl + + # Handle chunked processing for entropy calculation + # grad_context = nullcontext() if enable_grad else torch.no_grad() + # with grad_context: + entropies = [] + preds = [] + max_length = min(getattr(self, "max_length", 8192), self.max_seqlen) + batch_numel = max_length * patching_batch_size + splits = torch.split(token_values.flatten(), batch_numel) + + for split in splits: + pad_size = (max_length - (split.numel() % max_length)) % max_length + pad = torch.zeros( + pad_size, dtype=split.dtype, device=split.device, requires_grad=False + ) + split = torch.cat((split, pad), dim=0) + split = split.reshape(-1, max_length) + if device is not None: + split = split.to(device) + + # Process chunk: embeddings -> layers -> output + bsz, seqlen = split.shape + h = self.tok_embeddings(split) + chunk_mask = create_causal_mask( + seqlen, + attn_impl, + self.attn_bias_type, + sliding_window=self.sliding_window, + tokens=split, + eos_id=self.eos_id, + ) + freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None) + + for i, layer in enumerate(self.layers): + h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=attn_impl) + + pred = self.output(self.norm(h)) + pred = pred.reshape(-1, pred.shape[-1])[ + : split.numel() - pad_size, : + ] # [batch_size * seq_len, vocab] + preds.append(pred) + pred_entropies = self.entropy(pred) + entropies.append(pred_entropies) + + concat_entropies = torch.cat(entropies, dim=0) + concat_entropies = concat_entropies.reshape(token_values.shape) + concat_preds = torch.cat(preds, dim=0) + concat_preds = concat_preds.reshape(token_values.shape[0], -1) + + # Always compute patch lengths from concatenated entropies + bs, seq_len = token_values.shape + seq_len_next_tok = seq_len + 1 if include_next_token else seq_len + + # Find patch start IDs based on entropy + if patch_size is not None: + patch_start_ids = self.find_entropy_patch_start_ids( + concat_entropies, + patch_size, + include_next_token=include_next_token, + threshold=threshold, + threshold_add=threshold_add, + monotonicity=monotonicity, + ) + patch_lengths = self.patch_lengths_from_start_ids( + patch_start_ids, seq_len_next_tok + ) + else: + # Default to byte-level patching + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device + ) + + # Apply any processing to patch lengths + if max_patch_length is not None: + # TODO: avoid going back to a list here. + patch_lengths = [ + self.split_large_numbers(pl, max_patch_length) + for pl in patch_lengths.tolist() + ] + max_len = max([len(pl) for pl in patch_lengths]) + patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths] + patch_lengths = torch.tensor( + patch_lengths, dtype=token_values.dtype, device=token_values.device + ) + assert not check_non_zero_after_zero(patch_lengths) + # Find the last non-zero column index using argmax on a reversed version of the tensor + last_non_zero_col_reversed = ( + (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min() + ) + # Slice the tensor up to the last non-zero column + patch_lengths = patch_lengths[ + :, : patch_lengths.shape[1] - last_non_zero_col_reversed + ] + + return concat_entropies, patch_lengths, concat_preds + + def reset_parameters(self, init_std=None): + self.norm.reset_parameters() + + def init_weights(self): + self.reset_parameters() + init_std = self.dim ** (-0.5) + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + self.rope_embeddings.reset_parameters() + for depth, layer in enumerate(self.layers): + factor = self.config.get_init_std_factor(depth) + layer.init_weights(self.init_base_std, factor) + + if not self.weight_tying: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + @staticmethod + def entropy(scores): + """ + scores: [bs, seq_len, vocab] + returns [bs, seq_len] + + Computes the entropy for each token in the batch. + Note: uses natural log. + """ + log_probs = F.log_softmax(scores, dim=-1) + probs = torch.exp(log_probs) + p_log_p = log_probs * probs + entropy = -p_log_p.sum(dim=-1) + return entropy + + + + @staticmethod + def patch_start_ids_from_patch_start_mask(patch_start_mask): + bs, trunc_seq_len = patch_start_mask.shape + max_patches = patch_start_mask.sum(dim=1).max() + if max_patches == 0: + patch_start_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + else: + patch_ids = ( + torch.arange(trunc_seq_len, device=patch_start_mask.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + extra_patch_ids = torch.full( + (bs, trunc_seq_len), + trunc_seq_len, + dtype=torch.long, + device=patch_start_mask.device, + ) + all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) + patch_start_mask_padded = torch.cat( + (patch_start_mask, ~patch_start_mask), dim=1 + ) + patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape( + bs, trunc_seq_len + )[:, :max_patches] + return patch_start_ids + + @staticmethod + def patch_lengths_from_start_ids(patch_start_ids, seq_len): + """ + Calculate patch lengths from start ids. + start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then + the rest are filled to the seq len. + seq_len: ex: 7 length of the sequence + + returns the patch lengths: + [1, 6] for the above example. + """ + last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) + patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) + patch_lengths = patch_end_ids - patch_start_ids + 1 + assert torch.all(patch_lengths >= 0), f"{patch_lengths}" + assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}" + return patch_lengths + + @staticmethod + def find_entropy_patch_start_ids( + entropies, + patch_size=None, + threshold=None, + threshold_add=None, + monotonicity=False, + include_next_token=True, + ): + """ + Use entropies to find the start ids of each patch. + Use patch_size or threshold to figure out the total number of patches to allocate. + + When threshold is not None the number of patches is not constant between + different sequences, but patches can be identified incrementally rather than + decided globally using the entire sequence. + """ + bs, seq_len = entropies.shape[:2] + + first_ids = ( + torch.tensor([0, 1], dtype=torch.long, device=entropies.device) + .unsqueeze(0) + .repeat(bs, 1) + ) + preds_truncation_len = first_ids.shape[ + 1 + ] # remove the first preds because they will be start of patches. + entropies = entropies[:, 1:] + if threshold is None: + num_patches = seq_len // patch_size + patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices + patch_start_ids = patch_start_ids.sort(dim=1).values + else: + patch_start_mask = entropies > threshold + if not include_next_token: + patch_start_mask = patch_start_mask[:, :-1] + # patch_start_mask[1:] |= tokens[:-1] < OFFSET + patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask) + + patch_start_ids = torch.cat( + (first_ids, patch_start_ids + preds_truncation_len), dim=1 + ) + return patch_start_ids + + @staticmethod + def split_large_numbers(lst, m): + new_lst = [] + for i in lst: + if i > m: + while i > m: + new_lst.append(m) + i -= m + new_lst.append(i) + else: + new_lst.append(i) + assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}" + return new_lst + + +def init_hash_embeddings( + config, + local_encoder_dim: int, + encoder_hash_byte_group_size: list, +): + """Initialize hash-based token embeddings for the BLT encoder.""" + if config.encoder_hash_byte_group_size is None: + return None + + embeddings = [] + emb_dim = local_encoder_dim + encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab + + for _ in range(config.encoder_hash_byte_group_nb_functions): + for _ in encoder_hash_byte_group_size: + embeddings.append( + nn.Embedding( + encoder_hash_byte_group_vocab, + emb_dim, + ) + ) + + return nn.ModuleList(embeddings) + + +__all__ = [ + "BLTPreTrainedModel", + "BLTModel", + "BLTPatcher", + "LocalEncoder", + "LocalDecoder", + "GlobalTransformer", +] \ No newline at end of file diff --git a/backup_blt_wip_backup/tokenization_blt.py b/backup_blt_wip_backup/tokenization_blt.py new file mode 100644 index 0000000000000000000000000000000000000000..cf57143de5dd7d594c60b05660f03548dd60689b --- /dev/null +++ b/backup_blt_wip_backup/tokenization_blt.py @@ -0,0 +1,273 @@ +# coding=utf-8 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for BLT.""" + +import os +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +if TYPE_CHECKING: + from ...tokenization_utils_base import TextInput + +logger = logging.get_logger(__name__) + +# BLT tokenizer constants +SEP = " " +BOS_ID: int = 1 +EOS_ID: int = 2 +PAD_ID: int = -1 +BOE_ID: int = 0 +BPE_ID: int = 3 +OFFSET: int = 4 +BYTE_UNITS: int = 256 + +VOCAB_FILES_NAMES = {} # BLT doesn't require external vocab files + + +class BLTTokenizer(PreTrainedTokenizer): + """ + Construct a BLT tokenizer. Based on byte-level tokenization where each byte is treated as a token. + + This tokenizer converts text to UTF-8 bytes and then maps each byte to a token ID with an offset. + It supports special tokens for beginning of sequence (BOS), end of sequence (EOS), + beginning of example (BOE), and padding (PAD). + + Args: + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The padding token. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. Not used in BLT but kept for compatibility. + boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of example token, specific to BLT. + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add a `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + bos_token="", + eos_token="", + pad_token="", + unk_token="", + boe_token="", + add_bos_token=True, + add_eos_token=True, + clean_up_tokenization_spaces=False, + spaces_between_special_tokens=False, + **kwargs, + ): + # Store BLT-specific parameters first + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.vocab_size_unit_1 = BYTE_UNITS + self.offsetting_special_char = OFFSET + + # BLT token IDs (exactly like original) + self.boe_id = BOE_ID + self.bos_id = BOS_ID + self.eos_id = EOS_ID + self.pad_id = PAD_ID + self.bpe_id = BPE_ID + self.n_words = self.vocab_size_unit_1 + self.offsetting_special_char + + # Convert string tokens to AddedToken objects + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + self.boe_token = AddedToken(boe_token, normalized=False, special=True) if isinstance(boe_token, str) else boe_token + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + unk_token=unk_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + spaces_between_special_tokens=spaces_between_special_tokens, + **kwargs, + ) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.vocab_size_unit_1 + self.offsetting_special_char + + def get_vocab(self): + """Returns vocab as a dict""" + # Create a mapping for byte values + offset + vocab = {} + + # Add special tokens (with defensive checks) + if hasattr(self, 'bos_token'): + vocab[str(self.bos_token)] = self.bos_id + if hasattr(self, 'eos_token'): + vocab[str(self.eos_token)] = self.eos_id + if hasattr(self, 'pad_token'): + vocab[str(self.pad_token)] = self.pad_id + if hasattr(self, 'boe_token'): + vocab[str(self.boe_token)] = self.boe_id + + # Add byte tokens as string representations of byte values + vocab_size_unit_1 = getattr(self, 'vocab_size_unit_1', BYTE_UNITS) + offsetting_special_char = getattr(self, 'offsetting_special_char', OFFSET) + for i in range(vocab_size_unit_1): + vocab[str(i)] = i + offsetting_special_char + + # Add any additional tokens if available + if hasattr(self, 'added_tokens_encoder'): + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str, **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. For BLT, we work directly with byte values. + Returns a list of strings that represent the byte values. + """ + # Convert text to UTF-8 bytes, just like the original + try: + bytes_data = text.encode("utf-8", errors="ignore") + except UnicodeEncodeError: + bytes_data = text.encode("utf-8", errors="ignore") + + # Return string representations of byte values for the tokenizer framework + return [str(byte_val) for byte_val in bytes_data] + + def _convert_token_to_id(self, token: str) -> int: + """Converts a token (str) to an id using the vocab.""" + # Handle special tokens + if token == str(self.bos_token): + return self.bos_id + elif token == str(self.eos_token): + return self.eos_id + elif token == str(self.pad_token): + return self.pad_id + elif token == str(self.boe_token): + return self.boe_id + else: + try: + # Convert byte value string to int and add offset (like original) + byte_val = int(token) + if 0 <= byte_val <= 255: + return byte_val + self.offsetting_special_char + except ValueError: + pass + + # Check if it's in added tokens + return self.added_tokens_encoder.get(token, self.unk_token_id) + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) to a token (str) using the vocab.""" + # Handle special tokens + if index == self.bos_id: + return str(self.bos_token) + elif index == self.eos_id: + return str(self.eos_token) + elif index == self.pad_id: + return str(self.pad_token) + elif index == self.boe_id: + return str(self.boe_token) + elif index >= self.offsetting_special_char and index < self.vocab_size: + # Convert back to byte value (like original) + byte_val = index - self.offsetting_special_char + return str(byte_val) + else: + # Check added tokens + for token, token_id in self.added_tokens_encoder.items(): + if token_id == index: + return token + return str(self.unk_token) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """Converts a sequence of tokens to a single string.""" + byte_values = [] + + for token in tokens: + # Skip special tokens + if token in [str(self.bos_token), str(self.eos_token), str(self.pad_token), str(self.boe_token)]: + continue + + try: + # Convert token back to byte value (like original decode method) + byte_val = int(token) + if 0 <= byte_val <= 255: + byte_values.append(byte_val) + except ValueError: + continue + + # Convert byte values back to string (exactly like original) + try: + return bytes(byte_values).decode("utf-8", errors="ignore") + except (UnicodeDecodeError, ValueError): + return "" + + def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None): + """ + Encode text exactly like the original BLT tokenizer. + """ + if add_bos is None: + add_bos = self.add_bos_token + if add_eos is None: + add_eos = self.add_eos_token + + # Since bpe_delim=False, we use the simple byte encoding + tokens = bytes(text, encoding="utf-8", errors="ignore") + + # Offsetting (exactly like original) + tokens = [int(unit) + self.offsetting_special_char for unit in tokens] + + if add_bos: + tokens.insert(0, self.bos_id) + if add_eos: + tokens.append(self.eos_id) + + return tokens + + def decode(self, tokens: list[int], cut_at_eos: bool = False): + """ + Decode tokens exactly like the original BLT tokenizer. + """ + if cut_at_eos: + for k, t in enumerate(tokens): + if t == self.eos_id: + tokens = tokens[: k + 1] + break + return bytes( + [tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0] + ).decode("utf-8", errors="ignore") + + def get_vocab_size(self) -> int: + """Get vocab size like the original tokenizer.""" + return self.vocab_size_unit_1 + self.offsetting_special_char + +__all__ = ["BLTTokenizer"] \ No newline at end of file diff --git a/backup_blt_wip_backup/tokenizers/__init__.py b/backup_blt_wip_backup/tokenizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/backup_blt_wip_backup/tokenizers/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-312.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e437756c33aec5b5372596b16fd60b8732b57de5 Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-312.pyc differ diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-39.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ca0dcc4ad6d75cac3868dee0b9bf1183c2183a8 Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/__init__.cpython-39.pyc differ diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/abstract_tokenizer.cpython-312.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/abstract_tokenizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..029146f29c2b8c58f2e50bf95159751d1cf21006 Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/abstract_tokenizer.cpython-312.pyc differ diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/blt_tokenizer.cpython-312.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/blt_tokenizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1c2398abd9aa0a8bef2ac61f3ab44161d6cf055 Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/blt_tokenizer.cpython-312.pyc differ diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/build_tokenizer.cpython-312.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/build_tokenizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f344f971978a50d4cb123127a0cb65e1c8327472 Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/build_tokenizer.cpython-312.pyc differ diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-312.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c380523b8c7191a900bd6bce095df4c63ff30d4d Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-312.pyc differ diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-39.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2e7cfd7e1cc6e85e4d11e731235b383230c2149 Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/constants.cpython-39.pyc differ diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/sentence_piece_tokenizer.cpython-312.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/sentence_piece_tokenizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4a60bb2cf0afaecb335c7b70118047ded4e34ab Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/sentence_piece_tokenizer.cpython-312.pyc differ diff --git a/backup_blt_wip_backup/tokenizers/__pycache__/tiktoken_tokenizer.cpython-312.pyc b/backup_blt_wip_backup/tokenizers/__pycache__/tiktoken_tokenizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7548c986e56ce17009c2d81c63ce2545c5cc1d5 Binary files /dev/null and b/backup_blt_wip_backup/tokenizers/__pycache__/tiktoken_tokenizer.cpython-312.pyc differ diff --git a/backup_blt_wip_backup/tokenizers/abstract_tokenizer.py b/backup_blt_wip_backup/tokenizers/abstract_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ff31d655ae3461ca1d39230e40d72e0c24246fa3 --- /dev/null +++ b/backup_blt_wip_backup/tokenizers/abstract_tokenizer.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import abc + + +class Tokenizer(abc.ABC): + @abc.abstractmethod + def encode(self, text: str, add_bos: bool, add_eos: bool): + pass + + @abc.abstractmethod + def decode(self, tokens: list[int]): + pass + + @abc.abstractmethod + def get_token_offsets(self, text: str, tokens: list[int] | None = None) -> tuple[list[str], list[int]]: + """Return the offsets of the tokens in the original text. Only used for evaluation.""" + pass + + @abc.abstractmethod + def get_vocab_size(self) -> int: + pass diff --git a/backup_blt_wip_backup/tokenizers/blt_tokenizer.py b/backup_blt_wip_backup/tokenizers/blt_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..2d018ff90ead466a284e59a18ee5839d9f1a3155 --- /dev/null +++ b/backup_blt_wip_backup/tokenizers/blt_tokenizer.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import re + +from .abstract_tokenizer import Tokenizer +from .sentence_piece_tokenizer import SentencePieceTokenizer + + +SEP = " " +BOS_ID: int = 1 +EOS_ID: int = 2 +PAD_ID: int = -1 +BOE_ID: int = 0 +BPE_ID: int = 3 +OFFSET: int = 4 + +BYTE_UNITS: int = 256 + + +def convert_to_bytes(s): + # check if the output is a bytes like object of the format <0x00> + if re.match(r"<0x[0-9a-fA-F]+>", s): + return bytes.fromhex(s[3:-1]) + else: + return bytes(s, "utf-8", errors="ignore") + + +def text2bytes_bpe_delims( + text: str, + *, + bpe_tokenizer, + bpe_id: int, + offsetting_special_char: int, + add_bos: bool, + add_eos: bool, +): + cur_bpe = bpe_tokenizer.encode(text, add_bos=add_bos, add_eos=add_eos) + # merge the leading space tokens + leading_space_tokens = [] + other_bpe_tokens = [] + leading = True + for token in cur_bpe: + bpe_str = bpe_tokenizer.sp_model.id_to_piece(token) + if leading and all(c == "▁" for c in bpe_str): + leading_space_tokens.append(bpe_str) + else: + leading = False + other_bpe_tokens.append(bpe_str) + cur_bpe_strs = ["".join(leading_space_tokens)] + other_bpe_tokens + + # Remove the '▁' characters + bpe_strs = [] + for i, bpe_str in enumerate(cur_bpe_strs): + if len(bpe_strs) <= 1 and all([c == " " for s in bpe_strs for c in s]) and not all(c == "▁" for c in bpe_str): + # Remove leading space for first non space token. + bpe_str = bpe_str.replace("▁", "") + elif i == 0 and all(c == "▁" for c in bpe_str): + bpe_str = " " * (len(text) - len(text.lstrip(" "))) + else: + bpe_str = bpe_str.replace("▁", " ") + if len(bpe_str) > 0: + bpe_strs.append(bpe_str) + ex_seq = [] + # Convert bpe tokens to bytes + for s in bpe_strs: + byte_chunk = convert_to_bytes(s) + proc_chunk = [int(unit) for unit in byte_chunk] + ex_seq.extend([bpe_id - offsetting_special_char] + proc_chunk) + + return ex_seq + + +class BltTokenizer(Tokenizer): + def __init__( + self, + *, + vocab_size_unit_1: int = BYTE_UNITS, + bpe_delim: bool = False, + bpe_tokenizer_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", + add_bos: bool = True, + add_eos: bool = True, + ): + self.add_bos = add_bos + self.add_eos = add_eos + self.vocab_size_unit_1 = vocab_size_unit_1 + self.boe_id = BOE_ID + self.bos_id = BOS_ID + self.eos_id = EOS_ID + self.pad_id = PAD_ID + self.bpe_id = BPE_ID + self.bpe_tokenizer_path = bpe_tokenizer_path + if bpe_delim: + self.bpe_tokenizer = SentencePieceTokenizer(model_path=self.bpe_tokenizer_path) + else: + self.bpe_tokenizer = None + self.bpe_delim = bpe_delim + self.offsetting_special_char = OFFSET + self.vocab_size_unit_1 = vocab_size_unit_1 + self.n_words = vocab_size_unit_1 + self.offsetting_special_char + + def get_vocab_size(self) -> int: + return self.n_words + + def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None): + if add_bos is None: + add_bos = self.add_bos + if add_eos is None: + add_eos = self.add_eos + + if self.bpe_delim: + tokens = text2bytes_bpe_delims( + text, + bpe_tokenizer=self.bpe_tokenizer, + bpe_id=self.bpe_id, + offsetting_special_char=self.offsetting_special_char, + add_bos=False, + add_eos=False, + ) + else: + tokens = bytes(text, encoding="utf-8", errors="ignore") + + # Offsetting + tokens = [int(unit) + self.offsetting_special_char for unit in tokens] + + if add_bos: + tokens.insert(0, self.bos_id) + if add_eos: + tokens.append(self.eos_id) + + return tokens + + def decode(self, tokens: list[int], cut_at_eos: bool = False): + if cut_at_eos: + for k, t in enumerate(tokens): + if t == self.eos_id: + tokens = tokens[: k + 1] + break + return bytes( + [tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0] + ).decode("utf-8", errors="ignore") + + def get_token_offsets(self, text: str, tokens: list[int] | None = None): + # TODO: Figure out what this does + raise NotImplementedError() diff --git a/backup_blt_wip_backup/tokenizers/sentence_piece_tokenizer.py b/backup_blt_wip_backup/tokenizers/sentence_piece_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..fece8cbce7a85d6aabb3e3545cd256d17b947237 --- /dev/null +++ b/backup_blt_wip_backup/tokenizers/sentence_piece_tokenizer.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import logging +import os + + +try: + from sentencepiece import SentencePieceProcessor + + has_sp = True +except ImportError: + has_sp = False + +from .abstract_tokenizer import Tokenizer + + +logger = logging.getLogger(__name__) + + +class SentencePieceTokenizer(Tokenizer): + def __init__(self, model_path: str, add_bos: bool = True, add_eos: bool = True) -> None: + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + + logger.info(f"Reloaded SentencePiece model from {model_path}") + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.pad_id() + self.add_bos = add_bos + self.add_eos = add_eos + logger.info(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + def get_vocab_size(self) -> int: + return self.n_words + + def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None): + if add_bos is None: + add_bos = self.add_bos + + if add_eos is None: + add_eos = self.add_eos + assert type(s) is str + tokens = [self.bos_id] * add_bos + self.sp_model.encode(s) + [self.eos_id] * add_eos + return tokens + + def decode(self, tokens: list[int]): + return self.sp_model.decode(tokens) + + def get_token_offsets(self, text: str, tokens: list[int] | None = None) -> tuple[list[str], list[int]]: + pieces = self.sp_model.encode_as_immutable_proto(text).pieces + substrs = [p.surface for p in pieces] + offsets = [p.begin for p in pieces] + return substrs, offsets diff --git a/backup_blt_wip_backup/unified_blt_debug/config.json b/backup_blt_wip_backup/unified_blt_debug/config.json new file mode 100644 index 0000000000000000000000000000000000000000..67ab54a27541379cc9caae6512155029d2adf52c --- /dev/null +++ b/backup_blt_wip_backup/unified_blt_debug/config.json @@ -0,0 +1,144 @@ +{ + "args": { + "alpha_depth": "disabled", + "architecture": "vanilla", + "attn_bias_type": "block_causal", + "attn_impl": "xformers", + "attn_to_keep": "all", + "conv_kernel_size": null, + "cross_attn_all_layers_decoder": true, + "cross_attn_all_layers_encoder": false, + "cross_attn_decoder": true, + "cross_attn_encoder": true, + "cross_attn_init_by_pooling": true, + "cross_attn_k": 2, + "cross_attn_nheads": 16, + "cross_attn_use_flex_attention": true, + "cross_attn_window_decoder": null, + "cross_attn_window_encoder": null, + "custom_bwd": false, + "dim": 512, + "dim_global": 2048, + "dim_local_decoder": 1024, + "dim_local_encoder": 1024, + "dim_patch_emb": null, + "dim_token": null, + "dim_token_emb": null, + "downsampling_by_pooling": "max", + "dropout": 0.0, + "encoder_enable_byte_group_hash": false, + "encoder_enable_byte_ngrams": false, + "encoder_hash_byte_group_nb_functions": 1, + "encoder_hash_byte_group_size": [ + 3, + 4, + 5, + 6, + 7, + 8 + ], + "encoder_hash_byte_group_vocab": 500002, + "encoder_lm_loss": false, + "encoder_ngram_table_dir": null, + "encoder_ngram_to_size_str": null, + "encoder_preds_low_entropy_toks": null, + "encoder_preds_random_toks": null, + "entropy_model_checkpoint_dir": null, + "entropy_model_is_ngram_model": false, + "eos_id": 2, + "ffn_dim_multiplier": 1.0, + "full_logging_n_layers": 4, + "fuse_sequence_parallel": false, + "global_local_decoder_residual_layer": null, + "head_dim": null, + "init_base_std": null, + "init_std_factor": "current_depth", + "init_use_depth": "current", + "init_use_gaussian": true, + "layer_ckpt": "none", + "local_attention_window_len": 512, + "log_patch_lengths": false, + "loss_parallel": false, + "max_encoder_seq_length": 24576, + "max_length": 256, + "max_patch_length": null, + "max_seqlen": 4096, + "monotonicity": false, + "multiple_of": 256, + "n_heads": 8, + "n_heads_global": 16, + "n_heads_local_decoder": 16, + "n_heads_local_encoder": 16, + "n_kv_heads": null, + "n_kv_heads_global": null, + "n_layers": 8, + "n_layers_global": 25, + "n_layers_local_decoder": 9, + "n_layers_local_encoder": 1, + "ngram_vocab_sizes": null, + "non_linearity": "swiglu", + "norm_affine": true, + "norm_eps": 1e-05, + "norm_type": "rmsnorm", + "output_size": -1, + "pad_to_max_length": true, + "patch_in_forward": true, + "patch_size": 4.5, + "patching_batch_size": 1, + "patching_device": "cuda", + "patching_mode": "entropy", + "patching_threshold": 1.335442066192627, + "patching_threshold_add": null, + "patching_thresholds_str": null, + "pm_size": 0, + "pre_norm": true, + "recompute_attn": false, + "recompute_fc1_out": false, + "recompute_fc3_out": false, + "rope_theta": 500000.0, + "rope_use_fp32_in_outer_product": true, + "seed": 42, + "sequence_parallel": false, + "share_encoder_decoder_emb": true, + "tie_local_encoder_decoder": false, + "tie_local_encoder_decoder_logits": false, + "tokenize_with_bpe_delimiter": false, + "use_fsdp": true, + "use_local_encoder_transformer": true, + "use_rope": true, + "vocab_size": 260, + "weight_tying": false + }, + "patch_in_forward": true, + "realtime_patching": true, + "patching_mode": "entropy", + "patch_size": 4.5, + "patching_threshold": 1.335442066192627, + "patching_threshold_add": null, + "max_patch_length": null, + "patching_batch_size": 1, + "patching_device": "cuda", + "monotonicity": false, + "patcher_vocab_size": 260, + "patcher_dim": 768, + "patcher_n_layers": 14, + "patcher_n_heads": 12, + "patcher_head_dim": null, + "patcher_n_kv_heads": null, + "patcher_max_seqlen": 8192, + "patcher_norm_eps": 1e-05, + "patcher_dropout": 0.0, + "patcher_sliding_window": 512, + "patcher_ffn_dim_multiplier": 1.0, + "patcher_multiple_of": 256, + "patcher_rope_theta": 10000.0, + "patcher_rope_use_fp32_in_outer_product": false, + "patcher_attn_impl": "xformers", + "patcher_attn_bias_type": "local_block_causal", + "patcher_init_base_std": null, + "patcher_init_std_factor": "current_depth", + "patcher_dim_token_emb": null, + "patcher_weight_tying": false, + "patcher_bos_token_id": 1, + "patcher_eos_token_id": 2 +} \ No newline at end of file diff --git a/config.json b/config.json new file mode 100644 index 0000000000000000000000000000000000000000..60d10db33244fffe880b0acb16c5b5d7938002c0 --- /dev/null +++ b/config.json @@ -0,0 +1,100 @@ +{ + "model_type": "blt", + "vocab_size": 260, + "max_position_embeddings": 4096, + "patch_in_forward": true, + "realtime_patching": true, + "patching_mode": "entropy", + "patch_size": 4, + "patching_threshold": 1.335442066192627, + "patching_threshold_add": null, + "max_patch_length": null, + "patching_batch_size": 1, + "patching_device": "cuda", + "monotonicity": false, + "cross_attn_k": 2, + "encoder_hash_byte_group_size": [ + 3, + 4, + 5, + 6, + 7, + 8 + ], + "encoder_hash_byte_group_vocab": 500002, + "encoder_hash_byte_group_nb_functions": 1, + "pm_size": 0, + "tie_word_embeddings": false, + "initializer_range": 0.02, + "rope_theta": 500000.0, + "rope_scaling": { + "type": "default" + }, + "patcher_config": { + "vocab_size": 260, + "hidden_size": 768, + "num_hidden_layers": 14, + "num_attention_heads": 12, + "num_key_value_heads": null, + "max_position_embeddings": 8192, + "norm_eps": 1e-05, + "dropout": 0.0, + "rope_theta": 10000.0, + "attn_bias_type": "local_block_causal", + "intermediate_size": 2048 + }, + "encoder_config": { + "vocab_size": 260, + "cross_attn_all_layers": false, + "cross_attn_k": 2, + "hidden_size_global": 2048, + "pm_size": 0, + "hidden_size": 1024, + "num_attention_heads": 16, + "num_key_value_heads": null, + "num_hidden_layers": 1, + "norm_eps": 1e-05, + "dropout": 0.0, + "max_position_embeddings": 24576, + "rope_theta": 500000.0, + "rope_scaling": { + "type": "default" + }, + "hidden_act": "silu", + "intermediate_size": 2816 + }, + "decoder_config": { + "vocab_size": 260, + "cross_attn_all_layers": true, + "cross_attn_k": 2, + "hidden_size_global": 2048, + "hidden_size": 1024, + "num_attention_heads": 16, + "num_key_value_heads": null, + "num_hidden_layers": 9, + "norm_eps": 1e-05, + "dropout": 0.0, + "max_position_embeddings": 24576, + "rope_theta": 500000.0, + "rope_scaling": { + "type": "default" + }, + "hidden_act": "silu", + "intermediate_size": 2816 + }, + "global_config": { + "hidden_size": 2048, + "num_attention_heads": 16, + "num_key_value_heads": null, + "num_hidden_layers": 25, + "norm_eps": 1e-05, + "dropout": 0.0, + "max_position_embeddings": 4096, + "rope_theta": 500000.0, + "rope_scaling": { + "type": "default" + }, + "hidden_act": "silu", + "intermediate_size": 5632 + } +} \ No newline at end of file diff --git a/model.safetensors b/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..5e4e87591312fda1029d125dfeb22f37ea1e0935 --- /dev/null +++ b/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b42ece52607eacbb4e538c695b137a53d38ea68dcc4a03dd825a9656f476162d +size 9266850624 diff --git a/tokenizer_config.json b/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..3b3a3cd875566cd2014531f023e025f7a69e7539 --- /dev/null +++ b/tokenizer_config.json @@ -0,0 +1,11 @@ +{ + "tokenizer_class": "BLTTokenizer", + "vocab_size": 260, + "model_max_length": 4096, + "add_bos_token": true, + "add_eos_token": true, + "bos_token": "", + "eos_token": "", + "pad_token": "", + "unk_token": "" +} \ No newline at end of file