File size: 5,547 Bytes
724be6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
from enum import Enum
from typing import Any
from pydantic import BaseModel, ConfigDict, model_validator
from typing_extensions import Self
EOS_ID: int = 2
class InitStdFactor(str, Enum):
DISABLED = "disabled" # Init std is divided by 1.0
GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
class PatchingModeEnum(str, Enum):
entropy = "entropy"
bpe = "bpe"
bpe_patcher = "bpe_patcher"
space = "space"
static = "static"
byte = "byte"
class LMTransformerArgs(BaseModel):
"""Arguments for the Language Model Transformer (used as entropy model for patching)"""
model_config = ConfigDict()
# Basic architecture
dim: int = 512
n_layers: int = 8
head_dim: int | None = None
n_heads: int | None = None
n_kv_heads: int | None = None
# Transformer configuration
max_seqlen: int = 1024
norm_eps: float = 1e-5
dropout: float = 0
vocab_size: int = -1
sliding_window: int | None = None
# Feedforward
ffn_dim_multiplier: float | None = None
multiple_of: int = 256
# Positional encoding
rope_theta: float = 10000.0
rope_use_fp32_in_outer_product: bool = False
# Attention
attn_impl: str = "sdpa"
attn_bias_type: str = "causal"
# Initialization
init_base_std: float | None = None
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
# Embedding dimensions
dim_token_emb: int | None = None
# Model behavior
weight_tying: bool = False
seed: int = 42
# Special token config
eos_id: int = EOS_ID
class ByteLatentTransformerArgs(BaseModel):
"""Arguments for the Byte Latent Transformer (main BLT model)"""
model_config = ConfigDict()
# Basic model configuration
seed: int = 42
vocab_size: int = -1
# Main architecture dimensions (these will be used for creating transformer args)
dim: int = 512
n_layers: int = 8
head_dim: int | None = None
n_heads: int | None = None
n_kv_heads: int | None = None
# Component-specific dimensions
dim_global: int = 512
dim_local_decoder: int = 512
dim_local_encoder: int = 512
n_layers_global: int = 8
n_layers_local_decoder: int = 8
n_layers_local_encoder: int = 8
n_heads_global: int = 8
n_heads_local_decoder: int = 8
n_heads_local_encoder: int = 8
n_kv_heads_global: int | None = None
# Transformer configuration (needed by transformer components)
max_seqlen: int = 1024
norm_eps: float = 1e-5
dropout: float = 0
# Feedforward (needed by transformer components)
ffn_dim_multiplier: float = 1.0
multiple_of: int = 256
# Positional encoding (needed by transformer components)
rope_theta: float = 10000.0
rope_use_fp32_in_outer_product: bool = False
# Attention (needed by transformer components)
attn_impl: str = "sdpa"
attn_bias_type: str = "causal"
# Initialization (needed by transformer components)
init_base_std: float | None = None
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
# Embedding dimensions (needed by transformer components)
dim_token_emb: int | None = None
# Patching configuration
patch_in_forward: bool = False
realtime_patching: bool = True
patch_size: float | None = None
patching_mode: str | None = None
patching_threshold: float | None = None
patching_threshold_add: float | None = None
monotonicity: bool = False
patching_batch_size: int = 1
patching_device: str = "cuda"
max_patch_length: int | None = None
entropy_model_checkpoint_dir: str | None = None
# Cross attention configurations
cross_attn_encoder: bool = False
cross_attn_decoder: bool = False
cross_attn_window_encoder: int | None = None
cross_attn_window_decoder: int | None = None
cross_attn_k: int | None = None
cross_attn_nheads: int | None = None
cross_attn_all_layers_decoder: bool = False
cross_attn_all_layers_encoder: bool = False
cross_attn_use_flex_attention: bool = True
cross_attn_init_by_pooling: bool = False
# Encoder configurations
use_local_encoder_transformer: bool = False
max_encoder_seq_length: int | None = None
encoder_hash_byte_group_size: Any | None = None
encoder_hash_byte_group_vocab: int = 30000
encoder_hash_byte_group_nb_functions: int = 3
encoder_enable_byte_ngrams: bool = False
encoder_ngram_to_size_str: str | None = None
downsampling_by_pooling: str | None = None
# Architecture and dimensions
dim_token: int | None = None
share_encoder_decoder_emb: bool = True
weight_tying: bool = False
# Attention configuration
local_attention_window_len: int | None = None
use_rope: bool = True
# Performance optimization
sequence_parallel: bool = False
loss_parallel: bool = False
fuse_sequence_parallel: bool = False
use_fsdp: bool = True
# Parameter mixing
pm_size: int = 0
# Special token config
eos_id: int = EOS_ID
@model_validator(mode="after")
def check_hash_byte_sizes(self) -> Self:
if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str:
self.encoder_hash_byte_group_size = [
int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0
]
return self
|