flozi00 commited on
Commit
dc7fd72
·
verified ·
1 Parent(s): a01e5ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -196
app.py CHANGED
@@ -1,170 +1,60 @@
1
  import os
2
- import pathlib
3
- import tempfile
4
  from collections.abc import Iterator
5
  from threading import Thread
6
 
7
- import av
8
  import gradio as gr
9
  import spaces
10
  import torch
11
- from gradio.utils import get_upload_folder
12
  from transformers import AutoModelForImageTextToText, AutoProcessor
13
  from transformers.generation.streamers import TextIteratorStreamer
14
 
 
 
 
 
15
  model_id = "google/gemma-3n-E4B-it"
16
 
17
  processor = AutoProcessor.from_pretrained(model_id)
18
  model = AutoModelForImageTextToText.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
19
 
20
- IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
21
- VIDEO_FILE_TYPES = (".mp4", ".mov", ".webm")
22
- AUDIO_FILE_TYPES = (".mp3", ".wav")
23
-
24
- GRADIO_TEMP_DIR = get_upload_folder()
25
-
26
- TARGET_FPS = int(os.getenv("TARGET_FPS", "3"))
27
- MAX_FRAMES = int(os.getenv("MAX_FRAMES", "30"))
28
  MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
29
 
30
 
31
- def get_file_type(path: str) -> str:
32
- if path.endswith(IMAGE_FILE_TYPES):
33
- return "image"
34
- if path.endswith(VIDEO_FILE_TYPES):
35
- return "video"
36
- if path.endswith(AUDIO_FILE_TYPES):
37
- return "audio"
38
- error_message = f"Unsupported file type: {path}"
39
- raise ValueError(error_message)
40
-
41
-
42
- def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
43
- video_count = 0
44
- non_video_count = 0
45
- for path in paths:
46
- if path.endswith(VIDEO_FILE_TYPES):
47
- video_count += 1
48
- else:
49
- non_video_count += 1
50
- return video_count, non_video_count
51
-
52
-
53
- def validate_media_constraints(message: dict) -> bool:
54
- video_count, non_video_count = count_files_in_new_message(message["files"])
55
- if video_count > 1:
56
- gr.Warning("Only one video is supported.")
57
- return False
58
- if video_count == 1 and non_video_count > 0:
59
- gr.Warning("Mixing images and videos is not allowed.")
60
- return False
61
- return True
62
-
63
-
64
- def extract_frames_to_tempdir(
65
- video_path: str,
66
- target_fps: float,
67
- max_frames: int | None = None,
68
- parent_dir: str | None = None,
69
- prefix: str = "frames_",
70
- ) -> str:
71
- temp_dir = tempfile.mkdtemp(prefix=prefix, dir=parent_dir)
72
-
73
- container = av.open(video_path)
74
- video_stream = container.streams.video[0]
75
-
76
- if video_stream.duration is None or video_stream.time_base is None:
77
- raise ValueError("video_stream is missing duration or time_base")
78
-
79
- time_base = video_stream.time_base
80
- duration = float(video_stream.duration * time_base)
81
- interval = 1.0 / target_fps
82
-
83
- total_frames = int(duration * target_fps)
84
- if max_frames is not None:
85
- total_frames = min(total_frames, max_frames)
86
-
87
- target_times = [i * interval for i in range(total_frames)]
88
- target_index = 0
89
-
90
- for frame in container.decode(video=0):
91
- if frame.pts is None:
92
- continue
93
-
94
- timestamp = float(frame.pts * time_base)
95
-
96
- if target_index < len(target_times) and abs(timestamp - target_times[target_index]) < (interval / 2):
97
- frame_path = pathlib.Path(temp_dir) / f"frame_{target_index:04d}.jpg"
98
- frame.to_image().save(frame_path)
99
- target_index += 1
100
-
101
- if max_frames is not None and target_index >= max_frames:
102
- break
103
-
104
- container.close()
105
- return temp_dir
106
-
107
-
108
- def process_new_user_message(message: dict) -> list[dict]:
109
- if not message["files"]:
110
- return [{"type": "text", "text": message["text"]}]
111
-
112
- file_types = [get_file_type(path) for path in message["files"]]
113
-
114
- if len(file_types) == 1 and file_types[0] == "video":
115
- gr.Info(f"Video will be processed at {TARGET_FPS} FPS, max {MAX_FRAMES} frames in this Space.")
116
-
117
- temp_dir = extract_frames_to_tempdir(
118
- message["files"][0],
119
- target_fps=TARGET_FPS,
120
- max_frames=MAX_FRAMES,
121
- parent_dir=GRADIO_TEMP_DIR,
122
- )
123
- paths = sorted(pathlib.Path(temp_dir).glob("*.jpg"))
124
- return [
125
- {"type": "text", "text": message["text"]},
126
- *[{"type": "image", "image": path.as_posix()} for path in paths],
127
- ]
128
-
129
- return [
130
- {"type": "text", "text": message["text"]},
131
- *[{"type": file_type, file_type: path} for path, file_type in zip(message["files"], file_types, strict=True)],
132
- ]
133
-
134
-
135
- def process_history(history: list[dict]) -> list[dict]:
136
- messages = []
137
- current_user_content: list[dict] = []
138
- for item in history:
139
- if item["role"] == "assistant":
140
- if current_user_content:
141
- messages.append({"role": "user", "content": current_user_content})
142
- current_user_content = []
143
- messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
144
- else:
145
- content = item["content"]
146
- if isinstance(content, str):
147
- current_user_content.append({"type": "text", "text": content})
148
- else:
149
- filepath = content[0]
150
- file_type = get_file_type(filepath)
151
- current_user_content.append({"type": file_type, file_type: filepath})
152
- return messages
153
-
154
-
155
  @spaces.GPU(duration=120)
