Spaces:
Running
on
Zero
Running
on
Zero
Upload 37 files
Browse files- .gitattributes +4 -0
- LICENSE.txt +201 -0
- OmniAvatar/base.py +127 -0
- OmniAvatar/configs/__init__.py +0 -0
- OmniAvatar/configs/model_config.py +664 -0
- OmniAvatar/distributed/__init__.py +0 -0
- OmniAvatar/distributed/fsdp.py +43 -0
- OmniAvatar/distributed/xdit_context_parallel.py +134 -0
- OmniAvatar/models/audio_pack.py +40 -0
- OmniAvatar/models/model_manager.py +474 -0
- OmniAvatar/models/wan_video_dit.py +577 -0
- OmniAvatar/models/wan_video_text_encoder.py +269 -0
- OmniAvatar/models/wan_video_vae.py +807 -0
- OmniAvatar/models/wav2vec.py +209 -0
- OmniAvatar/prompters/__init__.py +1 -0
- OmniAvatar/prompters/base_prompter.py +70 -0
- OmniAvatar/prompters/wan_prompter.py +109 -0
- OmniAvatar/schedulers/flow_match.py +79 -0
- OmniAvatar/utils/args_config.py +123 -0
- OmniAvatar/utils/audio_preprocess.py +18 -0
- OmniAvatar/utils/io_utils.py +256 -0
- OmniAvatar/vram_management/__init__.py +1 -0
- OmniAvatar/vram_management/layers.py +95 -0
- OmniAvatar/wan_video.py +340 -0
- README.md +12 -13
- app.py +688 -0
- args_config.yaml +77 -0
- assets/material/pipeline.png +3 -0
- assets/material/teaser.png +3 -0
- configs/inference.yaml +37 -0
- configs/inference_1.3B.yaml +37 -0
- examples/audios/mushroom.wav +3 -0
- examples/audios/tape.wav +3 -0
- examples/images/female-001.png +3 -0
- examples/images/male-001.png +3 -0
- examples/infer_samples.txt +1 -0
- requirements.txt +13 -0
- scripts/inference.py +383 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/audios/mushroom.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
+
examples/audios/tape.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
+
examples/images/female-001.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
examples/images/male-001.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE.txt
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
OmniAvatar/base.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from torchvision.transforms import GaussianBlur
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class BasePipeline(torch.nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
|
11 |
+
super().__init__()
|
12 |
+
self.device = device
|
13 |
+
self.torch_dtype = torch_dtype
|
14 |
+
self.height_division_factor = height_division_factor
|
15 |
+
self.width_division_factor = width_division_factor
|
16 |
+
self.cpu_offload = False
|
17 |
+
self.model_names = []
|
18 |
+
|
19 |
+
|
20 |
+
def check_resize_height_width(self, height, width):
|
21 |
+
if height % self.height_division_factor != 0:
|
22 |
+
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
23 |
+
print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
|
24 |
+
if width % self.width_division_factor != 0:
|
25 |
+
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
26 |
+
print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
|
27 |
+
return height, width
|
28 |
+
|
29 |
+
|
30 |
+
def preprocess_image(self, image):
|
31 |
+
image = torch.Tensor(np.array(image, dtype=np.float16) * (2.0 / 255) - 1.0).permute(2, 0, 1).unsqueeze(0)
|
32 |
+
return image
|
33 |
+
|
34 |
+
|
35 |
+
def preprocess_images(self, images):
|
36 |
+
return [self.preprocess_image(image) for image in images]
|
37 |
+
|
38 |
+
|
39 |
+
def vae_output_to_image(self, vae_output):
|
40 |
+
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
|
41 |
+
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
42 |
+
return image
|
43 |
+
|
44 |
+
|
45 |
+
def vae_output_to_video(self, vae_output):
|
46 |
+
video = vae_output.cpu().permute(1, 2, 0).numpy()
|
47 |
+
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
|
48 |
+
return video
|
49 |
+
|
50 |
+
|
51 |
+
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
|
52 |
+
if len(latents) > 0:
|
53 |
+
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
|
54 |
+
height, width = value.shape[-2:]
|
55 |
+
weight = torch.ones_like(value)
|
56 |
+
for latent, mask, scale in zip(latents, masks, scales):
|
57 |
+
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
|
58 |
+
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
|
59 |
+
mask = blur(mask)
|
60 |
+
value += latent * mask * scale
|
61 |
+
weight += mask * scale
|
62 |
+
value /= weight
|
63 |
+
return value
|
64 |
+
|
65 |
+
|
66 |
+
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
|
67 |
+
if special_kwargs is None:
|
68 |
+
noise_pred_global = inference_callback(prompt_emb_global)
|
69 |
+
else:
|
70 |
+
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
|
71 |
+
if special_local_kwargs_list is None:
|
72 |
+
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
|
73 |
+
else:
|
74 |
+
noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
|
75 |
+
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
|
76 |
+
return noise_pred
|
77 |
+
|
78 |
+
|
79 |
+
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
|
80 |
+
local_prompts = local_prompts or []
|
81 |
+
masks = masks or []
|
82 |
+
mask_scales = mask_scales or []
|
83 |
+
extended_prompt_dict = self.prompter.extend_prompt(prompt)
|
84 |
+
prompt = extended_prompt_dict.get("prompt", prompt)
|
85 |
+
local_prompts += extended_prompt_dict.get("prompts", [])
|
86 |
+
masks += extended_prompt_dict.get("masks", [])
|
87 |
+
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
|
88 |
+
return prompt, local_prompts, masks, mask_scales
|
89 |
+
|
90 |
+
|
91 |
+
def enable_cpu_offload(self):
|
92 |
+
self.cpu_offload = True
|
93 |
+
|
94 |
+
|
95 |
+
def load_models_to_device(self, loadmodel_names=[]):
|
96 |
+
# only load models to device if cpu_offload is enabled
|
97 |
+
if not self.cpu_offload:
|
98 |
+
return
|
99 |
+
# offload the unneeded models to cpu
|
100 |
+
for model_name in self.model_names:
|
101 |
+
if model_name not in loadmodel_names:
|
102 |
+
model = getattr(self, model_name)
|
103 |
+
if model is not None:
|
104 |
+
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
105 |
+
for module in model.modules():
|
106 |
+
if hasattr(module, "offload"):
|
107 |
+
module.offload()
|
108 |
+
else:
|
109 |
+
model.cpu()
|
110 |
+
# load the needed models to device
|
111 |
+
for model_name in loadmodel_names:
|
112 |
+
model = getattr(self, model_name)
|
113 |
+
if model is not None:
|
114 |
+
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
115 |
+
for module in model.modules():
|
116 |
+
if hasattr(module, "onload"):
|
117 |
+
module.onload()
|
118 |
+
else:
|
119 |
+
model.to(self.device)
|
120 |
+
# fresh the cuda cache
|
121 |
+
torch.cuda.empty_cache()
|
122 |
+
|
123 |
+
|
124 |
+
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
|
125 |
+
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
|
126 |
+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
127 |
+
return noise
|
OmniAvatar/configs/__init__.py
ADDED
File without changes
|
OmniAvatar/configs/model_config.py
ADDED
@@ -0,0 +1,664 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing_extensions import Literal, TypeAlias
|
2 |
+
from ..models.wan_video_dit import WanModel
|
3 |
+
from ..models.wan_video_text_encoder import WanTextEncoder
|
4 |
+
from ..models.wan_video_vae import WanVideoVAE
|
5 |
+
|
6 |
+
|
7 |
+
model_loader_configs = [
|
8 |
+
# These configs are provided for detecting model type automatically.
|
9 |
+
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
10 |
+
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
11 |
+
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
12 |
+
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
13 |
+
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
14 |
+
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
15 |
+
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
16 |
+
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
17 |
+
]
|
18 |
+
huggingface_model_loader_configs = [
|
19 |
+
# These configs are provided for detecting model type automatically.
|
20 |
+
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
|
21 |
+
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
22 |
+
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
23 |
+
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
24 |
+
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
|
25 |
+
# ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
26 |
+
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
27 |
+
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
28 |
+
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
29 |
+
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
30 |
+
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
|
31 |
+
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
32 |
+
]
|
33 |
+
|
34 |
+
preset_models_on_huggingface = {
|
35 |
+
"HunyuanDiT": [
|
36 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
37 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
38 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
39 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
40 |
+
],
|
41 |
+
"stable-video-diffusion-img2vid-xt": [
|
42 |
+
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
43 |
+
],
|
44 |
+
"ExVideo-SVD-128f-v1": [
|
45 |
+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
46 |
+
],
|
47 |
+
# Stable Diffusion
|
48 |
+
"StableDiffusion_v15": [
|
49 |
+
("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
50 |
+
],
|
51 |
+
"DreamShaper_8": [
|
52 |
+
("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
53 |
+
],
|
54 |
+
# Textual Inversion
|
55 |
+
"TextualInversion_VeryBadImageNegative_v1.3": [
|
56 |
+
("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
57 |
+
],
|
58 |
+
# Stable Diffusion XL
|
59 |
+
"StableDiffusionXL_v1": [
|
60 |
+
("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
61 |
+
],
|
62 |
+
"BluePencilXL_v200": [
|
63 |
+
("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
64 |
+
],
|
65 |
+
"StableDiffusionXL_Turbo": [
|
66 |
+
("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
67 |
+
],
|
68 |
+
# Stable Diffusion 3
|
69 |
+
"StableDiffusion3": [
|
70 |
+
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
71 |
+
],
|
72 |
+
"StableDiffusion3_without_T5": [
|
73 |
+
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
74 |
+
],
|
75 |
+
# ControlNet
|
76 |
+
"ControlNet_v11f1p_sd15_depth": [
|
77 |
+
("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
78 |
+
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
79 |
+
],
|
80 |
+
"ControlNet_v11p_sd15_softedge": [
|
81 |
+
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
82 |
+
("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
|
83 |
+
],
|
84 |
+
"ControlNet_v11f1e_sd15_tile": [
|
85 |
+
("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
86 |
+
],
|
87 |
+
"ControlNet_v11p_sd15_lineart": [
|
88 |
+
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
89 |
+
("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
|
90 |
+
("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
|
91 |
+
],
|
92 |
+
"ControlNet_union_sdxl_promax": [
|
93 |
+
("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
94 |
+
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
95 |
+
],
|
96 |
+
# AnimateDiff
|
97 |
+
"AnimateDiff_v2": [
|
98 |
+
("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
99 |
+
],
|
100 |
+
"AnimateDiff_xl_beta": [
|
101 |
+
("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
102 |
+
],
|
103 |
+
|
104 |
+
# Qwen Prompt
|
105 |
+
"QwenPrompt": [
|
106 |
+
("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
107 |
+
("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
108 |
+
("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
109 |
+
("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
110 |
+
("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
111 |
+
("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
112 |
+
("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
113 |
+
("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
114 |
+
],
|
115 |
+
# Beautiful Prompt
|
116 |
+
"BeautifulPrompt": [
|
117 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
118 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
119 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
120 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
121 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
122 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
123 |
+
],
|
124 |
+
# Omost prompt
|
125 |
+
"OmostPrompt":[
|
126 |
+
("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
127 |
+
("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
128 |
+
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
129 |
+
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
130 |
+
("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
131 |
+
("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
132 |
+
("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
133 |
+
("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
134 |
+
],
|
135 |
+
# Translator
|
136 |
+
"opus-mt-zh-en": [
|
137 |
+
("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
138 |
+
("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
139 |
+
("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
140 |
+
("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
141 |
+
("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
142 |
+
("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
143 |
+
("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
144 |
+
("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
145 |
+
],
|
146 |
+
# IP-Adapter
|
147 |
+
"IP-Adapter-SD": [
|
148 |
+
("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
149 |
+
("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
150 |
+
],
|
151 |
+
"IP-Adapter-SDXL": [
|
152 |
+
("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
153 |
+
("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
154 |
+
],
|
155 |
+
"SDXL-vae-fp16-fix": [
|
156 |
+
("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
157 |
+
],
|
158 |
+
# Kolors
|
159 |
+
"Kolors": [
|
160 |
+
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
161 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
162 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
163 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
164 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
165 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
166 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
167 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
168 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
169 |
+
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
170 |
+
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
171 |
+
],
|
172 |
+
# FLUX
|
173 |
+
"FLUX.1-dev": [
|
174 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
175 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
176 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
177 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
178 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
179 |
+
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
180 |
+
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
181 |
+
],
|
182 |
+
"InstantX/FLUX.1-dev-IP-Adapter": {
|
183 |
+
"file_list": [
|
184 |
+
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
185 |
+
("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
186 |
+
("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
187 |
+
],
|
188 |
+
"load_path": [
|
189 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
190 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
191 |
+
],
|
192 |
+
},
|
193 |
+
# RIFE
|
194 |
+
"RIFE": [
|
195 |
+
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
|
196 |
+
],
|
197 |
+
# CogVideo
|
198 |
+
"CogVideoX-5B": [
|
199 |
+
("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
200 |
+
("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
201 |
+
("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
202 |
+
("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
203 |
+
("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
204 |
+
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
205 |
+
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
206 |
+
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
207 |
+
("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
208 |
+
],
|
209 |
+
# Stable Diffusion 3.5
|
210 |
+
"StableDiffusion3.5-large": [
|
211 |
+
("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
212 |
+
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
213 |
+
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
214 |
+
("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
215 |
+
],
|
216 |
+
}
|
217 |
+
preset_models_on_modelscope = {
|
218 |
+
# Hunyuan DiT
|
219 |
+
"HunyuanDiT": [
|
220 |
+
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
221 |
+
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
222 |
+
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
223 |
+
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
224 |
+
],
|
225 |
+
# Stable Video Diffusion
|
226 |
+
"stable-video-diffusion-img2vid-xt": [
|
227 |
+
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
228 |
+
],
|
229 |
+
# ExVideo
|
230 |
+
"ExVideo-SVD-128f-v1": [
|
231 |
+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
232 |
+
],
|
233 |
+
"ExVideo-CogVideoX-LoRA-129f-v1": [
|
234 |
+
("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
|
235 |
+
],
|
236 |
+
# Stable Diffusion
|
237 |
+
"StableDiffusion_v15": [
|
238 |
+
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
239 |
+
],
|
240 |
+
"DreamShaper_8": [
|
241 |
+
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
242 |
+
],
|
243 |
+
"AingDiffusion_v12": [
|
244 |
+
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
245 |
+
],
|
246 |
+
"Flat2DAnimerge_v45Sharp": [
|
247 |
+
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
248 |
+
],
|
249 |
+
# Textual Inversion
|
250 |
+
"TextualInversion_VeryBadImageNegative_v1.3": [
|
251 |
+
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
252 |
+
],
|
253 |
+
# Stable Diffusion XL
|
254 |
+
"StableDiffusionXL_v1": [
|
255 |
+
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
256 |
+
],
|
257 |
+
"BluePencilXL_v200": [
|
258 |
+
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
259 |
+
],
|
260 |
+
"StableDiffusionXL_Turbo": [
|
261 |
+
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
262 |
+
],
|
263 |
+
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
|
264 |
+
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
|
265 |
+
],
|
266 |
+
# Stable Diffusion 3
|
267 |
+
"StableDiffusion3": [
|
268 |
+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
269 |
+
],
|
270 |
+
"StableDiffusion3_without_T5": [
|
271 |
+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
272 |
+
],
|
273 |
+
# ControlNet
|
274 |
+
"ControlNet_v11f1p_sd15_depth": [
|
275 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
276 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
277 |
+
],
|
278 |
+
"ControlNet_v11p_sd15_softedge": [
|
279 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
280 |
+
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
281 |
+
],
|
282 |
+
"ControlNet_v11f1e_sd15_tile": [
|
283 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
284 |
+
],
|
285 |
+
"ControlNet_v11p_sd15_lineart": [
|
286 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
287 |
+
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
288 |
+
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
289 |
+
],
|
290 |
+
"ControlNet_union_sdxl_promax": [
|
291 |
+
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
292 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
293 |
+
],
|
294 |
+
"Annotators:Depth": [
|
295 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
|
296 |
+
],
|
297 |
+
"Annotators:Softedge": [
|
298 |
+
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
|
299 |
+
],
|
300 |
+
"Annotators:Lineart": [
|
301 |
+
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
302 |
+
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
|
303 |
+
],
|
304 |
+
"Annotators:Normal": [
|
305 |
+
("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
|
306 |
+
],
|
307 |
+
"Annotators:Openpose": [
|
308 |
+
("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
|
309 |
+
("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
|
310 |
+
("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
|
311 |
+
],
|
312 |
+
# AnimateDiff
|
313 |
+
"AnimateDiff_v2": [
|
314 |
+
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
315 |
+
],
|
316 |
+
"AnimateDiff_xl_beta": [
|
317 |
+
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
318 |
+
],
|
319 |
+
# RIFE
|
320 |
+
"RIFE": [
|
321 |
+
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
322 |
+
],
|
323 |
+
# Qwen Prompt
|
324 |
+
"QwenPrompt": {
|
325 |
+
"file_list": [
|
326 |
+
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
327 |
+
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
328 |
+
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
329 |
+
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
330 |
+
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
331 |
+
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
332 |
+
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
333 |
+
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
334 |
+
],
|
335 |
+
"load_path": [
|
336 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
337 |
+
],
|
338 |
+
},
|
339 |
+
# Beautiful Prompt
|
340 |
+
"BeautifulPrompt": {
|
341 |
+
"file_list": [
|
342 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
343 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
344 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
345 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
346 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
347 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
348 |
+
],
|
349 |
+
"load_path": [
|
350 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
351 |
+
],
|
352 |
+
},
|
353 |
+
# Omost prompt
|
354 |
+
"OmostPrompt": {
|
355 |
+
"file_list": [
|
356 |
+
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
357 |
+
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
358 |
+
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
359 |
+
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
360 |
+
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
361 |
+
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
362 |
+
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
363 |
+
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
364 |
+
],
|
365 |
+
"load_path": [
|
366 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
367 |
+
],
|
368 |
+
},
|
369 |
+
# Translator
|
370 |
+
"opus-mt-zh-en": {
|
371 |
+
"file_list": [
|
372 |
+
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
373 |
+
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
374 |
+
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
375 |
+
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
376 |
+
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
377 |
+
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
378 |
+
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
379 |
+
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
380 |
+
],
|
381 |
+
"load_path": [
|
382 |
+
"models/translator/opus-mt-zh-en",
|
383 |
+
],
|
384 |
+
},
|
385 |
+
# IP-Adapter
|
386 |
+
"IP-Adapter-SD": [
|
387 |
+
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
388 |
+
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
389 |
+
],
|
390 |
+
"IP-Adapter-SDXL": [
|
391 |
+
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
392 |
+
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
393 |
+
],
|
394 |
+
# Kolors
|
395 |
+
"Kolors": {
|
396 |
+
"file_list": [
|
397 |
+
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
398 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
399 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
400 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
401 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
402 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
403 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
404 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
405 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
406 |
+
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
407 |
+
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
408 |
+
],
|
409 |
+
"load_path": [
|
410 |
+
"models/kolors/Kolors/text_encoder",
|
411 |
+
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
412 |
+
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
|
413 |
+
],
|
414 |
+
},
|
415 |
+
"SDXL-vae-fp16-fix": [
|
416 |
+
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
417 |
+
],
|
418 |
+
# FLUX
|
419 |
+
"FLUX.1-dev": {
|
420 |
+
"file_list": [
|
421 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
422 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
423 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
424 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
425 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
426 |
+
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
427 |
+
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
428 |
+
],
|
429 |
+
"load_path": [
|
430 |
+
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
431 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
432 |
+
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
433 |
+
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
434 |
+
],
|
435 |
+
},
|
436 |
+
"FLUX.1-schnell": {
|
437 |
+
"file_list": [
|
438 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
439 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
440 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
441 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
442 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
443 |
+
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
444 |
+
("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
|
445 |
+
],
|
446 |
+
"load_path": [
|
447 |
+
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
448 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
449 |
+
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
450 |
+
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
|
451 |
+
],
|
452 |
+
},
|
453 |
+
"InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
|
454 |
+
("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
|
455 |
+
],
|
456 |
+
"jasperai/Flux.1-dev-Controlnet-Depth": [
|
457 |
+
("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
|
458 |
+
],
|
459 |
+
"jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
|
460 |
+
("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
|
461 |
+
],
|
462 |
+
"jasperai/Flux.1-dev-Controlnet-Upscaler": [
|
463 |
+
("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
|
464 |
+
],
|
465 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
|
466 |
+
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
|
467 |
+
],
|
468 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
|
469 |
+
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
|
470 |
+
],
|
471 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
|
472 |
+
("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
|
473 |
+
],
|
474 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
|
475 |
+
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
|
476 |
+
],
|
477 |
+
"InstantX/FLUX.1-dev-IP-Adapter": {
|
478 |
+
"file_list": [
|
479 |
+
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
480 |
+
("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
481 |
+
("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
482 |
+
],
|
483 |
+
"load_path": [
|
484 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
485 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
486 |
+
],
|
487 |
+
},
|
488 |
+
# ESRGAN
|
489 |
+
"ESRGAN_x4": [
|
490 |
+
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
491 |
+
],
|
492 |
+
# RIFE
|
493 |
+
"RIFE": [
|
494 |
+
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
|
495 |
+
],
|
496 |
+
# Omnigen
|
497 |
+
"OmniGen-v1": {
|
498 |
+
"file_list": [
|
499 |
+
("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
|
500 |
+
("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
|
501 |
+
("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
|
502 |
+
("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
|
503 |
+
("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
|
504 |
+
("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
|
505 |
+
],
|
506 |
+
"load_path": [
|
507 |
+
"models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
|
508 |
+
"models/OmniGen/OmniGen-v1/model.safetensors",
|
509 |
+
]
|
510 |
+
},
|
511 |
+
# CogVideo
|
512 |
+
"CogVideoX-5B": {
|
513 |
+
"file_list": [
|
514 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
515 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
516 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
517 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
518 |
+
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
519 |
+
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
520 |
+
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
521 |
+
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
522 |
+
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
523 |
+
],
|
524 |
+
"load_path": [
|
525 |
+
"models/CogVideo/CogVideoX-5b/text_encoder",
|
526 |
+
"models/CogVideo/CogVideoX-5b/transformer",
|
527 |
+
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
|
528 |
+
],
|
529 |
+
},
|
530 |
+
# Stable Diffusion 3.5
|
531 |
+
"StableDiffusion3.5-large": [
|
532 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
533 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
534 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
535 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
536 |
+
],
|
537 |
+
"StableDiffusion3.5-medium": [
|
538 |
+
("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
|
539 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
540 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
541 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
542 |
+
],
|
543 |
+
"StableDiffusion3.5-large-turbo": [
|
544 |
+
("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
|
545 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
546 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
547 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
548 |
+
],
|
549 |
+
"HunyuanVideo":{
|
550 |
+
"file_list": [
|
551 |
+
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
552 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
553 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
554 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
555 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
556 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
557 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
558 |
+
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
559 |
+
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
|
560 |
+
],
|
561 |
+
"load_path": [
|
562 |
+
"models/HunyuanVideo/text_encoder/model.safetensors",
|
563 |
+
"models/HunyuanVideo/text_encoder_2",
|
564 |
+
"models/HunyuanVideo/vae/pytorch_model.pt",
|
565 |
+
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
566 |
+
],
|
567 |
+
},
|
568 |
+
"HunyuanVideoI2V":{
|
569 |
+
"file_list": [
|
570 |
+
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
|
571 |
+
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
572 |
+
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
573 |
+
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
574 |
+
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
575 |
+
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
576 |
+
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
577 |
+
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
|
578 |
+
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
|
579 |
+
],
|
580 |
+
"load_path": [
|
581 |
+
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
582 |
+
"models/HunyuanVideoI2V/text_encoder_2",
|
583 |
+
"models/HunyuanVideoI2V/vae/pytorch_model.pt",
|
584 |
+
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
585 |
+
],
|
586 |
+
},
|
587 |
+
"HunyuanVideo-fp8":{
|
588 |
+
"file_list": [
|
589 |
+
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
590 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
591 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
592 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
593 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
594 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
595 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
596 |
+
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
597 |
+
("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
|
598 |
+
],
|
599 |
+
"load_path": [
|
600 |
+
"models/HunyuanVideo/text_encoder/model.safetensors",
|
601 |
+
"models/HunyuanVideo/text_encoder_2",
|
602 |
+
"models/HunyuanVideo/vae/pytorch_model.pt",
|
603 |
+
"models/HunyuanVideo/transformers/model.fp8.safetensors"
|
604 |
+
],
|
605 |
+
},
|
606 |
+
}
|
607 |
+
Preset_model_id: TypeAlias = Literal[
|
608 |
+
"HunyuanDiT",
|
609 |
+
"stable-video-diffusion-img2vid-xt",
|
610 |
+
"ExVideo-SVD-128f-v1",
|
611 |
+
"ExVideo-CogVideoX-LoRA-129f-v1",
|
612 |
+
"StableDiffusion_v15",
|
613 |
+
"DreamShaper_8",
|
614 |
+
"AingDiffusion_v12",
|
615 |
+
"Flat2DAnimerge_v45Sharp",
|
616 |
+
"TextualInversion_VeryBadImageNegative_v1.3",
|
617 |
+
"StableDiffusionXL_v1",
|
618 |
+
"BluePencilXL_v200",
|
619 |
+
"StableDiffusionXL_Turbo",
|
620 |
+
"ControlNet_v11f1p_sd15_depth",
|
621 |
+
"ControlNet_v11p_sd15_softedge",
|
622 |
+
"ControlNet_v11f1e_sd15_tile",
|
623 |
+
"ControlNet_v11p_sd15_lineart",
|
624 |
+
"AnimateDiff_v2",
|
625 |
+
"AnimateDiff_xl_beta",
|
626 |
+
"RIFE",
|
627 |
+
"BeautifulPrompt",
|
628 |
+
"opus-mt-zh-en",
|
629 |
+
"IP-Adapter-SD",
|
630 |
+
"IP-Adapter-SDXL",
|
631 |
+
"StableDiffusion3",
|
632 |
+
"StableDiffusion3_without_T5",
|
633 |
+
"Kolors",
|
634 |
+
"SDXL-vae-fp16-fix",
|
635 |
+
"ControlNet_union_sdxl_promax",
|
636 |
+
"FLUX.1-dev",
|
637 |
+
"FLUX.1-schnell",
|
638 |
+
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
|
639 |
+
"jasperai/Flux.1-dev-Controlnet-Depth",
|
640 |
+
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
|
641 |
+
"jasperai/Flux.1-dev-Controlnet-Upscaler",
|
642 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
|
643 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
644 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
645 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
646 |
+
"InstantX/FLUX.1-dev-IP-Adapter",
|
647 |
+
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
648 |
+
"QwenPrompt",
|
649 |
+
"OmostPrompt",
|
650 |
+
"ESRGAN_x4",
|
651 |
+
"RIFE",
|
652 |
+
"OmniGen-v1",
|
653 |
+
"CogVideoX-5B",
|
654 |
+
"Annotators:Depth",
|
655 |
+
"Annotators:Softedge",
|
656 |
+
"Annotators:Lineart",
|
657 |
+
"Annotators:Normal",
|
658 |
+
"Annotators:Openpose",
|
659 |
+
"StableDiffusion3.5-large",
|
660 |
+
"StableDiffusion3.5-medium",
|
661 |
+
"HunyuanVideo",
|
662 |
+
"HunyuanVideo-fp8",
|
663 |
+
"HunyuanVideoI2V",
|
664 |
+
]
|
OmniAvatar/distributed/__init__.py
ADDED
File without changes
|
OmniAvatar/distributed/fsdp.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
+
import gc
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
7 |
+
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
8 |
+
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
9 |
+
from torch.distributed.utils import _free_storage
|
10 |
+
|
11 |
+
|
12 |
+
def shard_model(
|
13 |
+
model,
|
14 |
+
device_id,
|
15 |
+
param_dtype=torch.bfloat16,
|
16 |
+
reduce_dtype=torch.float32,
|
17 |
+
buffer_dtype=torch.float32,
|
18 |
+
process_group=None,
|
19 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
20 |
+
sync_module_states=True,
|
21 |
+
):
|
22 |
+
model = FSDP(
|
23 |
+
module=model,
|
24 |
+
process_group=process_group,
|
25 |
+
sharding_strategy=sharding_strategy,
|
26 |
+
auto_wrap_policy=partial(
|
27 |
+
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
28 |
+
mixed_precision=MixedPrecision(
|
29 |
+
param_dtype=param_dtype,
|
30 |
+
reduce_dtype=reduce_dtype,
|
31 |
+
buffer_dtype=buffer_dtype),
|
32 |
+
device_id=device_id,
|
33 |
+
sync_module_states=sync_module_states)
|
34 |
+
return model
|
35 |
+
|
36 |
+
|
37 |
+
def free_model(model):
|
38 |
+
for m in model.modules():
|
39 |
+
if isinstance(m, FSDP):
|
40 |
+
_free_storage(m._handle.flat_param.data)
|
41 |
+
del model
|
42 |
+
gc.collect()
|
43 |
+
torch.cuda.empty_cache()
|
OmniAvatar/distributed/xdit_context_parallel.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Optional
|
3 |
+
from einops import rearrange
|
4 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
5 |
+
get_sequence_parallel_world_size,
|
6 |
+
get_sp_group)
|
7 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
8 |
+
from yunchang import LongContextAttention
|
9 |
+
|
10 |
+
def sinusoidal_embedding_1d(dim, position):
|
11 |
+
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
12 |
+
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
13 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
14 |
+
return x.to(position.dtype)
|
15 |
+
|
16 |
+
def pad_freqs(original_tensor, target_len):
|
17 |
+
seq_len, s1, s2 = original_tensor.shape
|
18 |
+
pad_size = target_len - seq_len
|
19 |
+
padding_tensor = torch.ones(
|
20 |
+
pad_size,
|
21 |
+
s1,
|
22 |
+
s2,
|
23 |
+
dtype=original_tensor.dtype,
|
24 |
+
device=original_tensor.device)
|
25 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
26 |
+
return padded_tensor
|
27 |
+
|
28 |
+
def rope_apply(x, freqs, num_heads):
|
29 |
+
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
30 |
+
s_per_rank = x.shape[1]
|
31 |
+
s_per_rank = get_sp_group().broadcast_object_list([s_per_rank], src=0)[0] # TODO: the size should be devided by sp_size
|
32 |
+
|
33 |
+
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
34 |
+
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
35 |
+
|
36 |
+
sp_size = get_sequence_parallel_world_size()
|
37 |
+
sp_rank = get_sequence_parallel_rank()
|
38 |
+
if freqs.shape[0] % sp_size != 0 and freqs.shape[0] // sp_size == s_per_rank:
|
39 |
+
s_per_rank = s_per_rank + 1
|
40 |
+
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
41 |
+
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
42 |
+
freqs_rank = freqs_rank[:x.shape[1]]
|
43 |
+
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
44 |
+
return x_out.to(x.dtype)
|
45 |
+
|
46 |
+
def usp_dit_forward(self,
|
47 |
+
x: torch.Tensor,
|
48 |
+
timestep: torch.Tensor,
|
49 |
+
context: torch.Tensor,
|
50 |
+
clip_feature: Optional[torch.Tensor] = None,
|
51 |
+
y: Optional[torch.Tensor] = None,
|
52 |
+
use_gradient_checkpointing: bool = False,
|
53 |
+
use_gradient_checkpointing_offload: bool = False,
|
54 |
+
**kwargs,
|
55 |
+
):
|
56 |
+
t = self.time_embedding(
|
57 |
+
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
58 |
+
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
59 |
+
context = self.text_embedding(context)
|
60 |
+
|
61 |
+
if self.has_image_input:
|
62 |
+
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
63 |
+
clip_embdding = self.img_emb(clip_feature)
|
64 |
+
context = torch.cat([clip_embdding, context], dim=1)
|
65 |
+
|
66 |
+
x, (f, h, w) = self.patchify(x)
|
67 |
+
|
68 |
+
freqs = torch.cat([
|
69 |
+
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
70 |
+
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
71 |
+
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
72 |
+
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
73 |
+
|
74 |
+
def create_custom_forward(module):
|
75 |
+
def custom_forward(*inputs):
|
76 |
+
return module(*inputs)
|
77 |
+
return custom_forward
|
78 |
+
|
79 |
+
# Context Parallel
|
80 |
+
x = torch.chunk(
|
81 |
+
x, get_sequence_parallel_world_size(),
|
82 |
+
dim=1)[get_sequence_parallel_rank()]
|
83 |
+
|
84 |
+
for block in self.blocks:
|
85 |
+
if self.training and use_gradient_checkpointing:
|
86 |
+
if use_gradient_checkpointing_offload:
|
87 |
+
with torch.autograd.graph.save_on_cpu():
|
88 |
+
x = torch.utils.checkpoint.checkpoint(
|
89 |
+
create_custom_forward(block),
|
90 |
+
x, context, t_mod, freqs,
|
91 |
+
use_reentrant=False,
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
x = torch.utils.checkpoint.checkpoint(
|
95 |
+
create_custom_forward(block),
|
96 |
+
x, context, t_mod, freqs,
|
97 |
+
use_reentrant=False,
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
x = block(x, context, t_mod, freqs)
|
101 |
+
|
102 |
+
x = self.head(x, t)
|
103 |
+
|
104 |
+
# Context Parallel
|
105 |
+
if x.shape[1] * get_sequence_parallel_world_size() < freqs.shape[0]:
|
106 |
+
x = torch.cat([x, x[:, -1:]], 1) # TODO: this may cause some bias, the best way is to use sp_size=2
|
107 |
+
x = get_sp_group().all_gather(x, dim=1) # TODO: the size should be devided by sp_size
|
108 |
+
x = x[:, :freqs.shape[0]]
|
109 |
+
|
110 |
+
# unpatchify
|
111 |
+
x = self.unpatchify(x, (f, h, w))
|
112 |
+
return x
|
113 |
+
|
114 |
+
|
115 |
+
def usp_attn_forward(self, x, freqs):
|
116 |
+
q = self.norm_q(self.q(x))
|
117 |
+
k = self.norm_k(self.k(x))
|
118 |
+
v = self.v(x)
|
119 |
+
|
120 |
+
q = rope_apply(q, freqs, self.num_heads)
|
121 |
+
k = rope_apply(k, freqs, self.num_heads)
|
122 |
+
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
|
123 |
+
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
124 |
+
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
125 |
+
|
126 |
+
x = xFuserLongContextAttention()(
|
127 |
+
None,
|
128 |
+
query=q,
|
129 |
+
key=k,
|
130 |
+
value=v,
|
131 |
+
)
|
132 |
+
x = x.flatten(2)
|
133 |
+
|
134 |
+
return self.o(x)
|
OmniAvatar/models/audio_pack.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Tuple, Union
|
3 |
+
import torch
|
4 |
+
from einops import rearrange
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
def make_triple(value: Union[int, Tuple[int, int, int]]) -> Tuple[int, int, int]:
|
9 |
+
value = (value,) * 3 if isinstance(value, int) else value
|
10 |
+
assert len(value) == 3
|
11 |
+
return value
|
12 |
+
|
13 |
+
|
14 |
+
class AudioPack(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
in_channels: int,
|
18 |
+
patch_size: Union[int, Tuple[int, int, int]],
|
19 |
+
dim: int,
|
20 |
+
layernorm=False,
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
t, h, w = make_triple(patch_size)
|
24 |
+
self.patch_size = t, h, w
|
25 |
+
self.proj = nn.Linear(in_channels * t * h * w, dim)
|
26 |
+
if layernorm:
|
27 |
+
self.norm_out = nn.LayerNorm(dim)
|
28 |
+
else:
|
29 |
+
self.norm_out = None
|
30 |
+
|
31 |
+
def forward(
|
32 |
+
self,
|
33 |
+
vid: torch.Tensor,
|
34 |
+
) -> torch.Tensor:
|
35 |
+
t, h, w = self.patch_size
|
36 |
+
vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w)
|
37 |
+
vid = self.proj(vid)
|
38 |
+
if self.norm_out is not None:
|
39 |
+
vid = self.norm_out(vid)
|
40 |
+
return vid
|
OmniAvatar/models/model_manager.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch, json, importlib
|
2 |
+
from typing import List
|
3 |
+
import torch.nn as nn
|
4 |
+
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs
|
5 |
+
from ..utils.io_utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix, smart_load_weights
|
6 |
+
|
7 |
+
class GeneralLoRAFromPeft:
|
8 |
+
|
9 |
+
def get_name_dict(self, lora_state_dict):
|
10 |
+
lora_name_dict = {}
|
11 |
+
for key in lora_state_dict:
|
12 |
+
if ".lora_B." not in key:
|
13 |
+
continue
|
14 |
+
keys = key.split(".")
|
15 |
+
if len(keys) > keys.index("lora_B") + 2:
|
16 |
+
keys.pop(keys.index("lora_B") + 1)
|
17 |
+
keys.pop(keys.index("lora_B"))
|
18 |
+
if keys[0] == "diffusion_model":
|
19 |
+
keys.pop(0)
|
20 |
+
target_name = ".".join(keys)
|
21 |
+
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
22 |
+
return lora_name_dict
|
23 |
+
|
24 |
+
|
25 |
+
def match(self, model: torch.nn.Module, state_dict_lora):
|
26 |
+
lora_name_dict = self.get_name_dict(state_dict_lora)
|
27 |
+
model_name_dict = {name: None for name, _ in model.named_parameters()}
|
28 |
+
matched_num = sum([i in model_name_dict for i in lora_name_dict])
|
29 |
+
if matched_num == len(lora_name_dict):
|
30 |
+
return "", ""
|
31 |
+
else:
|
32 |
+
return None
|
33 |
+
|
34 |
+
|
35 |
+
def fetch_device_and_dtype(self, state_dict):
|
36 |
+
device, dtype = None, None
|
37 |
+
for name, param in state_dict.items():
|
38 |
+
device, dtype = param.device, param.dtype
|
39 |
+
break
|
40 |
+
computation_device = device
|
41 |
+
computation_dtype = dtype
|
42 |
+
if computation_device == torch.device("cpu"):
|
43 |
+
if torch.cuda.is_available():
|
44 |
+
computation_device = torch.device("cuda")
|
45 |
+
if computation_dtype == torch.float8_e4m3fn:
|
46 |
+
computation_dtype = torch.float32
|
47 |
+
return device, dtype, computation_device, computation_dtype
|
48 |
+
|
49 |
+
|
50 |
+
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
51 |
+
state_dict_model = model.state_dict()
|
52 |
+
device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
|
53 |
+
lora_name_dict = self.get_name_dict(state_dict_lora)
|
54 |
+
for name in lora_name_dict:
|
55 |
+
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
|
56 |
+
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
|
57 |
+
if len(weight_up.shape) == 4:
|
58 |
+
weight_up = weight_up.squeeze(3).squeeze(2)
|
59 |
+
weight_down = weight_down.squeeze(3).squeeze(2)
|
60 |
+
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
61 |
+
else:
|
62 |
+
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
63 |
+
weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
|
64 |
+
weight_patched = weight_model + weight_lora
|
65 |
+
state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
|
66 |
+
print(f" {len(lora_name_dict)} tensors are updated.")
|
67 |
+
model.load_state_dict(state_dict_model)
|
68 |
+
|
69 |
+
|
70 |
+
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device, infer):
|
71 |
+
loaded_model_names, loaded_models = [], []
|
72 |
+
for model_name, model_class in zip(model_names, model_classes):
|
73 |
+
print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
74 |
+
state_dict_converter = model_class.state_dict_converter()
|
75 |
+
if model_resource == "civitai":
|
76 |
+
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
77 |
+
elif model_resource == "diffusers":
|
78 |
+
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
79 |
+
if isinstance(state_dict_results, tuple):
|
80 |
+
model_state_dict, extra_kwargs = state_dict_results
|
81 |
+
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
|
82 |
+
else:
|
83 |
+
model_state_dict, extra_kwargs = state_dict_results, {}
|
84 |
+
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
85 |
+
with init_weights_on_device():
|
86 |
+
model = model_class(**extra_kwargs)
|
87 |
+
if hasattr(model, "eval"):
|
88 |
+
model = model.eval()
|
89 |
+
if not infer: # 训练才初始化
|
90 |
+
model = model.to_empty(device=torch.device("cuda"))
|
91 |
+
for name, param in model.named_parameters():
|
92 |
+
if param.dim() > 1: # 通常只对权重矩阵而不是偏置做初始化
|
93 |
+
nn.init.xavier_uniform_(param, gain=0.05)
|
94 |
+
else:
|
95 |
+
nn.init.zeros_(param)
|
96 |
+
else:
|
97 |
+
model = model.to_empty(device=device)
|
98 |
+
model, _, _ = smart_load_weights(model, model_state_dict)
|
99 |
+
# model.load_state_dict(model_state_dict, assign=True, strict=False)
|
100 |
+
model = model.to(dtype=torch_dtype, device=device)
|
101 |
+
loaded_model_names.append(model_name)
|
102 |
+
loaded_models.append(model)
|
103 |
+
return loaded_model_names, loaded_models
|
104 |
+
|
105 |
+
|
106 |
+
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
107 |
+
loaded_model_names, loaded_models = [], []
|
108 |
+
for model_name, model_class in zip(model_names, model_classes):
|
109 |
+
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
110 |
+
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
111 |
+
else:
|
112 |
+
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
|
113 |
+
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
114 |
+
model = model.half()
|
115 |
+
try:
|
116 |
+
model = model.to(device=device)
|
117 |
+
except:
|
118 |
+
pass
|
119 |
+
loaded_model_names.append(model_name)
|
120 |
+
loaded_models.append(model)
|
121 |
+
return loaded_model_names, loaded_models
|
122 |
+
|
123 |
+
|
124 |
+
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
|
125 |
+
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
|
126 |
+
base_state_dict = base_model.state_dict()
|
127 |
+
base_model.to("cpu")
|
128 |
+
del base_model
|
129 |
+
model = model_class(**extra_kwargs)
|
130 |
+
model.load_state_dict(base_state_dict, strict=False)
|
131 |
+
model.load_state_dict(state_dict, strict=False)
|
132 |
+
model.to(dtype=torch_dtype, device=device)
|
133 |
+
return model
|
134 |
+
|
135 |
+
|
136 |
+
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
|
137 |
+
loaded_model_names, loaded_models = [], []
|
138 |
+
for model_name, model_class in zip(model_names, model_classes):
|
139 |
+
while True:
|
140 |
+
for model_id in range(len(model_manager.model)):
|
141 |
+
base_model_name = model_manager.model_name[model_id]
|
142 |
+
if base_model_name == model_name:
|
143 |
+
base_model_path = model_manager.model_path[model_id]
|
144 |
+
base_model = model_manager.model[model_id]
|
145 |
+
print(f" Adding patch model to {base_model_name} ({base_model_path})")
|
146 |
+
patched_model = load_single_patch_model_from_single_file(
|
147 |
+
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
|
148 |
+
loaded_model_names.append(base_model_name)
|
149 |
+
loaded_models.append(patched_model)
|
150 |
+
model_manager.model.pop(model_id)
|
151 |
+
model_manager.model_path.pop(model_id)
|
152 |
+
model_manager.model_name.pop(model_id)
|
153 |
+
break
|
154 |
+
else:
|
155 |
+
break
|
156 |
+
return loaded_model_names, loaded_models
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
class ModelDetectorTemplate:
|
161 |
+
def __init__(self):
|
162 |
+
pass
|
163 |
+
|
164 |
+
def match(self, file_path="", state_dict={}):
|
165 |
+
return False
|
166 |
+
|
167 |
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
168 |
+
return [], []
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
class ModelDetectorFromSingleFile:
|
173 |
+
def __init__(self, model_loader_configs=[]):
|
174 |
+
self.keys_hash_with_shape_dict = {}
|
175 |
+
self.keys_hash_dict = {}
|
176 |
+
for metadata in model_loader_configs:
|
177 |
+
self.add_model_metadata(*metadata)
|
178 |
+
|
179 |
+
|
180 |
+
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
|
181 |
+
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
|
182 |
+
if keys_hash is not None:
|
183 |
+
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
|
184 |
+
|
185 |
+
|
186 |
+
def match(self, file_path="", state_dict={}):
|
187 |
+
if isinstance(file_path, str) and os.path.isdir(file_path):
|
188 |
+
return False
|
189 |
+
if len(state_dict) == 0:
|
190 |
+
state_dict = load_state_dict(file_path)
|
191 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
192 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
193 |
+
return True
|
194 |
+
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
195 |
+
if keys_hash in self.keys_hash_dict:
|
196 |
+
return True
|
197 |
+
return False
|
198 |
+
|
199 |
+
|
200 |
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, infer=False, **kwargs):
|
201 |
+
if len(state_dict) == 0:
|
202 |
+
state_dict = load_state_dict(file_path)
|
203 |
+
|
204 |
+
# Load models with strict matching
|
205 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
206 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
207 |
+
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
208 |
+
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device, infer)
|
209 |
+
return loaded_model_names, loaded_models
|
210 |
+
|
211 |
+
# Load models without strict matching
|
212 |
+
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
213 |
+
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
214 |
+
if keys_hash in self.keys_hash_dict:
|
215 |
+
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
216 |
+
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device, infer)
|
217 |
+
return loaded_model_names, loaded_models
|
218 |
+
|
219 |
+
return loaded_model_names, loaded_models
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
224 |
+
def __init__(self, model_loader_configs=[]):
|
225 |
+
super().__init__(model_loader_configs)
|
226 |
+
|
227 |
+
|
228 |
+
def match(self, file_path="", state_dict={}):
|
229 |
+
if isinstance(file_path, str) and os.path.isdir(file_path):
|
230 |
+
return False
|
231 |
+
if len(state_dict) == 0:
|
232 |
+
state_dict = load_state_dict(file_path)
|
233 |
+
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
234 |
+
for sub_state_dict in splited_state_dict:
|
235 |
+
if super().match(file_path, sub_state_dict):
|
236 |
+
return True
|
237 |
+
return False
|
238 |
+
|
239 |
+
|
240 |
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
241 |
+
# Split the state_dict and load from each component
|
242 |
+
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
243 |
+
valid_state_dict = {}
|
244 |
+
for sub_state_dict in splited_state_dict:
|
245 |
+
if super().match(file_path, sub_state_dict):
|
246 |
+
valid_state_dict.update(sub_state_dict)
|
247 |
+
if super().match(file_path, valid_state_dict):
|
248 |
+
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
|
249 |
+
else:
|
250 |
+
loaded_model_names, loaded_models = [], []
|
251 |
+
for sub_state_dict in splited_state_dict:
|
252 |
+
if super().match(file_path, sub_state_dict):
|
253 |
+
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
|
254 |
+
loaded_model_names += loaded_model_names_
|
255 |
+
loaded_models += loaded_models_
|
256 |
+
return loaded_model_names, loaded_models
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
class ModelDetectorFromHuggingfaceFolder:
|
261 |
+
def __init__(self, model_loader_configs=[]):
|
262 |
+
self.architecture_dict = {}
|
263 |
+
for metadata in model_loader_configs:
|
264 |
+
self.add_model_metadata(*metadata)
|
265 |
+
|
266 |
+
|
267 |
+
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
|
268 |
+
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
|
269 |
+
|
270 |
+
|
271 |
+
def match(self, file_path="", state_dict={}):
|
272 |
+
if not isinstance(file_path, str) or os.path.isfile(file_path):
|
273 |
+
return False
|
274 |
+
file_list = os.listdir(file_path)
|
275 |
+
if "config.json" not in file_list:
|
276 |
+
return False
|
277 |
+
with open(os.path.join(file_path, "config.json"), "r") as f:
|
278 |
+
config = json.load(f)
|
279 |
+
if "architectures" not in config and "_class_name" not in config:
|
280 |
+
return False
|
281 |
+
return True
|
282 |
+
|
283 |
+
|
284 |
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
285 |
+
with open(os.path.join(file_path, "config.json"), "r") as f:
|
286 |
+
config = json.load(f)
|
287 |
+
loaded_model_names, loaded_models = [], []
|
288 |
+
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
|
289 |
+
for architecture in architectures:
|
290 |
+
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
|
291 |
+
if redirected_architecture is not None:
|
292 |
+
architecture = redirected_architecture
|
293 |
+
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
294 |
+
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
295 |
+
loaded_model_names += loaded_model_names_
|
296 |
+
loaded_models += loaded_models_
|
297 |
+
return loaded_model_names, loaded_models
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
class ModelDetectorFromPatchedSingleFile:
|
302 |
+
def __init__(self, model_loader_configs=[]):
|
303 |
+
self.keys_hash_with_shape_dict = {}
|
304 |
+
for metadata in model_loader_configs:
|
305 |
+
self.add_model_metadata(*metadata)
|
306 |
+
|
307 |
+
|
308 |
+
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
|
309 |
+
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
|
310 |
+
|
311 |
+
|
312 |
+
def match(self, file_path="", state_dict={}):
|
313 |
+
if not isinstance(file_path, str) or os.path.isdir(file_path):
|
314 |
+
return False
|
315 |
+
if len(state_dict) == 0:
|
316 |
+
state_dict = load_state_dict(file_path)
|
317 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
318 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
319 |
+
return True
|
320 |
+
return False
|
321 |
+
|
322 |
+
|
323 |
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
|
324 |
+
if len(state_dict) == 0:
|
325 |
+
state_dict = load_state_dict(file_path)
|
326 |
+
|
327 |
+
# Load models with strict matching
|
328 |
+
loaded_model_names, loaded_models = [], []
|
329 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
330 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
331 |
+
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
332 |
+
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
333 |
+
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
|
334 |
+
loaded_model_names += loaded_model_names_
|
335 |
+
loaded_models += loaded_models_
|
336 |
+
return loaded_model_names, loaded_models
|
337 |
+
|
338 |
+
|
339 |
+
|
340 |
+
class ModelManager:
|
341 |
+
def __init__(
|
342 |
+
self,
|
343 |
+
torch_dtype=torch.float16,
|
344 |
+
device="cuda",
|
345 |
+
model_id_list: List = [],
|
346 |
+
downloading_priority: List = ["ModelScope", "HuggingFace"],
|
347 |
+
file_path_list: List[str] = [],
|
348 |
+
infer: bool = False
|
349 |
+
):
|
350 |
+
self.torch_dtype = torch_dtype
|
351 |
+
self.device = device
|
352 |
+
self.model = []
|
353 |
+
self.model_path = []
|
354 |
+
self.model_name = []
|
355 |
+
self.infer = infer
|
356 |
+
downloaded_files = []
|
357 |
+
self.model_detector = [
|
358 |
+
ModelDetectorFromSingleFile(model_loader_configs),
|
359 |
+
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
360 |
+
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
361 |
+
]
|
362 |
+
self.load_models(downloaded_files + file_path_list)
|
363 |
+
|
364 |
+
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
365 |
+
if isinstance(file_path, list):
|
366 |
+
for file_path_ in file_path:
|
367 |
+
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
368 |
+
else:
|
369 |
+
print(f"Loading LoRA models from file: {file_path}")
|
370 |
+
is_loaded = False
|
371 |
+
if len(state_dict) == 0:
|
372 |
+
state_dict = load_state_dict(file_path)
|
373 |
+
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
374 |
+
lora = GeneralLoRAFromPeft()
|
375 |
+
match_results = lora.match(model, state_dict)
|
376 |
+
if match_results is not None:
|
377 |
+
print(f" Adding LoRA to {model_name} ({model_path}).")
|
378 |
+
lora_prefix, model_resource = match_results
|
379 |
+
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
380 |
+
|
381 |
+
|
382 |
+
|
383 |
+
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
|
384 |
+
print(f"Loading models from file: {file_path}")
|
385 |
+
if len(state_dict) == 0:
|
386 |
+
state_dict = load_state_dict(file_path)
|
387 |
+
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device, self.infer)
|
388 |
+
for model_name, model in zip(model_names, models):
|
389 |
+
self.model.append(model)
|
390 |
+
self.model_path.append(file_path)
|
391 |
+
self.model_name.append(model_name)
|
392 |
+
print(f" The following models are loaded: {model_names}.")
|
393 |
+
|
394 |
+
|
395 |
+
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
|
396 |
+
print(f"Loading models from folder: {file_path}")
|
397 |
+
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
|
398 |
+
for model_name, model in zip(model_names, models):
|
399 |
+
self.model.append(model)
|
400 |
+
self.model_path.append(file_path)
|
401 |
+
self.model_name.append(model_name)
|
402 |
+
print(f" The following models are loaded: {model_names}.")
|
403 |
+
|
404 |
+
|
405 |
+
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
|
406 |
+
print(f"Loading patch models from file: {file_path}")
|
407 |
+
model_names, models = load_patch_model_from_single_file(
|
408 |
+
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
|
409 |
+
for model_name, model in zip(model_names, models):
|
410 |
+
self.model.append(model)
|
411 |
+
self.model_path.append(file_path)
|
412 |
+
self.model_name.append(model_name)
|
413 |
+
print(f" The following patched models are loaded: {model_names}.")
|
414 |
+
|
415 |
+
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
416 |
+
print(f"Loading models from: {file_path}")
|
417 |
+
if device is None: device = self.device
|
418 |
+
if torch_dtype is None: torch_dtype = self.torch_dtype
|
419 |
+
if isinstance(file_path, list):
|
420 |
+
state_dict = {}
|
421 |
+
for path in file_path:
|
422 |
+
state_dict.update(load_state_dict(path))
|
423 |
+
elif os.path.isfile(file_path):
|
424 |
+
state_dict = load_state_dict(file_path)
|
425 |
+
else:
|
426 |
+
state_dict = None
|
427 |
+
for model_detector in self.model_detector:
|
428 |
+
if model_detector.match(file_path, state_dict):
|
429 |
+
model_names, models = model_detector.load(
|
430 |
+
file_path, state_dict,
|
431 |
+
device=device, torch_dtype=torch_dtype,
|
432 |
+
allowed_model_names=model_names, model_manager=self, infer=self.infer
|
433 |
+
)
|
434 |
+
for model_name, model in zip(model_names, models):
|
435 |
+
self.model.append(model)
|
436 |
+
self.model_path.append(file_path)
|
437 |
+
self.model_name.append(model_name)
|
438 |
+
print(f" The following models are loaded: {model_names}.")
|
439 |
+
break
|
440 |
+
else:
|
441 |
+
print(f" We cannot detect the model type. No models are loaded.")
|
442 |
+
|
443 |
+
|
444 |
+
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
|
445 |
+
for file_path in file_path_list:
|
446 |
+
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
|
447 |
+
|
448 |
+
|
449 |
+
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
450 |
+
fetched_models = []
|
451 |
+
fetched_model_paths = []
|
452 |
+
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
453 |
+
if file_path is not None and file_path != model_path:
|
454 |
+
continue
|
455 |
+
if model_name == model_name_:
|
456 |
+
fetched_models.append(model)
|
457 |
+
fetched_model_paths.append(model_path)
|
458 |
+
if len(fetched_models) == 0:
|
459 |
+
print(f"No {model_name} models available.")
|
460 |
+
return None
|
461 |
+
if len(fetched_models) == 1:
|
462 |
+
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
463 |
+
else:
|
464 |
+
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
465 |
+
if require_model_path:
|
466 |
+
return fetched_models[0], fetched_model_paths[0]
|
467 |
+
else:
|
468 |
+
return fetched_models[0]
|
469 |
+
|
470 |
+
|
471 |
+
def to(self, device):
|
472 |
+
for model in self.model:
|
473 |
+
model.to(device)
|
474 |
+
|
OmniAvatar/models/wan_video_dit.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
from typing import Tuple, Optional
|
6 |
+
from einops import rearrange
|
7 |
+
from ..utils.io_utils import hash_state_dict_keys
|
8 |
+
from .audio_pack import AudioPack
|
9 |
+
from ..utils.args_config import args
|
10 |
+
|
11 |
+
if args.sp_size > 1:
|
12 |
+
# Context Parallel
|
13 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
14 |
+
get_sequence_parallel_world_size,
|
15 |
+
get_sp_group)
|
16 |
+
|
17 |
+
try:
|
18 |
+
import flash_attn_interface
|
19 |
+
print('using flash_attn_interface')
|
20 |
+
FLASH_ATTN_3_AVAILABLE = True
|
21 |
+
except ModuleNotFoundError:
|
22 |
+
FLASH_ATTN_3_AVAILABLE = False
|
23 |
+
|
24 |
+
try:
|
25 |
+
import flash_attn
|
26 |
+
print('using flash_attn')
|
27 |
+
FLASH_ATTN_2_AVAILABLE = True
|
28 |
+
except ModuleNotFoundError:
|
29 |
+
FLASH_ATTN_2_AVAILABLE = False
|
30 |
+
|
31 |
+
try:
|
32 |
+
from sageattention import sageattn
|
33 |
+
print('using sageattention')
|
34 |
+
SAGE_ATTN_AVAILABLE = True
|
35 |
+
except ModuleNotFoundError:
|
36 |
+
SAGE_ATTN_AVAILABLE = False
|
37 |
+
|
38 |
+
|
39 |
+
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):
|
40 |
+
if compatibility_mode:
|
41 |
+
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
42 |
+
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
43 |
+
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
44 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
45 |
+
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
46 |
+
elif FLASH_ATTN_3_AVAILABLE:
|
47 |
+
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
48 |
+
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
49 |
+
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
50 |
+
x = flash_attn_interface.flash_attn_func(q, k, v)
|
51 |
+
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
52 |
+
elif FLASH_ATTN_2_AVAILABLE:
|
53 |
+
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
54 |
+
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
55 |
+
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
56 |
+
x = flash_attn.flash_attn_func(q, k, v)
|
57 |
+
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
58 |
+
elif SAGE_ATTN_AVAILABLE:
|
59 |
+
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
60 |
+
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
61 |
+
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
62 |
+
x = sageattn(q, k, v)
|
63 |
+
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
64 |
+
else:
|
65 |
+
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
66 |
+
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
67 |
+
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
68 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
69 |
+
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
74 |
+
return (x * (1 + scale) + shift)
|
75 |
+
|
76 |
+
|
77 |
+
def sinusoidal_embedding_1d(dim, position):
|
78 |
+
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
79 |
+
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
80 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
81 |
+
return x.to(position.dtype)
|
82 |
+
|
83 |
+
|
84 |
+
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
|
85 |
+
# 3d rope precompute
|
86 |
+
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
|
87 |
+
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
88 |
+
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
89 |
+
return f_freqs_cis, h_freqs_cis, w_freqs_cis
|
90 |
+
|
91 |
+
|
92 |
+
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
93 |
+
# 1d rope precompute
|
94 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
|
95 |
+
[: (dim // 2)].double() / dim))
|
96 |
+
freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
|
97 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
98 |
+
return freqs_cis
|
99 |
+
|
100 |
+
|
101 |
+
def rope_apply(x, freqs, num_heads):
|
102 |
+
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
103 |
+
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
104 |
+
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
105 |
+
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
106 |
+
return x_out.to(x.dtype)
|
107 |
+
|
108 |
+
|
109 |
+
class RMSNorm(nn.Module):
|
110 |
+
def __init__(self, dim, eps=1e-5):
|
111 |
+
super().__init__()
|
112 |
+
self.eps = eps
|
113 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
114 |
+
|
115 |
+
def norm(self, x):
|
116 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
dtype = x.dtype
|
120 |
+
return self.norm(x.float()).to(dtype) * self.weight
|
121 |
+
|
122 |
+
|
123 |
+
class AttentionModule(nn.Module):
|
124 |
+
def __init__(self, num_heads):
|
125 |
+
super().__init__()
|
126 |
+
self.num_heads = num_heads
|
127 |
+
|
128 |
+
def forward(self, q, k, v):
|
129 |
+
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
|
130 |
+
return x
|
131 |
+
|
132 |
+
|
133 |
+
class SelfAttention(nn.Module):
|
134 |
+
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
135 |
+
super().__init__()
|
136 |
+
self.dim = dim
|
137 |
+
self.num_heads = num_heads
|
138 |
+
self.head_dim = dim // num_heads
|
139 |
+
|
140 |
+
self.q = nn.Linear(dim, dim)
|
141 |
+
self.k = nn.Linear(dim, dim)
|
142 |
+
self.v = nn.Linear(dim, dim)
|
143 |
+
self.o = nn.Linear(dim, dim)
|
144 |
+
self.norm_q = RMSNorm(dim, eps=eps)
|
145 |
+
self.norm_k = RMSNorm(dim, eps=eps)
|
146 |
+
|
147 |
+
self.attn = AttentionModule(self.num_heads)
|
148 |
+
|
149 |
+
def forward(self, x, freqs):
|
150 |
+
q = self.norm_q(self.q(x))
|
151 |
+
k = self.norm_k(self.k(x))
|
152 |
+
v = self.v(x)
|
153 |
+
q = rope_apply(q, freqs, self.num_heads)
|
154 |
+
k = rope_apply(k, freqs, self.num_heads)
|
155 |
+
x = self.attn(q, k, v)
|
156 |
+
return self.o(x)
|
157 |
+
|
158 |
+
|
159 |
+
class CrossAttention(nn.Module):
|
160 |
+
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):
|
161 |
+
super().__init__()
|
162 |
+
self.dim = dim
|
163 |
+
self.num_heads = num_heads
|
164 |
+
self.head_dim = dim // num_heads
|
165 |
+
|
166 |
+
self.q = nn.Linear(dim, dim)
|
167 |
+
self.k = nn.Linear(dim, dim)
|
168 |
+
self.v = nn.Linear(dim, dim)
|
169 |
+
self.o = nn.Linear(dim, dim)
|
170 |
+
self.norm_q = RMSNorm(dim, eps=eps)
|
171 |
+
self.norm_k = RMSNorm(dim, eps=eps)
|
172 |
+
self.has_image_input = has_image_input
|
173 |
+
if has_image_input:
|
174 |
+
self.k_img = nn.Linear(dim, dim)
|
175 |
+
self.v_img = nn.Linear(dim, dim)
|
176 |
+
self.norm_k_img = RMSNorm(dim, eps=eps)
|
177 |
+
|
178 |
+
self.attn = AttentionModule(self.num_heads)
|
179 |
+
|
180 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
181 |
+
if self.has_image_input:
|
182 |
+
img = y[:, :257]
|
183 |
+
ctx = y[:, 257:]
|
184 |
+
else:
|
185 |
+
ctx = y
|
186 |
+
q = self.norm_q(self.q(x))
|
187 |
+
k = self.norm_k(self.k(ctx))
|
188 |
+
v = self.v(ctx)
|
189 |
+
x = self.attn(q, k, v)
|
190 |
+
if self.has_image_input:
|
191 |
+
k_img = self.norm_k_img(self.k_img(img))
|
192 |
+
v_img = self.v_img(img)
|
193 |
+
y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
|
194 |
+
x = x + y
|
195 |
+
return self.o(x)
|
196 |
+
|
197 |
+
|
198 |
+
class GateModule(nn.Module):
|
199 |
+
def __init__(self,):
|
200 |
+
super().__init__()
|
201 |
+
|
202 |
+
def forward(self, x, gate, residual):
|
203 |
+
return x + gate * residual
|
204 |
+
|
205 |
+
class DiTBlock(nn.Module):
|
206 |
+
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
|
207 |
+
super().__init__()
|
208 |
+
self.dim = dim
|
209 |
+
self.num_heads = num_heads
|
210 |
+
self.ffn_dim = ffn_dim
|
211 |
+
|
212 |
+
self.self_attn = SelfAttention(dim, num_heads, eps)
|
213 |
+
self.cross_attn = CrossAttention(
|
214 |
+
dim, num_heads, eps, has_image_input=has_image_input)
|
215 |
+
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
216 |
+
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
217 |
+
self.norm3 = nn.LayerNorm(dim, eps=eps)
|
218 |
+
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
|
219 |
+
approximate='tanh'), nn.Linear(ffn_dim, dim))
|
220 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
221 |
+
self.gate = GateModule()
|
222 |
+
|
223 |
+
def forward(self, x, context, t_mod, freqs):
|
224 |
+
# msa: multi-head self-attention mlp: multi-layer perceptron
|
225 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
226 |
+
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
227 |
+
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
228 |
+
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
229 |
+
x = x + self.cross_attn(self.norm3(x), context)
|
230 |
+
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
231 |
+
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
232 |
+
return x
|
233 |
+
|
234 |
+
|
235 |
+
class MLP(torch.nn.Module):
|
236 |
+
def __init__(self, in_dim, out_dim):
|
237 |
+
super().__init__()
|
238 |
+
self.proj = torch.nn.Sequential(
|
239 |
+
nn.LayerNorm(in_dim),
|
240 |
+
nn.Linear(in_dim, in_dim),
|
241 |
+
nn.GELU(),
|
242 |
+
nn.Linear(in_dim, out_dim),
|
243 |
+
nn.LayerNorm(out_dim)
|
244 |
+
)
|
245 |
+
|
246 |
+
def forward(self, x):
|
247 |
+
return self.proj(x)
|
248 |
+
|
249 |
+
|
250 |
+
class Head(nn.Module):
|
251 |
+
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
|
252 |
+
super().__init__()
|
253 |
+
self.dim = dim
|
254 |
+
self.patch_size = patch_size
|
255 |
+
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
256 |
+
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
|
257 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
258 |
+
|
259 |
+
def forward(self, x, t_mod):
|
260 |
+
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
|
261 |
+
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
262 |
+
return x
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
class WanModel(torch.nn.Module):
|
267 |
+
def __init__(
|
268 |
+
self,
|
269 |
+
dim: int,
|
270 |
+
in_dim: int,
|
271 |
+
ffn_dim: int,
|
272 |
+
out_dim: int,
|
273 |
+
text_dim: int,
|
274 |
+
freq_dim: int,
|
275 |
+
eps: float,
|
276 |
+
patch_size: Tuple[int, int, int],
|
277 |
+
num_heads: int,
|
278 |
+
num_layers: int,
|
279 |
+
has_image_input: bool,
|
280 |
+
audio_hidden_size: int=32,
|
281 |
+
):
|
282 |
+
super().__init__()
|
283 |
+
self.dim = dim
|
284 |
+
self.freq_dim = freq_dim
|
285 |
+
self.has_image_input = has_image_input
|
286 |
+
self.patch_size = patch_size
|
287 |
+
|
288 |
+
self.patch_embedding = nn.Conv3d(
|
289 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
290 |
+
# nn.LayerNorm(dim)
|
291 |
+
self.text_embedding = nn.Sequential(
|
292 |
+
nn.Linear(text_dim, dim),
|
293 |
+
nn.GELU(approximate='tanh'),
|
294 |
+
nn.Linear(dim, dim)
|
295 |
+
)
|
296 |
+
self.time_embedding = nn.Sequential(
|
297 |
+
nn.Linear(freq_dim, dim),
|
298 |
+
nn.SiLU(),
|
299 |
+
nn.Linear(dim, dim)
|
300 |
+
)
|
301 |
+
self.time_projection = nn.Sequential(
|
302 |
+
nn.SiLU(), nn.Linear(dim, dim * 6))
|
303 |
+
self.blocks = nn.ModuleList([
|
304 |
+
DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
|
305 |
+
for _ in range(num_layers)
|
306 |
+
])
|
307 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
308 |
+
head_dim = dim // num_heads
|
309 |
+
self.freqs = precompute_freqs_cis_3d(head_dim)
|
310 |
+
|
311 |
+
if has_image_input:
|
312 |
+
self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
|
313 |
+
|
314 |
+
if 'use_audio' in args:
|
315 |
+
self.use_audio = args.use_audio
|
316 |
+
else:
|
317 |
+
self.use_audio = False
|
318 |
+
if self.use_audio:
|
319 |
+
audio_input_dim = 10752
|
320 |
+
audio_out_dim = dim
|
321 |
+
self.audio_proj = AudioPack(audio_input_dim, [4, 1, 1], audio_hidden_size, layernorm=True)
|
322 |
+
self.audio_cond_projs = nn.ModuleList()
|
323 |
+
for d in range(num_layers // 2 - 1):
|
324 |
+
l = nn.Linear(audio_hidden_size, audio_out_dim)
|
325 |
+
self.audio_cond_projs.append(l)
|
326 |
+
|
327 |
+
def patchify(self, x: torch.Tensor):
|
328 |
+
grid_size = x.shape[2:]
|
329 |
+
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
330 |
+
return x, grid_size # x, grid_size: (f, h, w)
|
331 |
+
|
332 |
+
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
333 |
+
return rearrange(
|
334 |
+
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
|
335 |
+
f=grid_size[0], h=grid_size[1], w=grid_size[2],
|
336 |
+
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
|
337 |
+
)
|
338 |
+
|
339 |
+
def forward(self,
|
340 |
+
x: torch.Tensor,
|
341 |
+
timestep: torch.Tensor,
|
342 |
+
context: torch.Tensor,
|
343 |
+
clip_feature: Optional[torch.Tensor] = None,
|
344 |
+
y: Optional[torch.Tensor] = None,
|
345 |
+
use_gradient_checkpointing: bool = False,
|
346 |
+
audio_emb: Optional[torch.Tensor] = None,
|
347 |
+
use_gradient_checkpointing_offload: bool = False,
|
348 |
+
tea_cache = None,
|
349 |
+
**kwargs,
|
350 |
+
):
|
351 |
+
t = self.time_embedding(
|
352 |
+
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
353 |
+
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
354 |
+
context = self.text_embedding(context)
|
355 |
+
lat_h, lat_w = x.shape[-2], x.shape[-1]
|
356 |
+
|
357 |
+
if audio_emb != None and self.use_audio: # TODO cache
|
358 |
+
audio_emb = audio_emb.permute(0, 2, 1)[:, :, :, None, None]
|
359 |
+
audio_emb = torch.cat([audio_emb[:, :, :1].repeat(1, 1, 3, 1, 1), audio_emb], 2) # 1, 768, 44, 1, 1
|
360 |
+
audio_emb = self.audio_proj(audio_emb)
|
361 |
+
|
362 |
+
audio_emb = torch.concat([audio_cond_proj(audio_emb) for audio_cond_proj in self.audio_cond_projs], 0)
|
363 |
+
|
364 |
+
x = torch.cat([x, y], dim=1)
|
365 |
+
x = self.patch_embedding(x)
|
366 |
+
x, (f, h, w) = self.patchify(x)
|
367 |
+
|
368 |
+
freqs = torch.cat([
|
369 |
+
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
370 |
+
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
371 |
+
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
372 |
+
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
373 |
+
|
374 |
+
def create_custom_forward(module):
|
375 |
+
def custom_forward(*inputs):
|
376 |
+
return module(*inputs)
|
377 |
+
return custom_forward
|
378 |
+
|
379 |
+
if tea_cache is not None:
|
380 |
+
tea_cache_update = tea_cache.check(self, x, t_mod)
|
381 |
+
else:
|
382 |
+
tea_cache_update = False
|
383 |
+
ori_x_len = x.shape[1]
|
384 |
+
if tea_cache_update:
|
385 |
+
x = tea_cache.update(x)
|
386 |
+
else:
|
387 |
+
if args.sp_size > 1:
|
388 |
+
# Context Parallel
|
389 |
+
sp_size = get_sequence_parallel_world_size()
|
390 |
+
pad_size = 0
|
391 |
+
if ori_x_len % sp_size != 0:
|
392 |
+
pad_size = sp_size - ori_x_len % sp_size
|
393 |
+
x = torch.cat([x, torch.zeros_like(x[:, -1:]).repeat(1, pad_size, 1)], 1)
|
394 |
+
x = torch.chunk(x, sp_size, dim=1)[get_sequence_parallel_rank()]
|
395 |
+
|
396 |
+
audio_emb = audio_emb.reshape(x.shape[0], audio_emb.shape[0] // x.shape[0], -1, *audio_emb.shape[2:])
|
397 |
+
|
398 |
+
for layer_i, block in enumerate(self.blocks):
|
399 |
+
# audio cond
|
400 |
+
if self.use_audio:
|
401 |
+
au_idx = None
|
402 |
+
if (layer_i <= len(self.blocks) // 2 and layer_i > 1): # < len(self.blocks) - 1:
|
403 |
+
au_idx = layer_i - 2
|
404 |
+
audio_emb_tmp = audio_emb[:, au_idx].repeat(1, 1, lat_h // 2, lat_w // 2, 1) # 1, 11, 45, 25, 128
|
405 |
+
audio_cond_tmp = self.patchify(audio_emb_tmp.permute(0, 4, 1, 2, 3))[0]
|
406 |
+
if args.sp_size > 1:
|
407 |
+
if pad_size > 0:
|
408 |
+
audio_cond_tmp = torch.cat([audio_cond_tmp, torch.zeros_like(audio_cond_tmp[:, -1:]).repeat(1, pad_size, 1)], 1)
|
409 |
+
audio_cond_tmp = torch.chunk(audio_cond_tmp, sp_size, dim=1)[get_sequence_parallel_rank()]
|
410 |
+
x = audio_cond_tmp + x
|
411 |
+
|
412 |
+
if self.training and use_gradient_checkpointing:
|
413 |
+
if use_gradient_checkpointing_offload:
|
414 |
+
with torch.autograd.graph.save_on_cpu():
|
415 |
+
x = torch.utils.checkpoint.checkpoint(
|
416 |
+
create_custom_forward(block),
|
417 |
+
x, context, t_mod, freqs,
|
418 |
+
use_reentrant=False,
|
419 |
+
)
|
420 |
+
else:
|
421 |
+
x = torch.utils.checkpoint.checkpoint(
|
422 |
+
create_custom_forward(block),
|
423 |
+
x, context, t_mod, freqs,
|
424 |
+
use_reentrant=False,
|
425 |
+
)
|
426 |
+
else:
|
427 |
+
x = block(x, context, t_mod, freqs)
|
428 |
+
if tea_cache is not None:
|
429 |
+
x_cache = get_sp_group().all_gather(x, dim=1) # TODO: the size should be devided by sp_size
|
430 |
+
x_cache = x_cache[:, :ori_x_len]
|
431 |
+
tea_cache.store(x_cache)
|
432 |
+
|
433 |
+
x = self.head(x, t)
|
434 |
+
if args.sp_size > 1:
|
435 |
+
# Context Parallel
|
436 |
+
x = get_sp_group().all_gather(x, dim=1) # TODO: the size should be devided by sp_size
|
437 |
+
x = x[:, :ori_x_len]
|
438 |
+
|
439 |
+
x = self.unpatchify(x, (f, h, w))
|
440 |
+
return x
|
441 |
+
|
442 |
+
@staticmethod
|
443 |
+
def state_dict_converter():
|
444 |
+
return WanModelStateDictConverter()
|
445 |
+
|
446 |
+
|
447 |
+
class WanModelStateDictConverter:
|
448 |
+
def __init__(self):
|
449 |
+
pass
|
450 |
+
|
451 |
+
def from_diffusers(self, state_dict):
|
452 |
+
rename_dict = {
|
453 |
+
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
454 |
+
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
455 |
+
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
456 |
+
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
457 |
+
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
458 |
+
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
459 |
+
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
460 |
+
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
461 |
+
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
462 |
+
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
463 |
+
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
464 |
+
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
465 |
+
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
466 |
+
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
467 |
+
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
468 |
+
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
469 |
+
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
470 |
+
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
471 |
+
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
472 |
+
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
473 |
+
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
474 |
+
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
475 |
+
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
476 |
+
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
477 |
+
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
478 |
+
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
479 |
+
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
480 |
+
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
481 |
+
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
482 |
+
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
483 |
+
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
484 |
+
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
485 |
+
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
486 |
+
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
487 |
+
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
488 |
+
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
489 |
+
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
490 |
+
"patch_embedding.bias": "patch_embedding.bias",
|
491 |
+
"patch_embedding.weight": "patch_embedding.weight",
|
492 |
+
"scale_shift_table": "head.modulation",
|
493 |
+
"proj_out.bias": "head.head.bias",
|
494 |
+
"proj_out.weight": "head.head.weight",
|
495 |
+
}
|
496 |
+
state_dict_ = {}
|
497 |
+
for name, param in state_dict.items():
|
498 |
+
if name in rename_dict:
|
499 |
+
state_dict_[rename_dict[name]] = param
|
500 |
+
else:
|
501 |
+
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
502 |
+
if name_ in rename_dict:
|
503 |
+
name_ = rename_dict[name_]
|
504 |
+
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
505 |
+
state_dict_[name_] = param
|
506 |
+
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
|
507 |
+
config = {
|
508 |
+
"model_type": "t2v",
|
509 |
+
"patch_size": (1, 2, 2),
|
510 |
+
"text_len": 512,
|
511 |
+
"in_dim": 16,
|
512 |
+
"dim": 5120,
|
513 |
+
"ffn_dim": 13824,
|
514 |
+
"freq_dim": 256,
|
515 |
+
"text_dim": 4096,
|
516 |
+
"out_dim": 16,
|
517 |
+
"num_heads": 40,
|
518 |
+
"num_layers": 40,
|
519 |
+
"window_size": (-1, -1),
|
520 |
+
"qk_norm": True,
|
521 |
+
"cross_attn_norm": True,
|
522 |
+
"eps": 1e-6,
|
523 |
+
}
|
524 |
+
else:
|
525 |
+
config = {}
|
526 |
+
return state_dict_, config
|
527 |
+
|
528 |
+
def from_civitai(self, state_dict):
|
529 |
+
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
530 |
+
config = {
|
531 |
+
"has_image_input": False,
|
532 |
+
"patch_size": [1, 2, 2],
|
533 |
+
"in_dim": 16,
|
534 |
+
"dim": 1536,
|
535 |
+
"ffn_dim": 8960,
|
536 |
+
"freq_dim": 256,
|
537 |
+
"text_dim": 4096,
|
538 |
+
"out_dim": 16,
|
539 |
+
"num_heads": 12,
|
540 |
+
"num_layers": 30,
|
541 |
+
"eps": 1e-6
|
542 |
+
}
|
543 |
+
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
|
544 |
+
config = {
|
545 |
+
"has_image_input": False,
|
546 |
+
"patch_size": [1, 2, 2],
|
547 |
+
"in_dim": 16,
|
548 |
+
"dim": 5120,
|
549 |
+
"ffn_dim": 13824,
|
550 |
+
"freq_dim": 256,
|
551 |
+
"text_dim": 4096,
|
552 |
+
"out_dim": 16,
|
553 |
+
"num_heads": 40,
|
554 |
+
"num_layers": 40,
|
555 |
+
"eps": 1e-6
|
556 |
+
}
|
557 |
+
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
558 |
+
config = {
|
559 |
+
"has_image_input": True,
|
560 |
+
"patch_size": [1, 2, 2],
|
561 |
+
"in_dim": 36,
|
562 |
+
"dim": 5120,
|
563 |
+
"ffn_dim": 13824,
|
564 |
+
"freq_dim": 256,
|
565 |
+
"text_dim": 4096,
|
566 |
+
"out_dim": 16,
|
567 |
+
"num_heads": 40,
|
568 |
+
"num_layers": 40,
|
569 |
+
"eps": 1e-6
|
570 |
+
}
|
571 |
+
else:
|
572 |
+
config = {}
|
573 |
+
if hasattr(args, "model_config"):
|
574 |
+
model_config = args.model_config
|
575 |
+
if model_config is not None:
|
576 |
+
config.update(model_config)
|
577 |
+
return state_dict, config
|
OmniAvatar/models/wan_video_text_encoder.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def fp16_clamp(x):
|
9 |
+
if x.dtype == torch.float16 and torch.isinf(x).any():
|
10 |
+
clamp = torch.finfo(x.dtype).max - 1000
|
11 |
+
x = torch.clamp(x, min=-clamp, max=clamp)
|
12 |
+
return x
|
13 |
+
|
14 |
+
|
15 |
+
class GELU(nn.Module):
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return 0.5 * x * (1.0 + torch.tanh(
|
19 |
+
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
20 |
+
|
21 |
+
|
22 |
+
class T5LayerNorm(nn.Module):
|
23 |
+
|
24 |
+
def __init__(self, dim, eps=1e-6):
|
25 |
+
super(T5LayerNorm, self).__init__()
|
26 |
+
self.dim = dim
|
27 |
+
self.eps = eps
|
28 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
32 |
+
self.eps)
|
33 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
34 |
+
x = x.type_as(self.weight)
|
35 |
+
return self.weight * x
|
36 |
+
|
37 |
+
|
38 |
+
class T5Attention(nn.Module):
|
39 |
+
|
40 |
+
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
41 |
+
assert dim_attn % num_heads == 0
|
42 |
+
super(T5Attention, self).__init__()
|
43 |
+
self.dim = dim
|
44 |
+
self.dim_attn = dim_attn
|
45 |
+
self.num_heads = num_heads
|
46 |
+
self.head_dim = dim_attn // num_heads
|
47 |
+
|
48 |
+
# layers
|
49 |
+
self.q = nn.Linear(dim, dim_attn, bias=False)
|
50 |
+
self.k = nn.Linear(dim, dim_attn, bias=False)
|
51 |
+
self.v = nn.Linear(dim, dim_attn, bias=False)
|
52 |
+
self.o = nn.Linear(dim_attn, dim, bias=False)
|
53 |
+
self.dropout = nn.Dropout(dropout)
|
54 |
+
|
55 |
+
def forward(self, x, context=None, mask=None, pos_bias=None):
|
56 |
+
"""
|
57 |
+
x: [B, L1, C].
|
58 |
+
context: [B, L2, C] or None.
|
59 |
+
mask: [B, L2] or [B, L1, L2] or None.
|
60 |
+
"""
|
61 |
+
# check inputs
|
62 |
+
context = x if context is None else context
|
63 |
+
b, n, c = x.size(0), self.num_heads, self.head_dim
|
64 |
+
|
65 |
+
# compute query, key, value
|
66 |
+
q = self.q(x).view(b, -1, n, c)
|
67 |
+
k = self.k(context).view(b, -1, n, c)
|
68 |
+
v = self.v(context).view(b, -1, n, c)
|
69 |
+
|
70 |
+
# attention bias
|
71 |
+
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
72 |
+
if pos_bias is not None:
|
73 |
+
attn_bias += pos_bias
|
74 |
+
if mask is not None:
|
75 |
+
assert mask.ndim in [2, 3]
|
76 |
+
mask = mask.view(b, 1, 1,
|
77 |
+
-1) if mask.ndim == 2 else mask.unsqueeze(1)
|
78 |
+
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
79 |
+
|
80 |
+
# compute attention (T5 does not use scaling)
|
81 |
+
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
82 |
+
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
83 |
+
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
84 |
+
|
85 |
+
# output
|
86 |
+
x = x.reshape(b, -1, n * c)
|
87 |
+
x = self.o(x)
|
88 |
+
x = self.dropout(x)
|
89 |
+
return x
|
90 |
+
|
91 |
+
|
92 |
+
class T5FeedForward(nn.Module):
|
93 |
+
|
94 |
+
def __init__(self, dim, dim_ffn, dropout=0.1):
|
95 |
+
super(T5FeedForward, self).__init__()
|
96 |
+
self.dim = dim
|
97 |
+
self.dim_ffn = dim_ffn
|
98 |
+
|
99 |
+
# layers
|
100 |
+
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
101 |
+
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
102 |
+
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
103 |
+
self.dropout = nn.Dropout(dropout)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
x = self.fc1(x) * self.gate(x)
|
107 |
+
x = self.dropout(x)
|
108 |
+
x = self.fc2(x)
|
109 |
+
x = self.dropout(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class T5SelfAttention(nn.Module):
|
114 |
+
|
115 |
+
def __init__(self,
|
116 |
+
dim,
|
117 |
+
dim_attn,
|
118 |
+
dim_ffn,
|
119 |
+
num_heads,
|
120 |
+
num_buckets,
|
121 |
+
shared_pos=True,
|
122 |
+
dropout=0.1):
|
123 |
+
super(T5SelfAttention, self).__init__()
|
124 |
+
self.dim = dim
|
125 |
+
self.dim_attn = dim_attn
|
126 |
+
self.dim_ffn = dim_ffn
|
127 |
+
self.num_heads = num_heads
|
128 |
+
self.num_buckets = num_buckets
|
129 |
+
self.shared_pos = shared_pos
|
130 |
+
|
131 |
+
# layers
|
132 |
+
self.norm1 = T5LayerNorm(dim)
|
133 |
+
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
134 |
+
self.norm2 = T5LayerNorm(dim)
|
135 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
136 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
137 |
+
num_buckets, num_heads, bidirectional=True)
|
138 |
+
|
139 |
+
def forward(self, x, mask=None, pos_bias=None):
|
140 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(
|
141 |
+
x.size(1), x.size(1))
|
142 |
+
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
143 |
+
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
144 |
+
return x
|
145 |
+
|
146 |
+
|
147 |
+
class T5RelativeEmbedding(nn.Module):
|
148 |
+
|
149 |
+
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
150 |
+
super(T5RelativeEmbedding, self).__init__()
|
151 |
+
self.num_buckets = num_buckets
|
152 |
+
self.num_heads = num_heads
|
153 |
+
self.bidirectional = bidirectional
|
154 |
+
self.max_dist = max_dist
|
155 |
+
|
156 |
+
# layers
|
157 |
+
self.embedding = nn.Embedding(num_buckets, num_heads)
|
158 |
+
|
159 |
+
def forward(self, lq, lk):
|
160 |
+
device = self.embedding.weight.device
|
161 |
+
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
162 |
+
# torch.arange(lq).unsqueeze(1).to(device)
|
163 |
+
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
|
164 |
+
torch.arange(lq, device=device).unsqueeze(1)
|
165 |
+
rel_pos = self._relative_position_bucket(rel_pos)
|
166 |
+
rel_pos_embeds = self.embedding(rel_pos)
|
167 |
+
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
|
168 |
+
0) # [1, N, Lq, Lk]
|
169 |
+
return rel_pos_embeds.contiguous()
|
170 |
+
|
171 |
+
def _relative_position_bucket(self, rel_pos):
|
172 |
+
# preprocess
|
173 |
+
if self.bidirectional:
|
174 |
+
num_buckets = self.num_buckets // 2
|
175 |
+
rel_buckets = (rel_pos > 0).long() * num_buckets
|
176 |
+
rel_pos = torch.abs(rel_pos)
|
177 |
+
else:
|
178 |
+
num_buckets = self.num_buckets
|
179 |
+
rel_buckets = 0
|
180 |
+
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
181 |
+
|
182 |
+
# embeddings for small and large positions
|
183 |
+
max_exact = num_buckets // 2
|
184 |
+
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
|
185 |
+
math.log(self.max_dist / max_exact) *
|
186 |
+
(num_buckets - max_exact)).long()
|
187 |
+
rel_pos_large = torch.min(
|
188 |
+
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
189 |
+
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
190 |
+
return rel_buckets
|
191 |
+
|
192 |
+
def init_weights(m):
|
193 |
+
if isinstance(m, T5LayerNorm):
|
194 |
+
nn.init.ones_(m.weight)
|
195 |
+
elif isinstance(m, T5FeedForward):
|
196 |
+
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
197 |
+
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
198 |
+
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
199 |
+
elif isinstance(m, T5Attention):
|
200 |
+
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
|
201 |
+
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
202 |
+
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
203 |
+
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
|
204 |
+
elif isinstance(m, T5RelativeEmbedding):
|
205 |
+
nn.init.normal_(
|
206 |
+
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
|
207 |
+
|
208 |
+
|
209 |
+
class WanTextEncoder(torch.nn.Module):
|
210 |
+
|
211 |
+
def __init__(self,
|
212 |
+
vocab=256384,
|
213 |
+
dim=4096,
|
214 |
+
dim_attn=4096,
|
215 |
+
dim_ffn=10240,
|
216 |
+
num_heads=64,
|
217 |
+
num_layers=24,
|
218 |
+
num_buckets=32,
|
219 |
+
shared_pos=False,
|
220 |
+
dropout=0.1):
|
221 |
+
super(WanTextEncoder, self).__init__()
|
222 |
+
self.dim = dim
|
223 |
+
self.dim_attn = dim_attn
|
224 |
+
self.dim_ffn = dim_ffn
|
225 |
+
self.num_heads = num_heads
|
226 |
+
self.num_layers = num_layers
|
227 |
+
self.num_buckets = num_buckets
|
228 |
+
self.shared_pos = shared_pos
|
229 |
+
|
230 |
+
# layers
|
231 |
+
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
232 |
+
else nn.Embedding(vocab, dim)
|
233 |
+
self.pos_embedding = T5RelativeEmbedding(
|
234 |
+
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
235 |
+
self.dropout = nn.Dropout(dropout)
|
236 |
+
self.blocks = nn.ModuleList([
|
237 |
+
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
238 |
+
shared_pos, dropout) for _ in range(num_layers)
|
239 |
+
])
|
240 |
+
self.norm = T5LayerNorm(dim)
|
241 |
+
|
242 |
+
# initialize weights
|
243 |
+
self.apply(init_weights)
|
244 |
+
|
245 |
+
def forward(self, ids, mask=None):
|
246 |
+
x = self.token_embedding(ids)
|
247 |
+
x = self.dropout(x)
|
248 |
+
e = self.pos_embedding(x.size(1),
|
249 |
+
x.size(1)) if self.shared_pos else None
|
250 |
+
for block in self.blocks:
|
251 |
+
x = block(x, mask, pos_bias=e)
|
252 |
+
x = self.norm(x)
|
253 |
+
x = self.dropout(x)
|
254 |
+
return x
|
255 |
+
|
256 |
+
@staticmethod
|
257 |
+
def state_dict_converter():
|
258 |
+
return WanTextEncoderStateDictConverter()
|
259 |
+
|
260 |
+
|
261 |
+
class WanTextEncoderStateDictConverter:
|
262 |
+
def __init__(self):
|
263 |
+
pass
|
264 |
+
|
265 |
+
def from_diffusers(self, state_dict):
|
266 |
+
return state_dict
|
267 |
+
|
268 |
+
def from_civitai(self, state_dict):
|
269 |
+
return state_dict
|
OmniAvatar/models/wan_video_vae.py
ADDED
@@ -0,0 +1,807 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from einops import rearrange, repeat
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
CACHE_T = 2
|
9 |
+
|
10 |
+
|
11 |
+
def check_is_instance(model, module_class):
|
12 |
+
if isinstance(model, module_class):
|
13 |
+
return True
|
14 |
+
if hasattr(model, "module") and isinstance(model.module, module_class):
|
15 |
+
return True
|
16 |
+
return False
|
17 |
+
|
18 |
+
|
19 |
+
def block_causal_mask(x, block_size):
|
20 |
+
# params
|
21 |
+
b, n, s, _, device = *x.size(), x.device
|
22 |
+
assert s % block_size == 0
|
23 |
+
num_blocks = s // block_size
|
24 |
+
|
25 |
+
# build mask
|
26 |
+
mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
|
27 |
+
for i in range(num_blocks):
|
28 |
+
mask[:, :,
|
29 |
+
i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
|
30 |
+
return mask
|
31 |
+
|
32 |
+
|
33 |
+
class CausalConv3d(nn.Conv3d):
|
34 |
+
"""
|
35 |
+
Causal 3d convolusion.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, *args, **kwargs):
|
39 |
+
super().__init__(*args, **kwargs)
|
40 |
+
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
41 |
+
self.padding[1], 2 * self.padding[0], 0)
|
42 |
+
self.padding = (0, 0, 0)
|
43 |
+
|
44 |
+
def forward(self, x, cache_x=None):
|
45 |
+
padding = list(self._padding)
|
46 |
+
if cache_x is not None and self._padding[4] > 0:
|
47 |
+
cache_x = cache_x.to(x.device)
|
48 |
+
x = torch.cat([cache_x, x], dim=2)
|
49 |
+
padding[4] -= cache_x.shape[2]
|
50 |
+
x = F.pad(x, padding)
|
51 |
+
|
52 |
+
return super().forward(x)
|
53 |
+
|
54 |
+
|
55 |
+
class RMS_norm(nn.Module):
|
56 |
+
|
57 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
58 |
+
super().__init__()
|
59 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
60 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
61 |
+
|
62 |
+
self.channel_first = channel_first
|
63 |
+
self.scale = dim**0.5
|
64 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
65 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
return F.normalize(
|
69 |
+
x, dim=(1 if self.channel_first else
|
70 |
+
-1)) * self.scale * self.gamma + self.bias
|
71 |
+
|
72 |
+
|
73 |
+
class Upsample(nn.Upsample):
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
"""
|
77 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
78 |
+
"""
|
79 |
+
return super().forward(x.float()).type_as(x)
|
80 |
+
|
81 |
+
|
82 |
+
class Resample(nn.Module):
|
83 |
+
|
84 |
+
def __init__(self, dim, mode):
|
85 |
+
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
86 |
+
'downsample3d')
|
87 |
+
super().__init__()
|
88 |
+
self.dim = dim
|
89 |
+
self.mode = mode
|
90 |
+
|
91 |
+
# layers
|
92 |
+
if mode == 'upsample2d':
|
93 |
+
self.resample = nn.Sequential(
|
94 |
+
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
95 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
96 |
+
elif mode == 'upsample3d':
|
97 |
+
self.resample = nn.Sequential(
|
98 |
+
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
99 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
100 |
+
self.time_conv = CausalConv3d(dim,
|
101 |
+
dim * 2, (3, 1, 1),
|
102 |
+
padding=(1, 0, 0))
|
103 |
+
|
104 |
+
elif mode == 'downsample2d':
|
105 |
+
self.resample = nn.Sequential(
|
106 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
107 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
108 |
+
elif mode == 'downsample3d':
|
109 |
+
self.resample = nn.Sequential(
|
110 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
111 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
112 |
+
self.time_conv = CausalConv3d(dim,
|
113 |
+
dim, (3, 1, 1),
|
114 |
+
stride=(2, 1, 1),
|
115 |
+
padding=(0, 0, 0))
|
116 |
+
|
117 |
+
else:
|
118 |
+
self.resample = nn.Identity()
|
119 |
+
|
120 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
121 |
+
b, c, t, h, w = x.size()
|
122 |
+
if self.mode == 'upsample3d':
|
123 |
+
if feat_cache is not None:
|
124 |
+
idx = feat_idx[0]
|
125 |
+
if feat_cache[idx] is None:
|
126 |
+
feat_cache[idx] = 'Rep'
|
127 |
+
feat_idx[0] += 1
|
128 |
+
else:
|
129 |
+
|
130 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
131 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
132 |
+
idx] is not None and feat_cache[idx] != 'Rep':
|
133 |
+
# cache last frame of last two chunk
|
134 |
+
cache_x = torch.cat([
|
135 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
136 |
+
cache_x.device), cache_x
|
137 |
+
],
|
138 |
+
dim=2)
|
139 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
140 |
+
idx] is not None and feat_cache[idx] == 'Rep':
|
141 |
+
cache_x = torch.cat([
|
142 |
+
torch.zeros_like(cache_x).to(cache_x.device),
|
143 |
+
cache_x
|
144 |
+
],
|
145 |
+
dim=2)
|
146 |
+
if feat_cache[idx] == 'Rep':
|
147 |
+
x = self.time_conv(x)
|
148 |
+
else:
|
149 |
+
x = self.time_conv(x, feat_cache[idx])
|
150 |
+
feat_cache[idx] = cache_x
|
151 |
+
feat_idx[0] += 1
|
152 |
+
|
153 |
+
x = x.reshape(b, 2, c, t, h, w)
|
154 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
155 |
+
3)
|
156 |
+
x = x.reshape(b, c, t * 2, h, w)
|
157 |
+
t = x.shape[2]
|
158 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
159 |
+
x = self.resample(x)
|
160 |
+
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
161 |
+
|
162 |
+
if self.mode == 'downsample3d':
|
163 |
+
if feat_cache is not None:
|
164 |
+
idx = feat_idx[0]
|
165 |
+
if feat_cache[idx] is None:
|
166 |
+
feat_cache[idx] = x.clone()
|
167 |
+
feat_idx[0] += 1
|
168 |
+
else:
|
169 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
170 |
+
x = self.time_conv(
|
171 |
+
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
172 |
+
feat_cache[idx] = cache_x
|
173 |
+
feat_idx[0] += 1
|
174 |
+
return x
|
175 |
+
|
176 |
+
def init_weight(self, conv):
|
177 |
+
conv_weight = conv.weight
|
178 |
+
nn.init.zeros_(conv_weight)
|
179 |
+
c1, c2, t, h, w = conv_weight.size()
|
180 |
+
one_matrix = torch.eye(c1, c2)
|
181 |
+
init_matrix = one_matrix
|
182 |
+
nn.init.zeros_(conv_weight)
|
183 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix
|
184 |
+
conv.weight.data.copy_(conv_weight)
|
185 |
+
nn.init.zeros_(conv.bias.data)
|
186 |
+
|
187 |
+
def init_weight2(self, conv):
|
188 |
+
conv_weight = conv.weight.data
|
189 |
+
nn.init.zeros_(conv_weight)
|
190 |
+
c1, c2, t, h, w = conv_weight.size()
|
191 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
192 |
+
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
193 |
+
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
194 |
+
conv.weight.data.copy_(conv_weight)
|
195 |
+
nn.init.zeros_(conv.bias.data)
|
196 |
+
|
197 |
+
|
198 |
+
class ResidualBlock(nn.Module):
|
199 |
+
|
200 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
201 |
+
super().__init__()
|
202 |
+
self.in_dim = in_dim
|
203 |
+
self.out_dim = out_dim
|
204 |
+
|
205 |
+
# layers
|
206 |
+
self.residual = nn.Sequential(
|
207 |
+
RMS_norm(in_dim, images=False), nn.SiLU(),
|
208 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
209 |
+
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
210 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
211 |
+
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
212 |
+
if in_dim != out_dim else nn.Identity()
|
213 |
+
|
214 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
215 |
+
h = self.shortcut(x)
|
216 |
+
for layer in self.residual:
|
217 |
+
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
218 |
+
idx = feat_idx[0]
|
219 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
220 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
221 |
+
# cache last frame of last two chunk
|
222 |
+
cache_x = torch.cat([
|
223 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
224 |
+
cache_x.device), cache_x
|
225 |
+
],
|
226 |
+
dim=2)
|
227 |
+
x = layer(x, feat_cache[idx])
|
228 |
+
feat_cache[idx] = cache_x
|
229 |
+
feat_idx[0] += 1
|
230 |
+
else:
|
231 |
+
x = layer(x)
|
232 |
+
return x + h
|
233 |
+
|
234 |
+
|
235 |
+
class AttentionBlock(nn.Module):
|
236 |
+
"""
|
237 |
+
Causal self-attention with a single head.
|
238 |
+
"""
|
239 |
+
|
240 |
+
def __init__(self, dim):
|
241 |
+
super().__init__()
|
242 |
+
self.dim = dim
|
243 |
+
|
244 |
+
# layers
|
245 |
+
self.norm = RMS_norm(dim)
|
246 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
247 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
248 |
+
|
249 |
+
# zero out the last layer params
|
250 |
+
nn.init.zeros_(self.proj.weight)
|
251 |
+
|
252 |
+
def forward(self, x):
|
253 |
+
identity = x
|
254 |
+
b, c, t, h, w = x.size()
|
255 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
256 |
+
x = self.norm(x)
|
257 |
+
# compute query, key, value
|
258 |
+
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
|
259 |
+
0, 1, 3, 2).contiguous().chunk(3, dim=-1)
|
260 |
+
|
261 |
+
# apply attention
|
262 |
+
x = F.scaled_dot_product_attention(
|
263 |
+
q,
|
264 |
+
k,
|
265 |
+
v,
|
266 |
+
#attn_mask=block_causal_mask(q, block_size=h * w)
|
267 |
+
)
|
268 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
269 |
+
|
270 |
+
# output
|
271 |
+
x = self.proj(x)
|
272 |
+
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
273 |
+
return x + identity
|
274 |
+
|
275 |
+
|
276 |
+
class Encoder3d(nn.Module):
|
277 |
+
|
278 |
+
def __init__(self,
|
279 |
+
dim=128,
|
280 |
+
z_dim=4,
|
281 |
+
dim_mult=[1, 2, 4, 4],
|
282 |
+
num_res_blocks=2,
|
283 |
+
attn_scales=[],
|
284 |
+
temperal_downsample=[True, True, False],
|
285 |
+
dropout=0.0):
|
286 |
+
super().__init__()
|
287 |
+
self.dim = dim
|
288 |
+
self.z_dim = z_dim
|
289 |
+
self.dim_mult = dim_mult
|
290 |
+
self.num_res_blocks = num_res_blocks
|
291 |
+
self.attn_scales = attn_scales
|
292 |
+
self.temperal_downsample = temperal_downsample
|
293 |
+
|
294 |
+
# dimensions
|
295 |
+
dims = [dim * u for u in [1] + dim_mult]
|
296 |
+
scale = 1.0
|
297 |
+
|
298 |
+
# init block
|
299 |
+
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
300 |
+
|
301 |
+
# downsample blocks
|
302 |
+
downsamples = []
|
303 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
304 |
+
# residual (+attention) blocks
|
305 |
+
for _ in range(num_res_blocks):
|
306 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
307 |
+
if scale in attn_scales:
|
308 |
+
downsamples.append(AttentionBlock(out_dim))
|
309 |
+
in_dim = out_dim
|
310 |
+
|
311 |
+
# downsample block
|
312 |
+
if i != len(dim_mult) - 1:
|
313 |
+
mode = 'downsample3d' if temperal_downsample[
|
314 |
+
i] else 'downsample2d'
|
315 |
+
downsamples.append(Resample(out_dim, mode=mode))
|
316 |
+
scale /= 2.0
|
317 |
+
self.downsamples = nn.Sequential(*downsamples)
|
318 |
+
|
319 |
+
# middle blocks
|
320 |
+
self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
|
321 |
+
AttentionBlock(out_dim),
|
322 |
+
ResidualBlock(out_dim, out_dim, dropout))
|
323 |
+
|
324 |
+
# output blocks
|
325 |
+
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
326 |
+
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
327 |
+
|
328 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
329 |
+
if feat_cache is not None:
|
330 |
+
idx = feat_idx[0]
|
331 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
332 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
333 |
+
# cache last frame of last two chunk
|
334 |
+
cache_x = torch.cat([
|
335 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
336 |
+
cache_x.device), cache_x
|
337 |
+
],
|
338 |
+
dim=2)
|
339 |
+
x = self.conv1(x, feat_cache[idx])
|
340 |
+
feat_cache[idx] = cache_x
|
341 |
+
feat_idx[0] += 1
|
342 |
+
else:
|
343 |
+
x = self.conv1(x)
|
344 |
+
|
345 |
+
## downsamples
|
346 |
+
for layer in self.downsamples:
|
347 |
+
if feat_cache is not None:
|
348 |
+
x = layer(x, feat_cache, feat_idx)
|
349 |
+
else:
|
350 |
+
x = layer(x)
|
351 |
+
|
352 |
+
## middle
|
353 |
+
for layer in self.middle:
|
354 |
+
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
355 |
+
x = layer(x, feat_cache, feat_idx)
|
356 |
+
else:
|
357 |
+
x = layer(x)
|
358 |
+
|
359 |
+
## head
|
360 |
+
for layer in self.head:
|
361 |
+
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
362 |
+
idx = feat_idx[0]
|
363 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
364 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
365 |
+
# cache last frame of last two chunk
|
366 |
+
cache_x = torch.cat([
|
367 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
368 |
+
cache_x.device), cache_x
|
369 |
+
],
|
370 |
+
dim=2)
|
371 |
+
x = layer(x, feat_cache[idx])
|
372 |
+
feat_cache[idx] = cache_x
|
373 |
+
feat_idx[0] += 1
|
374 |
+
else:
|
375 |
+
x = layer(x)
|
376 |
+
return x
|
377 |
+
|
378 |
+
|
379 |
+
class Decoder3d(nn.Module):
|
380 |
+
|
381 |
+
def __init__(self,
|
382 |
+
dim=128,
|
383 |
+
z_dim=4,
|
384 |
+
dim_mult=[1, 2, 4, 4],
|
385 |
+
num_res_blocks=2,
|
386 |
+
attn_scales=[],
|
387 |
+
temperal_upsample=[False, True, True],
|
388 |
+
dropout=0.0):
|
389 |
+
super().__init__()
|
390 |
+
self.dim = dim
|
391 |
+
self.z_dim = z_dim
|
392 |
+
self.dim_mult = dim_mult
|
393 |
+
self.num_res_blocks = num_res_blocks
|
394 |
+
self.attn_scales = attn_scales
|
395 |
+
self.temperal_upsample = temperal_upsample
|
396 |
+
|
397 |
+
# dimensions
|
398 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
399 |
+
scale = 1.0 / 2**(len(dim_mult) - 2)
|
400 |
+
|
401 |
+
# init block
|
402 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
403 |
+
|
404 |
+
# middle blocks
|
405 |
+
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
|
406 |
+
AttentionBlock(dims[0]),
|
407 |
+
ResidualBlock(dims[0], dims[0], dropout))
|
408 |
+
|
409 |
+
# upsample blocks
|
410 |
+
upsamples = []
|
411 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
412 |
+
# residual (+attention) blocks
|
413 |
+
if i == 1 or i == 2 or i == 3:
|
414 |
+
in_dim = in_dim // 2
|
415 |
+
for _ in range(num_res_blocks + 1):
|
416 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
417 |
+
if scale in attn_scales:
|
418 |
+
upsamples.append(AttentionBlock(out_dim))
|
419 |
+
in_dim = out_dim
|
420 |
+
|
421 |
+
# upsample block
|
422 |
+
if i != len(dim_mult) - 1:
|
423 |
+
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
424 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
425 |
+
scale *= 2.0
|
426 |
+
self.upsamples = nn.Sequential(*upsamples)
|
427 |
+
|
428 |
+
# output blocks
|
429 |
+
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
430 |
+
CausalConv3d(out_dim, 3, 3, padding=1))
|
431 |
+
|
432 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
433 |
+
## conv1
|
434 |
+
if feat_cache is not None:
|
435 |
+
idx = feat_idx[0]
|
436 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
437 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
438 |
+
# cache last frame of last two chunk
|
439 |
+
cache_x = torch.cat([
|
440 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
441 |
+
cache_x.device), cache_x
|
442 |
+
],
|
443 |
+
dim=2)
|
444 |
+
x = self.conv1(x, feat_cache[idx])
|
445 |
+
feat_cache[idx] = cache_x
|
446 |
+
feat_idx[0] += 1
|
447 |
+
else:
|
448 |
+
x = self.conv1(x)
|
449 |
+
|
450 |
+
## middle
|
451 |
+
for layer in self.middle:
|
452 |
+
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
453 |
+
x = layer(x, feat_cache, feat_idx)
|
454 |
+
else:
|
455 |
+
x = layer(x)
|
456 |
+
|
457 |
+
## upsamples
|
458 |
+
for layer in self.upsamples:
|
459 |
+
if feat_cache is not None:
|
460 |
+
x = layer(x, feat_cache, feat_idx)
|
461 |
+
else:
|
462 |
+
x = layer(x)
|
463 |
+
|
464 |
+
## head
|
465 |
+
for layer in self.head:
|
466 |
+
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
467 |
+
idx = feat_idx[0]
|
468 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
469 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
470 |
+
# cache last frame of last two chunk
|
471 |
+
cache_x = torch.cat([
|
472 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
473 |
+
cache_x.device), cache_x
|
474 |
+
],
|
475 |
+
dim=2)
|
476 |
+
x = layer(x, feat_cache[idx])
|
477 |
+
feat_cache[idx] = cache_x
|
478 |
+
feat_idx[0] += 1
|
479 |
+
else:
|
480 |
+
x = layer(x)
|
481 |
+
return x
|
482 |
+
|
483 |
+
|
484 |
+
def count_conv3d(model):
|
485 |
+
count = 0
|
486 |
+
for m in model.modules():
|
487 |
+
if check_is_instance(m, CausalConv3d):
|
488 |
+
count += 1
|
489 |
+
return count
|
490 |
+
|
491 |
+
|
492 |
+
class VideoVAE_(nn.Module):
|
493 |
+
|
494 |
+
def __init__(self,
|
495 |
+
dim=96,
|
496 |
+
z_dim=16,
|
497 |
+
dim_mult=[1, 2, 4, 4],
|
498 |
+
num_res_blocks=2,
|
499 |
+
attn_scales=[],
|
500 |
+
temperal_downsample=[False, True, True],
|
501 |
+
dropout=0.0):
|
502 |
+
super().__init__()
|
503 |
+
self.dim = dim
|
504 |
+
self.z_dim = z_dim
|
505 |
+
self.dim_mult = dim_mult
|
506 |
+
self.num_res_blocks = num_res_blocks
|
507 |
+
self.attn_scales = attn_scales
|
508 |
+
self.temperal_downsample = temperal_downsample
|
509 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
510 |
+
|
511 |
+
# modules
|
512 |
+
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
513 |
+
attn_scales, self.temperal_downsample, dropout)
|
514 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
515 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
516 |
+
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
517 |
+
attn_scales, self.temperal_upsample, dropout)
|
518 |
+
|
519 |
+
def forward(self, x):
|
520 |
+
mu, log_var = self.encode(x)
|
521 |
+
z = self.reparameterize(mu, log_var)
|
522 |
+
x_recon = self.decode(z)
|
523 |
+
return x_recon, mu, log_var
|
524 |
+
|
525 |
+
def encode(self, x, scale):
|
526 |
+
self.clear_cache()
|
527 |
+
## cache
|
528 |
+
t = x.shape[2]
|
529 |
+
iter_ = 1 + (t - 1) // 4
|
530 |
+
|
531 |
+
for i in range(iter_):
|
532 |
+
self._enc_conv_idx = [0]
|
533 |
+
if i == 0:
|
534 |
+
out = self.encoder(x[:, :, :1, :, :],
|
535 |
+
feat_cache=self._enc_feat_map,
|
536 |
+
feat_idx=self._enc_conv_idx)
|
537 |
+
else:
|
538 |
+
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
539 |
+
feat_cache=self._enc_feat_map,
|
540 |
+
feat_idx=self._enc_conv_idx)
|
541 |
+
out = torch.cat([out, out_], 2)
|
542 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
543 |
+
if isinstance(scale[0], torch.Tensor):
|
544 |
+
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
|
545 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
546 |
+
1, self.z_dim, 1, 1, 1)
|
547 |
+
else:
|
548 |
+
scale = scale.to(dtype=mu.dtype, device=mu.device)
|
549 |
+
mu = (mu - scale[0]) * scale[1]
|
550 |
+
return mu
|
551 |
+
|
552 |
+
def decode(self, z, scale):
|
553 |
+
self.clear_cache()
|
554 |
+
# z: [b,c,t,h,w]
|
555 |
+
if isinstance(scale[0], torch.Tensor):
|
556 |
+
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
557 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
558 |
+
1, self.z_dim, 1, 1, 1)
|
559 |
+
else:
|
560 |
+
scale = scale.to(dtype=z.dtype, device=z.device)
|
561 |
+
z = z / scale[1] + scale[0]
|
562 |
+
iter_ = z.shape[2]
|
563 |
+
x = self.conv2(z)
|
564 |
+
for i in range(iter_):
|
565 |
+
self._conv_idx = [0]
|
566 |
+
if i == 0:
|
567 |
+
out = self.decoder(x[:, :, i:i + 1, :, :],
|
568 |
+
feat_cache=self._feat_map,
|
569 |
+
feat_idx=self._conv_idx)
|
570 |
+
else:
|
571 |
+
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
572 |
+
feat_cache=self._feat_map,
|
573 |
+
feat_idx=self._conv_idx)
|
574 |
+
out = torch.cat([out, out_], 2) # may add tensor offload
|
575 |
+
return out
|
576 |
+
|
577 |
+
def reparameterize(self, mu, log_var):
|
578 |
+
std = torch.exp(0.5 * log_var)
|
579 |
+
eps = torch.randn_like(std)
|
580 |
+
return eps * std + mu
|
581 |
+
|
582 |
+
def sample(self, imgs, deterministic=False):
|
583 |
+
mu, log_var = self.encode(imgs)
|
584 |
+
if deterministic:
|
585 |
+
return mu
|
586 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
587 |
+
return mu + std * torch.randn_like(std)
|
588 |
+
|
589 |
+
def clear_cache(self):
|
590 |
+
self._conv_num = count_conv3d(self.decoder)
|
591 |
+
self._conv_idx = [0]
|
592 |
+
self._feat_map = [None] * self._conv_num
|
593 |
+
# cache encode
|
594 |
+
self._enc_conv_num = count_conv3d(self.encoder)
|
595 |
+
self._enc_conv_idx = [0]
|
596 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
597 |
+
|
598 |
+
|
599 |
+
class WanVideoVAE(nn.Module):
|
600 |
+
|
601 |
+
def __init__(self, z_dim=16):
|
602 |
+
super().__init__()
|
603 |
+
|
604 |
+
mean = [
|
605 |
+
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
606 |
+
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
607 |
+
]
|
608 |
+
std = [
|
609 |
+
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
610 |
+
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
611 |
+
]
|
612 |
+
self.mean = torch.tensor(mean)
|
613 |
+
self.std = torch.tensor(std)
|
614 |
+
self.scale = [self.mean, 1.0 / self.std]
|
615 |
+
|
616 |
+
# init model
|
617 |
+
self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
|
618 |
+
self.upsampling_factor = 8
|
619 |
+
|
620 |
+
|
621 |
+
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
622 |
+
x = torch.ones((length,))
|
623 |
+
if not left_bound:
|
624 |
+
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
625 |
+
if not right_bound:
|
626 |
+
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
627 |
+
return x
|
628 |
+
|
629 |
+
|
630 |
+
def build_mask(self, data, is_bound, border_width):
|
631 |
+
_, _, _, H, W = data.shape
|
632 |
+
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
|
633 |
+
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
|
634 |
+
|
635 |
+
h = repeat(h, "H -> H W", H=H, W=W)
|
636 |
+
w = repeat(w, "W -> H W", H=H, W=W)
|
637 |
+
|
638 |
+
mask = torch.stack([h, w]).min(dim=0).values
|
639 |
+
mask = rearrange(mask, "H W -> 1 1 1 H W")
|
640 |
+
return mask
|
641 |
+
|
642 |
+
|
643 |
+
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
|
644 |
+
_, _, T, H, W = hidden_states.shape
|
645 |
+
size_h, size_w = tile_size
|
646 |
+
stride_h, stride_w = tile_stride
|
647 |
+
|
648 |
+
# Split tasks
|
649 |
+
tasks = []
|
650 |
+
for h in range(0, H, stride_h):
|
651 |
+
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
652 |
+
for w in range(0, W, stride_w):
|
653 |
+
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
654 |
+
h_, w_ = h + size_h, w + size_w
|
655 |
+
tasks.append((h, h_, w, w_))
|
656 |
+
|
657 |
+
data_device = "cpu"
|
658 |
+
computation_device = device
|
659 |
+
|
660 |
+
out_T = T * 4 - 3
|
661 |
+
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
662 |
+
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
663 |
+
|
664 |
+
for h, h_, w, w_ in tasks:
|
665 |
+
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
|
666 |
+
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
|
667 |
+
|
668 |
+
mask = self.build_mask(
|
669 |
+
hidden_states_batch,
|
670 |
+
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
671 |
+
border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
|
672 |
+
).to(dtype=hidden_states.dtype, device=data_device)
|
673 |
+
|
674 |
+
target_h = h * self.upsampling_factor
|
675 |
+
target_w = w * self.upsampling_factor
|
676 |
+
values[
|
677 |
+
:,
|
678 |
+
:,
|
679 |
+
:,
|
680 |
+
target_h:target_h + hidden_states_batch.shape[3],
|
681 |
+
target_w:target_w + hidden_states_batch.shape[4],
|
682 |
+
] += hidden_states_batch * mask
|
683 |
+
weight[
|
684 |
+
:,
|
685 |
+
:,
|
686 |
+
:,
|
687 |
+
target_h: target_h + hidden_states_batch.shape[3],
|
688 |
+
target_w: target_w + hidden_states_batch.shape[4],
|
689 |
+
] += mask
|
690 |
+
values = values / weight
|
691 |
+
values = values.clamp_(-1, 1)
|
692 |
+
return values
|
693 |
+
|
694 |
+
|
695 |
+
def tiled_encode(self, video, device, tile_size, tile_stride):
|
696 |
+
_, _, T, H, W = video.shape
|
697 |
+
size_h, size_w = tile_size
|
698 |
+
stride_h, stride_w = tile_stride
|
699 |
+
|
700 |
+
# Split tasks
|
701 |
+
tasks = []
|
702 |
+
for h in range(0, H, stride_h):
|
703 |
+
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
704 |
+
for w in range(0, W, stride_w):
|
705 |
+
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
706 |
+
h_, w_ = h + size_h, w + size_w
|
707 |
+
tasks.append((h, h_, w, w_))
|
708 |
+
|
709 |
+
data_device = "cpu"
|
710 |
+
computation_device = device
|
711 |
+
|
712 |
+
out_T = (T + 3) // 4
|
713 |
+
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
714 |
+
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
715 |
+
|
716 |
+
for h, h_, w, w_ in tasks:
|
717 |
+
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
|
718 |
+
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
|
719 |
+
|
720 |
+
mask = self.build_mask(
|
721 |
+
hidden_states_batch,
|
722 |
+
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
723 |
+
border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
|
724 |
+
).to(dtype=video.dtype, device=data_device)
|
725 |
+
|
726 |
+
target_h = h // self.upsampling_factor
|
727 |
+
target_w = w // self.upsampling_factor
|
728 |
+
values[
|
729 |
+
:,
|
730 |
+
:,
|
731 |
+
:,
|
732 |
+
target_h:target_h + hidden_states_batch.shape[3],
|
733 |
+
target_w:target_w + hidden_states_batch.shape[4],
|
734 |
+
] += hidden_states_batch * mask
|
735 |
+
weight[
|
736 |
+
:,
|
737 |
+
:,
|
738 |
+
:,
|
739 |
+
target_h: target_h + hidden_states_batch.shape[3],
|
740 |
+
target_w: target_w + hidden_states_batch.shape[4],
|
741 |
+
] += mask
|
742 |
+
values = values / weight
|
743 |
+
return values
|
744 |
+
|
745 |
+
|
746 |
+
def single_encode(self, video, device):
|
747 |
+
video = video.to(device)
|
748 |
+
x = self.model.encode(video, self.scale)
|
749 |
+
return x
|
750 |
+
|
751 |
+
|
752 |
+
def single_decode(self, hidden_state, device):
|
753 |
+
hidden_state = hidden_state.to(device)
|
754 |
+
video = self.model.decode(hidden_state, self.scale)
|
755 |
+
return video.clamp_(-1, 1)
|
756 |
+
|
757 |
+
|
758 |
+
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
759 |
+
|
760 |
+
videos = [video.to("cpu") for video in videos]
|
761 |
+
hidden_states = []
|
762 |
+
for video in videos:
|
763 |
+
video = video.unsqueeze(0)
|
764 |
+
if tiled:
|
765 |
+
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
|
766 |
+
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
|
767 |
+
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
|
768 |
+
else:
|
769 |
+
hidden_state = self.single_encode(video, device)
|
770 |
+
hidden_state = hidden_state.squeeze(0)
|
771 |
+
hidden_states.append(hidden_state)
|
772 |
+
hidden_states = torch.stack(hidden_states)
|
773 |
+
return hidden_states
|
774 |
+
|
775 |
+
|
776 |
+
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
777 |
+
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
778 |
+
videos = []
|
779 |
+
for hidden_state in hidden_states:
|
780 |
+
hidden_state = hidden_state.unsqueeze(0)
|
781 |
+
if tiled:
|
782 |
+
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
783 |
+
else:
|
784 |
+
video = self.single_decode(hidden_state, device)
|
785 |
+
video = video.squeeze(0)
|
786 |
+
videos.append(video)
|
787 |
+
videos = torch.stack(videos)
|
788 |
+
return videos
|
789 |
+
|
790 |
+
|
791 |
+
@staticmethod
|
792 |
+
def state_dict_converter():
|
793 |
+
return WanVideoVAEStateDictConverter()
|
794 |
+
|
795 |
+
|
796 |
+
class WanVideoVAEStateDictConverter:
|
797 |
+
|
798 |
+
def __init__(self):
|
799 |
+
pass
|
800 |
+
|
801 |
+
def from_civitai(self, state_dict):
|
802 |
+
state_dict_ = {}
|
803 |
+
if 'model_state' in state_dict:
|
804 |
+
state_dict = state_dict['model_state']
|
805 |
+
for name in state_dict:
|
806 |
+
state_dict_['model.' + name] = state_dict[name]
|
807 |
+
return state_dict_
|
OmniAvatar/models/wav2vec.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: disable=R0901
|
2 |
+
# src/models/wav2vec.py
|
3 |
+
|
4 |
+
"""
|
5 |
+
This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding.
|
6 |
+
It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities
|
7 |
+
such as feature extraction and encoding.
|
8 |
+
|
9 |
+
Classes:
|
10 |
+
Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding.
|
11 |
+
|
12 |
+
Functions:
|
13 |
+
linear_interpolation: Interpolates the features based on the sequence length.
|
14 |
+
"""
|
15 |
+
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from transformers import Wav2Vec2Model
|
18 |
+
from transformers.modeling_outputs import BaseModelOutput
|
19 |
+
|
20 |
+
|
21 |
+
class Wav2VecModel(Wav2Vec2Model):
|
22 |
+
"""
|
23 |
+
Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library.
|
24 |
+
It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
|
25 |
+
...
|
26 |
+
|
27 |
+
Attributes:
|
28 |
+
base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
|
29 |
+
|
30 |
+
Methods:
|
31 |
+
forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
|
32 |
+
, output_attentions=None, output_hidden_states=None, return_dict=None):
|
33 |
+
Forward pass of the Wav2VecModel.
|
34 |
+
It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
|
35 |
+
|
36 |
+
feature_extract(input_values, seq_len):
|
37 |
+
Extracts features from the input_values using the base model.
|
38 |
+
|
39 |
+
encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
|
40 |
+
Encodes the extracted features using the base model and returns the encoded features.
|
41 |
+
"""
|
42 |
+
def forward(
|
43 |
+
self,
|
44 |
+
input_values,
|
45 |
+
seq_len,
|
46 |
+
attention_mask=None,
|
47 |
+
mask_time_indices=None,
|
48 |
+
output_attentions=None,
|
49 |
+
output_hidden_states=None,
|
50 |
+
return_dict=None,
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
Forward pass of the Wav2Vec model.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
self: The instance of the model.
|
57 |
+
input_values: The input values (waveform) to the model.
|
58 |
+
seq_len: The sequence length of the input values.
|
59 |
+
attention_mask: Attention mask to be used for the model.
|
60 |
+
mask_time_indices: Mask indices to be used for the model.
|
61 |
+
output_attentions: If set to True, returns attentions.
|
62 |
+
output_hidden_states: If set to True, returns hidden states.
|
63 |
+
return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
The output of the Wav2Vec model.
|
67 |
+
"""
|
68 |
+
self.config.output_attentions = True
|
69 |
+
|
70 |
+
output_hidden_states = (
|
71 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
72 |
+
)
|
73 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
74 |
+
|
75 |
+
extract_features = self.feature_extractor(input_values)
|
76 |
+
extract_features = extract_features.transpose(1, 2)
|
77 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
78 |
+
|
79 |
+
if attention_mask is not None:
|
80 |
+
# compute reduced attention_mask corresponding to feature vectors
|
81 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
82 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
83 |
+
)
|
84 |
+
|
85 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
86 |
+
hidden_states = self._mask_hidden_states(
|
87 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
88 |
+
)
|
89 |
+
|
90 |
+
encoder_outputs = self.encoder(
|
91 |
+
hidden_states,
|
92 |
+
attention_mask=attention_mask,
|
93 |
+
output_attentions=output_attentions,
|
94 |
+
output_hidden_states=output_hidden_states,
|
95 |
+
return_dict=return_dict,
|
96 |
+
)
|
97 |
+
|
98 |
+
hidden_states = encoder_outputs[0]
|
99 |
+
|
100 |
+
if self.adapter is not None:
|
101 |
+
hidden_states = self.adapter(hidden_states)
|
102 |
+
|
103 |
+
if not return_dict:
|
104 |
+
return (hidden_states, ) + encoder_outputs[1:]
|
105 |
+
return BaseModelOutput(
|
106 |
+
last_hidden_state=hidden_states,
|
107 |
+
hidden_states=encoder_outputs.hidden_states,
|
108 |
+
attentions=encoder_outputs.attentions,
|
109 |
+
)
|
110 |
+
|
111 |
+
|
112 |
+
def feature_extract(
|
113 |
+
self,
|
114 |
+
input_values,
|
115 |
+
seq_len,
|
116 |
+
):
|
117 |
+
"""
|
118 |
+
Extracts features from the input values and returns the extracted features.
|
119 |
+
|
120 |
+
Parameters:
|
121 |
+
input_values (torch.Tensor): The input values to be processed.
|
122 |
+
seq_len (torch.Tensor): The sequence lengths of the input values.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
extracted_features (torch.Tensor): The extracted features from the input values.
|
126 |
+
"""
|
127 |
+
extract_features = self.feature_extractor(input_values)
|
128 |
+
extract_features = extract_features.transpose(1, 2)
|
129 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
130 |
+
|
131 |
+
return extract_features
|
132 |
+
|
133 |
+
def encode(
|
134 |
+
self,
|
135 |
+
extract_features,
|
136 |
+
attention_mask=None,
|
137 |
+
mask_time_indices=None,
|
138 |
+
output_attentions=None,
|
139 |
+
output_hidden_states=None,
|
140 |
+
return_dict=None,
|
141 |
+
):
|
142 |
+
"""
|
143 |
+
Encodes the input features into the output space.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
extract_features (torch.Tensor): The extracted features from the audio signal.
|
147 |
+
attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
|
148 |
+
mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
|
149 |
+
output_attentions (bool, optional): If set to True, returns the attention weights.
|
150 |
+
output_hidden_states (bool, optional): If set to True, returns all hidden states.
|
151 |
+
return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
The encoded output features.
|
155 |
+
"""
|
156 |
+
self.config.output_attentions = True
|
157 |
+
|
158 |
+
output_hidden_states = (
|
159 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
160 |
+
)
|
161 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
162 |
+
|
163 |
+
if attention_mask is not None:
|
164 |
+
# compute reduced attention_mask corresponding to feature vectors
|
165 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
166 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
167 |
+
)
|
168 |
+
|
169 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
170 |
+
hidden_states = self._mask_hidden_states(
|
171 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
172 |
+
)
|
173 |
+
|
174 |
+
encoder_outputs = self.encoder(
|
175 |
+
hidden_states,
|
176 |
+
attention_mask=attention_mask,
|
177 |
+
output_attentions=output_attentions,
|
178 |
+
output_hidden_states=output_hidden_states,
|
179 |
+
return_dict=return_dict,
|
180 |
+
)
|
181 |
+
|
182 |
+
hidden_states = encoder_outputs[0]
|
183 |
+
|
184 |
+
if self.adapter is not None:
|
185 |
+
hidden_states = self.adapter(hidden_states)
|
186 |
+
|
187 |
+
if not return_dict:
|
188 |
+
return (hidden_states, ) + encoder_outputs[1:]
|
189 |
+
return BaseModelOutput(
|
190 |
+
last_hidden_state=hidden_states,
|
191 |
+
hidden_states=encoder_outputs.hidden_states,
|
192 |
+
attentions=encoder_outputs.attentions,
|
193 |
+
)
|
194 |
+
|
195 |
+
|
196 |
+
def linear_interpolation(features, seq_len):
|
197 |
+
"""
|
198 |
+
Transpose the features to interpolate linearly.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
features (torch.Tensor): The extracted features to be interpolated.
|
202 |
+
seq_len (torch.Tensor): The sequence lengths of the features.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
torch.Tensor: The interpolated features.
|
206 |
+
"""
|
207 |
+
features = features.transpose(1, 2)
|
208 |
+
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
|
209 |
+
return output_features.transpose(1, 2)
|
OmniAvatar/prompters/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .wan_prompter import WanPrompter
|
OmniAvatar/prompters/base_prompter.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..models.model_manager import ModelManager
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
def tokenize_long_prompt(tokenizer, prompt, max_length=None):
|
7 |
+
# Get model_max_length from self.tokenizer
|
8 |
+
length = tokenizer.model_max_length if max_length is None else max_length
|
9 |
+
|
10 |
+
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
11 |
+
tokenizer.model_max_length = 99999999
|
12 |
+
|
13 |
+
# Tokenize it!
|
14 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
15 |
+
|
16 |
+
# Determine the real length.
|
17 |
+
max_length = (input_ids.shape[1] + length - 1) // length * length
|
18 |
+
|
19 |
+
# Restore tokenizer.model_max_length
|
20 |
+
tokenizer.model_max_length = length
|
21 |
+
|
22 |
+
# Tokenize it again with fixed length.
|
23 |
+
input_ids = tokenizer(
|
24 |
+
prompt,
|
25 |
+
return_tensors="pt",
|
26 |
+
padding="max_length",
|
27 |
+
max_length=max_length,
|
28 |
+
truncation=True
|
29 |
+
).input_ids
|
30 |
+
|
31 |
+
# Reshape input_ids to fit the text encoder.
|
32 |
+
num_sentence = input_ids.shape[1] // length
|
33 |
+
input_ids = input_ids.reshape((num_sentence, length))
|
34 |
+
|
35 |
+
return input_ids
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
class BasePrompter:
|
40 |
+
def __init__(self):
|
41 |
+
self.refiners = []
|
42 |
+
self.extenders = []
|
43 |
+
|
44 |
+
|
45 |
+
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
|
46 |
+
for refiner_class in refiner_classes:
|
47 |
+
refiner = refiner_class.from_model_manager(model_manager)
|
48 |
+
self.refiners.append(refiner)
|
49 |
+
|
50 |
+
def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]):
|
51 |
+
for extender_class in extender_classes:
|
52 |
+
extender = extender_class.from_model_manager(model_manager)
|
53 |
+
self.extenders.append(extender)
|
54 |
+
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def process_prompt(self, prompt, positive=True):
|
58 |
+
if isinstance(prompt, list):
|
59 |
+
prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
|
60 |
+
else:
|
61 |
+
for refiner in self.refiners:
|
62 |
+
prompt = refiner(prompt, positive=positive)
|
63 |
+
return prompt
|
64 |
+
|
65 |
+
@torch.no_grad()
|
66 |
+
def extend_prompt(self, prompt:str, positive=True):
|
67 |
+
extended_prompt = dict(prompt=prompt)
|
68 |
+
for extender in self.extenders:
|
69 |
+
extended_prompt = extender(extended_prompt)
|
70 |
+
return extended_prompt
|
OmniAvatar/prompters/wan_prompter.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_prompter import BasePrompter
|
2 |
+
from ..models.wan_video_text_encoder import WanTextEncoder
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
import os, torch
|
5 |
+
import ftfy
|
6 |
+
import html
|
7 |
+
import string
|
8 |
+
import regex as re
|
9 |
+
|
10 |
+
|
11 |
+
def basic_clean(text):
|
12 |
+
text = ftfy.fix_text(text)
|
13 |
+
text = html.unescape(html.unescape(text))
|
14 |
+
return text.strip()
|
15 |
+
|
16 |
+
|
17 |
+
def whitespace_clean(text):
|
18 |
+
text = re.sub(r'\s+', ' ', text)
|
19 |
+
text = text.strip()
|
20 |
+
return text
|
21 |
+
|
22 |
+
|
23 |
+
def canonicalize(text, keep_punctuation_exact_string=None):
|
24 |
+
text = text.replace('_', ' ')
|
25 |
+
if keep_punctuation_exact_string:
|
26 |
+
text = keep_punctuation_exact_string.join(
|
27 |
+
part.translate(str.maketrans('', '', string.punctuation))
|
28 |
+
for part in text.split(keep_punctuation_exact_string))
|
29 |
+
else:
|
30 |
+
text = text.translate(str.maketrans('', '', string.punctuation))
|
31 |
+
text = text.lower()
|
32 |
+
text = re.sub(r'\s+', ' ', text)
|
33 |
+
return text.strip()
|
34 |
+
|
35 |
+
|
36 |
+
class HuggingfaceTokenizer:
|
37 |
+
|
38 |
+
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
39 |
+
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
40 |
+
self.name = name
|
41 |
+
self.seq_len = seq_len
|
42 |
+
self.clean = clean
|
43 |
+
|
44 |
+
# init tokenizer
|
45 |
+
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
46 |
+
self.vocab_size = self.tokenizer.vocab_size
|
47 |
+
|
48 |
+
def __call__(self, sequence, **kwargs):
|
49 |
+
return_mask = kwargs.pop('return_mask', False)
|
50 |
+
|
51 |
+
# arguments
|
52 |
+
_kwargs = {'return_tensors': 'pt'}
|
53 |
+
if self.seq_len is not None:
|
54 |
+
_kwargs.update({
|
55 |
+
'padding': 'max_length',
|
56 |
+
'truncation': True,
|
57 |
+
'max_length': self.seq_len
|
58 |
+
})
|
59 |
+
_kwargs.update(**kwargs)
|
60 |
+
|
61 |
+
# tokenization
|
62 |
+
if isinstance(sequence, str):
|
63 |
+
sequence = [sequence]
|
64 |
+
if self.clean:
|
65 |
+
sequence = [self._clean(u) for u in sequence]
|
66 |
+
ids = self.tokenizer(sequence, **_kwargs)
|
67 |
+
|
68 |
+
# output
|
69 |
+
if return_mask:
|
70 |
+
return ids.input_ids, ids.attention_mask
|
71 |
+
else:
|
72 |
+
return ids.input_ids
|
73 |
+
|
74 |
+
def _clean(self, text):
|
75 |
+
if self.clean == 'whitespace':
|
76 |
+
text = whitespace_clean(basic_clean(text))
|
77 |
+
elif self.clean == 'lower':
|
78 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
79 |
+
elif self.clean == 'canonicalize':
|
80 |
+
text = canonicalize(basic_clean(text))
|
81 |
+
return text
|
82 |
+
|
83 |
+
|
84 |
+
class WanPrompter(BasePrompter):
|
85 |
+
|
86 |
+
def __init__(self, tokenizer_path=None, text_len=512):
|
87 |
+
super().__init__()
|
88 |
+
self.text_len = text_len
|
89 |
+
self.text_encoder = None
|
90 |
+
self.fetch_tokenizer(tokenizer_path)
|
91 |
+
|
92 |
+
def fetch_tokenizer(self, tokenizer_path=None):
|
93 |
+
if tokenizer_path is not None:
|
94 |
+
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
|
95 |
+
|
96 |
+
def fetch_models(self, text_encoder: WanTextEncoder = None):
|
97 |
+
self.text_encoder = text_encoder
|
98 |
+
|
99 |
+
def encode_prompt(self, prompt, positive=True, device="cuda"):
|
100 |
+
prompt = self.process_prompt(prompt, positive=positive)
|
101 |
+
|
102 |
+
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
|
103 |
+
ids = ids.to(device)
|
104 |
+
mask = mask.to(device)
|
105 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
106 |
+
prompt_emb = self.text_encoder(ids, mask)
|
107 |
+
for i, v in enumerate(seq_lens):
|
108 |
+
prompt_emb[:, v:] = 0
|
109 |
+
return prompt_emb
|
OmniAvatar/schedulers/flow_match.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
class FlowMatchScheduler():
|
6 |
+
|
7 |
+
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
|
8 |
+
self.num_train_timesteps = num_train_timesteps
|
9 |
+
self.shift = shift
|
10 |
+
self.sigma_max = sigma_max
|
11 |
+
self.sigma_min = sigma_min
|
12 |
+
self.inverse_timesteps = inverse_timesteps
|
13 |
+
self.extra_one_step = extra_one_step
|
14 |
+
self.reverse_sigmas = reverse_sigmas
|
15 |
+
self.set_timesteps(num_inference_steps)
|
16 |
+
|
17 |
+
|
18 |
+
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
|
19 |
+
if shift is not None:
|
20 |
+
self.shift = shift
|
21 |
+
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
22 |
+
if self.extra_one_step:
|
23 |
+
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
24 |
+
else:
|
25 |
+
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
26 |
+
if self.inverse_timesteps:
|
27 |
+
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
28 |
+
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
29 |
+
if self.reverse_sigmas:
|
30 |
+
self.sigmas = 1 - self.sigmas
|
31 |
+
self.timesteps = self.sigmas * self.num_train_timesteps
|
32 |
+
if training:
|
33 |
+
x = self.timesteps
|
34 |
+
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
|
35 |
+
y_shifted = y - y.min()
|
36 |
+
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
37 |
+
self.linear_timesteps_weights = bsmntw_weighing
|
38 |
+
|
39 |
+
|
40 |
+
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
41 |
+
if isinstance(timestep, torch.Tensor):
|
42 |
+
timestep = timestep.cpu()
|
43 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
44 |
+
sigma = self.sigmas[timestep_id]
|
45 |
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
46 |
+
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
|
47 |
+
else:
|
48 |
+
sigma_ = self.sigmas[timestep_id + 1]
|
49 |
+
prev_sample = sample + model_output * (sigma_ - sigma)
|
50 |
+
return prev_sample
|
51 |
+
|
52 |
+
|
53 |
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
54 |
+
if isinstance(timestep, torch.Tensor):
|
55 |
+
timestep = timestep.cpu()
|
56 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
57 |
+
sigma = self.sigmas[timestep_id]
|
58 |
+
model_output = (sample - sample_stablized) / sigma
|
59 |
+
return model_output
|
60 |
+
|
61 |
+
|
62 |
+
def add_noise(self, original_samples, noise, timestep):
|
63 |
+
if isinstance(timestep, torch.Tensor):
|
64 |
+
timestep = timestep.cpu()
|
65 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
66 |
+
sigma = self.sigmas[timestep_id]
|
67 |
+
sample = (1 - sigma) * original_samples + sigma * noise
|
68 |
+
return sample
|
69 |
+
|
70 |
+
|
71 |
+
def training_target(self, sample, noise, timestep):
|
72 |
+
target = noise - sample
|
73 |
+
return target
|
74 |
+
|
75 |
+
|
76 |
+
def training_weight(self, timestep):
|
77 |
+
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
78 |
+
weights = self.linear_timesteps_weights[timestep_id]
|
79 |
+
return weights
|
OmniAvatar/utils/args_config.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import yaml
|
5 |
+
args = None
|
6 |
+
|
7 |
+
def set_global_args(local_args):
|
8 |
+
global args
|
9 |
+
|
10 |
+
args = local_args
|
11 |
+
|
12 |
+
def parse_hp_string(hp_string):
|
13 |
+
result = {}
|
14 |
+
for pair in hp_string.split(','):
|
15 |
+
if not pair:
|
16 |
+
continue
|
17 |
+
key, value = pair.split('=')
|
18 |
+
try:
|
19 |
+
# 自动转换为 int / float / str
|
20 |
+
ori_value = value
|
21 |
+
value = float(value)
|
22 |
+
if '.' not in str(ori_value):
|
23 |
+
value = int(value)
|
24 |
+
except ValueError:
|
25 |
+
pass
|
26 |
+
|
27 |
+
if value in ['true', 'True']:
|
28 |
+
value = True
|
29 |
+
if value in ['false', 'False']:
|
30 |
+
value = False
|
31 |
+
if '.' in key:
|
32 |
+
keys = key.split('.')
|
33 |
+
keys = keys
|
34 |
+
current = result
|
35 |
+
for key in keys[:-1]:
|
36 |
+
if key not in current or not isinstance(current[key], dict):
|
37 |
+
current[key] = {}
|
38 |
+
current = current[key]
|
39 |
+
current[keys[-1]] = value
|
40 |
+
else:
|
41 |
+
result[key.strip()] = value
|
42 |
+
return result
|
43 |
+
|
44 |
+
def parse_args():
|
45 |
+
global args
|
46 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
47 |
+
parser.add_argument("--config", type=str, required=True, help="Path to YAML config file.")
|
48 |
+
|
49 |
+
# 定义 argparse 参数
|
50 |
+
parser.add_argument("--exp_path", type=str, help="Path to save the model.")
|
51 |
+
parser.add_argument("--input_file", type=str, help="Path to inference txt.")
|
52 |
+
parser.add_argument("--debug", action='store_true', default=None)
|
53 |
+
parser.add_argument("--infer", action='store_true')
|
54 |
+
parser.add_argument("-hp", "--hparams", type=str, default="")
|
55 |
+
|
56 |
+
args = parser.parse_args()
|
57 |
+
|
58 |
+
# 读取 YAML 配置(如果提供了 --config 参数)
|
59 |
+
if args.config:
|
60 |
+
with open(args.config, "r") as f:
|
61 |
+
yaml_config = yaml.safe_load(f)
|
62 |
+
|
63 |
+
# 遍历 YAML 配置,将其添加到 args(如果 argparse 里没有定义)
|
64 |
+
for key, value in yaml_config.items():
|
65 |
+
if not hasattr(args, key): # argparse 没有的参数
|
66 |
+
setattr(args, key, value)
|
67 |
+
elif getattr(args, key) is None: # argparse 有但值为空
|
68 |
+
setattr(args, key, value)
|
69 |
+
|
70 |
+
args.rank = int(os.getenv("RANK", "0"))
|
71 |
+
args.world_size = int(os.getenv("WORLD_SIZE", "1"))
|
72 |
+
args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun
|
73 |
+
args.device = 'cuda'
|
74 |
+
debug = args.debug
|
75 |
+
if not os.path.exists(args.exp_path):
|
76 |
+
args.exp_path = f'checkpoints/{args.exp_path}'
|
77 |
+
|
78 |
+
if hasattr(args, 'reload_cfg') and args.reload_cfg:
|
79 |
+
# 重新加载配置文件
|
80 |
+
conf_path = os.path.join(args.exp_path, "config.json")
|
81 |
+
if os.path.exists(conf_path):
|
82 |
+
print('| Reloading config from:', conf_path)
|
83 |
+
args = reload(args, conf_path)
|
84 |
+
if len(args.hparams) > 0:
|
85 |
+
hp_dict = parse_hp_string(args.hparams)
|
86 |
+
for key, value in hp_dict.items():
|
87 |
+
if not hasattr(args, key):
|
88 |
+
setattr(args, key, value)
|
89 |
+
else:
|
90 |
+
if isinstance(value, dict):
|
91 |
+
ori_v = getattr(args, key)
|
92 |
+
ori_v.update(value)
|
93 |
+
setattr(args, key, ori_v)
|
94 |
+
else:
|
95 |
+
setattr(args, key, value)
|
96 |
+
args.debug = debug
|
97 |
+
dict_args = convert_namespace_to_dict(args)
|
98 |
+
if args.local_rank == 0:
|
99 |
+
print(dict_args)
|
100 |
+
return args
|
101 |
+
|
102 |
+
def reload(args, conf_path):
|
103 |
+
"""重新加载配置文件,不覆盖已有的参数"""
|
104 |
+
with open(conf_path, "r") as f:
|
105 |
+
yaml_config = yaml.safe_load(f)
|
106 |
+
# 遍历 YAML 配置,将其添加到 args(如果 argparse 里没有定义)
|
107 |
+
for key, value in yaml_config.items():
|
108 |
+
if not hasattr(args, key): # argparse 没有的参数
|
109 |
+
setattr(args, key, value)
|
110 |
+
elif getattr(args, key) is None: # argparse 有但值为空
|
111 |
+
setattr(args, key, value)
|
112 |
+
return args
|
113 |
+
|
114 |
+
def convert_namespace_to_dict(namespace):
|
115 |
+
"""将 argparse.Namespace 转为字典,并处理不可序列化对象"""
|
116 |
+
result = {}
|
117 |
+
for key, value in vars(namespace).items():
|
118 |
+
try:
|
119 |
+
json.dumps(value) # 检查是否可序列化
|
120 |
+
result[key] = value
|
121 |
+
except (TypeError, OverflowError):
|
122 |
+
result[key] = str(value) # 将不可序列化的对象转为字符串表示
|
123 |
+
return result
|
OmniAvatar/utils/audio_preprocess.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
|
4 |
+
def add_silence_to_audio_ffmpeg(audio_path, tmp_audio_path, silence_duration_s=0.5):
|
5 |
+
# 使用 ffmpeg 命令在音频前加上静音
|
6 |
+
command = [
|
7 |
+
'ffmpeg',
|
8 |
+
'-i', audio_path, # 输入音频文件路径
|
9 |
+
'-f', 'lavfi', # 使用 lavfi 虚拟输入设备生成静音
|
10 |
+
'-t', str(silence_duration_s), # 静音时长,单位秒
|
11 |
+
'-i', 'anullsrc=r=16000:cl=stereo', # 创建静音片段(假设音频为 stereo,采样率 44100)
|
12 |
+
'-filter_complex', '[1][0]concat=n=2:v=0:a=1[out]', # 合并静音和原音频
|
13 |
+
'-map', '[out]', # 输出合并后的音频
|
14 |
+
'-y', tmp_audio_path, # 输出文件路径
|
15 |
+
'-loglevel', 'quiet'
|
16 |
+
]
|
17 |
+
|
18 |
+
subprocess.run(command, check=True)
|
OmniAvatar/utils/io_utils.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import torch, os
|
3 |
+
from safetensors import safe_open
|
4 |
+
from OmniAvatar.utils.args_config import args
|
5 |
+
from contextlib import contextmanager
|
6 |
+
|
7 |
+
import re
|
8 |
+
import tempfile
|
9 |
+
import numpy as np
|
10 |
+
import imageio
|
11 |
+
from glob import glob
|
12 |
+
import soundfile as sf
|
13 |
+
from einops import rearrange
|
14 |
+
import hashlib
|
15 |
+
|
16 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
17 |
+
|
18 |
+
@contextmanager
|
19 |
+
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
20 |
+
|
21 |
+
old_register_parameter = torch.nn.Module.register_parameter
|
22 |
+
if include_buffers:
|
23 |
+
old_register_buffer = torch.nn.Module.register_buffer
|
24 |
+
|
25 |
+
def register_empty_parameter(module, name, param):
|
26 |
+
old_register_parameter(module, name, param)
|
27 |
+
if param is not None:
|
28 |
+
param_cls = type(module._parameters[name])
|
29 |
+
kwargs = module._parameters[name].__dict__
|
30 |
+
kwargs["requires_grad"] = param.requires_grad
|
31 |
+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
32 |
+
|
33 |
+
def register_empty_buffer(module, name, buffer, persistent=True):
|
34 |
+
old_register_buffer(module, name, buffer, persistent=persistent)
|
35 |
+
if buffer is not None:
|
36 |
+
module._buffers[name] = module._buffers[name].to(device)
|
37 |
+
|
38 |
+
def patch_tensor_constructor(fn):
|
39 |
+
def wrapper(*args, **kwargs):
|
40 |
+
kwargs["device"] = device
|
41 |
+
return fn(*args, **kwargs)
|
42 |
+
|
43 |
+
return wrapper
|
44 |
+
|
45 |
+
if include_buffers:
|
46 |
+
tensor_constructors_to_patch = {
|
47 |
+
torch_function_name: getattr(torch, torch_function_name)
|
48 |
+
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
49 |
+
}
|
50 |
+
else:
|
51 |
+
tensor_constructors_to_patch = {}
|
52 |
+
|
53 |
+
try:
|
54 |
+
torch.nn.Module.register_parameter = register_empty_parameter
|
55 |
+
if include_buffers:
|
56 |
+
torch.nn.Module.register_buffer = register_empty_buffer
|
57 |
+
for torch_function_name in tensor_constructors_to_patch.keys():
|
58 |
+
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
59 |
+
yield
|
60 |
+
finally:
|
61 |
+
torch.nn.Module.register_parameter = old_register_parameter
|
62 |
+
if include_buffers:
|
63 |
+
torch.nn.Module.register_buffer = old_register_buffer
|
64 |
+
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
65 |
+
setattr(torch, torch_function_name, old_torch_function)
|
66 |
+
|
67 |
+
def load_state_dict_from_folder(file_path, torch_dtype=None):
|
68 |
+
state_dict = {}
|
69 |
+
for file_name in os.listdir(file_path):
|
70 |
+
if "." in file_name and file_name.split(".")[-1] in [
|
71 |
+
"safetensors", "bin", "ckpt", "pth", "pt"
|
72 |
+
]:
|
73 |
+
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
|
74 |
+
return state_dict
|
75 |
+
|
76 |
+
|
77 |
+
def load_state_dict(file_path, torch_dtype=None):
|
78 |
+
if file_path.endswith(".safetensors"):
|
79 |
+
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
80 |
+
else:
|
81 |
+
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
82 |
+
|
83 |
+
|
84 |
+
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
85 |
+
state_dict = {}
|
86 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
87 |
+
for k in f.keys():
|
88 |
+
state_dict[k] = f.get_tensor(k)
|
89 |
+
if torch_dtype is not None:
|
90 |
+
state_dict[k] = state_dict[k].to(torch_dtype)
|
91 |
+
return state_dict
|
92 |
+
|
93 |
+
|
94 |
+
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
95 |
+
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
96 |
+
if torch_dtype is not None:
|
97 |
+
for i in state_dict:
|
98 |
+
if isinstance(state_dict[i], torch.Tensor):
|
99 |
+
state_dict[i] = state_dict[i].to(torch_dtype)
|
100 |
+
return state_dict
|
101 |
+
|
102 |
+
def smart_load_weights(model, ckpt_state_dict):
|
103 |
+
model_state_dict = model.state_dict()
|
104 |
+
new_state_dict = {}
|
105 |
+
|
106 |
+
for name, param in model_state_dict.items():
|
107 |
+
if name in ckpt_state_dict:
|
108 |
+
ckpt_param = ckpt_state_dict[name]
|
109 |
+
if param.shape == ckpt_param.shape:
|
110 |
+
new_state_dict[name] = ckpt_param
|
111 |
+
else:
|
112 |
+
# 自动修剪维度以匹配
|
113 |
+
if all(p >= c for p, c in zip(param.shape, ckpt_param.shape)):
|
114 |
+
print(f"[Truncate] {name}: ckpt {ckpt_param.shape} -> model {param.shape}")
|
115 |
+
# 创建新张量,拷贝旧数据
|
116 |
+
new_param = param.clone()
|
117 |
+
slices = tuple(slice(0, s) for s in ckpt_param.shape)
|
118 |
+
new_param[slices] = ckpt_param
|
119 |
+
new_state_dict[name] = new_param
|
120 |
+
else:
|
121 |
+
print(f"[Skip] {name}: ckpt {ckpt_param.shape} is larger than model {param.shape}")
|
122 |
+
|
123 |
+
# 更新 state_dict,只更新那些匹配的
|
124 |
+
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, assign=True, strict=False)
|
125 |
+
return model, missing_keys, unexpected_keys
|
126 |
+
|
127 |
+
def save_wav(audio, audio_path):
|
128 |
+
if isinstance(audio, torch.Tensor):
|
129 |
+
audio = audio.float().detach().cpu().numpy()
|
130 |
+
|
131 |
+
if audio.ndim == 1:
|
132 |
+
audio = np.expand_dims(audio, axis=0) # (1, samples)
|
133 |
+
|
134 |
+
sf.write(audio_path, audio.T, 16000)
|
135 |
+
|
136 |
+
return True
|
137 |
+
|
138 |
+
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: float = 5,prompt=None, prompt_path=None, audio=None, audio_path=None, prefix=None):
|
139 |
+
os.makedirs(save_path, exist_ok=True)
|
140 |
+
out_videos = []
|
141 |
+
|
142 |
+
with tempfile.TemporaryDirectory() as tmp_path:
|
143 |
+
|
144 |
+
print(f'video batch shape:{video_batch.shape}')
|
145 |
+
|
146 |
+
for i, vid in enumerate(video_batch):
|
147 |
+
gif_frames = []
|
148 |
+
|
149 |
+
for frame in vid:
|
150 |
+
ft = frame.detach().cpu().clone()
|
151 |
+
ft = rearrange(ft, "c h w -> h w c")
|
152 |
+
arr = (255.0 * ft).numpy().astype(np.uint8)
|
153 |
+
gif_frames.append(arr)
|
154 |
+
|
155 |
+
if prefix is not None:
|
156 |
+
now_save_path = os.path.join(save_path, f"{prefix}_{i:03d}.mp4")
|
157 |
+
tmp_save_path = os.path.join(tmp_path, f"{prefix}_{i:03d}.mp4")
|
158 |
+
else:
|
159 |
+
now_save_path = os.path.join(save_path, f"{i:03d}.mp4")
|
160 |
+
tmp_save_path = os.path.join(tmp_path, f"{i:03d}.mp4")
|
161 |
+
with imageio.get_writer(tmp_save_path, fps=fps) as writer:
|
162 |
+
for frame in gif_frames:
|
163 |
+
writer.append_data(frame)
|
164 |
+
subprocess.run([f"cp {tmp_save_path} {now_save_path}"], check=True, shell=True)
|
165 |
+
print(f'save res video to : {now_save_path}')
|
166 |
+
final_video_path = now_save_path
|
167 |
+
|
168 |
+
if audio is not None or audio_path is not None:
|
169 |
+
if audio is not None:
|
170 |
+
audio_path = os.path.join(tmp_path, f"{i:06d}.mp3")
|
171 |
+
save_wav(audio[i], audio_path)
|
172 |
+
# cmd = f'/usr/bin/ffmpeg -i {tmp_save_path} -i {audio_path} -v quiet -c:v copy -c:a libmp3lame -strict experimental {tmp_save_path[:-4]}_wav.mp4 -y'
|
173 |
+
cmd = f'/usr/bin/ffmpeg -i {tmp_save_path} -i {audio_path} -v quiet -map 0:v:0 -map 1:a:0 -c:v copy -c:a aac {tmp_save_path[:-4]}_wav.mp4 -y'
|
174 |
+
subprocess.check_call(cmd, stdout=None, stdin=subprocess.PIPE, shell=True)
|
175 |
+
final_video_path = f"{now_save_path[:-4]}_wav.mp4"
|
176 |
+
subprocess.run([f"cp {tmp_save_path[:-4]}_wav.mp4 {final_video_path}"], check=True, shell=True)
|
177 |
+
os.remove(now_save_path)
|
178 |
+
if prompt is not None and prompt_path is not None:
|
179 |
+
with open(prompt_path, "w") as f:
|
180 |
+
f.write(prompt)
|
181 |
+
out_videos.append(final_video_path)
|
182 |
+
|
183 |
+
return out_videos
|
184 |
+
|
185 |
+
def is_zero_stage_3(trainer):
|
186 |
+
strategy = getattr(trainer, "strategy", None)
|
187 |
+
if strategy and hasattr(strategy, "model"):
|
188 |
+
ds_engine = strategy.model
|
189 |
+
stage = ds_engine.config.get("zero_optimization", {}).get("stage", 0)
|
190 |
+
return stage == 3
|
191 |
+
return False
|
192 |
+
|
193 |
+
def hash_state_dict_keys(state_dict, with_shape=True):
|
194 |
+
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
195 |
+
keys_str = keys_str.encode(encoding="UTF-8")
|
196 |
+
return hashlib.md5(keys_str).hexdigest()
|
197 |
+
|
198 |
+
def split_state_dict_with_prefix(state_dict):
|
199 |
+
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
200 |
+
prefix_dict = {}
|
201 |
+
for key in keys:
|
202 |
+
prefix = key if "." not in key else key.split(".")[0]
|
203 |
+
if prefix not in prefix_dict:
|
204 |
+
prefix_dict[prefix] = []
|
205 |
+
prefix_dict[prefix].append(key)
|
206 |
+
state_dicts = []
|
207 |
+
for prefix, keys in prefix_dict.items():
|
208 |
+
sub_state_dict = {key: state_dict[key] for key in keys}
|
209 |
+
state_dicts.append(sub_state_dict)
|
210 |
+
return state_dicts
|
211 |
+
|
212 |
+
def hash_state_dict_keys(state_dict, with_shape=True):
|
213 |
+
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
214 |
+
keys_str = keys_str.encode(encoding="UTF-8")
|
215 |
+
return hashlib.md5(keys_str).hexdigest()
|
216 |
+
|
217 |
+
def split_state_dict_with_prefix(state_dict):
|
218 |
+
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
219 |
+
prefix_dict = {}
|
220 |
+
for key in keys:
|
221 |
+
prefix = key if "." not in key else key.split(".")[0]
|
222 |
+
if prefix not in prefix_dict:
|
223 |
+
prefix_dict[prefix] = []
|
224 |
+
prefix_dict[prefix].append(key)
|
225 |
+
state_dicts = []
|
226 |
+
for prefix, keys in prefix_dict.items():
|
227 |
+
sub_state_dict = {key: state_dict[key] for key in keys}
|
228 |
+
state_dicts.append(sub_state_dict)
|
229 |
+
return state_dicts
|
230 |
+
|
231 |
+
def search_for_files(folder, extensions):
|
232 |
+
files = []
|
233 |
+
if os.path.isdir(folder):
|
234 |
+
for file in sorted(os.listdir(folder)):
|
235 |
+
files += search_for_files(os.path.join(folder, file), extensions)
|
236 |
+
elif os.path.isfile(folder):
|
237 |
+
for extension in extensions:
|
238 |
+
if folder.endswith(extension):
|
239 |
+
files.append(folder)
|
240 |
+
break
|
241 |
+
return files
|
242 |
+
|
243 |
+
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
244 |
+
keys = []
|
245 |
+
for key, value in state_dict.items():
|
246 |
+
if isinstance(key, str):
|
247 |
+
if isinstance(value, torch.Tensor):
|
248 |
+
if with_shape:
|
249 |
+
shape = "_".join(map(str, list(value.shape)))
|
250 |
+
keys.append(key + ":" + shape)
|
251 |
+
keys.append(key)
|
252 |
+
elif isinstance(value, dict):
|
253 |
+
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
254 |
+
keys.sort()
|
255 |
+
keys_str = ",".join(keys)
|
256 |
+
return keys_str
|
OmniAvatar/vram_management/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .layers import *
|
OmniAvatar/vram_management/layers.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, copy
|
2 |
+
from ..utils.io_utils import init_weights_on_device
|
3 |
+
|
4 |
+
|
5 |
+
def cast_to(weight, dtype, device):
|
6 |
+
r = torch.empty_like(weight, dtype=dtype, device=device)
|
7 |
+
r.copy_(weight)
|
8 |
+
return r
|
9 |
+
|
10 |
+
|
11 |
+
class AutoWrappedModule(torch.nn.Module):
|
12 |
+
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
13 |
+
super().__init__()
|
14 |
+
self.module = module.to(dtype=offload_dtype, device=offload_device)
|
15 |
+
self.offload_dtype = offload_dtype
|
16 |
+
self.offload_device = offload_device
|
17 |
+
self.onload_dtype = onload_dtype
|
18 |
+
self.onload_device = onload_device
|
19 |
+
self.computation_dtype = computation_dtype
|
20 |
+
self.computation_device = computation_device
|
21 |
+
self.state = 0
|
22 |
+
|
23 |
+
def offload(self):
|
24 |
+
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
25 |
+
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
|
26 |
+
self.state = 0
|
27 |
+
|
28 |
+
def onload(self):
|
29 |
+
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
30 |
+
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
|
31 |
+
self.state = 1
|
32 |
+
|
33 |
+
def forward(self, *args, **kwargs):
|
34 |
+
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
35 |
+
module = self.module
|
36 |
+
else:
|
37 |
+
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
38 |
+
return module(*args, **kwargs)
|
39 |
+
|
40 |
+
|
41 |
+
class AutoWrappedLinear(torch.nn.Linear):
|
42 |
+
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
43 |
+
with init_weights_on_device(device=torch.device("meta")):
|
44 |
+
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
45 |
+
self.weight = module.weight
|
46 |
+
self.bias = module.bias
|
47 |
+
self.offload_dtype = offload_dtype
|
48 |
+
self.offload_device = offload_device
|
49 |
+
self.onload_dtype = onload_dtype
|
50 |
+
self.onload_device = onload_device
|
51 |
+
self.computation_dtype = computation_dtype
|
52 |
+
self.computation_device = computation_device
|
53 |
+
self.state = 0
|
54 |
+
|
55 |
+
def offload(self):
|
56 |
+
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
57 |
+
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
58 |
+
self.state = 0
|
59 |
+
|
60 |
+
def onload(self):
|
61 |
+
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
62 |
+
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
63 |
+
self.state = 1
|
64 |
+
|
65 |
+
def forward(self, x, *args, **kwargs):
|
66 |
+
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
67 |
+
weight, bias = self.weight, self.bias
|
68 |
+
else:
|
69 |
+
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
70 |
+
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
71 |
+
return torch.nn.functional.linear(x, weight, bias)
|
72 |
+
|
73 |
+
|
74 |
+
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
|
75 |
+
for name, module in model.named_children():
|
76 |
+
for source_module, target_module in module_map.items():
|
77 |
+
if isinstance(module, source_module):
|
78 |
+
num_param = sum(p.numel() for p in module.parameters())
|
79 |
+
if max_num_param is not None and total_num_param + num_param > max_num_param:
|
80 |
+
module_config_ = overflow_module_config
|
81 |
+
else:
|
82 |
+
module_config_ = module_config
|
83 |
+
module_ = target_module(module, **module_config_)
|
84 |
+
setattr(model, name, module_)
|
85 |
+
total_num_param += num_param
|
86 |
+
break
|
87 |
+
else:
|
88 |
+
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
|
89 |
+
return total_num_param
|
90 |
+
|
91 |
+
|
92 |
+
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
|
93 |
+
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
|
94 |
+
model.vram_management_enabled = True
|
95 |
+
|
OmniAvatar/wan_video.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import types
|
2 |
+
from .models.model_manager import ModelManager
|
3 |
+
from .models.wan_video_dit import WanModel
|
4 |
+
from .models.wan_video_text_encoder import WanTextEncoder
|
5 |
+
from .models.wan_video_vae import WanVideoVAE
|
6 |
+
from .schedulers.flow_match import FlowMatchScheduler
|
7 |
+
from .base import BasePipeline
|
8 |
+
from .prompters import WanPrompter
|
9 |
+
import torch, os
|
10 |
+
from einops import rearrange
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
from tqdm import tqdm
|
14 |
+
from typing import Optional
|
15 |
+
from .vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
16 |
+
from .models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
17 |
+
from .models.wan_video_dit import RMSNorm
|
18 |
+
from .models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
19 |
+
|
20 |
+
|
21 |
+
class WanVideoPipeline(BasePipeline):
|
22 |
+
|
23 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
|
24 |
+
super().__init__(device=device, torch_dtype=torch_dtype)
|
25 |
+
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
26 |
+
self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
|
27 |
+
self.text_encoder: WanTextEncoder = None
|
28 |
+
self.image_encoder = None
|
29 |
+
self.dit: WanModel = None
|
30 |
+
self.vae: WanVideoVAE = None
|
31 |
+
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder']
|
32 |
+
self.height_division_factor = 16
|
33 |
+
self.width_division_factor = 16
|
34 |
+
self.use_unified_sequence_parallel = False
|
35 |
+
self.sp_size = 1
|
36 |
+
|
37 |
+
|
38 |
+
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
39 |
+
dtype = next(iter(self.text_encoder.parameters())).dtype
|
40 |
+
enable_vram_management(
|
41 |
+
self.text_encoder,
|
42 |
+
module_map = {
|
43 |
+
torch.nn.Linear: AutoWrappedLinear,
|
44 |
+
torch.nn.Embedding: AutoWrappedModule,
|
45 |
+
T5RelativeEmbedding: AutoWrappedModule,
|
46 |
+
T5LayerNorm: AutoWrappedModule,
|
47 |
+
},
|
48 |
+
module_config = dict(
|
49 |
+
offload_dtype=dtype,
|
50 |
+
offload_device="cpu",
|
51 |
+
onload_dtype=dtype,
|
52 |
+
onload_device="cpu",
|
53 |
+
computation_dtype=self.torch_dtype,
|
54 |
+
computation_device=self.device,
|
55 |
+
),
|
56 |
+
)
|
57 |
+
dtype = next(iter(self.dit.parameters())).dtype
|
58 |
+
enable_vram_management(
|
59 |
+
self.dit,
|
60 |
+
module_map = {
|
61 |
+
torch.nn.Linear: AutoWrappedLinear,
|
62 |
+
torch.nn.Conv3d: AutoWrappedModule,
|
63 |
+
torch.nn.LayerNorm: AutoWrappedModule,
|
64 |
+
RMSNorm: AutoWrappedModule,
|
65 |
+
},
|
66 |
+
module_config = dict(
|
67 |
+
offload_dtype=dtype,
|
68 |
+
offload_device="cpu",
|
69 |
+
onload_dtype=dtype,
|
70 |
+
onload_device=self.device,
|
71 |
+
computation_dtype=self.torch_dtype,
|
72 |
+
computation_device=self.device,
|
73 |
+
),
|
74 |
+
max_num_param=num_persistent_param_in_dit,
|
75 |
+
overflow_module_config = dict(
|
76 |
+
offload_dtype=dtype,
|
77 |
+
offload_device="cpu",
|
78 |
+
onload_dtype=dtype,
|
79 |
+
onload_device="cpu",
|
80 |
+
computation_dtype=self.torch_dtype,
|
81 |
+
computation_device=self.device,
|
82 |
+
),
|
83 |
+
)
|
84 |
+
dtype = next(iter(self.vae.parameters())).dtype
|
85 |
+
enable_vram_management(
|
86 |
+
self.vae,
|
87 |
+
module_map = {
|
88 |
+
torch.nn.Linear: AutoWrappedLinear,
|
89 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
90 |
+
RMS_norm: AutoWrappedModule,
|
91 |
+
CausalConv3d: AutoWrappedModule,
|
92 |
+
Upsample: AutoWrappedModule,
|
93 |
+
torch.nn.SiLU: AutoWrappedModule,
|
94 |
+
torch.nn.Dropout: AutoWrappedModule,
|
95 |
+
},
|
96 |
+
module_config = dict(
|
97 |
+
offload_dtype=dtype,
|
98 |
+
offload_device="cpu",
|
99 |
+
onload_dtype=dtype,
|
100 |
+
onload_device=self.device,
|
101 |
+
computation_dtype=self.torch_dtype,
|
102 |
+
computation_device=self.device,
|
103 |
+
),
|
104 |
+
)
|
105 |
+
if self.image_encoder is not None:
|
106 |
+
dtype = next(iter(self.image_encoder.parameters())).dtype
|
107 |
+
enable_vram_management(
|
108 |
+
self.image_encoder,
|
109 |
+
module_map = {
|
110 |
+
torch.nn.Linear: AutoWrappedLinear,
|
111 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
112 |
+
torch.nn.LayerNorm: AutoWrappedModule,
|
113 |
+
},
|
114 |
+
module_config = dict(
|
115 |
+
offload_dtype=dtype,
|
116 |
+
offload_device="cpu",
|
117 |
+
onload_dtype=dtype,
|
118 |
+
onload_device="cpu",
|
119 |
+
computation_dtype=dtype,
|
120 |
+
computation_device=self.device,
|
121 |
+
),
|
122 |
+
)
|
123 |
+
self.enable_cpu_offload()
|
124 |
+
|
125 |
+
|
126 |
+
def fetch_models(self, model_manager: ModelManager):
|
127 |
+
text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
|
128 |
+
if text_encoder_model_and_path is not None:
|
129 |
+
self.text_encoder, tokenizer_path = text_encoder_model_and_path
|
130 |
+
self.prompter.fetch_models(self.text_encoder)
|
131 |
+
self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
|
132 |
+
self.dit = model_manager.fetch_model("wan_video_dit")
|
133 |
+
self.vae = model_manager.fetch_model("wan_video_vae")
|
134 |
+
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
135 |
+
|
136 |
+
|
137 |
+
@staticmethod
|
138 |
+
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False, infer=False):
|
139 |
+
if device is None: device = model_manager.device
|
140 |
+
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
141 |
+
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
142 |
+
pipe.fetch_models(model_manager)
|
143 |
+
if use_usp:
|
144 |
+
from xfuser.core.distributed import get_sequence_parallel_world_size, get_sp_group
|
145 |
+
from OmniAvatar.distributed.xdit_context_parallel import usp_attn_forward
|
146 |
+
for block in pipe.dit.blocks:
|
147 |
+
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
148 |
+
pipe.sp_size = get_sequence_parallel_world_size()
|
149 |
+
pipe.use_unified_sequence_parallel = True
|
150 |
+
pipe.sp_group = get_sp_group()
|
151 |
+
return pipe
|
152 |
+
|
153 |
+
|
154 |
+
def denoising_model(self):
|
155 |
+
return self.dit
|
156 |
+
|
157 |
+
|
158 |
+
def encode_prompt(self, prompt, positive=True):
|
159 |
+
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
|
160 |
+
return {"context": prompt_emb}
|
161 |
+
|
162 |
+
|
163 |
+
def encode_image(self, image, num_frames, height, width):
|
164 |
+
image = self.preprocess_image(image.resize((width, height))).to(self.device, dtype=self.torch_dtype)
|
165 |
+
clip_context = self.image_encoder.encode_image([image])
|
166 |
+
clip_context = clip_context.to(dtype=self.torch_dtype)
|
167 |
+
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device, dtype=self.torch_dtype)
|
168 |
+
msk[:, 1:] = 0
|
169 |
+
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
170 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
171 |
+
msk = msk.transpose(1, 2)[0]
|
172 |
+
|
173 |
+
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device, dtype=self.torch_dtype)], dim=1)
|
174 |
+
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
|
175 |
+
y = torch.concat([msk, y])
|
176 |
+
y = y.unsqueeze(0)
|
177 |
+
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
|
178 |
+
y = y.to(dtype=self.torch_dtype, device=self.device)
|
179 |
+
return {"clip_feature": clip_context, "y": y}
|
180 |
+
|
181 |
+
|
182 |
+
def tensor2video(self, frames):
|
183 |
+
frames = rearrange(frames, "C T H W -> T H W C")
|
184 |
+
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
185 |
+
frames = [Image.fromarray(frame) for frame in frames]
|
186 |
+
return frames
|
187 |
+
|
188 |
+
|
189 |
+
def prepare_extra_input(self, latents=None):
|
190 |
+
return {}
|
191 |
+
|
192 |
+
|
193 |
+
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
194 |
+
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
195 |
+
return latents
|
196 |
+
|
197 |
+
|
198 |
+
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
199 |
+
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
200 |
+
return frames
|
201 |
+
|
202 |
+
|
203 |
+
def prepare_unified_sequence_parallel(self):
|
204 |
+
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
205 |
+
|
206 |
+
|
207 |
+
@torch.no_grad()
|
208 |
+
def log_video(
|
209 |
+
self,
|
210 |
+
lat,
|
211 |
+
prompt,
|
212 |
+
fixed_frame=0, # lat frames
|
213 |
+
image_emb={},
|
214 |
+
audio_emb={},
|
215 |
+
negative_prompt="",
|
216 |
+
cfg_scale=5.0,
|
217 |
+
audio_cfg_scale=5.0,
|
218 |
+
num_inference_steps=50,
|
219 |
+
denoising_strength=1.0,
|
220 |
+
sigma_shift=5.0,
|
221 |
+
tiled=True,
|
222 |
+
tile_size=(30, 52),
|
223 |
+
tile_stride=(15, 26),
|
224 |
+
tea_cache_l1_thresh=None,
|
225 |
+
tea_cache_model_id="",
|
226 |
+
progress_bar_cmd=tqdm,
|
227 |
+
return_latent=False,
|
228 |
+
):
|
229 |
+
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
230 |
+
# Scheduler
|
231 |
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
232 |
+
|
233 |
+
lat = lat.to(dtype=self.torch_dtype)
|
234 |
+
latents = lat.clone()
|
235 |
+
latents = torch.randn_like(latents, dtype=self.torch_dtype)
|
236 |
+
|
237 |
+
# Encode prompts
|
238 |
+
self.load_models_to_device(["text_encoder"])
|
239 |
+
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
240 |
+
if cfg_scale != 1.0:
|
241 |
+
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
242 |
+
|
243 |
+
# Extra input
|
244 |
+
extra_input = self.prepare_extra_input(latents)
|
245 |
+
|
246 |
+
# TeaCache
|
247 |
+
tea_cache_posi = {"tea_cache": None}
|
248 |
+
tea_cache_nega = {"tea_cache": None}
|
249 |
+
|
250 |
+
# Denoise
|
251 |
+
self.load_models_to_device(["dit"])
|
252 |
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
253 |
+
if fixed_frame > 0: # new
|
254 |
+
latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
|
255 |
+
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
256 |
+
|
257 |
+
# Inference
|
258 |
+
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **audio_emb, **tea_cache_posi, **extra_input)
|
259 |
+
print(f'noise_pred_posi:{noise_pred_posi.dtype}')
|
260 |
+
if cfg_scale != 1.0:
|
261 |
+
audio_emb_uc = {}
|
262 |
+
for key in audio_emb.keys():
|
263 |
+
audio_emb_uc[key] = torch.zeros_like(audio_emb[key], dtype=self.torch_dtype)
|
264 |
+
if audio_cfg_scale == cfg_scale:
|
265 |
+
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **audio_emb_uc, **tea_cache_nega, **extra_input)
|
266 |
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
267 |
+
else:
|
268 |
+
tea_cache_nega_audio = {"tea_cache": None}
|
269 |
+
audio_noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **audio_emb_uc, **tea_cache_nega_audio, **extra_input)
|
270 |
+
text_noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **audio_emb_uc, **tea_cache_nega, **extra_input)
|
271 |
+
noise_pred = text_noise_pred_nega + cfg_scale * (audio_noise_pred_nega - text_noise_pred_nega) + audio_cfg_scale * (noise_pred_posi - audio_noise_pred_nega)
|
272 |
+
else:
|
273 |
+
noise_pred = noise_pred_posi
|
274 |
+
# Scheduler
|
275 |
+
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
276 |
+
|
277 |
+
if fixed_frame > 0: # new
|
278 |
+
latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
|
279 |
+
# Decode
|
280 |
+
self.load_models_to_device(['vae'])
|
281 |
+
frames = self.decode_video(latents, **tiler_kwargs)
|
282 |
+
recons = self.decode_video(lat, **tiler_kwargs)
|
283 |
+
self.load_models_to_device([])
|
284 |
+
frames = (frames.permute(0, 2, 1, 3, 4).float() + 1.0) / 2.0
|
285 |
+
recons = (recons.permute(0, 2, 1, 3, 4).float() + 1.0) / 2.0
|
286 |
+
if return_latent:
|
287 |
+
return frames, recons, latents
|
288 |
+
return frames, recons
|
289 |
+
|
290 |
+
|
291 |
+
class TeaCache:
|
292 |
+
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
293 |
+
self.num_inference_steps = num_inference_steps
|
294 |
+
self.step = 0
|
295 |
+
self.accumulated_rel_l1_distance = 0
|
296 |
+
self.previous_modulated_input = None
|
297 |
+
self.rel_l1_thresh = rel_l1_thresh
|
298 |
+
self.previous_residual = None
|
299 |
+
self.previous_hidden_states = None
|
300 |
+
|
301 |
+
self.coefficients_dict = {
|
302 |
+
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
303 |
+
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
304 |
+
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
305 |
+
"Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
306 |
+
}
|
307 |
+
if model_id not in self.coefficients_dict:
|
308 |
+
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
309 |
+
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
310 |
+
self.coefficients = self.coefficients_dict[model_id]
|
311 |
+
|
312 |
+
def check(self, dit: WanModel, x, t_mod):
|
313 |
+
modulated_inp = t_mod.clone()
|
314 |
+
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
315 |
+
should_calc = True
|
316 |
+
self.accumulated_rel_l1_distance = 0
|
317 |
+
else:
|
318 |
+
coefficients = self.coefficients
|
319 |
+
rescale_func = np.poly1d(coefficients)
|
320 |
+
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
321 |
+
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
322 |
+
should_calc = False
|
323 |
+
else:
|
324 |
+
should_calc = True
|
325 |
+
self.accumulated_rel_l1_distance = 0
|
326 |
+
self.previous_modulated_input = modulated_inp
|
327 |
+
self.step += 1
|
328 |
+
if self.step == self.num_inference_steps:
|
329 |
+
self.step = 0
|
330 |
+
if should_calc:
|
331 |
+
self.previous_hidden_states = x.clone()
|
332 |
+
return not should_calc
|
333 |
+
|
334 |
+
def store(self, hidden_states):
|
335 |
+
self.previous_residual = hidden_states - self.previous_hidden_states
|
336 |
+
self.previous_hidden_states = None
|
337 |
+
|
338 |
+
def update(self, hidden_states):
|
339 |
+
hidden_states = hidden_states + self.previous_residual
|
340 |
+
return hidden_states
|
README.md
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
-
---
|
2 |
-
title: OmniAvatar
|
3 |
-
emoji:
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: yellow
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: OmniAvatar
|
3 |
+
emoji: 🏆🏆🏆🏆🏆🏆
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.36.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
app.py
ADDED
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import subprocess
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
import os, sys
|
6 |
+
from glob import glob
|
7 |
+
from datetime import datetime
|
8 |
+
import math
|
9 |
+
import random
|
10 |
+
import librosa
|
11 |
+
import numpy as np
|
12 |
+
import uuid
|
13 |
+
import shutil
|
14 |
+
|
15 |
+
import importlib, site, sys
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
print(f'torch version:{torch.__version__}')
|
20 |
+
|
21 |
+
|
22 |
+
import torch.nn as nn
|
23 |
+
from tqdm import tqdm
|
24 |
+
from functools import partial
|
25 |
+
from omegaconf import OmegaConf
|
26 |
+
from argparse import Namespace
|
27 |
+
|
28 |
+
# load the one true config you dumped
|
29 |
+
_args_cfg = OmegaConf.load("args_config.yaml")
|
30 |
+
args = Namespace(**OmegaConf.to_container(_args_cfg, resolve=True))
|
31 |
+
|
32 |
+
from OmniAvatar.utils.args_config import set_global_args
|
33 |
+
|
34 |
+
set_global_args(args)
|
35 |
+
# args = parse_args()
|
36 |
+
|
37 |
+
from OmniAvatar.utils.io_utils import load_state_dict
|
38 |
+
from peft import LoraConfig, inject_adapter_in_model
|
39 |
+
from OmniAvatar.models.model_manager import ModelManager
|
40 |
+
from OmniAvatar.schedulers.flow_match import FlowMatchScheduler
|
41 |
+
from OmniAvatar.wan_video import WanVideoPipeline
|
42 |
+
from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4
|
43 |
+
import torchvision.transforms as TT
|
44 |
+
from transformers import Wav2Vec2FeatureExtractor
|
45 |
+
import torchvision.transforms as transforms
|
46 |
+
import torch.nn.functional as F
|
47 |
+
from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
|
48 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
49 |
+
|
50 |
+
os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
|
51 |
+
|
52 |
+
def tensor_to_pil(tensor):
|
53 |
+
"""
|
54 |
+
Args:
|
55 |
+
tensor: torch.Tensor with shape like
|
56 |
+
(1, C, H, W), (1, C, 1, H, W), (C, H, W), etc.
|
57 |
+
values in [-1, 1], on any device.
|
58 |
+
Returns:
|
59 |
+
A PIL.Image in RGB mode.
|
60 |
+
"""
|
61 |
+
# 1) Remove batch dim if it exists
|
62 |
+
if tensor.dim() > 3 and tensor.shape[0] == 1:
|
63 |
+
tensor = tensor[0]
|
64 |
+
|
65 |
+
# 2) Squeeze out any other singleton dims (e.g. that extra frame axis)
|
66 |
+
tensor = tensor.squeeze()
|
67 |
+
|
68 |
+
# Now we should have exactly 3 dims: (C, H, W)
|
69 |
+
if tensor.dim() != 3:
|
70 |
+
raise ValueError(f"Expected 3 dims after squeeze, got {tensor.dim()}")
|
71 |
+
|
72 |
+
# 3) Move to CPU float32
|
73 |
+
tensor = tensor.cpu().float()
|
74 |
+
|
75 |
+
# 4) Undo normalization from [-1,1] -> [0,1]
|
76 |
+
tensor = (tensor + 1.0) / 2.0
|
77 |
+
|
78 |
+
# 5) Clamp to [0,1]
|
79 |
+
tensor = torch.clamp(tensor, 0.0, 1.0)
|
80 |
+
|
81 |
+
# 6) To NumPy H×W×C in [0,255]
|
82 |
+
np_img = (tensor.permute(1, 2, 0).numpy() * 255.0).round().astype("uint8")
|
83 |
+
|
84 |
+
# 7) Build PIL Image
|
85 |
+
return Image.fromarray(np_img)
|
86 |
+
|
87 |
+
|
88 |
+
def set_seed(seed: int = 42):
|
89 |
+
random.seed(seed)
|
90 |
+
np.random.seed(seed)
|
91 |
+
torch.manual_seed(seed)
|
92 |
+
torch.cuda.manual_seed(seed) # 设置当前GPU
|
93 |
+
torch.cuda.manual_seed_all(seed) # 设置所有GPU
|
94 |
+
|
95 |
+
def read_from_file(p):
|
96 |
+
with open(p, "r") as fin:
|
97 |
+
for l in fin:
|
98 |
+
yield l.strip()
|
99 |
+
|
100 |
+
def match_size(image_size, h, w):
|
101 |
+
ratio_ = 9999
|
102 |
+
size_ = 9999
|
103 |
+
select_size = None
|
104 |
+
for image_s in image_size:
|
105 |
+
ratio_tmp = abs(image_s[0] / image_s[1] - h / w)
|
106 |
+
size_tmp = abs(max(image_s) - max(w, h))
|
107 |
+
if ratio_tmp < ratio_:
|
108 |
+
ratio_ = ratio_tmp
|
109 |
+
size_ = size_tmp
|
110 |
+
select_size = image_s
|
111 |
+
if ratio_ == ratio_tmp:
|
112 |
+
if size_ == size_tmp:
|
113 |
+
select_size = image_s
|
114 |
+
return select_size
|
115 |
+
|
116 |
+
def resize_pad(image, ori_size, tgt_size):
|
117 |
+
h, w = ori_size
|
118 |
+
scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w)
|
119 |
+
scale_h = int(h * scale_ratio)
|
120 |
+
scale_w = int(w * scale_ratio)
|
121 |
+
|
122 |
+
image = transforms.Resize(size=[scale_h, scale_w])(image)
|
123 |
+
|
124 |
+
padding_h = tgt_size[0] - scale_h
|
125 |
+
padding_w = tgt_size[1] - scale_w
|
126 |
+
pad_top = padding_h // 2
|
127 |
+
pad_bottom = padding_h - pad_top
|
128 |
+
pad_left = padding_w // 2
|
129 |
+
pad_right = padding_w - pad_left
|
130 |
+
|
131 |
+
image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
|
132 |
+
return image
|
133 |
+
|
134 |
+
class WanInferencePipeline(nn.Module):
|
135 |
+
def __init__(self, args):
|
136 |
+
super().__init__()
|
137 |
+
self.args = args
|
138 |
+
self.device = torch.device(f"cuda")
|
139 |
+
self.dtype = torch.bfloat16
|
140 |
+
self.pipe = self.load_model()
|
141 |
+
chained_trainsforms = []
|
142 |
+
chained_trainsforms.append(TT.ToTensor())
|
143 |
+
self.transform = TT.Compose(chained_trainsforms)
|
144 |
+
|
145 |
+
if self.args.use_audio:
|
146 |
+
from OmniAvatar.models.wav2vec import Wav2VecModel
|
147 |
+
self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
148 |
+
self.args.wav2vec_path
|
149 |
+
)
|
150 |
+
self.audio_encoder = Wav2VecModel.from_pretrained(self.args.wav2vec_path, local_files_only=True).to(device=self.device, dtype=self.dtype)
|
151 |
+
self.audio_encoder.feature_extractor._freeze_parameters()
|
152 |
+
|
153 |
+
|
154 |
+
def load_model(self):
|
155 |
+
ckpt_path = f'{self.args.exp_path}/pytorch_model.pt'
|
156 |
+
assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}"
|
157 |
+
if self.args.train_architecture == 'lora':
|
158 |
+
self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path
|
159 |
+
else:
|
160 |
+
resume_path = ckpt_path
|
161 |
+
|
162 |
+
self.step = 0
|
163 |
+
|
164 |
+
# Load models
|
165 |
+
model_manager = ModelManager(device="cuda", infer=True)
|
166 |
+
|
167 |
+
model_manager.load_models(
|
168 |
+
[
|
169 |
+
self.args.dit_path.split(","),
|
170 |
+
self.args.vae_path,
|
171 |
+
self.args.text_encoder_path
|
172 |
+
],
|
173 |
+
torch_dtype=self.dtype,
|
174 |
+
device='cuda',
|
175 |
+
)
|
176 |
+
|
177 |
+
pipe = WanVideoPipeline.from_model_manager(model_manager,
|
178 |
+
torch_dtype=self.dtype,
|
179 |
+
device="cuda",
|
180 |
+
use_usp=False,
|
181 |
+
infer=True)
|
182 |
+
|
183 |
+
if self.args.train_architecture == "lora":
|
184 |
+
print(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}')
|
185 |
+
self.add_lora_to_model(
|
186 |
+
pipe.denoising_model(),
|
187 |
+
lora_rank=self.args.lora_rank,
|
188 |
+
lora_alpha=self.args.lora_alpha,
|
189 |
+
lora_target_modules=self.args.lora_target_modules,
|
190 |
+
init_lora_weights=self.args.init_lora_weights,
|
191 |
+
pretrained_lora_path=pretrained_lora_path,
|
192 |
+
)
|
193 |
+
print(next(pipe.denoising_model().parameters()).device)
|
194 |
+
else:
|
195 |
+
missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True)
|
196 |
+
print(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
|
197 |
+
pipe.requires_grad_(False)
|
198 |
+
pipe.eval()
|
199 |
+
# pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
|
200 |
+
return pipe
|
201 |
+
|
202 |
+
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
|
203 |
+
# Add LoRA to UNet
|
204 |
+
|
205 |
+
self.lora_alpha = lora_alpha
|
206 |
+
if init_lora_weights == "kaiming":
|
207 |
+
init_lora_weights = True
|
208 |
+
|
209 |
+
lora_config = LoraConfig(
|
210 |
+
r=lora_rank,
|
211 |
+
lora_alpha=lora_alpha,
|
212 |
+
init_lora_weights=init_lora_weights,
|
213 |
+
target_modules=lora_target_modules.split(","),
|
214 |
+
)
|
215 |
+
model = inject_adapter_in_model(lora_config, model)
|
216 |
+
|
217 |
+
# Lora pretrained lora weights
|
218 |
+
if pretrained_lora_path is not None:
|
219 |
+
state_dict = load_state_dict(pretrained_lora_path, torch_dtype=self.dtype)
|
220 |
+
if state_dict_converter is not None:
|
221 |
+
state_dict = state_dict_converter(state_dict)
|
222 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
223 |
+
all_keys = [i for i, _ in model.named_parameters()]
|
224 |
+
num_updated_keys = len(all_keys) - len(missing_keys)
|
225 |
+
num_unexpected_keys = len(unexpected_keys)
|
226 |
+
|
227 |
+
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
|
228 |
+
|
229 |
+
def get_times(self, prompt,
|
230 |
+
image_path=None,
|
231 |
+
audio_path=None,
|
232 |
+
seq_len=101, # not used while audio_path is not None
|
233 |
+
height=720,
|
234 |
+
width=720,
|
235 |
+
overlap_frame=None,
|
236 |
+
num_steps=None,
|
237 |
+
negative_prompt=None,
|
238 |
+
guidance_scale=None,
|
239 |
+
audio_scale=None):
|
240 |
+
|
241 |
+
overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
|
242 |
+
num_steps = num_steps if num_steps is not None else self.args.num_steps
|
243 |
+
negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
|
244 |
+
guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
|
245 |
+
audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
|
246 |
+
|
247 |
+
if image_path is not None:
|
248 |
+
from PIL import Image
|
249 |
+
image = Image.open(image_path).convert("RGB")
|
250 |
+
|
251 |
+
image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
|
252 |
+
|
253 |
+
_, _, h, w = image.shape
|
254 |
+
select_size = match_size(getattr( self.args, f'image_sizes_{ self.args.max_hw}'), h, w)
|
255 |
+
image = resize_pad(image, (h, w), select_size)
|
256 |
+
image = image * 2.0 - 1.0
|
257 |
+
image = image[:, :, None]
|
258 |
+
|
259 |
+
else:
|
260 |
+
image = None
|
261 |
+
select_size = [height, width]
|
262 |
+
num = self.args.max_tokens * 16 * 16 * 4
|
263 |
+
den = select_size[0] * select_size[1]
|
264 |
+
L0 = num // den
|
265 |
+
diff = (L0 - 1) % 4
|
266 |
+
L = L0 - diff
|
267 |
+
if L < 1:
|
268 |
+
L = 1
|
269 |
+
T = (L + 3) // 4
|
270 |
+
|
271 |
+
|
272 |
+
if self.args.random_prefix_frames:
|
273 |
+
fixed_frame = overlap_frame
|
274 |
+
assert fixed_frame % 4 == 1
|
275 |
+
else:
|
276 |
+
fixed_frame = 1
|
277 |
+
prefix_lat_frame = (3 + fixed_frame) // 4
|
278 |
+
first_fixed_frame = 1
|
279 |
+
|
280 |
+
|
281 |
+
audio, sr = librosa.load(audio_path, sr= self.args.sample_rate)
|
282 |
+
|
283 |
+
input_values = np.squeeze(
|
284 |
+
self.wav_feature_extractor(audio, sampling_rate=16000).input_values
|
285 |
+
)
|
286 |
+
input_values = torch.from_numpy(input_values).float().to(dtype=self.dtype)
|
287 |
+
audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
|
288 |
+
|
289 |
+
if audio_len < L - first_fixed_frame:
|
290 |
+
audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
|
291 |
+
elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
|
292 |
+
audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
|
293 |
+
|
294 |
+
seq_len = audio_len
|
295 |
+
|
296 |
+
times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
|
297 |
+
if times * (L-fixed_frame) + fixed_frame < seq_len:
|
298 |
+
times += 1
|
299 |
+
|
300 |
+
return times
|
301 |
+
|
302 |
+
@torch.no_grad()
|
303 |
+
def forward(self, prompt,
|
304 |
+
image_path=None,
|
305 |
+
audio_path=None,
|
306 |
+
seq_len=101, # not used while audio_path is not None
|
307 |
+
height=720,
|
308 |
+
width=720,
|
309 |
+
overlap_frame=None,
|
310 |
+
num_steps=None,
|
311 |
+
negative_prompt=None,
|
312 |
+
guidance_scale=None,
|
313 |
+
audio_scale=None):
|
314 |
+
overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
|
315 |
+
num_steps = num_steps if num_steps is not None else self.args.num_steps
|
316 |
+
negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
|
317 |
+
guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
|
318 |
+
audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
|
319 |
+
|
320 |
+
if image_path is not None:
|
321 |
+
from PIL import Image
|
322 |
+
image = Image.open(image_path).convert("RGB")
|
323 |
+
|
324 |
+
image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
|
325 |
+
|
326 |
+
_, _, h, w = image.shape
|
327 |
+
select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
|
328 |
+
image = resize_pad(image, (h, w), select_size)
|
329 |
+
image = image * 2.0 - 1.0
|
330 |
+
image = image[:, :, None]
|
331 |
+
|
332 |
+
else:
|
333 |
+
image = None
|
334 |
+
select_size = [height, width]
|
335 |
+
# L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
|
336 |
+
# L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
|
337 |
+
# T = (L + 3) // 4 # latent frames
|
338 |
+
|
339 |
+
# step 1: numerator and denominator as ints
|
340 |
+
num = args.max_tokens * 16 * 16 * 4
|
341 |
+
den = select_size[0] * select_size[1]
|
342 |
+
|
343 |
+
# step 2: integer division
|
344 |
+
L0 = num // den # exact floor division, no float in sight
|
345 |
+
|
346 |
+
# step 3: make it ≡ 1 mod 4
|
347 |
+
# if L0 % 4 == 1, keep L0;
|
348 |
+
# otherwise subtract the difference so that (L0 - diff) % 4 == 1,
|
349 |
+
# but ensure the result stays positive.
|
350 |
+
diff = (L0 - 1) % 4
|
351 |
+
L = L0 - diff
|
352 |
+
if L < 1:
|
353 |
+
L = 1 # or whatever your minimal frame count is
|
354 |
+
|
355 |
+
# step 4: latent frames
|
356 |
+
T = (L + 3) // 4
|
357 |
+
|
358 |
+
|
359 |
+
if self.args.i2v:
|
360 |
+
if self.args.random_prefix_frames:
|
361 |
+
fixed_frame = overlap_frame
|
362 |
+
assert fixed_frame % 4 == 1
|
363 |
+
else:
|
364 |
+
fixed_frame = 1
|
365 |
+
prefix_lat_frame = (3 + fixed_frame) // 4
|
366 |
+
first_fixed_frame = 1
|
367 |
+
else:
|
368 |
+
fixed_frame = 0
|
369 |
+
prefix_lat_frame = 0
|
370 |
+
first_fixed_frame = 0
|
371 |
+
|
372 |
+
|
373 |
+
if audio_path is not None and self.args.use_audio:
|
374 |
+
audio, sr = librosa.load(audio_path, sr=self.args.sample_rate)
|
375 |
+
input_values = np.squeeze(
|
376 |
+
self.wav_feature_extractor(audio, sampling_rate=16000).input_values
|
377 |
+
)
|
378 |
+
input_values = torch.from_numpy(input_values).float().to(device=self.device, dtype=self.dtype)
|
379 |
+
ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
|
380 |
+
input_values = input_values.unsqueeze(0)
|
381 |
+
# padding audio
|
382 |
+
if audio_len < L - first_fixed_frame:
|
383 |
+
audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
|
384 |
+
elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
|
385 |
+
audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
|
386 |
+
input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0)
|
387 |
+
with torch.no_grad():
|
388 |
+
hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True)
|
389 |
+
audio_embeddings = hidden_states.last_hidden_state
|
390 |
+
for mid_hidden_states in hidden_states.hidden_states:
|
391 |
+
audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1)
|
392 |
+
seq_len = audio_len
|
393 |
+
audio_embeddings = audio_embeddings.squeeze(0)
|
394 |
+
audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame])
|
395 |
+
else:
|
396 |
+
audio_embeddings = None
|
397 |
+
|
398 |
+
# loop
|
399 |
+
times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
|
400 |
+
if times * (L-fixed_frame) + fixed_frame < seq_len:
|
401 |
+
times += 1
|
402 |
+
video = []
|
403 |
+
image_emb = {}
|
404 |
+
img_lat = None
|
405 |
+
if self.args.i2v:
|
406 |
+
self.pipe.load_models_to_device(['vae'])
|
407 |
+
img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
|
408 |
+
|
409 |
+
msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1], dtype=self.dtype)
|
410 |
+
image_cat = img_lat.repeat(1, 1, T, 1, 1)
|
411 |
+
msk[:, :, 1:] = 1
|
412 |
+
image_emb["y"] = torch.cat([image_cat, msk], dim=1)
|
413 |
+
|
414 |
+
for t in range(times):
|
415 |
+
print(f"[{t+1}/{times}]")
|
416 |
+
audio_emb = {}
|
417 |
+
if t == 0:
|
418 |
+
overlap = first_fixed_frame
|
419 |
+
else:
|
420 |
+
overlap = fixed_frame
|
421 |
+
image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
|
422 |
+
prefix_overlap = (3 + overlap) // 4
|
423 |
+
if audio_embeddings is not None:
|
424 |
+
if t == 0:
|
425 |
+
audio_tensor = audio_embeddings[
|
426 |
+
:min(L - overlap, audio_embeddings.shape[0])
|
427 |
+
]
|
428 |
+
else:
|
429 |
+
audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
|
430 |
+
audio_tensor = audio_embeddings[
|
431 |
+
audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
|
432 |
+
]
|
433 |
+
|
434 |
+
audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
|
435 |
+
audio_prefix = audio_tensor[-fixed_frame:]
|
436 |
+
audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
|
437 |
+
audio_emb["audio_emb"] = audio_tensor
|
438 |
+
else:
|
439 |
+
audio_prefix = None
|
440 |
+
if image is not None and img_lat is None:
|
441 |
+
self.pipe.load_models_to_device(['vae'])
|
442 |
+
img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
|
443 |
+
assert img_lat.shape[2] == prefix_overlap
|
444 |
+
img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1), dtype=self.dtype)], dim=2)
|
445 |
+
frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
|
446 |
+
negative_prompt, num_inference_steps=num_steps,
|
447 |
+
cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
|
448 |
+
return_latent=True,
|
449 |
+
tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B")
|
450 |
+
|
451 |
+
torch.cuda.empty_cache()
|
452 |
+
img_lat = None
|
453 |
+
image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
|
454 |
+
|
455 |
+
if t == 0:
|
456 |
+
video.append(frames)
|
457 |
+
else:
|
458 |
+
video.append(frames[:, overlap:])
|
459 |
+
video = torch.cat(video, dim=1)
|
460 |
+
video = video[:, :ori_audio_len + 1]
|
461 |
+
|
462 |
+
return video
|
463 |
+
|
464 |
+
|
465 |
+
snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir="./pretrained_models/Wan2.1-T2V-14B")
|
466 |
+
snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
|
467 |
+
snapshot_download(repo_id="OmniAvatar/OmniAvatar-14B", local_dir="./pretrained_models/OmniAvatar-14B")
|
468 |
+
|
469 |
+
|
470 |
+
# snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./pretrained_models/Wan2.1-T2V-1.3B")
|
471 |
+
# snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
|
472 |
+
# snapshot_download(repo_id="OmniAvatar/OmniAvatar-1.3B", local_dir="./pretrained_models/OmniAvatar-1.3B")
|
473 |
+
|
474 |
+
import tempfile
|
475 |
+
|
476 |
+
from PIL import Image
|
477 |
+
|
478 |
+
|
479 |
+
set_seed(args.seed)
|
480 |
+
seq_len = args.seq_len
|
481 |
+
inferpipe = WanInferencePipeline(args)
|
482 |
+
|
483 |
+
|
484 |
+
def update_generate_button(image_path, audio_path, text, num_steps):
|
485 |
+
|
486 |
+
if image_path is None or audio_path is None:
|
487 |
+
return gr.update(value="⌚ Zero GPU Required: --")
|
488 |
+
|
489 |
+
duration_s = get_duration(image_path, audio_path, text, num_steps, None, None)
|
490 |
+
|
491 |
+
return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s")
|
492 |
+
|
493 |
+
def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
|
494 |
+
|
495 |
+
audio_chunks = inferpipe.get_times(
|
496 |
+
prompt=text,
|
497 |
+
image_path=image_path,
|
498 |
+
audio_path=audio_path,
|
499 |
+
seq_len=args.seq_len,
|
500 |
+
num_steps=num_steps
|
501 |
+
)
|
502 |
+
|
503 |
+
warmup_s = 30
|
504 |
+
duration_s = (20 * num_steps) + warmup_s
|
505 |
+
|
506 |
+
if audio_chunks > 1:
|
507 |
+
duration_s = (20 * num_steps * audio_chunks) + warmup_s
|
508 |
+
|
509 |
+
print(f'for {audio_chunks} times, might take {duration_s}')
|
510 |
+
|
511 |
+
return int(duration_s)
|
512 |
+
|
513 |
+
def preprocess_img(image_path, session_id = None):
|
514 |
+
|
515 |
+
if session_id is None:
|
516 |
+
session_id = uuid.uuid4().hex
|
517 |
+
|
518 |
+
image = Image.open(image_path).convert("RGB")
|
519 |
+
|
520 |
+
image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
|
521 |
+
|
522 |
+
_, _, h, w = image.shape
|
523 |
+
select_size = match_size(getattr( args, f'image_sizes_{ args.max_hw}'), h, w)
|
524 |
+
image = resize_pad(image, (h, w), select_size)
|
525 |
+
image = image * 2.0 - 1.0
|
526 |
+
image = image[:, :, None]
|
527 |
+
|
528 |
+
output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
|
529 |
+
|
530 |
+
img_dir = output_dir + '/image'
|
531 |
+
os.makedirs(img_dir, exist_ok=True)
|
532 |
+
input_img_path = os.path.join(img_dir, f"img_input.jpg")
|
533 |
+
|
534 |
+
image = tensor_to_pil(image)
|
535 |
+
image.save(input_img_path)
|
536 |
+
|
537 |
+
return input_img_path
|
538 |
+
|
539 |
+
|
540 |
+
@spaces.GPU(duration=get_duration)
|
541 |
+
def infer(image_path, audio_path, text, num_steps, session_id = None, progress=gr.Progress(track_tqdm=True),):
|
542 |
+
|
543 |
+
if session_id is None:
|
544 |
+
session_id = uuid.uuid4().hex
|
545 |
+
|
546 |
+
output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
|
547 |
+
|
548 |
+
audio_dir = output_dir + '/audio'
|
549 |
+
os.makedirs(audio_dir, exist_ok=True)
|
550 |
+
if args.silence_duration_s > 0:
|
551 |
+
input_audio_path = os.path.join(audio_dir, f"audio_input.wav")
|
552 |
+
else:
|
553 |
+
input_audio_path = audio_path
|
554 |
+
prompt_dir = output_dir + '/prompt'
|
555 |
+
os.makedirs(prompt_dir, exist_ok=True)
|
556 |
+
|
557 |
+
if args.silence_duration_s > 0:
|
558 |
+
add_silence_to_audio_ffmpeg(audio_path, input_audio_path, args.silence_duration_s)
|
559 |
+
|
560 |
+
tmp2_audio_path = os.path.join(audio_dir, f"audio_out.wav")
|
561 |
+
prompt_path = os.path.join(prompt_dir, f"prompt.txt")
|
562 |
+
|
563 |
+
video = inferpipe(
|
564 |
+
prompt=text,
|
565 |
+
image_path=image_path,
|
566 |
+
audio_path=input_audio_path,
|
567 |
+
seq_len=args.seq_len,
|
568 |
+
num_steps=num_steps
|
569 |
+
)
|
570 |
+
|
571 |
+
torch.cuda.empty_cache()
|
572 |
+
|
573 |
+
add_silence_to_audio_ffmpeg(audio_path, tmp2_audio_path, 1.0 / args.fps + args.silence_duration_s)
|
574 |
+
video_paths = save_video_as_grid_and_mp4(video,
|
575 |
+
output_dir,
|
576 |
+
args.fps,
|
577 |
+
prompt=text,
|
578 |
+
prompt_path = prompt_path,
|
579 |
+
audio_path=tmp2_audio_path if args.use_audio else None,
|
580 |
+
prefix=f'result')
|
581 |
+
|
582 |
+
return video_paths[0]
|
583 |
+
|
584 |
+
def cleanup(request: gr.Request):
|
585 |
+
|
586 |
+
sid = request.session_hash
|
587 |
+
if sid:
|
588 |
+
d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
|
589 |
+
shutil.rmtree(d1, ignore_errors=True)
|
590 |
+
|
591 |
+
def start_session(request: gr.Request):
|
592 |
+
|
593 |
+
return request.session_hash
|
594 |
+
|
595 |
+
css = """
|
596 |
+
#col-container {
|
597 |
+
margin: 0 auto;
|
598 |
+
max-width: 1560px;
|
599 |
+
}
|
600 |
+
"""
|
601 |
+
theme = gr.themes.Ocean()
|
602 |
+
|
603 |
+
with gr.Blocks(css=css, theme=theme) as demo:
|
604 |
+
|
605 |
+
session_state = gr.State()
|
606 |
+
demo.load(start_session, outputs=[session_state])
|
607 |
+
|
608 |
+
|
609 |
+
with gr.Column(elem_id="col-container"):
|
610 |
+
gr.HTML(
|
611 |
+
"""
|
612 |
+
<div style="text-align: left;">
|
613 |
+
<p style="font-size:16px; display: inline; margin: 0;">
|
614 |
+
<strong>OmniAvatar</strong> – Efficient Audio-Driven Avatar Video Generation with Adaptive Body Animation
|
615 |
+
</p>
|
616 |
+
<a href="https://github.com/Omni-Avatar/OmniAvatar" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
|
617 |
+
<img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo">
|
618 |
+
</a>
|
619 |
+
</div>
|
620 |
+
<div style="text-align: left;">
|
621 |
+
HF Space by :<a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
|
622 |
+
<img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
|
623 |
+
</a>
|
624 |
+
</div>
|
625 |
+
|
626 |
+
<div style="text-align: left;">
|
627 |
+
<a href="https://huggingface.co/alexnasa">
|
628 |
+
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
|
629 |
+
</a>
|
630 |
+
</div>
|
631 |
+
|
632 |
+
"""
|
633 |
+
)
|
634 |
+
|
635 |
+
with gr.Row():
|
636 |
+
|
637 |
+
with gr.Column():
|
638 |
+
|
639 |
+
image_input = gr.Image(label="Reference Image", type="filepath", height=512)
|
640 |
+
audio_input = gr.Audio(label="Input Audio", type="filepath")
|
641 |
+
|
642 |
+
|
643 |
+
with gr.Column():
|
644 |
+
|
645 |
+
output_video = gr.Video(label="Avatar", height=512)
|
646 |
+
num_steps = gr.Slider(1, 50, value=8, step=1, label="Steps")
|
647 |
+
time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
|
648 |
+
infer_btn = gr.Button("🦜 Avatar Me", variant="primary")
|
649 |
+
with gr.Accordion("Advanced Settings", open=False):
|
650 |
+
text_input = gr.Textbox(label="Prompt Text", lines=4, value="A realistic video of a person speaking directly to the camera on a sofa, with dynamic and rhythmic hand gestures that complement their speech. Their hands are clearly visible, independent, and unobstructed. Their facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.")
|
651 |
+
|
652 |
+
with gr.Column():
|
653 |
+
|
654 |
+
examples = gr.Examples(
|
655 |
+
examples=[
|
656 |
+
[
|
657 |
+
"examples/images/female-001.png",
|
658 |
+
"examples/audios/mushroom.wav",
|
659 |
+
"A realistic video of a woman speaking and sometimes looking directly to the camera, sitting on a sofa, with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
|
660 |
+
12
|
661 |
+
],
|
662 |
+
[
|
663 |
+
"examples/images/male-001.png",
|
664 |
+
"examples/audios/tape.wav",
|
665 |
+
"A realistic video of a man moving his hands extensively and speaking. The motion of his hands matches his speech. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
|
666 |
+
8
|
667 |
+
],
|
668 |
+
],
|
669 |
+
inputs=[image_input, audio_input, text_input, num_steps],
|
670 |
+
outputs=[output_video],
|
671 |
+
fn=infer,
|
672 |
+
cache_examples=True
|
673 |
+
)
|
674 |
+
|
675 |
+
infer_btn.click(
|
676 |
+
fn=infer,
|
677 |
+
inputs=[image_input, audio_input, text_input, num_steps, session_state],
|
678 |
+
outputs=[output_video]
|
679 |
+
)
|
680 |
+
image_input.upload(fn=preprocess_img, inputs=[image_input, session_state], outputs=[image_input]).then(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
|
681 |
+
audio_input.upload(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
|
682 |
+
num_steps.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
|
683 |
+
|
684 |
+
|
685 |
+
if __name__ == "__main__":
|
686 |
+
demo.unload(cleanup)
|
687 |
+
demo.queue()
|
688 |
+
demo.launch(ssr_mode=False)
|
args_config.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
config: configs/inference.yaml
|
2 |
+
|
3 |
+
input_file: examples/infer_samples.txt
|
4 |
+
debug: null
|
5 |
+
infer: false
|
6 |
+
hparams: ''
|
7 |
+
dtype: bf16
|
8 |
+
|
9 |
+
exp_path: pretrained_models/OmniAvatar-14B
|
10 |
+
text_encoder_path: pretrained_models/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth
|
11 |
+
image_encoder_path: None
|
12 |
+
dit_path: pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors
|
13 |
+
vae_path: pretrained_models/Wan2.1-T2V-14B/Wan2.1_VAE.pth
|
14 |
+
|
15 |
+
# exp_path: pretrained_models/OmniAvatar-1.3B
|
16 |
+
# text_encoder_path: pretrained_models/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
|
17 |
+
# image_encoder_path: None
|
18 |
+
# dit_path: pretrained_models/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors
|
19 |
+
# vae_path: pretrained_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
|
20 |
+
|
21 |
+
wav2vec_path: pretrained_models/wav2vec2-base-960h
|
22 |
+
num_persistent_param_in_dit:
|
23 |
+
reload_cfg: true
|
24 |
+
sp_size: 1
|
25 |
+
seed: 42
|
26 |
+
image_sizes_720:
|
27 |
+
- - 400
|
28 |
+
- 720
|
29 |
+
# - - 720 commented out due duration needed on HF
|
30 |
+
# - 720
|
31 |
+
- - 720
|
32 |
+
- 400
|
33 |
+
image_sizes_1280:
|
34 |
+
- - 720
|
35 |
+
- 720
|
36 |
+
- - 528
|
37 |
+
- 960
|
38 |
+
- - 960
|
39 |
+
- 528
|
40 |
+
- - 720
|
41 |
+
- 1280
|
42 |
+
- - 1280
|
43 |
+
- 720
|
44 |
+
max_hw: 720
|
45 |
+
max_tokens: 40000
|
46 |
+
seq_len: 200
|
47 |
+
overlap_frame: 13
|
48 |
+
guidance_scale: 4.5
|
49 |
+
audio_scale: null
|
50 |
+
num_steps: 8
|
51 |
+
fps: 24
|
52 |
+
sample_rate: 16000
|
53 |
+
negative_prompt: Vivid color tones, background/camera moving quickly, screen switching,
|
54 |
+
subtitles and special effects, mutation, overexposed, static, blurred details, subtitles,
|
55 |
+
style, work, painting, image, still, overall grayish, worst quality, low quality,
|
56 |
+
JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly
|
57 |
+
drawn face, deformed, disfigured, malformed limbs, fingers merging, motionless image,
|
58 |
+
chaotic background, three legs, crowded background with many people, walking backward
|
59 |
+
silence_duration_s: 0.0
|
60 |
+
use_fsdp: false
|
61 |
+
tea_cache_l1_thresh: 0
|
62 |
+
rank: 0
|
63 |
+
world_size: 1
|
64 |
+
local_rank: 0
|
65 |
+
device: cuda
|
66 |
+
num_nodes: 1
|
67 |
+
i2v: true
|
68 |
+
use_audio: true
|
69 |
+
random_prefix_frames: true
|
70 |
+
model_config:
|
71 |
+
in_dim: 33
|
72 |
+
audio_hidden_size: 32
|
73 |
+
train_architecture: lora
|
74 |
+
lora_target_modules: q,k,v,o,ffn.0,ffn.2
|
75 |
+
init_lora_weights: kaiming
|
76 |
+
lora_rank: 128
|
77 |
+
lora_alpha: 64.0
|
assets/material/pipeline.png
ADDED
![]() |
assets/material/teaser.png
ADDED
![]() |
configs/inference.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 预训练模型路径
|
2 |
+
dtype: "bf16"
|
3 |
+
text_encoder_path: pretrained_models/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth
|
4 |
+
image_encoder_path: None
|
5 |
+
dit_path: pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors
|
6 |
+
vae_path: pretrained_models/Wan2.1-T2V-14B/Wan2.1_VAE.pth
|
7 |
+
wav2vec_path: pretrained_models/wav2vec2-base-960h
|
8 |
+
exp_path: pretrained_models/OmniAvatar-14B
|
9 |
+
num_persistent_param_in_dit: # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
10 |
+
|
11 |
+
reload_cfg: True
|
12 |
+
sp_size: 1
|
13 |
+
|
14 |
+
# 数据参数
|
15 |
+
seed: 42
|
16 |
+
image_sizes_720: [[400, 720],
|
17 |
+
[720, 720],
|
18 |
+
[720, 400]]
|
19 |
+
image_sizes_1280: [
|
20 |
+
[720, 720],
|
21 |
+
[528, 960],
|
22 |
+
[960, 528],
|
23 |
+
[720, 1280],
|
24 |
+
[1280, 720]]
|
25 |
+
max_hw: 720 # 720: 480p; 1280: 720p
|
26 |
+
max_tokens: 30000
|
27 |
+
seq_len: 200
|
28 |
+
overlap_frame: 13 # must be 1 + 4*n
|
29 |
+
guidance_scale: 4.5
|
30 |
+
audio_scale:
|
31 |
+
num_steps: 16
|
32 |
+
fps: 25
|
33 |
+
sample_rate: 16000
|
34 |
+
negative_prompt: "Vivid color tones, background/camera moving quickly, screen switching, subtitles and special effects, mutation, overexposed, static, blurred details, subtitles, style, work, painting, image, still, overall grayish, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, fingers merging, motionless image, chaotic background, three legs, crowded background with many people, walking backward"
|
35 |
+
silence_duration_s: 0.3
|
36 |
+
use_fsdp: False
|
37 |
+
tea_cache_l1_thresh: 0 # 0.14 The larger this value is, the faster the speed, but the worse the visual quality. TODO check value
|
configs/inference_1.3B.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 预训练模型路径
|
2 |
+
dtype: "bf16"
|
3 |
+
text_encoder_path: pretrained_models/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
|
4 |
+
image_encoder_path: None
|
5 |
+
dit_path: pretrained_models/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors
|
6 |
+
vae_path: pretrained_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
|
7 |
+
wav2vec_path: pretrained_models/wav2vec2-base-960h
|
8 |
+
exp_path: pretrained_models/OmniAvatar-1.3B
|
9 |
+
num_persistent_param_in_dit: # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
10 |
+
|
11 |
+
reload_cfg: True
|
12 |
+
sp_size: 1
|
13 |
+
|
14 |
+
# 数据参数
|
15 |
+
seed: 42
|
16 |
+
image_sizes_720: [[400, 720],
|
17 |
+
[720, 720],
|
18 |
+
[720, 400]]
|
19 |
+
image_sizes_1280: [
|
20 |
+
[720, 720],
|
21 |
+
[528, 960],
|
22 |
+
[960, 528],
|
23 |
+
[720, 1280],
|
24 |
+
[1280, 720]]
|
25 |
+
max_hw: 720 # 720: 480p; 1280: 720p
|
26 |
+
max_tokens: 30000
|
27 |
+
seq_len: 200
|
28 |
+
overlap_frame: 13 # must be 1 + 4*n
|
29 |
+
guidance_scale: 4.5
|
30 |
+
audio_scale:
|
31 |
+
num_steps: 10
|
32 |
+
fps: 25
|
33 |
+
sample_rate: 16000
|
34 |
+
negative_prompt: "Vivid color tones, background/camera moving quickly, screen switching, subtitles and special effects, mutation, overexposed, static, blurred details, subtitles, style, work, painting, image, still, overall grayish, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, fingers merging, motionless image, chaotic background, three legs, crowded background with many people, walking backward"
|
35 |
+
silence_duration_s: 0.3
|
36 |
+
use_fsdp: False
|
37 |
+
tea_cache_l1_thresh: 0 # 0.14 The larger this value is, the faster the speed, but the worse the visual quality. TODO check value
|
examples/audios/mushroom.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d10e51bb6169d206f4eccbbde14868a1c9a09e07bd7a2f18258cc2b265620226
|
3 |
+
size 460878
|
examples/audios/tape.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ba4e3861b6b124ecde7dc621612957613c495c64d4ba8c0f26e9c87fe5c1566c
|
3 |
+
size 270764
|
examples/images/female-001.png
ADDED
![]() |
Git LFS Details
|
examples/images/male-001.png
ADDED
![]() |
Git LFS Details
|
examples/infer_samples.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
A realistic video of a man speaking directly to the camera on a sofa, with dynamic and rhythmic hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.@@examples/images/0000.jpeg@@examples/audios/0000.MP3
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tqdm
|
2 |
+
librosa==0.10.2.post1
|
3 |
+
peft==0.15.1
|
4 |
+
transformers==4.52.3
|
5 |
+
scipy==1.14.0
|
6 |
+
numpy==1.26.4
|
7 |
+
xfuser==0.4.1
|
8 |
+
ftfy
|
9 |
+
einops
|
10 |
+
omegaconf
|
11 |
+
torchvision
|
12 |
+
ninja
|
13 |
+
flash-attn-3 @ https://github.com/OutofAi/PyTorch3D-wheels/releases/download/0.7.8/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
|
scripts/inference.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import os, sys
|
3 |
+
from glob import glob
|
4 |
+
from datetime import datetime
|
5 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
import librosa
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from tqdm import tqdm
|
13 |
+
from functools import partial
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
from argparse import Namespace
|
16 |
+
|
17 |
+
# # load the one true config you dumped
|
18 |
+
# _args_cfg = OmegaConf.load("demo_out/config/args_config.yaml")
|
19 |
+
# args = Namespace(**OmegaConf.to_container(_args_cfg, resolve=True))
|
20 |
+
|
21 |
+
# from OmniAvatar.utils.args_config import set_global_args
|
22 |
+
|
23 |
+
# set_global_args(args)
|
24 |
+
|
25 |
+
from OmniAvatar.utils.args_config import parse_args
|
26 |
+
args = parse_args()
|
27 |
+
|
28 |
+
from OmniAvatar.utils.io_utils import load_state_dict
|
29 |
+
from peft import LoraConfig, inject_adapter_in_model
|
30 |
+
from OmniAvatar.models.model_manager import ModelManager
|
31 |
+
from OmniAvatar.wan_video import WanVideoPipeline
|
32 |
+
from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4
|
33 |
+
import torchvision.transforms as TT
|
34 |
+
from transformers import Wav2Vec2FeatureExtractor
|
35 |
+
import torchvision.transforms as transforms
|
36 |
+
import torch.nn.functional as F
|
37 |
+
from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
|
38 |
+
from huggingface_hub import hf_hub_download
|
39 |
+
|
40 |
+
def set_seed(seed: int = 42):
|
41 |
+
random.seed(seed)
|
42 |
+
np.random.seed(seed)
|
43 |
+
torch.manual_seed(seed)
|
44 |
+
torch.cuda.manual_seed(seed) # 设置当前GPU
|
45 |
+
torch.cuda.manual_seed_all(seed) # 设置所有GPU
|
46 |
+
|
47 |
+
def read_from_file(p):
|
48 |
+
with open(p, "r") as fin:
|
49 |
+
for l in fin:
|
50 |
+
yield l.strip()
|
51 |
+
|
52 |
+
def match_size(image_size, h, w):
|
53 |
+
ratio_ = 9999
|
54 |
+
size_ = 9999
|
55 |
+
select_size = None
|
56 |
+
for image_s in image_size:
|
57 |
+
ratio_tmp = abs(image_s[0] / image_s[1] - h / w)
|
58 |
+
size_tmp = abs(max(image_s) - max(w, h))
|
59 |
+
if ratio_tmp < ratio_:
|
60 |
+
ratio_ = ratio_tmp
|
61 |
+
size_ = size_tmp
|
62 |
+
select_size = image_s
|
63 |
+
if ratio_ == ratio_tmp:
|
64 |
+
if size_ == size_tmp:
|
65 |
+
select_size = image_s
|
66 |
+
return select_size
|
67 |
+
|
68 |
+
def resize_pad(image, ori_size, tgt_size):
|
69 |
+
h, w = ori_size
|
70 |
+
scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w)
|
71 |
+
scale_h = int(h * scale_ratio)
|
72 |
+
scale_w = int(w * scale_ratio)
|
73 |
+
|
74 |
+
image = transforms.Resize(size=[scale_h, scale_w])(image)
|
75 |
+
|
76 |
+
padding_h = tgt_size[0] - scale_h
|
77 |
+
padding_w = tgt_size[1] - scale_w
|
78 |
+
pad_top = padding_h // 2
|
79 |
+
pad_bottom = padding_h - pad_top
|
80 |
+
pad_left = padding_w // 2
|
81 |
+
pad_right = padding_w - pad_left
|
82 |
+
|
83 |
+
image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
|
84 |
+
return image
|
85 |
+
|
86 |
+
class WanInferencePipeline(nn.Module):
|
87 |
+
def __init__(self, args):
|
88 |
+
super().__init__()
|
89 |
+
self.args = args
|
90 |
+
self.device = torch.device(f"cuda")
|
91 |
+
if self.args.dtype=='bf16':
|
92 |
+
self.dtype = torch.bfloat16
|
93 |
+
elif self.args.dtype=='fp16':
|
94 |
+
self.dtype = torch.float16
|
95 |
+
else:
|
96 |
+
self.dtype = torch.float32
|
97 |
+
self.pipe = self.load_model()
|
98 |
+
if self.args.i2v:
|
99 |
+
chained_trainsforms = []
|
100 |
+
chained_trainsforms.append(TT.ToTensor())
|
101 |
+
self.transform = TT.Compose(chained_trainsforms)
|
102 |
+
if self.args.use_audio:
|
103 |
+
from OmniAvatar.models.wav2vec import Wav2VecModel
|
104 |
+
self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
105 |
+
self.args.wav2vec_path
|
106 |
+
)
|
107 |
+
self.audio_encoder = Wav2VecModel.from_pretrained(self.args.wav2vec_path, local_files_only=True).to(device=self.device)
|
108 |
+
self.audio_encoder.feature_extractor._freeze_parameters()
|
109 |
+
|
110 |
+
def load_model(self):
|
111 |
+
torch.cuda.set_device(0)
|
112 |
+
ckpt_path = f'{self.args.exp_path}/pytorch_model.pt'
|
113 |
+
assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}"
|
114 |
+
if self.args.train_architecture == 'lora':
|
115 |
+
self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path
|
116 |
+
else:
|
117 |
+
resume_path = ckpt_path
|
118 |
+
|
119 |
+
self.step = 0
|
120 |
+
|
121 |
+
# Load models
|
122 |
+
model_manager = ModelManager(device="cpu", infer=True)
|
123 |
+
model_manager.load_models(
|
124 |
+
[
|
125 |
+
self.args.dit_path.split(","),
|
126 |
+
self.args.text_encoder_path,
|
127 |
+
self.args.vae_path
|
128 |
+
],
|
129 |
+
torch_dtype=self.dtype, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
130 |
+
device='cpu',
|
131 |
+
)
|
132 |
+
LORA_REPO_ID = "Kijai/WanVideo_comfy"
|
133 |
+
LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
|
134 |
+
causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
|
135 |
+
model_manager.load_lora(causvid_path, lora_alpha=1.0)
|
136 |
+
pipe = WanVideoPipeline.from_model_manager(model_manager,
|
137 |
+
torch_dtype=self.dtype,
|
138 |
+
device=f"cuda",
|
139 |
+
use_usp=True if self.args.sp_size > 1 else False,
|
140 |
+
infer=True)
|
141 |
+
if self.args.train_architecture == "lora":
|
142 |
+
print(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}')
|
143 |
+
self.add_lora_to_model(
|
144 |
+
pipe.denoising_model(),
|
145 |
+
lora_rank=self.args.lora_rank,
|
146 |
+
lora_alpha=self.args.lora_alpha,
|
147 |
+
lora_target_modules=self.args.lora_target_modules,
|
148 |
+
init_lora_weights=self.args.init_lora_weights,
|
149 |
+
pretrained_lora_path=pretrained_lora_path,
|
150 |
+
)
|
151 |
+
else:
|
152 |
+
missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True)
|
153 |
+
print(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
|
154 |
+
pipe.requires_grad_(False)
|
155 |
+
pipe.eval()
|
156 |
+
pipe.enable_vram_management(num_persistent_param_in_dit=self.args.num_persistent_param_in_dit) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
157 |
+
return pipe
|
158 |
+
|
159 |
+
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
|
160 |
+
# Add LoRA to UNet
|
161 |
+
self.lora_alpha = lora_alpha
|
162 |
+
if init_lora_weights == "kaiming":
|
163 |
+
init_lora_weights = True
|
164 |
+
|
165 |
+
lora_config = LoraConfig(
|
166 |
+
r=lora_rank,
|
167 |
+
lora_alpha=lora_alpha,
|
168 |
+
init_lora_weights=init_lora_weights,
|
169 |
+
target_modules=lora_target_modules.split(","),
|
170 |
+
)
|
171 |
+
model = inject_adapter_in_model(lora_config, model)
|
172 |
+
|
173 |
+
# Lora pretrained lora weights
|
174 |
+
if pretrained_lora_path is not None:
|
175 |
+
state_dict = load_state_dict(pretrained_lora_path)
|
176 |
+
if state_dict_converter is not None:
|
177 |
+
state_dict = state_dict_converter(state_dict)
|
178 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
179 |
+
all_keys = [i for i, _ in model.named_parameters()]
|
180 |
+
num_updated_keys = len(all_keys) - len(missing_keys)
|
181 |
+
num_unexpected_keys = len(unexpected_keys)
|
182 |
+
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
|
183 |
+
|
184 |
+
|
185 |
+
def forward(self, prompt,
|
186 |
+
image_path=None,
|
187 |
+
audio_path=None,
|
188 |
+
seq_len=101, # not used while audio_path is not None
|
189 |
+
height=720,
|
190 |
+
width=720,
|
191 |
+
overlap_frame=None,
|
192 |
+
num_steps=None,
|
193 |
+
negative_prompt=None,
|
194 |
+
guidance_scale=None,
|
195 |
+
audio_scale=None):
|
196 |
+
overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
|
197 |
+
num_steps = num_steps if num_steps is not None else self.args.num_steps
|
198 |
+
negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
|
199 |
+
guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
|
200 |
+
audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
|
201 |
+
|
202 |
+
if image_path is not None:
|
203 |
+
from PIL import Image
|
204 |
+
image = Image.open(image_path).convert("RGB")
|
205 |
+
image = self.transform(image).unsqueeze(0).to(self.device)
|
206 |
+
_, _, h, w = image.shape
|
207 |
+
select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
|
208 |
+
image = resize_pad(image, (h, w), select_size)
|
209 |
+
image = image * 2.0 - 1.0
|
210 |
+
image = image[:, :, None]
|
211 |
+
else:
|
212 |
+
image = None
|
213 |
+
select_size = [height, width]
|
214 |
+
L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
|
215 |
+
L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
|
216 |
+
T = (L + 3) // 4 # latent frames
|
217 |
+
|
218 |
+
if self.args.i2v:
|
219 |
+
if self.args.random_prefix_frames:
|
220 |
+
fixed_frame = overlap_frame
|
221 |
+
assert fixed_frame % 4 == 1
|
222 |
+
else:
|
223 |
+
fixed_frame = 1
|
224 |
+
prefix_lat_frame = (3 + fixed_frame) // 4
|
225 |
+
first_fixed_frame = 1
|
226 |
+
else:
|
227 |
+
fixed_frame = 0
|
228 |
+
prefix_lat_frame = 0
|
229 |
+
first_fixed_frame = 0
|
230 |
+
|
231 |
+
|
232 |
+
if audio_path is not None and self.args.use_audio:
|
233 |
+
audio, sr = librosa.load(audio_path, sr=self.args.sample_rate)
|
234 |
+
input_values = np.squeeze(
|
235 |
+
self.wav_feature_extractor(audio, sampling_rate=16000).input_values
|
236 |
+
)
|
237 |
+
input_values = torch.from_numpy(input_values).float().to(device=self.device)
|
238 |
+
ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
|
239 |
+
input_values = input_values.unsqueeze(0)
|
240 |
+
# padding audio
|
241 |
+
if audio_len < L - first_fixed_frame:
|
242 |
+
audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
|
243 |
+
elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
|
244 |
+
audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
|
245 |
+
input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0)
|
246 |
+
with torch.no_grad():
|
247 |
+
hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True)
|
248 |
+
audio_embeddings = hidden_states.last_hidden_state
|
249 |
+
for mid_hidden_states in hidden_states.hidden_states:
|
250 |
+
audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1)
|
251 |
+
seq_len = audio_len
|
252 |
+
audio_embeddings = audio_embeddings.squeeze(0)
|
253 |
+
audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame])
|
254 |
+
else:
|
255 |
+
audio_embeddings = None
|
256 |
+
|
257 |
+
# loop
|
258 |
+
times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
|
259 |
+
if times * (L-fixed_frame) + fixed_frame < seq_len:
|
260 |
+
times += 1
|
261 |
+
video = []
|
262 |
+
image_emb = {}
|
263 |
+
img_lat = None
|
264 |
+
if self.args.i2v:
|
265 |
+
self.pipe.load_models_to_device(['vae'])
|
266 |
+
img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device)
|
267 |
+
|
268 |
+
msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1])
|
269 |
+
image_cat = img_lat.repeat(1, 1, T, 1, 1)
|
270 |
+
msk[:, :, 1:] = 1
|
271 |
+
image_emb["y"] = torch.cat([image_cat, msk], dim=1)
|
272 |
+
for t in range(times):
|
273 |
+
print(f"[{t+1}/{times}]")
|
274 |
+
audio_emb = {}
|
275 |
+
if t == 0:
|
276 |
+
overlap = first_fixed_frame
|
277 |
+
else:
|
278 |
+
overlap = fixed_frame
|
279 |
+
image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
|
280 |
+
prefix_overlap = (3 + overlap) // 4
|
281 |
+
if audio_embeddings is not None:
|
282 |
+
if t == 0:
|
283 |
+
audio_tensor = audio_embeddings[
|
284 |
+
:min(L - overlap, audio_embeddings.shape[0])
|
285 |
+
]
|
286 |
+
else:
|
287 |
+
audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
|
288 |
+
audio_tensor = audio_embeddings[
|
289 |
+
audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
|
290 |
+
]
|
291 |
+
|
292 |
+
audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
|
293 |
+
audio_prefix = audio_tensor[-fixed_frame:]
|
294 |
+
audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
|
295 |
+
audio_emb["audio_emb"] = audio_tensor
|
296 |
+
else:
|
297 |
+
audio_prefix = None
|
298 |
+
if image is not None and img_lat is None:
|
299 |
+
self.pipe.load_models_to_device(['vae'])
|
300 |
+
img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device)
|
301 |
+
assert img_lat.shape[2] == prefix_overlap
|
302 |
+
img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1))], dim=2)
|
303 |
+
frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
|
304 |
+
negative_prompt, num_inference_steps=num_steps,
|
305 |
+
cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
|
306 |
+
return_latent=True,
|
307 |
+
tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B")
|
308 |
+
img_lat = None
|
309 |
+
image = (frames[:, -fixed_frame:].clip(0, 1) * 2 - 1).permute(0, 2, 1, 3, 4).contiguous()
|
310 |
+
if t == 0:
|
311 |
+
video.append(frames)
|
312 |
+
else:
|
313 |
+
video.append(frames[:, overlap:])
|
314 |
+
video = torch.cat(video, dim=1)
|
315 |
+
video = video[:, :ori_audio_len + 1]
|
316 |
+
return video
|
317 |
+
|
318 |
+
|
319 |
+
def main():
|
320 |
+
|
321 |
+
# os.makedirs("demo_out/config", exist_ok=True)
|
322 |
+
# OmegaConf.save(config=OmegaConf.create(vars(args)),
|
323 |
+
# f="demo_out/config/args_config.yaml")
|
324 |
+
# print("Saved merged args to demo_out/config/args_config.yaml")
|
325 |
+
|
326 |
+
set_seed(args.seed)
|
327 |
+
# laod data
|
328 |
+
data_iter = read_from_file(args.input_file)
|
329 |
+
exp_name = os.path.basename(args.exp_path)
|
330 |
+
seq_len = args.seq_len
|
331 |
+
|
332 |
+
# Text-to-video
|
333 |
+
inferpipe = WanInferencePipeline(args)
|
334 |
+
|
335 |
+
output_dir = f'demo_out'
|
336 |
+
|
337 |
+
idx = 0
|
338 |
+
text = "A realistic video of a man speaking directly to the camera on a sofa, with dynamic and rhythmic hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence."
|
339 |
+
image_path = "examples/images/0000.jpeg"
|
340 |
+
audio_path = "examples/audios/0000.MP3"
|
341 |
+
audio_dir = output_dir + '/audio'
|
342 |
+
os.makedirs(audio_dir, exist_ok=True)
|
343 |
+
if args.silence_duration_s > 0:
|
344 |
+
input_audio_path = os.path.join(audio_dir, f"audio_input_{idx:03d}.wav")
|
345 |
+
else:
|
346 |
+
input_audio_path = audio_path
|
347 |
+
prompt_dir = output_dir + '/prompt'
|
348 |
+
os.makedirs(prompt_dir, exist_ok=True)
|
349 |
+
|
350 |
+
if args.silence_duration_s > 0:
|
351 |
+
add_silence_to_audio_ffmpeg(audio_path, input_audio_path, args.silence_duration_s)
|
352 |
+
|
353 |
+
video = inferpipe(
|
354 |
+
prompt=text,
|
355 |
+
image_path=image_path,
|
356 |
+
audio_path=input_audio_path,
|
357 |
+
seq_len=seq_len
|
358 |
+
)
|
359 |
+
tmp2_audio_path = os.path.join(audio_dir, f"audio_out_{idx:03d}.wav") # 因为第一帧是参考帧,因此需要往前1/25秒
|
360 |
+
prompt_path = os.path.join(prompt_dir, f"prompt_{idx:03d}.txt")
|
361 |
+
|
362 |
+
|
363 |
+
add_silence_to_audio_ffmpeg(audio_path, tmp2_audio_path, 1.0 / args.fps + args.silence_duration_s)
|
364 |
+
save_video_as_grid_and_mp4(video,
|
365 |
+
output_dir,
|
366 |
+
args.fps,
|
367 |
+
prompt=text,
|
368 |
+
prompt_path = prompt_path,
|
369 |
+
audio_path=tmp2_audio_path if args.use_audio else None,
|
370 |
+
prefix=f'result_{idx:03d}')
|
371 |
+
|
372 |
+
|
373 |
+
class NoPrint:
|
374 |
+
def write(self, x):
|
375 |
+
pass
|
376 |
+
def flush(self):
|
377 |
+
pass
|
378 |
+
|
379 |
+
if __name__ == '__main__':
|
380 |
+
if not args.debug:
|
381 |
+
if args.local_rank != 0: # 屏蔽除0外的输出
|
382 |
+
sys.stdout = NoPrint()
|
383 |
+
main()
|