File size: 14,166 Bytes
c765ebd
 
 
74b06bf
371129f
c765ebd
 
371129f
3b82033
371129f
 
963c9e7
c765ebd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371129f
78d1c13
240e8b2
c765ebd
8a905b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fe452e
8a905b7
672d0ad
0280b37
8fe452e
0280b37
 
 
 
 
 
 
5c49df9
 
 
 
0280b37
 
 
 
 
 
 
 
 
 
 
 
8fe452e
0280b37
 
 
 
 
 
 
 
 
 
8fe452e
0280b37
 
 
8fe452e
0280b37
5c49df9
 
 
8fe452e
5c49df9
 
 
 
8fe452e
0280b37
5c49df9
 
 
8fe452e
0280b37
 
 
 
8fe452e
0280b37
 
 
8fe452e
 
 
 
 
09be21c
8fe452e
0280b37
a661a7f
0280b37
 
 
 
 
d2cddaf
0280b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a905b7
0280b37
 
 
 
 
8fe452e
0280b37
8fe452e
0280b37
 
 
8a905b7
0280b37
 
 
 
 
 
 
 
 
 
 
 
8fe452e
0280b37
 
 
8fe452e
0280b37
8a905b7
 
 
09be21c
8a905b7
 
c765ebd
 
8a905b7
 
 
c765ebd
 
 
371129f
d2cddaf
 
 
 
 
371129f
963c9e7
d2cddaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371129f
 
 
d2cddaf
371129f
d2cddaf
371129f
 
d2cddaf
 
 
 
 
 
 
 
371129f
 
 
 
 
d2cddaf
 
 
 
 
 
 
 
 
 
 
 
371129f
d2cddaf
 
 
 
 
963c9e7
d2cddaf
963c9e7
d2cddaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f8b092
 
d2cddaf
 
371129f
d2cddaf
 
 
 
 
 
 
 
 
963c9e7
d2cddaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371129f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
import types
import random
import spaces
import os
import torch
import numpy as np
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler
from diffusers.utils import export_to_video
from huggingface_hub import snapshot_download
import gradio as gr
import tempfile
from huggingface_hub import hf_hub_download

from src.pipeline_wan_nag import NAGWanPipeline
from src.transformer_wan_nag import NagWanTransformer3DModel


MOD_VALUE = 32
DEFAULT_DURATION_SECONDS = 4
DEFAULT_STEPS = 4
DEFAULT_SEED = 2025
DEFAULT_H_SLIDER_VALUE = 480
DEFAULT_W_SLIDER_VALUE = 832
NEW_FORMULA_MAX_AREA = 480.0 * 832.0

SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
MAX_SEED = np.iinfo(np.int32).max

FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 81

DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"


# MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
# LORA_REPO_ID = "Kijai/WanVideo_comfy"
# LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
# # Additional enhancement LoRAs for FusionX-like quality
# ACCVIDEO_LORA_REPO = "alibaba-pai/Wan2.1-Fun-Reward-LoRAs"
# MPS_LORA_FILENAME = "Wan2.1-MPS-Reward-LoRA.safetensors"

# vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
# pipe = NAGWanPipeline.from_pretrained(
#     MODEL_ID, vae=vae, torch_dtype=torch.bfloat16
# )
# pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
# pipe.to("cuda")

# causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
# pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
# pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
# for name, param in pipe.transformer.named_parameters():
#     if "lora_B" in name:
#         if "blocks.0" in name:
#             param.data = param.data * 0.25
# pipe.fuse_lora()
# pipe.unload_lora_weights()


###### Working attempt 2 #########################

#--- Model and LoRA definitions ---
MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"

# LoRA 1: CausVid
LORA1_REPO_ID = "Kijai/WanVideo_comfy"
LORA1_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
LORA1_NAME = "causvid_lora"
LORA1_WEIGHT = 0.95

# LoRA 2: MPS Reward
LORA2_REPO_ID = "alibaba-pai/Wan2.1-Fun-Reward-LoRAs"
LORA2_FILENAME = "Wan2.1-MPS-Reward-LoRA.safetensors"
LORA2_NAME = "mps_lora"
LORA2_WEIGHT = 0.7

# # LoRA 3: (NEW) Insert actual repo/filename below
# LORA3_REPO_ID = "your-username/your-lora-repo"
# LORA3_FILENAME = "your_third_lora.safetensors"
# LORA3_NAME = "third_lora"
# LORA3_WEIGHT = 0.85

# --- Load model and VAE ---
vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
pipe = NAGWanPipeline.from_pretrained(MODEL_ID, vae=vae, torch_dtype=torch.bfloat16)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
pipe.to("cuda")

# --- LoRA loading helper ---
def load_lora_from_repo(repo_id, filename, adapter_name, weight):
    snapshot_path = snapshot_download(
        repo_id=repo_id,
        allow_patterns=[filename],  # Only download this file
        repo_type="model"  # or "dataset" if needed
    )
    lora_path = os.path.join(snapshot_path, filename)
    pipe.load_lora_weights(lora_path, adapter_name=adapter_name)
    return adapter_name, weight, lora_path

