Update optimization.py
Browse files- optimization.py +3 -2
optimization.py
CHANGED
@@ -10,6 +10,7 @@ import torch
|
|
10 |
from torch.utils._pytree import tree_map_only
|
11 |
from torchao.quantization import quantize_
|
12 |
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
|
|
|
13 |
|
14 |
from optimization_utils import capture_component_call
|
15 |
from optimization_utils import aoti_compile
|
@@ -42,6 +43,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
42 |
@spaces.GPU(duration=1500)
|
43 |
def compile_transformer():
|
44 |
|
|
|
|
|
45 |
with capture_component_call(pipeline, 'transformer') as call:
|
46 |
pipeline(*args, **kwargs)
|
47 |
|
@@ -86,9 +89,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
86 |
compiled_portrait_2,
|
87 |
)
|
88 |
|
89 |
-
pipeline.text_encoder.to('cpu')
|
90 |
cl1, cl2, cp1, cp2 = compile_transformer()
|
91 |
-
pipeline.text_encoder.to('cuda')
|
92 |
|
93 |
def combined_transformer_1(*args, **kwargs):
|
94 |
hidden_states: torch.Tensor = kwargs['hidden_states']
|
|
|
10 |
from torch.utils._pytree import tree_map_only
|
11 |
from torchao.quantization import quantize_
|
12 |
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
|
13 |
+
from torchao.quantization import Int8WeightOnlyConfig
|
14 |
|
15 |
from optimization_utils import capture_component_call
|
16 |
from optimization_utils import aoti_compile
|
|
|
43 |
@spaces.GPU(duration=1500)
|
44 |
def compile_transformer():
|
45 |
|
46 |
+
quantize_(pipeline.text_encoder, Int8WeightOnlyConfig()) # Just to free-up some GPU memory
|
47 |
+
|
48 |
with capture_component_call(pipeline, 'transformer') as call:
|
49 |
pipeline(*args, **kwargs)
|
50 |
|
|
|
89 |
compiled_portrait_2,
|
90 |
)
|
91 |
|
|
|
92 |
cl1, cl2, cp1, cp2 = compile_transformer()
|
|
|
93 |
|
94 |
def combined_transformer_1(*args, **kwargs):
|
95 |
hidden_states: torch.Tensor = kwargs['hidden_states']
|