rahul7star commited on
Commit
f1eaac4
·
verified ·
1 Parent(s): 40cd568

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +25 -10
optimization.py CHANGED
@@ -15,7 +15,7 @@ from torchao.quantization import Int8WeightOnlyConfig
15
  from optimization_utils import capture_component_call
16
  from optimization_utils import aoti_compile
17
  from optimization_utils import ZeroGPUCompiledModel
18
-
19
 
20
  P = ParamSpec('P')
21
 
@@ -43,6 +43,26 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
43
  @spaces.GPU(duration=1500)
44
  def compile_transformer():
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  with capture_component_call(pipeline, 'transformer') as call:
47
  pipeline(*args, **kwargs)
48
 
@@ -105,13 +125,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
105
  else:
106
  return cp2(*args, **kwargs)
107
 
108
- transformer_config = pipeline.transformer.config
109
- transformer_dtype = pipeline.transformer.dtype
110
-
111
- pipeline.transformer = combined_transformer_1
112
- pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
113
- pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
114
 
115
- pipeline.transformer_2 = combined_transformer_2
116
- pipeline.transformer_2.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
117
- pipeline.transformer_2.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
 
15
  from optimization_utils import capture_component_call
16
  from optimization_utils import aoti_compile
17
  from optimization_utils import ZeroGPUCompiledModel
18
+ from optimization_utils import drain_module_parameters
19
 
20
  P = ParamSpec('P')
21
 
 
43
  @spaces.GPU(duration=1500)
44
  def compile_transformer():
45
 
46
+ pipeline.load_lora_weights(
47
+ "Kijai/WanVideo_comfy",
48
+ weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
49
+ adapter_name="lightning"
50
+ )
51
+ kwargs_lora = {}
52
+ kwargs_lora["load_into_transformer_2"] = True
53
+ pipeline.load_lora_weights(
54
+ "Kijai/WanVideo_comfy",
55
+ weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
56
+ #weight_name="Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors",
57
+ adapter_name="lightning_2", **kwargs_lora
58
+ )
59
+ pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1., 1.])
60
+
61
+
62
+ pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3., components=["transformer"])
63
+ pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1., components=["transformer_2"])
64
+ pipeline.unload_lora_weights()
65
+
66
  with capture_component_call(pipeline, 'transformer') as call:
67
  pipeline(*args, **kwargs)
68
 
 
125
  else:
126
  return cp2(*args, **kwargs)
127
 
128
+ pipeline.transformer.forward = combined_transformer_1
129
+ drain_module_parameters(pipeline.transformer)
 
 
 
 
130
 
131
+ pipeline.transformer_2.forward = combined_transformer_2
132
+ drain_module_parameters(pipeline.transformer_2)