# --- Load all LoRAs ---
lora_adapters = []
lora_weights = []

# Load CausVid LoRA
# name, weight, path = load_lora_from_repo(LORA1_REPO_ID, LORA1_FILENAME, LORA1_NAME, LORA1_WEIGHT)
# lora_adapters.append(name)
# lora_weights.append(weight)

# # Special scale adjustment for causvid blocks.0
# for n, p in pipe.transformer.named_parameters():
#     if "lora_B" in n and "blocks.0" in n:
#         p.data = p.data * 0.25

# # Load MPS LoRA
name, weight, path = load_lora_from_repo(LORA2_REPO_ID, LORA2_FILENAME, LORA2_NAME, LORA2_WEIGHT)
lora_adapters.append(name)
lora_weights.append(weight)

# # # Load Third LoRA
# name, weight, path = load_lora_from_repo(LORA3_REPO_ID, LORA3_FILENAME, LORA3_NAME, LORA3_WEIGHT)
# lora_adapters.append(name)
# lora_weights.append(weight)

# --- Set and fuse adapters ---
pipe.set_adapters(lora_adapters, adapter_weights=lora_weights)
pipe.fuse_lora()





##### Attempt 3  #####################################################

# MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"

# # --- Initialize pipeline ---
# vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
# pipe = NAGWanPipeline.from_pretrained(MODEL_ID, vae=vae, torch_dtype=torch.bfloat16)
# pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
# pipe.to("cuda")

# # --- LoRA config list ---
# LORA_CONFIGS = [
#     {
#         "repo_id": "Kijai/WanVideo_comfy",
#         "filename": "Wan21_CausVid_14B_T2V_lora_rank32.safetensors",
#         "adapter_name": "causvid_lora",
#         "weight": 0.95,
#         "scale_blocks": ["blocks.0"],  # special scaling
#     },
#     {
#         "repo_id": "vrgamedevgirl84/Wan14BT2VFusioniX",
#         "filename": "OtherLoRa's/DetailEnhancerV1.safetensors",
#         "adapter_name": "mps_lora",
#         "weight": 0.7
#     }
#     # {
#     #     "repo_id": "your-user/your-lora-repo",
#     #     "filename": "your_third_lora.safetensors",
#     #     "adapter_name": "third_lora",
#     #     "weight": 0.85
#     # }
# ]

# # --- LoRA loader ---
# def load_lora_from_repo(pipe, repo_id, filename, adapter_name, weight, repo_type="model", scale_blocks=None):
#     snapshot_path = snapshot_download(
#         repo_id=repo_id,
#         allow_patterns=[filename],
#         repo_type=repo_type
#     )
#     lora_path = os.path.join(snapshot_path, filename)
#     pipe.load_lora_weights(lora_path, adapter_name=adapter_name)

#     # Optional: Apply scale to certain blocks
#     if scale_blocks:
#         for n, p in pipe.transformer.named_parameters():
#             if "lora_B" in n and any(block in n for block in scale_blocks):
#                 p.data *= 0.25

#     return adapter_name, weight

# # --- Load and apply LoRAs ---
# lora_adapters = []
# lora_weights = []

# for config in LORA_CONFIGS:
#     name, weight = load_lora_from_repo(
#         pipe,
#         repo_id=config["repo_id"],
#         filename=config["filename"],
#         adapter_name=config["adapter_name"],
#         weight=config.get("weight", 1.0),
#         repo_type=config.get("repo_type", "model"),
#         scale_blocks=config.get("scale_blocks", [])
#     )
#     lora_adapters.append(name)
#     lora_weights.append(weight)

# pipe.set_adapters(lora_adapters, adapter_weights=lora_weights)
# pipe.fuse_lora()
# # pipe.unload_lora_weights()  # Optional: only needed if you want to release memory

# print(f"✅ Fused LoRAs: {lora_adapters}")





# Optional: unload after fusing
pipe.unload_lora_weights()

print(f"✅ Loaded and fused {len(lora_adapters)} LoRAs: {lora_adapters}")


pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors
pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor
pipe.transformer.__class__.forward = NagWanTransformer3DModel.forward

examples = [
    ["A ginger cat passionately plays eletric guitar with intensity and emotion on a stage. The background is shrouded in deep darkness. Spotlights casts dramatic shadows.", DEFAULT_NAG_NEGATIVE_PROMPT, 11],
    ["A red vintage Porsche convertible flying over a rugged coastal cliff. Monstrous waves violently crashing against the rocks below. A lighthouse stands tall atop the cliff.", DEFAULT_NAG_NEGATIVE_PROMPT, 11],
    ["Enormous glowing jellyfish float slowly across a sky filled with soft clouds. Their tentacles shimmer with iridescent light as they drift above a peaceful mountain landscape. Magical and dreamlike, captured in a wide shot. Surreal realism style with detailed textures.", DEFAULT_NAG_NEGATIVE_PROMPT, 11],
]


def get_duration(
        prompt,
        nag_negative_prompt, nag_scale,
        height, width, duration_seconds,
        steps,
        seed, randomize_seed,
        compare,
):
    duration = int(duration_seconds) * int(steps) * 2.25 + 5
    if compare:
        duration *= 2
    return duration

