File size: 5,547 Bytes
724be6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from enum import Enum
from typing import Any

from pydantic import BaseModel, ConfigDict, model_validator
from typing_extensions import Self


EOS_ID: int = 2


class InitStdFactor(str, Enum):
    DISABLED = "disabled"  # Init std is divided by 1.0
    GLOBAL_DEPTH = "global_depth"  # Init std is divided by sqrt(2*n_layers)
    CURRENT_DEPTH = "current_depth"  # Init std is divided by sqrt(2*depth)
    DIM_RATIO = "dim_ratio"  # Init std is divided by model_dim/4096


class PatchingModeEnum(str, Enum):
    entropy = "entropy"
    bpe = "bpe"
    bpe_patcher = "bpe_patcher"
    space = "space"
    static = "static"
    byte = "byte"


class LMTransformerArgs(BaseModel):
    """Arguments for the Language Model Transformer (used as entropy model for patching)"""

    model_config = ConfigDict()

    # Basic architecture
    dim: int = 512
    n_layers: int = 8
    head_dim: int | None = None
    n_heads: int | None = None
    n_kv_heads: int | None = None

    # Transformer configuration
    max_seqlen: int = 1024
    norm_eps: float = 1e-5
    dropout: float = 0
    vocab_size: int = -1
    sliding_window: int | None = None

    # Feedforward
    ffn_dim_multiplier: float | None = None
    multiple_of: int = 256

    # Positional encoding
    rope_theta: float = 10000.0
    rope_use_fp32_in_outer_product: bool = False

    # Attention
    attn_impl: str = "sdpa"
    attn_bias_type: str = "causal"

    # Initialization
    init_base_std: float | None = None
    init_std_factor: InitStdFactor = InitStdFactor.DISABLED

    # Embedding dimensions
    dim_token_emb: int | None = None

    # Model behavior
    weight_tying: bool = False
    seed: int = 42

    # Special token config
    eos_id: int = EOS_ID


class ByteLatentTransformerArgs(BaseModel):
    """Arguments for the Byte Latent Transformer (main BLT model)"""

    model_config = ConfigDict()

    # Basic model configuration
    seed: int = 42
    vocab_size: int = -1

    # Main architecture dimensions (these will be used for creating transformer args)
    dim: int = 512
    n_layers: int = 8
    head_dim: int | None = None
    n_heads: int | None = None
    n_kv_heads: int | None = None

    # Component-specific dimensions
    dim_global: int = 512
    dim_local_decoder: int = 512
    dim_local_encoder: int = 512
    n_layers_global: int = 8
    n_layers_local_decoder: int = 8
    n_layers_local_encoder: int = 8
    n_heads_global: int = 8
    n_heads_local_decoder: int = 8
    n_heads_local_encoder: int = 8
    n_kv_heads_global: int | None = None

    # Transformer configuration (needed by transformer components)
    max_seqlen: int = 1024
    norm_eps: float = 1e-5
    dropout: float = 0

    # Feedforward (needed by transformer components)
    ffn_dim_multiplier: float = 1.0
    multiple_of: int = 256

    # Positional encoding (needed by transformer components)
    rope_theta: float = 10000.0
    rope_use_fp32_in_outer_product: bool = False

    # Attention (needed by transformer components)
    attn_impl: str = "sdpa"
    attn_bias_type: str = "causal"

    # Initialization (needed by transformer components)
    init_base_std: float | None = None
    init_std_factor: InitStdFactor = InitStdFactor.DISABLED

    # Embedding dimensions (needed by transformer components)
    dim_token_emb: int | None = None

    # Patching configuration
    patch_in_forward: bool = False
    realtime_patching: bool = True
    patch_size: float | None = None
    patching_mode: str | None = None
    patching_threshold: float | None = None
    patching_threshold_add: float | None = None
    monotonicity: bool = False
    patching_batch_size: int = 1
    patching_device: str = "cuda"
    max_patch_length: int | None = None
    entropy_model_checkpoint_dir: str | None = None

    # Cross attention configurations
    cross_attn_encoder: bool = False
    cross_attn_decoder: bool = False
    cross_attn_window_encoder: int | None = None
    cross_attn_window_decoder: int | None = None
    cross_attn_k: int | None = None
    cross_attn_nheads: int | None = None
    cross_attn_all_layers_decoder: bool = False
    cross_attn_all_layers_encoder: bool = False
    cross_attn_use_flex_attention: bool = True
    cross_attn_init_by_pooling: bool = False

    # Encoder configurations
    use_local_encoder_transformer: bool = False
    max_encoder_seq_length: int | None = None
    encoder_hash_byte_group_size: Any | None = None
    encoder_hash_byte_group_vocab: int = 30000
    encoder_hash_byte_group_nb_functions: int = 3
    encoder_enable_byte_ngrams: bool = False
    encoder_ngram_to_size_str: str | None = None
    downsampling_by_pooling: str | None = None

    # Architecture and dimensions
    dim_token: int | None = None
    share_encoder_decoder_emb: bool = True
    weight_tying: bool = False

    # Attention configuration
    local_attention_window_len: int | None = None
    use_rope: bool = True

    # Performance optimization
    sequence_parallel: bool = False
    loss_parallel: bool = False
    fuse_sequence_parallel: bool = False
    use_fsdp: bool = True

    # Parameter mixing
    pm_size: int = 0

    # Special token config
    eos_id: int = EOS_ID

    @model_validator(mode="after")
    def check_hash_byte_sizes(self) -> Self:
        if self.encoder_hash_byte_group_size is not None and type(self.encoder_hash_byte_group_size) == str:
            self.encoder_hash_byte_group_size = [
                int(x) for x in self.encoder_hash_byte_group_size.split(",") if len(x) > 0
            ]
        return self