blt-testing / backup_blt_wip copy /modeling_blt_old.py
itazap's picture
itazap HF Staff
Upload BLT model converted
724be6e verified
raw
history blame
58.9 kB
#blt old
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import os
from typing import List, Optional, Tuple, Union
import torch
import torch.nn
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
from ...modeling_utils import PreTrainedModel
from .configuration_blt_og import (
BLTConfig,
PatchingModeEnum,
)
RMSNorm = nn.RMSNorm
logger = logging.getLogger()
flex_attention_comp = flex_attention
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def create_causal_mask(
seqlen,
attn_impl: str,
attn_bias_type: str | None,
*,
eos_id: int | None = None,
tokens: torch.Tensor | None = None,
sliding_window: int | None = None,
):
if attn_impl == "sdpa":
BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0))
if attn_bias_type == "causal":
return "causal"
if BLT_SUPPRESS_ATTN_ERROR == 1:
return "causal"
else:
raise ValueError(
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1"
)
elif attn_impl == "flex_attention":
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
else:
raise NotImplementedError(f"Attention {attn_impl} with {sliding_window} sliding window not implemented")
def cross_entropy(pred, target, **kwargs):
return F.nll_loss(
F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
target.flatten(end_dim=-1),
**kwargs,
)
def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
def precompute_freqs_cis(
dim: int,
end: int,
theta: float = 10000.0,
rope_use_fp32_in_outer_product: bool = False,
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
if rope_use_fp32_in_outer_product:
t = t.to(torch.float32)
freqs = torch.outer(t, freqs).float()
cos, sin = freqs.cos(), freqs.sin()
return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
seq_dim (int): Sequence dimension index.
Returns:
torch.Tensor: Reshaped frequency tensor.
"""
ndim = x.ndim
assert 0 <= seq_dim < ndim
assert freqs_cis.shape == (
x.shape[seq_dim],
x.shape[-3],
2,
2,
), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
shape = [d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])] + [2, 2]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
seq_dim: int,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, seq_dim).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
class RotaryEmbedding(torch.nn.Module):
"""
RotaryEmbedding Module
"""
def __init__(
self,
theta: float,
head_dim: int,
max_seqlen: int = 1024,
rope_use_fp32_in_outer_product: bool = False,
):
super().__init__()
self.theta = theta
self.head_dim = head_dim
self.max_seqlen = max_seqlen
self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
self.register_buffer(
"freqs_cis",
precompute_freqs_cis(
dim=head_dim,
end=max_seqlen,
theta=theta,
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
),
persistent=False,
)
def forward(self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None):
"""
Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
Args:
seqlen (int): Contiguous sequence length
tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
Returns:
Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
"""
test = (seqlen is not None) or (tok_idx is not None)
assert test, "Should provide atleast seqlen or tok_idx"
if tok_idx is not None:
return self.freqs_cis[tok_idx]
elif seqlen is not None:
return self.freqs_cis[0:seqlen]
class BLTSelfAttention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int,
n_heads: int,
n_kv_heads: int,
rope_theta: float,
):
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.rope_theta = rope_theta
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.heads_per_group = self.n_heads // self.n_kv_heads
self.wq = nn.Linear(
dim,
n_heads * head_dim,
bias=False,
)
self.wk = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wv = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wo = nn.Linear(
n_heads * head_dim,
dim,
bias=False,
)
def forward(
self,
x: torch.Tensor,
freq_cis: torch.Tensor,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, str]] = None,
attn_impl: str = "sdpa",
) -> torch.Tensor:
# B S D
bsz, seq_len, dim = x.shape
xq = self.wq(x.view_as(x))
xk = self.wk(x.view_as(x))
xv = self.wv(x.view_as(x))
output_shape = xq.shape
# B S D -> B S H D
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
# This condition helps us be easily compatible
# with inference by adding a pluggable KVCache
if hasattr(self, "kv_cache"):
xk, xv = self.kv_cache.update(xk, xv, tok_idx)
xk = repeat_kv(xk, self.heads_per_group, dim=2)
xv = repeat_kv(xv, self.heads_per_group, dim=2)
if attn_impl == "flex_attention":
assert mask is None or isinstance(mask, BlockMask)
xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv))
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
elif attn_impl == "sdpa":
xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv))
assert mask is None or isinstance(mask, (str, torch.Tensor))
is_causal = (mask == "causal") if isinstance(mask, str) else False
mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None
output = F.scaled_dot_product_attention(
xq,
xk,
xv,
is_causal=is_causal,
attn_mask=mask,
)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
else:
raise NotImplementedError(f"Attention implementation {attn_impl} not supported")
output_reshaped = output.reshape(output_shape)
output = self.wo(output_reshaped)
return output
class BLTMLP(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
mp_size: int = 1,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
assert hidden_dim % mp_size == 0
self.dim = dim
self.hidden_dim = hidden_dim
self.w1 = nn.Linear(
dim,
hidden_dim,
bias=False,
)
self.w3 = nn.Linear(
dim,
hidden_dim,
bias=False,
)
self.w2 = nn.Linear(
hidden_dim,
dim,
bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# B S D
x1 = self.w1(x.view_as(x))
x3 = self.w3(x.view_as(x))
output = self.w2(F.silu(x1) * x3)
return output
class BLTTransformerLayer(nn.Module):
def __init__(self, args):
super().__init__()
# Extract parameters from dictionary
dim = args["dim"]
n_heads = args["n_heads"]
head_dim = args["head_dim"]
n_kv_heads = args["n_kv_heads"]
rope_theta = args["rope_theta"]
multiple_of = args["multiple_of"]
ffn_dim_multiplier = args["ffn_dim_multiplier"]
norm_eps = args["norm_eps"]
assert (head_dim is not None) or (n_heads is not None), "Should specify at least head_dim or n_heads"
self.head_dim = head_dim or dim // n_heads
self.n_heads = n_heads or dim // head_dim
self.n_kv_heads = n_kv_heads or self.n_heads
assert n_heads % self.n_kv_heads == 0
assert dim % n_heads == 0
self.attention = BLTSelfAttention(
dim=dim,
head_dim=self.head_dim,
n_heads=self.n_heads,
n_kv_heads=self.n_kv_heads,
rope_theta=rope_theta,
)
self.feed_forward = BLTMLP(
dim=dim,
hidden_dim=4 * dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
)
self.attention_norm = RMSNorm(dim, eps=norm_eps)
self.ffn_norm = RMSNorm(dim, eps=norm_eps)
def forward(
self,
x: torch.Tensor,
freq_cis: torch.Tensor,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, str]] = None,
attn_impl: str = "sdpa",
) -> torch.Tensor:
norm_x = self.attention_norm(x)
attn_out = self.attention(
norm_x,
freq_cis,
tok_idx=tok_idx,
mask=mask,
attn_impl=attn_impl,
)
h = x + attn_out
h_norm = self.ffn_norm(h)
out = h + self.feed_forward(h_norm)
return out
def check_non_zero_after_zero(tensor):
zero_mask = tensor == 0
shifted_mask = torch.cat(
[
torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
zero_mask[:, :-1],
],
dim=1,
)
non_zero_after_zero = (tensor != 0) & shifted_mask
return non_zero_after_zero.any()
def rolling_polynomial_hash(t, hash_func_nb: int = 0):
primes = [
1000000007,
5915587277,
1500450271,
3267000013,
5754853343,
4093082899,
9576890767,
3628273133,
2860486313,
5463458053,
3367900313,
]
prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device)
prime_powers = torch.stack([prime**i for i in range(t.shape[-1])])
return torch.sum(t * prime_powers, dim=-1)
def byte_group_hash_function(x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000):
"""
Returns a hash of the input x and maps it to a value in the range [0, max_hash].
expects: x of shape (batch_size, seq_len) with values as ids in the token vocab.
returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash].
Note: max hash can make a big difference on the number of collisions.
"""
with torch.no_grad():
bs, seq_len = x.shape
prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device)
x = torch.cat([prefix, x], dim=1)
windows = x.unfold(1, group_size, 1)
# hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
hashes = rolling_polynomial_hash(windows, hash_func_nb)
hash_values_range = hashes % max_hash
hash_values_range.requires_grad = False
return hash_values_range
def create_patch_mask_from_ids(patch_ids, num_patches, window=None, patches_as_queries=False):
"""
Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k)
is True if the patch id at position (i, j) is less than or equal to k.
Args:
patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids.
num_patches (int): Total number of patches.
window (int): If not None, only considers patches within a window of size window.
patches_as_queries (bool): If True, the patches are used as queries
Returns:
torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask.
"""
bs, seq_len = patch_ids.shape
if not patches_as_queries:
q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches)
kv_ids = (
torch.arange(num_patches, device=patch_ids.device)
.unsqueeze(0)
.unsqueeze(0)
.expand(bs, seq_len, num_patches)
)
else:
kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len)
q_ids = (
torch.arange(num_patches, device=patch_ids.device)
.unsqueeze(0)
.unsqueeze(-1)
.expand(bs, num_patches, seq_len)
)
if window is None:
mask = q_ids == kv_ids
else:
mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window)
return mask
def cross_attn_mask(
patch_ids,
patch_lengths,
N,
patches_as_queries=False,
cross_attn_k=1,
window=None,
block_mask=True,
):
bs = patch_ids.shape[0]
with torch.no_grad():
# Create the patch mask
cross_mask = create_patch_mask_from_ids(
patch_ids,
patch_lengths.shape[1],
window=window,
patches_as_queries=patches_as_queries,
).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1)
q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N
kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k
assert cross_mask.shape == (
bs,
q_len,
kv_len,
), f"{cross_mask.shape} != {(bs, q_len, kv_len)}"
# block_mask = None
if block_mask:
def patch_mask(b, h, q_idx, kv_idx):
return cross_mask[b, q_idx, kv_idx]
block_mask = create_block_mask(
patch_mask,
B=bs,
H=None,
Q_LEN=q_len,
KV_LEN=kv_len,
_compile=True,
)
return block_mask
else:
return torch.where(cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))).unsqueeze(
1
) # [bs, 1, q_len, kv_len]
def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor:
if max_patch_length is None:
return patch_lengths
batch_size = patch_lengths.size(0)
split_all = []
max_len = 0
for seq in patch_lengths:
splits = []
for length in seq[seq > 0]:
# Split long patches into max_patch_length chunks
full, rem = divmod(length.item(), max_patch_length)
splits.extend([max_patch_length] * full + ([rem] if rem else []))
split_all.append(splits)
max_len = max(max_len, len(splits))
# Pad sequences to the maximum length
padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
for i, splits in enumerate(split_all):
if splits:
padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
# Trim trailing columns that are all zeros
last_non_zero = (padded != 0).flip(1).int().argmax(1).min()
if last_non_zero < padded.shape[1]:
padded = padded[:, :padded.shape[1] - last_non_zero]
return padded
class BLTLocalModelBase(nn.Module):
def __init__(self, config: BLTConfig, component_type: str = "encoder"):
super().__init__()
self.config = config
if component_type == "encoder":
self.dim = config.dim_local_encoder
self.n_layers = config.n_layers_local_encoder
self.n_heads = config.n_heads_local_encoder
self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen
self.attn_bias_type = "local_block_causal"
self.sliding_window = config.local_attention_window_len
elif component_type == "decoder":
self.dim = config.dim_local_decoder
self.n_layers = config.n_layers_local_decoder
self.n_heads = config.n_heads_local_decoder
self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen
self.attn_bias_type = "local_block_causal"
self.sliding_window = config.local_attention_window_len
else:
raise ValueError(f"Unknown component_type: {component_type}")
self.dropout = config.dropout
self.vocab_size = config.vocab_size + config.pm_size
self.patch_size = config.patch_size
self.attn_impl = config.attn_impl
self.use_rope = config.use_rope
self.init_std_factor = config.init_std_factor
self.init_base_std = config.init_base_std
self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None)
self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None)
self.cross_attn_k = getattr(config, "cross_attn_k", None)
self.eos_id = config.eos_token_id
self.boe_id = config.boe_id
# Initialize cross attention layers as None (will be set by subclasses if needed)
self.cross_attn_layers = None
# Create parameter dict for BLTTransformerLayers
layer_params = {
"dim": self.dim,
"n_heads": self.n_heads,
"head_dim": config.head_dim,
"n_kv_heads": getattr(config, "n_kv_heads", None),
"rope_theta": config.rope_theta,
"multiple_of": getattr(config, "multiple_of", 256),
"ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None),
"norm_eps": config.norm_eps,
}
self.layers = nn.ModuleList([BLTTransformerLayer(layer_params) for _ in range(self.n_layers)])
if not self.use_rope:
self.pos_embeddings = nn.Embedding(2048, self.dim) # fallback max_length
else:
self.rope = RotaryEmbedding(
theta=config.rope_theta,
head_dim=config.head_dim or self.dim // self.n_heads,
max_seqlen=self.max_seqlen,
rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product,
)
self.pos_embeddings = None
# Set dimension-specific embedding dimensions
if component_type == "encoder":
self.dim_token_emb = config.encoder_dim_token_emb
self.dim_patch_emb = config.encoder_dim_patch_emb
elif component_type == "decoder":
self.dim_token_emb = config.decoder_dim_token_emb
self.dim_patch_emb = config.dim_global
self.token_embedding_projection = (
nn.Linear(self.dim_token_emb, self.dim, bias=False)
if self.dim_token_emb is not None and self.dim_token_emb != self.dim
else None
)
self.patch_embedding_projection = self._create_patch_projection(config)
def _should_create_patch_projection(self, config: BLTConfig):
dimension_mismatch = self.dim_patch_emb is not None and self.dim_patch_emb != self.dim
# Check cross attention conditions
cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or (
config.cross_attn_decoder and config.cross_attn_init_by_pooling
)
return dimension_mismatch or cross_attn_conditions
def _create_patch_projection(self, config):
if not self._should_create_patch_projection(config):
return None
output_dim = self.dim_token_emb * (self.cross_attn_k or 1)
return nn.Linear(
in_features=self.dim_patch_emb,
out_features=output_dim,
bias=False,
)
def apply_embedding(self, tokens, embeds):
if embeds is not None:
return embeds
else:
return self.tok_embeddings(tokens)
class BLTLocalEncoder(BLTLocalModelBase):
def __init__(self, config: BLTConfig):
super().__init__(config, component_type="encoder")
self.apply_transformer = config.use_local_encoder_transformer
self.downsampling_by_pooling = config.downsampling_by_pooling
self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None
self.cross_attn_encoder = config.cross_attn_encoder
self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder
self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling
self.cross_attn_nheads = config.cross_attn_nheads
self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim)
if self.cross_attn_encoder:
self.cross_attn_layers = torch.nn.ModuleList()
layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1
for _ in range(layers_to_add):
self.cross_attn_layers.append(
BLTCrossAttention(
dim=self.dim,
head_dim=self.dim // self.cross_attn_nheads,
n_heads=self.cross_attn_nheads,
n_kv_heads=self.cross_attn_nheads,
norm_eps=config.norm_eps,
)
)
def apply_embedding(self, tokens, embeds):
if embeds is not None:
assert self.expects_hash_embeddings, "Not expecting embeddings to be passed."
return embeds
else:
return self.tok_embeddings(tokens)
def forward(
self,
tokens: torch.Tensor,
embeds: Optional[torch.Tensor] = None,
patch_embeds: Optional[torch.Tensor] = None,
mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
cross_mask: Optional[torch.Tensor] = None,
num_patches: Optional[int] = None,
patch_ids: Optional[torch.Tensor] = None,
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
):
""" """
bs, seqlen = tokens.shape
if mask is None:
mask = create_causal_mask(
seqlen,
self.attn_impl,
"local_block_causal",
sliding_window=self.sliding_window,
tokens=tokens,
eos_id=self.eos_id,
)
h = self.apply_embedding(tokens, embeds)
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
h = F.dropout(h, p=self.dropout, training=self.training)
for i, layer in enumerate(self.layers):
h = layer(h, freqs_cis, tok_idx=None, mask=mask, attn_impl=self.attn_impl)
# check if cross attention should be applied to either all layer or only the last layer
if self.cross_attn_encoder and (i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder):
# apply pooling and project
if self.cross_attn_init_by_pooling and patch_embeds is None:
patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids)
if self.patch_embedding_projection is not None:
patch_embeds = self.patch_embedding_projection(patch_embeds)
patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim)
layer_idx = i if self.cross_attn_all_layers_encoder else 0
patch_embeds_cross = self.cross_attn_layers[layer_idx](
x=patch_embeds,
kv=h,
mask=cross_mask,
)
patch_embeds = patch_embeds + patch_embeds_cross
h_residual = patch_embeds if self.cross_attn_encoder else None
return (h, h_residual), cache
def patch_reduce(self, h, max_num_patches, reduction, patch_ids):
"""
Reduce variable length patches to single embedding per patch
Note: this works with variable number of patches for different sequences in the batch
It handles variable length patches by assuming that patch_lengths will be 0 for any
extra patches on the *right*. Since there can be a variable number of patches
this function also return the number of patches for each sequence in the batch.
Any embeddings on the right that are not allocated to a patch
(i.e. if the sum(patch_lengths[i]) < seq_len for any i)
will be sent to a dummy patch, which is trimmed before returning.
"""
bs, seq_len, emb_dim = h.shape
patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
reduced_embs = torch.zeros((bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device)
reduced_embs = reduced_embs.scatter_reduce(
src=h,
dim=1,
index=patch_ids,
reduce=reduction,
include_self=False,
)
reduced_embs = reduced_embs[:, :max_num_patches, :]
return reduced_embs
class BLTLocalDecoder(BLTLocalModelBase):
def __init__(self, config: BLTConfig):
super().__init__(config, component_type="decoder")
# Model configuration flags
self.cross_attn_decoder = config.cross_attn_decoder
self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder
self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling
self.cross_attn_nheads = config.cross_attn_nheads
self.norm = RMSNorm(self.dim, eps=config.norm_eps)
if self.cross_attn_decoder:
self.cross_attn_layers = torch.nn.ModuleList()
layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1
for _ in range(layers_to_add):
self.cross_attn_layers.append(
BLTCrossAttention(
dim=self.dim,
head_dim=self.dim // self.cross_attn_nheads,
n_heads=self.cross_attn_nheads,
n_kv_heads=self.cross_attn_nheads,
norm_eps=config.norm_eps,
)
)
self.output = nn.Linear(
self.dim,
config.vocab_size,
bias=False,
)
def forward(
self,
tokens: torch.Tensor,
embeds: Optional[torch.Tensor],
patch_embeds: Optional[torch.Tensor] = None,
mask: Optional[Union["BlockMask", torch.Tensor, str]] = None,
cross_mask: Optional[torch.Tensor] = None,
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
):
bs, seqlen = tokens.shape
assert embeds is not None, "Embeddings must be provided"
if mask is None:
mask = create_causal_mask(
seqlen,
self.attn_impl,
"local_block_causal",
sliding_window=self.sliding_window,
tokens=tokens,
eos_id=self.eos_id,
)
h = embeds
if self.patch_embedding_projection is not None:
assert patch_embeds is not None, "Patch embeddings must be passed."
patch_embeds = self.patch_embedding_projection(patch_embeds)
if self.cross_attn_k is not None:
patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim)
if patch_embeds is not None and not self.cross_attn_decoder:
h = h + patch_embeds
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
h = F.dropout(h, p=self.dropout, training=self.training)
for i, layer in enumerate(self.layers):
if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder):
# Use cross attention to extract info from patch_embeds into h
h_cross = self.cross_attn_layers[i](
x=h,
kv=patch_embeds,
mask=cross_mask,
)
h = h + h_cross
h = layer(h, freqs_cis, tok_idx=None, mask=mask, attn_impl=self.attn_impl)
h_preds = self.norm(h)
h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
h_preds = self.output(h_preds)
h_preds = h_preds.float()
return h_preds, cache
class BLTCrossAttention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float,
):
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.heads_per_group = self.n_heads // self.n_kv_heads
self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps)
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
self.wq = nn.Linear(
dim,
n_heads * head_dim,
bias=False,
)
self.wk = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wv = nn.Linear(
dim,
n_kv_heads * head_dim,
bias=False,
)
self.wo = nn.Linear(
n_heads * head_dim,
dim,
bias=False,
)
def forward(
self,
x: torch.Tensor,
kv: torch.Tensor,
mask: Optional[Union[BlockMask, str]] = None,
) -> torch.Tensor:
# B S D
bsz, seq_len, _ = x.shape
_, slen_kv, _ = kv.shape
x_norm = self.cross_attn_norm_q(x)
kv = self.cross_attn_norm_kv(kv)
xq = self.wq(x_norm)
xk = self.wk(kv)
xv = self.wv(kv)
output_shape = xq.shape
# B S D -> B S H D
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
xk = repeat_kv(xk, self.heads_per_group, dim=2)
xv = repeat_kv(xv, self.heads_per_group, dim=2)
# assert mask is None or isinstance(mask, BlockMask)
xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv))
# output = flex_attention_comp(xq, xk, xv, block_mask=mask)
is_causal = (mask == "causal") if isinstance(mask, str) else False
mask = mask if isinstance(mask, torch.Tensor) else None
mask = mask.to(dtype=xq.dtype).to(xq.device)
output = F.scaled_dot_product_attention(
xq,
xk,
xv,
is_causal=is_causal,
attn_mask=mask,
)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
output = self.wo(output.reshape(output_shape))
return x + output
class BLTGlobalTransformer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.dim = config.dim_global
self.rope_embeddings = RotaryEmbedding(
theta=config.rope_theta,
head_dim=config.head_dim or self.config.dim_global // config.n_heads_global,
max_seqlen=config.max_seqlen,
rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product,
)
# Handle both eos_id and eos_token_id for compatibility
self.eos_id = getattr(config, "eos_id", getattr(config, "eos_token_id", 2))
# Create parameter dict for BLTTransformerLayers
layer_params = {
"dim": self.dim,
"n_heads": config.n_heads_global,
"head_dim": config.head_dim,
"n_kv_heads": getattr(config, "n_kv_heads_global", None),
"rope_theta": config.rope_theta,
"multiple_of": getattr(config, "multiple_of", 256),
"ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None),
"norm_eps": config.norm_eps,
}
self.layers = nn.ModuleList()
for _ in range(config.n_layers_global):
self.layers.append(BLTTransformerLayer(layer_params))
self.token_embedding_projection = None
if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim:
self.token_embedding_projection = nn.Linear(
config.global_dim_patch_emb,
config.dim_global,
bias=False,
)
def forward(
self,
tokens: torch.Tensor,
tok_idx: Optional[torch.Tensor] = None,
embeds: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, torch.Tensor, str]] = None,
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
):
bs, seqlen = tokens.shape
h = embeds
mask = (
mask
if mask is not None
else create_causal_mask(
seqlen,
self.config.attn_impl,
self.config.attn_bias_type,
tokens=tokens,
eos_id=self.eos_id,
)
)
if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
h = self.token_embedding_projection(h)
h = F.dropout(h, p=self.config.dropout, training=self.training)
freq_cis = self.rope_embeddings(seqlen=self.config.max_seqlen, tok_idx=tok_idx)
for i, layer in enumerate(self.layers):
h = layer(h, freq_cis, tok_idx=None, mask=mask, attn_impl=self.config.attn_impl)
return h, cache
def compute_hash_embeddings(
local_encoder_tokens: torch.Tensor,
local_encoder,
encoder_hash_tok_embedding: nn.ModuleList,
encoder_hash_byte_group_nb_functions: int,
encoder_hash_byte_group_size: list,
encoder_hash_byte_group_vocab: int,
) -> torch.Tensor:
"""
Compute embeddings using hash token embeddings.
Args:
local_encoder_tokens: Input tokens tensor
local_encoder: Encoder object with tok_embeddings method
encoder_hash_tok_embedding: ModuleList of hash token embeddings
encoder_hash_byte_group_nb_functions: Number of hash functions
encoder_hash_byte_group_size: List of byte group sizes
encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings
Returns:
torch.Tensor: Combined embeddings
"""
if encoder_hash_tok_embedding is None:
return None
local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens)
i = 0
for func_nb in range(encoder_hash_byte_group_nb_functions):
for byte_group_size in encoder_hash_byte_group_size:
hash_ids = byte_group_hash_function(
local_encoder_tokens,
byte_group_size,
hash_func_nb=func_nb,
max_hash=encoder_hash_byte_group_vocab,
)
hash_tok_embedding = encoder_hash_tok_embedding[i]
local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids)
i += 1
assert i == len(encoder_hash_tok_embedding)
return local_encoder_embeds
class BLTPreTrainedModel(PreTrainedModel):
config_class = BLTConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = False # BLT uses its own attention implementation
_supports_sdpa = True
_supports_cache_class = False
def _init_weights(self, module):
if isinstance(module, nn.Linear):
std = getattr(module, '_custom_std', module.in_features ** (-0.5))
nn.init.trunc_normal_(
module.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5))
nn.init.trunc_normal_(
module.weight,
mean=0.0,
std=std,
a=-3 * std,
b=3 * std,
)
elif isinstance(module, (nn.RMSNorm, nn.LayerNorm)):
nn.init.ones_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, RotaryEmbedding):
module.freqs_cis[...] = precompute_freqs_cis(
dim=module.head_dim,
end=module.max_seqlen,
theta=module.theta,
rope_use_fp32_in_outer_product=module.rope_use_fp32_in_outer_product,
)
elif isinstance(module, BLTModel):
if module.encoder_hash_tok_embedding is not None:
emb_std = module.local_encoder.dim ** (-0.5)
for emb in module.encoder_hash_tok_embedding:
emb._custom_std = emb_std
elif isinstance(module, (BLTLocalEncoder, BLTLocalDecoder)):
if module.token_embedding_projection is not None:
module.token_embedding_projection._custom_std = module.dim ** (-0.5)
if module.patch_embedding_projection is not None:
module.patch_embedding_projection._custom_std = module.dim_patch_emb ** (-0.5)
elif isinstance(module, BLTGlobalTransformer):
if module.token_embedding_projection is not None:
module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5)
elif isinstance(module, BLTPatcher):
emb_std = module.config.patcher_dim ** (-0.5)
module.tok_embeddings._custom_std = emb_std
module.output._custom_std = emb_std
class BLTModel(BLTPreTrainedModel):
def __init__(self, config: BLTConfig):
super().__init__(config)
self.config = config
self.local_encoder = BLTLocalEncoder(config)
self.global_transformer = BLTGlobalTransformer(config)
self.local_decoder = BLTLocalDecoder(config)
self.encoder_hash_tok_embedding = init_hash_embeddings(
config,
local_encoder_dim=self.local_encoder.dim,
encoder_hash_byte_group_size=config.encoder_hash_byte_group_size,
)
if config.patch_in_forward:
self.patcher = BLTPatcher(config)
self.patcher.eval()
for param in self.patcher.parameters():
param.requires_grad = False
else:
self.patcher = None
def forward(
self,
tokens: torch.Tensor,
patch_lengths: Optional[torch.Tensor] = None,
):
# NOTE: ngram_ids parameter removed since frequency-based n-gram embeddings
# are no longer used in the final BLT model
bs, N = tokens.shape # Batch size and sequence length
local_encoder_tokens, local_decoder_tokens = tokens, tokens
# Patching
if patch_lengths is None:
# assert (
# getattr(self.config, "patch_in_forward", None) is not None and self.config.patch_in_forward
# ), "Patch in forward not enabled and no patch_lengths passed."
# PATCHER MODEL DEFINED
if self.config.patching_mode == PatchingModeEnum.entropy:
_, patch_lengths, _ = self.patcher(
local_encoder_tokens,
patch_size=self.config.patch_size,
include_next_token=True,
threshold=self.config.patching_threshold,
max_patch_length=self.config.max_patch_length,
patching_batch_size=self.config.patching_batch_size,
device=self.config.patching_device,
)
else:
# self.config.patching_mode == PatchingModeEnum.byte
bs, seq_len = local_encoder_tokens.shape
seq_len_next_tok = seq_len + 1 # include_next_token=True
patch_lengths = torch.ones(
(bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device
)
patch_lengths = process_patch_lengths(patch_lengths, self.config.max_patch_length)
#assert torch.min(patch_lengths) >= 0
# Generate patch IDs from patch_lengths
patch_ids = self._patch_ids_from_lengths(patch_lengths, local_encoder_tokens.shape[-1])
# assert torch.max(patch_ids) + 1 <= torch.max((patch_lengths != 0).sum(dim=-1)), (
# f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}"
# )
cross_attn_mask_enc = None
# Cross-attention encoder
if self.config.cross_attn_encoder:
cross_attn_mask_enc = cross_attn_mask(
patch_ids,
patch_lengths,
N,
patches_as_queries=True,
cross_attn_k=self.config.cross_attn_k,
window=self.config.cross_attn_window_encoder,
block_mask=self.config.cross_attn_use_flex_attention,
)
# Hashing and embedding
local_encoder_embeds = compute_hash_embeddings(
local_encoder_tokens=local_encoder_tokens,
local_encoder=self.local_encoder,
encoder_hash_tok_embedding=self.encoder_hash_tok_embedding,
encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions,
encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size,
encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab,
)
# NOTE: Frequency-based n-gram embeddings removed as per paper
# The final BLT model uses only hash-based n-gram embeddings
# Local encoder
(h_encoder, h_cross), cache_encoder = self.local_encoder(
tokens=local_encoder_tokens,
embeds=local_encoder_embeds,
patch_embeds=None,
cross_mask=cross_attn_mask_enc,
num_patches=patch_lengths.shape[1],
patch_ids=patch_ids,
)
# Downsampling
h = h_cross.view(bs, patch_lengths.shape[1], -1)
# Global transformer
global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.config.boe_id)
rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id)
eos_patch_ids = patch_ids[rows, cols]
global_tokens[rows, eos_patch_ids] = self.config.eos_token_id
h, _ = self.global_transformer(
embeds=h,
tokens=global_tokens,
)
# Unpatching
dec_embeds = h_encoder
# Decoder uses patches 1,2,3,... (skipping patch 0 which contains BOE tokens), so we need to map decoder positions to the remaining patches.
decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], local_decoder_tokens.shape[-1])
# assert torch.max(decoder_patch_ids) + 1 <= h.shape[1], f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}"
# assert decoder_patch_ids.shape[1] == dec_embeds.shape[1], (
# f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}"
# )
# Cross-attention decoder
if not self.config.cross_attn_decoder:
h = torch.gather(h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]))
cross_attn_mask_dec = None
# assert local_decoder_tokens.shape == h.shape[:-1]
else:
cross_attn_mask_dec = cross_attn_mask(
decoder_patch_ids,
patch_lengths,
N,
patches_as_queries=False,
cross_attn_k=self.config.cross_attn_k,
window=self.config.cross_attn_window_decoder,
block_mask=self.config.cross_attn_use_flex_attention,
)
# Local decoder
output, _ = self.local_decoder(
embeds=dec_embeds,
patch_embeds=h,
tokens=local_decoder_tokens,
cross_mask=cross_attn_mask_dec,
)
return output
def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
"""
Convert patch lengths to patch IDs for each token position.
For each token position in the sequence, determines which patch it belongs to.
Args:
patch_lengths: [batch_size, num_patches] - length of each patch
seq_len: total sequence length
Returns:
patch_ids: [batch_size, seq_len] - patch index for each token position
Example:
patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1
seq_len = 10
Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]]
# pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3
"""
batch_size, num_patches = patch_lengths.shape
# Create patch start positions: [0, 3, 5, 9] for the example above
patch_starts = torch.cat(
[
torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
patch_lengths.cumsum(dim=-1)[:, :-1], # cumsum without the final total
],
dim=-1,
)
# For each token position, find which patch it belongs to
# by finding the rightmost patch start that's <= the position
token_positions = torch.arange(seq_len, device=patch_lengths.device) # [0, 1, 2, ..., seq_len-1]
# Broadcasting: patch_starts[batch, patch] <= token_positions[position]
# Result: [batch, seq_len, num_patches] where result[b,t,p] = True if patch p starts <= position t
position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)
# Count how many patch starts are <= each position, then subtract 1 to get patch index
patch_ids = position_ge_patch_start.sum(dim=-1) - 1
return patch_ids
class BLTPatcher(BLTPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.rope_embeddings = RotaryEmbedding(
theta=config.patcher_rope_theta,
head_dim=config.patcher_head_dim or config.patcher_dim // config.patcher_n_heads,
max_seqlen=config.patcher_max_seqlen,
rope_use_fp32_in_outer_product=config.patcher_rope_use_fp32_in_outer_product,
)
self.layers = nn.ModuleList()
for _ in range(config.patcher_n_layers):
self.layers.append(
BLTTransformerLayer(
{
"dim": config.patcher_dim,
"n_heads": config.patcher_n_heads,
"head_dim": config.patcher_head_dim,
"n_kv_heads": config.patcher_n_kv_heads,
"rope_theta": config.patcher_rope_theta,
"multiple_of": config.patcher_multiple_of,
"ffn_dim_multiplier": config.patcher_ffn_dim_multiplier,
"norm_eps": config.patcher_norm_eps,
}
)
)
#assert config.patcher_vocab_size > 0
self.tok_embeddings = torch.nn.Embedding(config.patcher_vocab_size, config.patcher_dim)
self.norm = RMSNorm(config.patcher_dim, eps=config.patcher_norm_eps)
self.output = nn.Linear(
config.patcher_dim,
config.patcher_vocab_size,
bias=False,
)
def forward(
self,
token_values: torch.Tensor,
patch_size: Optional[int] = None,
include_next_token: bool = True,
threshold: Optional[float] = None,
max_patch_length: Optional[int] = None,
patching_batch_size: int = 1,
device: Optional[str] = None,
):
# Handle chunked processing for entropy calculation
entropies = []
preds = []
max_length = self.config.patcher_max_seqlen
batch_numel = max_length * patching_batch_size
splits = torch.split(token_values.flatten(), batch_numel)
for split in splits:
pad_size = (max_length - (split.numel() % max_length)) % max_length
pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False)
split = torch.cat((split, pad), dim=0)
split = split.reshape(-1, max_length)
if device is not None:
split = split.to(device)
# Process chunk: embeddings -> layers -> output
bsz, seqlen = split.shape
h = self.tok_embeddings(split)
chunk_mask = create_causal_mask(
seqlen,
self.config.patcher_attn_impl ,
self.config.patcher_attn_bias_type,
sliding_window=self.config.patcher_sliding_window,
tokens=split,
eos_id=self.config.eos_id,
)
freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None)
for i, layer in enumerate(self.layers):
h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=self.config.patcher_attn_impl)
pred = self.output(self.norm(h))
pred = pred.reshape(-1, pred.shape[-1])[: split.numel() - pad_size, :] # [batch_size * seq_len, vocab]
preds.append(pred)
pred_entropies = self.entropy(pred)
entropies.append(pred_entropies)
concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape)
concat_preds = torch.cat(preds, dim=0).reshape(token_values.shape[0], -1)
# Always compute patch lengths from concatenated entropies
bs, seq_len = token_values.shape
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
# Find patch start IDs based on entropy
if patch_size is not None:
patch_start_ids = self.find_entropy_patch_start_ids(
concat_entropies,
patch_size,
include_next_token=include_next_token,
threshold=threshold
)
patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok)
else:
# Default to byte-level patching
patch_lengths = torch.ones((bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device)
patch_lengths = process_patch_lengths(patch_lengths, max_patch_length)
return concat_entropies, patch_lengths, concat_preds
@staticmethod
def entropy(scores):
"""
scores: [bs, seq_len, vocab]
returns [bs, seq_len]
Computes the entropy for each token in the batch.
Note: uses natural log.
"""
log_probs = F.log_softmax(scores, dim=-1)
probs = torch.exp(log_probs)
p_log_p = log_probs * probs
entropy = -p_log_p.sum(dim=-1)
return entropy
@staticmethod
def patch_start_ids_from_patch_start_mask(patch_start_mask):
bs, trunc_seq_len = patch_start_mask.shape
max_patches = patch_start_mask.sum(dim=1).max()
if max_patches == 0:
patch_start_ids = torch.full(
(bs, trunc_seq_len),
trunc_seq_len,
dtype=torch.long,
device=patch_start_mask.device,
)
else:
patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(bs, 1)
extra_patch_ids = torch.full(
(bs, trunc_seq_len),
trunc_seq_len,
dtype=torch.long,
device=patch_start_mask.device,
)
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1)
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, trunc_seq_len)[:, :max_patches]
return patch_start_ids
@staticmethod
def patch_lengths_from_start_ids(patch_start_ids, seq_len):
"""
Calculate patch lengths from start ids.
start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then
the rest are filled to the seq len.
seq_len: ex: 7 length of the sequence
returns the patch lengths:
[1, 6] for the above example.
"""
last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1)
patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1)
patch_lengths = patch_end_ids - patch_start_ids + 1
assert torch.all(patch_lengths >= 0), f"{patch_lengths}"
assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}"
return patch_lengths
@staticmethod
def find_entropy_patch_start_ids(
entropies,
patch_size=None,
threshold=None,
include_next_token=True,
):
"""
Use entropies to find the start ids of each patch.
Use patch_size or threshold to figure out the total number of patches to allocate.
When threshold is not None the number of patches is not constant between
different sequences, but patches can be identified incrementally rather than
decided globally using the entire sequence.
"""
bs, seq_len = entropies.shape[:2]
first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(bs, 1)
preds_truncation_len = first_ids.shape[1] # remove the first preds because they will be start of patches.
entropies = entropies[:, 1:]
if threshold is None:
num_patches = seq_len // patch_size
patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices
patch_start_ids = patch_start_ids.sort(dim=1).values
else:
patch_start_mask = entropies > threshold
if not include_next_token:
patch_start_mask = patch_start_mask[:, :-1]
# patch_start_mask[1:] |= tokens[:-1] < OFFSET
patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask)
patch_start_ids = torch.cat((first_ids, patch_start_ids + preds_truncation_len), dim=1)
return patch_start_ids
def init_hash_embeddings(
config,
local_encoder_dim: int,
encoder_hash_byte_group_size: list,
):
"""Initialize hash-based token embeddings for the BLT encoder."""
if config.encoder_hash_byte_group_size is None:
return None
embeddings = []
emb_dim = local_encoder_dim
encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab
for _ in range(config.encoder_hash_byte_group_nb_functions):
for _ in encoder_hash_byte_group_size:
embeddings.append(
nn.Embedding(
encoder_hash_byte_group_vocab,
emb_dim,
)
)
return nn.ModuleList(embeddings)
__all__ = [
"BLTPreTrainedModel",
"BLTModel",
"BLTPatcher",
"BLTLocalEncoder",
"BLTLocalDecoder",
"BLTGlobalTransformer",
]