@spaces.GPU(duration=get_duration)
def generate_video(
        prompt,
        nag_negative_prompt, nag_scale,
        height=DEFAULT_H_SLIDER_VALUE, width=DEFAULT_W_SLIDER_VALUE, duration_seconds=DEFAULT_DURATION_SECONDS,
        steps=DEFAULT_STEPS,
        seed=DEFAULT_SEED, randomize_seed=False,
        compare=True,
):
    target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
    target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)

    num_frames = np.clip(int(round(int(duration_seconds) * FIXED_FPS) + 1), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)

    current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)

    with torch.inference_mode():
        nag_output_frames_list = pipe(
            prompt=prompt,
            nag_negative_prompt=nag_negative_prompt,
            nag_scale=nag_scale,
            nag_tau=3.5,
            nag_alpha=0.5,
            height=target_h, width=target_w, num_frames=num_frames,
            guidance_scale=0.,
            num_inference_steps=int(steps),
            generator=torch.Generator(device="cuda").manual_seed(current_seed)
        ).frames[0]

    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
        nag_video_path = tmpfile.name
    export_to_video(nag_output_frames_list, nag_video_path, fps=FIXED_FPS)

    if compare:
        baseline_output_frames_list = pipe(
            prompt=prompt,
            nag_negative_prompt=nag_negative_prompt,
            height=target_h, width=target_w, num_frames=num_frames,
            guidance_scale=0.,
            num_inference_steps=int(steps),
            generator=torch.Generator(device="cuda").manual_seed(current_seed)
        ).frames[0]

        with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
            baseline_video_path = tmpfile.name
        export_to_video(baseline_output_frames_list, baseline_video_path, fps=FIXED_FPS)
    else:
        baseline_video_path = None

    return nag_video_path, baseline_video_path, current_seed


def generate_video_with_example(
        prompt,
        nag_negative_prompt,
        nag_scale,
):
    nag_video_path, baseline_video_path, seed = generate_video(
        prompt=prompt,
        nag_negative_prompt=nag_negative_prompt, nag_scale=nag_scale,
        height=DEFAULT_H_SLIDER_VALUE, width=DEFAULT_W_SLIDER_VALUE, duration_seconds=DEFAULT_DURATION_SECONDS,
        steps=DEFAULT_STEPS,
        seed=DEFAULT_SEED, randomize_seed=False,
        compare=True,
    )
    return nag_video_path, baseline_video_path, \
        DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, \
        DEFAULT_DURATION_SECONDS, DEFAULT_STEPS, seed, True


with gr.Blocks() as demo:
    gr.Markdown('''# Normalized Attention Guidance + Wan2.1-T2V-14B + CausVid LoRA + Detail Face lora
     ''')

    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(
                label="Prompt",
                max_lines=3,
                placeholder="Enter your prompt",
            )
            nag_negative_prompt = gr.Textbox(
                label="Negative Prompt for NAG",
                value=DEFAULT_NAG_NEGATIVE_PROMPT,
                max_lines=3,
            )
            nag_scale = gr.Slider(label="NAG Scale", minimum=1., maximum=20., step=0.25, value=11.)
            compare = gr.Checkbox(
                label="Compare with baseline",
                info="If unchecked, only sample with NAG will be generated.", value=True,
            )

            with gr.Accordion("Advanced Settings", open=False):
                steps_slider = gr.Slider(minimum=1, maximum=8, step=1, value=DEFAULT_STEPS, label="Inference Steps")
                duration_seconds_input = gr.Slider(
                    minimum=1, maximum=5, step=1, value=DEFAULT_DURATION_SECONDS,
                    label="Duration (seconds)",
                )
                seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=DEFAULT_SEED, interactive=True)
                randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
                with gr.Row():
                    height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE,
                                             value=DEFAULT_H_SLIDER_VALUE,
                                             label=f"Output Height (multiple of {MOD_VALUE})")
                    width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE,
                                            value=DEFAULT_W_SLIDER_VALUE,
                                            label=f"Output Width (multiple of {MOD_VALUE})")

            generate_button = gr.Button("Generate Video", variant="primary")
        with gr.Column():
            nag_video_output = gr.Video(label="Video with NAG", autoplay=True, interactive=False)
            baseline_video_output = gr.Video(label="Baseline Video without NAG", autoplay=True, interactive=False)

    gr.Examples(
        examples=examples,
        fn=generate_video_with_example,
        inputs=[prompt, nag_negative_prompt, nag_scale],
        outputs=[
            nag_video_output, baseline_video_output,
            height_input, width_input, duration_seconds_input,
            steps_slider,
            seed_input,
            compare,
        ],
        cache_examples="lazy"
    )

    ui_inputs = [
        prompt,
        nag_negative_prompt, nag_scale,
        height_input, width_input, duration_seconds_input,
        steps_slider,
        seed_input, randomize_seed_checkbox,
        compare,
    ]
    generate_button.click(
        fn=generate_video,
        inputs=ui_inputs,
        outputs=[nag_video_output, baseline_video_output, seed_input],
    )

if __name__ == "__main__":
    demo.queue().launch()