156
  @torch.inference_mode()
157
- def generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
158
- if not validate_media_constraints(message):
 
 
 
 
 
 
 
 
 
 
 
 
159
  yield ""
160
  return
161
 
162
- messages = []
163
- if system_prompt:
164
- messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
165
- messages.extend(process_history(history))
166
- messages.append({"role": "user", "content": process_new_user_message(message)})
 
 
 
 
 
 
 
167
 
 
168
  inputs = processor.apply_chat_template(
169
  messages,
170
  add_generation_prompt=True,
@@ -172,87 +62,69 @@ def generate(message: dict, history: list[dict], system_prompt: str = "", max_ne
172
  return_dict=True,
173
  return_tensors="pt",
174
  )
 
 
175
  n_tokens = inputs["input_ids"].shape[1]
176
  if n_tokens > MAX_INPUT_TOKENS:
177
  gr.Warning(
178
- f"Input too long. Max {MAX_INPUT_TOKENS} tokens. Got {n_tokens} tokens. This limit is set to avoid CUDA out-of-memory errors in this Space."
 
179
  )
180
  yield ""
181
  return
182
 
 
183
  inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
184
 
 
185
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
 
 
186
  generate_kwargs = dict(
187
  inputs,
188
  streamer=streamer,
189
- max_new_tokens=max_new_tokens,
190
  do_sample=False,
191
- disable_compile=True,
192
  )
 
 
193
  t = Thread(target=model.generate, kwargs=generate_kwargs)
194
  t.start()
195
 
 
196
  output = ""
197
  for delta in streamer:
198
  output += delta
199
  yield output
200
 
201
 
202
- examples = [
203
- [
204
- {
205
- "text": "What is the capital of France?",
206
- "files": [],
207
- }
208
- ],
209
- [
210
- {
211
- "text": "Describe this image in detail.",
212
- "files": ["assets/cat.jpeg"],
213
- }
214
- ],
215
- [
216
- {
217
- "text": "Transcribe the following speech segment in English.",
218
- "files": ["assets/speech.wav"],
219
- }
220
  ],
221
- [
222
- {
223
- "text": "Transcribe the following speech segment in English.",
224
- "files": ["assets/speech2.wav"],
225
- }
226
  ],
227
- [
228
- {
229
- "text": "Describe this video",
230
- "files": ["assets/holding_phone.mp4"],
 
 
 
 
 
231
  }
232
- ],
233
- ]
234
-
235
- demo = gr.ChatInterface(
236
- fn=generate,
237
- type="messages",
238
- textbox=gr.MultimodalTextbox(
239
- file_types=list(IMAGE_FILE_TYPES + VIDEO_FILE_TYPES + AUDIO_FILE_TYPES),
240
- file_count="multiple",
241
- autofocus=True,
242
- ),
243
- multimodal=True,
244
- additional_inputs=[
245
- gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
246
- gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
247
- ],
248
- stop_btn=False,
249
- title="Gemma 3n E4B it",
250
- examples=examples,
251
- run_examples_on_click=False,
252
- cache_examples=False,
253
- css_paths="style.css",
254
- delete_cache=(1800, 1800),
255
  )
256
 
 
257
  if __name__ == "__main__":
258
  demo.launch()
 
 
1
  import os
 
 
2
  from collections.abc import Iterator
3
  from threading import Thread
4
 
 
5
  import gradio as gr
6
  import spaces
7
  import torch
 
8
  from transformers import AutoModelForImageTextToText, AutoProcessor
9
  from transformers.generation.streamers import TextIteratorStreamer
10
 
11
+ # --- Model- und Prozessor-Setup ---
12
+ # Lädt das Modell und den Prozessor von Hugging Face.
13
+ # device_map="auto" weist das Modell automatisch der verfügbaren Hardware (GPU/CPU) zu.
14
+ # torch_dtype=torch.bfloat16 verwendet eine speichereffizientere Datenart für schnellere Berechnungen.
15
  model_id = "google/gemma-3n-E4B-it"
16
 
17
  processor = AutoProcessor.from_pretrained(model_id)
18
  model = AutoModelForImageTextToText.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
19
 
20
+ # --- Konstanten ---
21
+ # Definiert die maximal zulässige Anzahl von Eingabe-Tokens, um Speicherfehler zu vermeiden.
 
 
 
 
 
 
22
  MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  @spaces.GPU(duration=120)
26
  @torch.inference_mode()
27
+ def transcribe_german(audio_input: str | None) -> Iterator[str]:
28
+ """
29
+ Diese Funktion nimmt einen Audiopfad entgegen, transkribiert ihn mit dem Gemma-Modell
30
+ ins Deutsche und gibt die Transkription als Text-Stream zurück.
31
+
32
+ Args:
33
+ audio_input: Der Dateipfad zur Audiodatei, die transkribiert werden soll.
34
+ Kann None sein, wenn keine Eingabe erfolgt.
35
+
36
+ Yields:
37
+ Einen String-Iterator, der das Transkript Stück für Stück ausgibt.
38
+ """
39
+ if audio_input is None:
40
+ gr.Warning("Bitte stellen Sie eine Audiodatei zur Verfügung oder nehmen Sie Audio auf.")
41
  yield ""
42
  return
43
 
44
+ # Die Eingabe für das Modell vorbereiten.
45
+ # Wir geben dem Modell eine explizite Anweisung ("Transkribiere dies auf Deutsch.")
46
+ # und die Audiodatei.
47
+ messages = [
48
+ {
49
+ "role": "user",
50
+ "content": [
51
+ {"type": "text", "text": "Transkribiere dies auf Deutsch."},
52
+ {"type": "audio", "audio": audio_input},
53
+ ],
54
+ }
55
+ ]
56
 
57
+ # Das Chat-Template des Prozessors verwenden, um die Eingabe korrekt zu formatieren.
58
  inputs = processor.apply_chat_template(
59
  messages,
60
  add_generation_prompt=True,
 
62
  return_dict=True,
63
  return_tensors="pt",
64
  )
65
+
66
+ # Überprüfen, ob die Eingabe die maximale Token-Länge überschreitet.
67
  n_tokens = inputs["input_ids"].shape[1]
68
  if n_tokens > MAX_INPUT_TOKENS:
69
  gr.Warning(
70
+ f"Eingabe zu lang. Max. {MAX_INPUT_TOKENS} Tokens. Habe {n_tokens} Tokens erhalten. "
71
+ "Dieses Limit dient dazu, CUDA-Speicherfehler in diesem Space zu vermeiden."
72
  )
73
  yield ""
74
  return
75
 
76
+ # Die formatierten Eingaben auf das Gerät des Modells verschieben.
77
  inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
78
 
79
+ # Einen Streamer einrichten, um die Ausgabe des Modells in Echtzeit zu erhalten.
80
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
81
+
82
+ # Die Argumente für die Generierung definieren.
83
  generate_kwargs = dict(
84
  inputs,
85
  streamer=streamer,
86
+ max_new_tokens=700, # Ausreichend für die meisten Transkriptionen
87
  do_sample=False,
88
+ disable_compile=True, # Kompilierung für schnellere Inferenz deaktivieren (kann bei manchen Setups zu Problemen führen)
89
  )
90
+
91
+ # Die Generierung in einem separaten Thread starten, damit die UI nicht blockiert wird.
92
  t = Thread(target=model.generate, kwargs=generate_kwargs)
93
  t.start()
94
 
95
+ # Den generierten Text durchlaufen und schrittweise ausgeben.
96
  output = ""
97
  for delta in streamer:
98
  output += delta
99
  yield output
100
 
101
 
102
+ # --- Gradio UI Setup ---
103
+ # Erstellt die Benutzeroberfläche mit Gradio.
104
+ demo = gr.Interface(
105
+ fn=transcribe_german,
106
+ inputs=[
107
+ gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio hochladen oder aufnehmen")
 
 
 
 
 
 
 
 
 
 
 
 
108
  ],
109
+ outputs=[
110
+ gr.Textbox(label="Deutsches Transkript", interactive=False)
 
 
 
111
  ],
112
+ title="Audio-zu-Text Transkription (Deutsch)",
113
+ description="Laden Sie eine Audiodatei hoch oder nehmen Sie Audio mit Ihrem Mikrofon auf, um ein deutsches Transkript zu erhalten. Das Modell ist `google/gemma-3n-E4B-it`.",
114
+ theme="soft",
115
+ css="""
116
+ .gradio-container {
117
+ font-family: 'Inter', sans-serif;
118
+ max-width: 800px;
119
+ margin-left: auto;
120
+ margin-right: auto;
121
  }
122
+ footer {display: none !important}
123
+ """,
124
+ allow_flagging="never",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  )
126
 
127
+ # --- App starten ---
128
  if __name__ == "__main__":
129
  demo.launch()
130
+