Bruno commited on
Commit
5aab8fc
·
verified ·
1 Parent(s): 2fdc223

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -24
app.py CHANGED
@@ -6,6 +6,7 @@ from threading import Thread
6
  import gradio as gr
7
  import spaces
8
  import torch
 
9
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
10
 
11
  model_id = "google/gemma-3-12b-it"
@@ -14,8 +15,70 @@ model = Gemma3ForConditionalGeneration.from_pretrained(
14
  model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
15
  )
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def process_new_user_message(message: dict) -> list[dict]:
18
- return [{"type": "text", "text": message["text"]}, *[{"type": "image", "url": path} for path in message["files"]]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def process_history(history: list[dict]) -> list[dict]:
21
  messages = []
@@ -34,6 +97,7 @@ def process_history(history: list[dict]) -> list[dict]:
34
  current_user_content.append({"type": "image", "url": content[0]})
35
  return messages
36
 
 
37
  @spaces.GPU(duration=120)
38
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
39
  messages = []
@@ -64,35 +128,30 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
64
  output += delta
65
  yield output
66
 
 
67
  examples = [
68
  [
69
  {
70
- "text": "Descreve a imagem",
71
- "files": ["assets/sample-images/11.png"],
72
  }
73
  ],
74
  [
75
  {
76
- "text": "O que diz a placa?",
77
- "files": ["assets/sample-images/02.png"],
78
  }
79
  ],
80
  [
81
  {
82
- "text": "Compare e contraste as duas imagens.",
83
- "files": ["assets/sample-images/03.png"],
84
  }
85
  ],
86
  [
87
  {
88
- "text": "Liste todos os objetos na imagem e suas cores.",
89
- "files": ["assets/sample-images/04.png"],
90
- }
91
- ],
92
- [
93
- {
94
- "text": "Descreva a atmosfera da cena.",
95
- "files": ["assets/sample-images/05.png"],
96
  }
97
  ],
98
  [
@@ -120,7 +179,7 @@ examples = [
120
  ],
121
  [
122
  {
123
- "text": "Crie uma história curta com base na sequência de imagens.",
124
  "files": [
125
  "assets/sample-images/09-1.png",
126
  "assets/sample-images/09-2.png",
@@ -132,8 +191,8 @@ examples = [
132
  ],
133
  [
134
  {
135
- "text": "Descreva as criaturas que viveriam neste mundo.",
136
- "files": ["assets/sample-images/10.png"],
137
  }
138
  ],
139
  [
@@ -160,20 +219,51 @@ examples = [
160
  "files": ["assets/additional-examples/4.png"],
161
  }
162
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  ]
164
 
 
165
  demo = gr.ChatInterface(
166
  fn=run,
167
  type="messages",
168
- textbox=gr.MultimodalTextbox(file_types=["image"], file_count="multiple"),
169
  multimodal=True,
170
  additional_inputs=[
171
- gr.Textbox(label="System Prompt", value="Você é um assistente útil. responder em pt br"),
172
- gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=500),
173
  ],
174
  stop_btn=False,
175
- title="Gemma 3 12B it - Bruno Henrique",
176
- description="<img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />",
177
  examples=examples,
178
  run_examples_on_click=False,
179
  cache_examples=False,
@@ -182,4 +272,4 @@ demo = gr.ChatInterface(
182
  )
183
 
184
  if __name__ == "__main__":
185
- demo.launch()
 
6
  import gradio as gr
7
  import spaces
8
  import torch
9
+ import re
10
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
11
 
12
  model_id = "google/gemma-3-12b-it"
 
15
  model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
16
  )
17
 
18
+ import cv2
19
+ from PIL import Image
20
+ import numpy as np
21
+ import tempfile
22
+
23
+ def downsample_video(video_path):
24
+ vidcap = cv2.VideoCapture(video_path)
25
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
26
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
27
+
28
+ frame_interval = int(fps / 3)
29
+ frames = []
30
+
31
+ for i in range(0, total_frames, frame_interval):
32
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
33
+ success, image = vidcap.read()
34
+ if success:
35
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
36
+ pil_image = Image.fromarray(image)
37
+ timestamp = round(i / fps, 2)
38
+ frames.append((pil_image, timestamp))
39
+
40
+ vidcap.release()
41
+ return frames
42
+
43
+
44
  def process_new_user_message(message: dict) -> list[dict]:
45
+ if message["files"]:
46
+ if "<image>" in message["text"]:
47
+ content = []
48
+ print("message[files]", message["files"])
49
+ parts = re.split(r'(<image>)', message["text"])
50
+ image_index = 0
51
+ print("parts", parts)
52
+ for part in parts:
53
+ print("part", part)
54
+ if part == "<image>":
55
+ content.append({"type": "image", "url": message["files"][image_index]})
56
+ print("file", message["files"][image_index])
57
+ image_index += 1
58
+ elif part.strip():
59
+ content.append({"type": "text", "text": part.strip()})
60
+ elif isinstance(part, str) and not part == "<image>":
61
+ content.append({"type": "text", "text": part})
62
+ print(content)
63
+ return content
64
+ elif message["files"][0].endswith(".mp4"):
65
+ content = []
66
+ video = message["files"].pop(0)
67
+ frames = downsample_video(video)
68
+ for frame in frames:
69
+ pil_image, timestamp = frame
70
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
71
+ pil_image.save(temp_file.name)
72
+ content.append({"type": "text", "text": f"Frame {timestamp}:"})
73
+ content.append({"type": "image", "url": temp_file.name})
74
+ print(content)
75
+ return content
76
+ else:
77
+ # non interleaved images
78
+ return [{"type": "text", "text": message["text"]}, *[{"type": "image", "url": path} for path in message["files"]]]
79
+ else:
80
+ return [{"type": "text", "text": message["text"]}]
81
+
82
 
83
  def process_history(history: list[dict]) -> list[dict]:
84
  messages = []
 
97
  current_user_content.append({"type": "image", "url": content[0]})
98
  return messages
99
 
100
+
101
  @spaces.GPU(duration=120)
102
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
103
  messages = []
 
128
  output += delta
129
  yield output
130
 
131
+
132
  examples = [
133
  [
134
  {
135
+ "text": "Preciso estar no Japão por 10 dias, indo para Tóquio, Kyoto e Osaka. Pense no número de atrações em cada uma delas e aloque o número de dias para cada cidade. Faça recomendações de transporte público.",
136
+ "files": [],
137
  }
138
  ],
139
  [
140
  {
141
+ "text": "Escreva o código matplotlib para gerar o mesmo gráfico de barras.",
142
+ "files": ["assets/sample-images/barchart.png"],
143
  }
144
  ],
145
  [
146
  {
147
+ "text": "O que de estranho neste vídeo?",
148
+ "files": ["assets/sample-images/tmp.mp4"],
149
  }
150
  ],
151
  [
152
  {
153
+ "text": "Eu tenho este suplemento <image> e quero comprar este outro <image>. Há algum aviso que eu deva saber?",
154
+ "files": ["assets/sample-images/pill1.png", "assets/sample-images/pill2.png"],
 
 
 
 
 
 
155
  }
156
  ],
157
  [
 
179
  ],
180
  [
181
  {
182
+ "text": "Crie uma história curta baseada na sequência de imagens.",
183
  "files": [
184
  "assets/sample-images/09-1.png",
185
  "assets/sample-images/09-2.png",
 
191
  ],
192
  [
193
  {
194
+ "text": "Descreva essa imagem.",
195
+ "files": ["assets/sample-images/PIX.png"],
196
  }
197
  ],
198
  [
 
219
  "files": ["assets/additional-examples/4.png"],
220
  }
221
  ],
222
+ [
223
+ {
224
+ "text": "Legende esta imagem.",
225
+ "files": ["assets/sample-images/01.png"],
226
+ }
227
+ ],
228
+ [
229
+ {
230
+ "text": "O que diz a placa?",
231
+ "files": ["assets/sample-images/02.png"],
232
+ }
233
+ ],
234
+ [
235
+ {
236
+ "text": "Compare e contraste as duas imagens.",
237
+ "files": ["assets/sample-images/03.png"],
238
+ }
239
+ ],
240
+ [
241
+ {
242
+ "text": "Liste todos os objetos na imagem e suas cores.",
243
+ "files": ["assets/sample-images/04.png"],
244
+ }
245
+ ],
246
+ [
247
+ {
248
+ "text": "Descreva a atmosfera da cena.",
249
+ "files": ["assets/sample-images/05.png"],
250
+ }
251
+ ],
252
  ]
253
 
254
+
255
  demo = gr.ChatInterface(
256
  fn=run,
257
  type="messages",
258
+ textbox=gr.MultimodalTextbox(file_types=["image", ".mp4"], file_count="multiple"),
259
  multimodal=True,
260
  additional_inputs=[
261
+ gr.Textbox(label="System Prompt", value="Você é um assistente, responder em ptbr."),
262
+ gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
263
  ],
264
  stop_btn=False,
265
+ title="Gemma 3 12B PT-BR",
266
+ description="<img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' /><br>This is a demo of Gemma 3 12B it, a vision language model with outstanding performance on a wide range of tasks. You can upload images, interleaved images and videos. Note that video input only supports single-turn conversation and mp4 input.",
267
  examples=examples,
268
  run_examples_on_click=False,
269
  cache_examples=False,
 
272
  )
273
 
274
  if __name__ == "__main__":
275
+ demo.launch()