alexnasa commited on
Commit
468a4ed
·
verified ·
1 Parent(s): af00586

Upload 37 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/audios/mushroom.wav filter=lfs diff=lfs merge=lfs -text
37
+ examples/audios/tape.wav filter=lfs diff=lfs merge=lfs -text
38
+ examples/images/female-001.png filter=lfs diff=lfs merge=lfs -text
39
+ examples/images/male-001.png filter=lfs diff=lfs merge=lfs -text
LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
OmniAvatar/base.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from torchvision.transforms import GaussianBlur
5
+
6
+
7
+
8
+ class BasePipeline(torch.nn.Module):
9
+
10
+ def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
11
+ super().__init__()
12
+ self.device = device
13
+ self.torch_dtype = torch_dtype
14
+ self.height_division_factor = height_division_factor
15
+ self.width_division_factor = width_division_factor
16
+ self.cpu_offload = False
17
+ self.model_names = []
18
+
19
+
20
+ def check_resize_height_width(self, height, width):
21
+ if height % self.height_division_factor != 0:
22
+ height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
23
+ print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
24
+ if width % self.width_division_factor != 0:
25
+ width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
26
+ print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
27
+ return height, width
28
+
29
+
30
+ def preprocess_image(self, image):
31
+ image = torch.Tensor(np.array(image, dtype=np.float16) * (2.0 / 255) - 1.0).permute(2, 0, 1).unsqueeze(0)
32
+ return image
33
+
34
+
35
+ def preprocess_images(self, images):
36
+ return [self.preprocess_image(image) for image in images]
37
+
38
+
39
+ def vae_output_to_image(self, vae_output):
40
+ image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
41
+ image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
42
+ return image
43
+
44
+
45
+ def vae_output_to_video(self, vae_output):
46
+ video = vae_output.cpu().permute(1, 2, 0).numpy()
47
+ video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
48
+ return video
49
+
50
+
51
+ def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
52
+ if len(latents) > 0:
53
+ blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
54
+ height, width = value.shape[-2:]
55
+ weight = torch.ones_like(value)
56
+ for latent, mask, scale in zip(latents, masks, scales):
57
+ mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
58
+ mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
59
+ mask = blur(mask)
60
+ value += latent * mask * scale
61
+ weight += mask * scale
62
+ value /= weight
63
+ return value
64
+
65
+
66
+ def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
67
+ if special_kwargs is None:
68
+ noise_pred_global = inference_callback(prompt_emb_global)
69
+ else:
70
+ noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
71
+ if special_local_kwargs_list is None:
72
+ noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
73
+ else:
74
+ noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
75
+ noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
76
+ return noise_pred
77
+
78
+
79
+ def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
80
+ local_prompts = local_prompts or []
81
+ masks = masks or []
82
+ mask_scales = mask_scales or []
83
+ extended_prompt_dict = self.prompter.extend_prompt(prompt)
84
+ prompt = extended_prompt_dict.get("prompt", prompt)
85
+ local_prompts += extended_prompt_dict.get("prompts", [])
86
+ masks += extended_prompt_dict.get("masks", [])
87
+ mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
88
+ return prompt, local_prompts, masks, mask_scales
89
+
90
+
91
+ def enable_cpu_offload(self):
92
+ self.cpu_offload = True
93
+
94
+
95
+ def load_models_to_device(self, loadmodel_names=[]):
96
+ # only load models to device if cpu_offload is enabled
97
+ if not self.cpu_offload:
98
+ return
99
+ # offload the unneeded models to cpu
100
+ for model_name in self.model_names:
101
+ if model_name not in loadmodel_names:
102
+ model = getattr(self, model_name)
103
+ if model is not None:
104
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
105
+ for module in model.modules():
106
+ if hasattr(module, "offload"):
107
+ module.offload()
108
+ else:
109
+ model.cpu()
110
+ # load the needed models to device
111
+ for model_name in loadmodel_names:
112
+ model = getattr(self, model_name)
113
+ if model is not None:
114
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
115
+ for module in model.modules():
116
+ if hasattr(module, "onload"):
117
+ module.onload()
118
+ else:
119
+ model.to(self.device)
120
+ # fresh the cuda cache
121
+ torch.cuda.empty_cache()
122
+
123
+
124
+ def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
125
+ generator = None if seed is None else torch.Generator(device).manual_seed(seed)
126
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
127
+ return noise
OmniAvatar/configs/__init__.py ADDED
File without changes
OmniAvatar/configs/model_config.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+ from ..models.wan_video_dit import WanModel
3
+ from ..models.wan_video_text_encoder import WanTextEncoder
4
+ from ..models.wan_video_vae import WanVideoVAE
5
+
6
+
7
+ model_loader_configs = [
8
+ # These configs are provided for detecting model type automatically.
9
+ # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
10
+ (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
11
+ (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
12
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
13
+ (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
14
+ (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
15
+ (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
16
+ (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
17
+ ]
18
+ huggingface_model_loader_configs = [
19
+ # These configs are provided for detecting model type automatically.
20
+ # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
21
+ ("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
22
+ ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
23
+ ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
24
+ ("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
25
+ # ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
26
+ ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
27
+ ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
28
+ ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
29
+ ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
30
+ ("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
31
+ ("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
32
+ ]
33
+
34
+ preset_models_on_huggingface = {
35
+ "HunyuanDiT": [
36
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
37
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
38
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
39
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
40
+ ],
41
+ "stable-video-diffusion-img2vid-xt": [
42
+ ("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
43
+ ],
44
+ "ExVideo-SVD-128f-v1": [
45
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
46
+ ],
47
+ # Stable Diffusion
48
+ "StableDiffusion_v15": [
49
+ ("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
50
+ ],
51
+ "DreamShaper_8": [
52
+ ("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
53
+ ],
54
+ # Textual Inversion
55
+ "TextualInversion_VeryBadImageNegative_v1.3": [
56
+ ("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
57
+ ],
58
+ # Stable Diffusion XL
59
+ "StableDiffusionXL_v1": [
60
+ ("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
61
+ ],
62
+ "BluePencilXL_v200": [
63
+ ("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
64
+ ],
65
+ "StableDiffusionXL_Turbo": [
66
+ ("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
67
+ ],
68
+ # Stable Diffusion 3
69
+ "StableDiffusion3": [
70
+ ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
71
+ ],
72
+ "StableDiffusion3_without_T5": [
73
+ ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
74
+ ],
75
+ # ControlNet
76
+ "ControlNet_v11f1p_sd15_depth": [
77
+ ("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
78
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
79
+ ],
80
+ "ControlNet_v11p_sd15_softedge": [
81
+ ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
82
+ ("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
83
+ ],
84
+ "ControlNet_v11f1e_sd15_tile": [
85
+ ("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
86
+ ],
87
+ "ControlNet_v11p_sd15_lineart": [
88
+ ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
89
+ ("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
90
+ ("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
91
+ ],
92
+ "ControlNet_union_sdxl_promax": [
93
+ ("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
94
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
95
+ ],
96
+ # AnimateDiff
97
+ "AnimateDiff_v2": [
98
+ ("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
99
+ ],
100
+ "AnimateDiff_xl_beta": [
101
+ ("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
102
+ ],
103
+
104
+ # Qwen Prompt
105
+ "QwenPrompt": [
106
+ ("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
107
+ ("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
108
+ ("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
109
+ ("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
110
+ ("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
111
+ ("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
112
+ ("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
113
+ ("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
114
+ ],
115
+ # Beautiful Prompt
116
+ "BeautifulPrompt": [
117
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
118
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
119
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
120
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
121
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
122
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
123
+ ],
124
+ # Omost prompt
125
+ "OmostPrompt":[
126
+ ("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
127
+ ("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
128
+ ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
129
+ ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
130
+ ("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
131
+ ("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
132
+ ("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
133
+ ("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
134
+ ],
135
+ # Translator
136
+ "opus-mt-zh-en": [
137
+ ("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
138
+ ("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
139
+ ("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
140
+ ("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
141
+ ("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
142
+ ("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
143
+ ("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
144
+ ("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
145
+ ],
146
+ # IP-Adapter
147
+ "IP-Adapter-SD": [
148
+ ("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
149
+ ("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
150
+ ],
151
+ "IP-Adapter-SDXL": [
152
+ ("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
153
+ ("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
154
+ ],
155
+ "SDXL-vae-fp16-fix": [
156
+ ("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
157
+ ],
158
+ # Kolors
159
+ "Kolors": [
160
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
161
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
162
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
163
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
164
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
165
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
166
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
167
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
168
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
169
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
170
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
171
+ ],
172
+ # FLUX
173
+ "FLUX.1-dev": [
174
+ ("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
175
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
176
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
177
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
178
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
179
+ ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
180
+ ("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
181
+ ],
182
+ "InstantX/FLUX.1-dev-IP-Adapter": {
183
+ "file_list": [
184
+ ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
185
+ ("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
186
+ ("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
187
+ ],
188
+ "load_path": [
189
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
190
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
191
+ ],
192
+ },
193
+ # RIFE
194
+ "RIFE": [
195
+ ("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
196
+ ],
197
+ # CogVideo
198
+ "CogVideoX-5B": [
199
+ ("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
200
+ ("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
201
+ ("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
202
+ ("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
203
+ ("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
204
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
205
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
206
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
207
+ ("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
208
+ ],
209
+ # Stable Diffusion 3.5
210
+ "StableDiffusion3.5-large": [
211
+ ("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
212
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
213
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
214
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
215
+ ],
216
+ }
217
+ preset_models_on_modelscope = {
218
+ # Hunyuan DiT
219
+ "HunyuanDiT": [
220
+ ("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
221
+ ("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
222
+ ("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
223
+ ("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
224
+ ],
225
+ # Stable Video Diffusion
226
+ "stable-video-diffusion-img2vid-xt": [
227
+ ("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
228
+ ],
229
+ # ExVideo
230
+ "ExVideo-SVD-128f-v1": [
231
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
232
+ ],
233
+ "ExVideo-CogVideoX-LoRA-129f-v1": [
234
+ ("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
235
+ ],
236
+ # Stable Diffusion
237
+ "StableDiffusion_v15": [
238
+ ("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
239
+ ],
240
+ "DreamShaper_8": [
241
+ ("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
242
+ ],
243
+ "AingDiffusion_v12": [
244
+ ("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
245
+ ],
246
+ "Flat2DAnimerge_v45Sharp": [
247
+ ("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
248
+ ],
249
+ # Textual Inversion
250
+ "TextualInversion_VeryBadImageNegative_v1.3": [
251
+ ("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
252
+ ],
253
+ # Stable Diffusion XL
254
+ "StableDiffusionXL_v1": [
255
+ ("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
256
+ ],
257
+ "BluePencilXL_v200": [
258
+ ("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
259
+ ],
260
+ "StableDiffusionXL_Turbo": [
261
+ ("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
262
+ ],
263
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
264
+ ("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
265
+ ],
266
+ # Stable Diffusion 3
267
+ "StableDiffusion3": [
268
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
269
+ ],
270
+ "StableDiffusion3_without_T5": [
271
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
272
+ ],
273
+ # ControlNet
274
+ "ControlNet_v11f1p_sd15_depth": [
275
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
276
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
277
+ ],
278
+ "ControlNet_v11p_sd15_softedge": [
279
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
280
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
281
+ ],
282
+ "ControlNet_v11f1e_sd15_tile": [
283
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
284
+ ],
285
+ "ControlNet_v11p_sd15_lineart": [
286
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
287
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
288
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
289
+ ],
290
+ "ControlNet_union_sdxl_promax": [
291
+ ("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
292
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
293
+ ],
294
+ "Annotators:Depth": [
295
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
296
+ ],
297
+ "Annotators:Softedge": [
298
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
299
+ ],
300
+ "Annotators:Lineart": [
301
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
302
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
303
+ ],
304
+ "Annotators:Normal": [
305
+ ("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
306
+ ],
307
+ "Annotators:Openpose": [
308
+ ("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
309
+ ("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
310
+ ("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
311
+ ],
312
+ # AnimateDiff
313
+ "AnimateDiff_v2": [
314
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
315
+ ],
316
+ "AnimateDiff_xl_beta": [
317
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
318
+ ],
319
+ # RIFE
320
+ "RIFE": [
321
+ ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
322
+ ],
323
+ # Qwen Prompt
324
+ "QwenPrompt": {
325
+ "file_list": [
326
+ ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
327
+ ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
328
+ ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
329
+ ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
330
+ ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
331
+ ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
332
+ ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
333
+ ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
334
+ ],
335
+ "load_path": [
336
+ "models/QwenPrompt/qwen2-1.5b-instruct",
337
+ ],
338
+ },
339
+ # Beautiful Prompt
340
+ "BeautifulPrompt": {
341
+ "file_list": [
342
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
343
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
344
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
345
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
346
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
347
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
348
+ ],
349
+ "load_path": [
350
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
351
+ ],
352
+ },
353
+ # Omost prompt
354
+ "OmostPrompt": {
355
+ "file_list": [
356
+ ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
357
+ ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
358
+ ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
359
+ ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
360
+ ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
361
+ ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
362
+ ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
363
+ ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
364
+ ],
365
+ "load_path": [
366
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
367
+ ],
368
+ },
369
+ # Translator
370
+ "opus-mt-zh-en": {
371
+ "file_list": [
372
+ ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
373
+ ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
374
+ ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
375
+ ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
376
+ ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
377
+ ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
378
+ ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
379
+ ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
380
+ ],
381
+ "load_path": [
382
+ "models/translator/opus-mt-zh-en",
383
+ ],
384
+ },
385
+ # IP-Adapter
386
+ "IP-Adapter-SD": [
387
+ ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
388
+ ("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
389
+ ],
390
+ "IP-Adapter-SDXL": [
391
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
392
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
393
+ ],
394
+ # Kolors
395
+ "Kolors": {
396
+ "file_list": [
397
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
398
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
399
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
400
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
401
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
402
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
403
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
404
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
405
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
406
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
407
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
408
+ ],
409
+ "load_path": [
410
+ "models/kolors/Kolors/text_encoder",
411
+ "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
412
+ "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
413
+ ],
414
+ },
415
+ "SDXL-vae-fp16-fix": [
416
+ ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
417
+ ],
418
+ # FLUX
419
+ "FLUX.1-dev": {
420
+ "file_list": [
421
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
422
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
423
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
424
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
425
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
426
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
427
+ ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
428
+ ],
429
+ "load_path": [
430
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
431
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
432
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
433
+ "models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
434
+ ],
435
+ },
436
+ "FLUX.1-schnell": {
437
+ "file_list": [
438
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
439
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
440
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
441
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
442
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
443
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
444
+ ("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
445
+ ],
446
+ "load_path": [
447
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
448
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
449
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
450
+ "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
451
+ ],
452
+ },
453
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
454
+ ("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
455
+ ],
456
+ "jasperai/Flux.1-dev-Controlnet-Depth": [
457
+ ("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
458
+ ],
459
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
460
+ ("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
461
+ ],
462
+ "jasperai/Flux.1-dev-Controlnet-Upscaler": [
463
+ ("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
464
+ ],
465
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
466
+ ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
467
+ ],
468
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
469
+ ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
470
+ ],
471
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
472
+ ("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
473
+ ],
474
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
475
+ ("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
476
+ ],
477
+ "InstantX/FLUX.1-dev-IP-Adapter": {
478
+ "file_list": [
479
+ ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
480
+ ("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
481
+ ("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
482
+ ],
483
+ "load_path": [
484
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
485
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
486
+ ],
487
+ },
488
+ # ESRGAN
489
+ "ESRGAN_x4": [
490
+ ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
491
+ ],
492
+ # RIFE
493
+ "RIFE": [
494
+ ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
495
+ ],
496
+ # Omnigen
497
+ "OmniGen-v1": {
498
+ "file_list": [
499
+ ("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
500
+ ("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
501
+ ("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
502
+ ("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
503
+ ("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
504
+ ("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
505
+ ],
506
+ "load_path": [
507
+ "models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
508
+ "models/OmniGen/OmniGen-v1/model.safetensors",
509
+ ]
510
+ },
511
+ # CogVideo
512
+ "CogVideoX-5B": {
513
+ "file_list": [
514
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
515
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
516
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
517
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
518
+ ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
519
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
520
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
521
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
522
+ ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
523
+ ],
524
+ "load_path": [
525
+ "models/CogVideo/CogVideoX-5b/text_encoder",
526
+ "models/CogVideo/CogVideoX-5b/transformer",
527
+ "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
528
+ ],
529
+ },
530
+ # Stable Diffusion 3.5
531
+ "StableDiffusion3.5-large": [
532
+ ("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
533
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
534
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
535
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
536
+ ],
537
+ "StableDiffusion3.5-medium": [
538
+ ("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
539
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
540
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
541
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
542
+ ],
543
+ "StableDiffusion3.5-large-turbo": [
544
+ ("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
545
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
546
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
547
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
548
+ ],
549
+ "HunyuanVideo":{
550
+ "file_list": [
551
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
552
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
553
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
554
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
555
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
556
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
557
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
558
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
559
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
560
+ ],
561
+ "load_path": [
562
+ "models/HunyuanVideo/text_encoder/model.safetensors",
563
+ "models/HunyuanVideo/text_encoder_2",
564
+ "models/HunyuanVideo/vae/pytorch_model.pt",
565
+ "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
566
+ ],
567
+ },
568
+ "HunyuanVideoI2V":{
569
+ "file_list": [
570
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
571
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
572
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
573
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
574
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
575
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
576
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
577
+ ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
578
+ ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
579
+ ],
580
+ "load_path": [
581
+ "models/HunyuanVideoI2V/text_encoder/model.safetensors",
582
+ "models/HunyuanVideoI2V/text_encoder_2",
583
+ "models/HunyuanVideoI2V/vae/pytorch_model.pt",
584
+ "models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
585
+ ],
586
+ },
587
+ "HunyuanVideo-fp8":{
588
+ "file_list": [
589
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
590
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
591
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
592
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
593
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
594
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
595
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
596
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
597
+ ("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
598
+ ],
599
+ "load_path": [
600
+ "models/HunyuanVideo/text_encoder/model.safetensors",
601
+ "models/HunyuanVideo/text_encoder_2",
602
+ "models/HunyuanVideo/vae/pytorch_model.pt",
603
+ "models/HunyuanVideo/transformers/model.fp8.safetensors"
604
+ ],
605
+ },
606
+ }
607
+ Preset_model_id: TypeAlias = Literal[
608
+ "HunyuanDiT",
609
+ "stable-video-diffusion-img2vid-xt",
610
+ "ExVideo-SVD-128f-v1",
611
+ "ExVideo-CogVideoX-LoRA-129f-v1",
612
+ "StableDiffusion_v15",
613
+ "DreamShaper_8",
614
+ "AingDiffusion_v12",
615
+ "Flat2DAnimerge_v45Sharp",
616
+ "TextualInversion_VeryBadImageNegative_v1.3",
617
+ "StableDiffusionXL_v1",
618
+ "BluePencilXL_v200",
619
+ "StableDiffusionXL_Turbo",
620
+ "ControlNet_v11f1p_sd15_depth",
621
+ "ControlNet_v11p_sd15_softedge",
622
+ "ControlNet_v11f1e_sd15_tile",
623
+ "ControlNet_v11p_sd15_lineart",
624
+ "AnimateDiff_v2",
625
+ "AnimateDiff_xl_beta",
626
+ "RIFE",
627
+ "BeautifulPrompt",
628
+ "opus-mt-zh-en",
629
+ "IP-Adapter-SD",
630
+ "IP-Adapter-SDXL",
631
+ "StableDiffusion3",
632
+ "StableDiffusion3_without_T5",
633
+ "Kolors",
634
+ "SDXL-vae-fp16-fix",
635
+ "ControlNet_union_sdxl_promax",
636
+ "FLUX.1-dev",
637
+ "FLUX.1-schnell",
638
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha",
639
+ "jasperai/Flux.1-dev-Controlnet-Depth",
640
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals",
641
+ "jasperai/Flux.1-dev-Controlnet-Upscaler",
642
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
643
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
644
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
645
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
646
+ "InstantX/FLUX.1-dev-IP-Adapter",
647
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
648
+ "QwenPrompt",
649
+ "OmostPrompt",
650
+ "ESRGAN_x4",
651
+ "RIFE",
652
+ "OmniGen-v1",
653
+ "CogVideoX-5B",
654
+ "Annotators:Depth",
655
+ "Annotators:Softedge",
656
+ "Annotators:Lineart",
657
+ "Annotators:Normal",
658
+ "Annotators:Openpose",
659
+ "StableDiffusion3.5-large",
660
+ "StableDiffusion3.5-medium",
661
+ "HunyuanVideo",
662
+ "HunyuanVideo-fp8",
663
+ "HunyuanVideoI2V",
664
+ ]
OmniAvatar/distributed/__init__.py ADDED
File without changes
OmniAvatar/distributed/fsdp.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
8
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
9
+ from torch.distributed.utils import _free_storage
10
+
11
+
12
+ def shard_model(
13
+ model,
14
+ device_id,
15
+ param_dtype=torch.bfloat16,
16
+ reduce_dtype=torch.float32,
17
+ buffer_dtype=torch.float32,
18
+ process_group=None,
19
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
20
+ sync_module_states=True,
21
+ ):
22
+ model = FSDP(
23
+ module=model,
24
+ process_group=process_group,
25
+ sharding_strategy=sharding_strategy,
26
+ auto_wrap_policy=partial(
27
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
28
+ mixed_precision=MixedPrecision(
29
+ param_dtype=param_dtype,
30
+ reduce_dtype=reduce_dtype,
31
+ buffer_dtype=buffer_dtype),
32
+ device_id=device_id,
33
+ sync_module_states=sync_module_states)
34
+ return model
35
+
36
+
37
+ def free_model(model):
38
+ for m in model.modules():
39
+ if isinstance(m, FSDP):
40
+ _free_storage(m._handle.flat_param.data)
41
+ del model
42
+ gc.collect()
43
+ torch.cuda.empty_cache()
OmniAvatar/distributed/xdit_context_parallel.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+ from einops import rearrange
4
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
5
+ get_sequence_parallel_world_size,
6
+ get_sp_group)
7
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
8
+ from yunchang import LongContextAttention
9
+
10
+ def sinusoidal_embedding_1d(dim, position):
11
+ sinusoid = torch.outer(position.type(torch.float64), torch.pow(
12
+ 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
13
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
14
+ return x.to(position.dtype)
15
+
16
+ def pad_freqs(original_tensor, target_len):
17
+ seq_len, s1, s2 = original_tensor.shape
18
+ pad_size = target_len - seq_len
19
+ padding_tensor = torch.ones(
20
+ pad_size,
21
+ s1,
22
+ s2,
23
+ dtype=original_tensor.dtype,
24
+ device=original_tensor.device)
25
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
26
+ return padded_tensor
27
+
28
+ def rope_apply(x, freqs, num_heads):
29
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
30
+ s_per_rank = x.shape[1]
31
+ s_per_rank = get_sp_group().broadcast_object_list([s_per_rank], src=0)[0] # TODO: the size should be devided by sp_size
32
+
33
+ x_out = torch.view_as_complex(x.to(torch.float64).reshape(
34
+ x.shape[0], x.shape[1], x.shape[2], -1, 2))
35
+
36
+ sp_size = get_sequence_parallel_world_size()
37
+ sp_rank = get_sequence_parallel_rank()
38
+ if freqs.shape[0] % sp_size != 0 and freqs.shape[0] // sp_size == s_per_rank:
39
+ s_per_rank = s_per_rank + 1
40
+ freqs = pad_freqs(freqs, s_per_rank * sp_size)
41
+ freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
42
+ freqs_rank = freqs_rank[:x.shape[1]]
43
+ x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
44
+ return x_out.to(x.dtype)
45
+
46
+ def usp_dit_forward(self,
47
+ x: torch.Tensor,
48
+ timestep: torch.Tensor,
49
+ context: torch.Tensor,
50
+ clip_feature: Optional[torch.Tensor] = None,
51
+ y: Optional[torch.Tensor] = None,
52
+ use_gradient_checkpointing: bool = False,
53
+ use_gradient_checkpointing_offload: bool = False,
54
+ **kwargs,
55
+ ):
56
+ t = self.time_embedding(
57
+ sinusoidal_embedding_1d(self.freq_dim, timestep))
58
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
59
+ context = self.text_embedding(context)
60
+
61
+ if self.has_image_input:
62
+ x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
63
+ clip_embdding = self.img_emb(clip_feature)
64
+ context = torch.cat([clip_embdding, context], dim=1)
65
+
66
+ x, (f, h, w) = self.patchify(x)
67
+
68
+ freqs = torch.cat([
69
+ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
70
+ self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
71
+ self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
72
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
73
+
74
+ def create_custom_forward(module):
75
+ def custom_forward(*inputs):
76
+ return module(*inputs)
77
+ return custom_forward
78
+
79
+ # Context Parallel
80
+ x = torch.chunk(
81
+ x, get_sequence_parallel_world_size(),
82
+ dim=1)[get_sequence_parallel_rank()]
83
+
84
+ for block in self.blocks:
85
+ if self.training and use_gradient_checkpointing:
86
+ if use_gradient_checkpointing_offload:
87
+ with torch.autograd.graph.save_on_cpu():
88
+ x = torch.utils.checkpoint.checkpoint(
89
+ create_custom_forward(block),
90
+ x, context, t_mod, freqs,
91
+ use_reentrant=False,
92
+ )
93
+ else:
94
+ x = torch.utils.checkpoint.checkpoint(
95
+ create_custom_forward(block),
96
+ x, context, t_mod, freqs,
97
+ use_reentrant=False,
98
+ )
99
+ else:
100
+ x = block(x, context, t_mod, freqs)
101
+
102
+ x = self.head(x, t)
103
+
104
+ # Context Parallel
105
+ if x.shape[1] * get_sequence_parallel_world_size() < freqs.shape[0]:
106
+ x = torch.cat([x, x[:, -1:]], 1) # TODO: this may cause some bias, the best way is to use sp_size=2
107
+ x = get_sp_group().all_gather(x, dim=1) # TODO: the size should be devided by sp_size
108
+ x = x[:, :freqs.shape[0]]
109
+
110
+ # unpatchify
111
+ x = self.unpatchify(x, (f, h, w))
112
+ return x
113
+
114
+
115
+ def usp_attn_forward(self, x, freqs):
116
+ q = self.norm_q(self.q(x))
117
+ k = self.norm_k(self.k(x))
118
+ v = self.v(x)
119
+
120
+ q = rope_apply(q, freqs, self.num_heads)
121
+ k = rope_apply(k, freqs, self.num_heads)
122
+ q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
123
+ k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
124
+ v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
125
+
126
+ x = xFuserLongContextAttention()(
127
+ None,
128
+ query=q,
129
+ key=k,
130
+ value=v,
131
+ )
132
+ x = x.flatten(2)
133
+
134
+ return self.o(x)
OmniAvatar/models/audio_pack.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple, Union
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import nn
6
+
7
+
8
+ def make_triple(value: Union[int, Tuple[int, int, int]]) -> Tuple[int, int, int]:
9
+ value = (value,) * 3 if isinstance(value, int) else value
10
+ assert len(value) == 3
11
+ return value
12
+
13
+
14
+ class AudioPack(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_channels: int,
18
+ patch_size: Union[int, Tuple[int, int, int]],
19
+ dim: int,
20
+ layernorm=False,
21
+ ):
22
+ super().__init__()
23
+ t, h, w = make_triple(patch_size)
24
+ self.patch_size = t, h, w
25
+ self.proj = nn.Linear(in_channels * t * h * w, dim)
26
+ if layernorm:
27
+ self.norm_out = nn.LayerNorm(dim)
28
+ else:
29
+ self.norm_out = None
30
+
31
+ def forward(
32
+ self,
33
+ vid: torch.Tensor,
34
+ ) -> torch.Tensor:
35
+ t, h, w = self.patch_size
36
+ vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w)
37
+ vid = self.proj(vid)
38
+ if self.norm_out is not None:
39
+ vid = self.norm_out(vid)
40
+ return vid
OmniAvatar/models/model_manager.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, json, importlib
2
+ from typing import List
3
+ import torch.nn as nn
4
+ from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs
5
+ from ..utils.io_utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix, smart_load_weights
6
+
7
+ class GeneralLoRAFromPeft:
8
+
9
+ def get_name_dict(self, lora_state_dict):
10
+ lora_name_dict = {}
11
+ for key in lora_state_dict:
12
+ if ".lora_B." not in key:
13
+ continue
14
+ keys = key.split(".")
15
+ if len(keys) > keys.index("lora_B") + 2:
16
+ keys.pop(keys.index("lora_B") + 1)
17
+ keys.pop(keys.index("lora_B"))
18
+ if keys[0] == "diffusion_model":
19
+ keys.pop(0)
20
+ target_name = ".".join(keys)
21
+ lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
22
+ return lora_name_dict
23
+
24
+
25
+ def match(self, model: torch.nn.Module, state_dict_lora):
26
+ lora_name_dict = self.get_name_dict(state_dict_lora)
27
+ model_name_dict = {name: None for name, _ in model.named_parameters()}
28
+ matched_num = sum([i in model_name_dict for i in lora_name_dict])
29
+ if matched_num == len(lora_name_dict):
30
+ return "", ""
31
+ else:
32
+ return None
33
+
34
+
35
+ def fetch_device_and_dtype(self, state_dict):
36
+ device, dtype = None, None
37
+ for name, param in state_dict.items():
38
+ device, dtype = param.device, param.dtype
39
+ break
40
+ computation_device = device
41
+ computation_dtype = dtype
42
+ if computation_device == torch.device("cpu"):
43
+ if torch.cuda.is_available():
44
+ computation_device = torch.device("cuda")
45
+ if computation_dtype == torch.float8_e4m3fn:
46
+ computation_dtype = torch.float32
47
+ return device, dtype, computation_device, computation_dtype
48
+
49
+
50
+ def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
51
+ state_dict_model = model.state_dict()
52
+ device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
53
+ lora_name_dict = self.get_name_dict(state_dict_lora)
54
+ for name in lora_name_dict:
55
+ weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
56
+ weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
57
+ if len(weight_up.shape) == 4:
58
+ weight_up = weight_up.squeeze(3).squeeze(2)
59
+ weight_down = weight_down.squeeze(3).squeeze(2)
60
+ weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
61
+ else:
62
+ weight_lora = alpha * torch.mm(weight_up, weight_down)
63
+ weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
64
+ weight_patched = weight_model + weight_lora
65
+ state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
66
+ print(f" {len(lora_name_dict)} tensors are updated.")
67
+ model.load_state_dict(state_dict_model)
68
+
69
+
70
+ def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device, infer):
71
+ loaded_model_names, loaded_models = [], []
72
+ for model_name, model_class in zip(model_names, model_classes):
73
+ print(f" model_name: {model_name} model_class: {model_class.__name__}")
74
+ state_dict_converter = model_class.state_dict_converter()
75
+ if model_resource == "civitai":
76
+ state_dict_results = state_dict_converter.from_civitai(state_dict)
77
+ elif model_resource == "diffusers":
78
+ state_dict_results = state_dict_converter.from_diffusers(state_dict)
79
+ if isinstance(state_dict_results, tuple):
80
+ model_state_dict, extra_kwargs = state_dict_results
81
+ print(f" This model is initialized with extra kwargs: {extra_kwargs}")
82
+ else:
83
+ model_state_dict, extra_kwargs = state_dict_results, {}
84
+ torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
85
+ with init_weights_on_device():
86
+ model = model_class(**extra_kwargs)
87
+ if hasattr(model, "eval"):
88
+ model = model.eval()
89
+ if not infer: # 训练才初始化
90
+ model = model.to_empty(device=torch.device("cuda"))
91
+ for name, param in model.named_parameters():
92
+ if param.dim() > 1: # 通常只对权重矩阵而不是偏置做初始化
93
+ nn.init.xavier_uniform_(param, gain=0.05)
94
+ else:
95
+ nn.init.zeros_(param)
96
+ else:
97
+ model = model.to_empty(device=device)
98
+ model, _, _ = smart_load_weights(model, model_state_dict)
99
+ # model.load_state_dict(model_state_dict, assign=True, strict=False)
100
+ model = model.to(dtype=torch_dtype, device=device)
101
+ loaded_model_names.append(model_name)
102
+ loaded_models.append(model)
103
+ return loaded_model_names, loaded_models
104
+
105
+
106
+ def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
107
+ loaded_model_names, loaded_models = [], []
108
+ for model_name, model_class in zip(model_names, model_classes):
109
+ if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
110
+ model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
111
+ else:
112
+ model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
113
+ if torch_dtype == torch.float16 and hasattr(model, "half"):
114
+ model = model.half()
115
+ try:
116
+ model = model.to(device=device)
117
+ except:
118
+ pass
119
+ loaded_model_names.append(model_name)
120
+ loaded_models.append(model)
121
+ return loaded_model_names, loaded_models
122
+
123
+
124
+ def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
125
+ print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
126
+ base_state_dict = base_model.state_dict()
127
+ base_model.to("cpu")
128
+ del base_model
129
+ model = model_class(**extra_kwargs)
130
+ model.load_state_dict(base_state_dict, strict=False)
131
+ model.load_state_dict(state_dict, strict=False)
132
+ model.to(dtype=torch_dtype, device=device)
133
+ return model
134
+
135
+
136
+ def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
137
+ loaded_model_names, loaded_models = [], []
138
+ for model_name, model_class in zip(model_names, model_classes):
139
+ while True:
140
+ for model_id in range(len(model_manager.model)):
141
+ base_model_name = model_manager.model_name[model_id]
142
+ if base_model_name == model_name:
143
+ base_model_path = model_manager.model_path[model_id]
144
+ base_model = model_manager.model[model_id]
145
+ print(f" Adding patch model to {base_model_name} ({base_model_path})")
146
+ patched_model = load_single_patch_model_from_single_file(
147
+ state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
148
+ loaded_model_names.append(base_model_name)
149
+ loaded_models.append(patched_model)
150
+ model_manager.model.pop(model_id)
151
+ model_manager.model_path.pop(model_id)
152
+ model_manager.model_name.pop(model_id)
153
+ break
154
+ else:
155
+ break
156
+ return loaded_model_names, loaded_models
157
+
158
+
159
+
160
+ class ModelDetectorTemplate:
161
+ def __init__(self):
162
+ pass
163
+
164
+ def match(self, file_path="", state_dict={}):
165
+ return False
166
+
167
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
168
+ return [], []
169
+
170
+
171
+
172
+ class ModelDetectorFromSingleFile:
173
+ def __init__(self, model_loader_configs=[]):
174
+ self.keys_hash_with_shape_dict = {}
175
+ self.keys_hash_dict = {}
176
+ for metadata in model_loader_configs:
177
+ self.add_model_metadata(*metadata)
178
+
179
+
180
+ def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
181
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
182
+ if keys_hash is not None:
183
+ self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
184
+
185
+
186
+ def match(self, file_path="", state_dict={}):
187
+ if isinstance(file_path, str) and os.path.isdir(file_path):
188
+ return False
189
+ if len(state_dict) == 0:
190
+ state_dict = load_state_dict(file_path)
191
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
192
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
193
+ return True
194
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
195
+ if keys_hash in self.keys_hash_dict:
196
+ return True
197
+ return False
198
+
199
+
200
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, infer=False, **kwargs):
201
+ if len(state_dict) == 0:
202
+ state_dict = load_state_dict(file_path)
203
+
204
+ # Load models with strict matching
205
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
206
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
207
+ model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
208
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device, infer)
209
+ return loaded_model_names, loaded_models
210
+
211
+ # Load models without strict matching
212
+ # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
213
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
214
+ if keys_hash in self.keys_hash_dict:
215
+ model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
216
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device, infer)
217
+ return loaded_model_names, loaded_models
218
+
219
+ return loaded_model_names, loaded_models
220
+
221
+
222
+
223
+ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
224
+ def __init__(self, model_loader_configs=[]):
225
+ super().__init__(model_loader_configs)
226
+
227
+
228
+ def match(self, file_path="", state_dict={}):
229
+ if isinstance(file_path, str) and os.path.isdir(file_path):
230
+ return False
231
+ if len(state_dict) == 0:
232
+ state_dict = load_state_dict(file_path)
233
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
234
+ for sub_state_dict in splited_state_dict:
235
+ if super().match(file_path, sub_state_dict):
236
+ return True
237
+ return False
238
+
239
+
240
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
241
+ # Split the state_dict and load from each component
242
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
243
+ valid_state_dict = {}
244
+ for sub_state_dict in splited_state_dict:
245
+ if super().match(file_path, sub_state_dict):
246
+ valid_state_dict.update(sub_state_dict)
247
+ if super().match(file_path, valid_state_dict):
248
+ loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
249
+ else:
250
+ loaded_model_names, loaded_models = [], []
251
+ for sub_state_dict in splited_state_dict:
252
+ if super().match(file_path, sub_state_dict):
253
+ loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
254
+ loaded_model_names += loaded_model_names_
255
+ loaded_models += loaded_models_
256
+ return loaded_model_names, loaded_models
257
+
258
+
259
+
260
+ class ModelDetectorFromHuggingfaceFolder:
261
+ def __init__(self, model_loader_configs=[]):
262
+ self.architecture_dict = {}
263
+ for metadata in model_loader_configs:
264
+ self.add_model_metadata(*metadata)
265
+
266
+
267
+ def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
268
+ self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
269
+
270
+
271
+ def match(self, file_path="", state_dict={}):
272
+ if not isinstance(file_path, str) or os.path.isfile(file_path):
273
+ return False
274
+ file_list = os.listdir(file_path)
275
+ if "config.json" not in file_list:
276
+ return False
277
+ with open(os.path.join(file_path, "config.json"), "r") as f:
278
+ config = json.load(f)
279
+ if "architectures" not in config and "_class_name" not in config:
280
+ return False
281
+ return True
282
+
283
+
284
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
285
+ with open(os.path.join(file_path, "config.json"), "r") as f:
286
+ config = json.load(f)
287
+ loaded_model_names, loaded_models = [], []
288
+ architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
289
+ for architecture in architectures:
290
+ huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
291
+ if redirected_architecture is not None:
292
+ architecture = redirected_architecture
293
+ model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
294
+ loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
295
+ loaded_model_names += loaded_model_names_
296
+ loaded_models += loaded_models_
297
+ return loaded_model_names, loaded_models
298
+
299
+
300
+
301
+ class ModelDetectorFromPatchedSingleFile:
302
+ def __init__(self, model_loader_configs=[]):
303
+ self.keys_hash_with_shape_dict = {}
304
+ for metadata in model_loader_configs:
305
+ self.add_model_metadata(*metadata)
306
+
307
+
308
+ def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
309
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
310
+
311
+
312
+ def match(self, file_path="", state_dict={}):
313
+ if not isinstance(file_path, str) or os.path.isdir(file_path):
314
+ return False
315
+ if len(state_dict) == 0:
316
+ state_dict = load_state_dict(file_path)
317
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
318
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
319
+ return True
320
+ return False
321
+
322
+
323
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
324
+ if len(state_dict) == 0:
325
+ state_dict = load_state_dict(file_path)
326
+
327
+ # Load models with strict matching
328
+ loaded_model_names, loaded_models = [], []
329
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
330
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
331
+ model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
332
+ loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
333
+ state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
334
+ loaded_model_names += loaded_model_names_
335
+ loaded_models += loaded_models_
336
+ return loaded_model_names, loaded_models
337
+
338
+
339
+
340
+ class ModelManager:
341
+ def __init__(
342
+ self,
343
+ torch_dtype=torch.float16,
344
+ device="cuda",
345
+ model_id_list: List = [],
346
+ downloading_priority: List = ["ModelScope", "HuggingFace"],
347
+ file_path_list: List[str] = [],
348
+ infer: bool = False
349
+ ):
350
+ self.torch_dtype = torch_dtype
351
+ self.device = device
352
+ self.model = []
353
+ self.model_path = []
354
+ self.model_name = []
355
+ self.infer = infer
356
+ downloaded_files = []
357
+ self.model_detector = [
358
+ ModelDetectorFromSingleFile(model_loader_configs),
359
+ ModelDetectorFromSplitedSingleFile(model_loader_configs),
360
+ ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
361
+ ]
362
+ self.load_models(downloaded_files + file_path_list)
363
+
364
+ def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
365
+ if isinstance(file_path, list):
366
+ for file_path_ in file_path:
367
+ self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
368
+ else:
369
+ print(f"Loading LoRA models from file: {file_path}")
370
+ is_loaded = False
371
+ if len(state_dict) == 0:
372
+ state_dict = load_state_dict(file_path)
373
+ for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
374
+ lora = GeneralLoRAFromPeft()
375
+ match_results = lora.match(model, state_dict)
376
+ if match_results is not None:
377
+ print(f" Adding LoRA to {model_name} ({model_path}).")
378
+ lora_prefix, model_resource = match_results
379
+ lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
380
+
381
+
382
+
383
+ def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
384
+ print(f"Loading models from file: {file_path}")
385
+ if len(state_dict) == 0:
386
+ state_dict = load_state_dict(file_path)
387
+ model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device, self.infer)
388
+ for model_name, model in zip(model_names, models):
389
+ self.model.append(model)
390
+ self.model_path.append(file_path)
391
+ self.model_name.append(model_name)
392
+ print(f" The following models are loaded: {model_names}.")
393
+
394
+
395
+ def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
396
+ print(f"Loading models from folder: {file_path}")
397
+ model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
398
+ for model_name, model in zip(model_names, models):
399
+ self.model.append(model)
400
+ self.model_path.append(file_path)
401
+ self.model_name.append(model_name)
402
+ print(f" The following models are loaded: {model_names}.")
403
+
404
+
405
+ def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
406
+ print(f"Loading patch models from file: {file_path}")
407
+ model_names, models = load_patch_model_from_single_file(
408
+ state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
409
+ for model_name, model in zip(model_names, models):
410
+ self.model.append(model)
411
+ self.model_path.append(file_path)
412
+ self.model_name.append(model_name)
413
+ print(f" The following patched models are loaded: {model_names}.")
414
+
415
+ def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
416
+ print(f"Loading models from: {file_path}")
417
+ if device is None: device = self.device
418
+ if torch_dtype is None: torch_dtype = self.torch_dtype
419
+ if isinstance(file_path, list):
420
+ state_dict = {}
421
+ for path in file_path:
422
+ state_dict.update(load_state_dict(path))
423
+ elif os.path.isfile(file_path):
424
+ state_dict = load_state_dict(file_path)
425
+ else:
426
+ state_dict = None
427
+ for model_detector in self.model_detector:
428
+ if model_detector.match(file_path, state_dict):
429
+ model_names, models = model_detector.load(
430
+ file_path, state_dict,
431
+ device=device, torch_dtype=torch_dtype,
432
+ allowed_model_names=model_names, model_manager=self, infer=self.infer
433
+ )
434
+ for model_name, model in zip(model_names, models):
435
+ self.model.append(model)
436
+ self.model_path.append(file_path)
437
+ self.model_name.append(model_name)
438
+ print(f" The following models are loaded: {model_names}.")
439
+ break
440
+ else:
441
+ print(f" We cannot detect the model type. No models are loaded.")
442
+
443
+
444
+ def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
445
+ for file_path in file_path_list:
446
+ self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
447
+
448
+
449
+ def fetch_model(self, model_name, file_path=None, require_model_path=False):
450
+ fetched_models = []
451
+ fetched_model_paths = []
452
+ for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
453
+ if file_path is not None and file_path != model_path:
454
+ continue
455
+ if model_name == model_name_:
456
+ fetched_models.append(model)
457
+ fetched_model_paths.append(model_path)
458
+ if len(fetched_models) == 0:
459
+ print(f"No {model_name} models available.")
460
+ return None
461
+ if len(fetched_models) == 1:
462
+ print(f"Using {model_name} from {fetched_model_paths[0]}.")
463
+ else:
464
+ print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
465
+ if require_model_path:
466
+ return fetched_models[0], fetched_model_paths[0]
467
+ else:
468
+ return fetched_models[0]
469
+
470
+
471
+ def to(self, device):
472
+ for model in self.model:
473
+ model.to(device)
474
+
OmniAvatar/models/wan_video_dit.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import Tuple, Optional
6
+ from einops import rearrange
7
+ from ..utils.io_utils import hash_state_dict_keys
8
+ from .audio_pack import AudioPack
9
+ from ..utils.args_config import args
10
+
11
+ if args.sp_size > 1:
12
+ # Context Parallel
13
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
14
+ get_sequence_parallel_world_size,
15
+ get_sp_group)
16
+
17
+ try:
18
+ import flash_attn_interface
19
+ print('using flash_attn_interface')
20
+ FLASH_ATTN_3_AVAILABLE = True
21
+ except ModuleNotFoundError:
22
+ FLASH_ATTN_3_AVAILABLE = False
23
+
24
+ try:
25
+ import flash_attn
26
+ print('using flash_attn')
27
+ FLASH_ATTN_2_AVAILABLE = True
28
+ except ModuleNotFoundError:
29
+ FLASH_ATTN_2_AVAILABLE = False
30
+
31
+ try:
32
+ from sageattention import sageattn
33
+ print('using sageattention')
34
+ SAGE_ATTN_AVAILABLE = True
35
+ except ModuleNotFoundError:
36
+ SAGE_ATTN_AVAILABLE = False
37
+
38
+
39
+ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):
40
+ if compatibility_mode:
41
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
42
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
43
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
44
+ x = F.scaled_dot_product_attention(q, k, v)
45
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
46
+ elif FLASH_ATTN_3_AVAILABLE:
47
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
48
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
49
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
50
+ x = flash_attn_interface.flash_attn_func(q, k, v)
51
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
52
+ elif FLASH_ATTN_2_AVAILABLE:
53
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
54
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
55
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
56
+ x = flash_attn.flash_attn_func(q, k, v)
57
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
58
+ elif SAGE_ATTN_AVAILABLE:
59
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
60
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
61
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
62
+ x = sageattn(q, k, v)
63
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
64
+ else:
65
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
66
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
67
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
68
+ x = F.scaled_dot_product_attention(q, k, v)
69
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
70
+ return x
71
+
72
+
73
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
74
+ return (x * (1 + scale) + shift)
75
+
76
+
77
+ def sinusoidal_embedding_1d(dim, position):
78
+ sinusoid = torch.outer(position.type(torch.float64), torch.pow(
79
+ 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
80
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
81
+ return x.to(position.dtype)
82
+
83
+
84
+ def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
85
+ # 3d rope precompute
86
+ f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
87
+ h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
88
+ w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
89
+ return f_freqs_cis, h_freqs_cis, w_freqs_cis
90
+
91
+
92
+ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
93
+ # 1d rope precompute
94
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
95
+ [: (dim // 2)].double() / dim))
96
+ freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
97
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
98
+ return freqs_cis
99
+
100
+
101
+ def rope_apply(x, freqs, num_heads):
102
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
103
+ x_out = torch.view_as_complex(x.to(torch.float64).reshape(
104
+ x.shape[0], x.shape[1], x.shape[2], -1, 2))
105
+ x_out = torch.view_as_real(x_out * freqs).flatten(2)
106
+ return x_out.to(x.dtype)
107
+
108
+
109
+ class RMSNorm(nn.Module):
110
+ def __init__(self, dim, eps=1e-5):
111
+ super().__init__()
112
+ self.eps = eps
113
+ self.weight = nn.Parameter(torch.ones(dim))
114
+
115
+ def norm(self, x):
116
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
117
+
118
+ def forward(self, x):
119
+ dtype = x.dtype
120
+ return self.norm(x.float()).to(dtype) * self.weight
121
+
122
+
123
+ class AttentionModule(nn.Module):
124
+ def __init__(self, num_heads):
125
+ super().__init__()
126
+ self.num_heads = num_heads
127
+
128
+ def forward(self, q, k, v):
129
+ x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
130
+ return x
131
+
132
+
133
+ class SelfAttention(nn.Module):
134
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
135
+ super().__init__()
136
+ self.dim = dim
137
+ self.num_heads = num_heads
138
+ self.head_dim = dim // num_heads
139
+
140
+ self.q = nn.Linear(dim, dim)
141
+ self.k = nn.Linear(dim, dim)
142
+ self.v = nn.Linear(dim, dim)
143
+ self.o = nn.Linear(dim, dim)
144
+ self.norm_q = RMSNorm(dim, eps=eps)
145
+ self.norm_k = RMSNorm(dim, eps=eps)
146
+
147
+ self.attn = AttentionModule(self.num_heads)
148
+
149
+ def forward(self, x, freqs):
150
+ q = self.norm_q(self.q(x))
151
+ k = self.norm_k(self.k(x))
152
+ v = self.v(x)
153
+ q = rope_apply(q, freqs, self.num_heads)
154
+ k = rope_apply(k, freqs, self.num_heads)
155
+ x = self.attn(q, k, v)
156
+ return self.o(x)
157
+
158
+
159
+ class CrossAttention(nn.Module):
160
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):
161
+ super().__init__()
162
+ self.dim = dim
163
+ self.num_heads = num_heads
164
+ self.head_dim = dim // num_heads
165
+
166
+ self.q = nn.Linear(dim, dim)
167
+ self.k = nn.Linear(dim, dim)
168
+ self.v = nn.Linear(dim, dim)
169
+ self.o = nn.Linear(dim, dim)
170
+ self.norm_q = RMSNorm(dim, eps=eps)
171
+ self.norm_k = RMSNorm(dim, eps=eps)
172
+ self.has_image_input = has_image_input
173
+ if has_image_input:
174
+ self.k_img = nn.Linear(dim, dim)
175
+ self.v_img = nn.Linear(dim, dim)
176
+ self.norm_k_img = RMSNorm(dim, eps=eps)
177
+
178
+ self.attn = AttentionModule(self.num_heads)
179
+
180
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
181
+ if self.has_image_input:
182
+ img = y[:, :257]
183
+ ctx = y[:, 257:]
184
+ else:
185
+ ctx = y
186
+ q = self.norm_q(self.q(x))
187
+ k = self.norm_k(self.k(ctx))
188
+ v = self.v(ctx)
189
+ x = self.attn(q, k, v)
190
+ if self.has_image_input:
191
+ k_img = self.norm_k_img(self.k_img(img))
192
+ v_img = self.v_img(img)
193
+ y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
194
+ x = x + y
195
+ return self.o(x)
196
+
197
+
198
+ class GateModule(nn.Module):
199
+ def __init__(self,):
200
+ super().__init__()
201
+
202
+ def forward(self, x, gate, residual):
203
+ return x + gate * residual
204
+
205
+ class DiTBlock(nn.Module):
206
+ def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
207
+ super().__init__()
208
+ self.dim = dim
209
+ self.num_heads = num_heads
210
+ self.ffn_dim = ffn_dim
211
+
212
+ self.self_attn = SelfAttention(dim, num_heads, eps)
213
+ self.cross_attn = CrossAttention(
214
+ dim, num_heads, eps, has_image_input=has_image_input)
215
+ self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
216
+ self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
217
+ self.norm3 = nn.LayerNorm(dim, eps=eps)
218
+ self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
219
+ approximate='tanh'), nn.Linear(ffn_dim, dim))
220
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
221
+ self.gate = GateModule()
222
+
223
+ def forward(self, x, context, t_mod, freqs):
224
+ # msa: multi-head self-attention mlp: multi-layer perceptron
225
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
226
+ self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
227
+ input_x = modulate(self.norm1(x), shift_msa, scale_msa)
228
+ x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
229
+ x = x + self.cross_attn(self.norm3(x), context)
230
+ input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
231
+ x = self.gate(x, gate_mlp, self.ffn(input_x))
232
+ return x
233
+
234
+
235
+ class MLP(torch.nn.Module):
236
+ def __init__(self, in_dim, out_dim):
237
+ super().__init__()
238
+ self.proj = torch.nn.Sequential(
239
+ nn.LayerNorm(in_dim),
240
+ nn.Linear(in_dim, in_dim),
241
+ nn.GELU(),
242
+ nn.Linear(in_dim, out_dim),
243
+ nn.LayerNorm(out_dim)
244
+ )
245
+
246
+ def forward(self, x):
247
+ return self.proj(x)
248
+
249
+
250
+ class Head(nn.Module):
251
+ def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
252
+ super().__init__()
253
+ self.dim = dim
254
+ self.patch_size = patch_size
255
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
256
+ self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
257
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
258
+
259
+ def forward(self, x, t_mod):
260
+ shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
261
+ x = (self.head(self.norm(x) * (1 + scale) + shift))
262
+ return x
263
+
264
+
265
+
266
+ class WanModel(torch.nn.Module):
267
+ def __init__(
268
+ self,
269
+ dim: int,
270
+ in_dim: int,
271
+ ffn_dim: int,
272
+ out_dim: int,
273
+ text_dim: int,
274
+ freq_dim: int,
275
+ eps: float,
276
+ patch_size: Tuple[int, int, int],
277
+ num_heads: int,
278
+ num_layers: int,
279
+ has_image_input: bool,
280
+ audio_hidden_size: int=32,
281
+ ):
282
+ super().__init__()
283
+ self.dim = dim
284
+ self.freq_dim = freq_dim
285
+ self.has_image_input = has_image_input
286
+ self.patch_size = patch_size
287
+
288
+ self.patch_embedding = nn.Conv3d(
289
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
290
+ # nn.LayerNorm(dim)
291
+ self.text_embedding = nn.Sequential(
292
+ nn.Linear(text_dim, dim),
293
+ nn.GELU(approximate='tanh'),
294
+ nn.Linear(dim, dim)
295
+ )
296
+ self.time_embedding = nn.Sequential(
297
+ nn.Linear(freq_dim, dim),
298
+ nn.SiLU(),
299
+ nn.Linear(dim, dim)
300
+ )
301
+ self.time_projection = nn.Sequential(
302
+ nn.SiLU(), nn.Linear(dim, dim * 6))
303
+ self.blocks = nn.ModuleList([
304
+ DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
305
+ for _ in range(num_layers)
306
+ ])
307
+ self.head = Head(dim, out_dim, patch_size, eps)
308
+ head_dim = dim // num_heads
309
+ self.freqs = precompute_freqs_cis_3d(head_dim)
310
+
311
+ if has_image_input:
312
+ self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
313
+
314
+ if 'use_audio' in args:
315
+ self.use_audio = args.use_audio
316
+ else:
317
+ self.use_audio = False
318
+ if self.use_audio:
319
+ audio_input_dim = 10752
320
+ audio_out_dim = dim
321
+ self.audio_proj = AudioPack(audio_input_dim, [4, 1, 1], audio_hidden_size, layernorm=True)
322
+ self.audio_cond_projs = nn.ModuleList()
323
+ for d in range(num_layers // 2 - 1):
324
+ l = nn.Linear(audio_hidden_size, audio_out_dim)
325
+ self.audio_cond_projs.append(l)
326
+
327
+ def patchify(self, x: torch.Tensor):
328
+ grid_size = x.shape[2:]
329
+ x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
330
+ return x, grid_size # x, grid_size: (f, h, w)
331
+
332
+ def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
333
+ return rearrange(
334
+ x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
335
+ f=grid_size[0], h=grid_size[1], w=grid_size[2],
336
+ x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
337
+ )
338
+
339
+ def forward(self,
340
+ x: torch.Tensor,
341
+ timestep: torch.Tensor,
342
+ context: torch.Tensor,
343
+ clip_feature: Optional[torch.Tensor] = None,
344
+ y: Optional[torch.Tensor] = None,
345
+ use_gradient_checkpointing: bool = False,
346
+ audio_emb: Optional[torch.Tensor] = None,
347
+ use_gradient_checkpointing_offload: bool = False,
348
+ tea_cache = None,
349
+ **kwargs,
350
+ ):
351
+ t = self.time_embedding(
352
+ sinusoidal_embedding_1d(self.freq_dim, timestep))
353
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
354
+ context = self.text_embedding(context)
355
+ lat_h, lat_w = x.shape[-2], x.shape[-1]
356
+
357
+ if audio_emb != None and self.use_audio: # TODO cache
358
+ audio_emb = audio_emb.permute(0, 2, 1)[:, :, :, None, None]
359
+ audio_emb = torch.cat([audio_emb[:, :, :1].repeat(1, 1, 3, 1, 1), audio_emb], 2) # 1, 768, 44, 1, 1
360
+ audio_emb = self.audio_proj(audio_emb)
361
+
362
+ audio_emb = torch.concat([audio_cond_proj(audio_emb) for audio_cond_proj in self.audio_cond_projs], 0)
363
+
364
+ x = torch.cat([x, y], dim=1)
365
+ x = self.patch_embedding(x)
366
+ x, (f, h, w) = self.patchify(x)
367
+
368
+ freqs = torch.cat([
369
+ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
370
+ self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
371
+ self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
372
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
373
+
374
+ def create_custom_forward(module):
375
+ def custom_forward(*inputs):
376
+ return module(*inputs)
377
+ return custom_forward
378
+
379
+ if tea_cache is not None:
380
+ tea_cache_update = tea_cache.check(self, x, t_mod)
381
+ else:
382
+ tea_cache_update = False
383
+ ori_x_len = x.shape[1]
384
+ if tea_cache_update:
385
+ x = tea_cache.update(x)
386
+ else:
387
+ if args.sp_size > 1:
388
+ # Context Parallel
389
+ sp_size = get_sequence_parallel_world_size()
390
+ pad_size = 0
391
+ if ori_x_len % sp_size != 0:
392
+ pad_size = sp_size - ori_x_len % sp_size
393
+ x = torch.cat([x, torch.zeros_like(x[:, -1:]).repeat(1, pad_size, 1)], 1)
394
+ x = torch.chunk(x, sp_size, dim=1)[get_sequence_parallel_rank()]
395
+
396
+ audio_emb = audio_emb.reshape(x.shape[0], audio_emb.shape[0] // x.shape[0], -1, *audio_emb.shape[2:])
397
+
398
+ for layer_i, block in enumerate(self.blocks):
399
+ # audio cond
400
+ if self.use_audio:
401
+ au_idx = None
402
+ if (layer_i <= len(self.blocks) // 2 and layer_i > 1): # < len(self.blocks) - 1:
403
+ au_idx = layer_i - 2
404
+ audio_emb_tmp = audio_emb[:, au_idx].repeat(1, 1, lat_h // 2, lat_w // 2, 1) # 1, 11, 45, 25, 128
405
+ audio_cond_tmp = self.patchify(audio_emb_tmp.permute(0, 4, 1, 2, 3))[0]
406
+ if args.sp_size > 1:
407
+ if pad_size > 0:
408
+ audio_cond_tmp = torch.cat([audio_cond_tmp, torch.zeros_like(audio_cond_tmp[:, -1:]).repeat(1, pad_size, 1)], 1)
409
+ audio_cond_tmp = torch.chunk(audio_cond_tmp, sp_size, dim=1)[get_sequence_parallel_rank()]
410
+ x = audio_cond_tmp + x
411
+
412
+ if self.training and use_gradient_checkpointing:
413
+ if use_gradient_checkpointing_offload:
414
+ with torch.autograd.graph.save_on_cpu():
415
+ x = torch.utils.checkpoint.checkpoint(
416
+ create_custom_forward(block),
417
+ x, context, t_mod, freqs,
418
+ use_reentrant=False,
419
+ )
420
+ else:
421
+ x = torch.utils.checkpoint.checkpoint(
422
+ create_custom_forward(block),
423
+ x, context, t_mod, freqs,
424
+ use_reentrant=False,
425
+ )
426
+ else:
427
+ x = block(x, context, t_mod, freqs)
428
+ if tea_cache is not None:
429
+ x_cache = get_sp_group().all_gather(x, dim=1) # TODO: the size should be devided by sp_size
430
+ x_cache = x_cache[:, :ori_x_len]
431
+ tea_cache.store(x_cache)
432
+
433
+ x = self.head(x, t)
434
+ if args.sp_size > 1:
435
+ # Context Parallel
436
+ x = get_sp_group().all_gather(x, dim=1) # TODO: the size should be devided by sp_size
437
+ x = x[:, :ori_x_len]
438
+
439
+ x = self.unpatchify(x, (f, h, w))
440
+ return x
441
+
442
+ @staticmethod
443
+ def state_dict_converter():
444
+ return WanModelStateDictConverter()
445
+
446
+
447
+ class WanModelStateDictConverter:
448
+ def __init__(self):
449
+ pass
450
+
451
+ def from_diffusers(self, state_dict):
452
+ rename_dict = {
453
+ "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
454
+ "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
455
+ "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
456
+ "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
457
+ "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
458
+ "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
459
+ "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
460
+ "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
461
+ "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
462
+ "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
463
+ "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
464
+ "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
465
+ "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
466
+ "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
467
+ "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
468
+ "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
469
+ "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
470
+ "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
471
+ "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
472
+ "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
473
+ "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
474
+ "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
475
+ "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
476
+ "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
477
+ "blocks.0.norm2.bias": "blocks.0.norm3.bias",
478
+ "blocks.0.norm2.weight": "blocks.0.norm3.weight",
479
+ "blocks.0.scale_shift_table": "blocks.0.modulation",
480
+ "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
481
+ "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
482
+ "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
483
+ "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
484
+ "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
485
+ "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
486
+ "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
487
+ "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
488
+ "condition_embedder.time_proj.bias": "time_projection.1.bias",
489
+ "condition_embedder.time_proj.weight": "time_projection.1.weight",
490
+ "patch_embedding.bias": "patch_embedding.bias",
491
+ "patch_embedding.weight": "patch_embedding.weight",
492
+ "scale_shift_table": "head.modulation",
493
+ "proj_out.bias": "head.head.bias",
494
+ "proj_out.weight": "head.head.weight",
495
+ }
496
+ state_dict_ = {}
497
+ for name, param in state_dict.items():
498
+ if name in rename_dict:
499
+ state_dict_[rename_dict[name]] = param
500
+ else:
501
+ name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
502
+ if name_ in rename_dict:
503
+ name_ = rename_dict[name_]
504
+ name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
505
+ state_dict_[name_] = param
506
+ if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
507
+ config = {
508
+ "model_type": "t2v",
509
+ "patch_size": (1, 2, 2),
510
+ "text_len": 512,
511
+ "in_dim": 16,
512
+ "dim": 5120,
513
+ "ffn_dim": 13824,
514
+ "freq_dim": 256,
515
+ "text_dim": 4096,
516
+ "out_dim": 16,
517
+ "num_heads": 40,
518
+ "num_layers": 40,
519
+ "window_size": (-1, -1),
520
+ "qk_norm": True,
521
+ "cross_attn_norm": True,
522
+ "eps": 1e-6,
523
+ }
524
+ else:
525
+ config = {}
526
+ return state_dict_, config
527
+
528
+ def from_civitai(self, state_dict):
529
+ if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
530
+ config = {
531
+ "has_image_input": False,
532
+ "patch_size": [1, 2, 2],
533
+ "in_dim": 16,
534
+ "dim": 1536,
535
+ "ffn_dim": 8960,
536
+ "freq_dim": 256,
537
+ "text_dim": 4096,
538
+ "out_dim": 16,
539
+ "num_heads": 12,
540
+ "num_layers": 30,
541
+ "eps": 1e-6
542
+ }
543
+ elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
544
+ config = {
545
+ "has_image_input": False,
546
+ "patch_size": [1, 2, 2],
547
+ "in_dim": 16,
548
+ "dim": 5120,
549
+ "ffn_dim": 13824,
550
+ "freq_dim": 256,
551
+ "text_dim": 4096,
552
+ "out_dim": 16,
553
+ "num_heads": 40,
554
+ "num_layers": 40,
555
+ "eps": 1e-6
556
+ }
557
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
558
+ config = {
559
+ "has_image_input": True,
560
+ "patch_size": [1, 2, 2],
561
+ "in_dim": 36,
562
+ "dim": 5120,
563
+ "ffn_dim": 13824,
564
+ "freq_dim": 256,
565
+ "text_dim": 4096,
566
+ "out_dim": 16,
567
+ "num_heads": 40,
568
+ "num_layers": 40,
569
+ "eps": 1e-6
570
+ }
571
+ else:
572
+ config = {}
573
+ if hasattr(args, "model_config"):
574
+ model_config = args.model_config
575
+ if model_config is not None:
576
+ config.update(model_config)
577
+ return state_dict, config
OmniAvatar/models/wan_video_text_encoder.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def fp16_clamp(x):
9
+ if x.dtype == torch.float16 and torch.isinf(x).any():
10
+ clamp = torch.finfo(x.dtype).max - 1000
11
+ x = torch.clamp(x, min=-clamp, max=clamp)
12
+ return x
13
+
14
+
15
+ class GELU(nn.Module):
16
+
17
+ def forward(self, x):
18
+ return 0.5 * x * (1.0 + torch.tanh(
19
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
20
+
21
+
22
+ class T5LayerNorm(nn.Module):
23
+
24
+ def __init__(self, dim, eps=1e-6):
25
+ super(T5LayerNorm, self).__init__()
26
+ self.dim = dim
27
+ self.eps = eps
28
+ self.weight = nn.Parameter(torch.ones(dim))
29
+
30
+ def forward(self, x):
31
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
32
+ self.eps)
33
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
34
+ x = x.type_as(self.weight)
35
+ return self.weight * x
36
+
37
+
38
+ class T5Attention(nn.Module):
39
+
40
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
41
+ assert dim_attn % num_heads == 0
42
+ super(T5Attention, self).__init__()
43
+ self.dim = dim
44
+ self.dim_attn = dim_attn
45
+ self.num_heads = num_heads
46
+ self.head_dim = dim_attn // num_heads
47
+
48
+ # layers
49
+ self.q = nn.Linear(dim, dim_attn, bias=False)
50
+ self.k = nn.Linear(dim, dim_attn, bias=False)
51
+ self.v = nn.Linear(dim, dim_attn, bias=False)
52
+ self.o = nn.Linear(dim_attn, dim, bias=False)
53
+ self.dropout = nn.Dropout(dropout)
54
+
55
+ def forward(self, x, context=None, mask=None, pos_bias=None):
56
+ """
57
+ x: [B, L1, C].
58
+ context: [B, L2, C] or None.
59
+ mask: [B, L2] or [B, L1, L2] or None.
60
+ """
61
+ # check inputs
62
+ context = x if context is None else context
63
+ b, n, c = x.size(0), self.num_heads, self.head_dim
64
+
65
+ # compute query, key, value
66
+ q = self.q(x).view(b, -1, n, c)
67
+ k = self.k(context).view(b, -1, n, c)
68
+ v = self.v(context).view(b, -1, n, c)
69
+
70
+ # attention bias
71
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
72
+ if pos_bias is not None:
73
+ attn_bias += pos_bias
74
+ if mask is not None:
75
+ assert mask.ndim in [2, 3]
76
+ mask = mask.view(b, 1, 1,
77
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
78
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
79
+
80
+ # compute attention (T5 does not use scaling)
81
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
82
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
83
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
84
+
85
+ # output
86
+ x = x.reshape(b, -1, n * c)
87
+ x = self.o(x)
88
+ x = self.dropout(x)
89
+ return x
90
+
91
+
92
+ class T5FeedForward(nn.Module):
93
+
94
+ def __init__(self, dim, dim_ffn, dropout=0.1):
95
+ super(T5FeedForward, self).__init__()
96
+ self.dim = dim
97
+ self.dim_ffn = dim_ffn
98
+
99
+ # layers
100
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
101
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
102
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
103
+ self.dropout = nn.Dropout(dropout)
104
+
105
+ def forward(self, x):
106
+ x = self.fc1(x) * self.gate(x)
107
+ x = self.dropout(x)
108
+ x = self.fc2(x)
109
+ x = self.dropout(x)
110
+ return x
111
+
112
+
113
+ class T5SelfAttention(nn.Module):
114
+
115
+ def __init__(self,
116
+ dim,
117
+ dim_attn,
118
+ dim_ffn,
119
+ num_heads,
120
+ num_buckets,
121
+ shared_pos=True,
122
+ dropout=0.1):
123
+ super(T5SelfAttention, self).__init__()
124
+ self.dim = dim
125
+ self.dim_attn = dim_attn
126
+ self.dim_ffn = dim_ffn
127
+ self.num_heads = num_heads
128
+ self.num_buckets = num_buckets
129
+ self.shared_pos = shared_pos
130
+
131
+ # layers
132
+ self.norm1 = T5LayerNorm(dim)
133
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
134
+ self.norm2 = T5LayerNorm(dim)
135
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
136
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
137
+ num_buckets, num_heads, bidirectional=True)
138
+
139
+ def forward(self, x, mask=None, pos_bias=None):
140
+ e = pos_bias if self.shared_pos else self.pos_embedding(
141
+ x.size(1), x.size(1))
142
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
143
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
144
+ return x
145
+
146
+
147
+ class T5RelativeEmbedding(nn.Module):
148
+
149
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
150
+ super(T5RelativeEmbedding, self).__init__()
151
+ self.num_buckets = num_buckets
152
+ self.num_heads = num_heads
153
+ self.bidirectional = bidirectional
154
+ self.max_dist = max_dist
155
+
156
+ # layers
157
+ self.embedding = nn.Embedding(num_buckets, num_heads)
158
+
159
+ def forward(self, lq, lk):
160
+ device = self.embedding.weight.device
161
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
162
+ # torch.arange(lq).unsqueeze(1).to(device)
163
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
164
+ torch.arange(lq, device=device).unsqueeze(1)
165
+ rel_pos = self._relative_position_bucket(rel_pos)
166
+ rel_pos_embeds = self.embedding(rel_pos)
167
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
168
+ 0) # [1, N, Lq, Lk]
169
+ return rel_pos_embeds.contiguous()
170
+
171
+ def _relative_position_bucket(self, rel_pos):
172
+ # preprocess
173
+ if self.bidirectional:
174
+ num_buckets = self.num_buckets // 2
175
+ rel_buckets = (rel_pos > 0).long() * num_buckets
176
+ rel_pos = torch.abs(rel_pos)
177
+ else:
178
+ num_buckets = self.num_buckets
179
+ rel_buckets = 0
180
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
181
+
182
+ # embeddings for small and large positions
183
+ max_exact = num_buckets // 2
184
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
185
+ math.log(self.max_dist / max_exact) *
186
+ (num_buckets - max_exact)).long()
187
+ rel_pos_large = torch.min(
188
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
189
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
190
+ return rel_buckets
191
+
192
+ def init_weights(m):
193
+ if isinstance(m, T5LayerNorm):
194
+ nn.init.ones_(m.weight)
195
+ elif isinstance(m, T5FeedForward):
196
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
197
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
198
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
199
+ elif isinstance(m, T5Attention):
200
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
201
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
202
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
203
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
204
+ elif isinstance(m, T5RelativeEmbedding):
205
+ nn.init.normal_(
206
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
207
+
208
+
209
+ class WanTextEncoder(torch.nn.Module):
210
+
211
+ def __init__(self,
212
+ vocab=256384,
213
+ dim=4096,
214
+ dim_attn=4096,
215
+ dim_ffn=10240,
216
+ num_heads=64,
217
+ num_layers=24,
218
+ num_buckets=32,
219
+ shared_pos=False,
220
+ dropout=0.1):
221
+ super(WanTextEncoder, self).__init__()
222
+ self.dim = dim
223
+ self.dim_attn = dim_attn
224
+ self.dim_ffn = dim_ffn
225
+ self.num_heads = num_heads
226
+ self.num_layers = num_layers
227
+ self.num_buckets = num_buckets
228
+ self.shared_pos = shared_pos
229
+
230
+ # layers
231
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
232
+ else nn.Embedding(vocab, dim)
233
+ self.pos_embedding = T5RelativeEmbedding(
234
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
235
+ self.dropout = nn.Dropout(dropout)
236
+ self.blocks = nn.ModuleList([
237
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
238
+ shared_pos, dropout) for _ in range(num_layers)
239
+ ])
240
+ self.norm = T5LayerNorm(dim)
241
+
242
+ # initialize weights
243
+ self.apply(init_weights)
244
+
245
+ def forward(self, ids, mask=None):
246
+ x = self.token_embedding(ids)
247
+ x = self.dropout(x)
248
+ e = self.pos_embedding(x.size(1),
249
+ x.size(1)) if self.shared_pos else None
250
+ for block in self.blocks:
251
+ x = block(x, mask, pos_bias=e)
252
+ x = self.norm(x)
253
+ x = self.dropout(x)
254
+ return x
255
+
256
+ @staticmethod
257
+ def state_dict_converter():
258
+ return WanTextEncoderStateDictConverter()
259
+
260
+
261
+ class WanTextEncoderStateDictConverter:
262
+ def __init__(self):
263
+ pass
264
+
265
+ def from_diffusers(self, state_dict):
266
+ return state_dict
267
+
268
+ def from_civitai(self, state_dict):
269
+ return state_dict
OmniAvatar/models/wan_video_vae.py ADDED
@@ -0,0 +1,807 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange, repeat
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from tqdm import tqdm
7
+
8
+ CACHE_T = 2
9
+
10
+
11
+ def check_is_instance(model, module_class):
12
+ if isinstance(model, module_class):
13
+ return True
14
+ if hasattr(model, "module") and isinstance(model.module, module_class):
15
+ return True
16
+ return False
17
+
18
+
19
+ def block_causal_mask(x, block_size):
20
+ # params
21
+ b, n, s, _, device = *x.size(), x.device
22
+ assert s % block_size == 0
23
+ num_blocks = s // block_size
24
+
25
+ # build mask
26
+ mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
27
+ for i in range(num_blocks):
28
+ mask[:, :,
29
+ i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
30
+ return mask
31
+
32
+
33
+ class CausalConv3d(nn.Conv3d):
34
+ """
35
+ Causal 3d convolusion.
36
+ """
37
+
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs)
40
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
41
+ self.padding[1], 2 * self.padding[0], 0)
42
+ self.padding = (0, 0, 0)
43
+
44
+ def forward(self, x, cache_x=None):
45
+ padding = list(self._padding)
46
+ if cache_x is not None and self._padding[4] > 0:
47
+ cache_x = cache_x.to(x.device)
48
+ x = torch.cat([cache_x, x], dim=2)
49
+ padding[4] -= cache_x.shape[2]
50
+ x = F.pad(x, padding)
51
+
52
+ return super().forward(x)
53
+
54
+
55
+ class RMS_norm(nn.Module):
56
+
57
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
58
+ super().__init__()
59
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
60
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
61
+
62
+ self.channel_first = channel_first
63
+ self.scale = dim**0.5
64
+ self.gamma = nn.Parameter(torch.ones(shape))
65
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
66
+
67
+ def forward(self, x):
68
+ return F.normalize(
69
+ x, dim=(1 if self.channel_first else
70
+ -1)) * self.scale * self.gamma + self.bias
71
+
72
+
73
+ class Upsample(nn.Upsample):
74
+
75
+ def forward(self, x):
76
+ """
77
+ Fix bfloat16 support for nearest neighbor interpolation.
78
+ """
79
+ return super().forward(x.float()).type_as(x)
80
+
81
+
82
+ class Resample(nn.Module):
83
+
84
+ def __init__(self, dim, mode):
85
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
86
+ 'downsample3d')
87
+ super().__init__()
88
+ self.dim = dim
89
+ self.mode = mode
90
+
91
+ # layers
92
+ if mode == 'upsample2d':
93
+ self.resample = nn.Sequential(
94
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
95
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
96
+ elif mode == 'upsample3d':
97
+ self.resample = nn.Sequential(
98
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
99
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
100
+ self.time_conv = CausalConv3d(dim,
101
+ dim * 2, (3, 1, 1),
102
+ padding=(1, 0, 0))
103
+
104
+ elif mode == 'downsample2d':
105
+ self.resample = nn.Sequential(
106
+ nn.ZeroPad2d((0, 1, 0, 1)),
107
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
108
+ elif mode == 'downsample3d':
109
+ self.resample = nn.Sequential(
110
+ nn.ZeroPad2d((0, 1, 0, 1)),
111
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
112
+ self.time_conv = CausalConv3d(dim,
113
+ dim, (3, 1, 1),
114
+ stride=(2, 1, 1),
115
+ padding=(0, 0, 0))
116
+
117
+ else:
118
+ self.resample = nn.Identity()
119
+
120
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
121
+ b, c, t, h, w = x.size()
122
+ if self.mode == 'upsample3d':
123
+ if feat_cache is not None:
124
+ idx = feat_idx[0]
125
+ if feat_cache[idx] is None:
126
+ feat_cache[idx] = 'Rep'
127
+ feat_idx[0] += 1
128
+ else:
129
+
130
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
131
+ if cache_x.shape[2] < 2 and feat_cache[
132
+ idx] is not None and feat_cache[idx] != 'Rep':
133
+ # cache last frame of last two chunk
134
+ cache_x = torch.cat([
135
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
136
+ cache_x.device), cache_x
137
+ ],
138
+ dim=2)
139
+ if cache_x.shape[2] < 2 and feat_cache[
140
+ idx] is not None and feat_cache[idx] == 'Rep':
141
+ cache_x = torch.cat([
142
+ torch.zeros_like(cache_x).to(cache_x.device),
143
+ cache_x
144
+ ],
145
+ dim=2)
146
+ if feat_cache[idx] == 'Rep':
147
+ x = self.time_conv(x)
148
+ else:
149
+ x = self.time_conv(x, feat_cache[idx])
150
+ feat_cache[idx] = cache_x
151
+ feat_idx[0] += 1
152
+
153
+ x = x.reshape(b, 2, c, t, h, w)
154
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
155
+ 3)
156
+ x = x.reshape(b, c, t * 2, h, w)
157
+ t = x.shape[2]
158
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
159
+ x = self.resample(x)
160
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
161
+
162
+ if self.mode == 'downsample3d':
163
+ if feat_cache is not None:
164
+ idx = feat_idx[0]
165
+ if feat_cache[idx] is None:
166
+ feat_cache[idx] = x.clone()
167
+ feat_idx[0] += 1
168
+ else:
169
+ cache_x = x[:, :, -1:, :, :].clone()
170
+ x = self.time_conv(
171
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
172
+ feat_cache[idx] = cache_x
173
+ feat_idx[0] += 1
174
+ return x
175
+
176
+ def init_weight(self, conv):
177
+ conv_weight = conv.weight
178
+ nn.init.zeros_(conv_weight)
179
+ c1, c2, t, h, w = conv_weight.size()
180
+ one_matrix = torch.eye(c1, c2)
181
+ init_matrix = one_matrix
182
+ nn.init.zeros_(conv_weight)
183
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix
184
+ conv.weight.data.copy_(conv_weight)
185
+ nn.init.zeros_(conv.bias.data)
186
+
187
+ def init_weight2(self, conv):
188
+ conv_weight = conv.weight.data
189
+ nn.init.zeros_(conv_weight)
190
+ c1, c2, t, h, w = conv_weight.size()
191
+ init_matrix = torch.eye(c1 // 2, c2)
192
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
193
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
194
+ conv.weight.data.copy_(conv_weight)
195
+ nn.init.zeros_(conv.bias.data)
196
+
197
+
198
+ class ResidualBlock(nn.Module):
199
+
200
+ def __init__(self, in_dim, out_dim, dropout=0.0):
201
+ super().__init__()
202
+ self.in_dim = in_dim
203
+ self.out_dim = out_dim
204
+
205
+ # layers
206
+ self.residual = nn.Sequential(
207
+ RMS_norm(in_dim, images=False), nn.SiLU(),
208
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
209
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
210
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
211
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
212
+ if in_dim != out_dim else nn.Identity()
213
+
214
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
215
+ h = self.shortcut(x)
216
+ for layer in self.residual:
217
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
218
+ idx = feat_idx[0]
219
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
220
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
221
+ # cache last frame of last two chunk
222
+ cache_x = torch.cat([
223
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
224
+ cache_x.device), cache_x
225
+ ],
226
+ dim=2)
227
+ x = layer(x, feat_cache[idx])
228
+ feat_cache[idx] = cache_x
229
+ feat_idx[0] += 1
230
+ else:
231
+ x = layer(x)
232
+ return x + h
233
+
234
+
235
+ class AttentionBlock(nn.Module):
236
+ """
237
+ Causal self-attention with a single head.
238
+ """
239
+
240
+ def __init__(self, dim):
241
+ super().__init__()
242
+ self.dim = dim
243
+
244
+ # layers
245
+ self.norm = RMS_norm(dim)
246
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
247
+ self.proj = nn.Conv2d(dim, dim, 1)
248
+
249
+ # zero out the last layer params
250
+ nn.init.zeros_(self.proj.weight)
251
+
252
+ def forward(self, x):
253
+ identity = x
254
+ b, c, t, h, w = x.size()
255
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
256
+ x = self.norm(x)
257
+ # compute query, key, value
258
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
259
+ 0, 1, 3, 2).contiguous().chunk(3, dim=-1)
260
+
261
+ # apply attention
262
+ x = F.scaled_dot_product_attention(
263
+ q,
264
+ k,
265
+ v,
266
+ #attn_mask=block_causal_mask(q, block_size=h * w)
267
+ )
268
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
269
+
270
+ # output
271
+ x = self.proj(x)
272
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
273
+ return x + identity
274
+
275
+
276
+ class Encoder3d(nn.Module):
277
+
278
+ def __init__(self,
279
+ dim=128,
280
+ z_dim=4,
281
+ dim_mult=[1, 2, 4, 4],
282
+ num_res_blocks=2,
283
+ attn_scales=[],
284
+ temperal_downsample=[True, True, False],
285
+ dropout=0.0):
286
+ super().__init__()
287
+ self.dim = dim
288
+ self.z_dim = z_dim
289
+ self.dim_mult = dim_mult
290
+ self.num_res_blocks = num_res_blocks
291
+ self.attn_scales = attn_scales
292
+ self.temperal_downsample = temperal_downsample
293
+
294
+ # dimensions
295
+ dims = [dim * u for u in [1] + dim_mult]
296
+ scale = 1.0
297
+
298
+ # init block
299
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
300
+
301
+ # downsample blocks
302
+ downsamples = []
303
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
304
+ # residual (+attention) blocks
305
+ for _ in range(num_res_blocks):
306
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
307
+ if scale in attn_scales:
308
+ downsamples.append(AttentionBlock(out_dim))
309
+ in_dim = out_dim
310
+
311
+ # downsample block
312
+ if i != len(dim_mult) - 1:
313
+ mode = 'downsample3d' if temperal_downsample[
314
+ i] else 'downsample2d'
315
+ downsamples.append(Resample(out_dim, mode=mode))
316
+ scale /= 2.0
317
+ self.downsamples = nn.Sequential(*downsamples)
318
+
319
+ # middle blocks
320
+ self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
321
+ AttentionBlock(out_dim),
322
+ ResidualBlock(out_dim, out_dim, dropout))
323
+
324
+ # output blocks
325
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
326
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
327
+
328
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
329
+ if feat_cache is not None:
330
+ idx = feat_idx[0]
331
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
332
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
333
+ # cache last frame of last two chunk
334
+ cache_x = torch.cat([
335
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
336
+ cache_x.device), cache_x
337
+ ],
338
+ dim=2)
339
+ x = self.conv1(x, feat_cache[idx])
340
+ feat_cache[idx] = cache_x
341
+ feat_idx[0] += 1
342
+ else:
343
+ x = self.conv1(x)
344
+
345
+ ## downsamples
346
+ for layer in self.downsamples:
347
+ if feat_cache is not None:
348
+ x = layer(x, feat_cache, feat_idx)
349
+ else:
350
+ x = layer(x)
351
+
352
+ ## middle
353
+ for layer in self.middle:
354
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
355
+ x = layer(x, feat_cache, feat_idx)
356
+ else:
357
+ x = layer(x)
358
+
359
+ ## head
360
+ for layer in self.head:
361
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
362
+ idx = feat_idx[0]
363
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
364
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
365
+ # cache last frame of last two chunk
366
+ cache_x = torch.cat([
367
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
368
+ cache_x.device), cache_x
369
+ ],
370
+ dim=2)
371
+ x = layer(x, feat_cache[idx])
372
+ feat_cache[idx] = cache_x
373
+ feat_idx[0] += 1
374
+ else:
375
+ x = layer(x)
376
+ return x
377
+
378
+
379
+ class Decoder3d(nn.Module):
380
+
381
+ def __init__(self,
382
+ dim=128,
383
+ z_dim=4,
384
+ dim_mult=[1, 2, 4, 4],
385
+ num_res_blocks=2,
386
+ attn_scales=[],
387
+ temperal_upsample=[False, True, True],
388
+ dropout=0.0):
389
+ super().__init__()
390
+ self.dim = dim
391
+ self.z_dim = z_dim
392
+ self.dim_mult = dim_mult
393
+ self.num_res_blocks = num_res_blocks
394
+ self.attn_scales = attn_scales
395
+ self.temperal_upsample = temperal_upsample
396
+
397
+ # dimensions
398
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
399
+ scale = 1.0 / 2**(len(dim_mult) - 2)
400
+
401
+ # init block
402
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
403
+
404
+ # middle blocks
405
+ self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
406
+ AttentionBlock(dims[0]),
407
+ ResidualBlock(dims[0], dims[0], dropout))
408
+
409
+ # upsample blocks
410
+ upsamples = []
411
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
412
+ # residual (+attention) blocks
413
+ if i == 1 or i == 2 or i == 3:
414
+ in_dim = in_dim // 2
415
+ for _ in range(num_res_blocks + 1):
416
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
417
+ if scale in attn_scales:
418
+ upsamples.append(AttentionBlock(out_dim))
419
+ in_dim = out_dim
420
+
421
+ # upsample block
422
+ if i != len(dim_mult) - 1:
423
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
424
+ upsamples.append(Resample(out_dim, mode=mode))
425
+ scale *= 2.0
426
+ self.upsamples = nn.Sequential(*upsamples)
427
+
428
+ # output blocks
429
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
430
+ CausalConv3d(out_dim, 3, 3, padding=1))
431
+
432
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
433
+ ## conv1
434
+ if feat_cache is not None:
435
+ idx = feat_idx[0]
436
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
437
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
438
+ # cache last frame of last two chunk
439
+ cache_x = torch.cat([
440
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
441
+ cache_x.device), cache_x
442
+ ],
443
+ dim=2)
444
+ x = self.conv1(x, feat_cache[idx])
445
+ feat_cache[idx] = cache_x
446
+ feat_idx[0] += 1
447
+ else:
448
+ x = self.conv1(x)
449
+
450
+ ## middle
451
+ for layer in self.middle:
452
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
453
+ x = layer(x, feat_cache, feat_idx)
454
+ else:
455
+ x = layer(x)
456
+
457
+ ## upsamples
458
+ for layer in self.upsamples:
459
+ if feat_cache is not None:
460
+ x = layer(x, feat_cache, feat_idx)
461
+ else:
462
+ x = layer(x)
463
+
464
+ ## head
465
+ for layer in self.head:
466
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
467
+ idx = feat_idx[0]
468
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
469
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
470
+ # cache last frame of last two chunk
471
+ cache_x = torch.cat([
472
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
473
+ cache_x.device), cache_x
474
+ ],
475
+ dim=2)
476
+ x = layer(x, feat_cache[idx])
477
+ feat_cache[idx] = cache_x
478
+ feat_idx[0] += 1
479
+ else:
480
+ x = layer(x)
481
+ return x
482
+
483
+
484
+ def count_conv3d(model):
485
+ count = 0
486
+ for m in model.modules():
487
+ if check_is_instance(m, CausalConv3d):
488
+ count += 1
489
+ return count
490
+
491
+
492
+ class VideoVAE_(nn.Module):
493
+
494
+ def __init__(self,
495
+ dim=96,
496
+ z_dim=16,
497
+ dim_mult=[1, 2, 4, 4],
498
+ num_res_blocks=2,
499
+ attn_scales=[],
500
+ temperal_downsample=[False, True, True],
501
+ dropout=0.0):
502
+ super().__init__()
503
+ self.dim = dim
504
+ self.z_dim = z_dim
505
+ self.dim_mult = dim_mult
506
+ self.num_res_blocks = num_res_blocks
507
+ self.attn_scales = attn_scales
508
+ self.temperal_downsample = temperal_downsample
509
+ self.temperal_upsample = temperal_downsample[::-1]
510
+
511
+ # modules
512
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
513
+ attn_scales, self.temperal_downsample, dropout)
514
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
515
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
516
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
517
+ attn_scales, self.temperal_upsample, dropout)
518
+
519
+ def forward(self, x):
520
+ mu, log_var = self.encode(x)
521
+ z = self.reparameterize(mu, log_var)
522
+ x_recon = self.decode(z)
523
+ return x_recon, mu, log_var
524
+
525
+ def encode(self, x, scale):
526
+ self.clear_cache()
527
+ ## cache
528
+ t = x.shape[2]
529
+ iter_ = 1 + (t - 1) // 4
530
+
531
+ for i in range(iter_):
532
+ self._enc_conv_idx = [0]
533
+ if i == 0:
534
+ out = self.encoder(x[:, :, :1, :, :],
535
+ feat_cache=self._enc_feat_map,
536
+ feat_idx=self._enc_conv_idx)
537
+ else:
538
+ out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
539
+ feat_cache=self._enc_feat_map,
540
+ feat_idx=self._enc_conv_idx)
541
+ out = torch.cat([out, out_], 2)
542
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
543
+ if isinstance(scale[0], torch.Tensor):
544
+ scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
545
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
546
+ 1, self.z_dim, 1, 1, 1)
547
+ else:
548
+ scale = scale.to(dtype=mu.dtype, device=mu.device)
549
+ mu = (mu - scale[0]) * scale[1]
550
+ return mu
551
+
552
+ def decode(self, z, scale):
553
+ self.clear_cache()
554
+ # z: [b,c,t,h,w]
555
+ if isinstance(scale[0], torch.Tensor):
556
+ scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
557
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
558
+ 1, self.z_dim, 1, 1, 1)
559
+ else:
560
+ scale = scale.to(dtype=z.dtype, device=z.device)
561
+ z = z / scale[1] + scale[0]
562
+ iter_ = z.shape[2]
563
+ x = self.conv2(z)
564
+ for i in range(iter_):
565
+ self._conv_idx = [0]
566
+ if i == 0:
567
+ out = self.decoder(x[:, :, i:i + 1, :, :],
568
+ feat_cache=self._feat_map,
569
+ feat_idx=self._conv_idx)
570
+ else:
571
+ out_ = self.decoder(x[:, :, i:i + 1, :, :],
572
+ feat_cache=self._feat_map,
573
+ feat_idx=self._conv_idx)
574
+ out = torch.cat([out, out_], 2) # may add tensor offload
575
+ return out
576
+
577
+ def reparameterize(self, mu, log_var):
578
+ std = torch.exp(0.5 * log_var)
579
+ eps = torch.randn_like(std)
580
+ return eps * std + mu
581
+
582
+ def sample(self, imgs, deterministic=False):
583
+ mu, log_var = self.encode(imgs)
584
+ if deterministic:
585
+ return mu
586
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
587
+ return mu + std * torch.randn_like(std)
588
+
589
+ def clear_cache(self):
590
+ self._conv_num = count_conv3d(self.decoder)
591
+ self._conv_idx = [0]
592
+ self._feat_map = [None] * self._conv_num
593
+ # cache encode
594
+ self._enc_conv_num = count_conv3d(self.encoder)
595
+ self._enc_conv_idx = [0]
596
+ self._enc_feat_map = [None] * self._enc_conv_num
597
+
598
+
599
+ class WanVideoVAE(nn.Module):
600
+
601
+ def __init__(self, z_dim=16):
602
+ super().__init__()
603
+
604
+ mean = [
605
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
606
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
607
+ ]
608
+ std = [
609
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
610
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
611
+ ]
612
+ self.mean = torch.tensor(mean)
613
+ self.std = torch.tensor(std)
614
+ self.scale = [self.mean, 1.0 / self.std]
615
+
616
+ # init model
617
+ self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
618
+ self.upsampling_factor = 8
619
+
620
+
621
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
622
+ x = torch.ones((length,))
623
+ if not left_bound:
624
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
625
+ if not right_bound:
626
+ x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
627
+ return x
628
+
629
+
630
+ def build_mask(self, data, is_bound, border_width):
631
+ _, _, _, H, W = data.shape
632
+ h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
633
+ w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
634
+
635
+ h = repeat(h, "H -> H W", H=H, W=W)
636
+ w = repeat(w, "W -> H W", H=H, W=W)
637
+
638
+ mask = torch.stack([h, w]).min(dim=0).values
639
+ mask = rearrange(mask, "H W -> 1 1 1 H W")
640
+ return mask
641
+
642
+
643
+ def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
644
+ _, _, T, H, W = hidden_states.shape
645
+ size_h, size_w = tile_size
646
+ stride_h, stride_w = tile_stride
647
+
648
+ # Split tasks
649
+ tasks = []
650
+ for h in range(0, H, stride_h):
651
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
652
+ for w in range(0, W, stride_w):
653
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
654
+ h_, w_ = h + size_h, w + size_w
655
+ tasks.append((h, h_, w, w_))
656
+
657
+ data_device = "cpu"
658
+ computation_device = device
659
+
660
+ out_T = T * 4 - 3
661
+ weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
662
+ values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
663
+
664
+ for h, h_, w, w_ in tasks:
665
+ hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
666
+ hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
667
+
668
+ mask = self.build_mask(
669
+ hidden_states_batch,
670
+ is_bound=(h==0, h_>=H, w==0, w_>=W),
671
+ border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
672
+ ).to(dtype=hidden_states.dtype, device=data_device)
673
+
674
+ target_h = h * self.upsampling_factor
675
+ target_w = w * self.upsampling_factor
676
+ values[
677
+ :,
678
+ :,
679
+ :,
680
+ target_h:target_h + hidden_states_batch.shape[3],
681
+ target_w:target_w + hidden_states_batch.shape[4],
682
+ ] += hidden_states_batch * mask
683
+ weight[
684
+ :,
685
+ :,
686
+ :,
687
+ target_h: target_h + hidden_states_batch.shape[3],
688
+ target_w: target_w + hidden_states_batch.shape[4],
689
+ ] += mask
690
+ values = values / weight
691
+ values = values.clamp_(-1, 1)
692
+ return values
693
+
694
+
695
+ def tiled_encode(self, video, device, tile_size, tile_stride):
696
+ _, _, T, H, W = video.shape
697
+ size_h, size_w = tile_size
698
+ stride_h, stride_w = tile_stride
699
+
700
+ # Split tasks
701
+ tasks = []
702
+ for h in range(0, H, stride_h):
703
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
704
+ for w in range(0, W, stride_w):
705
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
706
+ h_, w_ = h + size_h, w + size_w
707
+ tasks.append((h, h_, w, w_))
708
+
709
+ data_device = "cpu"
710
+ computation_device = device
711
+
712
+ out_T = (T + 3) // 4
713
+ weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
714
+ values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
715
+
716
+ for h, h_, w, w_ in tasks:
717
+ hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
718
+ hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
719
+
720
+ mask = self.build_mask(
721
+ hidden_states_batch,
722
+ is_bound=(h==0, h_>=H, w==0, w_>=W),
723
+ border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
724
+ ).to(dtype=video.dtype, device=data_device)
725
+
726
+ target_h = h // self.upsampling_factor
727
+ target_w = w // self.upsampling_factor
728
+ values[
729
+ :,
730
+ :,
731
+ :,
732
+ target_h:target_h + hidden_states_batch.shape[3],
733
+ target_w:target_w + hidden_states_batch.shape[4],
734
+ ] += hidden_states_batch * mask
735
+ weight[
736
+ :,
737
+ :,
738
+ :,
739
+ target_h: target_h + hidden_states_batch.shape[3],
740
+ target_w: target_w + hidden_states_batch.shape[4],
741
+ ] += mask
742
+ values = values / weight
743
+ return values
744
+
745
+
746
+ def single_encode(self, video, device):
747
+ video = video.to(device)
748
+ x = self.model.encode(video, self.scale)
749
+ return x
750
+
751
+
752
+ def single_decode(self, hidden_state, device):
753
+ hidden_state = hidden_state.to(device)
754
+ video = self.model.decode(hidden_state, self.scale)
755
+ return video.clamp_(-1, 1)
756
+
757
+
758
+ def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
759
+
760
+ videos = [video.to("cpu") for video in videos]
761
+ hidden_states = []
762
+ for video in videos:
763
+ video = video.unsqueeze(0)
764
+ if tiled:
765
+ tile_size = (tile_size[0] * 8, tile_size[1] * 8)
766
+ tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
767
+ hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
768
+ else:
769
+ hidden_state = self.single_encode(video, device)
770
+ hidden_state = hidden_state.squeeze(0)
771
+ hidden_states.append(hidden_state)
772
+ hidden_states = torch.stack(hidden_states)
773
+ return hidden_states
774
+
775
+
776
+ def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
777
+ hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
778
+ videos = []
779
+ for hidden_state in hidden_states:
780
+ hidden_state = hidden_state.unsqueeze(0)
781
+ if tiled:
782
+ video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
783
+ else:
784
+ video = self.single_decode(hidden_state, device)
785
+ video = video.squeeze(0)
786
+ videos.append(video)
787
+ videos = torch.stack(videos)
788
+ return videos
789
+
790
+
791
+ @staticmethod
792
+ def state_dict_converter():
793
+ return WanVideoVAEStateDictConverter()
794
+
795
+
796
+ class WanVideoVAEStateDictConverter:
797
+
798
+ def __init__(self):
799
+ pass
800
+
801
+ def from_civitai(self, state_dict):
802
+ state_dict_ = {}
803
+ if 'model_state' in state_dict:
804
+ state_dict = state_dict['model_state']
805
+ for name in state_dict:
806
+ state_dict_['model.' + name] = state_dict[name]
807
+ return state_dict_
OmniAvatar/models/wav2vec.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=R0901
2
+ # src/models/wav2vec.py
3
+
4
+ """
5
+ This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding.
6
+ It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities
7
+ such as feature extraction and encoding.
8
+
9
+ Classes:
10
+ Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding.
11
+
12
+ Functions:
13
+ linear_interpolation: Interpolates the features based on the sequence length.
14
+ """
15
+
16
+ import torch.nn.functional as F
17
+ from transformers import Wav2Vec2Model
18
+ from transformers.modeling_outputs import BaseModelOutput
19
+
20
+
21
+ class Wav2VecModel(Wav2Vec2Model):
22
+ """
23
+ Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library.
24
+ It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
25
+ ...
26
+
27
+ Attributes:
28
+ base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
29
+
30
+ Methods:
31
+ forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
32
+ , output_attentions=None, output_hidden_states=None, return_dict=None):
33
+ Forward pass of the Wav2VecModel.
34
+ It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
35
+
36
+ feature_extract(input_values, seq_len):
37
+ Extracts features from the input_values using the base model.
38
+
39
+ encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
40
+ Encodes the extracted features using the base model and returns the encoded features.
41
+ """
42
+ def forward(
43
+ self,
44
+ input_values,
45
+ seq_len,
46
+ attention_mask=None,
47
+ mask_time_indices=None,
48
+ output_attentions=None,
49
+ output_hidden_states=None,
50
+ return_dict=None,
51
+ ):
52
+ """
53
+ Forward pass of the Wav2Vec model.
54
+
55
+ Args:
56
+ self: The instance of the model.
57
+ input_values: The input values (waveform) to the model.
58
+ seq_len: The sequence length of the input values.
59
+ attention_mask: Attention mask to be used for the model.
60
+ mask_time_indices: Mask indices to be used for the model.
61
+ output_attentions: If set to True, returns attentions.
62
+ output_hidden_states: If set to True, returns hidden states.
63
+ return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
64
+
65
+ Returns:
66
+ The output of the Wav2Vec model.
67
+ """
68
+ self.config.output_attentions = True
69
+
70
+ output_hidden_states = (
71
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72
+ )
73
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74
+
75
+ extract_features = self.feature_extractor(input_values)
76
+ extract_features = extract_features.transpose(1, 2)
77
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
78
+
79
+ if attention_mask is not None:
80
+ # compute reduced attention_mask corresponding to feature vectors
81
+ attention_mask = self._get_feature_vector_attention_mask(
82
+ extract_features.shape[1], attention_mask, add_adapter=False
83
+ )
84
+
85
+ hidden_states, extract_features = self.feature_projection(extract_features)
86
+ hidden_states = self._mask_hidden_states(
87
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
88
+ )
89
+
90
+ encoder_outputs = self.encoder(
91
+ hidden_states,
92
+ attention_mask=attention_mask,
93
+ output_attentions=output_attentions,
94
+ output_hidden_states=output_hidden_states,
95
+ return_dict=return_dict,
96
+ )
97
+
98
+ hidden_states = encoder_outputs[0]
99
+
100
+ if self.adapter is not None:
101
+ hidden_states = self.adapter(hidden_states)
102
+
103
+ if not return_dict:
104
+ return (hidden_states, ) + encoder_outputs[1:]
105
+ return BaseModelOutput(
106
+ last_hidden_state=hidden_states,
107
+ hidden_states=encoder_outputs.hidden_states,
108
+ attentions=encoder_outputs.attentions,
109
+ )
110
+
111
+
112
+ def feature_extract(
113
+ self,
114
+ input_values,
115
+ seq_len,
116
+ ):
117
+ """
118
+ Extracts features from the input values and returns the extracted features.
119
+
120
+ Parameters:
121
+ input_values (torch.Tensor): The input values to be processed.
122
+ seq_len (torch.Tensor): The sequence lengths of the input values.
123
+
124
+ Returns:
125
+ extracted_features (torch.Tensor): The extracted features from the input values.
126
+ """
127
+ extract_features = self.feature_extractor(input_values)
128
+ extract_features = extract_features.transpose(1, 2)
129
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
130
+
131
+ return extract_features
132
+
133
+ def encode(
134
+ self,
135
+ extract_features,
136
+ attention_mask=None,
137
+ mask_time_indices=None,
138
+ output_attentions=None,
139
+ output_hidden_states=None,
140
+ return_dict=None,
141
+ ):
142
+ """
143
+ Encodes the input features into the output space.
144
+
145
+ Args:
146
+ extract_features (torch.Tensor): The extracted features from the audio signal.
147
+ attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
148
+ mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
149
+ output_attentions (bool, optional): If set to True, returns the attention weights.
150
+ output_hidden_states (bool, optional): If set to True, returns all hidden states.
151
+ return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
152
+
153
+ Returns:
154
+ The encoded output features.
155
+ """
156
+ self.config.output_attentions = True
157
+
158
+ output_hidden_states = (
159
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
160
+ )
161
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
162
+
163
+ if attention_mask is not None:
164
+ # compute reduced attention_mask corresponding to feature vectors
165
+ attention_mask = self._get_feature_vector_attention_mask(
166
+ extract_features.shape[1], attention_mask, add_adapter=False
167
+ )
168
+
169
+ hidden_states, extract_features = self.feature_projection(extract_features)
170
+ hidden_states = self._mask_hidden_states(
171
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
172
+ )
173
+
174
+ encoder_outputs = self.encoder(
175
+ hidden_states,
176
+ attention_mask=attention_mask,
177
+ output_attentions=output_attentions,
178
+ output_hidden_states=output_hidden_states,
179
+ return_dict=return_dict,
180
+ )
181
+
182
+ hidden_states = encoder_outputs[0]
183
+
184
+ if self.adapter is not None:
185
+ hidden_states = self.adapter(hidden_states)
186
+
187
+ if not return_dict:
188
+ return (hidden_states, ) + encoder_outputs[1:]
189
+ return BaseModelOutput(
190
+ last_hidden_state=hidden_states,
191
+ hidden_states=encoder_outputs.hidden_states,
192
+ attentions=encoder_outputs.attentions,
193
+ )
194
+
195
+
196
+ def linear_interpolation(features, seq_len):
197
+ """
198
+ Transpose the features to interpolate linearly.
199
+
200
+ Args:
201
+ features (torch.Tensor): The extracted features to be interpolated.
202
+ seq_len (torch.Tensor): The sequence lengths of the features.
203
+
204
+ Returns:
205
+ torch.Tensor: The interpolated features.
206
+ """
207
+ features = features.transpose(1, 2)
208
+ output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
209
+ return output_features.transpose(1, 2)
OmniAvatar/prompters/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .wan_prompter import WanPrompter
OmniAvatar/prompters/base_prompter.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..models.model_manager import ModelManager
2
+ import torch
3
+
4
+
5
+
6
+ def tokenize_long_prompt(tokenizer, prompt, max_length=None):
7
+ # Get model_max_length from self.tokenizer
8
+ length = tokenizer.model_max_length if max_length is None else max_length
9
+
10
+ # To avoid the warning. set self.tokenizer.model_max_length to +oo.
11
+ tokenizer.model_max_length = 99999999
12
+
13
+ # Tokenize it!
14
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
15
+
16
+ # Determine the real length.
17
+ max_length = (input_ids.shape[1] + length - 1) // length * length
18
+
19
+ # Restore tokenizer.model_max_length
20
+ tokenizer.model_max_length = length
21
+
22
+ # Tokenize it again with fixed length.
23
+ input_ids = tokenizer(
24
+ prompt,
25
+ return_tensors="pt",
26
+ padding="max_length",
27
+ max_length=max_length,
28
+ truncation=True
29
+ ).input_ids
30
+
31
+ # Reshape input_ids to fit the text encoder.
32
+ num_sentence = input_ids.shape[1] // length
33
+ input_ids = input_ids.reshape((num_sentence, length))
34
+
35
+ return input_ids
36
+
37
+
38
+
39
+ class BasePrompter:
40
+ def __init__(self):
41
+ self.refiners = []
42
+ self.extenders = []
43
+
44
+
45
+ def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
46
+ for refiner_class in refiner_classes:
47
+ refiner = refiner_class.from_model_manager(model_manager)
48
+ self.refiners.append(refiner)
49
+
50
+ def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]):
51
+ for extender_class in extender_classes:
52
+ extender = extender_class.from_model_manager(model_manager)
53
+ self.extenders.append(extender)
54
+
55
+
56
+ @torch.no_grad()
57
+ def process_prompt(self, prompt, positive=True):
58
+ if isinstance(prompt, list):
59
+ prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
60
+ else:
61
+ for refiner in self.refiners:
62
+ prompt = refiner(prompt, positive=positive)
63
+ return prompt
64
+
65
+ @torch.no_grad()
66
+ def extend_prompt(self, prompt:str, positive=True):
67
+ extended_prompt = dict(prompt=prompt)
68
+ for extender in self.extenders:
69
+ extended_prompt = extender(extended_prompt)
70
+ return extended_prompt
OmniAvatar/prompters/wan_prompter.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_prompter import BasePrompter
2
+ from ..models.wan_video_text_encoder import WanTextEncoder
3
+ from transformers import AutoTokenizer
4
+ import os, torch
5
+ import ftfy
6
+ import html
7
+ import string
8
+ import regex as re
9
+
10
+
11
+ def basic_clean(text):
12
+ text = ftfy.fix_text(text)
13
+ text = html.unescape(html.unescape(text))
14
+ return text.strip()
15
+
16
+
17
+ def whitespace_clean(text):
18
+ text = re.sub(r'\s+', ' ', text)
19
+ text = text.strip()
20
+ return text
21
+
22
+
23
+ def canonicalize(text, keep_punctuation_exact_string=None):
24
+ text = text.replace('_', ' ')
25
+ if keep_punctuation_exact_string:
26
+ text = keep_punctuation_exact_string.join(
27
+ part.translate(str.maketrans('', '', string.punctuation))
28
+ for part in text.split(keep_punctuation_exact_string))
29
+ else:
30
+ text = text.translate(str.maketrans('', '', string.punctuation))
31
+ text = text.lower()
32
+ text = re.sub(r'\s+', ' ', text)
33
+ return text.strip()
34
+
35
+
36
+ class HuggingfaceTokenizer:
37
+
38
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
39
+ assert clean in (None, 'whitespace', 'lower', 'canonicalize')
40
+ self.name = name
41
+ self.seq_len = seq_len
42
+ self.clean = clean
43
+
44
+ # init tokenizer
45
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
46
+ self.vocab_size = self.tokenizer.vocab_size
47
+
48
+ def __call__(self, sequence, **kwargs):
49
+ return_mask = kwargs.pop('return_mask', False)
50
+
51
+ # arguments
52
+ _kwargs = {'return_tensors': 'pt'}
53
+ if self.seq_len is not None:
54
+ _kwargs.update({
55
+ 'padding': 'max_length',
56
+ 'truncation': True,
57
+ 'max_length': self.seq_len
58
+ })
59
+ _kwargs.update(**kwargs)
60
+
61
+ # tokenization
62
+ if isinstance(sequence, str):
63
+ sequence = [sequence]
64
+ if self.clean:
65
+ sequence = [self._clean(u) for u in sequence]
66
+ ids = self.tokenizer(sequence, **_kwargs)
67
+
68
+ # output
69
+ if return_mask:
70
+ return ids.input_ids, ids.attention_mask
71
+ else:
72
+ return ids.input_ids
73
+
74
+ def _clean(self, text):
75
+ if self.clean == 'whitespace':
76
+ text = whitespace_clean(basic_clean(text))
77
+ elif self.clean == 'lower':
78
+ text = whitespace_clean(basic_clean(text)).lower()
79
+ elif self.clean == 'canonicalize':
80
+ text = canonicalize(basic_clean(text))
81
+ return text
82
+
83
+
84
+ class WanPrompter(BasePrompter):
85
+
86
+ def __init__(self, tokenizer_path=None, text_len=512):
87
+ super().__init__()
88
+ self.text_len = text_len
89
+ self.text_encoder = None
90
+ self.fetch_tokenizer(tokenizer_path)
91
+
92
+ def fetch_tokenizer(self, tokenizer_path=None):
93
+ if tokenizer_path is not None:
94
+ self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
95
+
96
+ def fetch_models(self, text_encoder: WanTextEncoder = None):
97
+ self.text_encoder = text_encoder
98
+
99
+ def encode_prompt(self, prompt, positive=True, device="cuda"):
100
+ prompt = self.process_prompt(prompt, positive=positive)
101
+
102
+ ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
103
+ ids = ids.to(device)
104
+ mask = mask.to(device)
105
+ seq_lens = mask.gt(0).sum(dim=1).long()
106
+ prompt_emb = self.text_encoder(ids, mask)
107
+ for i, v in enumerate(seq_lens):
108
+ prompt_emb[:, v:] = 0
109
+ return prompt_emb
OmniAvatar/schedulers/flow_match.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+
5
+ class FlowMatchScheduler():
6
+
7
+ def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
8
+ self.num_train_timesteps = num_train_timesteps
9
+ self.shift = shift
10
+ self.sigma_max = sigma_max
11
+ self.sigma_min = sigma_min
12
+ self.inverse_timesteps = inverse_timesteps
13
+ self.extra_one_step = extra_one_step
14
+ self.reverse_sigmas = reverse_sigmas
15
+ self.set_timesteps(num_inference_steps)
16
+
17
+
18
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
19
+ if shift is not None:
20
+ self.shift = shift
21
+ sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
22
+ if self.extra_one_step:
23
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
24
+ else:
25
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
26
+ if self.inverse_timesteps:
27
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
28
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
29
+ if self.reverse_sigmas:
30
+ self.sigmas = 1 - self.sigmas
31
+ self.timesteps = self.sigmas * self.num_train_timesteps
32
+ if training:
33
+ x = self.timesteps
34
+ y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
35
+ y_shifted = y - y.min()
36
+ bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
37
+ self.linear_timesteps_weights = bsmntw_weighing
38
+
39
+
40
+ def step(self, model_output, timestep, sample, to_final=False, **kwargs):
41
+ if isinstance(timestep, torch.Tensor):
42
+ timestep = timestep.cpu()
43
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
44
+ sigma = self.sigmas[timestep_id]
45
+ if to_final or timestep_id + 1 >= len(self.timesteps):
46
+ sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
47
+ else:
48
+ sigma_ = self.sigmas[timestep_id + 1]
49
+ prev_sample = sample + model_output * (sigma_ - sigma)
50
+ return prev_sample
51
+
52
+
53
+ def return_to_timestep(self, timestep, sample, sample_stablized):
54
+ if isinstance(timestep, torch.Tensor):
55
+ timestep = timestep.cpu()
56
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
57
+ sigma = self.sigmas[timestep_id]
58
+ model_output = (sample - sample_stablized) / sigma
59
+ return model_output
60
+
61
+
62
+ def add_noise(self, original_samples, noise, timestep):
63
+ if isinstance(timestep, torch.Tensor):
64
+ timestep = timestep.cpu()
65
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
66
+ sigma = self.sigmas[timestep_id]
67
+ sample = (1 - sigma) * original_samples + sigma * noise
68
+ return sample
69
+
70
+
71
+ def training_target(self, sample, noise, timestep):
72
+ target = noise - sample
73
+ return target
74
+
75
+
76
+ def training_weight(self, timestep):
77
+ timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
78
+ weights = self.linear_timesteps_weights[timestep_id]
79
+ return weights
OmniAvatar/utils/args_config.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import argparse
4
+ import yaml
5
+ args = None
6
+
7
+ def set_global_args(local_args):
8
+ global args
9
+
10
+ args = local_args
11
+
12
+ def parse_hp_string(hp_string):
13
+ result = {}
14
+ for pair in hp_string.split(','):
15
+ if not pair:
16
+ continue
17
+ key, value = pair.split('=')
18
+ try:
19
+ # 自动转换为 int / float / str
20
+ ori_value = value
21
+ value = float(value)
22
+ if '.' not in str(ori_value):
23
+ value = int(value)
24
+ except ValueError:
25
+ pass
26
+
27
+ if value in ['true', 'True']:
28
+ value = True
29
+ if value in ['false', 'False']:
30
+ value = False
31
+ if '.' in key:
32
+ keys = key.split('.')
33
+ keys = keys
34
+ current = result
35
+ for key in keys[:-1]:
36
+ if key not in current or not isinstance(current[key], dict):
37
+ current[key] = {}
38
+ current = current[key]
39
+ current[keys[-1]] = value
40
+ else:
41
+ result[key.strip()] = value
42
+ return result
43
+
44
+ def parse_args():
45
+ global args
46
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
47
+ parser.add_argument("--config", type=str, required=True, help="Path to YAML config file.")
48
+
49
+ # 定义 argparse 参数
50
+ parser.add_argument("--exp_path", type=str, help="Path to save the model.")
51
+ parser.add_argument("--input_file", type=str, help="Path to inference txt.")
52
+ parser.add_argument("--debug", action='store_true', default=None)
53
+ parser.add_argument("--infer", action='store_true')
54
+ parser.add_argument("-hp", "--hparams", type=str, default="")
55
+
56
+ args = parser.parse_args()
57
+
58
+ # 读取 YAML 配置(如果提供了 --config 参数)
59
+ if args.config:
60
+ with open(args.config, "r") as f:
61
+ yaml_config = yaml.safe_load(f)
62
+
63
+ # 遍历 YAML 配置,将其添加到 args(如果 argparse 里没有定义)
64
+ for key, value in yaml_config.items():
65
+ if not hasattr(args, key): # argparse 没有的参数
66
+ setattr(args, key, value)
67
+ elif getattr(args, key) is None: # argparse 有但值为空
68
+ setattr(args, key, value)
69
+
70
+ args.rank = int(os.getenv("RANK", "0"))
71
+ args.world_size = int(os.getenv("WORLD_SIZE", "1"))
72
+ args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun
73
+ args.device = 'cuda'
74
+ debug = args.debug
75
+ if not os.path.exists(args.exp_path):
76
+ args.exp_path = f'checkpoints/{args.exp_path}'
77
+
78
+ if hasattr(args, 'reload_cfg') and args.reload_cfg:
79
+ # 重新加载配置文件
80
+ conf_path = os.path.join(args.exp_path, "config.json")
81
+ if os.path.exists(conf_path):
82
+ print('| Reloading config from:', conf_path)
83
+ args = reload(args, conf_path)
84
+ if len(args.hparams) > 0:
85
+ hp_dict = parse_hp_string(args.hparams)
86
+ for key, value in hp_dict.items():
87
+ if not hasattr(args, key):
88
+ setattr(args, key, value)
89
+ else:
90
+ if isinstance(value, dict):
91
+ ori_v = getattr(args, key)
92
+ ori_v.update(value)
93
+ setattr(args, key, ori_v)
94
+ else:
95
+ setattr(args, key, value)
96
+ args.debug = debug
97
+ dict_args = convert_namespace_to_dict(args)
98
+ if args.local_rank == 0:
99
+ print(dict_args)
100
+ return args
101
+
102
+ def reload(args, conf_path):
103
+ """重新加载配置文件,不覆盖已有的参数"""
104
+ with open(conf_path, "r") as f:
105
+ yaml_config = yaml.safe_load(f)
106
+ # 遍历 YAML 配置,将其添加到 args(如果 argparse 里没有定义)
107
+ for key, value in yaml_config.items():
108
+ if not hasattr(args, key): # argparse 没有的参数
109
+ setattr(args, key, value)
110
+ elif getattr(args, key) is None: # argparse 有但值为空
111
+ setattr(args, key, value)
112
+ return args
113
+
114
+ def convert_namespace_to_dict(namespace):
115
+ """将 argparse.Namespace 转为字典,并处理不可序列化对象"""
116
+ result = {}
117
+ for key, value in vars(namespace).items():
118
+ try:
119
+ json.dumps(value) # 检查是否可序列化
120
+ result[key] = value
121
+ except (TypeError, OverflowError):
122
+ result[key] = str(value) # 将不可序列化的对象转为字符串表示
123
+ return result
OmniAvatar/utils/audio_preprocess.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+
4
+ def add_silence_to_audio_ffmpeg(audio_path, tmp_audio_path, silence_duration_s=0.5):
5
+ # 使用 ffmpeg 命令在音频前加上静音
6
+ command = [
7
+ 'ffmpeg',
8
+ '-i', audio_path, # 输入音频文件路径
9
+ '-f', 'lavfi', # 使用 lavfi 虚拟输入设备生成静音
10
+ '-t', str(silence_duration_s), # 静音时长,单位秒
11
+ '-i', 'anullsrc=r=16000:cl=stereo', # 创建静音片段(假设音频为 stereo,采样率 44100)
12
+ '-filter_complex', '[1][0]concat=n=2:v=0:a=1[out]', # 合并静音和原音频
13
+ '-map', '[out]', # 输出合并后的音频
14
+ '-y', tmp_audio_path, # 输出文件路径
15
+ '-loglevel', 'quiet'
16
+ ]
17
+
18
+ subprocess.run(command, check=True)
OmniAvatar/utils/io_utils.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import torch, os
3
+ from safetensors import safe_open
4
+ from OmniAvatar.utils.args_config import args
5
+ from contextlib import contextmanager
6
+
7
+ import re
8
+ import tempfile
9
+ import numpy as np
10
+ import imageio
11
+ from glob import glob
12
+ import soundfile as sf
13
+ from einops import rearrange
14
+ import hashlib
15
+
16
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
+
18
+ @contextmanager
19
+ def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
20
+
21
+ old_register_parameter = torch.nn.Module.register_parameter
22
+ if include_buffers:
23
+ old_register_buffer = torch.nn.Module.register_buffer
24
+
25
+ def register_empty_parameter(module, name, param):
26
+ old_register_parameter(module, name, param)
27
+ if param is not None:
28
+ param_cls = type(module._parameters[name])
29
+ kwargs = module._parameters[name].__dict__
30
+ kwargs["requires_grad"] = param.requires_grad
31
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
32
+
33
+ def register_empty_buffer(module, name, buffer, persistent=True):
34
+ old_register_buffer(module, name, buffer, persistent=persistent)
35
+ if buffer is not None:
36
+ module._buffers[name] = module._buffers[name].to(device)
37
+
38
+ def patch_tensor_constructor(fn):
39
+ def wrapper(*args, **kwargs):
40
+ kwargs["device"] = device
41
+ return fn(*args, **kwargs)
42
+
43
+ return wrapper
44
+
45
+ if include_buffers:
46
+ tensor_constructors_to_patch = {
47
+ torch_function_name: getattr(torch, torch_function_name)
48
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
49
+ }
50
+ else:
51
+ tensor_constructors_to_patch = {}
52
+
53
+ try:
54
+ torch.nn.Module.register_parameter = register_empty_parameter
55
+ if include_buffers:
56
+ torch.nn.Module.register_buffer = register_empty_buffer
57
+ for torch_function_name in tensor_constructors_to_patch.keys():
58
+ setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
59
+ yield
60
+ finally:
61
+ torch.nn.Module.register_parameter = old_register_parameter
62
+ if include_buffers:
63
+ torch.nn.Module.register_buffer = old_register_buffer
64
+ for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
65
+ setattr(torch, torch_function_name, old_torch_function)
66
+
67
+ def load_state_dict_from_folder(file_path, torch_dtype=None):
68
+ state_dict = {}
69
+ for file_name in os.listdir(file_path):
70
+ if "." in file_name and file_name.split(".")[-1] in [
71
+ "safetensors", "bin", "ckpt", "pth", "pt"
72
+ ]:
73
+ state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
74
+ return state_dict
75
+
76
+
77
+ def load_state_dict(file_path, torch_dtype=None):
78
+ if file_path.endswith(".safetensors"):
79
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
80
+ else:
81
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
82
+
83
+
84
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
85
+ state_dict = {}
86
+ with safe_open(file_path, framework="pt", device="cpu") as f:
87
+ for k in f.keys():
88
+ state_dict[k] = f.get_tensor(k)
89
+ if torch_dtype is not None:
90
+ state_dict[k] = state_dict[k].to(torch_dtype)
91
+ return state_dict
92
+
93
+
94
+ def load_state_dict_from_bin(file_path, torch_dtype=None):
95
+ state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
96
+ if torch_dtype is not None:
97
+ for i in state_dict:
98
+ if isinstance(state_dict[i], torch.Tensor):
99
+ state_dict[i] = state_dict[i].to(torch_dtype)
100
+ return state_dict
101
+
102
+ def smart_load_weights(model, ckpt_state_dict):
103
+ model_state_dict = model.state_dict()
104
+ new_state_dict = {}
105
+
106
+ for name, param in model_state_dict.items():
107
+ if name in ckpt_state_dict:
108
+ ckpt_param = ckpt_state_dict[name]
109
+ if param.shape == ckpt_param.shape:
110
+ new_state_dict[name] = ckpt_param
111
+ else:
112
+ # 自动修剪维度以匹配
113
+ if all(p >= c for p, c in zip(param.shape, ckpt_param.shape)):
114
+ print(f"[Truncate] {name}: ckpt {ckpt_param.shape} -> model {param.shape}")
115
+ # 创建新张量,拷贝旧数据
116
+ new_param = param.clone()
117
+ slices = tuple(slice(0, s) for s in ckpt_param.shape)
118
+ new_param[slices] = ckpt_param
119
+ new_state_dict[name] = new_param
120
+ else:
121
+ print(f"[Skip] {name}: ckpt {ckpt_param.shape} is larger than model {param.shape}")
122
+
123
+ # 更新 state_dict,只更新那些匹配的
124
+ missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, assign=True, strict=False)
125
+ return model, missing_keys, unexpected_keys
126
+
127
+ def save_wav(audio, audio_path):
128
+ if isinstance(audio, torch.Tensor):
129
+ audio = audio.float().detach().cpu().numpy()
130
+
131
+ if audio.ndim == 1:
132
+ audio = np.expand_dims(audio, axis=0) # (1, samples)
133
+
134
+ sf.write(audio_path, audio.T, 16000)
135
+
136
+ return True
137
+
138
+ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: float = 5,prompt=None, prompt_path=None, audio=None, audio_path=None, prefix=None):
139
+ os.makedirs(save_path, exist_ok=True)
140
+ out_videos = []
141
+
142
+ with tempfile.TemporaryDirectory() as tmp_path:
143
+
144
+ print(f'video batch shape:{video_batch.shape}')
145
+
146
+ for i, vid in enumerate(video_batch):
147
+ gif_frames = []
148
+
149
+ for frame in vid:
150
+ ft = frame.detach().cpu().clone()
151
+ ft = rearrange(ft, "c h w -> h w c")
152
+ arr = (255.0 * ft).numpy().astype(np.uint8)
153
+ gif_frames.append(arr)
154
+
155
+ if prefix is not None:
156
+ now_save_path = os.path.join(save_path, f"{prefix}_{i:03d}.mp4")
157
+ tmp_save_path = os.path.join(tmp_path, f"{prefix}_{i:03d}.mp4")
158
+ else:
159
+ now_save_path = os.path.join(save_path, f"{i:03d}.mp4")
160
+ tmp_save_path = os.path.join(tmp_path, f"{i:03d}.mp4")
161
+ with imageio.get_writer(tmp_save_path, fps=fps) as writer:
162
+ for frame in gif_frames:
163
+ writer.append_data(frame)
164
+ subprocess.run([f"cp {tmp_save_path} {now_save_path}"], check=True, shell=True)
165
+ print(f'save res video to : {now_save_path}')
166
+ final_video_path = now_save_path
167
+
168
+ if audio is not None or audio_path is not None:
169
+ if audio is not None:
170
+ audio_path = os.path.join(tmp_path, f"{i:06d}.mp3")
171
+ save_wav(audio[i], audio_path)
172
+ # cmd = f'/usr/bin/ffmpeg -i {tmp_save_path} -i {audio_path} -v quiet -c:v copy -c:a libmp3lame -strict experimental {tmp_save_path[:-4]}_wav.mp4 -y'
173
+ cmd = f'/usr/bin/ffmpeg -i {tmp_save_path} -i {audio_path} -v quiet -map 0:v:0 -map 1:a:0 -c:v copy -c:a aac {tmp_save_path[:-4]}_wav.mp4 -y'
174
+ subprocess.check_call(cmd, stdout=None, stdin=subprocess.PIPE, shell=True)
175
+ final_video_path = f"{now_save_path[:-4]}_wav.mp4"
176
+ subprocess.run([f"cp {tmp_save_path[:-4]}_wav.mp4 {final_video_path}"], check=True, shell=True)
177
+ os.remove(now_save_path)
178
+ if prompt is not None and prompt_path is not None:
179
+ with open(prompt_path, "w") as f:
180
+ f.write(prompt)
181
+ out_videos.append(final_video_path)
182
+
183
+ return out_videos
184
+
185
+ def is_zero_stage_3(trainer):
186
+ strategy = getattr(trainer, "strategy", None)
187
+ if strategy and hasattr(strategy, "model"):
188
+ ds_engine = strategy.model
189
+ stage = ds_engine.config.get("zero_optimization", {}).get("stage", 0)
190
+ return stage == 3
191
+ return False
192
+
193
+ def hash_state_dict_keys(state_dict, with_shape=True):
194
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
195
+ keys_str = keys_str.encode(encoding="UTF-8")
196
+ return hashlib.md5(keys_str).hexdigest()
197
+
198
+ def split_state_dict_with_prefix(state_dict):
199
+ keys = sorted([key for key in state_dict if isinstance(key, str)])
200
+ prefix_dict = {}
201
+ for key in keys:
202
+ prefix = key if "." not in key else key.split(".")[0]
203
+ if prefix not in prefix_dict:
204
+ prefix_dict[prefix] = []
205
+ prefix_dict[prefix].append(key)
206
+ state_dicts = []
207
+ for prefix, keys in prefix_dict.items():
208
+ sub_state_dict = {key: state_dict[key] for key in keys}
209
+ state_dicts.append(sub_state_dict)
210
+ return state_dicts
211
+
212
+ def hash_state_dict_keys(state_dict, with_shape=True):
213
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
214
+ keys_str = keys_str.encode(encoding="UTF-8")
215
+ return hashlib.md5(keys_str).hexdigest()
216
+
217
+ def split_state_dict_with_prefix(state_dict):
218
+ keys = sorted([key for key in state_dict if isinstance(key, str)])
219
+ prefix_dict = {}
220
+ for key in keys:
221
+ prefix = key if "." not in key else key.split(".")[0]
222
+ if prefix not in prefix_dict:
223
+ prefix_dict[prefix] = []
224
+ prefix_dict[prefix].append(key)
225
+ state_dicts = []
226
+ for prefix, keys in prefix_dict.items():
227
+ sub_state_dict = {key: state_dict[key] for key in keys}
228
+ state_dicts.append(sub_state_dict)
229
+ return state_dicts
230
+
231
+ def search_for_files(folder, extensions):
232
+ files = []
233
+ if os.path.isdir(folder):
234
+ for file in sorted(os.listdir(folder)):
235
+ files += search_for_files(os.path.join(folder, file), extensions)
236
+ elif os.path.isfile(folder):
237
+ for extension in extensions:
238
+ if folder.endswith(extension):
239
+ files.append(folder)
240
+ break
241
+ return files
242
+
243
+ def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
244
+ keys = []
245
+ for key, value in state_dict.items():
246
+ if isinstance(key, str):
247
+ if isinstance(value, torch.Tensor):
248
+ if with_shape:
249
+ shape = "_".join(map(str, list(value.shape)))
250
+ keys.append(key + ":" + shape)
251
+ keys.append(key)
252
+ elif isinstance(value, dict):
253
+ keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
254
+ keys.sort()
255
+ keys_str = ",".join(keys)
256
+ return keys_str
OmniAvatar/vram_management/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .layers import *
OmniAvatar/vram_management/layers.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, copy
2
+ from ..utils.io_utils import init_weights_on_device
3
+
4
+
5
+ def cast_to(weight, dtype, device):
6
+ r = torch.empty_like(weight, dtype=dtype, device=device)
7
+ r.copy_(weight)
8
+ return r
9
+
10
+
11
+ class AutoWrappedModule(torch.nn.Module):
12
+ def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
13
+ super().__init__()
14
+ self.module = module.to(dtype=offload_dtype, device=offload_device)
15
+ self.offload_dtype = offload_dtype
16
+ self.offload_device = offload_device
17
+ self.onload_dtype = onload_dtype
18
+ self.onload_device = onload_device
19
+ self.computation_dtype = computation_dtype
20
+ self.computation_device = computation_device
21
+ self.state = 0
22
+
23
+ def offload(self):
24
+ if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
25
+ self.module.to(dtype=self.offload_dtype, device=self.offload_device)
26
+ self.state = 0
27
+
28
+ def onload(self):
29
+ if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
30
+ self.module.to(dtype=self.onload_dtype, device=self.onload_device)
31
+ self.state = 1
32
+
33
+ def forward(self, *args, **kwargs):
34
+ if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
35
+ module = self.module
36
+ else:
37
+ module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
38
+ return module(*args, **kwargs)
39
+
40
+
41
+ class AutoWrappedLinear(torch.nn.Linear):
42
+ def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
43
+ with init_weights_on_device(device=torch.device("meta")):
44
+ super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
45
+ self.weight = module.weight
46
+ self.bias = module.bias
47
+ self.offload_dtype = offload_dtype
48
+ self.offload_device = offload_device
49
+ self.onload_dtype = onload_dtype
50
+ self.onload_device = onload_device
51
+ self.computation_dtype = computation_dtype
52
+ self.computation_device = computation_device
53
+ self.state = 0
54
+
55
+ def offload(self):
56
+ if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
57
+ self.to(dtype=self.offload_dtype, device=self.offload_device)
58
+ self.state = 0
59
+
60
+ def onload(self):
61
+ if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
62
+ self.to(dtype=self.onload_dtype, device=self.onload_device)
63
+ self.state = 1
64
+
65
+ def forward(self, x, *args, **kwargs):
66
+ if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
67
+ weight, bias = self.weight, self.bias
68
+ else:
69
+ weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
70
+ bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
71
+ return torch.nn.functional.linear(x, weight, bias)
72
+
73
+
74
+ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
75
+ for name, module in model.named_children():
76
+ for source_module, target_module in module_map.items():
77
+ if isinstance(module, source_module):
78
+ num_param = sum(p.numel() for p in module.parameters())
79
+ if max_num_param is not None and total_num_param + num_param > max_num_param:
80
+ module_config_ = overflow_module_config
81
+ else:
82
+ module_config_ = module_config
83
+ module_ = target_module(module, **module_config_)
84
+ setattr(model, name, module_)
85
+ total_num_param += num_param
86
+ break
87
+ else:
88
+ total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
89
+ return total_num_param
90
+
91
+
92
+ def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
93
+ enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
94
+ model.vram_management_enabled = True
95
+
OmniAvatar/wan_video.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ from .models.model_manager import ModelManager
3
+ from .models.wan_video_dit import WanModel
4
+ from .models.wan_video_text_encoder import WanTextEncoder
5
+ from .models.wan_video_vae import WanVideoVAE
6
+ from .schedulers.flow_match import FlowMatchScheduler
7
+ from .base import BasePipeline
8
+ from .prompters import WanPrompter
9
+ import torch, os
10
+ from einops import rearrange
11
+ import numpy as np
12
+ from PIL import Image
13
+ from tqdm import tqdm
14
+ from typing import Optional
15
+ from .vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
16
+ from .models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
17
+ from .models.wan_video_dit import RMSNorm
18
+ from .models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
19
+
20
+
21
+ class WanVideoPipeline(BasePipeline):
22
+
23
+ def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
24
+ super().__init__(device=device, torch_dtype=torch_dtype)
25
+ self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
26
+ self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
27
+ self.text_encoder: WanTextEncoder = None
28
+ self.image_encoder = None
29
+ self.dit: WanModel = None
30
+ self.vae: WanVideoVAE = None
31
+ self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder']
32
+ self.height_division_factor = 16
33
+ self.width_division_factor = 16
34
+ self.use_unified_sequence_parallel = False
35
+ self.sp_size = 1
36
+
37
+
38
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
39
+ dtype = next(iter(self.text_encoder.parameters())).dtype
40
+ enable_vram_management(
41
+ self.text_encoder,
42
+ module_map = {
43
+ torch.nn.Linear: AutoWrappedLinear,
44
+ torch.nn.Embedding: AutoWrappedModule,
45
+ T5RelativeEmbedding: AutoWrappedModule,
46
+ T5LayerNorm: AutoWrappedModule,
47
+ },
48
+ module_config = dict(
49
+ offload_dtype=dtype,
50
+ offload_device="cpu",
51
+ onload_dtype=dtype,
52
+ onload_device="cpu",
53
+ computation_dtype=self.torch_dtype,
54
+ computation_device=self.device,
55
+ ),
56
+ )
57
+ dtype = next(iter(self.dit.parameters())).dtype
58
+ enable_vram_management(
59
+ self.dit,
60
+ module_map = {
61
+ torch.nn.Linear: AutoWrappedLinear,
62
+ torch.nn.Conv3d: AutoWrappedModule,
63
+ torch.nn.LayerNorm: AutoWrappedModule,
64
+ RMSNorm: AutoWrappedModule,
65
+ },
66
+ module_config = dict(
67
+ offload_dtype=dtype,
68
+ offload_device="cpu",
69
+ onload_dtype=dtype,
70
+ onload_device=self.device,
71
+ computation_dtype=self.torch_dtype,
72
+ computation_device=self.device,
73
+ ),
74
+ max_num_param=num_persistent_param_in_dit,
75
+ overflow_module_config = dict(
76
+ offload_dtype=dtype,
77
+ offload_device="cpu",
78
+ onload_dtype=dtype,
79
+ onload_device="cpu",
80
+ computation_dtype=self.torch_dtype,
81
+ computation_device=self.device,
82
+ ),
83
+ )
84
+ dtype = next(iter(self.vae.parameters())).dtype
85
+ enable_vram_management(
86
+ self.vae,
87
+ module_map = {
88
+ torch.nn.Linear: AutoWrappedLinear,
89
+ torch.nn.Conv2d: AutoWrappedModule,
90
+ RMS_norm: AutoWrappedModule,
91
+ CausalConv3d: AutoWrappedModule,
92
+ Upsample: AutoWrappedModule,
93
+ torch.nn.SiLU: AutoWrappedModule,
94
+ torch.nn.Dropout: AutoWrappedModule,
95
+ },
96
+ module_config = dict(
97
+ offload_dtype=dtype,
98
+ offload_device="cpu",
99
+ onload_dtype=dtype,
100
+ onload_device=self.device,
101
+ computation_dtype=self.torch_dtype,
102
+ computation_device=self.device,
103
+ ),
104
+ )
105
+ if self.image_encoder is not None:
106
+ dtype = next(iter(self.image_encoder.parameters())).dtype
107
+ enable_vram_management(
108
+ self.image_encoder,
109
+ module_map = {
110
+ torch.nn.Linear: AutoWrappedLinear,
111
+ torch.nn.Conv2d: AutoWrappedModule,
112
+ torch.nn.LayerNorm: AutoWrappedModule,
113
+ },
114
+ module_config = dict(
115
+ offload_dtype=dtype,
116
+ offload_device="cpu",
117
+ onload_dtype=dtype,
118
+ onload_device="cpu",
119
+ computation_dtype=dtype,
120
+ computation_device=self.device,
121
+ ),
122
+ )
123
+ self.enable_cpu_offload()
124
+
125
+
126
+ def fetch_models(self, model_manager: ModelManager):
127
+ text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
128
+ if text_encoder_model_and_path is not None:
129
+ self.text_encoder, tokenizer_path = text_encoder_model_and_path
130
+ self.prompter.fetch_models(self.text_encoder)
131
+ self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
132
+ self.dit = model_manager.fetch_model("wan_video_dit")
133
+ self.vae = model_manager.fetch_model("wan_video_vae")
134
+ self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
135
+
136
+
137
+ @staticmethod
138
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False, infer=False):
139
+ if device is None: device = model_manager.device
140
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
141
+ pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
142
+ pipe.fetch_models(model_manager)
143
+ if use_usp:
144
+ from xfuser.core.distributed import get_sequence_parallel_world_size, get_sp_group
145
+ from OmniAvatar.distributed.xdit_context_parallel import usp_attn_forward
146
+ for block in pipe.dit.blocks:
147
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
148
+ pipe.sp_size = get_sequence_parallel_world_size()
149
+ pipe.use_unified_sequence_parallel = True
150
+ pipe.sp_group = get_sp_group()
151
+ return pipe
152
+
153
+
154
+ def denoising_model(self):
155
+ return self.dit
156
+
157
+
158
+ def encode_prompt(self, prompt, positive=True):
159
+ prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
160
+ return {"context": prompt_emb}
161
+
162
+
163
+ def encode_image(self, image, num_frames, height, width):
164
+ image = self.preprocess_image(image.resize((width, height))).to(self.device, dtype=self.torch_dtype)
165
+ clip_context = self.image_encoder.encode_image([image])
166
+ clip_context = clip_context.to(dtype=self.torch_dtype)
167
+ msk = torch.ones(1, num_frames, height//8, width//8, device=self.device, dtype=self.torch_dtype)
168
+ msk[:, 1:] = 0
169
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
170
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
171
+ msk = msk.transpose(1, 2)[0]
172
+
173
+ vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device, dtype=self.torch_dtype)], dim=1)
174
+ y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
175
+ y = torch.concat([msk, y])
176
+ y = y.unsqueeze(0)
177
+ clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
178
+ y = y.to(dtype=self.torch_dtype, device=self.device)
179
+ return {"clip_feature": clip_context, "y": y}
180
+
181
+
182
+ def tensor2video(self, frames):
183
+ frames = rearrange(frames, "C T H W -> T H W C")
184
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
185
+ frames = [Image.fromarray(frame) for frame in frames]
186
+ return frames
187
+
188
+
189
+ def prepare_extra_input(self, latents=None):
190
+ return {}
191
+
192
+
193
+ def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
194
+ latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
195
+ return latents
196
+
197
+
198
+ def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
199
+ frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
200
+ return frames
201
+
202
+
203
+ def prepare_unified_sequence_parallel(self):
204
+ return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
205
+
206
+
207
+ @torch.no_grad()
208
+ def log_video(
209
+ self,
210
+ lat,
211
+ prompt,
212
+ fixed_frame=0, # lat frames
213
+ image_emb={},
214
+ audio_emb={},
215
+ negative_prompt="",
216
+ cfg_scale=5.0,
217
+ audio_cfg_scale=5.0,
218
+ num_inference_steps=50,
219
+ denoising_strength=1.0,
220
+ sigma_shift=5.0,
221
+ tiled=True,
222
+ tile_size=(30, 52),
223
+ tile_stride=(15, 26),
224
+ tea_cache_l1_thresh=None,
225
+ tea_cache_model_id="",
226
+ progress_bar_cmd=tqdm,
227
+ return_latent=False,
228
+ ):
229
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
230
+ # Scheduler
231
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
232
+
233
+ lat = lat.to(dtype=self.torch_dtype)
234
+ latents = lat.clone()
235
+ latents = torch.randn_like(latents, dtype=self.torch_dtype)
236
+
237
+ # Encode prompts
238
+ self.load_models_to_device(["text_encoder"])
239
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
240
+ if cfg_scale != 1.0:
241
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
242
+
243
+ # Extra input
244
+ extra_input = self.prepare_extra_input(latents)
245
+
246
+ # TeaCache
247
+ tea_cache_posi = {"tea_cache": None}
248
+ tea_cache_nega = {"tea_cache": None}
249
+
250
+ # Denoise
251
+ self.load_models_to_device(["dit"])
252
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
253
+ if fixed_frame > 0: # new
254
+ latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
255
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
256
+
257
+ # Inference
258
+ noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **audio_emb, **tea_cache_posi, **extra_input)
259
+ print(f'noise_pred_posi:{noise_pred_posi.dtype}')
260
+ if cfg_scale != 1.0:
261
+ audio_emb_uc = {}
262
+ for key in audio_emb.keys():
263
+ audio_emb_uc[key] = torch.zeros_like(audio_emb[key], dtype=self.torch_dtype)
264
+ if audio_cfg_scale == cfg_scale:
265
+ noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **audio_emb_uc, **tea_cache_nega, **extra_input)
266
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
267
+ else:
268
+ tea_cache_nega_audio = {"tea_cache": None}
269
+ audio_noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **audio_emb_uc, **tea_cache_nega_audio, **extra_input)
270
+ text_noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **audio_emb_uc, **tea_cache_nega, **extra_input)
271
+ noise_pred = text_noise_pred_nega + cfg_scale * (audio_noise_pred_nega - text_noise_pred_nega) + audio_cfg_scale * (noise_pred_posi - audio_noise_pred_nega)
272
+ else:
273
+ noise_pred = noise_pred_posi
274
+ # Scheduler
275
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
276
+
277
+ if fixed_frame > 0: # new
278
+ latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
279
+ # Decode
280
+ self.load_models_to_device(['vae'])
281
+ frames = self.decode_video(latents, **tiler_kwargs)
282
+ recons = self.decode_video(lat, **tiler_kwargs)
283
+ self.load_models_to_device([])
284
+ frames = (frames.permute(0, 2, 1, 3, 4).float() + 1.0) / 2.0
285
+ recons = (recons.permute(0, 2, 1, 3, 4).float() + 1.0) / 2.0
286
+ if return_latent:
287
+ return frames, recons, latents
288
+ return frames, recons
289
+
290
+
291
+ class TeaCache:
292
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
293
+ self.num_inference_steps = num_inference_steps
294
+ self.step = 0
295
+ self.accumulated_rel_l1_distance = 0
296
+ self.previous_modulated_input = None
297
+ self.rel_l1_thresh = rel_l1_thresh
298
+ self.previous_residual = None
299
+ self.previous_hidden_states = None
300
+
301
+ self.coefficients_dict = {
302
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
303
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
304
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
305
+ "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
306
+ }
307
+ if model_id not in self.coefficients_dict:
308
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
309
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
310
+ self.coefficients = self.coefficients_dict[model_id]
311
+
312
+ def check(self, dit: WanModel, x, t_mod):
313
+ modulated_inp = t_mod.clone()
314
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
315
+ should_calc = True
316
+ self.accumulated_rel_l1_distance = 0
317
+ else:
318
+ coefficients = self.coefficients
319
+ rescale_func = np.poly1d(coefficients)
320
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
321
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
322
+ should_calc = False
323
+ else:
324
+ should_calc = True
325
+ self.accumulated_rel_l1_distance = 0
326
+ self.previous_modulated_input = modulated_inp
327
+ self.step += 1
328
+ if self.step == self.num_inference_steps:
329
+ self.step = 0
330
+ if should_calc:
331
+ self.previous_hidden_states = x.clone()
332
+ return not should_calc
333
+
334
+ def store(self, hidden_states):
335
+ self.previous_residual = hidden_states - self.previous_hidden_states
336
+ self.previous_hidden_states = None
337
+
338
+ def update(self, hidden_states):
339
+ hidden_states = hidden_states + self.previous_residual
340
+ return hidden_states
README.md CHANGED
@@ -1,13 +1,12 @@
1
- ---
2
- title: OmniAvatar
3
- emoji: 🚀
4
- colorFrom: gray
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.39.0
8
- app_file: app.py
9
- pinned: false
10
- short_description: Avatar Video Generation with Adaptive Body Animation
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: OmniAvatar
3
+ emoji: 🏆🏆🏆🏆🏆🏆
4
+ colorFrom: gray
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 5.36.2
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
app.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import subprocess
3
+ import gradio as gr
4
+
5
+ import os, sys
6
+ from glob import glob
7
+ from datetime import datetime
8
+ import math
9
+ import random
10
+ import librosa
11
+ import numpy as np
12
+ import uuid
13
+ import shutil
14
+
15
+ import importlib, site, sys
16
+
17
+ import torch
18
+
19
+ print(f'torch version:{torch.__version__}')
20
+
21
+
22
+ import torch.nn as nn
23
+ from tqdm import tqdm
24
+ from functools import partial
25
+ from omegaconf import OmegaConf
26
+ from argparse import Namespace
27
+
28
+ # load the one true config you dumped
29
+ _args_cfg = OmegaConf.load("args_config.yaml")
30
+ args = Namespace(**OmegaConf.to_container(_args_cfg, resolve=True))
31
+
32
+ from OmniAvatar.utils.args_config import set_global_args
33
+
34
+ set_global_args(args)
35
+ # args = parse_args()
36
+
37
+ from OmniAvatar.utils.io_utils import load_state_dict
38
+ from peft import LoraConfig, inject_adapter_in_model
39
+ from OmniAvatar.models.model_manager import ModelManager
40
+ from OmniAvatar.schedulers.flow_match import FlowMatchScheduler
41
+ from OmniAvatar.wan_video import WanVideoPipeline
42
+ from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4
43
+ import torchvision.transforms as TT
44
+ from transformers import Wav2Vec2FeatureExtractor
45
+ import torchvision.transforms as transforms
46
+ import torch.nn.functional as F
47
+ from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
48
+ from huggingface_hub import hf_hub_download, snapshot_download
49
+
50
+ os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
51
+
52
+ def tensor_to_pil(tensor):
53
+ """
54
+ Args:
55
+ tensor: torch.Tensor with shape like
56
+ (1, C, H, W), (1, C, 1, H, W), (C, H, W), etc.
57
+ values in [-1, 1], on any device.
58
+ Returns:
59
+ A PIL.Image in RGB mode.
60
+ """
61
+ # 1) Remove batch dim if it exists
62
+ if tensor.dim() > 3 and tensor.shape[0] == 1:
63
+ tensor = tensor[0]
64
+
65
+ # 2) Squeeze out any other singleton dims (e.g. that extra frame axis)
66
+ tensor = tensor.squeeze()
67
+
68
+ # Now we should have exactly 3 dims: (C, H, W)
69
+ if tensor.dim() != 3:
70
+ raise ValueError(f"Expected 3 dims after squeeze, got {tensor.dim()}")
71
+
72
+ # 3) Move to CPU float32
73
+ tensor = tensor.cpu().float()
74
+
75
+ # 4) Undo normalization from [-1,1] -> [0,1]
76
+ tensor = (tensor + 1.0) / 2.0
77
+
78
+ # 5) Clamp to [0,1]
79
+ tensor = torch.clamp(tensor, 0.0, 1.0)
80
+
81
+ # 6) To NumPy H×W×C in [0,255]
82
+ np_img = (tensor.permute(1, 2, 0).numpy() * 255.0).round().astype("uint8")
83
+
84
+ # 7) Build PIL Image
85
+ return Image.fromarray(np_img)
86
+
87
+
88
+ def set_seed(seed: int = 42):
89
+ random.seed(seed)
90
+ np.random.seed(seed)
91
+ torch.manual_seed(seed)
92
+ torch.cuda.manual_seed(seed) # 设置当前GPU
93
+ torch.cuda.manual_seed_all(seed) # 设置所有GPU
94
+
95
+ def read_from_file(p):
96
+ with open(p, "r") as fin:
97
+ for l in fin:
98
+ yield l.strip()
99
+
100
+ def match_size(image_size, h, w):
101
+ ratio_ = 9999
102
+ size_ = 9999
103
+ select_size = None
104
+ for image_s in image_size:
105
+ ratio_tmp = abs(image_s[0] / image_s[1] - h / w)
106
+ size_tmp = abs(max(image_s) - max(w, h))
107
+ if ratio_tmp < ratio_:
108
+ ratio_ = ratio_tmp
109
+ size_ = size_tmp
110
+ select_size = image_s
111
+ if ratio_ == ratio_tmp:
112
+ if size_ == size_tmp:
113
+ select_size = image_s
114
+ return select_size
115
+
116
+ def resize_pad(image, ori_size, tgt_size):
117
+ h, w = ori_size
118
+ scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w)
119
+ scale_h = int(h * scale_ratio)
120
+ scale_w = int(w * scale_ratio)
121
+
122
+ image = transforms.Resize(size=[scale_h, scale_w])(image)
123
+
124
+ padding_h = tgt_size[0] - scale_h
125
+ padding_w = tgt_size[1] - scale_w
126
+ pad_top = padding_h // 2
127
+ pad_bottom = padding_h - pad_top
128
+ pad_left = padding_w // 2
129
+ pad_right = padding_w - pad_left
130
+
131
+ image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
132
+ return image
133
+
134
+ class WanInferencePipeline(nn.Module):
135
+ def __init__(self, args):
136
+ super().__init__()
137
+ self.args = args
138
+ self.device = torch.device(f"cuda")
139
+ self.dtype = torch.bfloat16
140
+ self.pipe = self.load_model()
141
+ chained_trainsforms = []
142
+ chained_trainsforms.append(TT.ToTensor())
143
+ self.transform = TT.Compose(chained_trainsforms)
144
+
145
+ if self.args.use_audio:
146
+ from OmniAvatar.models.wav2vec import Wav2VecModel
147
+ self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
148
+ self.args.wav2vec_path
149
+ )
150
+ self.audio_encoder = Wav2VecModel.from_pretrained(self.args.wav2vec_path, local_files_only=True).to(device=self.device, dtype=self.dtype)
151
+ self.audio_encoder.feature_extractor._freeze_parameters()
152
+
153
+
154
+ def load_model(self):
155
+ ckpt_path = f'{self.args.exp_path}/pytorch_model.pt'
156
+ assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}"
157
+ if self.args.train_architecture == 'lora':
158
+ self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path
159
+ else:
160
+ resume_path = ckpt_path
161
+
162
+ self.step = 0
163
+
164
+ # Load models
165
+ model_manager = ModelManager(device="cuda", infer=True)
166
+
167
+ model_manager.load_models(
168
+ [
169
+ self.args.dit_path.split(","),
170
+ self.args.vae_path,
171
+ self.args.text_encoder_path
172
+ ],
173
+ torch_dtype=self.dtype,
174
+ device='cuda',
175
+ )
176
+
177
+ pipe = WanVideoPipeline.from_model_manager(model_manager,
178
+ torch_dtype=self.dtype,
179
+ device="cuda",
180
+ use_usp=False,
181
+ infer=True)
182
+
183
+ if self.args.train_architecture == "lora":
184
+ print(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}')
185
+ self.add_lora_to_model(
186
+ pipe.denoising_model(),
187
+ lora_rank=self.args.lora_rank,
188
+ lora_alpha=self.args.lora_alpha,
189
+ lora_target_modules=self.args.lora_target_modules,
190
+ init_lora_weights=self.args.init_lora_weights,
191
+ pretrained_lora_path=pretrained_lora_path,
192
+ )
193
+ print(next(pipe.denoising_model().parameters()).device)
194
+ else:
195
+ missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True)
196
+ print(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
197
+ pipe.requires_grad_(False)
198
+ pipe.eval()
199
+ # pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
200
+ return pipe
201
+
202
+ def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
203
+ # Add LoRA to UNet
204
+
205
+ self.lora_alpha = lora_alpha
206
+ if init_lora_weights == "kaiming":
207
+ init_lora_weights = True
208
+
209
+ lora_config = LoraConfig(
210
+ r=lora_rank,
211
+ lora_alpha=lora_alpha,
212
+ init_lora_weights=init_lora_weights,
213
+ target_modules=lora_target_modules.split(","),
214
+ )
215
+ model = inject_adapter_in_model(lora_config, model)
216
+
217
+ # Lora pretrained lora weights
218
+ if pretrained_lora_path is not None:
219
+ state_dict = load_state_dict(pretrained_lora_path, torch_dtype=self.dtype)
220
+ if state_dict_converter is not None:
221
+ state_dict = state_dict_converter(state_dict)
222
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
223
+ all_keys = [i for i, _ in model.named_parameters()]
224
+ num_updated_keys = len(all_keys) - len(missing_keys)
225
+ num_unexpected_keys = len(unexpected_keys)
226
+
227
+ print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
228
+
229
+ def get_times(self, prompt,
230
+ image_path=None,
231
+ audio_path=None,
232
+ seq_len=101, # not used while audio_path is not None
233
+ height=720,
234
+ width=720,
235
+ overlap_frame=None,
236
+ num_steps=None,
237
+ negative_prompt=None,
238
+ guidance_scale=None,
239
+ audio_scale=None):
240
+
241
+ overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
242
+ num_steps = num_steps if num_steps is not None else self.args.num_steps
243
+ negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
244
+ guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
245
+ audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
246
+
247
+ if image_path is not None:
248
+ from PIL import Image
249
+ image = Image.open(image_path).convert("RGB")
250
+
251
+ image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
252
+
253
+ _, _, h, w = image.shape
254
+ select_size = match_size(getattr( self.args, f'image_sizes_{ self.args.max_hw}'), h, w)
255
+ image = resize_pad(image, (h, w), select_size)
256
+ image = image * 2.0 - 1.0
257
+ image = image[:, :, None]
258
+
259
+ else:
260
+ image = None
261
+ select_size = [height, width]
262
+ num = self.args.max_tokens * 16 * 16 * 4
263
+ den = select_size[0] * select_size[1]
264
+ L0 = num // den
265
+ diff = (L0 - 1) % 4
266
+ L = L0 - diff
267
+ if L < 1:
268
+ L = 1
269
+ T = (L + 3) // 4
270
+
271
+
272
+ if self.args.random_prefix_frames:
273
+ fixed_frame = overlap_frame
274
+ assert fixed_frame % 4 == 1
275
+ else:
276
+ fixed_frame = 1
277
+ prefix_lat_frame = (3 + fixed_frame) // 4
278
+ first_fixed_frame = 1
279
+
280
+
281
+ audio, sr = librosa.load(audio_path, sr= self.args.sample_rate)
282
+
283
+ input_values = np.squeeze(
284
+ self.wav_feature_extractor(audio, sampling_rate=16000).input_values
285
+ )
286
+ input_values = torch.from_numpy(input_values).float().to(dtype=self.dtype)
287
+ audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
288
+
289
+ if audio_len < L - first_fixed_frame:
290
+ audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
291
+ elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
292
+ audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
293
+
294
+ seq_len = audio_len
295
+
296
+ times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
297
+ if times * (L-fixed_frame) + fixed_frame < seq_len:
298
+ times += 1
299
+
300
+ return times
301
+
302
+ @torch.no_grad()
303
+ def forward(self, prompt,
304
+ image_path=None,
305
+ audio_path=None,
306
+ seq_len=101, # not used while audio_path is not None
307
+ height=720,
308
+ width=720,
309
+ overlap_frame=None,
310
+ num_steps=None,
311
+ negative_prompt=None,
312
+ guidance_scale=None,
313
+ audio_scale=None):
314
+ overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
315
+ num_steps = num_steps if num_steps is not None else self.args.num_steps
316
+ negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
317
+ guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
318
+ audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
319
+
320
+ if image_path is not None:
321
+ from PIL import Image
322
+ image = Image.open(image_path).convert("RGB")
323
+
324
+ image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
325
+
326
+ _, _, h, w = image.shape
327
+ select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
328
+ image = resize_pad(image, (h, w), select_size)
329
+ image = image * 2.0 - 1.0
330
+ image = image[:, :, None]
331
+
332
+ else:
333
+ image = None
334
+ select_size = [height, width]
335
+ # L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
336
+ # L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
337
+ # T = (L + 3) // 4 # latent frames
338
+
339
+ # step 1: numerator and denominator as ints
340
+ num = args.max_tokens * 16 * 16 * 4
341
+ den = select_size[0] * select_size[1]
342
+
343
+ # step 2: integer division
344
+ L0 = num // den # exact floor division, no float in sight
345
+
346
+ # step 3: make it ≡ 1 mod 4
347
+ # if L0 % 4 == 1, keep L0;
348
+ # otherwise subtract the difference so that (L0 - diff) % 4 == 1,
349
+ # but ensure the result stays positive.
350
+ diff = (L0 - 1) % 4
351
+ L = L0 - diff
352
+ if L < 1:
353
+ L = 1 # or whatever your minimal frame count is
354
+
355
+ # step 4: latent frames
356
+ T = (L + 3) // 4
357
+
358
+
359
+ if self.args.i2v:
360
+ if self.args.random_prefix_frames:
361
+ fixed_frame = overlap_frame
362
+ assert fixed_frame % 4 == 1
363
+ else:
364
+ fixed_frame = 1
365
+ prefix_lat_frame = (3 + fixed_frame) // 4
366
+ first_fixed_frame = 1
367
+ else:
368
+ fixed_frame = 0
369
+ prefix_lat_frame = 0
370
+ first_fixed_frame = 0
371
+
372
+
373
+ if audio_path is not None and self.args.use_audio:
374
+ audio, sr = librosa.load(audio_path, sr=self.args.sample_rate)
375
+ input_values = np.squeeze(
376
+ self.wav_feature_extractor(audio, sampling_rate=16000).input_values
377
+ )
378
+ input_values = torch.from_numpy(input_values).float().to(device=self.device, dtype=self.dtype)
379
+ ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
380
+ input_values = input_values.unsqueeze(0)
381
+ # padding audio
382
+ if audio_len < L - first_fixed_frame:
383
+ audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
384
+ elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
385
+ audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
386
+ input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0)
387
+ with torch.no_grad():
388
+ hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True)
389
+ audio_embeddings = hidden_states.last_hidden_state
390
+ for mid_hidden_states in hidden_states.hidden_states:
391
+ audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1)
392
+ seq_len = audio_len
393
+ audio_embeddings = audio_embeddings.squeeze(0)
394
+ audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame])
395
+ else:
396
+ audio_embeddings = None
397
+
398
+ # loop
399
+ times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
400
+ if times * (L-fixed_frame) + fixed_frame < seq_len:
401
+ times += 1
402
+ video = []
403
+ image_emb = {}
404
+ img_lat = None
405
+ if self.args.i2v:
406
+ self.pipe.load_models_to_device(['vae'])
407
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
408
+
409
+ msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1], dtype=self.dtype)
410
+ image_cat = img_lat.repeat(1, 1, T, 1, 1)
411
+ msk[:, :, 1:] = 1
412
+ image_emb["y"] = torch.cat([image_cat, msk], dim=1)
413
+
414
+ for t in range(times):
415
+ print(f"[{t+1}/{times}]")
416
+ audio_emb = {}
417
+ if t == 0:
418
+ overlap = first_fixed_frame
419
+ else:
420
+ overlap = fixed_frame
421
+ image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
422
+ prefix_overlap = (3 + overlap) // 4
423
+ if audio_embeddings is not None:
424
+ if t == 0:
425
+ audio_tensor = audio_embeddings[
426
+ :min(L - overlap, audio_embeddings.shape[0])
427
+ ]
428
+ else:
429
+ audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
430
+ audio_tensor = audio_embeddings[
431
+ audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
432
+ ]
433
+
434
+ audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
435
+ audio_prefix = audio_tensor[-fixed_frame:]
436
+ audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
437
+ audio_emb["audio_emb"] = audio_tensor
438
+ else:
439
+ audio_prefix = None
440
+ if image is not None and img_lat is None:
441
+ self.pipe.load_models_to_device(['vae'])
442
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
443
+ assert img_lat.shape[2] == prefix_overlap
444
+ img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1), dtype=self.dtype)], dim=2)
445
+ frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
446
+ negative_prompt, num_inference_steps=num_steps,
447
+ cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
448
+ return_latent=True,
449
+ tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B")
450
+
451
+ torch.cuda.empty_cache()
452
+ img_lat = None
453
+ image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
454
+
455
+ if t == 0:
456
+ video.append(frames)
457
+ else:
458
+ video.append(frames[:, overlap:])
459
+ video = torch.cat(video, dim=1)
460
+ video = video[:, :ori_audio_len + 1]
461
+
462
+ return video
463
+
464
+
465
+ snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir="./pretrained_models/Wan2.1-T2V-14B")
466
+ snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
467
+ snapshot_download(repo_id="OmniAvatar/OmniAvatar-14B", local_dir="./pretrained_models/OmniAvatar-14B")
468
+
469
+
470
+ # snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./pretrained_models/Wan2.1-T2V-1.3B")
471
+ # snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
472
+ # snapshot_download(repo_id="OmniAvatar/OmniAvatar-1.3B", local_dir="./pretrained_models/OmniAvatar-1.3B")
473
+
474
+ import tempfile
475
+
476
+ from PIL import Image
477
+
478
+
479
+ set_seed(args.seed)
480
+ seq_len = args.seq_len
481
+ inferpipe = WanInferencePipeline(args)
482
+
483
+
484
+ def update_generate_button(image_path, audio_path, text, num_steps):
485
+
486
+ if image_path is None or audio_path is None:
487
+ return gr.update(value="⌚ Zero GPU Required: --")
488
+
489
+ duration_s = get_duration(image_path, audio_path, text, num_steps, None, None)
490
+
491
+ return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s")
492
+
493
+ def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
494
+
495
+ audio_chunks = inferpipe.get_times(
496
+ prompt=text,
497
+ image_path=image_path,
498
+ audio_path=audio_path,
499
+ seq_len=args.seq_len,
500
+ num_steps=num_steps
501
+ )
502
+
503
+ warmup_s = 30
504
+ duration_s = (20 * num_steps) + warmup_s
505
+
506
+ if audio_chunks > 1:
507
+ duration_s = (20 * num_steps * audio_chunks) + warmup_s
508
+
509
+ print(f'for {audio_chunks} times, might take {duration_s}')
510
+
511
+ return int(duration_s)
512
+
513
+ def preprocess_img(image_path, session_id = None):
514
+
515
+ if session_id is None:
516
+ session_id = uuid.uuid4().hex
517
+
518
+ image = Image.open(image_path).convert("RGB")
519
+
520
+ image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
521
+
522
+ _, _, h, w = image.shape
523
+ select_size = match_size(getattr( args, f'image_sizes_{ args.max_hw}'), h, w)
524
+ image = resize_pad(image, (h, w), select_size)
525
+ image = image * 2.0 - 1.0
526
+ image = image[:, :, None]
527
+
528
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
529
+
530
+ img_dir = output_dir + '/image'
531
+ os.makedirs(img_dir, exist_ok=True)
532
+ input_img_path = os.path.join(img_dir, f"img_input.jpg")
533
+
534
+ image = tensor_to_pil(image)
535
+ image.save(input_img_path)
536
+
537
+ return input_img_path
538
+
539
+
540
+ @spaces.GPU(duration=get_duration)
541
+ def infer(image_path, audio_path, text, num_steps, session_id = None, progress=gr.Progress(track_tqdm=True),):
542
+
543
+ if session_id is None:
544
+ session_id = uuid.uuid4().hex
545
+
546
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
547
+
548
+ audio_dir = output_dir + '/audio'
549
+ os.makedirs(audio_dir, exist_ok=True)
550
+ if args.silence_duration_s > 0:
551
+ input_audio_path = os.path.join(audio_dir, f"audio_input.wav")
552
+ else:
553
+ input_audio_path = audio_path
554
+ prompt_dir = output_dir + '/prompt'
555
+ os.makedirs(prompt_dir, exist_ok=True)
556
+
557
+ if args.silence_duration_s > 0:
558
+ add_silence_to_audio_ffmpeg(audio_path, input_audio_path, args.silence_duration_s)
559
+
560
+ tmp2_audio_path = os.path.join(audio_dir, f"audio_out.wav")
561
+ prompt_path = os.path.join(prompt_dir, f"prompt.txt")
562
+
563
+ video = inferpipe(
564
+ prompt=text,
565
+ image_path=image_path,
566
+ audio_path=input_audio_path,
567
+ seq_len=args.seq_len,
568
+ num_steps=num_steps
569
+ )
570
+
571
+ torch.cuda.empty_cache()
572
+
573
+ add_silence_to_audio_ffmpeg(audio_path, tmp2_audio_path, 1.0 / args.fps + args.silence_duration_s)
574
+ video_paths = save_video_as_grid_and_mp4(video,
575
+ output_dir,
576
+ args.fps,
577
+ prompt=text,
578
+ prompt_path = prompt_path,
579
+ audio_path=tmp2_audio_path if args.use_audio else None,
580
+ prefix=f'result')
581
+
582
+ return video_paths[0]
583
+
584
+ def cleanup(request: gr.Request):
585
+
586
+ sid = request.session_hash
587
+ if sid:
588
+ d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
589
+ shutil.rmtree(d1, ignore_errors=True)
590
+
591
+ def start_session(request: gr.Request):
592
+
593
+ return request.session_hash
594
+
595
+ css = """
596
+ #col-container {
597
+ margin: 0 auto;
598
+ max-width: 1560px;
599
+ }
600
+ """
601
+ theme = gr.themes.Ocean()
602
+
603
+ with gr.Blocks(css=css, theme=theme) as demo:
604
+
605
+ session_state = gr.State()
606
+ demo.load(start_session, outputs=[session_state])
607
+
608
+
609
+ with gr.Column(elem_id="col-container"):
610
+ gr.HTML(
611
+ """
612
+ <div style="text-align: left;">
613
+ <p style="font-size:16px; display: inline; margin: 0;">
614
+ <strong>OmniAvatar</strong> – Efficient Audio-Driven Avatar Video Generation with Adaptive Body Animation
615
+ </p>
616
+ <a href="https://github.com/Omni-Avatar/OmniAvatar" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
617
+ <img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo">
618
+ </a>
619
+ </div>
620
+ <div style="text-align: left;">
621
+ HF Space by :<a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
622
+ <img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
623
+ </a>
624
+ </div>
625
+
626
+ <div style="text-align: left;">
627
+ <a href="https://huggingface.co/alexnasa">
628
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
629
+ </a>
630
+ </div>
631
+
632
+ """
633
+ )
634
+
635
+ with gr.Row():
636
+
637
+ with gr.Column():
638
+
639
+ image_input = gr.Image(label="Reference Image", type="filepath", height=512)
640
+ audio_input = gr.Audio(label="Input Audio", type="filepath")
641
+
642
+
643
+ with gr.Column():
644
+
645
+ output_video = gr.Video(label="Avatar", height=512)
646
+ num_steps = gr.Slider(1, 50, value=8, step=1, label="Steps")
647
+ time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
648
+ infer_btn = gr.Button("🦜 Avatar Me", variant="primary")
649
+ with gr.Accordion("Advanced Settings", open=False):
650
+ text_input = gr.Textbox(label="Prompt Text", lines=4, value="A realistic video of a person speaking directly to the camera on a sofa, with dynamic and rhythmic hand gestures that complement their speech. Their hands are clearly visible, independent, and unobstructed. Their facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.")
651
+
652
+ with gr.Column():
653
+
654
+ examples = gr.Examples(
655
+ examples=[
656
+ [
657
+ "examples/images/female-001.png",
658
+ "examples/audios/mushroom.wav",
659
+ "A realistic video of a woman speaking and sometimes looking directly to the camera, sitting on a sofa, with dynamic and rhythmic and extensive hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
660
+ 12
661
+ ],
662
+ [
663
+ "examples/images/male-001.png",
664
+ "examples/audios/tape.wav",
665
+ "A realistic video of a man moving his hands extensively and speaking. The motion of his hands matches his speech. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.",
666
+ 8
667
+ ],
668
+ ],
669
+ inputs=[image_input, audio_input, text_input, num_steps],
670
+ outputs=[output_video],
671
+ fn=infer,
672
+ cache_examples=True
673
+ )
674
+
675
+ infer_btn.click(
676
+ fn=infer,
677
+ inputs=[image_input, audio_input, text_input, num_steps, session_state],
678
+ outputs=[output_video]
679
+ )
680
+ image_input.upload(fn=preprocess_img, inputs=[image_input, session_state], outputs=[image_input]).then(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
681
+ audio_input.upload(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
682
+ num_steps.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
683
+
684
+
685
+ if __name__ == "__main__":
686
+ demo.unload(cleanup)
687
+ demo.queue()
688
+ demo.launch(ssr_mode=False)
args_config.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ config: configs/inference.yaml
2
+
3
+ input_file: examples/infer_samples.txt
4
+ debug: null
5
+ infer: false
6
+ hparams: ''
7
+ dtype: bf16
8
+
9
+ exp_path: pretrained_models/OmniAvatar-14B
10
+ text_encoder_path: pretrained_models/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth
11
+ image_encoder_path: None
12
+ dit_path: pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors
13
+ vae_path: pretrained_models/Wan2.1-T2V-14B/Wan2.1_VAE.pth
14
+
15
+ # exp_path: pretrained_models/OmniAvatar-1.3B
16
+ # text_encoder_path: pretrained_models/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
17
+ # image_encoder_path: None
18
+ # dit_path: pretrained_models/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors
19
+ # vae_path: pretrained_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
20
+
21
+ wav2vec_path: pretrained_models/wav2vec2-base-960h
22
+ num_persistent_param_in_dit:
23
+ reload_cfg: true
24
+ sp_size: 1
25
+ seed: 42
26
+ image_sizes_720:
27
+ - - 400
28
+ - 720
29
+ # - - 720 commented out due duration needed on HF
30
+ # - 720
31
+ - - 720
32
+ - 400
33
+ image_sizes_1280:
34
+ - - 720
35
+ - 720
36
+ - - 528
37
+ - 960
38
+ - - 960
39
+ - 528
40
+ - - 720
41
+ - 1280
42
+ - - 1280
43
+ - 720
44
+ max_hw: 720
45
+ max_tokens: 40000
46
+ seq_len: 200
47
+ overlap_frame: 13
48
+ guidance_scale: 4.5
49
+ audio_scale: null
50
+ num_steps: 8
51
+ fps: 24
52
+ sample_rate: 16000
53
+ negative_prompt: Vivid color tones, background/camera moving quickly, screen switching,
54
+ subtitles and special effects, mutation, overexposed, static, blurred details, subtitles,
55
+ style, work, painting, image, still, overall grayish, worst quality, low quality,
56
+ JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly
57
+ drawn face, deformed, disfigured, malformed limbs, fingers merging, motionless image,
58
+ chaotic background, three legs, crowded background with many people, walking backward
59
+ silence_duration_s: 0.0
60
+ use_fsdp: false
61
+ tea_cache_l1_thresh: 0
62
+ rank: 0
63
+ world_size: 1
64
+ local_rank: 0
65
+ device: cuda
66
+ num_nodes: 1
67
+ i2v: true
68
+ use_audio: true
69
+ random_prefix_frames: true
70
+ model_config:
71
+ in_dim: 33
72
+ audio_hidden_size: 32
73
+ train_architecture: lora
74
+ lora_target_modules: q,k,v,o,ffn.0,ffn.2
75
+ init_lora_weights: kaiming
76
+ lora_rank: 128
77
+ lora_alpha: 64.0
assets/material/pipeline.png ADDED
assets/material/teaser.png ADDED
configs/inference.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 预训练模型路径
2
+ dtype: "bf16"
3
+ text_encoder_path: pretrained_models/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth
4
+ image_encoder_path: None
5
+ dit_path: pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,pretrained_models/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors
6
+ vae_path: pretrained_models/Wan2.1-T2V-14B/Wan2.1_VAE.pth
7
+ wav2vec_path: pretrained_models/wav2vec2-base-960h
8
+ exp_path: pretrained_models/OmniAvatar-14B
9
+ num_persistent_param_in_dit: # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
10
+
11
+ reload_cfg: True
12
+ sp_size: 1
13
+
14
+ # 数据参数
15
+ seed: 42
16
+ image_sizes_720: [[400, 720],
17
+ [720, 720],
18
+ [720, 400]]
19
+ image_sizes_1280: [
20
+ [720, 720],
21
+ [528, 960],
22
+ [960, 528],
23
+ [720, 1280],
24
+ [1280, 720]]
25
+ max_hw: 720 # 720: 480p; 1280: 720p
26
+ max_tokens: 30000
27
+ seq_len: 200
28
+ overlap_frame: 13 # must be 1 + 4*n
29
+ guidance_scale: 4.5
30
+ audio_scale:
31
+ num_steps: 16
32
+ fps: 25
33
+ sample_rate: 16000
34
+ negative_prompt: "Vivid color tones, background/camera moving quickly, screen switching, subtitles and special effects, mutation, overexposed, static, blurred details, subtitles, style, work, painting, image, still, overall grayish, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, fingers merging, motionless image, chaotic background, three legs, crowded background with many people, walking backward"
35
+ silence_duration_s: 0.3
36
+ use_fsdp: False
37
+ tea_cache_l1_thresh: 0 # 0.14 The larger this value is, the faster the speed, but the worse the visual quality. TODO check value
configs/inference_1.3B.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 预训练模型路径
2
+ dtype: "bf16"
3
+ text_encoder_path: pretrained_models/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
4
+ image_encoder_path: None
5
+ dit_path: pretrained_models/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors
6
+ vae_path: pretrained_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
7
+ wav2vec_path: pretrained_models/wav2vec2-base-960h
8
+ exp_path: pretrained_models/OmniAvatar-1.3B
9
+ num_persistent_param_in_dit: # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
10
+
11
+ reload_cfg: True
12
+ sp_size: 1
13
+
14
+ # 数据参数
15
+ seed: 42
16
+ image_sizes_720: [[400, 720],
17
+ [720, 720],
18
+ [720, 400]]
19
+ image_sizes_1280: [
20
+ [720, 720],
21
+ [528, 960],
22
+ [960, 528],
23
+ [720, 1280],
24
+ [1280, 720]]
25
+ max_hw: 720 # 720: 480p; 1280: 720p
26
+ max_tokens: 30000
27
+ seq_len: 200
28
+ overlap_frame: 13 # must be 1 + 4*n
29
+ guidance_scale: 4.5
30
+ audio_scale:
31
+ num_steps: 10
32
+ fps: 25
33
+ sample_rate: 16000
34
+ negative_prompt: "Vivid color tones, background/camera moving quickly, screen switching, subtitles and special effects, mutation, overexposed, static, blurred details, subtitles, style, work, painting, image, still, overall grayish, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, fingers merging, motionless image, chaotic background, three legs, crowded background with many people, walking backward"
35
+ silence_duration_s: 0.3
36
+ use_fsdp: False
37
+ tea_cache_l1_thresh: 0 # 0.14 The larger this value is, the faster the speed, but the worse the visual quality. TODO check value
examples/audios/mushroom.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d10e51bb6169d206f4eccbbde14868a1c9a09e07bd7a2f18258cc2b265620226
3
+ size 460878
examples/audios/tape.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba4e3861b6b124ecde7dc621612957613c495c64d4ba8c0f26e9c87fe5c1566c
3
+ size 270764
examples/images/female-001.png ADDED

Git LFS Details

  • SHA256: a7bef121857c09664e4890420b983d81a0efa3c7ef5b6dbeb4564e60a4cbdaad
  • Pointer size: 132 Bytes
  • Size of remote file: 2.65 MB
examples/images/male-001.png ADDED

Git LFS Details

  • SHA256: df0b06323cc249d5a3acc56a1a766a0481e19ba278a0c7d2047044bde8ff02d1
  • Pointer size: 132 Bytes
  • Size of remote file: 2.76 MB
examples/infer_samples.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A realistic video of a man speaking directly to the camera on a sofa, with dynamic and rhythmic hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence.@@examples/images/0000.jpeg@@examples/audios/0000.MP3
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tqdm
2
+ librosa==0.10.2.post1
3
+ peft==0.15.1
4
+ transformers==4.52.3
5
+ scipy==1.14.0
6
+ numpy==1.26.4
7
+ xfuser==0.4.1
8
+ ftfy
9
+ einops
10
+ omegaconf
11
+ torchvision
12
+ ninja
13
+ flash-attn-3 @ https://github.com/OutofAi/PyTorch3D-wheels/releases/download/0.7.8/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
scripts/inference.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import os, sys
3
+ from glob import glob
4
+ from datetime import datetime
5
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
6
+ import math
7
+ import random
8
+ import librosa
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ from tqdm import tqdm
13
+ from functools import partial
14
+ from omegaconf import OmegaConf
15
+ from argparse import Namespace
16
+
17
+ # # load the one true config you dumped
18
+ # _args_cfg = OmegaConf.load("demo_out/config/args_config.yaml")
19
+ # args = Namespace(**OmegaConf.to_container(_args_cfg, resolve=True))
20
+
21
+ # from OmniAvatar.utils.args_config import set_global_args
22
+
23
+ # set_global_args(args)
24
+
25
+ from OmniAvatar.utils.args_config import parse_args
26
+ args = parse_args()
27
+
28
+ from OmniAvatar.utils.io_utils import load_state_dict
29
+ from peft import LoraConfig, inject_adapter_in_model
30
+ from OmniAvatar.models.model_manager import ModelManager
31
+ from OmniAvatar.wan_video import WanVideoPipeline
32
+ from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4
33
+ import torchvision.transforms as TT
34
+ from transformers import Wav2Vec2FeatureExtractor
35
+ import torchvision.transforms as transforms
36
+ import torch.nn.functional as F
37
+ from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
38
+ from huggingface_hub import hf_hub_download
39
+
40
+ def set_seed(seed: int = 42):
41
+ random.seed(seed)
42
+ np.random.seed(seed)
43
+ torch.manual_seed(seed)
44
+ torch.cuda.manual_seed(seed) # 设置当前GPU
45
+ torch.cuda.manual_seed_all(seed) # 设置所有GPU
46
+
47
+ def read_from_file(p):
48
+ with open(p, "r") as fin:
49
+ for l in fin:
50
+ yield l.strip()
51
+
52
+ def match_size(image_size, h, w):
53
+ ratio_ = 9999
54
+ size_ = 9999
55
+ select_size = None
56
+ for image_s in image_size:
57
+ ratio_tmp = abs(image_s[0] / image_s[1] - h / w)
58
+ size_tmp = abs(max(image_s) - max(w, h))
59
+ if ratio_tmp < ratio_:
60
+ ratio_ = ratio_tmp
61
+ size_ = size_tmp
62
+ select_size = image_s
63
+ if ratio_ == ratio_tmp:
64
+ if size_ == size_tmp:
65
+ select_size = image_s
66
+ return select_size
67
+
68
+ def resize_pad(image, ori_size, tgt_size):
69
+ h, w = ori_size
70
+ scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w)
71
+ scale_h = int(h * scale_ratio)
72
+ scale_w = int(w * scale_ratio)
73
+
74
+ image = transforms.Resize(size=[scale_h, scale_w])(image)
75
+
76
+ padding_h = tgt_size[0] - scale_h
77
+ padding_w = tgt_size[1] - scale_w
78
+ pad_top = padding_h // 2
79
+ pad_bottom = padding_h - pad_top
80
+ pad_left = padding_w // 2
81
+ pad_right = padding_w - pad_left
82
+
83
+ image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
84
+ return image
85
+
86
+ class WanInferencePipeline(nn.Module):
87
+ def __init__(self, args):
88
+ super().__init__()
89
+ self.args = args
90
+ self.device = torch.device(f"cuda")
91
+ if self.args.dtype=='bf16':
92
+ self.dtype = torch.bfloat16
93
+ elif self.args.dtype=='fp16':
94
+ self.dtype = torch.float16
95
+ else:
96
+ self.dtype = torch.float32
97
+ self.pipe = self.load_model()
98
+ if self.args.i2v:
99
+ chained_trainsforms = []
100
+ chained_trainsforms.append(TT.ToTensor())
101
+ self.transform = TT.Compose(chained_trainsforms)
102
+ if self.args.use_audio:
103
+ from OmniAvatar.models.wav2vec import Wav2VecModel
104
+ self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
105
+ self.args.wav2vec_path
106
+ )
107
+ self.audio_encoder = Wav2VecModel.from_pretrained(self.args.wav2vec_path, local_files_only=True).to(device=self.device)
108
+ self.audio_encoder.feature_extractor._freeze_parameters()
109
+
110
+ def load_model(self):
111
+ torch.cuda.set_device(0)
112
+ ckpt_path = f'{self.args.exp_path}/pytorch_model.pt'
113
+ assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}"
114
+ if self.args.train_architecture == 'lora':
115
+ self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path
116
+ else:
117
+ resume_path = ckpt_path
118
+
119
+ self.step = 0
120
+
121
+ # Load models
122
+ model_manager = ModelManager(device="cpu", infer=True)
123
+ model_manager.load_models(
124
+ [
125
+ self.args.dit_path.split(","),
126
+ self.args.text_encoder_path,
127
+ self.args.vae_path
128
+ ],
129
+ torch_dtype=self.dtype, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
130
+ device='cpu',
131
+ )
132
+ LORA_REPO_ID = "Kijai/WanVideo_comfy"
133
+ LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
134
+ causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
135
+ model_manager.load_lora(causvid_path, lora_alpha=1.0)
136
+ pipe = WanVideoPipeline.from_model_manager(model_manager,
137
+ torch_dtype=self.dtype,
138
+ device=f"cuda",
139
+ use_usp=True if self.args.sp_size > 1 else False,
140
+ infer=True)
141
+ if self.args.train_architecture == "lora":
142
+ print(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}')
143
+ self.add_lora_to_model(
144
+ pipe.denoising_model(),
145
+ lora_rank=self.args.lora_rank,
146
+ lora_alpha=self.args.lora_alpha,
147
+ lora_target_modules=self.args.lora_target_modules,
148
+ init_lora_weights=self.args.init_lora_weights,
149
+ pretrained_lora_path=pretrained_lora_path,
150
+ )
151
+ else:
152
+ missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True)
153
+ print(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
154
+ pipe.requires_grad_(False)
155
+ pipe.eval()
156
+ pipe.enable_vram_management(num_persistent_param_in_dit=self.args.num_persistent_param_in_dit) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
157
+ return pipe
158
+
159
+ def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
160
+ # Add LoRA to UNet
161
+ self.lora_alpha = lora_alpha
162
+ if init_lora_weights == "kaiming":
163
+ init_lora_weights = True
164
+
165
+ lora_config = LoraConfig(
166
+ r=lora_rank,
167
+ lora_alpha=lora_alpha,
168
+ init_lora_weights=init_lora_weights,
169
+ target_modules=lora_target_modules.split(","),
170
+ )
171
+ model = inject_adapter_in_model(lora_config, model)
172
+
173
+ # Lora pretrained lora weights
174
+ if pretrained_lora_path is not None:
175
+ state_dict = load_state_dict(pretrained_lora_path)
176
+ if state_dict_converter is not None:
177
+ state_dict = state_dict_converter(state_dict)
178
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
179
+ all_keys = [i for i, _ in model.named_parameters()]
180
+ num_updated_keys = len(all_keys) - len(missing_keys)
181
+ num_unexpected_keys = len(unexpected_keys)
182
+ print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
183
+
184
+
185
+ def forward(self, prompt,
186
+ image_path=None,
187
+ audio_path=None,
188
+ seq_len=101, # not used while audio_path is not None
189
+ height=720,
190
+ width=720,
191
+ overlap_frame=None,
192
+ num_steps=None,
193
+ negative_prompt=None,
194
+ guidance_scale=None,
195
+ audio_scale=None):
196
+ overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
197
+ num_steps = num_steps if num_steps is not None else self.args.num_steps
198
+ negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
199
+ guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
200
+ audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
201
+
202
+ if image_path is not None:
203
+ from PIL import Image
204
+ image = Image.open(image_path).convert("RGB")
205
+ image = self.transform(image).unsqueeze(0).to(self.device)
206
+ _, _, h, w = image.shape
207
+ select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
208
+ image = resize_pad(image, (h, w), select_size)
209
+ image = image * 2.0 - 1.0
210
+ image = image[:, :, None]
211
+ else:
212
+ image = None
213
+ select_size = [height, width]
214
+ L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
215
+ L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
216
+ T = (L + 3) // 4 # latent frames
217
+
218
+ if self.args.i2v:
219
+ if self.args.random_prefix_frames:
220
+ fixed_frame = overlap_frame
221
+ assert fixed_frame % 4 == 1
222
+ else:
223
+ fixed_frame = 1
224
+ prefix_lat_frame = (3 + fixed_frame) // 4
225
+ first_fixed_frame = 1
226
+ else:
227
+ fixed_frame = 0
228
+ prefix_lat_frame = 0
229
+ first_fixed_frame = 0
230
+
231
+
232
+ if audio_path is not None and self.args.use_audio:
233
+ audio, sr = librosa.load(audio_path, sr=self.args.sample_rate)
234
+ input_values = np.squeeze(
235
+ self.wav_feature_extractor(audio, sampling_rate=16000).input_values
236
+ )
237
+ input_values = torch.from_numpy(input_values).float().to(device=self.device)
238
+ ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
239
+ input_values = input_values.unsqueeze(0)
240
+ # padding audio
241
+ if audio_len < L - first_fixed_frame:
242
+ audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
243
+ elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
244
+ audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
245
+ input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0)
246
+ with torch.no_grad():
247
+ hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True)
248
+ audio_embeddings = hidden_states.last_hidden_state
249
+ for mid_hidden_states in hidden_states.hidden_states:
250
+ audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1)
251
+ seq_len = audio_len
252
+ audio_embeddings = audio_embeddings.squeeze(0)
253
+ audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame])
254
+ else:
255
+ audio_embeddings = None
256
+
257
+ # loop
258
+ times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
259
+ if times * (L-fixed_frame) + fixed_frame < seq_len:
260
+ times += 1
261
+ video = []
262
+ image_emb = {}
263
+ img_lat = None
264
+ if self.args.i2v:
265
+ self.pipe.load_models_to_device(['vae'])
266
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device)
267
+
268
+ msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1])
269
+ image_cat = img_lat.repeat(1, 1, T, 1, 1)
270
+ msk[:, :, 1:] = 1
271
+ image_emb["y"] = torch.cat([image_cat, msk], dim=1)
272
+ for t in range(times):
273
+ print(f"[{t+1}/{times}]")
274
+ audio_emb = {}
275
+ if t == 0:
276
+ overlap = first_fixed_frame
277
+ else:
278
+ overlap = fixed_frame
279
+ image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
280
+ prefix_overlap = (3 + overlap) // 4
281
+ if audio_embeddings is not None:
282
+ if t == 0:
283
+ audio_tensor = audio_embeddings[
284
+ :min(L - overlap, audio_embeddings.shape[0])
285
+ ]
286
+ else:
287
+ audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
288
+ audio_tensor = audio_embeddings[
289
+ audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
290
+ ]
291
+
292
+ audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
293
+ audio_prefix = audio_tensor[-fixed_frame:]
294
+ audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
295
+ audio_emb["audio_emb"] = audio_tensor
296
+ else:
297
+ audio_prefix = None
298
+ if image is not None and img_lat is None:
299
+ self.pipe.load_models_to_device(['vae'])
300
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device)
301
+ assert img_lat.shape[2] == prefix_overlap
302
+ img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1))], dim=2)
303
+ frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
304
+ negative_prompt, num_inference_steps=num_steps,
305
+ cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
306
+ return_latent=True,
307
+ tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B")
308
+ img_lat = None
309
+ image = (frames[:, -fixed_frame:].clip(0, 1) * 2 - 1).permute(0, 2, 1, 3, 4).contiguous()
310
+ if t == 0:
311
+ video.append(frames)
312
+ else:
313
+ video.append(frames[:, overlap:])
314
+ video = torch.cat(video, dim=1)
315
+ video = video[:, :ori_audio_len + 1]
316
+ return video
317
+
318
+
319
+ def main():
320
+
321
+ # os.makedirs("demo_out/config", exist_ok=True)
322
+ # OmegaConf.save(config=OmegaConf.create(vars(args)),
323
+ # f="demo_out/config/args_config.yaml")
324
+ # print("Saved merged args to demo_out/config/args_config.yaml")
325
+
326
+ set_seed(args.seed)
327
+ # laod data
328
+ data_iter = read_from_file(args.input_file)
329
+ exp_name = os.path.basename(args.exp_path)
330
+ seq_len = args.seq_len
331
+
332
+ # Text-to-video
333
+ inferpipe = WanInferencePipeline(args)
334
+
335
+ output_dir = f'demo_out'
336
+
337
+ idx = 0
338
+ text = "A realistic video of a man speaking directly to the camera on a sofa, with dynamic and rhythmic hand gestures that complement his speech. His hands are clearly visible, independent, and unobstructed. His facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence."
339
+ image_path = "examples/images/0000.jpeg"
340
+ audio_path = "examples/audios/0000.MP3"
341
+ audio_dir = output_dir + '/audio'
342
+ os.makedirs(audio_dir, exist_ok=True)
343
+ if args.silence_duration_s > 0:
344
+ input_audio_path = os.path.join(audio_dir, f"audio_input_{idx:03d}.wav")
345
+ else:
346
+ input_audio_path = audio_path
347
+ prompt_dir = output_dir + '/prompt'
348
+ os.makedirs(prompt_dir, exist_ok=True)
349
+
350
+ if args.silence_duration_s > 0:
351
+ add_silence_to_audio_ffmpeg(audio_path, input_audio_path, args.silence_duration_s)
352
+
353
+ video = inferpipe(
354
+ prompt=text,
355
+ image_path=image_path,
356
+ audio_path=input_audio_path,
357
+ seq_len=seq_len
358
+ )
359
+ tmp2_audio_path = os.path.join(audio_dir, f"audio_out_{idx:03d}.wav") # 因为第一帧是参考帧,因此需要往前1/25秒
360
+ prompt_path = os.path.join(prompt_dir, f"prompt_{idx:03d}.txt")
361
+
362
+
363
+ add_silence_to_audio_ffmpeg(audio_path, tmp2_audio_path, 1.0 / args.fps + args.silence_duration_s)
364
+ save_video_as_grid_and_mp4(video,
365
+ output_dir,
366
+ args.fps,
367
+ prompt=text,
368
+ prompt_path = prompt_path,
369
+ audio_path=tmp2_audio_path if args.use_audio else None,
370
+ prefix=f'result_{idx:03d}')
371
+
372
+
373
+ class NoPrint:
374
+ def write(self, x):
375
+ pass
376
+ def flush(self):
377
+ pass
378
+
379
+ if __name__ == '__main__':
380
+ if not args.debug:
381
+ if args.local_rank != 0: # 屏蔽除0外的输出
382
+ sys.stdout = NoPrint()
383
+ main()