cbensimon HF Staff commited on
Commit
879ee4e
·
verified ·
1 Parent(s): a651c76

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +3 -2
optimization.py CHANGED
@@ -10,6 +10,7 @@ import torch
10
  from torch.utils._pytree import tree_map_only
11
  from torchao.quantization import quantize_
12
  from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
 
13
 
14
  from optimization_utils import capture_component_call
15
  from optimization_utils import aoti_compile
@@ -42,6 +43,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
42
  @spaces.GPU(duration=1500)
43
  def compile_transformer():
44
 
 
 
45
  with capture_component_call(pipeline, 'transformer') as call:
46
  pipeline(*args, **kwargs)
47
 
@@ -86,9 +89,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
86
  compiled_portrait_2,
87
  )
88
 
89
- pipeline.text_encoder.to('cpu')
90
  cl1, cl2, cp1, cp2 = compile_transformer()
91
- pipeline.text_encoder.to('cuda')
92
 
93
  def combined_transformer_1(*args, **kwargs):
94
  hidden_states: torch.Tensor = kwargs['hidden_states']
 
10
  from torch.utils._pytree import tree_map_only
11
  from torchao.quantization import quantize_
12
  from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
13
+ from torchao.quantization import Int8WeightOnlyConfig
14
 
15
  from optimization_utils import capture_component_call
16
  from optimization_utils import aoti_compile
 
43
  @spaces.GPU(duration=1500)
44
  def compile_transformer():
45
 
46
+ quantize_(pipeline.text_encoder, Int8WeightOnlyConfig()) # Just to free-up some GPU memory
47
+
48
  with capture_component_call(pipeline, 'transformer') as call:
49
  pipeline(*args, **kwargs)
50
 
 
89
  compiled_portrait_2,
90
  )
91
 
 
92
  cl1, cl2, cp1, cp2 = compile_transformer()
 
93
 
94
  def combined_transformer_1(*args, **kwargs):
95
  hidden_states: torch.Tensor = kwargs['hidden_states']