Spaces:
Running
Running
Update GPT_SoVITS/AR/modules/patched_mha_with_cache.py
Browse files
GPT_SoVITS/AR/modules/patched_mha_with_cache.py
CHANGED
@@ -1,465 +1,466 @@
|
|
1 |
-
from torch.nn.functional import *
|
2 |
-
from torch.nn.functional import (
|
3 |
-
_mha_shape_check,
|
4 |
-
_canonical_mask,
|
5 |
-
_none_or_dtype,
|
6 |
-
_in_projection_packed,
|
7 |
-
)
|
8 |
-
from torch.nn import functional as F
|
9 |
-
import torch
|
10 |
-
|
11 |
-
#
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
incorrect
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
If a
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
:math:`S` is the
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
#
|
165 |
-
#
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
"
|
190 |
-
"
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
#
|
196 |
-
#
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
#
|
211 |
-
#
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
#
|
237 |
-
#
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
cache["
|
273 |
-
|
274 |
-
|
275 |
-
cache["k"]
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
#
|
285 |
-
#
|
286 |
-
|
287 |
-
|
288 |
-
#
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
assert
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
assert
|
332 |
-
|
333 |
-
|
334 |
-
#
|
335 |
-
#
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
key_padding_mask
|
384 |
-
|
385 |
-
.
|
386 |
-
.
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
#
|
399 |
-
#
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
attn_output = attn_output
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
#
|
440 |
-
#
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
attn_output = attn_output
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
|
|
|
1 |
+
from torch.nn.functional import *
|
2 |
+
from torch.nn.functional import (
|
3 |
+
_mha_shape_check,
|
4 |
+
_canonical_mask,
|
5 |
+
_none_or_dtype,
|
6 |
+
_in_projection_packed,
|
7 |
+
)
|
8 |
+
from torch.nn import functional as F
|
9 |
+
import torch
|
10 |
+
from typing import Tuple, Optional, Any
|
11 |
+
# Tensor = torch.Tensor
|
12 |
+
# from typing import Callable, List, Optional, Tuple, Union
|
13 |
+
|
14 |
+
|
15 |
+
def multi_head_attention_forward_patched(
|
16 |
+
query: Tensor,
|
17 |
+
key: Tensor,
|
18 |
+
value: Tensor,
|
19 |
+
embed_dim_to_check: int,
|
20 |
+
num_heads: int,
|
21 |
+
in_proj_weight: Optional[Tensor],
|
22 |
+
in_proj_bias: Optional[Tensor],
|
23 |
+
bias_k: Optional[Tensor],
|
24 |
+
bias_v: Optional[Tensor],
|
25 |
+
add_zero_attn: bool,
|
26 |
+
dropout_p: float,
|
27 |
+
out_proj_weight: Tensor,
|
28 |
+
out_proj_bias: Optional[Tensor],
|
29 |
+
training: bool = True,
|
30 |
+
key_padding_mask: Optional[Tensor] = None,
|
31 |
+
need_weights: bool = True,
|
32 |
+
attn_mask: Optional[Tensor] = None,
|
33 |
+
use_separate_proj_weight: bool = False,
|
34 |
+
q_proj_weight: Optional[Tensor] = None,
|
35 |
+
k_proj_weight: Optional[Tensor] = None,
|
36 |
+
v_proj_weight: Optional[Tensor] = None,
|
37 |
+
static_k: Optional[Tensor] = None,
|
38 |
+
static_v: Optional[Tensor] = None,
|
39 |
+
average_attn_weights: bool = True,
|
40 |
+
is_causal: bool = False,
|
41 |
+
cache=None,
|
42 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
43 |
+
r"""
|
44 |
+
Args:
|
45 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
46 |
+
See "Attention Is All You Need" for more details.
|
47 |
+
embed_dim_to_check: total dimension of the model.
|
48 |
+
num_heads: parallel attention heads.
|
49 |
+
in_proj_weight, in_proj_bias: input projection weight and bias.
|
50 |
+
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
51 |
+
add_zero_attn: add a new batch of zeros to the key and
|
52 |
+
value sequences at dim=1.
|
53 |
+
dropout_p: probability of an element to be zeroed.
|
54 |
+
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
55 |
+
training: apply dropout if is ``True``.
|
56 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
57 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
58 |
+
the corresponding value on the attention layer will be filled with -inf.
|
59 |
+
need_weights: output attn_output_weights.
|
60 |
+
Default: `True`
|
61 |
+
Note: `needs_weight` defaults to `True`, but should be set to `False`
|
62 |
+
For best performance when attention weights are not nedeeded.
|
63 |
+
*Setting needs_weights to `True`
|
64 |
+
leads to a significant performance degradation.*
|
65 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
66 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
67 |
+
is_causal: If specified, applies a causal mask as attention mask, and ignores
|
68 |
+
attn_mask for computing scaled dot product attention.
|
69 |
+
Default: ``False``.
|
70 |
+
.. warning::
|
71 |
+
is_causal is provides a hint that the attn_mask is the
|
72 |
+
causal mask.Providing incorrect hints can result in
|
73 |
+
incorrect execution, including forward and backward
|
74 |
+
compatibility.
|
75 |
+
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
76 |
+
and value in different forms. If false, in_proj_weight will be used, which is
|
77 |
+
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
78 |
+
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
79 |
+
static_k, static_v: static key and value used for attention operators.
|
80 |
+
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
|
81 |
+
Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
|
82 |
+
when ``need_weights=True.``. Default: True
|
83 |
+
|
84 |
+
|
85 |
+
Shape:
|
86 |
+
Inputs:
|
87 |
+
- query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
88 |
+
the embedding dimension.
|
89 |
+
- key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
90 |
+
the embedding dimension.
|
91 |
+
- value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
92 |
+
the embedding dimension.
|
93 |
+
- key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
94 |
+
If a FloatTensor is provided, it will be directly added to the value.
|
95 |
+
If a BoolTensor is provided, the positions with the
|
96 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
97 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
98 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
99 |
+
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
100 |
+
positions. If a BoolTensor is provided, positions with ``True``
|
101 |
+
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
102 |
+
is provided, it will be added to the attention weight.
|
103 |
+
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
104 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
105 |
+
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
106 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
107 |
+
|
108 |
+
Outputs:
|
109 |
+
- attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
110 |
+
E is the embedding dimension.
|
111 |
+
- attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
|
112 |
+
attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
113 |
+
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
114 |
+
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
115 |
+
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
|
116 |
+
"""
|
117 |
+
tens_ops = (
|
118 |
+
query,
|
119 |
+
key,
|
120 |
+
value,
|
121 |
+
in_proj_weight,
|
122 |
+
in_proj_bias,
|
123 |
+
bias_k,
|
124 |
+
bias_v,
|
125 |
+
out_proj_weight,
|
126 |
+
out_proj_bias,
|
127 |
+
)
|
128 |
+
if has_torch_function(tens_ops):
|
129 |
+
return handle_torch_function(
|
130 |
+
multi_head_attention_forward,
|
131 |
+
tens_ops,
|
132 |
+
query,
|
133 |
+
key,
|
134 |
+
value,
|
135 |
+
embed_dim_to_check,
|
136 |
+
num_heads,
|
137 |
+
in_proj_weight,
|
138 |
+
in_proj_bias,
|
139 |
+
bias_k,
|
140 |
+
bias_v,
|
141 |
+
add_zero_attn,
|
142 |
+
dropout_p,
|
143 |
+
out_proj_weight,
|
144 |
+
out_proj_bias,
|
145 |
+
training=training,
|
146 |
+
key_padding_mask=key_padding_mask,
|
147 |
+
need_weights=need_weights,
|
148 |
+
attn_mask=attn_mask,
|
149 |
+
is_causal=is_causal,
|
150 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
151 |
+
q_proj_weight=q_proj_weight,
|
152 |
+
k_proj_weight=k_proj_weight,
|
153 |
+
v_proj_weight=v_proj_weight,
|
154 |
+
static_k=static_k,
|
155 |
+
static_v=static_v,
|
156 |
+
average_attn_weights=average_attn_weights,
|
157 |
+
cache=cache,
|
158 |
+
)
|
159 |
+
|
160 |
+
is_batched = _mha_shape_check(
|
161 |
+
query, key, value, key_padding_mask, attn_mask, num_heads
|
162 |
+
)
|
163 |
+
|
164 |
+
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
165 |
+
# is batched, run the computation and before returning squeeze the
|
166 |
+
# batch dimension so that the output doesn't carry this temporary batch dimension.
|
167 |
+
if not is_batched:
|
168 |
+
# unsqueeze if the input is unbatched
|
169 |
+
query = query.unsqueeze(1)
|
170 |
+
key = key.unsqueeze(1)
|
171 |
+
value = value.unsqueeze(1)
|
172 |
+
if key_padding_mask is not None:
|
173 |
+
key_padding_mask = key_padding_mask.unsqueeze(0)
|
174 |
+
|
175 |
+
# set up shape vars
|
176 |
+
tgt_len, bsz, embed_dim = query.shape
|
177 |
+
src_len, _, _ = key.shape
|
178 |
+
|
179 |
+
key_padding_mask = _canonical_mask(
|
180 |
+
mask=key_padding_mask,
|
181 |
+
mask_name="key_padding_mask",
|
182 |
+
other_type=_none_or_dtype(attn_mask),
|
183 |
+
other_name="attn_mask",
|
184 |
+
target_type=query.dtype,
|
185 |
+
)
|
186 |
+
|
187 |
+
if is_causal and attn_mask is None:
|
188 |
+
raise RuntimeError(
|
189 |
+
"Need attn_mask if specifying the is_causal hint. "
|
190 |
+
"You may use the Transformer module method "
|
191 |
+
"`generate_square_subsequent_mask` to create this mask."
|
192 |
+
)
|
193 |
+
|
194 |
+
if is_causal and key_padding_mask is None and not need_weights:
|
195 |
+
# when we have a kpm or need weights, we need attn_mask
|
196 |
+
# Otherwise, we use the is_causal hint go as is_causal
|
197 |
+
# indicator to SDPA.
|
198 |
+
attn_mask = None
|
199 |
+
else:
|
200 |
+
attn_mask = _canonical_mask(
|
201 |
+
mask=attn_mask,
|
202 |
+
mask_name="attn_mask",
|
203 |
+
other_type=None,
|
204 |
+
other_name="",
|
205 |
+
target_type=query.dtype,
|
206 |
+
check_other=False,
|
207 |
+
)
|
208 |
+
|
209 |
+
if key_padding_mask is not None:
|
210 |
+
# We have the attn_mask, and use that to merge kpm into it.
|
211 |
+
# Turn off use of is_causal hint, as the merged mask is no
|
212 |
+
# longer causal.
|
213 |
+
is_causal = False
|
214 |
+
|
215 |
+
assert (
|
216 |
+
embed_dim == embed_dim_to_check
|
217 |
+
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
218 |
+
if isinstance(embed_dim, torch.Tensor):
|
219 |
+
# embed_dim can be a tensor when JIT tracing
|
220 |
+
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
|
221 |
+
else:
|
222 |
+
head_dim = embed_dim // num_heads
|
223 |
+
assert (
|
224 |
+
head_dim * num_heads == embed_dim
|
225 |
+
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
226 |
+
if use_separate_proj_weight:
|
227 |
+
# allow MHA to have different embedding dimensions when separate projection weights are used
|
228 |
+
assert (
|
229 |
+
key.shape[:2] == value.shape[:2]
|
230 |
+
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
231 |
+
else:
|
232 |
+
assert (
|
233 |
+
key.shape == value.shape
|
234 |
+
), f"key shape {key.shape} does not match value shape {value.shape}"
|
235 |
+
|
236 |
+
#
|
237 |
+
# compute in-projection
|
238 |
+
#
|
239 |
+
if not use_separate_proj_weight:
|
240 |
+
assert (
|
241 |
+
in_proj_weight is not None
|
242 |
+
), "use_separate_proj_weight is False but in_proj_weight is None"
|
243 |
+
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
|
244 |
+
else:
|
245 |
+
assert (
|
246 |
+
q_proj_weight is not None
|
247 |
+
), "use_separate_proj_weight is True but q_proj_weight is None"
|
248 |
+
assert (
|
249 |
+
k_proj_weight is not None
|
250 |
+
), "use_separate_proj_weight is True but k_proj_weight is None"
|
251 |
+
assert (
|
252 |
+
v_proj_weight is not None
|
253 |
+
), "use_separate_proj_weight is True but v_proj_weight is None"
|
254 |
+
if in_proj_bias is None:
|
255 |
+
b_q = b_k = b_v = None
|
256 |
+
else:
|
257 |
+
b_q, b_k, b_v = in_proj_bias.chunk(3)
|
258 |
+
q, k, v = _in_projection(
|
259 |
+
query,
|
260 |
+
key,
|
261 |
+
value,
|
262 |
+
q_proj_weight,
|
263 |
+
k_proj_weight,
|
264 |
+
v_proj_weight,
|
265 |
+
b_q,
|
266 |
+
b_k,
|
267 |
+
b_v,
|
268 |
+
)
|
269 |
+
if cache != None:
|
270 |
+
if cache["first_infer"] == 1:
|
271 |
+
cache["k"][cache["stage"]] = k
|
272 |
+
# print(0,cache["k"].shape)
|
273 |
+
cache["v"][cache["stage"]] = v
|
274 |
+
else: ###12个layer每个都要留自己的cache_kv
|
275 |
+
# print(1,cache["k"].shape)
|
276 |
+
cache["k"][cache["stage"]] = torch.cat(
|
277 |
+
[cache["k"][cache["stage"]], k], 0
|
278 |
+
) ##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
|
279 |
+
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
|
280 |
+
# print(2, cache["k"].shape)
|
281 |
+
src_len = cache["k"][cache["stage"]].shape[0]
|
282 |
+
k = cache["k"][cache["stage"]]
|
283 |
+
v = cache["v"][cache["stage"]]
|
284 |
+
# if attn_mask is not None:
|
285 |
+
# attn_mask=attn_mask[-1:,]
|
286 |
+
# print(attn_mask.shape,attn_mask)
|
287 |
+
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
|
288 |
+
# print(2333,cache)
|
289 |
+
# prep attention mask
|
290 |
+
|
291 |
+
attn_mask = _canonical_mask(
|
292 |
+
mask=attn_mask,
|
293 |
+
mask_name="attn_mask",
|
294 |
+
other_type=None,
|
295 |
+
other_name="",
|
296 |
+
target_type=q.dtype,
|
297 |
+
check_other=False,
|
298 |
+
)
|
299 |
+
|
300 |
+
if attn_mask is not None:
|
301 |
+
# ensure attn_mask's dim is 3
|
302 |
+
if attn_mask.dim() == 2:
|
303 |
+
correct_2d_size = (tgt_len, src_len)
|
304 |
+
if attn_mask.shape != correct_2d_size:
|
305 |
+
raise RuntimeError(
|
306 |
+
f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
|
307 |
+
)
|
308 |
+
attn_mask = attn_mask.unsqueeze(0)
|
309 |
+
elif attn_mask.dim() == 3:
|
310 |
+
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
|
311 |
+
if attn_mask.shape != correct_3d_size:
|
312 |
+
raise RuntimeError(
|
313 |
+
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
|
314 |
+
)
|
315 |
+
else:
|
316 |
+
raise RuntimeError(
|
317 |
+
f"attn_mask's dimension {attn_mask.dim()} is not supported"
|
318 |
+
)
|
319 |
+
|
320 |
+
# add bias along batch dimension (currently second)
|
321 |
+
if bias_k is not None and bias_v is not None:
|
322 |
+
assert static_k is None, "bias cannot be added to static key."
|
323 |
+
assert static_v is None, "bias cannot be added to static value."
|
324 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
325 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
326 |
+
if attn_mask is not None:
|
327 |
+
attn_mask = pad(attn_mask, (0, 1))
|
328 |
+
if key_padding_mask is not None:
|
329 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
330 |
+
else:
|
331 |
+
assert bias_k is None
|
332 |
+
assert bias_v is None
|
333 |
+
|
334 |
+
#
|
335 |
+
# reshape q, k, v for multihead attention and make em batch first
|
336 |
+
#
|
337 |
+
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
338 |
+
if static_k is None:
|
339 |
+
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
340 |
+
else:
|
341 |
+
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
342 |
+
assert (
|
343 |
+
static_k.size(0) == bsz * num_heads
|
344 |
+
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
345 |
+
assert (
|
346 |
+
static_k.size(2) == head_dim
|
347 |
+
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
348 |
+
k = static_k
|
349 |
+
if static_v is None:
|
350 |
+
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
351 |
+
else:
|
352 |
+
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
353 |
+
assert (
|
354 |
+
static_v.size(0) == bsz * num_heads
|
355 |
+
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
356 |
+
assert (
|
357 |
+
static_v.size(2) == head_dim
|
358 |
+
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
359 |
+
v = static_v
|
360 |
+
|
361 |
+
# add zero attention along batch dimension (now first)
|
362 |
+
if add_zero_attn:
|
363 |
+
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
364 |
+
k = torch.cat(
|
365 |
+
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
|
366 |
+
)
|
367 |
+
v = torch.cat(
|
368 |
+
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
|
369 |
+
)
|
370 |
+
if attn_mask is not None:
|
371 |
+
attn_mask = pad(attn_mask, (0, 1))
|
372 |
+
if key_padding_mask is not None:
|
373 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
374 |
+
|
375 |
+
# update source sequence length after adjustments
|
376 |
+
src_len = k.size(1)
|
377 |
+
|
378 |
+
# merge key padding and attention masks
|
379 |
+
if key_padding_mask is not None:
|
380 |
+
assert key_padding_mask.shape == (
|
381 |
+
bsz,
|
382 |
+
src_len,
|
383 |
+
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
384 |
+
key_padding_mask = (
|
385 |
+
key_padding_mask.view(bsz, 1, 1, src_len)
|
386 |
+
.expand(-1, num_heads, -1, -1)
|
387 |
+
.reshape(bsz * num_heads, 1, src_len)
|
388 |
+
)
|
389 |
+
if attn_mask is None:
|
390 |
+
attn_mask = key_padding_mask
|
391 |
+
else:
|
392 |
+
attn_mask = attn_mask + key_padding_mask
|
393 |
+
|
394 |
+
# adjust dropout probability
|
395 |
+
if not training:
|
396 |
+
dropout_p = 0.0
|
397 |
+
|
398 |
+
#
|
399 |
+
# (deep breath) calculate attention and out projection
|
400 |
+
#
|
401 |
+
|
402 |
+
if need_weights:
|
403 |
+
B, Nt, E = q.shape
|
404 |
+
q_scaled = q / math.sqrt(E)
|
405 |
+
|
406 |
+
assert not (
|
407 |
+
is_causal and attn_mask is None
|
408 |
+
), "FIXME: is_causal not implemented for need_weights"
|
409 |
+
|
410 |
+
if attn_mask is not None:
|
411 |
+
attn_output_weights = torch.baddbmm(
|
412 |
+
attn_mask, q_scaled, k.transpose(-2, -1)
|
413 |
+
)
|
414 |
+
else:
|
415 |
+
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
416 |
+
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
417 |
+
if dropout_p > 0.0:
|
418 |
+
attn_output_weights = dropout(attn_output_weights, p=dropout_p)
|
419 |
+
|
420 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
421 |
+
|
422 |
+
attn_output = (
|
423 |
+
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
424 |
+
)
|
425 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
426 |
+
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
427 |
+
|
428 |
+
# optionally average attention weights over heads
|
429 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
430 |
+
if average_attn_weights:
|
431 |
+
attn_output_weights = attn_output_weights.mean(dim=1)
|
432 |
+
|
433 |
+
if not is_batched:
|
434 |
+
# squeeze the output if input was unbatched
|
435 |
+
attn_output = attn_output.squeeze(1)
|
436 |
+
attn_output_weights = attn_output_weights.squeeze(0)
|
437 |
+
return attn_output, attn_output_weights
|
438 |
+
else:
|
439 |
+
# attn_mask can be either (L,S) or (N*num_heads, L, S)
|
440 |
+
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
|
441 |
+
# in order to match the input for SDPA of (N, num_heads, L, S)
|
442 |
+
if attn_mask is not None:
|
443 |
+
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
|
444 |
+
attn_mask = attn_mask.unsqueeze(0)
|
445 |
+
else:
|
446 |
+
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
|
447 |
+
|
448 |
+
q = q.view(bsz, num_heads, tgt_len, head_dim)
|
449 |
+
k = k.view(bsz, num_heads, src_len, head_dim)
|
450 |
+
v = v.view(bsz, num_heads, src_len, head_dim)
|
451 |
+
|
452 |
+
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
453 |
+
attn_output = scaled_dot_product_attention(
|
454 |
+
q, k, v, attn_mask, dropout_p, is_causal
|
455 |
+
)
|
456 |
+
|
457 |
+
attn_output = (
|
458 |
+
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
459 |
+
)
|
460 |
+
|
461 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
462 |
+
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
463 |
+
if not is_batched:
|
464 |
+
# squeeze the output if input was unbatched
|
465 |
+
attn_output = attn_output.squeeze(1)
|
466 |
+
return attn_output, None
|