Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,170 +1,60 @@
|
|
1 |
import os
|
2 |
-
import pathlib
|
3 |
-
import tempfile
|
4 |
from collections.abc import Iterator
|
5 |
from threading import Thread
|
6 |
|
7 |
-
import av
|
8 |
import gradio as gr
|
9 |
import spaces
|
10 |
import torch
|
11 |
-
from gradio.utils import get_upload_folder
|
12 |
from transformers import AutoModelForImageTextToText, AutoProcessor
|
13 |
from transformers.generation.streamers import TextIteratorStreamer
|
14 |
|
|
|
|
|
|
|
|
|
15 |
model_id = "google/gemma-3n-E4B-it"
|
16 |
|
17 |
processor = AutoProcessor.from_pretrained(model_id)
|
18 |
model = AutoModelForImageTextToText.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
AUDIO_FILE_TYPES = (".mp3", ".wav")
|
23 |
-
|
24 |
-
GRADIO_TEMP_DIR = get_upload_folder()
|
25 |
-
|
26 |
-
TARGET_FPS = int(os.getenv("TARGET_FPS", "3"))
|
27 |
-
MAX_FRAMES = int(os.getenv("MAX_FRAMES", "30"))
|
28 |
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
|
29 |
|
30 |
|
31 |
-
def get_file_type(path: str) -> str:
|
32 |
-
if path.endswith(IMAGE_FILE_TYPES):
|
33 |
-
return "image"
|
34 |
-
if path.endswith(VIDEO_FILE_TYPES):
|
35 |
-
return "video"
|
36 |
-
if path.endswith(AUDIO_FILE_TYPES):
|
37 |
-
return "audio"
|
38 |
-
error_message = f"Unsupported file type: {path}"
|
39 |
-
raise ValueError(error_message)
|
40 |
-
|
41 |
-
|
42 |
-
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
43 |
-
video_count = 0
|
44 |
-
non_video_count = 0
|
45 |
-
for path in paths:
|
46 |
-
if path.endswith(VIDEO_FILE_TYPES):
|
47 |
-
video_count += 1
|
48 |
-
else:
|
49 |
-
non_video_count += 1
|
50 |
-
return video_count, non_video_count
|
51 |
-
|
52 |
-
|
53 |
-
def validate_media_constraints(message: dict) -> bool:
|
54 |
-
video_count, non_video_count = count_files_in_new_message(message["files"])
|
55 |
-
if video_count > 1:
|
56 |
-
gr.Warning("Only one video is supported.")
|
57 |
-
return False
|
58 |
-
if video_count == 1 and non_video_count > 0:
|
59 |
-
gr.Warning("Mixing images and videos is not allowed.")
|
60 |
-
return False
|
61 |
-
return True
|
62 |
-
|
63 |
-
|
64 |
-
def extract_frames_to_tempdir(
|
65 |
-
video_path: str,
|
66 |
-
target_fps: float,
|
67 |
-
max_frames: int | None = None,
|
68 |
-
parent_dir: str | None = None,
|
69 |
-
prefix: str = "frames_",
|
70 |
-
) -> str:
|
71 |
-
temp_dir = tempfile.mkdtemp(prefix=prefix, dir=parent_dir)
|
72 |
-
|
73 |
-
container = av.open(video_path)
|
74 |
-
video_stream = container.streams.video[0]
|
75 |
-
|
76 |
-
if video_stream.duration is None or video_stream.time_base is None:
|
77 |
-
raise ValueError("video_stream is missing duration or time_base")
|
78 |
-
|
79 |
-
time_base = video_stream.time_base
|
80 |
-
duration = float(video_stream.duration * time_base)
|
81 |
-
interval = 1.0 / target_fps
|
82 |
-
|
83 |
-
total_frames = int(duration * target_fps)
|
84 |
-
if max_frames is not None:
|
85 |
-
total_frames = min(total_frames, max_frames)
|
86 |
-
|
87 |
-
target_times = [i * interval for i in range(total_frames)]
|
88 |
-
target_index = 0
|
89 |
-
|
90 |
-
for frame in container.decode(video=0):
|
91 |
-
if frame.pts is None:
|
92 |
-
continue
|
93 |
-
|
94 |
-
timestamp = float(frame.pts * time_base)
|
95 |
-
|
96 |
-
if target_index < len(target_times) and abs(timestamp - target_times[target_index]) < (interval / 2):
|
97 |
-
frame_path = pathlib.Path(temp_dir) / f"frame_{target_index:04d}.jpg"
|
98 |
-
frame.to_image().save(frame_path)
|
99 |
-
target_index += 1
|
100 |
-
|
101 |
-
if max_frames is not None and target_index >= max_frames:
|
102 |
-
break
|
103 |
-
|
104 |
-
container.close()
|
105 |
-
return temp_dir
|
106 |
-
|
107 |
-
|
108 |
-
def process_new_user_message(message: dict) -> list[dict]:
|
109 |
-
if not message["files"]:
|
110 |
-
return [{"type": "text", "text": message["text"]}]
|
111 |
-
|
112 |
-
file_types = [get_file_type(path) for path in message["files"]]
|
113 |
-
|
114 |
-
if len(file_types) == 1 and file_types[0] == "video":
|
115 |
-
gr.Info(f"Video will be processed at {TARGET_FPS} FPS, max {MAX_FRAMES} frames in this Space.")
|
116 |
-
|
117 |
-
temp_dir = extract_frames_to_tempdir(
|
118 |
-
message["files"][0],
|
119 |
-
target_fps=TARGET_FPS,
|
120 |
-
max_frames=MAX_FRAMES,
|
121 |
-
parent_dir=GRADIO_TEMP_DIR,
|
122 |
-
)
|
123 |
-
paths = sorted(pathlib.Path(temp_dir).glob("*.jpg"))
|
124 |
-
return [
|
125 |
-
{"type": "text", "text": message["text"]},
|
126 |
-
*[{"type": "image", "image": path.as_posix()} for path in paths],
|
127 |
-
]
|
128 |
-
|
129 |
-
return [
|
130 |
-
{"type": "text", "text": message["text"]},
|
131 |
-
*[{"type": file_type, file_type: path} for path, file_type in zip(message["files"], file_types, strict=True)],
|
132 |
-
]
|
133 |
-
|
134 |
-
|
135 |
-
def process_history(history: list[dict]) -> list[dict]:
|
136 |
-
messages = []
|
137 |
-
current_user_content: list[dict] = []
|
138 |
-
for item in history:
|
139 |
-
if item["role"] == "assistant":
|
140 |
-
if current_user_content:
|
141 |
-
messages.append({"role": "user", "content": current_user_content})
|
142 |
-
current_user_content = []
|
143 |
-
messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
|
144 |
-
else:
|
145 |
-
content = item["content"]
|
146 |
-
if isinstance(content, str):
|
147 |
-
current_user_content.append({"type": "text", "text": content})
|
148 |
-
else:
|
149 |
-
filepath = content[0]
|
150 |
-
file_type = get_file_type(filepath)
|
151 |
-
current_user_content.append({"type": file_type, file_type: filepath})
|
152 |
-
return messages
|
153 |
-
|
154 |
-
|
155 |
@spaces.GPU(duration=120)
|
156 |
@torch.inference_mode()
|
157 |
-
def
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
yield ""
|
160 |
return
|
161 |
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
messages
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
|
|
168 |
inputs = processor.apply_chat_template(
|
169 |
messages,
|
170 |
add_generation_prompt=True,
|
@@ -172,87 +62,69 @@ def generate(message: dict, history: list[dict], system_prompt: str = "", max_ne
|
|
172 |
return_dict=True,
|
173 |
return_tensors="pt",
|
174 |
)
|
|
|
|
|
175 |
n_tokens = inputs["input_ids"].shape[1]
|
176 |
if n_tokens > MAX_INPUT_TOKENS:
|
177 |
gr.Warning(
|
178 |
-
f"
|
|
|
179 |
)
|
180 |
yield ""
|
181 |
return
|
182 |
|
|
|
183 |
inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
|
184 |
|
|
|
185 |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
|
|
|
|
|
186 |
generate_kwargs = dict(
|
187 |
inputs,
|
188 |
streamer=streamer,
|
189 |
-
max_new_tokens=
|
190 |
do_sample=False,
|
191 |
-
disable_compile=True,
|
192 |
)
|
|
|
|
|
193 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
194 |
t.start()
|
195 |
|
|
|
196 |
output = ""
|
197 |
for delta in streamer:
|
198 |
output += delta
|
199 |
yield output
|
200 |
|
201 |
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
],
|
209 |
-
[
|
210 |
-
{
|
211 |
-
"text": "Describe this image in detail.",
|
212 |
-
"files": ["assets/cat.jpeg"],
|
213 |
-
}
|
214 |
-
],
|
215 |
-
[
|
216 |
-
{
|
217 |
-
"text": "Transcribe the following speech segment in English.",
|
218 |
-
"files": ["assets/speech.wav"],
|
219 |
-
}
|
220 |
],
|
221 |
-
[
|
222 |
-
|
223 |
-
"text": "Transcribe the following speech segment in English.",
|
224 |
-
"files": ["assets/speech2.wav"],
|
225 |
-
}
|
226 |
],
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
|
|
|
|
|
|
|
|
|
|
231 |
}
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
demo = gr.ChatInterface(
|
236 |
-
fn=generate,
|
237 |
-
type="messages",
|
238 |
-
textbox=gr.MultimodalTextbox(
|
239 |
-
file_types=list(IMAGE_FILE_TYPES + VIDEO_FILE_TYPES + AUDIO_FILE_TYPES),
|
240 |
-
file_count="multiple",
|
241 |
-
autofocus=True,
|
242 |
-
),
|
243 |
-
multimodal=True,
|
244 |
-
additional_inputs=[
|
245 |
-
gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
|
246 |
-
gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
|
247 |
-
],
|
248 |
-
stop_btn=False,
|
249 |
-
title="Gemma 3n E4B it",
|
250 |
-
examples=examples,
|
251 |
-
run_examples_on_click=False,
|
252 |
-
cache_examples=False,
|
253 |
-
css_paths="style.css",
|
254 |
-
delete_cache=(1800, 1800),
|
255 |
)
|
256 |
|
|
|
257 |
if __name__ == "__main__":
|
258 |
demo.launch()
|
|
|
|
1 |
import os
|
|
|
|
|
2 |
from collections.abc import Iterator
|
3 |
from threading import Thread
|
4 |
|
|
|
5 |
import gradio as gr
|
6 |
import spaces
|
7 |
import torch
|
|
|
8 |
from transformers import AutoModelForImageTextToText, AutoProcessor
|
9 |
from transformers.generation.streamers import TextIteratorStreamer
|
10 |
|
11 |
+
# --- Model- und Prozessor-Setup ---
|
12 |
+
# Lädt das Modell und den Prozessor von Hugging Face.
|
13 |
+
# device_map="auto" weist das Modell automatisch der verfügbaren Hardware (GPU/CPU) zu.
|
14 |
+
# torch_dtype=torch.bfloat16 verwendet eine speichereffizientere Datenart für schnellere Berechnungen.
|
15 |
model_id = "google/gemma-3n-E4B-it"
|
16 |
|
17 |
processor = AutoProcessor.from_pretrained(model_id)
|
18 |
model = AutoModelForImageTextToText.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
|
19 |
|
20 |
+
# --- Konstanten ---
|
21 |
+
# Definiert die maximal zulässige Anzahl von Eingabe-Tokens, um Speicherfehler zu vermeiden.
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
@spaces.GPU(duration=120)
|
26 |
@torch.inference_mode()
|
27 |
+
def transcribe_german(audio_input: str | None) -> Iterator[str]:
|
28 |
+
"""
|
29 |
+
Diese Funktion nimmt einen Audiopfad entgegen, transkribiert ihn mit dem Gemma-Modell
|
30 |
+
ins Deutsche und gibt die Transkription als Text-Stream zurück.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
audio_input: Der Dateipfad zur Audiodatei, die transkribiert werden soll.
|
34 |
+
Kann None sein, wenn keine Eingabe erfolgt.
|
35 |
+
|
36 |
+
Yields:
|
37 |
+
Einen String-Iterator, der das Transkript Stück für Stück ausgibt.
|
38 |
+
"""
|
39 |
+
if audio_input is None:
|
40 |
+
gr.Warning("Bitte stellen Sie eine Audiodatei zur Verfügung oder nehmen Sie Audio auf.")
|
41 |
yield ""
|
42 |
return
|
43 |
|
44 |
+
# Die Eingabe für das Modell vorbereiten.
|
45 |
+
# Wir geben dem Modell eine explizite Anweisung ("Transkribiere dies auf Deutsch.")
|
46 |
+
# und die Audiodatei.
|
47 |
+
messages = [
|
48 |
+
{
|
49 |
+
"role": "user",
|
50 |
+
"content": [
|
51 |
+
{"type": "text", "text": "Transkribiere dies auf Deutsch."},
|
52 |
+
{"type": "audio", "audio": audio_input},
|
53 |
+
],
|
54 |
+
}
|
55 |
+
]
|
56 |
|
57 |
+
# Das Chat-Template des Prozessors verwenden, um die Eingabe korrekt zu formatieren.
|
58 |
inputs = processor.apply_chat_template(
|
59 |
messages,
|
60 |
add_generation_prompt=True,
|
|
|
62 |
return_dict=True,
|
63 |
return_tensors="pt",
|
64 |
)
|
65 |
+
|
66 |
+
# Überprüfen, ob die Eingabe die maximale Token-Länge überschreitet.
|
67 |
n_tokens = inputs["input_ids"].shape[1]
|
68 |
if n_tokens > MAX_INPUT_TOKENS:
|
69 |
gr.Warning(
|
70 |
+
f"Eingabe zu lang. Max. {MAX_INPUT_TOKENS} Tokens. Habe {n_tokens} Tokens erhalten. "
|
71 |
+
"Dieses Limit dient dazu, CUDA-Speicherfehler in diesem Space zu vermeiden."
|
72 |
)
|
73 |
yield ""
|
74 |
return
|
75 |
|
76 |
+
# Die formatierten Eingaben auf das Gerät des Modells verschieben.
|
77 |
inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
|
78 |
|
79 |
+
# Einen Streamer einrichten, um die Ausgabe des Modells in Echtzeit zu erhalten.
|
80 |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
|
81 |
+
|
82 |
+
# Die Argumente für die Generierung definieren.
|
83 |
generate_kwargs = dict(
|
84 |
inputs,
|
85 |
streamer=streamer,
|
86 |
+
max_new_tokens=700, # Ausreichend für die meisten Transkriptionen
|
87 |
do_sample=False,
|
88 |
+
disable_compile=True, # Kompilierung für schnellere Inferenz deaktivieren (kann bei manchen Setups zu Problemen führen)
|
89 |
)
|
90 |
+
|
91 |
+
# Die Generierung in einem separaten Thread starten, damit die UI nicht blockiert wird.
|
92 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
93 |
t.start()
|
94 |
|
95 |
+
# Den generierten Text durchlaufen und schrittweise ausgeben.
|
96 |
output = ""
|
97 |
for delta in streamer:
|
98 |
output += delta
|
99 |
yield output
|
100 |
|
101 |
|
102 |
+
# --- Gradio UI Setup ---
|
103 |
+
# Erstellt die Benutzeroberfläche mit Gradio.
|
104 |
+
demo = gr.Interface(
|
105 |
+
fn=transcribe_german,
|
106 |
+
inputs=[
|
107 |
+
gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio hochladen oder aufnehmen")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
],
|
109 |
+
outputs=[
|
110 |
+
gr.Textbox(label="Deutsches Transkript", interactive=False)
|
|
|
|
|
|
|
111 |
],
|
112 |
+
title="Audio-zu-Text Transkription (Deutsch)",
|
113 |
+
description="Laden Sie eine Audiodatei hoch oder nehmen Sie Audio mit Ihrem Mikrofon auf, um ein deutsches Transkript zu erhalten. Das Modell ist `google/gemma-3n-E4B-it`.",
|
114 |
+
theme="soft",
|
115 |
+
css="""
|
116 |
+
.gradio-container {
|
117 |
+
font-family: 'Inter', sans-serif;
|
118 |
+
max-width: 800px;
|
119 |
+
margin-left: auto;
|
120 |
+
margin-right: auto;
|
121 |
}
|
122 |
+
footer {display: none !important}
|
123 |
+
""",
|
124 |
+
allow_flagging="never",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
)
|
126 |
|
127 |
+
# --- App starten ---
|
128 |
if __name__ == "__main__":
|
129 |
demo.launch()
|
130 |
+
|