Spaces:
Sleeping
Sleeping
import spaces | |
import importlib | |
import importlib.util | |
import subprocess | |
import sys | |
import os | |
def ensure_mmpose_installed(local_path='/data/wheelhouse/mmpose-0.24.0-py2.py3-none-any.whl'): | |
""" | |
Check if 'mmpose' can be imported; if not, attempt to install it from local_path. | |
local_path should contain setup.py or pyproject.toml so that pip can install it. | |
""" | |
package_name = 'mmpose' | |
# Try to find the spec for import | |
spec = importlib.util.find_spec(package_name) | |
if spec is not None: | |
try: | |
module = importlib.import_module(package_name) | |
print(f"'{package_name}' is already installed, version: {getattr(module, '__version__', 'unknown')}") | |
return True | |
except Exception as e: | |
print(f"Found '{package_name}', but import failed: {e}. Will attempt re-install.") | |
else: | |
print(f"'{package_name}' not found, attempting installation...") | |
# If we reach here, we need to install or reinstall | |
# Check that the directory exists | |
#if not os.path.isdir(local_path): | |
# raise FileNotFoundError(f"Specified install directory does not exist: {local_path}") | |
# Construct pip install command using the current Python executable | |
cmd = [sys.executable, "-m", "pip", "install", local_path] | |
print("Running command:", " ".join(cmd)) | |
try: | |
subprocess.check_call(cmd) | |
except subprocess.CalledProcessError as e: | |
raise RuntimeError(f"Failed to install mmpose: {e}") | |
# After installation, try importing again | |
try: | |
module = importlib.import_module(package_name) | |
print(f"'{package_name}' installed and imported successfully, version: {getattr(module, '__version__', 'unknown')}") | |
return True | |
except Exception as e: | |
raise RuntimeError(f"Installed but still cannot import '{package_name}': {e}") | |
ensure_mmpose_installed('/data/wheelhouse/mmpose-0.24.0-py2.py3-none-any.whl') | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
import json | |
import subprocess | |
import os | |
from typing import Tuple, List, Dict, Any | |
import tempfile | |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor | |
import torch | |
from PIL import Image | |
import datetime | |
import uuid | |
from PIL import Image | |
import io, base64 | |
from datasets import load_dataset, Dataset | |
from huggingface_hub import HfApi, Repository, upload_file | |
sys.path.append("/data/WHAM") | |
from demo import wham_execute | |
HF_DATASET_ID = "qihfang/sportscoaching" | |
# Use HuggingFace remote inference | |
try: | |
from huggingface_hub import InferenceClient | |
except ImportError: | |
InferenceClient = None | |
class PoseEstimationApp: | |
def __init__(self, model_name: str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", use_remote: bool = True, max_history_turns: int = 50): | |
self.processing_steps = [ | |
"Video upload completed", | |
"Starting video downsampling...", | |
"Executing Pose Estimation...", | |
"Running Stage1 prompt...", | |
"Running Stage2 prompt...", | |
"Running Evaluator...", | |
"Running Stage3 prompt...", | |
"Generating final result" | |
] | |
self.use_remote = use_remote | |
self.model_name = model_name | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
if not use_remote: | |
raise RuntimeError("Remote inference only supported, please set use_remote=True and provide HF_TOKEN environment variable.") | |
if InferenceClient is None: | |
raise RuntimeError("huggingface_hub not installed, please install to use remote inference.") | |
token = os.getenv("HF_TOKEN") | |
if not token: | |
raise RuntimeError("HF_TOKEN environment variable not set, please set the access token in deployment environment.") | |
try: | |
self.client = InferenceClient(model=model_name, token=token, provider="novita") | |
except Exception as e: | |
raise RuntimeError(f"Failed to initialize remote inference client: {e}") | |
# Conversation history management | |
# Use a list to store several rounds of conversation, each item is a dict containing 'role' ('user' or 'assistant') and 'content' | |
self.conversation_history: List[Dict[str, str]] = [] | |
# Keep the most recent number of rounds (user+assistant), truncate when exceeded | |
self.max_history_turns = max_history_turns | |
def reset_history(self): | |
""" | |
Clear conversation history, call when starting a new multi-turn conversation scenario. | |
""" | |
self.conversation_history = [] | |
def add_user_message(self, message: str): | |
self.conversation_history.append({"role": "user", "content": message}) | |
# If exceeding maximum rounds (here a round refers to user+assistant), remove earliest rounds | |
# Calculate current entries, if len > 2 * max_history_turns, truncate the earliest two entries | |
max_items = 2 * self.max_history_turns | |
if len(self.conversation_history) > max_items: | |
# Discard the earliest two entries | |
self.conversation_history = self.conversation_history[-max_items:] | |
def add_assistant_message(self, message: str): | |
self.conversation_history.append({"role": "assistant", "content": message}) | |
# Similarly truncate history | |
max_items = 2 * self.max_history_turns | |
if len(self.conversation_history) > max_items: | |
self.conversation_history = self.conversation_history[-max_items:] | |
def build_prompt_with_history(self, new_user_input: str) -> str: | |
""" | |
Concatenate history rounds with current user input into a prompt string. | |
Example: | |
User: ... | |
Assistant: ... | |
User: new_user_input | |
Assistant: | |
""" | |
prompt_parts = [] | |
for turn in self.conversation_history: | |
if turn["role"] == "user": | |
prompt_parts.append(f"User: {turn['content']}") | |
else: | |
prompt_parts.append(f"Assistant: {turn['content']}") | |
# Add new user input, model reply will be generated at the end | |
prompt_parts.append(f"User: {new_user_input}") | |
prompt_parts.append("Assistant:") # Guide model generation | |
full_prompt = "\n".join(prompt_parts) | |
return full_prompt | |
def image_to_datauri(self, img: Image.Image, max_size=640, jpeg_quality=70): | |
# 先按最长边缩放到 max_size | |
w, h = img.size | |
scale = max_size / max(w, h) | |
if scale < 1.0: | |
new_w, new_h = int(w*scale), int(h*scale) | |
img = img.resize((new_w, new_h), Image.BILINEAR) | |
# 转 JPEG | |
buffered = io.BytesIO() | |
img.save(buffered, format="JPEG", quality=jpeg_quality) | |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
return f"data:image/jpeg;base64,{img_b64}" | |
def query_llm_multimodal(self, text: str, indexed_images: list, use_history: bool = True, max_tokens: int = 1024): | |
messages = [] | |
if use_history: | |
for turn in self.conversation_history: | |
messages.append({"role": turn['role'], "content": turn['content']}) | |
# 先处理文字部分(如果有) | |
if text: | |
messages.append({"role": "user", "content": [{"type": "text", "text": text}]}) | |
self.add_user_message(text) | |
# 逐帧添加图片输入 | |
for clip_idx, (video_idx, img) in enumerate(indexed_images): | |
uri = self.image_to_datauri(img, max_size=512, jpeg_quality=60) | |
# 只对最关键的少量帧调用,或在调用前筛选 | |
msg_content = [ | |
{"type":"image_url","image_url": {"url": uri}}, | |
{"type":"text","text":f"The {clip_idx}th image is the {video_idx+1}th frame in the video, please analyze and summarize the content. Use the original frame index ({video_idx+1}) for reminder."} | |
] | |
messages.append({"role":"user","content":msg_content}) | |
self.add_user_message(f"[IMAGE frame {video_idx}]") | |
# 如帧过多,可在此处 break,或只处理前 K 帧 | |
try: | |
response = self.client.chat.completions.create(messages=messages, max_tokens=max_tokens) | |
reply = response.choices[0].message.content | |
except Exception as e: | |
raise RuntimeError(f"Multimodal inference error: {e}") | |
self.add_assistant_message(reply) | |
return reply | |
def query_llm(self, prompt: str, max_length: int = 2048, use_history: bool = True) -> str: | |
if use_history: | |
messages = [] | |
for turn in self.conversation_history: | |
messages.append({"role": turn['role'], "content": turn['content']}) | |
messages.append({"role": "user", "content": prompt}) | |
self.add_user_message(prompt) | |
else: | |
messages = [{"role": "user", "content": prompt}] | |
try: | |
response = self.client.chat.completions.create(messages=messages, max_tokens=max_length) | |
reply = response.choices[0].message.content | |
except Exception as e: | |
raise RuntimeError(f"Remote inference error: {e}") | |
self.add_assistant_message(reply) | |
return reply | |
# Other methods remain unchanged... | |
def downsample_video(self, input_path: str, output_path: str, downsample_rate: int) -> Tuple[str, int]: | |
cap = cv2.VideoCapture(input_path) | |
if not cap.isOpened(): | |
raise RuntimeError(f"Cannot open video file {input_path}") | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
new_fps = max(1, int(fps / downsample_rate)) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
out = cv2.VideoWriter(output_path, fourcc, new_fps, (width, height)) | |
frame_count = 0 | |
frames = [] | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if frame_count % downsample_rate == 0: | |
out.write(frame) | |
frames.append((frame_count, Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))) | |
frame_count += 1 | |
cap.release() | |
out.release() | |
return output_path, new_fps, frames | |
def run_pose_estimation(self, tmp_dir, video_path: str) -> str: | |
base_name = os.path.splitext(os.path.basename(video_path))[0] | |
out_dir = os.path.join(tmp_dir, "output") | |
os.makedirs(out_dir, exist_ok=True) | |
# cmd = [ | |
# "python", "/data/WHAM/demo.py", | |
# "--video", video_path, | |
# "--save_pkl", | |
# "--output_pth", out_dir | |
# ] | |
wham_execute(video_path, out_dir, True, True, False) | |
out_dir = os.path.join(out_dir, base_name) | |
# result = subprocess.run(cmd, capture_output=True, text=True) | |
# if result.returncode != 0: | |
# raise RuntimeError(f"Pose Estimation failed: {result.stderr}") | |
result_path = os.path.join(out_dir, "wham_output.pkl") | |
if not os.path.exists(result_path): | |
raise FileNotFoundError(f"Result file not found: {result_path}") | |
return result_path | |
def load_pose_data(self, pth_path: str): | |
try: | |
data = torch.load(pth_path, map_location="cpu") | |
return data | |
except Exception as e: | |
raise RuntimeError(f"Failed to load Pose data: {e}") | |
def extract_frames(self, video_path: str, frame_skip: int = 1) -> List[Tuple[int, Image.Image]]: | |
""" | |
Eagerly read and return all frames as a list of (frame_index, PIL.Image). | |
""" | |
frames = [] | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
return frames | |
idx = 0 | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if idx % frame_skip == 0: | |
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
frames.append((idx, img)) | |
idx += 1 | |
cap.release() | |
return frames | |
def upload_initial_entry(video_path: str, instruction: str, downsample_rate: int, token: str): | |
""" | |
上传视频文件和 instruction 到 HF Dataset,并追加一条没有 rating 的记录。 | |
""" | |
api = HfApi() | |
# 生成唯一 ID | |
entry_id = uuid.uuid4().hex | |
# 上传视频到 HF | |
filename = os.path.basename(video_path) | |
remote_video_path = f"videos/{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}_{entry_id}_{filename}" | |
api.upload_file( | |
path_or_fileobj=video_path, | |
path_in_repo=remote_video_path, | |
repo_id=HF_DATASET_ID, | |
token=token | |
) | |
# 加载或初始化 Dataset | |
try: | |
ds = load_dataset(HF_DATASET_ID, use_auth_token=token) | |
if isinstance(ds, dict): | |
ds = ds["train"] | |
except: | |
ds = Dataset.from_dict({ | |
"entry_id": [], "timestamp": [], "video_path": [], | |
"instruction": [], "downsample_rate": [], "rating": [] | |
}) | |
# 追加一条带 entry_id 的记录(rating 留空) | |
new = { | |
"entry_id": entry_id, | |
"timestamp": datetime.datetime.now().isoformat(), | |
"video_path": remote_video_path, | |
"instruction": instruction, | |
"downsample_rate": downsample_rate, | |
"rating": None | |
} | |
ds = ds.add_item(new) | |
ds.push_to_hub(HF_DATASET_ID, token=token) | |
# 返回 entry_id 给前端 | |
return entry_id | |
def process_video(self, video_file, downsample_rate, progress=gr.Progress()): | |
""" | |
修改:process_video 返回 (result_text, downsampled_video_path, downsample_rate) 三元组 | |
以便界面显示视频并存储。 | |
""" | |
if video_file is None: | |
return "Please upload a video file first", None, None | |
try: | |
self.reset_history() | |
progress(0.1, desc=self.processing_steps[0]) | |
tmp_dir = tempfile.mkdtemp() | |
if hasattr(video_file, 'name'): | |
orig_ext = os.path.splitext(video_file.name)[1] # e.g. ".avi" | |
else: | |
orig_ext = os.path.splitext(video_file)[1] | |
input_path = os.path.join(tmp_dir, f"input{orig_ext}") | |
os.replace(video_file, input_path) | |
progress(0.2, desc=self.processing_steps[1]) | |
downsampled_tmp = os.path.join(tmp_dir, "downsample.mp4") | |
downsampled_path, new_fps, frames = self.downsample_video(input_path, downsampled_tmp, downsample_rate) | |
frr = self.extract_frames(downsampled_path) | |
# results_dir = "results" | |
# os.makedirs(results_dir, exist_ok=True) | |
# unique_name = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + "_" + str(uuid.uuid4())[:8] + ".mp4" | |
# persistent_path = os.path.join(results_dir, unique_name) | |
# # 复制文件 | |
# import shutil | |
# shutil.copyfile(downsampled_path, persistent_path) | |
# 不再复制到持久目录,直接使用 downsampled_path(在 tmp_dir 中)进行上传 | |
persistent_path = input_path | |
progress(0.3, desc=self.processing_steps[2]) | |
pth_path = self.run_pose_estimation(tmp_dir, input_path) | |
with open(pth_path, "rb") as f: | |
import joblib | |
pkl_file = joblib.load(f) | |
subjs = len(pkl_file.keys()) | |
if subjs < 1: | |
return "Failed to detect characters from the video, please update a new video with higher frame rate and .", None, None | |
#pth_path = "wham_output.pth" | |
# Stage1 | |
progress(0.4, desc=self.processing_steps[3]) | |
stage1_path = os.path.join("prompts", "stage1.txt") | |
if not os.path.exists(stage1_path): | |
raise RuntimeError("Missing prompts/stage1.txt prompt file") | |
with open(stage1_path, 'r', encoding='utf-8') as f: | |
prompt1 = f.read() | |
prompt1_1 = prompt1.split("[IMAGEFLAG]")[0].strip() | |
prompt1_2 = prompt1.split("[IMAGEFLAG]")[1].strip() | |
out_stage1_1 = self.query_llm(prompt1_1, use_history=True) | |
out_images = self.query_llm_multimodal(text="", indexed_images=frames, use_history=True) | |
out_stage1_part2 = self.query_llm(prompt1_2, use_history=True) | |
# Stage2 | |
progress(0.5, desc=self.processing_steps[4]) | |
stage2_path = os.path.join("prompts", "stage2.txt") | |
if not os.path.exists(stage2_path): | |
raise RuntimeError("Missing prompts/stage2.txt prompt file") | |
with open(stage2_path, 'r', encoding='utf-8') as f: | |
prompt2 = f.read() | |
prompt2 = prompt2.replace("[FRAMERATE]", str(new_fps)) | |
max_retries = 3 | |
out_stage2 = "" | |
temp_json_path = os.path.join(tmp_dir, "temp_json.json") | |
for attempt in range(max_retries): | |
out_stage2 = self.query_llm(prompt2, use_history=True) | |
try: | |
parsed = json.loads(out_stage2) | |
with open(temp_json_path, 'w', encoding='utf-8') as f: | |
json.dump(parsed, f, ensure_ascii=False, indent=2) | |
break | |
except json.JSONDecodeError: | |
prompt2 = "The previous output was not valid JSON. Please output only valid JSON without any extra content." + "\n" + out_stage2 | |
if attempt == max_retries - 1: | |
with open(temp_json_path, 'w', encoding='utf-8') as f: | |
f.write(out_stage2) | |
# Evaluator | |
progress(0.6, desc=self.processing_steps[5]) | |
evaluator_cmd = ["python", "estimator.py", pth_path, temp_json_path] | |
result = subprocess.run(evaluator_cmd, capture_output=True, text=True) | |
if result.returncode != 0: | |
raise RuntimeError(f"Evaluator error: {result.stderr}") | |
output_txt_path = os.path.join(tmp_dir, "temp_json_output.txt") | |
with open(output_txt_path, 'r', encoding='utf-8') as f: | |
evaluator_output = f.read() | |
# Stage3 | |
progress(0.7, desc=self.processing_steps[6]) | |
stage3_path = os.path.join("prompts", "stage3.txt") | |
if not os.path.exists(stage3_path): | |
raise RuntimeError("Missing prompts/stage3.txt prompt file") | |
with open(stage3_path, 'r', encoding='utf-8') as f: | |
prompt3 = f.read() | |
prompt3 = prompt3.replace("[RESULTS]", evaluator_output) | |
out_stage3 = self.query_llm(prompt3, use_history=True) | |
stage4_path = os.path.join("prompts", "stage4.txt") | |
if not os.path.exists(stage4_path): | |
raise RuntimeError("Missing prompts/stage4.txt prompt file") | |
with open(stage4_path, 'r', encoding='utf-8') as f: | |
prompt4 = f.read() | |
prompt4 = prompt4.replace("[RESULTS]", evaluator_output) | |
out_stage4 = self.query_llm(prompt4, use_history=True) | |
hf_token = os.getenv("HF_TOKEN") | |
try: | |
entry_id = upload_initial_entry(persistent_path, out_stage4, downsample_rate, hf_token) | |
except Exception as e: | |
# 记录日志即可,不影响正常返回 | |
print(f"Warning: initial upload failed: {e}") | |
progress(1.0, desc=self.processing_steps[7]) | |
# 返回最终文本、持久化保存的视频路径、下采样率 | |
return out_stage4, persistent_path, downsample_rate, entry_id | |
except Exception as e: | |
# 出错返回三个值,其中视频路径和下采样率为 None | |
return "Processing error: " + str(e), None, None, None | |
app = PoseEstimationApp() | |
def create_interface(): | |
# 预定义两种语言下的文本 | |
texts = { | |
"en": { | |
"title_md": "# 🎬 Video Pose Estimation Processing Platform", | |
"description_md": "Upload a video to downsample and perform pose estimation, combine multimodal LLM analysis to generate intelligent insights", | |
"input_settings": "## 📤 Input Settings", | |
"video_label": "Upload video file", | |
"downsample_label": "Temporal downsampling rate", | |
"downsample_info": "Take 1 frame every N frames and reduce frame rate. Higher rate runs faster, lower rate yields more accurate results.", | |
"process_btn": "🚀 Start Processing", | |
"clear_btn": "🔄 Clear", | |
"results_md": "## 📊 Processing Results", | |
"final_tab": "Final Result", | |
"final_label": "Final Comprehensive Result", | |
"rating_label": "Please rate the result (1–5):", | |
"submit_rating_btn": "Submit Rating", | |
"thankyou_msg": "Thank you for your feedback!", | |
"instructions_md": """ | |
## 💡 Instructions | |
1. After uploading a video, the system will generate downsample.mp4 based on the downsampling rate. | |
2. Run WHAM/demo.py for Pose Estimation; results are saved in output/<video_name>/wham_output.pth. | |
3. The system will automatically read prompts/stage1.txt, stage2.txt, stage3.txt; user custom prompts are not accepted. | |
4. Stage1: prompts/stage1.txt can include [POSE_SUMMARY] placeholder, auto-replaced with pose summary. | |
5. Stage2: prompts/stage2.txt can include [FRAMERATE] and [STAGE1_RESULT] placeholders, auto-replaced. | |
6. Prompts will be forced to output JSON format for Evaluator use. | |
7. After Evaluator runs, it generates output.txt; content is automatically passed to Stage3. | |
8. Deployment requires HF_TOKEN environment variable set for HuggingFace access token; code uses it automatically. | |
9. Ensure project root contains prompts/stage1.txt, stage2.txt, stage3.txt and WHAM/demo.py, evaluator.py. | |
""" | |
}, | |
"zh": { | |
"title_md": "# 🎬 视频姿态估计处理平台", | |
"description_md": "上传视频进行降采样和姿态估计,结合多模态 LLM 分析生成智能化见解", | |
"input_settings": "## 📤 输入设置", | |
"video_label": "上传视频文件", | |
"downsample_label": "时间降采样率", | |
"downsample_info": "每隔 N 帧取 1 帧并降低帧率。更高的采样率速度更快,但精度可能下降;更低采样率更准确。", | |
"process_btn": "🚀 开始处理", | |
"clear_btn": "🔄 清除", | |
"results_md": "## 📊 处理结果", | |
"final_tab": "最终结果", | |
"final_label": "最终综合结果", | |
"rating_label": "请对结果进行评分 (1–5):", | |
"submit_rating_btn": "提交评分", | |
"thankyou_msg": "感谢您的反馈!", | |
"instructions_md": """ | |
## 💡 使用说明 | |
1. 上传视频后,系统会根据降采样率生成 downsample.mp4。 | |
2. 运行 WHAM/demo.py 进行姿态估计;结果保存在 output/<video_name>/wham_output.pth。 | |
3. 系统会自动读取 prompts/stage1.txt、stage2.txt、stage3.txt;不接受用户自定义提示。 | |
4. Stage1: prompts/stage1.txt 可包含 [POSE_SUMMARY] 占位符,将被自动替换。 | |
5. Stage2: prompts/stage2.txt 可包含 [FRAMERATE] 和 [STAGE1_RESULT] 占位符,将被自动替换。 | |
6. 提示将被强制输出 JSON 格式以供 Evaluator 使用。 | |
7. Evaluator 运行后会生成 output.txt;内容会自动传递到 Stage3。 | |
8. 部署需要设置 HF_TOKEN 环境变量以获得 HuggingFace 访问令牌;代码会自动使用。 | |
9. 确保项目根目录下包含 prompts/stage1.txt、stage2.txt、stage3.txt 以及 WHAM/demo.py、evaluator.py。 | |
""" | |
} | |
} | |
with gr.Blocks( | |
theme=gr.themes.Soft(), | |
title="Video Pose Estimation Processing Platform", | |
css=""" | |
.gradio-container { max-width: 1200px !important; } | |
.tab-nav { background: linear-gradient(90deg, #667eea, #764ba2) !important; } | |
""" | |
) as demo: | |
# 语言状态 | |
lang_state = gr.State("en") | |
# 隐藏状态:存储最近处理的视频路径和下采样率 | |
last_video_path = gr.State(None) | |
last_downsample_rate = gr.State(None) | |
last_entry_id = gr.State(None) | |
# 语言切换按钮 | |
lang_btn = gr.Button("中文") # 初始语言 en,所以按钮文字为“中文” | |
# 头部 Markdown | |
header_md = gr.Markdown(texts["en"]["title_md"]) | |
desc_md = gr.Markdown(texts["en"]["description_md"]) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_md = gr.Markdown(texts["en"]["input_settings"]) | |
video_input = gr.Video(label=texts["en"]["video_label"], sources=["upload"], height=300) | |
with gr.Row(): | |
downsample_rate = gr.Slider(minimum=1, maximum=30, value=10, step=1, | |
label=texts["en"]["downsample_label"], | |
info=texts["en"]["downsample_info"]) | |
with gr.Row(): | |
process_btn = gr.Button(texts["en"]["process_btn"], variant="primary", size="lg") | |
clear_btn = gr.Button(texts["en"]["clear_btn"], variant="secondary") | |
with gr.Column(scale=2): | |
results_md = gr.Markdown(texts["en"]["results_md"]) | |
with gr.Tabs() as tabs: | |
with gr.TabItem(texts["en"]["final_tab"]): | |
final_output = gr.Textbox(label=texts["en"]["final_label"], lines=12, max_lines=20) | |
# 新增:评分滑块和按钮 | |
rating_slider = gr.Slider(minimum=1, maximum=5, step=1, | |
label=texts["en"]["rating_label"]) | |
submit_rating_btn = gr.Button(value=texts["en"]["submit_rating_btn"]) | |
# 用于显示提交后的感谢信息 | |
thankyou_text = gr.Markdown("") # 初始为空 | |
# 语言切换回调 | |
def toggle_language(current_lang): | |
# current_lang: "en" 或 "zh",返回新的 current_lang 以及一系列组件更新 | |
new_lang = "zh" if current_lang == "en" else "en" | |
t = texts[new_lang] | |
# 更新各个组件文本 | |
updates = { | |
lang_state: new_lang, | |
header_md: gr.update(value=t["title_md"]), | |
desc_md: gr.update(value=t["description_md"]), | |
input_md: gr.update(value=t["input_settings"]), | |
video_input: gr.update(label=t["video_label"]), | |
downsample_rate: gr.update(label=t["downsample_label"], info=t["downsample_info"]), | |
process_btn: gr.update(value=t["process_btn"]), | |
clear_btn: gr.update(value=t["clear_btn"]), | |
results_md: gr.update(value=t["results_md"]), | |
final_output: gr.update(label=t["final_label"]), | |
rating_slider: gr.update(label=t["rating_label"]), | |
submit_rating_btn: gr.update(value=t["submit_rating_btn"]), | |
thankyou_text: gr.update(value="") # 切换语言时清空感谢信息 | |
} | |
# 语言切换按钮文字也需更新:若当前是英文,则按钮显示“中文”,反之显示“English” | |
btn_text = "English" if new_lang == "zh" else "中文" | |
updates[lang_btn] = gr.update(value=btn_text) | |
return updates | |
lang_btn.click(fn=toggle_language, | |
inputs=[lang_state], | |
outputs=[lang_state, | |
header_md, desc_md, | |
input_md, video_input, downsample_rate, process_btn, clear_btn, | |
results_md, final_output, rating_slider, submit_rating_btn, thankyou_text, | |
lang_btn]) | |
# 处理视频的回调:process_video 返回 (result_text, video_path, downsample_rate) | |
def on_process(video, rate): | |
result_text, video_path, dr, process_video = app.process_video(video, rate) | |
# 更新状态 | |
# 如果成功,video_path 不为 None | |
return result_text, video_path, dr, process_video, gr.update(value=None), gr.update(value="") | |
# 注意:outputs 顺序对应 on_process 返回值 | |
# outputs: final_output (文本), last_video_path (state), last_downsample_rate (state), last_entry_id (state), rating_slider (复位), thankyou_text (清空) | |
process_btn.click(fn=on_process, | |
inputs=[video_input, downsample_rate], | |
outputs=[final_output, last_video_path, | |
last_downsample_rate, last_entry_id, | |
rating_slider, thankyou_text]) | |
# 清除按钮:重置所有 | |
def on_clear(): | |
return None, 10, None, None, gr.update(value=10), gr.update(value=""), "中文" if lang_state.value=="en" else "English" | |
# 返回顺序:video_input, downsample_rate, last_video_path, last_downsample_rate, rating_slider, thankyou_text, lang_btn | |
clear_btn.click(fn=on_clear, | |
outputs=[video_input, downsample_rate, | |
last_video_path, last_downsample_rate, | |
rating_slider, thankyou_text, lang_btn]) | |
# 提交评分回调:读取 last_video_path, last_downsample_rate, rating_slider.value | |
def save_rating(entry_id, rating, current_lang): | |
hf_token = os.getenv("HF_TOKEN") | |
if not hf_token or entry_id is None: | |
return texts[current_lang]["thankyou_msg"] | |
# 加载 Dataset | |
ds = load_dataset(HF_DATASET_ID, use_auth_token=hf_token) | |
ds = ds["train"] if isinstance(ds, dict) else ds | |
# 把对应 entry_id 的那行的 rating 更新 | |
records = ds.to_list() | |
for rec in records: | |
if rec["entry_id"] == entry_id: | |
rec["rating"] = int(rating) | |
break | |
# 重新推到 Hub | |
new_ds = Dataset.from_list(records) | |
new_ds.push_to_hub(HF_DATASET_ID, token=hf_token) | |
return texts[current_lang]["thankyou_msg"] | |
# 绑定评分提交按钮 | |
submit_rating_btn.click(fn=save_rating, | |
inputs=[rating_slider, last_entry_id, lang_state], | |
outputs=[thankyou_text]) | |
# 底部说明 | |
instructions_md = gr.Markdown(texts["en"]["instructions_md"]) | |
# 当切换语言时,上面 toggle_language 已更新 instructions_md | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
#port = int(os.environ.get("GRADIO_SERVER_PORT", os.environ.get("PORT", 7860))) | |
#server_name = str(os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")) | |
demo.launch(show_error=True, mcp_server=True) |