paulpak58 commited on
Commit
4237495
·
verified ·
1 Parent(s): 4e15566

Bump transformers source v4.54.0.dev0

Browse files
Files changed (3) hide show
  1. config.json +1 -5
  2. modeling_lfm2.py +0 -945
  3. requirements.txt +0 -2
config.json CHANGED
@@ -42,9 +42,5 @@
42
  "transformers_version": "4.53.0.dev0",
43
  "use_cache": true,
44
  "use_pos_enc": true,
45
- "vocab_size": 65536,
46
- "auto_map": {
47
- "AutoConfig": "modeling_lfm2.LFM2Config",
48
- "AutoModelForCausalLM": "modeling_lfm2.LFM2ForCausalLM"
49
- }
50
  }
 
42
  "transformers_version": "4.53.0.dev0",
43
  "use_cache": true,
44
  "use_pos_enc": true,
45
+ "vocab_size": 65536
 
 
 
 
46
  }
modeling_lfm2.py DELETED
@@ -1,945 +0,0 @@
1
- from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from transformers.cache_utils import DynamicCache
7
- from transformers.configuration_utils import PretrainedConfig
8
- from transformers.generation import GenerationMixin
9
- from transformers.masking_utils import create_causal_mask
10
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
11
- from transformers.modeling_layers import GradientCheckpointingLayer
12
- from transformers.modeling_outputs import (
13
- BaseModelOutputWithPast,
14
- CausalLMOutputWithPast,
15
- )
16
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
17
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
18
- from transformers.processing_utils import Unpack
19
- from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
20
- from transformers.utils.import_utils import is_causal_conv1d_available
21
-
22
- if is_causal_conv1d_available():
23
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
24
- else:
25
- causal_conv1d_fn, causal_conv1d_update = None, None
26
-
27
-
28
- kernel_modules = (causal_conv1d_fn, causal_conv1d_update)
29
- is_fast_path_available = all(kernel_modules)
30
-
31
- logger = logging.get_logger(__name__)
32
-
33
-
34
- # ========================================================
35
- # Config Class (to be removed) once integrated into
36
- # `transformers`. For now, allows for dynamic importing.
37
- # ========================================================s
38
- # from .configuration_lfm2 import LFM2Config
39
-
40
-
41
- class LFM2Config(PretrainedConfig):
42
- model_type = "lfm2"
43
- keys_to_ignore_at_inference: ClassVar = ["past_key_values"]
44
-
45
- def __init__(
46
- self,
47
- vocab_size: int = 65536,
48
- hidden_size: int = 2560,
49
- num_hidden_layers: int = 32,
50
- pad_token_id: int = 0,
51
- bos_token_id: int = 1,
52
- eos_token_id: int = 2,
53
- tie_embedding: bool = True,
54
- theta: float = 1000000.0,
55
- max_position_embeddings: int = 128_000,
56
- use_cache: bool = True,
57
- norm_eps: float = 0.00001,
58
- initializer_range: float = 0.02,
59
- num_attention_heads: int = 32,
60
- num_key_value_heads: int = 8,
61
- conv_bias: bool = False,
62
- conv_dim: int = 2560,
63
- conv_L_cache: int = 3,
64
- block_dim: int = 2560,
65
- block_ff_dim: int = 12288,
66
- block_multiple_of: int = 256,
67
- block_ffn_dim_multiplier: float = 1.0,
68
- block_auto_adjust_ff_dim: bool = True,
69
- full_attn_idxs: Optional[list[int]] = None,
70
- **kwargs,
71
- ):
72
- self.vocab_size = vocab_size
73
- self.hidden_size = hidden_size
74
- self.num_hidden_layers = num_hidden_layers
75
- self.rope_theta = theta
76
- self.max_position_embeddings = max_position_embeddings
77
- self.use_cache = use_cache
78
- self.norm_eps = norm_eps
79
- self.initializer_range = initializer_range
80
-
81
- # attn operator config
82
- self.num_attention_heads = num_attention_heads
83
- self.num_key_value_heads = num_key_value_heads
84
- self.full_attn_idxs = full_attn_idxs
85
-
86
- # custom operator config
87
- self.conv_bias = conv_bias
88
- self.conv_dim = conv_dim
89
- self.conv_L_cache = conv_L_cache
90
-
91
- # block config
92
- self.block_dim = block_dim
93
- self.block_ff_dim = block_ff_dim
94
- self.block_multiple_of = block_multiple_of
95
- self.block_ffn_dim_multiplier = block_ffn_dim_multiplier
96
- self.block_auto_adjust_ff_dim = block_auto_adjust_ff_dim
97
-
98
- super().__init__(
99
- pad_token_id=pad_token_id,
100
- bos_token_id=bos_token_id,
101
- eos_token_id=eos_token_id,
102
- tie_word_embeddings=tie_embedding,
103
- **kwargs,
104
- )
105
-
106
- @property
107
- def layers_block_type(self):
108
- return [
109
- "attention" if i in self.full_attn_idxs else "conv"
110
- for i in range(self.num_hidden_layers)
111
- ]
112
-
113
-
114
- class LFM2RMSNorm(torch.nn.Module):
115
- def __init__(self, dim: int, eps: float = 1e-6):
116
- super().__init__()
117
- self.eps = eps
118
- self.weight = nn.Parameter(torch.ones(dim))
119
-
120
- def _norm(self, x):
121
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
122
-
123
- def forward(self, x):
124
- output = self._norm(x.float())
125
- return output.type_as(x) * self.weight
126
-
127
-
128
- def rotate_half(x):
129
- """Rotates half the hidden dims of the input."""
130
- x1 = x[..., : x.shape[-1] // 2]
131
- x2 = x[..., x.shape[-1] // 2 :]
132
- return torch.cat((-x2, x1), dim=-1)
133
-
134
-
135
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
136
- """Applies Rotary Position Embedding to the query and key tensors."""
137
- cos = cos.unsqueeze(unsqueeze_dim)
138
- sin = sin.unsqueeze(unsqueeze_dim)
139
- q_embed = (q * cos) + (rotate_half(q) * sin)
140
- k_embed = (k * cos) + (rotate_half(k) * sin)
141
- return q_embed, k_embed
142
-
143
-
144
- class LFM2RotaryEmbedding(nn.Module):
145
- def __init__(self, config: LFM2Config, device=None):
146
- super().__init__()
147
- # BC: "rope_type" was originally "type"
148
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
149
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
150
- else:
151
- self.rope_type = "default"
152
- self.max_seq_len_cached = config.max_position_embeddings
153
- self.original_max_seq_len = config.max_position_embeddings
154
-
155
- self.config = config
156
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
157
-
158
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
159
- self.register_buffer("inv_freq", inv_freq, persistent=False)
160
- self.original_inv_freq = self.inv_freq
161
-
162
- @torch.no_grad()
163
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
164
- def forward(self, x, position_ids):
165
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
166
- position_ids_expanded = position_ids[:, None, :].float()
167
-
168
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
169
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
170
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
171
- emb = torch.cat((freqs, freqs), dim=-1)
172
- cos = emb.cos() * self.attention_scaling
173
- sin = emb.sin() * self.attention_scaling
174
-
175
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
176
-
177
-
178
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
179
- """
180
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
181
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
182
- """
183
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
184
- if n_rep == 1:
185
- return hidden_states
186
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
187
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
188
-
189
-
190
- def eager_attention_forward(
191
- module: nn.Module,
192
- query: torch.Tensor,
193
- key: torch.Tensor,
194
- value: torch.Tensor,
195
- attention_mask: Optional[torch.Tensor],
196
- scaling: float,
197
- dropout: float = 0.0,
198
- **kwargs,
199
- ):
200
- num_key_value_groups = query.shape[1] // key.shape[1]
201
- key_states = repeat_kv(key, num_key_value_groups)
202
- value_states = repeat_kv(value, num_key_value_groups)
203
-
204
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
205
- if attention_mask is not None:
206
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
207
- attn_weights = attn_weights + causal_mask
208
- else:
209
- seq_len = key_states.shape[-2]
210
- causal_mask = torch.triu(
211
- torch.full((seq_len, seq_len), float("-inf"), device=attn_weights.device),
212
- diagonal=1,
213
- )
214
- attn_weights = attn_weights + causal_mask
215
-
216
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
217
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
218
- attn_output = torch.matmul(attn_weights, value_states)
219
- attn_output = attn_output.transpose(1, 2).contiguous()
220
-
221
- return attn_output, attn_weights
222
-
223
-
224
- class LFM2MLP(nn.Module):
225
- def __init__(
226
- self,
227
- dim: int,
228
- ff_dim: int,
229
- multiple_of: int,
230
- auto_adjust_ff_dim: bool,
231
- ffn_dim_multiplier: Optional[float],
232
- ):
233
- super().__init__()
234
- if auto_adjust_ff_dim:
235
- ff_dim = int(2 * ff_dim / 3)
236
- # custom dim factor multiplier
237
- if ffn_dim_multiplier is not None:
238
- ff_dim = int(ffn_dim_multiplier * ff_dim)
239
- ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
240
-
241
- self.w1 = nn.Linear(dim, ff_dim, bias=False)
242
- self.w3 = nn.Linear(dim, ff_dim, bias=False)
243
- self.w2 = nn.Linear(ff_dim, dim, bias=False)
244
-
245
- def forward(self, x):
246
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
247
-
248
-
249
- class LFM2Cache(DynamicCache):
250
- """
251
- Attention and conv cache for LFM2.
252
-
253
- It stores the Key and Value states as a list of tensors, one for each layer.
254
- Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`.
255
- Conv layer cache shape: `[batch_size, conv_dim, L_cache-1]`.
256
- """
257
-
258
- def __init__(
259
- self,
260
- config: LFM2Config,
261
- max_batch_size: int,
262
- dtype: torch.dtype = torch.float32,
263
- device: Union[torch.device, str, None] = None,
264
- ):
265
- super().__init__() # initialize key and value cache
266
- self.max_batch_size = max_batch_size
267
- self.full_attn_idxs = config.full_attn_idxs
268
- self.conv_L_cache = config.conv_L_cache
269
- self._dtype = dtype
270
-
271
- self.conv_cache: List[torch.Tensor] = []
272
- device = torch.device(device) if device is not None else None
273
-
274
- for _ in range(config.num_hidden_layers):
275
- conv_state = torch.zeros(
276
- self.max_batch_size,
277
- config.conv_dim,
278
- self.conv_L_cache,
279
- dtype=self._dtype,
280
- device=device,
281
- )
282
- torch._dynamo.mark_static_address(conv_state)
283
- self.conv_cache.append(conv_state)
284
-
285
- def update(
286
- self,
287
- key_states: torch.Tensor,
288
- value_states: torch.Tensor,
289
- layer_idx: int,
290
- cache_kwargs: Optional[Dict[str, Any]] = None,
291
- ) -> Tuple[torch.Tensor, torch.Tensor]:
292
- """
293
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
294
-
295
- Parameters:
296
- key_states (`torch.Tensor`):
297
- The new key states to cache.
298
- value_states (`torch.Tensor`):
299
- The new value states to cache.
300
- layer_idx (`int`):
301
- The index of the layer to cache the states for.
302
- cache_kwargs (`Dict[str, Any]`, `optional`):
303
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
304
-
305
- Return:
306
- A tuple containing the updated key and value states.
307
- """
308
- # Update the number of seen tokens
309
- # if layer_idx == 0:
310
- if layer_idx == self.full_attn_idxs[0]:
311
- self._seen_tokens += key_states.shape[-2]
312
-
313
- # Update the cache
314
- if key_states is not None:
315
- if len(self.key_cache) <= layer_idx:
316
- # There may be skipped layers, fill them with empty lists
317
- for _ in range(len(self.key_cache), layer_idx):
318
- self.key_cache.append(torch.tensor([]))
319
- self.value_cache.append(torch.tensor([]))
320
- self.key_cache.append(key_states)
321
- self.value_cache.append(value_states)
322
- elif (
323
- not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
324
- ): # fills previously skipped layers; checking for tensor causes errors
325
- self.key_cache[layer_idx] = key_states
326
- self.value_cache[layer_idx] = value_states
327
- else:
328
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
329
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
330
-
331
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
332
-
333
- def reorder_cache(self, beam_idx: torch.LongTensor):
334
- """Reorders the cache for beam search, given the selected beam indices."""
335
- for layer_idx in range(len(self.key_cache)):
336
- device = self.key_cache[layer_idx].device
337
- self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
338
- device = self.value_cache[layer_idx].device
339
- self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
340
-
341
- device = self.conv_cache[layer_idx].device
342
- self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
343
-
344
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
345
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
346
- # take any layer that contains cache and not empty tensor
347
- layer_idx = self.full_attn_idxs[0] if layer_idx not in self.full_attn_idxs else layer_idx
348
- if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0:
349
- return 0
350
- return self.key_cache[layer_idx].shape[-2]
351
-
352
- def reset(self):
353
- for layer_idx in range(len(self.conv_cache)):
354
- # In-place ops prevent breaking the static address
355
- self.conv_cache[layer_idx].zero_()
356
-
357
-
358
- class LFM2Attention(nn.Module):
359
- def __init__(self, config: LFM2Config, layer_idx: Optional[int] = None, **kwargs):
360
- super().__init__()
361
- self.config = config
362
- self.layer_idx = layer_idx
363
- if layer_idx is None:
364
- logger.warning_once(
365
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and "
366
- "will lead to errors during the forward call if caching is used. Please make sure to provide a "
367
- "`layer_idx` when creating this class."
368
- )
369
- self.head_dim = config.hidden_size // config.num_attention_heads
370
- self.num_key_value_heads = config.num_key_value_heads
371
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
372
- self.scaling = self.head_dim**-0.5
373
- self.is_causal = True
374
-
375
- self.q_layernorm = LFM2RMSNorm(self.head_dim, eps=config.norm_eps)
376
- self.k_layernorm = LFM2RMSNorm(self.head_dim, eps=config.norm_eps)
377
-
378
- self.q_proj = nn.Linear(
379
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=False
380
- )
381
- self.k_proj = nn.Linear(
382
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
383
- )
384
- self.v_proj = nn.Linear(
385
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
386
- )
387
- self.out_proj = nn.Linear(
388
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
389
- )
390
-
391
- def forward(
392
- self,
393
- hidden_states: torch.Tensor,
394
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
395
- attention_mask: Optional[torch.Tensor],
396
- past_key_value: Optional[LFM2Cache] = None,
397
- cache_position: Optional[torch.LongTensor] = None,
398
- **kwargs,
399
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
400
- input_shape = hidden_states.shape[:-1]
401
- hidden_shape = (*input_shape, -1, self.head_dim)
402
-
403
- q = self.q_layernorm(self.q_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
404
- k = self.k_layernorm(self.k_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
405
- v = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
406
-
407
- cos, sin = position_embeddings
408
- q, k = apply_rotary_pos_emb(q, k, cos, sin)
409
-
410
- if past_key_value is not None:
411
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
412
- k, v = past_key_value.update(key_states=k, value_states=v, layer_idx=self.layer_idx, cache_kwargs=cache_kwargs)
413
-
414
- attention_interface: Callable = eager_attention_forward
415
- if self.config._attn_implementation != "eager":
416
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
417
-
418
- attn_output, attn_weights = attention_interface(
419
- self,
420
- q,
421
- k,
422
- v,
423
- attention_mask,
424
- dropout=0.0,
425
- scaling=self.scaling,
426
- **kwargs,
427
- )
428
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
429
- output = self.out_proj(attn_output)
430
- return output, attn_weights
431
-
432
-
433
- class LFM2ShortConv(nn.Module):
434
- def __init__(
435
- self,
436
- config: LFM2Config,
437
- dim: int,
438
- layer_idx: int,
439
- ):
440
- super().__init__()
441
- self.config = config
442
- self.layer_idx = layer_idx
443
- self.L_cache = config.conv_L_cache
444
- self.bias = config.conv_bias
445
-
446
- self.conv = nn.Conv1d(
447
- in_channels=dim,
448
- out_channels=dim,
449
- kernel_size=self.L_cache,
450
- groups=dim,
451
- bias=self.bias,
452
- padding=self.L_cache - 1,
453
- )
454
- self.in_proj = nn.Linear(dim, 3 * dim, bias=self.bias)
455
- self.out_proj = nn.Linear(dim, dim, bias=self.bias)
456
-
457
- def cuda_kernels_forward(
458
- self,
459
- x: torch.Tensor,
460
- cache_params: Optional[LFM2Cache] = None,
461
- cache_position: Optional[torch.LongTensor] = None,
462
- attention_mask: Optional[torch.Tensor] = None,
463
- ):
464
- BCx = self.in_proj(x).transpose(-1, -2)
465
- B, C, x = BCx.chunk(3, dim=-2)
466
-
467
- Bx = B * x
468
-
469
- conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
470
- if cache_params is not None and cache_position[0] > 0:
471
- conv_out = causal_conv1d_update(
472
- Bx.squeeze(-1),
473
- cache_params.conv_cache[self.layer_idx],
474
- conv_weights,
475
- self.conv.bias,
476
- None,
477
- )
478
- conv_out = conv_out.unsqueeze(-1)
479
- else:
480
- if cache_params is not None:
481
- conv_state = nn.functional.pad(
482
- Bx,
483
- (self.L_cache - Bx.shape[-1], 0)
484
- )
485
- cache_params.conv_cache[self.layer_idx].copy_(conv_state)
486
-
487
- conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None)
488
-
489
- y = C * conv_out
490
- y = self.out_proj(y.transpose(-1, -2).contiguous())
491
- return y
492
-
493
- def slow_forward(
494
- self,
495
- x: torch.Tensor,
496
- cache_params: Optional[LFM2Cache] = None,
497
- cache_position: Optional[torch.LongTensor] = None,
498
- attention_mask: Optional[torch.Tensor] = None,
499
- ):
500
- seqlen = x.shape[1]
501
- BCx = self.in_proj(x).transpose(-1, -2)
502
- B, C, x = BCx.chunk(3, dim=-2)
503
-
504
- Bx = B * x
505
-
506
- if cache_params is not None and cache_position[0] > 0:
507
- conv_state = cache_params.conv_cache[self.layer_idx]
508
- cache_position = cache_position.clamp(0, self.L_cache - 1)
509
- conv_state = conv_state.roll(shifts=-1, dims=-1)
510
- conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype)
511
- cache_params.conv_cache[self.layer_idx].copy_(conv_state)
512
- conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1)
513
- if self.bias:
514
- conv_out += self.conv.bias
515
-
516
- conv_out = conv_out.unsqueeze(-1)
517
- else:
518
- if cache_params is not None:
519
- conv_state = nn.functional.pad(
520
- Bx,
521
- (self.L_cache - Bx.shape[-1], 0)
522
- )
523
- cache_params.conv_cache[self.layer_idx].copy_(conv_state)
524
-
525
- conv_out = self.conv(Bx)[..., :seqlen]
526
-
527
- y = C * conv_out
528
- y = y.transpose(-1, -2).contiguous()
529
- y = self.out_proj(y)
530
- return y
531
-
532
-
533
- def forward(
534
- self,
535
- x: torch.Tensor,
536
- cache_params: Optional[LFM2Cache] = None,
537
- cache_position: Optional[torch.LongTensor] = None,
538
- attention_mask: Optional[torch.Tensor] = None,
539
- ):
540
- if is_fast_path_available and "cuda" in x.device.type and not torch._dynamo.is_compiling():
541
- return self.cuda_kernels_forward(x, cache_params, cache_position, attention_mask)
542
- return self.slow_forward(x, cache_params, cache_position, attention_mask)
543
-
544
-
545
- class LFM2AttentionDecoderLayer(GradientCheckpointingLayer):
546
- def __init__(self, config: LFM2Config, layer_idx: int):
547
- super().__init__()
548
- self.self_attn = LFM2Attention(config, layer_idx)
549
- self.feed_forward = LFM2MLP(
550
- dim=config.block_dim,
551
- ff_dim=config.block_ff_dim,
552
- multiple_of=config.block_multiple_of,
553
- auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
554
- ffn_dim_multiplier=config.block_ffn_dim_multiplier,
555
- )
556
- self.operator_norm = LFM2RMSNorm(config.hidden_size, eps=config.norm_eps)
557
- self.ffn_norm = LFM2RMSNorm(config.hidden_size, eps=config.norm_eps)
558
-
559
- def forward(
560
- self,
561
- hidden_states: torch.Tensor,
562
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
563
- attention_mask: Optional[torch.Tensor] = None,
564
- position_ids: Optional[torch.LongTensor] = None,
565
- past_key_value: Optional[tuple[torch.Tensor]] = None,
566
- output_attentions: Optional[bool] = False,
567
- cache_position: Optional[torch.LongTensor] = None,
568
- **kwargs,
569
- ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
570
- h, self_attn_weights = self.self_attn(
571
- hidden_states=self.operator_norm(hidden_states),
572
- position_embeddings=position_embeddings,
573
- attention_mask=attention_mask,
574
- position_ids=position_ids,
575
- past_key_value=past_key_value,
576
- cache_position=cache_position,
577
- **kwargs,
578
- )
579
- h += hidden_states
580
- out = h + self.feed_forward.forward(self.ffn_norm(h))
581
-
582
- outputs = (out,)
583
- if output_attentions:
584
- outputs += (self_attn_weights,)
585
-
586
- return outputs
587
-
588
-
589
- class LFM2ShortConvDecoderLayer(GradientCheckpointingLayer):
590
- def __init__(self, config: LFM2Config, layer_idx: int):
591
- super().__init__()
592
- self.conv = LFM2ShortConv(
593
- config=config,
594
- dim=config.conv_dim,
595
- layer_idx=layer_idx,
596
- )
597
- self.feed_forward = LFM2MLP(
598
- dim=config.block_dim,
599
- ff_dim=config.block_ff_dim,
600
- multiple_of=config.block_multiple_of,
601
- auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
602
- ffn_dim_multiplier=config.block_ffn_dim_multiplier,
603
- )
604
- self.operator_norm = LFM2RMSNorm(config.hidden_size, eps=config.norm_eps)
605
- self.ffn_norm = LFM2RMSNorm(config.hidden_size, eps=config.norm_eps)
606
-
607
- def forward(
608
- self,
609
- hidden_states: torch.Tensor,
610
- past_key_value: Optional[LFM2Cache] = None,
611
- cache_position: Optional[torch.LongTensor] = None,
612
- attention_mask: Optional[torch.Tensor] = None,
613
- output_attentions: Optional[bool] = False,
614
- **kwargs,
615
- ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
616
- h = self.conv(
617
- self.operator_norm(hidden_states),
618
- cache_params=past_key_value,
619
- cache_position=cache_position,
620
- attention_mask=attention_mask,
621
- )
622
- self_attn_weights = None
623
-
624
- h += hidden_states
625
- out = h + self.feed_forward.forward(self.ffn_norm(h))
626
-
627
- outputs = (out,)
628
- if output_attentions:
629
- outputs += (self_attn_weights,)
630
-
631
- return outputs
632
-
633
-
634
- @auto_docstring
635
- class LFM2PretrainedModel(PreTrainedModel):
636
- config_class = LFM2Config
637
- base_model_prefix = "model"
638
- supports_gradient_checkpointing = True
639
- _no_split_modules: ClassVar = ["LFM2AttentionDecoderLayer", "LFM2ShortConvDecoderLayer"]
640
- _skip_keys_device_placement = "past_key_values"
641
- _supports_flash_attn_2 = True
642
- _supports_sdpa = True
643
- _supports_flex_attn = True
644
- _supports_cache_class = True
645
- _supports_quantized_cache = True
646
- _supports_static_cache = True
647
- _supports_attention_backend = True
648
-
649
- def _init_weights(self, module):
650
- std = self.config.initializer_range
651
- if isinstance(module, (nn.Linear, nn.Conv1d)):
652
- module.weight.data.normal_(mean=0.0, std=std)
653
- if module.bias is not None:
654
- module.bias.data.zero_()
655
- elif isinstance(module, nn.Embedding):
656
- module.weight.data.normal_(mean=0.0, std=std)
657
- if module.padding_idx is not None:
658
- module.weight.data[module.padding_idx].zero_()
659
- elif isinstance(module, LFM2RMSNorm):
660
- module.weight.data.fill_(1.0)
661
-
662
-
663
- class LFM2Model(LFM2PretrainedModel):
664
- def __init__(self, config: LFM2Config):
665
- super().__init__(config)
666
- self.padding_idx = config.pad_token_id
667
- self.vocab_size = config.vocab_size
668
-
669
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
670
-
671
- self.pos_emb = LFM2RotaryEmbedding(config)
672
-
673
- decoder_layers = []
674
- for i in range(config.num_hidden_layers):
675
- if i in config.full_attn_idxs:
676
- decoder_layers.append(LFM2AttentionDecoderLayer(config, layer_idx=i))
677
- else:
678
- decoder_layers.append(LFM2ShortConvDecoderLayer(config, layer_idx=i))
679
- self.layers = nn.ModuleList(decoder_layers)
680
-
681
- self.embedding_norm = LFM2RMSNorm(config.hidden_size, eps=config.norm_eps)
682
-
683
- self.gradient_checkpointing = False
684
-
685
- # Initialize weights and apply final processing
686
- self.post_init()
687
-
688
- def get_input_embeddings(self):
689
- return self.embed_tokens
690
-
691
- def set_input_embeddings(self, value):
692
- self.embed_tokens = value
693
-
694
- @can_return_tuple
695
- @auto_docstring
696
- def forward(
697
- self,
698
- input_ids: torch.LongTensor = None,
699
- attention_mask: Optional[torch.Tensor] = None,
700
- position_ids: Optional[torch.LongTensor] = None,
701
- past_key_values: Optional[LFM2Cache] = None,
702
- inputs_embeds: Optional[torch.FloatTensor] = None,
703
- use_cache: Optional[bool] = None,
704
- output_attentions: Optional[bool] = None,
705
- output_hidden_states: Optional[bool] = None,
706
- return_dict: Optional[bool] = None,
707
- cache_position: Optional[torch.LongTensor] = None,
708
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
709
- ) -> BaseModelOutputWithPast:
710
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
711
- output_hidden_states = (
712
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
713
- )
714
- use_cache = use_cache if use_cache is not None else self.config.use_cache
715
-
716
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
717
-
718
- if (input_ids is None) ^ (inputs_embeds is not None):
719
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
720
-
721
- if self.gradient_checkpointing and self.training and use_cache:
722
- logger.warning_once(
723
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
724
- )
725
- use_cache = False
726
-
727
- if inputs_embeds is None:
728
- inputs_embeds = self.embed_tokens(input_ids)
729
-
730
- if use_cache and past_key_values is None:
731
- batch_size = inputs_embeds.shape[0]
732
- past_key_values = LFM2Cache(
733
- config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device
734
- )
735
-
736
- if cache_position is None:
737
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
738
- cache_position = torch.arange(
739
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
740
- )
741
-
742
- if position_ids is None:
743
- position_ids = cache_position.unsqueeze(0)
744
-
745
- causal_mask = create_causal_mask(
746
- config=self.config,
747
- input_embeds=inputs_embeds,
748
- attention_mask=attention_mask,
749
- cache_position=cache_position,
750
- past_key_values=past_key_values,
751
- )
752
- hidden_states = inputs_embeds
753
-
754
- position_embeddings = self.pos_emb(hidden_states, position_ids)
755
-
756
- # decoder layers
757
- all_hidden_states = () if output_hidden_states else None
758
- all_self_attns = () if output_attentions else None
759
- for decoder_layer in self.layers:
760
- if output_hidden_states:
761
- all_hidden_states += (hidden_states,)
762
-
763
- layer_outputs = decoder_layer(
764
- hidden_states,
765
- attention_mask=causal_mask,
766
- position_ids=position_ids,
767
- past_key_value=past_key_values,
768
- output_attentions=output_attentions,
769
- use_cache=use_cache,
770
- cache_position=cache_position,
771
- position_embeddings=position_embeddings,
772
- **flash_attn_kwargs,
773
- )
774
-
775
- hidden_states = layer_outputs[0]
776
-
777
- if output_attentions:
778
- all_self_attns += (layer_outputs[1],)
779
-
780
- hidden_states = self.embedding_norm(hidden_states)
781
-
782
- # add hidden states from the last decoder layer
783
- if output_hidden_states:
784
- all_hidden_states += (hidden_states,)
785
-
786
- output = BaseModelOutputWithPast(
787
- last_hidden_state=hidden_states,
788
- past_key_values=past_key_values if use_cache else None,
789
- hidden_states=all_hidden_states,
790
- attentions=all_self_attns,
791
- )
792
- return output if return_dict else output.to_tuple()
793
-
794
-
795
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
796
-
797
-
798
- @auto_docstring
799
- class LFM2ForCausalLM(LFM2PretrainedModel, GenerationMixin):
800
- _tied_weights_keys = ["lm_head.weight"]
801
-
802
- def __init__(self, config: LFM2Config):
803
- super().__init__(config)
804
- self.model = LFM2Model(config)
805
- self.vocab_size = config.vocab_size
806
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
807
- self.post_init()
808
-
809
- def get_input_embeddings(self):
810
- return self.model.embed_tokens
811
-
812
- def set_input_embeddings(self, value):
813
- self.model.embed_tokens = value
814
-
815
- def get_output_embeddings(self):
816
- return self.lm_head
817
-
818
- def set_output_embeddings(self, new_embeddings):
819
- self.lm_head = new_embeddings
820
-
821
- def set_decoder(self, decoder):
822
- self.model = decoder
823
-
824
- def get_decoder(self):
825
- return self.model
826
-
827
- def forward(
828
- self,
829
- input_ids: torch.LongTensor = None,
830
- attention_mask: Optional[torch.Tensor] = None,
831
- position_ids: Optional[torch.LongTensor] = None,
832
- past_key_values: Optional[LFM2Cache] = None,
833
- inputs_embeds: Optional[torch.FloatTensor] = None,
834
- labels: Optional[torch.LongTensor] = None,
835
- use_cache: Optional[bool] = None,
836
- output_attentions: Optional[bool] = None,
837
- output_hidden_states: Optional[bool] = None,
838
- return_dict: Optional[bool] = None,
839
- cache_position: Optional[torch.LongTensor] = None,
840
- logits_to_keep: Union[int, torch.Tensor] = 0,
841
- **kwargs: Unpack[KwargsForCausalLM],
842
- ) -> Union[tuple, CausalLMOutputWithPast]:
843
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
844
- output_hidden_states = (
845
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
846
- )
847
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
848
-
849
- outputs: BaseModelOutputWithPast = self.model(
850
- input_ids=input_ids,
851
- attention_mask=attention_mask,
852
- position_ids=position_ids,
853
- past_key_values=past_key_values,
854
- inputs_embeds=inputs_embeds,
855
- use_cache=use_cache,
856
- output_attentions=output_attentions,
857
- output_hidden_states=output_hidden_states,
858
- cache_position=cache_position,
859
- return_dict=return_dict,
860
- **kwargs,
861
- )
862
-
863
- hidden_states = outputs.last_hidden_state
864
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
865
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
866
- logits = self.lm_head(hidden_states[:, slice_indices, :])
867
-
868
- loss = None
869
- if labels is not None:
870
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
871
-
872
- if not return_dict:
873
- output = (logits,) + outputs[1:]
874
- return (loss,) + output if loss is not None else output
875
-
876
- return CausalLMOutputWithPast(
877
- loss=loss,
878
- logits=logits,
879
- past_key_values=outputs.past_key_values,
880
- hidden_states=outputs.hidden_states,
881
- attentions=outputs.attentions,
882
- )
883
-
884
- def prepare_inputs_for_generation(
885
- self,
886
- input_ids,
887
- past_key_values=None,
888
- attention_mask=None,
889
- inputs_embeds=None,
890
- cache_position=None,
891
- position_ids=None,
892
- use_cache=True,
893
- **kwargs,
894
- ):
895
- # Overwritten -- Support custom LFM2Cache.
896
-
897
- empty_past_kv = past_key_values is None or (
898
- isinstance(past_key_values, DynamicCache) and past_key_values._seen_tokens == 0
899
- )
900
-
901
- # Omit tokens covered by past_key_values.
902
- if not empty_past_kv:
903
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
904
- # Exception 1: when passing input_embeds, input_ids may be missing entries
905
- # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
906
- # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
907
- # (we can't check exception 3 while compiling)
908
- if (
909
- inputs_embeds is not None # Exception 1
910
- or cache_position[-1] >= input_ids.shape[1] # Exception 3
911
- ):
912
- input_ids = input_ids[:, -cache_position.shape[0] :]
913
- elif (
914
- input_ids.shape[1] != cache_position.shape[0]
915
- ): # Default case (the "else", a no op, is Exception 2)
916
- input_ids = input_ids[:, cache_position]
917
- else:
918
- past_key_values = LFM2Cache(self.config, input_ids.shape[0], dtype=self.dtype, device=self.device)
919
-
920
- # if attention_mask is not None and position_ids is None:
921
- # # create position_ids on the fly for batch generation
922
- # position_ids = attention_mask.long().cumsum(-1) - 1
923
- # position_ids.masked_fill_(attention_mask == 0, 1)
924
- # if not empty_past_kv:
925
- # position_ids = position_ids[:, -input_ids.shape[1] :]
926
-
927
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
928
- if inputs_embeds is not None and empty_past_kv:
929
- model_inputs = {"inputs_embeds": inputs_embeds}
930
- else:
931
- model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
932
-
933
- model_inputs.update(
934
- {
935
- # "position_ids": position_ids,
936
- "past_key_values": past_key_values,
937
- "use_cache": use_cache,
938
- "attention_mask": attention_mask,
939
- "cache_position": cache_position,
940
- }
941
- )
942
- return model_inputs
943
-
944
-
945
- __all__ = ["LFM2ForCausalLM", "LFM2Model", "LFM2PretrainedModel"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,2 +0,0 @@
1
- transformers==4.53.0.dev0
2
- tokenizers==0.21.1