|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention |
|
|
|
from ...modeling_utils import PreTrainedModel |
|
from .configuration_blt_og import ( |
|
BLTConfig, |
|
PatchingModeEnum, |
|
) |
|
|
|
RMSNorm = nn.RMSNorm |
|
|
|
logger = logging.getLogger() |
|
|
|
flex_attention_comp = flex_attention |
|
|
|
|
|
def causal_mask(b, h, q_idx, kv_idx): |
|
return q_idx >= kv_idx |
|
|
|
|
|
def create_causal_mask( |
|
seqlen, |
|
attn_impl: str, |
|
attn_bias_type: str | None, |
|
*, |
|
eos_id: int | None = None, |
|
tokens: torch.Tensor | None = None, |
|
sliding_window: int | None = None, |
|
): |
|
if attn_impl == "sdpa": |
|
BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0)) |
|
|
|
if attn_bias_type == "causal": |
|
return "causal" |
|
|
|
if BLT_SUPPRESS_ATTN_ERROR == 1: |
|
return "causal" |
|
else: |
|
raise ValueError( |
|
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1" |
|
) |
|
elif attn_impl == "flex_attention": |
|
return create_block_mask(causal_mask, None, None, seqlen, seqlen) |
|
else: |
|
raise NotImplementedError(f"Attention {attn_impl} with {sliding_window} sliding window not implemented") |
|
|
|
|
|
def cross_entropy(pred, target, **kwargs): |
|
return F.nll_loss( |
|
F.log_softmax(pred.flatten(end_dim=-2).float(), -1), |
|
target.flatten(end_dim=-1), |
|
**kwargs, |
|
) |
|
|
|
|
|
def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: |
|
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)""" |
|
assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims." |
|
bs, slen, n_kv_heads, head_dim = x.shape |
|
if n_rep == 1: |
|
return x |
|
return ( |
|
x[:, :, :, None, :] |
|
.expand(bs, slen, n_kv_heads, n_rep, head_dim) |
|
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) |
|
) |
|
|
|
|
|
def precompute_freqs_cis( |
|
dim: int, |
|
end: int, |
|
theta: float = 10000.0, |
|
rope_use_fp32_in_outer_product: bool = False, |
|
): |
|
""" |
|
Precompute the frequency tensor for complex exponentials (cis) with given dimensions. |
|
|
|
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' |
|
and the end index 'end'. The 'theta' parameter scales the frequencies. |
|
The returned tensor contains complex values in complex64 data type. |
|
|
|
Args: |
|
dim (int): Dimension of the frequency tensor. |
|
end (int): End index for precomputing frequencies. |
|
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. |
|
|
|
Returns: |
|
torch.Tensor: Precomputed frequency tensor with complex exponentials. |
|
""" |
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
|
t = torch.arange(end, device=freqs.device) |
|
if rope_use_fp32_in_outer_product: |
|
t = t.to(torch.float32) |
|
|
|
freqs = torch.outer(t, freqs).float() |
|
|
|
cos, sin = freqs.cos(), freqs.sin() |
|
|
|
return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2) |
|
|
|
|
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): |
|
""" |
|
Reshape frequency tensor for broadcasting it with another tensor. |
|
|
|
This function reshapes the frequency tensor to have the same shape as the target tensor 'x' |
|
for the purpose of broadcasting the frequency tensor during element-wise operations. |
|
|
|
Args: |
|
freqs_cis (torch.Tensor): Frequency tensor to be reshaped. |
|
x (torch.Tensor): Target tensor for broadcasting compatibility. |
|
seq_dim (int): Sequence dimension index. |
|
|
|
Returns: |
|
torch.Tensor: Reshaped frequency tensor. |
|
""" |
|
ndim = x.ndim |
|
assert 0 <= seq_dim < ndim |
|
assert freqs_cis.shape == ( |
|
x.shape[seq_dim], |
|
x.shape[-3], |
|
2, |
|
2, |
|
), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" |
|
shape = [d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])] + [2, 2] |
|
return freqs_cis.view(*shape) |
|
|
|
|
|
def apply_rotary_emb( |
|
xq: torch.Tensor, |
|
xk: torch.Tensor, |
|
seq_dim: int, |
|
freqs_cis: torch.Tensor, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) |
|
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) |
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, seq_dim).float() |
|
xq_out = (xq_ * freqs_cis).sum(5).flatten(3) |
|
xk_out = (xk_ * freqs_cis).sum(5).flatten(3) |
|
return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
|
|
|
class RotaryEmbedding(torch.nn.Module): |
|
""" |
|
RotaryEmbedding Module |
|
""" |
|
|
|
def __init__( |
|
self, |
|
theta: float, |
|
head_dim: int, |
|
max_seqlen: int = 1024, |
|
rope_use_fp32_in_outer_product: bool = False, |
|
): |
|
super().__init__() |
|
|
|
self.theta = theta |
|
self.head_dim = head_dim |
|
self.max_seqlen = max_seqlen |
|
self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product |
|
|
|
self.register_buffer( |
|
"freqs_cis", |
|
precompute_freqs_cis( |
|
dim=head_dim, |
|
end=max_seqlen, |
|
theta=theta, |
|
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, |
|
), |
|
persistent=False, |
|
) |
|
|
|
|
|
def forward(self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None): |
|
""" |
|
Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions |
|
Args: |
|
seqlen (int): Contiguous sequence length |
|
tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen |
|
|
|
Returns: |
|
Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis |
|
""" |
|
test = (seqlen is not None) or (tok_idx is not None) |
|
assert test, "Should provide atleast seqlen or tok_idx" |
|
if tok_idx is not None: |
|
return self.freqs_cis[tok_idx] |
|
elif seqlen is not None: |
|
return self.freqs_cis[0:seqlen] |
|
|
|
|
|
class BLTSelfAttention(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
head_dim: int, |
|
n_heads: int, |
|
n_kv_heads: int, |
|
rope_theta: float, |
|
): |
|
super().__init__() |
|
|
|
self.dim = dim |
|
self.head_dim = head_dim |
|
self.rope_theta = rope_theta |
|
|
|
self.n_heads = n_heads |
|
self.n_kv_heads = n_kv_heads |
|
self.heads_per_group = self.n_heads // self.n_kv_heads |
|
|
|
self.wq = nn.Linear( |
|
dim, |
|
n_heads * head_dim, |
|
bias=False, |
|
) |
|
self.wk = nn.Linear( |
|
dim, |
|
n_kv_heads * head_dim, |
|
bias=False, |
|
) |
|
self.wv = nn.Linear( |
|
dim, |
|
n_kv_heads * head_dim, |
|
bias=False, |
|
) |
|
|
|
self.wo = nn.Linear( |
|
n_heads * head_dim, |
|
dim, |
|
bias=False, |
|
) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
freq_cis: torch.Tensor, |
|
tok_idx: Optional[torch.Tensor] = None, |
|
mask: Optional[Union[BlockMask, str]] = None, |
|
attn_impl: str = "sdpa", |
|
) -> torch.Tensor: |
|
|
|
bsz, seq_len, dim = x.shape |
|
|
|
xq = self.wq(x.view_as(x)) |
|
xk = self.wk(x.view_as(x)) |
|
xv = self.wv(x.view_as(x)) |
|
|
|
output_shape = xq.shape |
|
|
|
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) |
|
xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) |
|
xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) |
|
|
|
xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len]) |
|
|
|
|
|
|
|
if hasattr(self, "kv_cache"): |
|
xk, xv = self.kv_cache.update(xk, xv, tok_idx) |
|
|
|
xk = repeat_kv(xk, self.heads_per_group, dim=2) |
|
xv = repeat_kv(xv, self.heads_per_group, dim=2) |
|
|
|
if attn_impl == "flex_attention": |
|
assert mask is None or isinstance(mask, BlockMask) |
|
xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) |
|
output = flex_attention_comp(xq, xk, xv, block_mask=mask) |
|
output = output.transpose(1, 2).contiguous() |
|
|
|
elif attn_impl == "sdpa": |
|
xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) |
|
assert mask is None or isinstance(mask, (str, torch.Tensor)) |
|
is_causal = (mask == "causal") if isinstance(mask, str) else False |
|
mask = mask.to(xq.device) if isinstance(mask, torch.Tensor) else None |
|
output = F.scaled_dot_product_attention( |
|
xq, |
|
xk, |
|
xv, |
|
is_causal=is_causal, |
|
attn_mask=mask, |
|
) |
|
output = output.transpose(1, 2).contiguous() |
|
else: |
|
raise NotImplementedError(f"Attention implementation {attn_impl} not supported") |
|
|
|
output_reshaped = output.reshape(output_shape) |
|
|
|
output = self.wo(output_reshaped) |
|
|
|
return output |
|
|
|
|
|
class BLTMLP(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
hidden_dim: int, |
|
multiple_of: int, |
|
ffn_dim_multiplier: Optional[float], |
|
mp_size: int = 1, |
|
): |
|
super().__init__() |
|
|
|
hidden_dim = int(2 * hidden_dim / 3) |
|
if ffn_dim_multiplier is not None: |
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim) |
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
|
assert hidden_dim % mp_size == 0 |
|
|
|
self.dim = dim |
|
self.hidden_dim = hidden_dim |
|
|
|
self.w1 = nn.Linear( |
|
dim, |
|
hidden_dim, |
|
bias=False, |
|
) |
|
self.w3 = nn.Linear( |
|
dim, |
|
hidden_dim, |
|
bias=False, |
|
) |
|
self.w2 = nn.Linear( |
|
hidden_dim, |
|
dim, |
|
bias=False, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
x1 = self.w1(x.view_as(x)) |
|
x3 = self.w3(x.view_as(x)) |
|
output = self.w2(F.silu(x1) * x3) |
|
return output |
|
|
|
|
|
|
|
|
|
class BLTTransformerLayer(nn.Module): |
|
def __init__(self, args): |
|
super().__init__() |
|
|
|
|
|
dim = args["dim"] |
|
n_heads = args["n_heads"] |
|
head_dim = args["head_dim"] |
|
n_kv_heads = args["n_kv_heads"] |
|
rope_theta = args["rope_theta"] |
|
multiple_of = args["multiple_of"] |
|
ffn_dim_multiplier = args["ffn_dim_multiplier"] |
|
norm_eps = args["norm_eps"] |
|
|
|
assert (head_dim is not None) or (n_heads is not None), "Should specify at least head_dim or n_heads" |
|
self.head_dim = head_dim or dim // n_heads |
|
self.n_heads = n_heads or dim // head_dim |
|
self.n_kv_heads = n_kv_heads or self.n_heads |
|
|
|
assert n_heads % self.n_kv_heads == 0 |
|
assert dim % n_heads == 0 |
|
|
|
self.attention = BLTSelfAttention( |
|
dim=dim, |
|
head_dim=self.head_dim, |
|
n_heads=self.n_heads, |
|
n_kv_heads=self.n_kv_heads, |
|
rope_theta=rope_theta, |
|
) |
|
self.feed_forward = BLTMLP( |
|
dim=dim, |
|
hidden_dim=4 * dim, |
|
multiple_of=multiple_of, |
|
ffn_dim_multiplier=ffn_dim_multiplier, |
|
) |
|
self.attention_norm = RMSNorm(dim, eps=norm_eps) |
|
self.ffn_norm = RMSNorm(dim, eps=norm_eps) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
freq_cis: torch.Tensor, |
|
tok_idx: Optional[torch.Tensor] = None, |
|
mask: Optional[Union[BlockMask, str]] = None, |
|
attn_impl: str = "sdpa", |
|
) -> torch.Tensor: |
|
norm_x = self.attention_norm(x) |
|
attn_out = self.attention( |
|
norm_x, |
|
freq_cis, |
|
tok_idx=tok_idx, |
|
mask=mask, |
|
attn_impl=attn_impl, |
|
) |
|
h = x + attn_out |
|
h_norm = self.ffn_norm(h) |
|
out = h + self.feed_forward(h_norm) |
|
return out |
|
|
|
def check_non_zero_after_zero(tensor): |
|
zero_mask = tensor == 0 |
|
shifted_mask = torch.cat( |
|
[ |
|
torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device), |
|
zero_mask[:, :-1], |
|
], |
|
dim=1, |
|
) |
|
non_zero_after_zero = (tensor != 0) & shifted_mask |
|
return non_zero_after_zero.any() |
|
|
|
def rolling_polynomial_hash(t, hash_func_nb: int = 0): |
|
primes = [ |
|
1000000007, |
|
5915587277, |
|
1500450271, |
|
3267000013, |
|
5754853343, |
|
4093082899, |
|
9576890767, |
|
3628273133, |
|
2860486313, |
|
5463458053, |
|
3367900313, |
|
] |
|
prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) |
|
prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) |
|
return torch.sum(t * prime_powers, dim=-1) |
|
|
|
|
|
def byte_group_hash_function(x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000): |
|
""" |
|
Returns a hash of the input x and maps it to a value in the range [0, max_hash]. |
|
|
|
expects: x of shape (batch_size, seq_len) with values as ids in the token vocab. |
|
returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash]. |
|
|
|
Note: max hash can make a big difference on the number of collisions. |
|
""" |
|
with torch.no_grad(): |
|
bs, seq_len = x.shape |
|
prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device) |
|
x = torch.cat([prefix, x], dim=1) |
|
windows = x.unfold(1, group_size, 1) |
|
|
|
hashes = rolling_polynomial_hash(windows, hash_func_nb) |
|
hash_values_range = hashes % max_hash |
|
hash_values_range.requires_grad = False |
|
return hash_values_range |
|
|
|
|
|
def create_patch_mask_from_ids(patch_ids, num_patches, window=None, patches_as_queries=False): |
|
""" |
|
Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k) |
|
is True if the patch id at position (i, j) is less than or equal to k. |
|
Args: |
|
patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids. |
|
num_patches (int): Total number of patches. |
|
window (int): If not None, only considers patches within a window of size window. |
|
patches_as_queries (bool): If True, the patches are used as queries |
|
Returns: |
|
torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask. |
|
""" |
|
bs, seq_len = patch_ids.shape |
|
if not patches_as_queries: |
|
q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches) |
|
kv_ids = ( |
|
torch.arange(num_patches, device=patch_ids.device) |
|
.unsqueeze(0) |
|
.unsqueeze(0) |
|
.expand(bs, seq_len, num_patches) |
|
) |
|
else: |
|
kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len) |
|
q_ids = ( |
|
torch.arange(num_patches, device=patch_ids.device) |
|
.unsqueeze(0) |
|
.unsqueeze(-1) |
|
.expand(bs, num_patches, seq_len) |
|
) |
|
if window is None: |
|
mask = q_ids == kv_ids |
|
else: |
|
mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window) |
|
return mask |
|
|
|
|
|
def cross_attn_mask( |
|
patch_ids, |
|
patch_lengths, |
|
N, |
|
patches_as_queries=False, |
|
cross_attn_k=1, |
|
window=None, |
|
block_mask=True, |
|
): |
|
bs = patch_ids.shape[0] |
|
with torch.no_grad(): |
|
|
|
cross_mask = create_patch_mask_from_ids( |
|
patch_ids, |
|
patch_lengths.shape[1], |
|
window=window, |
|
patches_as_queries=patches_as_queries, |
|
).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1) |
|
q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N |
|
kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k |
|
assert cross_mask.shape == ( |
|
bs, |
|
q_len, |
|
kv_len, |
|
), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" |
|
|
|
if block_mask: |
|
|
|
def patch_mask(b, h, q_idx, kv_idx): |
|
return cross_mask[b, q_idx, kv_idx] |
|
|
|
block_mask = create_block_mask( |
|
patch_mask, |
|
B=bs, |
|
H=None, |
|
Q_LEN=q_len, |
|
KV_LEN=kv_len, |
|
_compile=True, |
|
) |
|
return block_mask |
|
else: |
|
return torch.where(cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))).unsqueeze( |
|
1 |
|
) |
|
|
|
|
|
def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int) -> torch.Tensor: |
|
if max_patch_length is None: |
|
return patch_lengths |
|
|
|
batch_size = patch_lengths.size(0) |
|
split_all = [] |
|
max_len = 0 |
|
|
|
for seq in patch_lengths: |
|
splits = [] |
|
for length in seq[seq > 0]: |
|
|
|
full, rem = divmod(length.item(), max_patch_length) |
|
splits.extend([max_patch_length] * full + ([rem] if rem else [])) |
|
split_all.append(splits) |
|
max_len = max(max_len, len(splits)) |
|
|
|
|
|
padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device) |
|
for i, splits in enumerate(split_all): |
|
if splits: |
|
padded[i, :len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device) |
|
|
|
|
|
last_non_zero = (padded != 0).flip(1).int().argmax(1).min() |
|
if last_non_zero < padded.shape[1]: |
|
padded = padded[:, :padded.shape[1] - last_non_zero] |
|
|
|
return padded |
|
|
|
class BLTLocalModelBase(nn.Module): |
|
def __init__(self, config: BLTConfig, component_type: str = "encoder"): |
|
super().__init__() |
|
|
|
self.config = config |
|
|
|
if component_type == "encoder": |
|
self.dim = config.dim_local_encoder |
|
self.n_layers = config.n_layers_local_encoder |
|
self.n_heads = config.n_heads_local_encoder |
|
self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen |
|
self.attn_bias_type = "local_block_causal" |
|
self.sliding_window = config.local_attention_window_len |
|
elif component_type == "decoder": |
|
self.dim = config.dim_local_decoder |
|
self.n_layers = config.n_layers_local_decoder |
|
self.n_heads = config.n_heads_local_decoder |
|
self.max_seqlen = config.max_encoder_seq_length or config.max_seqlen |
|
self.attn_bias_type = "local_block_causal" |
|
self.sliding_window = config.local_attention_window_len |
|
else: |
|
raise ValueError(f"Unknown component_type: {component_type}") |
|
|
|
self.dropout = config.dropout |
|
self.vocab_size = config.vocab_size + config.pm_size |
|
self.patch_size = config.patch_size |
|
|
|
self.attn_impl = config.attn_impl |
|
self.use_rope = config.use_rope |
|
self.init_std_factor = config.init_std_factor |
|
self.init_base_std = config.init_base_std |
|
self.cross_attn_encoder = getattr(config, "cross_attn_encoder", None) |
|
self.cross_attn_decoder = getattr(config, "cross_attn_decoder", None) |
|
self.cross_attn_k = getattr(config, "cross_attn_k", None) |
|
self.eos_id = config.eos_token_id |
|
|
|
self.boe_id = config.boe_id |
|
|
|
|
|
self.cross_attn_layers = None |
|
|
|
|
|
layer_params = { |
|
"dim": self.dim, |
|
"n_heads": self.n_heads, |
|
"head_dim": config.head_dim, |
|
"n_kv_heads": getattr(config, "n_kv_heads", None), |
|
"rope_theta": config.rope_theta, |
|
"multiple_of": getattr(config, "multiple_of", 256), |
|
"ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None), |
|
"norm_eps": config.norm_eps, |
|
} |
|
|
|
self.layers = nn.ModuleList([BLTTransformerLayer(layer_params) for _ in range(self.n_layers)]) |
|
|
|
if not self.use_rope: |
|
self.pos_embeddings = nn.Embedding(2048, self.dim) |
|
else: |
|
self.rope = RotaryEmbedding( |
|
theta=config.rope_theta, |
|
head_dim=config.head_dim or self.dim // self.n_heads, |
|
max_seqlen=self.max_seqlen, |
|
rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, |
|
) |
|
self.pos_embeddings = None |
|
|
|
|
|
if component_type == "encoder": |
|
self.dim_token_emb = config.encoder_dim_token_emb |
|
self.dim_patch_emb = config.encoder_dim_patch_emb |
|
elif component_type == "decoder": |
|
self.dim_token_emb = config.decoder_dim_token_emb |
|
self.dim_patch_emb = config.dim_global |
|
|
|
self.token_embedding_projection = ( |
|
nn.Linear(self.dim_token_emb, self.dim, bias=False) |
|
if self.dim_token_emb is not None and self.dim_token_emb != self.dim |
|
else None |
|
) |
|
|
|
self.patch_embedding_projection = self._create_patch_projection(config) |
|
|
|
def _should_create_patch_projection(self, config: BLTConfig): |
|
dimension_mismatch = self.dim_patch_emb is not None and self.dim_patch_emb != self.dim |
|
|
|
|
|
cross_attn_conditions = (config.cross_attn_encoder and config.cross_attn_init_by_pooling) or ( |
|
config.cross_attn_decoder and config.cross_attn_init_by_pooling |
|
) |
|
|
|
return dimension_mismatch or cross_attn_conditions |
|
|
|
def _create_patch_projection(self, config): |
|
if not self._should_create_patch_projection(config): |
|
return None |
|
|
|
output_dim = self.dim_token_emb * (self.cross_attn_k or 1) |
|
|
|
return nn.Linear( |
|
in_features=self.dim_patch_emb, |
|
out_features=output_dim, |
|
bias=False, |
|
) |
|
|
|
def apply_embedding(self, tokens, embeds): |
|
if embeds is not None: |
|
return embeds |
|
else: |
|
return self.tok_embeddings(tokens) |
|
|
|
|
|
class BLTLocalEncoder(BLTLocalModelBase): |
|
def __init__(self, config: BLTConfig): |
|
super().__init__(config, component_type="encoder") |
|
|
|
self.apply_transformer = config.use_local_encoder_transformer |
|
self.downsampling_by_pooling = config.downsampling_by_pooling |
|
self.expects_hash_embeddings = config.encoder_hash_byte_group_size is not None |
|
self.cross_attn_encoder = config.cross_attn_encoder |
|
self.cross_attn_all_layers_encoder = config.cross_attn_all_layers_encoder |
|
self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling |
|
self.cross_attn_nheads = config.cross_attn_nheads |
|
|
|
self.tok_embeddings = nn.Embedding(self.vocab_size, self.dim) |
|
|
|
if self.cross_attn_encoder: |
|
self.cross_attn_layers = torch.nn.ModuleList() |
|
layers_to_add = self.n_layers if self.cross_attn_all_layers_encoder else 1 |
|
for _ in range(layers_to_add): |
|
self.cross_attn_layers.append( |
|
BLTCrossAttention( |
|
dim=self.dim, |
|
head_dim=self.dim // self.cross_attn_nheads, |
|
n_heads=self.cross_attn_nheads, |
|
n_kv_heads=self.cross_attn_nheads, |
|
norm_eps=config.norm_eps, |
|
) |
|
) |
|
|
|
def apply_embedding(self, tokens, embeds): |
|
if embeds is not None: |
|
assert self.expects_hash_embeddings, "Not expecting embeddings to be passed." |
|
return embeds |
|
else: |
|
return self.tok_embeddings(tokens) |
|
|
|
def forward( |
|
self, |
|
tokens: torch.Tensor, |
|
embeds: Optional[torch.Tensor] = None, |
|
patch_embeds: Optional[torch.Tensor] = None, |
|
mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, |
|
cross_mask: Optional[torch.Tensor] = None, |
|
num_patches: Optional[int] = None, |
|
patch_ids: Optional[torch.Tensor] = None, |
|
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, |
|
): |
|
""" """ |
|
bs, seqlen = tokens.shape |
|
if mask is None: |
|
mask = create_causal_mask( |
|
seqlen, |
|
self.attn_impl, |
|
"local_block_causal", |
|
sliding_window=self.sliding_window, |
|
tokens=tokens, |
|
eos_id=self.eos_id, |
|
) |
|
|
|
h = self.apply_embedding(tokens, embeds) |
|
|
|
|
|
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None |
|
|
|
|
|
h = F.dropout(h, p=self.dropout, training=self.training) |
|
|
|
for i, layer in enumerate(self.layers): |
|
h = layer(h, freqs_cis, tok_idx=None, mask=mask, attn_impl=self.attn_impl) |
|
|
|
if self.cross_attn_encoder and (i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder): |
|
|
|
if self.cross_attn_init_by_pooling and patch_embeds is None: |
|
patch_embeds = self.patch_reduce(h, num_patches, "amax", patch_ids) |
|
if self.patch_embedding_projection is not None: |
|
patch_embeds = self.patch_embedding_projection(patch_embeds) |
|
patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim) |
|
|
|
layer_idx = i if self.cross_attn_all_layers_encoder else 0 |
|
patch_embeds_cross = self.cross_attn_layers[layer_idx]( |
|
x=patch_embeds, |
|
kv=h, |
|
mask=cross_mask, |
|
) |
|
patch_embeds = patch_embeds + patch_embeds_cross |
|
|
|
h_residual = patch_embeds if self.cross_attn_encoder else None |
|
return (h, h_residual), cache |
|
|
|
def patch_reduce(self, h, max_num_patches, reduction, patch_ids): |
|
""" |
|
Reduce variable length patches to single embedding per patch |
|
Note: this works with variable number of patches for different sequences in the batch |
|
It handles variable length patches by assuming that patch_lengths will be 0 for any |
|
extra patches on the *right*. Since there can be a variable number of patches |
|
this function also return the number of patches for each sequence in the batch. |
|
Any embeddings on the right that are not allocated to a patch |
|
(i.e. if the sum(patch_lengths[i]) < seq_len for any i) |
|
will be sent to a dummy patch, which is trimmed before returning. |
|
""" |
|
bs, seq_len, emb_dim = h.shape |
|
|
|
patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) |
|
|
|
reduced_embs = torch.zeros((bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device) |
|
reduced_embs = reduced_embs.scatter_reduce( |
|
src=h, |
|
dim=1, |
|
index=patch_ids, |
|
reduce=reduction, |
|
include_self=False, |
|
) |
|
reduced_embs = reduced_embs[:, :max_num_patches, :] |
|
|
|
return reduced_embs |
|
|
|
|
|
class BLTLocalDecoder(BLTLocalModelBase): |
|
def __init__(self, config: BLTConfig): |
|
super().__init__(config, component_type="decoder") |
|
|
|
|
|
self.cross_attn_decoder = config.cross_attn_decoder |
|
self.cross_attn_all_layers_decoder = config.cross_attn_all_layers_decoder |
|
self.cross_attn_init_by_pooling = config.cross_attn_init_by_pooling |
|
self.cross_attn_nheads = config.cross_attn_nheads |
|
|
|
self.norm = RMSNorm(self.dim, eps=config.norm_eps) |
|
|
|
if self.cross_attn_decoder: |
|
self.cross_attn_layers = torch.nn.ModuleList() |
|
layers_to_add = self.n_layers if self.cross_attn_all_layers_decoder else 1 |
|
for _ in range(layers_to_add): |
|
self.cross_attn_layers.append( |
|
BLTCrossAttention( |
|
dim=self.dim, |
|
head_dim=self.dim // self.cross_attn_nheads, |
|
n_heads=self.cross_attn_nheads, |
|
n_kv_heads=self.cross_attn_nheads, |
|
norm_eps=config.norm_eps, |
|
) |
|
) |
|
|
|
self.output = nn.Linear( |
|
self.dim, |
|
config.vocab_size, |
|
bias=False, |
|
) |
|
|
|
def forward( |
|
self, |
|
tokens: torch.Tensor, |
|
embeds: Optional[torch.Tensor], |
|
patch_embeds: Optional[torch.Tensor] = None, |
|
mask: Optional[Union["BlockMask", torch.Tensor, str]] = None, |
|
cross_mask: Optional[torch.Tensor] = None, |
|
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, |
|
): |
|
bs, seqlen = tokens.shape |
|
assert embeds is not None, "Embeddings must be provided" |
|
|
|
if mask is None: |
|
mask = create_causal_mask( |
|
seqlen, |
|
self.attn_impl, |
|
"local_block_causal", |
|
sliding_window=self.sliding_window, |
|
tokens=tokens, |
|
eos_id=self.eos_id, |
|
) |
|
|
|
h = embeds |
|
|
|
if self.patch_embedding_projection is not None: |
|
assert patch_embeds is not None, "Patch embeddings must be passed." |
|
patch_embeds = self.patch_embedding_projection(patch_embeds) |
|
if self.cross_attn_k is not None: |
|
patch_embeds = patch_embeds.reshape(bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim) |
|
|
|
if patch_embeds is not None and not self.cross_attn_decoder: |
|
h = h + patch_embeds |
|
|
|
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None |
|
|
|
h = F.dropout(h, p=self.dropout, training=self.training) |
|
for i, layer in enumerate(self.layers): |
|
if self.cross_attn_decoder and (i == 0 or self.cross_attn_all_layers_decoder): |
|
|
|
h_cross = self.cross_attn_layers[i]( |
|
x=h, |
|
kv=patch_embeds, |
|
mask=cross_mask, |
|
) |
|
h = h + h_cross |
|
|
|
h = layer(h, freqs_cis, tok_idx=None, mask=mask, attn_impl=self.attn_impl) |
|
|
|
h_preds = self.norm(h) |
|
h_preds = F.dropout(h_preds, p=self.dropout, training=self.training) |
|
h_preds = self.output(h_preds) |
|
h_preds = h_preds.float() |
|
return h_preds, cache |
|
|
|
|
|
class BLTCrossAttention(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
head_dim: int, |
|
n_heads: int, |
|
n_kv_heads: int, |
|
norm_eps: float, |
|
): |
|
super().__init__() |
|
|
|
self.dim = dim |
|
self.head_dim = head_dim |
|
|
|
self.n_heads = n_heads |
|
self.n_kv_heads = n_kv_heads |
|
self.heads_per_group = self.n_heads // self.n_kv_heads |
|
|
|
self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) |
|
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) |
|
|
|
self.wq = nn.Linear( |
|
dim, |
|
n_heads * head_dim, |
|
bias=False, |
|
) |
|
self.wk = nn.Linear( |
|
dim, |
|
n_kv_heads * head_dim, |
|
bias=False, |
|
) |
|
self.wv = nn.Linear( |
|
dim, |
|
n_kv_heads * head_dim, |
|
bias=False, |
|
) |
|
|
|
self.wo = nn.Linear( |
|
n_heads * head_dim, |
|
dim, |
|
bias=False, |
|
) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
kv: torch.Tensor, |
|
mask: Optional[Union[BlockMask, str]] = None, |
|
) -> torch.Tensor: |
|
|
|
bsz, seq_len, _ = x.shape |
|
_, slen_kv, _ = kv.shape |
|
x_norm = self.cross_attn_norm_q(x) |
|
kv = self.cross_attn_norm_kv(kv) |
|
|
|
xq = self.wq(x_norm) |
|
xk = self.wk(kv) |
|
xv = self.wv(kv) |
|
|
|
output_shape = xq.shape |
|
|
|
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) |
|
xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) |
|
xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim) |
|
|
|
xk = repeat_kv(xk, self.heads_per_group, dim=2) |
|
xv = repeat_kv(xv, self.heads_per_group, dim=2) |
|
|
|
|
|
xq, xk, xv = (e.transpose(1, 2) for e in (xq, xk, xv)) |
|
|
|
is_causal = (mask == "causal") if isinstance(mask, str) else False |
|
mask = mask if isinstance(mask, torch.Tensor) else None |
|
mask = mask.to(dtype=xq.dtype).to(xq.device) |
|
output = F.scaled_dot_product_attention( |
|
xq, |
|
xk, |
|
xv, |
|
is_causal=is_causal, |
|
attn_mask=mask, |
|
) |
|
output = output.transpose(1, 2).contiguous() |
|
|
|
output = self.wo(output.reshape(output_shape)) |
|
|
|
return x + output |
|
|
|
|
|
class BLTGlobalTransformer(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
|
|
self.config = config |
|
|
|
self.dim = config.dim_global |
|
self.rope_embeddings = RotaryEmbedding( |
|
theta=config.rope_theta, |
|
head_dim=config.head_dim or self.config.dim_global // config.n_heads_global, |
|
max_seqlen=config.max_seqlen, |
|
rope_use_fp32_in_outer_product=config.rope_use_fp32_in_outer_product, |
|
) |
|
|
|
self.eos_id = getattr(config, "eos_id", getattr(config, "eos_token_id", 2)) |
|
|
|
|
|
layer_params = { |
|
"dim": self.dim, |
|
"n_heads": config.n_heads_global, |
|
"head_dim": config.head_dim, |
|
"n_kv_heads": getattr(config, "n_kv_heads_global", None), |
|
"rope_theta": config.rope_theta, |
|
"multiple_of": getattr(config, "multiple_of", 256), |
|
"ffn_dim_multiplier": getattr(config, "ffn_dim_multiplier", None), |
|
"norm_eps": config.norm_eps, |
|
} |
|
|
|
self.layers = nn.ModuleList() |
|
for _ in range(config.n_layers_global): |
|
self.layers.append(BLTTransformerLayer(layer_params)) |
|
|
|
self.token_embedding_projection = None |
|
if config.global_dim_patch_emb is not None and config.global_dim_patch_emb != self.dim: |
|
self.token_embedding_projection = nn.Linear( |
|
config.global_dim_patch_emb, |
|
config.dim_global, |
|
bias=False, |
|
) |
|
|
|
def forward( |
|
self, |
|
tokens: torch.Tensor, |
|
tok_idx: Optional[torch.Tensor] = None, |
|
embeds: Optional[torch.Tensor] = None, |
|
mask: Optional[Union[BlockMask, torch.Tensor, str]] = None, |
|
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None, |
|
): |
|
bs, seqlen = tokens.shape |
|
|
|
h = embeds |
|
|
|
mask = ( |
|
mask |
|
if mask is not None |
|
else create_causal_mask( |
|
seqlen, |
|
self.config.attn_impl, |
|
self.config.attn_bias_type, |
|
tokens=tokens, |
|
eos_id=self.eos_id, |
|
) |
|
) |
|
|
|
if self.token_embedding_projection is not None and h.shape[-1] != self.dim: |
|
h = self.token_embedding_projection(h) |
|
|
|
h = F.dropout(h, p=self.config.dropout, training=self.training) |
|
freq_cis = self.rope_embeddings(seqlen=self.config.max_seqlen, tok_idx=tok_idx) |
|
|
|
for i, layer in enumerate(self.layers): |
|
h = layer(h, freq_cis, tok_idx=None, mask=mask, attn_impl=self.config.attn_impl) |
|
|
|
return h, cache |
|
|
|
|
|
def compute_hash_embeddings( |
|
local_encoder_tokens: torch.Tensor, |
|
local_encoder, |
|
encoder_hash_tok_embedding: nn.ModuleList, |
|
encoder_hash_byte_group_nb_functions: int, |
|
encoder_hash_byte_group_size: list, |
|
encoder_hash_byte_group_vocab: int, |
|
) -> torch.Tensor: |
|
""" |
|
Compute embeddings using hash token embeddings. |
|
|
|
Args: |
|
local_encoder_tokens: Input tokens tensor |
|
local_encoder: Encoder object with tok_embeddings method |
|
encoder_hash_tok_embedding: ModuleList of hash token embeddings |
|
encoder_hash_byte_group_nb_functions: Number of hash functions |
|
encoder_hash_byte_group_size: List of byte group sizes |
|
encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings |
|
|
|
Returns: |
|
torch.Tensor: Combined embeddings |
|
""" |
|
if encoder_hash_tok_embedding is None: |
|
return None |
|
|
|
local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens) |
|
|
|
i = 0 |
|
for func_nb in range(encoder_hash_byte_group_nb_functions): |
|
for byte_group_size in encoder_hash_byte_group_size: |
|
hash_ids = byte_group_hash_function( |
|
local_encoder_tokens, |
|
byte_group_size, |
|
hash_func_nb=func_nb, |
|
max_hash=encoder_hash_byte_group_vocab, |
|
) |
|
hash_tok_embedding = encoder_hash_tok_embedding[i] |
|
local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids) |
|
i += 1 |
|
|
|
assert i == len(encoder_hash_tok_embedding) |
|
return local_encoder_embeds |
|
|
|
|
|
class BLTPreTrainedModel(PreTrainedModel): |
|
config_class = BLTConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["BLTTransformerLayer", "BLTLocalEncoder", "BLTLocalDecoder", "BLTGlobalTransformer"] |
|
_skip_keys_device_placement = ["past_key_values"] |
|
_supports_flash_attn_2 = False |
|
_supports_sdpa = True |
|
_supports_cache_class = False |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Linear): |
|
std = getattr(module, '_custom_std', module.in_features ** (-0.5)) |
|
|
|
nn.init.trunc_normal_( |
|
module.weight, |
|
mean=0.0, |
|
std=std, |
|
a=-3 * std, |
|
b=3 * std, |
|
) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
|
|
elif isinstance(module, nn.Embedding): |
|
std = getattr(module, '_custom_std', module.embedding_dim ** (-0.5)) |
|
|
|
nn.init.trunc_normal_( |
|
module.weight, |
|
mean=0.0, |
|
std=std, |
|
a=-3 * std, |
|
b=3 * std, |
|
) |
|
|
|
elif isinstance(module, (nn.RMSNorm, nn.LayerNorm)): |
|
nn.init.ones_(module.weight) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
|
|
elif isinstance(module, RotaryEmbedding): |
|
module.freqs_cis[...] = precompute_freqs_cis( |
|
dim=module.head_dim, |
|
end=module.max_seqlen, |
|
theta=module.theta, |
|
rope_use_fp32_in_outer_product=module.rope_use_fp32_in_outer_product, |
|
) |
|
|
|
elif isinstance(module, BLTModel): |
|
if module.encoder_hash_tok_embedding is not None: |
|
emb_std = module.local_encoder.dim ** (-0.5) |
|
for emb in module.encoder_hash_tok_embedding: |
|
emb._custom_std = emb_std |
|
|
|
elif isinstance(module, (BLTLocalEncoder, BLTLocalDecoder)): |
|
if module.token_embedding_projection is not None: |
|
module.token_embedding_projection._custom_std = module.dim ** (-0.5) |
|
|
|
if module.patch_embedding_projection is not None: |
|
module.patch_embedding_projection._custom_std = module.dim_patch_emb ** (-0.5) |
|
|
|
elif isinstance(module, BLTGlobalTransformer): |
|
if module.token_embedding_projection is not None: |
|
module.token_embedding_projection._custom_std = module.dim_token_emb ** (-0.5) |
|
|
|
elif isinstance(module, BLTPatcher): |
|
emb_std = module.config.patcher_dim ** (-0.5) |
|
module.tok_embeddings._custom_std = emb_std |
|
module.output._custom_std = emb_std |
|
|
|
|
|
class BLTModel(BLTPreTrainedModel): |
|
def __init__(self, config: BLTConfig): |
|
super().__init__(config) |
|
|
|
self.config = config |
|
self.local_encoder = BLTLocalEncoder(config) |
|
self.global_transformer = BLTGlobalTransformer(config) |
|
self.local_decoder = BLTLocalDecoder(config) |
|
|
|
self.encoder_hash_tok_embedding = init_hash_embeddings( |
|
config, |
|
local_encoder_dim=self.local_encoder.dim, |
|
encoder_hash_byte_group_size=config.encoder_hash_byte_group_size, |
|
) |
|
|
|
if config.patch_in_forward: |
|
self.patcher = BLTPatcher(config) |
|
self.patcher.eval() |
|
for param in self.patcher.parameters(): |
|
param.requires_grad = False |
|
else: |
|
self.patcher = None |
|
|
|
def forward( |
|
self, |
|
tokens: torch.Tensor, |
|
patch_lengths: Optional[torch.Tensor] = None, |
|
): |
|
|
|
|
|
|
|
bs, N = tokens.shape |
|
|
|
local_encoder_tokens, local_decoder_tokens = tokens, tokens |
|
|
|
|
|
if patch_lengths is None: |
|
|
|
|
|
|
|
|
|
|
|
if self.config.patching_mode == PatchingModeEnum.entropy: |
|
_, patch_lengths, _ = self.patcher( |
|
local_encoder_tokens, |
|
patch_size=self.config.patch_size, |
|
include_next_token=True, |
|
threshold=self.config.patching_threshold, |
|
max_patch_length=self.config.max_patch_length, |
|
patching_batch_size=self.config.patching_batch_size, |
|
device=self.config.patching_device, |
|
) |
|
else: |
|
|
|
bs, seq_len = local_encoder_tokens.shape |
|
seq_len_next_tok = seq_len + 1 |
|
patch_lengths = torch.ones( |
|
(bs, seq_len_next_tok), dtype=local_encoder_tokens.dtype, device=local_encoder_tokens.device |
|
) |
|
|
|
patch_lengths = process_patch_lengths(patch_lengths, self.config.max_patch_length) |
|
|
|
|
|
|
|
patch_ids = self._patch_ids_from_lengths(patch_lengths, local_encoder_tokens.shape[-1]) |
|
|
|
|
|
|
|
|
|
cross_attn_mask_enc = None |
|
|
|
if self.config.cross_attn_encoder: |
|
cross_attn_mask_enc = cross_attn_mask( |
|
patch_ids, |
|
patch_lengths, |
|
N, |
|
patches_as_queries=True, |
|
cross_attn_k=self.config.cross_attn_k, |
|
window=self.config.cross_attn_window_encoder, |
|
block_mask=self.config.cross_attn_use_flex_attention, |
|
) |
|
|
|
|
|
local_encoder_embeds = compute_hash_embeddings( |
|
local_encoder_tokens=local_encoder_tokens, |
|
local_encoder=self.local_encoder, |
|
encoder_hash_tok_embedding=self.encoder_hash_tok_embedding, |
|
encoder_hash_byte_group_nb_functions=self.config.encoder_hash_byte_group_nb_functions, |
|
encoder_hash_byte_group_size=self.config.encoder_hash_byte_group_size, |
|
encoder_hash_byte_group_vocab=self.config.encoder_hash_byte_group_vocab, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
(h_encoder, h_cross), cache_encoder = self.local_encoder( |
|
tokens=local_encoder_tokens, |
|
embeds=local_encoder_embeds, |
|
patch_embeds=None, |
|
cross_mask=cross_attn_mask_enc, |
|
num_patches=patch_lengths.shape[1], |
|
patch_ids=patch_ids, |
|
) |
|
|
|
|
|
h = h_cross.view(bs, patch_lengths.shape[1], -1) |
|
|
|
|
|
global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.config.boe_id) |
|
rows, cols = torch.where(local_encoder_tokens == self.config.eos_token_id) |
|
eos_patch_ids = patch_ids[rows, cols] |
|
global_tokens[rows, eos_patch_ids] = self.config.eos_token_id |
|
|
|
h, _ = self.global_transformer( |
|
embeds=h, |
|
tokens=global_tokens, |
|
) |
|
|
|
|
|
|
|
dec_embeds = h_encoder |
|
|
|
|
|
decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], local_decoder_tokens.shape[-1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.config.cross_attn_decoder: |
|
h = torch.gather(h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])) |
|
cross_attn_mask_dec = None |
|
|
|
else: |
|
cross_attn_mask_dec = cross_attn_mask( |
|
decoder_patch_ids, |
|
patch_lengths, |
|
N, |
|
patches_as_queries=False, |
|
cross_attn_k=self.config.cross_attn_k, |
|
window=self.config.cross_attn_window_decoder, |
|
block_mask=self.config.cross_attn_use_flex_attention, |
|
) |
|
|
|
|
|
output, _ = self.local_decoder( |
|
embeds=dec_embeds, |
|
patch_embeds=h, |
|
tokens=local_decoder_tokens, |
|
cross_mask=cross_attn_mask_dec, |
|
) |
|
return output |
|
|
|
def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor: |
|
""" |
|
Convert patch lengths to patch IDs for each token position. |
|
For each token position in the sequence, determines which patch it belongs to. |
|
|
|
Args: |
|
patch_lengths: [batch_size, num_patches] - length of each patch |
|
seq_len: total sequence length |
|
|
|
Returns: |
|
patch_ids: [batch_size, seq_len] - patch index for each token position |
|
|
|
Example: |
|
patch_lengths = [[3, 2, 4, 1]] # 4 patches of lengths 3,2,4,1 |
|
seq_len = 10 |
|
Returns: [[0, 0, 0, 1, 1, 2, 2, 2, 2, 3]] |
|
# pos 0-2→patch 0, pos 3-4→patch 1, pos 5-8→patch 2, pos 9→patch 3 |
|
""" |
|
batch_size, num_patches = patch_lengths.shape |
|
|
|
|
|
patch_starts = torch.cat( |
|
[ |
|
torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device), |
|
patch_lengths.cumsum(dim=-1)[:, :-1], |
|
], |
|
dim=-1, |
|
) |
|
|
|
|
|
|
|
token_positions = torch.arange(seq_len, device=patch_lengths.device) |
|
|
|
|
|
|
|
position_ge_patch_start = patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1) |
|
|
|
|
|
patch_ids = position_ge_patch_start.sum(dim=-1) - 1 |
|
|
|
return patch_ids |
|
|
|
|
|
class BLTPatcher(BLTPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.rope_embeddings = RotaryEmbedding( |
|
theta=config.patcher_rope_theta, |
|
head_dim=config.patcher_head_dim or config.patcher_dim // config.patcher_n_heads, |
|
max_seqlen=config.patcher_max_seqlen, |
|
rope_use_fp32_in_outer_product=config.patcher_rope_use_fp32_in_outer_product, |
|
) |
|
|
|
self.layers = nn.ModuleList() |
|
for _ in range(config.patcher_n_layers): |
|
self.layers.append( |
|
BLTTransformerLayer( |
|
{ |
|
"dim": config.patcher_dim, |
|
"n_heads": config.patcher_n_heads, |
|
"head_dim": config.patcher_head_dim, |
|
"n_kv_heads": config.patcher_n_kv_heads, |
|
"rope_theta": config.patcher_rope_theta, |
|
"multiple_of": config.patcher_multiple_of, |
|
"ffn_dim_multiplier": config.patcher_ffn_dim_multiplier, |
|
"norm_eps": config.patcher_norm_eps, |
|
} |
|
) |
|
) |
|
|
|
|
|
|
|
self.tok_embeddings = torch.nn.Embedding(config.patcher_vocab_size, config.patcher_dim) |
|
|
|
self.norm = RMSNorm(config.patcher_dim, eps=config.patcher_norm_eps) |
|
|
|
self.output = nn.Linear( |
|
config.patcher_dim, |
|
config.patcher_vocab_size, |
|
bias=False, |
|
) |
|
|
|
def forward( |
|
self, |
|
token_values: torch.Tensor, |
|
patch_size: Optional[int] = None, |
|
include_next_token: bool = True, |
|
threshold: Optional[float] = None, |
|
max_patch_length: Optional[int] = None, |
|
patching_batch_size: int = 1, |
|
device: Optional[str] = None, |
|
): |
|
|
|
|
|
entropies = [] |
|
preds = [] |
|
max_length = self.config.patcher_max_seqlen |
|
batch_numel = max_length * patching_batch_size |
|
splits = torch.split(token_values.flatten(), batch_numel) |
|
|
|
for split in splits: |
|
pad_size = (max_length - (split.numel() % max_length)) % max_length |
|
pad = torch.zeros(pad_size, dtype=split.dtype, device=split.device, requires_grad=False) |
|
split = torch.cat((split, pad), dim=0) |
|
split = split.reshape(-1, max_length) |
|
if device is not None: |
|
split = split.to(device) |
|
|
|
|
|
bsz, seqlen = split.shape |
|
h = self.tok_embeddings(split) |
|
chunk_mask = create_causal_mask( |
|
seqlen, |
|
self.config.patcher_attn_impl , |
|
self.config.patcher_attn_bias_type, |
|
sliding_window=self.config.patcher_sliding_window, |
|
tokens=split, |
|
eos_id=self.config.eos_id, |
|
) |
|
|
|
freq_cis = self.rope_embeddings(seqlen=seqlen, tok_idx=None) |
|
|
|
for i, layer in enumerate(self.layers): |
|
h = layer(h, freq_cis, tok_idx=None, mask=chunk_mask, attn_impl=self.config.patcher_attn_impl) |
|
|
|
pred = self.output(self.norm(h)) |
|
pred = pred.reshape(-1, pred.shape[-1])[: split.numel() - pad_size, :] |
|
preds.append(pred) |
|
pred_entropies = self.entropy(pred) |
|
entropies.append(pred_entropies) |
|
|
|
concat_entropies = torch.cat(entropies, dim=0).reshape(token_values.shape) |
|
concat_preds = torch.cat(preds, dim=0).reshape(token_values.shape[0], -1) |
|
|
|
|
|
bs, seq_len = token_values.shape |
|
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len |
|
|
|
|
|
if patch_size is not None: |
|
patch_start_ids = self.find_entropy_patch_start_ids( |
|
concat_entropies, |
|
patch_size, |
|
include_next_token=include_next_token, |
|
threshold=threshold |
|
) |
|
patch_lengths = self.patch_lengths_from_start_ids(patch_start_ids, seq_len_next_tok) |
|
else: |
|
|
|
patch_lengths = torch.ones((bs, seq_len_next_tok), dtype=token_values.dtype, device=token_values.device) |
|
|
|
patch_lengths = process_patch_lengths(patch_lengths, max_patch_length) |
|
return concat_entropies, patch_lengths, concat_preds |
|
|
|
|
|
@staticmethod |
|
def entropy(scores): |
|
""" |
|
scores: [bs, seq_len, vocab] |
|
returns [bs, seq_len] |
|
|
|
Computes the entropy for each token in the batch. |
|
Note: uses natural log. |
|
""" |
|
log_probs = F.log_softmax(scores, dim=-1) |
|
probs = torch.exp(log_probs) |
|
p_log_p = log_probs * probs |
|
entropy = -p_log_p.sum(dim=-1) |
|
return entropy |
|
|
|
@staticmethod |
|
def patch_start_ids_from_patch_start_mask(patch_start_mask): |
|
bs, trunc_seq_len = patch_start_mask.shape |
|
max_patches = patch_start_mask.sum(dim=1).max() |
|
if max_patches == 0: |
|
patch_start_ids = torch.full( |
|
(bs, trunc_seq_len), |
|
trunc_seq_len, |
|
dtype=torch.long, |
|
device=patch_start_mask.device, |
|
) |
|
else: |
|
patch_ids = torch.arange(trunc_seq_len, device=patch_start_mask.device).unsqueeze(0).repeat(bs, 1) |
|
extra_patch_ids = torch.full( |
|
(bs, trunc_seq_len), |
|
trunc_seq_len, |
|
dtype=torch.long, |
|
device=patch_start_mask.device, |
|
) |
|
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1) |
|
patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1) |
|
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, trunc_seq_len)[:, :max_patches] |
|
return patch_start_ids |
|
|
|
@staticmethod |
|
def patch_lengths_from_start_ids(patch_start_ids, seq_len): |
|
""" |
|
Calculate patch lengths from start ids. |
|
start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then |
|
the rest are filled to the seq len. |
|
seq_len: ex: 7 length of the sequence |
|
|
|
returns the patch lengths: |
|
[1, 6] for the above example. |
|
""" |
|
last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1) |
|
patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1) |
|
patch_lengths = patch_end_ids - patch_start_ids + 1 |
|
assert torch.all(patch_lengths >= 0), f"{patch_lengths}" |
|
assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}" |
|
return patch_lengths |
|
|
|
@staticmethod |
|
def find_entropy_patch_start_ids( |
|
entropies, |
|
patch_size=None, |
|
threshold=None, |
|
include_next_token=True, |
|
): |
|
""" |
|
Use entropies to find the start ids of each patch. |
|
Use patch_size or threshold to figure out the total number of patches to allocate. |
|
|
|
When threshold is not None the number of patches is not constant between |
|
different sequences, but patches can be identified incrementally rather than |
|
decided globally using the entire sequence. |
|
""" |
|
bs, seq_len = entropies.shape[:2] |
|
|
|
first_ids = torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(bs, 1) |
|
preds_truncation_len = first_ids.shape[1] |
|
entropies = entropies[:, 1:] |
|
if threshold is None: |
|
num_patches = seq_len // patch_size |
|
patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices |
|
patch_start_ids = patch_start_ids.sort(dim=1).values |
|
else: |
|
patch_start_mask = entropies > threshold |
|
if not include_next_token: |
|
patch_start_mask = patch_start_mask[:, :-1] |
|
|
|
patch_start_ids = BLTPatcher.patch_start_ids_from_patch_start_mask(patch_start_mask) |
|
|
|
patch_start_ids = torch.cat((first_ids, patch_start_ids + preds_truncation_len), dim=1) |
|
return patch_start_ids |
|
|
|
def init_hash_embeddings( |
|
config, |
|
local_encoder_dim: int, |
|
encoder_hash_byte_group_size: list, |
|
): |
|
"""Initialize hash-based token embeddings for the BLT encoder.""" |
|
if config.encoder_hash_byte_group_size is None: |
|
return None |
|
|
|
embeddings = [] |
|
emb_dim = local_encoder_dim |
|
encoder_hash_byte_group_vocab = config.encoder_hash_byte_group_vocab |
|
|
|
for _ in range(config.encoder_hash_byte_group_nb_functions): |
|
for _ in encoder_hash_byte_group_size: |
|
embeddings.append( |
|
nn.Embedding( |
|
encoder_hash_byte_group_vocab, |
|
emb_dim, |
|
) |
|
) |
|
|
|
return nn.ModuleList(embeddings) |
|
|
|
|
|
__all__ = [ |
|
"BLTPreTrainedModel", |
|
"BLTModel", |
|
"BLTPatcher", |
|
"BLTLocalEncoder", |
|
"BLTLocalDecoder", |
|
"BLTGlobalTransformer", |
|
] |
|
|