Spaces:
Sleeping
Sleeping
Commit
·
5b58ac7
1
Parent(s):
fe25f9c
fix: Update Gradio interface to fix type handling issues and improve memory management
Browse files- src/api/app.py +67 -16
src/api/app.py
CHANGED
@@ -4,6 +4,9 @@ Gradio interface for the LLaVA model.
|
|
4 |
|
5 |
import gradio as gr
|
6 |
from PIL import Image
|
|
|
|
|
|
|
7 |
|
8 |
from ..configs.settings import (
|
9 |
GRADIO_THEME,
|
@@ -25,7 +28,30 @@ setup_logging()
|
|
25 |
logger = get_logger(__name__)
|
26 |
|
27 |
# Initialize model
|
28 |
-
model =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
def process_image(
|
31 |
image: Image.Image,
|
@@ -47,15 +73,36 @@ def process_image(
|
|
47 |
Returns:
|
48 |
str: Model response
|
49 |
"""
|
|
|
|
|
|
|
50 |
try:
|
51 |
logger.info(f"Processing image with prompt: {prompt[:100]}...")
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
logger.info("Successfully generated response")
|
60 |
return response
|
61 |
except Exception as e:
|
@@ -129,17 +176,17 @@ Try these prompts to get started:
|
|
129 |
show_copy_button=True
|
130 |
)
|
131 |
|
132 |
-
# Set up event handlers
|
133 |
generate_btn.click(
|
134 |
fn=process_image,
|
135 |
inputs=[
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
],
|
142 |
-
outputs=
|
143 |
)
|
144 |
|
145 |
return interface
|
@@ -156,4 +203,8 @@ def main():
|
|
156 |
)
|
157 |
|
158 |
if __name__ == "__main__":
|
159 |
-
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
import gradio as gr
|
6 |
from PIL import Image
|
7 |
+
import os
|
8 |
+
import tempfile
|
9 |
+
import torch
|
10 |
|
11 |
from ..configs.settings import (
|
12 |
GRADIO_THEME,
|
|
|
28 |
logger = get_logger(__name__)
|
29 |
|
30 |
# Initialize model
|
31 |
+
model = None
|
32 |
+
|
33 |
+
def initialize_model():
|
34 |
+
global model
|
35 |
+
try:
|
36 |
+
logger.info("Initializing LLaVA model...")
|
37 |
+
# Use a smaller model variant and enable memory optimizations
|
38 |
+
model = LLaVAModel(
|
39 |
+
vision_model_path="openai/clip-vit-base-patch32", # Smaller vision model
|
40 |
+
language_model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0", # Smaller language model
|
41 |
+
device="cpu", # Force CPU for Hugging Face Spaces
|
42 |
+
projection_hidden_dim=2048 # Reduce projection layer size
|
43 |
+
)
|
44 |
+
|
45 |
+
# Enable memory optimizations
|
46 |
+
torch.cuda.empty_cache() # Clear any cached memory
|
47 |
+
if hasattr(model, 'language_model'):
|
48 |
+
model.language_model.config.use_cache = False # Disable KV cache
|
49 |
+
|
50 |
+
logger.info(f"Model initialized on {model.device}")
|
51 |
+
return True
|
52 |
+
except Exception as e:
|
53 |
+
logger.error(f"Error initializing model: {e}")
|
54 |
+
return False
|
55 |
|
56 |
def process_image(
|
57 |
image: Image.Image,
|
|
|
73 |
Returns:
|
74 |
str: Model response
|
75 |
"""
|
76 |
+
if not model:
|
77 |
+
return "Error: Model not initialized"
|
78 |
+
|
79 |
try:
|
80 |
logger.info(f"Processing image with prompt: {prompt[:100]}...")
|
81 |
+
|
82 |
+
# Save the uploaded image temporarily
|
83 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
|
84 |
+
image.save(temp_file.name)
|
85 |
+
temp_path = temp_file.name
|
86 |
+
|
87 |
+
# Clear memory before processing
|
88 |
+
torch.cuda.empty_cache()
|
89 |
+
|
90 |
+
# Generate response with reduced memory usage
|
91 |
+
with torch.inference_mode(): # More memory efficient than no_grad
|
92 |
+
response = model(
|
93 |
+
image=image,
|
94 |
+
prompt=prompt,
|
95 |
+
max_new_tokens=max_new_tokens,
|
96 |
+
temperature=temperature,
|
97 |
+
top_p=top_p
|
98 |
+
)
|
99 |
+
|
100 |
+
# Clean up temporary file
|
101 |
+
os.unlink(temp_path)
|
102 |
+
|
103 |
+
# Clear memory after processing
|
104 |
+
torch.cuda.empty_cache()
|
105 |
+
|
106 |
logger.info("Successfully generated response")
|
107 |
return response
|
108 |
except Exception as e:
|
|
|
176 |
show_copy_button=True
|
177 |
)
|
178 |
|
179 |
+
# Set up event handlers with explicit types
|
180 |
generate_btn.click(
|
181 |
fn=process_image,
|
182 |
inputs=[
|
183 |
+
gr.Image(type="pil"),
|
184 |
+
gr.Textbox(),
|
185 |
+
gr.Slider(),
|
186 |
+
gr.Slider(),
|
187 |
+
gr.Slider()
|
188 |
],
|
189 |
+
outputs=gr.Textbox()
|
190 |
)
|
191 |
|
192 |
return interface
|
|
|
203 |
)
|
204 |
|
205 |
if __name__ == "__main__":
|
206 |
+
# Initialize model
|
207 |
+
if initialize_model():
|
208 |
+
main()
|
209 |
+
else:
|
210 |
+
print("Failed to initialize model. Exiting...")
|