Prashant26am commited on
Commit
5b58ac7
·
1 Parent(s): fe25f9c

fix: Update Gradio interface to fix type handling issues and improve memory management

Browse files
Files changed (1) hide show
  1. src/api/app.py +67 -16
src/api/app.py CHANGED
@@ -4,6 +4,9 @@ Gradio interface for the LLaVA model.
4
 
5
  import gradio as gr
6
  from PIL import Image
 
 
 
7
 
8
  from ..configs.settings import (
9
  GRADIO_THEME,
@@ -25,7 +28,30 @@ setup_logging()
25
  logger = get_logger(__name__)
26
 
27
  # Initialize model
28
- model = LLaVAModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def process_image(
31
  image: Image.Image,
@@ -47,15 +73,36 @@ def process_image(
47
  Returns:
48
  str: Model response
49
  """
 
 
 
50
  try:
51
  logger.info(f"Processing image with prompt: {prompt[:100]}...")
52
- response = model(
53
- image=image,
54
- prompt=prompt,
55
- max_new_tokens=max_new_tokens,
56
- temperature=temperature,
57
- top_p=top_p
58
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  logger.info("Successfully generated response")
60
  return response
61
  except Exception as e:
@@ -129,17 +176,17 @@ Try these prompts to get started:
129
  show_copy_button=True
130
  )
131
 
132
- # Set up event handlers
133
  generate_btn.click(
134
  fn=process_image,
135
  inputs=[
136
- image_input,
137
- prompt_input,
138
- max_tokens,
139
- temperature,
140
- top_p
141
  ],
142
- outputs=output
143
  )
144
 
145
  return interface
@@ -156,4 +203,8 @@ def main():
156
  )
157
 
158
  if __name__ == "__main__":
159
- main()
 
 
 
 
 
4
 
5
  import gradio as gr
6
  from PIL import Image
7
+ import os
8
+ import tempfile
9
+ import torch
10
 
11
  from ..configs.settings import (
12
  GRADIO_THEME,
 
28
  logger = get_logger(__name__)
29
 
30
  # Initialize model
31
+ model = None
32
+
33
+ def initialize_model():
34
+ global model
35
+ try:
36
+ logger.info("Initializing LLaVA model...")
37
+ # Use a smaller model variant and enable memory optimizations
38
+ model = LLaVAModel(
39
+ vision_model_path="openai/clip-vit-base-patch32", # Smaller vision model
40
+ language_model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0", # Smaller language model
41
+ device="cpu", # Force CPU for Hugging Face Spaces
42
+ projection_hidden_dim=2048 # Reduce projection layer size
43
+ )
44
+
45
+ # Enable memory optimizations
46
+ torch.cuda.empty_cache() # Clear any cached memory
47
+ if hasattr(model, 'language_model'):
48
+ model.language_model.config.use_cache = False # Disable KV cache
49
+
50
+ logger.info(f"Model initialized on {model.device}")
51
+ return True
52
+ except Exception as e:
53
+ logger.error(f"Error initializing model: {e}")
54
+ return False
55
 
56
  def process_image(
57
  image: Image.Image,
 
73
  Returns:
74
  str: Model response
75
  """
76
+ if not model:
77
+ return "Error: Model not initialized"
78
+
79
  try:
80
  logger.info(f"Processing image with prompt: {prompt[:100]}...")
81
+
82
+ # Save the uploaded image temporarily
83
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
84
+ image.save(temp_file.name)
85
+ temp_path = temp_file.name
86
+
87
+ # Clear memory before processing
88
+ torch.cuda.empty_cache()
89
+
90
+ # Generate response with reduced memory usage
91
+ with torch.inference_mode(): # More memory efficient than no_grad
92
+ response = model(
93
+ image=image,
94
+ prompt=prompt,
95
+ max_new_tokens=max_new_tokens,
96
+ temperature=temperature,
97
+ top_p=top_p
98
+ )
99
+
100
+ # Clean up temporary file
101
+ os.unlink(temp_path)
102
+
103
+ # Clear memory after processing
104
+ torch.cuda.empty_cache()
105
+
106
  logger.info("Successfully generated response")
107
  return response
108
  except Exception as e:
 
176
  show_copy_button=True
177
  )
178
 
179
+ # Set up event handlers with explicit types
180
  generate_btn.click(
181
  fn=process_image,
182
  inputs=[
183
+ gr.Image(type="pil"),
184
+ gr.Textbox(),
185
+ gr.Slider(),
186
+ gr.Slider(),
187
+ gr.Slider()
188
  ],
189
+ outputs=gr.Textbox()
190
  )
191
 
192
  return interface
 
203
  )
204
 
205
  if __name__ == "__main__":
206
+ # Initialize model
207
+ if initialize_model():
208
+ main()
209
+ else:
210
+ print("Failed to initialize model. Exiting...")