vyles commited on
Commit
1d91706
·
verified ·
1 Parent(s): 5f27a69

Update GPT_SoVITS/AR/modules/patched_mha_with_cache.py

Browse files
GPT_SoVITS/AR/modules/patched_mha_with_cache.py CHANGED
@@ -1,465 +1,466 @@
1
- from torch.nn.functional import *
2
- from torch.nn.functional import (
3
- _mha_shape_check,
4
- _canonical_mask,
5
- _none_or_dtype,
6
- _in_projection_packed,
7
- )
8
- from torch.nn import functional as F
9
- import torch
10
- # Tensor = torch.Tensor
11
- # from typing import Callable, List, Optional, Tuple, Union
12
-
13
-
14
- def multi_head_attention_forward_patched(
15
- query: Tensor,
16
- key: Tensor,
17
- value: Tensor,
18
- embed_dim_to_check: int,
19
- num_heads: int,
20
- in_proj_weight: Optional[Tensor],
21
- in_proj_bias: Optional[Tensor],
22
- bias_k: Optional[Tensor],
23
- bias_v: Optional[Tensor],
24
- add_zero_attn: bool,
25
- dropout_p: float,
26
- out_proj_weight: Tensor,
27
- out_proj_bias: Optional[Tensor],
28
- training: bool = True,
29
- key_padding_mask: Optional[Tensor] = None,
30
- need_weights: bool = True,
31
- attn_mask: Optional[Tensor] = None,
32
- use_separate_proj_weight: bool = False,
33
- q_proj_weight: Optional[Tensor] = None,
34
- k_proj_weight: Optional[Tensor] = None,
35
- v_proj_weight: Optional[Tensor] = None,
36
- static_k: Optional[Tensor] = None,
37
- static_v: Optional[Tensor] = None,
38
- average_attn_weights: bool = True,
39
- is_causal: bool = False,
40
- cache=None,
41
- ) -> Tuple[Tensor, Optional[Tensor]]:
42
- r"""
43
- Args:
44
- query, key, value: map a query and a set of key-value pairs to an output.
45
- See "Attention Is All You Need" for more details.
46
- embed_dim_to_check: total dimension of the model.
47
- num_heads: parallel attention heads.
48
- in_proj_weight, in_proj_bias: input projection weight and bias.
49
- bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
50
- add_zero_attn: add a new batch of zeros to the key and
51
- value sequences at dim=1.
52
- dropout_p: probability of an element to be zeroed.
53
- out_proj_weight, out_proj_bias: the output projection weight and bias.
54
- training: apply dropout if is ``True``.
55
- key_padding_mask: if provided, specified padding elements in the key will
56
- be ignored by the attention. This is an binary mask. When the value is True,
57
- the corresponding value on the attention layer will be filled with -inf.
58
- need_weights: output attn_output_weights.
59
- Default: `True`
60
- Note: `needs_weight` defaults to `True`, but should be set to `False`
61
- For best performance when attention weights are not nedeeded.
62
- *Setting needs_weights to `True`
63
- leads to a significant performance degradation.*
64
- attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
65
- the batches while a 3D mask allows to specify a different mask for the entries of each batch.
66
- is_causal: If specified, applies a causal mask as attention mask, and ignores
67
- attn_mask for computing scaled dot product attention.
68
- Default: ``False``.
69
- .. warning::
70
- is_causal is provides a hint that the attn_mask is the
71
- causal mask.Providing incorrect hints can result in
72
- incorrect execution, including forward and backward
73
- compatibility.
74
- use_separate_proj_weight: the function accept the proj. weights for query, key,
75
- and value in different forms. If false, in_proj_weight will be used, which is
76
- a combination of q_proj_weight, k_proj_weight, v_proj_weight.
77
- q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
78
- static_k, static_v: static key and value used for attention operators.
79
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
80
- Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
81
- when ``need_weights=True.``. Default: True
82
-
83
-
84
- Shape:
85
- Inputs:
86
- - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
87
- the embedding dimension.
88
- - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
89
- the embedding dimension.
90
- - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
91
- the embedding dimension.
92
- - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
93
- If a FloatTensor is provided, it will be directly added to the value.
94
- If a BoolTensor is provided, the positions with the
95
- value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
96
- - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
97
- 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
98
- S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
99
- positions. If a BoolTensor is provided, positions with ``True``
100
- are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
101
- is provided, it will be added to the attention weight.
102
- - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
103
- N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
104
- - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
105
- N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
106
-
107
- Outputs:
108
- - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
109
- E is the embedding dimension.
110
- - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
111
- attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
112
- :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
113
- :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
114
- head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
115
- """
116
- tens_ops = (
117
- query,
118
- key,
119
- value,
120
- in_proj_weight,
121
- in_proj_bias,
122
- bias_k,
123
- bias_v,
124
- out_proj_weight,
125
- out_proj_bias,
126
- )
127
- if has_torch_function(tens_ops):
128
- return handle_torch_function(
129
- multi_head_attention_forward,
130
- tens_ops,
131
- query,
132
- key,
133
- value,
134
- embed_dim_to_check,
135
- num_heads,
136
- in_proj_weight,
137
- in_proj_bias,
138
- bias_k,
139
- bias_v,
140
- add_zero_attn,
141
- dropout_p,
142
- out_proj_weight,
143
- out_proj_bias,
144
- training=training,
145
- key_padding_mask=key_padding_mask,
146
- need_weights=need_weights,
147
- attn_mask=attn_mask,
148
- is_causal=is_causal,
149
- use_separate_proj_weight=use_separate_proj_weight,
150
- q_proj_weight=q_proj_weight,
151
- k_proj_weight=k_proj_weight,
152
- v_proj_weight=v_proj_weight,
153
- static_k=static_k,
154
- static_v=static_v,
155
- average_attn_weights=average_attn_weights,
156
- cache=cache,
157
- )
158
-
159
- is_batched = _mha_shape_check(
160
- query, key, value, key_padding_mask, attn_mask, num_heads
161
- )
162
-
163
- # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
164
- # is batched, run the computation and before returning squeeze the
165
- # batch dimension so that the output doesn't carry this temporary batch dimension.
166
- if not is_batched:
167
- # unsqueeze if the input is unbatched
168
- query = query.unsqueeze(1)
169
- key = key.unsqueeze(1)
170
- value = value.unsqueeze(1)
171
- if key_padding_mask is not None:
172
- key_padding_mask = key_padding_mask.unsqueeze(0)
173
-
174
- # set up shape vars
175
- tgt_len, bsz, embed_dim = query.shape
176
- src_len, _, _ = key.shape
177
-
178
- key_padding_mask = _canonical_mask(
179
- mask=key_padding_mask,
180
- mask_name="key_padding_mask",
181
- other_type=_none_or_dtype(attn_mask),
182
- other_name="attn_mask",
183
- target_type=query.dtype,
184
- )
185
-
186
- if is_causal and attn_mask is None:
187
- raise RuntimeError(
188
- "Need attn_mask if specifying the is_causal hint. "
189
- "You may use the Transformer module method "
190
- "`generate_square_subsequent_mask` to create this mask."
191
- )
192
-
193
- if is_causal and key_padding_mask is None and not need_weights:
194
- # when we have a kpm or need weights, we need attn_mask
195
- # Otherwise, we use the is_causal hint go as is_causal
196
- # indicator to SDPA.
197
- attn_mask = None
198
- else:
199
- attn_mask = _canonical_mask(
200
- mask=attn_mask,
201
- mask_name="attn_mask",
202
- other_type=None,
203
- other_name="",
204
- target_type=query.dtype,
205
- check_other=False,
206
- )
207
-
208
- if key_padding_mask is not None:
209
- # We have the attn_mask, and use that to merge kpm into it.
210
- # Turn off use of is_causal hint, as the merged mask is no
211
- # longer causal.
212
- is_causal = False
213
-
214
- assert (
215
- embed_dim == embed_dim_to_check
216
- ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
217
- if isinstance(embed_dim, torch.Tensor):
218
- # embed_dim can be a tensor when JIT tracing
219
- head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
220
- else:
221
- head_dim = embed_dim // num_heads
222
- assert (
223
- head_dim * num_heads == embed_dim
224
- ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
225
- if use_separate_proj_weight:
226
- # allow MHA to have different embedding dimensions when separate projection weights are used
227
- assert (
228
- key.shape[:2] == value.shape[:2]
229
- ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
230
- else:
231
- assert (
232
- key.shape == value.shape
233
- ), f"key shape {key.shape} does not match value shape {value.shape}"
234
-
235
- #
236
- # compute in-projection
237
- #
238
- if not use_separate_proj_weight:
239
- assert (
240
- in_proj_weight is not None
241
- ), "use_separate_proj_weight is False but in_proj_weight is None"
242
- q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
243
- else:
244
- assert (
245
- q_proj_weight is not None
246
- ), "use_separate_proj_weight is True but q_proj_weight is None"
247
- assert (
248
- k_proj_weight is not None
249
- ), "use_separate_proj_weight is True but k_proj_weight is None"
250
- assert (
251
- v_proj_weight is not None
252
- ), "use_separate_proj_weight is True but v_proj_weight is None"
253
- if in_proj_bias is None:
254
- b_q = b_k = b_v = None
255
- else:
256
- b_q, b_k, b_v = in_proj_bias.chunk(3)
257
- q, k, v = _in_projection(
258
- query,
259
- key,
260
- value,
261
- q_proj_weight,
262
- k_proj_weight,
263
- v_proj_weight,
264
- b_q,
265
- b_k,
266
- b_v,
267
- )
268
- if cache != None:
269
- if cache["first_infer"] == 1:
270
- cache["k"][cache["stage"]] = k
271
- # print(0,cache["k"].shape)
272
- cache["v"][cache["stage"]] = v
273
- else: ###12个layer每个都要留自己的cache_kv
274
- # print(1,cache["k"].shape)
275
- cache["k"][cache["stage"]] = torch.cat(
276
- [cache["k"][cache["stage"]], k], 0
277
- ) ##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
278
- cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
279
- # print(2, cache["k"].shape)
280
- src_len = cache["k"][cache["stage"]].shape[0]
281
- k = cache["k"][cache["stage"]]
282
- v = cache["v"][cache["stage"]]
283
- # if attn_mask is not None:
284
- # attn_mask=attn_mask[-1:,]
285
- # print(attn_mask.shape,attn_mask)
286
- cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
287
- # print(2333,cache)
288
- # prep attention mask
289
-
290
- attn_mask = _canonical_mask(
291
- mask=attn_mask,
292
- mask_name="attn_mask",
293
- other_type=None,
294
- other_name="",
295
- target_type=q.dtype,
296
- check_other=False,
297
- )
298
-
299
- if attn_mask is not None:
300
- # ensure attn_mask's dim is 3
301
- if attn_mask.dim() == 2:
302
- correct_2d_size = (tgt_len, src_len)
303
- if attn_mask.shape != correct_2d_size:
304
- raise RuntimeError(
305
- f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
306
- )
307
- attn_mask = attn_mask.unsqueeze(0)
308
- elif attn_mask.dim() == 3:
309
- correct_3d_size = (bsz * num_heads, tgt_len, src_len)
310
- if attn_mask.shape != correct_3d_size:
311
- raise RuntimeError(
312
- f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
313
- )
314
- else:
315
- raise RuntimeError(
316
- f"attn_mask's dimension {attn_mask.dim()} is not supported"
317
- )
318
-
319
- # add bias along batch dimension (currently second)
320
- if bias_k is not None and bias_v is not None:
321
- assert static_k is None, "bias cannot be added to static key."
322
- assert static_v is None, "bias cannot be added to static value."
323
- k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
324
- v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
325
- if attn_mask is not None:
326
- attn_mask = pad(attn_mask, (0, 1))
327
- if key_padding_mask is not None:
328
- key_padding_mask = pad(key_padding_mask, (0, 1))
329
- else:
330
- assert bias_k is None
331
- assert bias_v is None
332
-
333
- #
334
- # reshape q, k, v for multihead attention and make em batch first
335
- #
336
- q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
337
- if static_k is None:
338
- k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
339
- else:
340
- # TODO finish disentangling control flow so we don't do in-projections when statics are passed
341
- assert (
342
- static_k.size(0) == bsz * num_heads
343
- ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
344
- assert (
345
- static_k.size(2) == head_dim
346
- ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
347
- k = static_k
348
- if static_v is None:
349
- v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
350
- else:
351
- # TODO finish disentangling control flow so we don't do in-projections when statics are passed
352
- assert (
353
- static_v.size(0) == bsz * num_heads
354
- ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
355
- assert (
356
- static_v.size(2) == head_dim
357
- ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
358
- v = static_v
359
-
360
- # add zero attention along batch dimension (now first)
361
- if add_zero_attn:
362
- zero_attn_shape = (bsz * num_heads, 1, head_dim)
363
- k = torch.cat(
364
- [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
365
- )
366
- v = torch.cat(
367
- [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
368
- )
369
- if attn_mask is not None:
370
- attn_mask = pad(attn_mask, (0, 1))
371
- if key_padding_mask is not None:
372
- key_padding_mask = pad(key_padding_mask, (0, 1))
373
-
374
- # update source sequence length after adjustments
375
- src_len = k.size(1)
376
-
377
- # merge key padding and attention masks
378
- if key_padding_mask is not None:
379
- assert key_padding_mask.shape == (
380
- bsz,
381
- src_len,
382
- ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
383
- key_padding_mask = (
384
- key_padding_mask.view(bsz, 1, 1, src_len)
385
- .expand(-1, num_heads, -1, -1)
386
- .reshape(bsz * num_heads, 1, src_len)
387
- )
388
- if attn_mask is None:
389
- attn_mask = key_padding_mask
390
- else:
391
- attn_mask = attn_mask + key_padding_mask
392
-
393
- # adjust dropout probability
394
- if not training:
395
- dropout_p = 0.0
396
-
397
- #
398
- # (deep breath) calculate attention and out projection
399
- #
400
-
401
- if need_weights:
402
- B, Nt, E = q.shape
403
- q_scaled = q / math.sqrt(E)
404
-
405
- assert not (
406
- is_causal and attn_mask is None
407
- ), "FIXME: is_causal not implemented for need_weights"
408
-
409
- if attn_mask is not None:
410
- attn_output_weights = torch.baddbmm(
411
- attn_mask, q_scaled, k.transpose(-2, -1)
412
- )
413
- else:
414
- attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
415
- attn_output_weights = softmax(attn_output_weights, dim=-1)
416
- if dropout_p > 0.0:
417
- attn_output_weights = dropout(attn_output_weights, p=dropout_p)
418
-
419
- attn_output = torch.bmm(attn_output_weights, v)
420
-
421
- attn_output = (
422
- attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
423
- )
424
- attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
425
- attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
426
-
427
- # optionally average attention weights over heads
428
- attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
429
- if average_attn_weights:
430
- attn_output_weights = attn_output_weights.mean(dim=1)
431
-
432
- if not is_batched:
433
- # squeeze the output if input was unbatched
434
- attn_output = attn_output.squeeze(1)
435
- attn_output_weights = attn_output_weights.squeeze(0)
436
- return attn_output, attn_output_weights
437
- else:
438
- # attn_mask can be either (L,S) or (N*num_heads, L, S)
439
- # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
440
- # in order to match the input for SDPA of (N, num_heads, L, S)
441
- if attn_mask is not None:
442
- if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
443
- attn_mask = attn_mask.unsqueeze(0)
444
- else:
445
- attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
446
-
447
- q = q.view(bsz, num_heads, tgt_len, head_dim)
448
- k = k.view(bsz, num_heads, src_len, head_dim)
449
- v = v.view(bsz, num_heads, src_len, head_dim)
450
-
451
- # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
452
- attn_output = scaled_dot_product_attention(
453
- q, k, v, attn_mask, dropout_p, is_causal
454
- )
455
-
456
- attn_output = (
457
- attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
458
- )
459
-
460
- attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
461
- attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
462
- if not is_batched:
463
- # squeeze the output if input was unbatched
464
- attn_output = attn_output.squeeze(1)
465
- return attn_output, None
 
 
1
+ from torch.nn.functional import *
2
+ from torch.nn.functional import (
3
+ _mha_shape_check,
4
+ _canonical_mask,
5
+ _none_or_dtype,
6
+ _in_projection_packed,
7
+ )
8
+ from torch.nn import functional as F
9
+ import torch
10
+ from typing import Tuple, Optional, Any
11
+ # Tensor = torch.Tensor
12
+ # from typing import Callable, List, Optional, Tuple, Union
13
+
14
+
15
+ def multi_head_attention_forward_patched(
16
+ query: Tensor,
17
+ key: Tensor,
18
+ value: Tensor,
19
+ embed_dim_to_check: int,
20
+ num_heads: int,
21
+ in_proj_weight: Optional[Tensor],
22
+ in_proj_bias: Optional[Tensor],
23
+ bias_k: Optional[Tensor],
24
+ bias_v: Optional[Tensor],
25
+ add_zero_attn: bool,
26
+ dropout_p: float,
27
+ out_proj_weight: Tensor,
28
+ out_proj_bias: Optional[Tensor],
29
+ training: bool = True,
30
+ key_padding_mask: Optional[Tensor] = None,
31
+ need_weights: bool = True,
32
+ attn_mask: Optional[Tensor] = None,
33
+ use_separate_proj_weight: bool = False,
34
+ q_proj_weight: Optional[Tensor] = None,
35
+ k_proj_weight: Optional[Tensor] = None,
36
+ v_proj_weight: Optional[Tensor] = None,
37
+ static_k: Optional[Tensor] = None,
38
+ static_v: Optional[Tensor] = None,
39
+ average_attn_weights: bool = True,
40
+ is_causal: bool = False,
41
+ cache=None,
42
+ ) -> Tuple[Tensor, Optional[Tensor]]:
43
+ r"""
44
+ Args:
45
+ query, key, value: map a query and a set of key-value pairs to an output.
46
+ See "Attention Is All You Need" for more details.
47
+ embed_dim_to_check: total dimension of the model.
48
+ num_heads: parallel attention heads.
49
+ in_proj_weight, in_proj_bias: input projection weight and bias.
50
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
51
+ add_zero_attn: add a new batch of zeros to the key and
52
+ value sequences at dim=1.
53
+ dropout_p: probability of an element to be zeroed.
54
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
55
+ training: apply dropout if is ``True``.
56
+ key_padding_mask: if provided, specified padding elements in the key will
57
+ be ignored by the attention. This is an binary mask. When the value is True,
58
+ the corresponding value on the attention layer will be filled with -inf.
59
+ need_weights: output attn_output_weights.
60
+ Default: `True`
61
+ Note: `needs_weight` defaults to `True`, but should be set to `False`
62
+ For best performance when attention weights are not nedeeded.
63
+ *Setting needs_weights to `True`
64
+ leads to a significant performance degradation.*
65
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
66
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
67
+ is_causal: If specified, applies a causal mask as attention mask, and ignores
68
+ attn_mask for computing scaled dot product attention.
69
+ Default: ``False``.
70
+ .. warning::
71
+ is_causal is provides a hint that the attn_mask is the
72
+ causal mask.Providing incorrect hints can result in
73
+ incorrect execution, including forward and backward
74
+ compatibility.
75
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
76
+ and value in different forms. If false, in_proj_weight will be used, which is
77
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
78
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
79
+ static_k, static_v: static key and value used for attention operators.
80
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
81
+ Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
82
+ when ``need_weights=True.``. Default: True
83
+
84
+
85
+ Shape:
86
+ Inputs:
87
+ - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
88
+ the embedding dimension.
89
+ - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
90
+ the embedding dimension.
91
+ - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
92
+ the embedding dimension.
93
+ - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
94
+ If a FloatTensor is provided, it will be directly added to the value.
95
+ If a BoolTensor is provided, the positions with the
96
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
97
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
98
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
99
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
100
+ positions. If a BoolTensor is provided, positions with ``True``
101
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
102
+ is provided, it will be added to the attention weight.
103
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
104
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
105
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
106
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
107
+
108
+ Outputs:
109
+ - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
110
+ E is the embedding dimension.
111
+ - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
112
+ attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
113
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
114
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
115
+ head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
116
+ """
117
+ tens_ops = (
118
+ query,
119
+ key,
120
+ value,
121
+ in_proj_weight,
122
+ in_proj_bias,
123
+ bias_k,
124
+ bias_v,
125
+ out_proj_weight,
126
+ out_proj_bias,
127
+ )
128
+ if has_torch_function(tens_ops):
129
+ return handle_torch_function(
130
+ multi_head_attention_forward,
131
+ tens_ops,
132
+ query,
133
+ key,
134
+ value,
135
+ embed_dim_to_check,
136
+ num_heads,
137
+ in_proj_weight,
138
+ in_proj_bias,
139
+ bias_k,
140
+ bias_v,
141
+ add_zero_attn,
142
+ dropout_p,
143
+ out_proj_weight,
144
+ out_proj_bias,
145
+ training=training,
146
+ key_padding_mask=key_padding_mask,
147
+ need_weights=need_weights,
148
+ attn_mask=attn_mask,
149
+ is_causal=is_causal,
150
+ use_separate_proj_weight=use_separate_proj_weight,
151
+ q_proj_weight=q_proj_weight,
152
+ k_proj_weight=k_proj_weight,
153
+ v_proj_weight=v_proj_weight,
154
+ static_k=static_k,
155
+ static_v=static_v,
156
+ average_attn_weights=average_attn_weights,
157
+ cache=cache,
158
+ )
159
+
160
+ is_batched = _mha_shape_check(
161
+ query, key, value, key_padding_mask, attn_mask, num_heads
162
+ )
163
+
164
+ # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
165
+ # is batched, run the computation and before returning squeeze the
166
+ # batch dimension so that the output doesn't carry this temporary batch dimension.
167
+ if not is_batched:
168
+ # unsqueeze if the input is unbatched
169
+ query = query.unsqueeze(1)
170
+ key = key.unsqueeze(1)
171
+ value = value.unsqueeze(1)
172
+ if key_padding_mask is not None:
173
+ key_padding_mask = key_padding_mask.unsqueeze(0)
174
+
175
+ # set up shape vars
176
+ tgt_len, bsz, embed_dim = query.shape
177
+ src_len, _, _ = key.shape
178
+
179
+ key_padding_mask = _canonical_mask(
180
+ mask=key_padding_mask,
181
+ mask_name="key_padding_mask",
182
+ other_type=_none_or_dtype(attn_mask),
183
+ other_name="attn_mask",
184
+ target_type=query.dtype,
185
+ )
186
+
187
+ if is_causal and attn_mask is None:
188
+ raise RuntimeError(
189
+ "Need attn_mask if specifying the is_causal hint. "
190
+ "You may use the Transformer module method "
191
+ "`generate_square_subsequent_mask` to create this mask."
192
+ )
193
+
194
+ if is_causal and key_padding_mask is None and not need_weights:
195
+ # when we have a kpm or need weights, we need attn_mask
196
+ # Otherwise, we use the is_causal hint go as is_causal
197
+ # indicator to SDPA.
198
+ attn_mask = None
199
+ else:
200
+ attn_mask = _canonical_mask(
201
+ mask=attn_mask,
202
+ mask_name="attn_mask",
203
+ other_type=None,
204
+ other_name="",
205
+ target_type=query.dtype,
206
+ check_other=False,
207
+ )
208
+
209
+ if key_padding_mask is not None:
210
+ # We have the attn_mask, and use that to merge kpm into it.
211
+ # Turn off use of is_causal hint, as the merged mask is no
212
+ # longer causal.
213
+ is_causal = False
214
+
215
+ assert (
216
+ embed_dim == embed_dim_to_check
217
+ ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
218
+ if isinstance(embed_dim, torch.Tensor):
219
+ # embed_dim can be a tensor when JIT tracing
220
+ head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
221
+ else:
222
+ head_dim = embed_dim // num_heads
223
+ assert (
224
+ head_dim * num_heads == embed_dim
225
+ ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
226
+ if use_separate_proj_weight:
227
+ # allow MHA to have different embedding dimensions when separate projection weights are used
228
+ assert (
229
+ key.shape[:2] == value.shape[:2]
230
+ ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
231
+ else:
232
+ assert (
233
+ key.shape == value.shape
234
+ ), f"key shape {key.shape} does not match value shape {value.shape}"
235
+
236
+ #
237
+ # compute in-projection
238
+ #
239
+ if not use_separate_proj_weight:
240
+ assert (
241
+ in_proj_weight is not None
242
+ ), "use_separate_proj_weight is False but in_proj_weight is None"
243
+ q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
244
+ else:
245
+ assert (
246
+ q_proj_weight is not None
247
+ ), "use_separate_proj_weight is True but q_proj_weight is None"
248
+ assert (
249
+ k_proj_weight is not None
250
+ ), "use_separate_proj_weight is True but k_proj_weight is None"
251
+ assert (
252
+ v_proj_weight is not None
253
+ ), "use_separate_proj_weight is True but v_proj_weight is None"
254
+ if in_proj_bias is None:
255
+ b_q = b_k = b_v = None
256
+ else:
257
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
258
+ q, k, v = _in_projection(
259
+ query,
260
+ key,
261
+ value,
262
+ q_proj_weight,
263
+ k_proj_weight,
264
+ v_proj_weight,
265
+ b_q,
266
+ b_k,
267
+ b_v,
268
+ )
269
+ if cache != None:
270
+ if cache["first_infer"] == 1:
271
+ cache["k"][cache["stage"]] = k
272
+ # print(0,cache["k"].shape)
273
+ cache["v"][cache["stage"]] = v
274
+ else: ###12个layer每个都要留自己的cache_kv
275
+ # print(1,cache["k"].shape)
276
+ cache["k"][cache["stage"]] = torch.cat(
277
+ [cache["k"][cache["stage"]], k], 0
278
+ ) ##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
279
+ cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
280
+ # print(2, cache["k"].shape)
281
+ src_len = cache["k"][cache["stage"]].shape[0]
282
+ k = cache["k"][cache["stage"]]
283
+ v = cache["v"][cache["stage"]]
284
+ # if attn_mask is not None:
285
+ # attn_mask=attn_mask[-1:,]
286
+ # print(attn_mask.shape,attn_mask)
287
+ cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
288
+ # print(2333,cache)
289
+ # prep attention mask
290
+
291
+ attn_mask = _canonical_mask(
292
+ mask=attn_mask,
293
+ mask_name="attn_mask",
294
+ other_type=None,
295
+ other_name="",
296
+ target_type=q.dtype,
297
+ check_other=False,
298
+ )
299
+
300
+ if attn_mask is not None:
301
+ # ensure attn_mask's dim is 3
302
+ if attn_mask.dim() == 2:
303
+ correct_2d_size = (tgt_len, src_len)
304
+ if attn_mask.shape != correct_2d_size:
305
+ raise RuntimeError(
306
+ f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
307
+ )
308
+ attn_mask = attn_mask.unsqueeze(0)
309
+ elif attn_mask.dim() == 3:
310
+ correct_3d_size = (bsz * num_heads, tgt_len, src_len)
311
+ if attn_mask.shape != correct_3d_size:
312
+ raise RuntimeError(
313
+ f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
314
+ )
315
+ else:
316
+ raise RuntimeError(
317
+ f"attn_mask's dimension {attn_mask.dim()} is not supported"
318
+ )
319
+
320
+ # add bias along batch dimension (currently second)
321
+ if bias_k is not None and bias_v is not None:
322
+ assert static_k is None, "bias cannot be added to static key."
323
+ assert static_v is None, "bias cannot be added to static value."
324
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
325
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
326
+ if attn_mask is not None:
327
+ attn_mask = pad(attn_mask, (0, 1))
328
+ if key_padding_mask is not None:
329
+ key_padding_mask = pad(key_padding_mask, (0, 1))
330
+ else:
331
+ assert bias_k is None
332
+ assert bias_v is None
333
+
334
+ #
335
+ # reshape q, k, v for multihead attention and make em batch first
336
+ #
337
+ q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
338
+ if static_k is None:
339
+ k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
340
+ else:
341
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
342
+ assert (
343
+ static_k.size(0) == bsz * num_heads
344
+ ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
345
+ assert (
346
+ static_k.size(2) == head_dim
347
+ ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
348
+ k = static_k
349
+ if static_v is None:
350
+ v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
351
+ else:
352
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
353
+ assert (
354
+ static_v.size(0) == bsz * num_heads
355
+ ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
356
+ assert (
357
+ static_v.size(2) == head_dim
358
+ ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
359
+ v = static_v
360
+
361
+ # add zero attention along batch dimension (now first)
362
+ if add_zero_attn:
363
+ zero_attn_shape = (bsz * num_heads, 1, head_dim)
364
+ k = torch.cat(
365
+ [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
366
+ )
367
+ v = torch.cat(
368
+ [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
369
+ )
370
+ if attn_mask is not None:
371
+ attn_mask = pad(attn_mask, (0, 1))
372
+ if key_padding_mask is not None:
373
+ key_padding_mask = pad(key_padding_mask, (0, 1))
374
+
375
+ # update source sequence length after adjustments
376
+ src_len = k.size(1)
377
+
378
+ # merge key padding and attention masks
379
+ if key_padding_mask is not None:
380
+ assert key_padding_mask.shape == (
381
+ bsz,
382
+ src_len,
383
+ ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
384
+ key_padding_mask = (
385
+ key_padding_mask.view(bsz, 1, 1, src_len)
386
+ .expand(-1, num_heads, -1, -1)
387
+ .reshape(bsz * num_heads, 1, src_len)
388
+ )
389
+ if attn_mask is None:
390
+ attn_mask = key_padding_mask
391
+ else:
392
+ attn_mask = attn_mask + key_padding_mask
393
+
394
+ # adjust dropout probability
395
+ if not training:
396
+ dropout_p = 0.0
397
+
398
+ #
399
+ # (deep breath) calculate attention and out projection
400
+ #
401
+
402
+ if need_weights:
403
+ B, Nt, E = q.shape
404
+ q_scaled = q / math.sqrt(E)
405
+
406
+ assert not (
407
+ is_causal and attn_mask is None
408
+ ), "FIXME: is_causal not implemented for need_weights"
409
+
410
+ if attn_mask is not None:
411
+ attn_output_weights = torch.baddbmm(
412
+ attn_mask, q_scaled, k.transpose(-2, -1)
413
+ )
414
+ else:
415
+ attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
416
+ attn_output_weights = softmax(attn_output_weights, dim=-1)
417
+ if dropout_p > 0.0:
418
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p)
419
+
420
+ attn_output = torch.bmm(attn_output_weights, v)
421
+
422
+ attn_output = (
423
+ attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
424
+ )
425
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
426
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
427
+
428
+ # optionally average attention weights over heads
429
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
430
+ if average_attn_weights:
431
+ attn_output_weights = attn_output_weights.mean(dim=1)
432
+
433
+ if not is_batched:
434
+ # squeeze the output if input was unbatched
435
+ attn_output = attn_output.squeeze(1)
436
+ attn_output_weights = attn_output_weights.squeeze(0)
437
+ return attn_output, attn_output_weights
438
+ else:
439
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
440
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
441
+ # in order to match the input for SDPA of (N, num_heads, L, S)
442
+ if attn_mask is not None:
443
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
444
+ attn_mask = attn_mask.unsqueeze(0)
445
+ else:
446
+ attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
447
+
448
+ q = q.view(bsz, num_heads, tgt_len, head_dim)
449
+ k = k.view(bsz, num_heads, src_len, head_dim)
450
+ v = v.view(bsz, num_heads, src_len, head_dim)
451
+
452
+ # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
453
+ attn_output = scaled_dot_product_attention(
454
+ q, k, v, attn_mask, dropout_p, is_causal
455
+ )
456
+
457
+ attn_output = (
458
+ attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
459
+ )
460
+
461
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
462
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
463
+ if not is_batched:
464
+ # squeeze the output if input was unbatched
465
+ attn_output = attn_output.squeeze(1)
466
+ return attn_output, None