qihfang commited on
Commit
447adea
·
1 Parent(s): 348010c

First Commit

Browse files

First Commit

First Commit

.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: 视频Pose Estimation处理平台
3
  emoji: 🎬
4
  colorFrom: blue
5
  colorTo: purple
@@ -9,61 +9,66 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- # 🎬 视频Pose Estimation处理平台
13
 
14
- 一个基于Gradio的视频姿态估计和AI分析平台,支持视频上传、姿态检测、时序下采样和智能分析。
15
 
16
- ## 功能特性
17
 
18
- - 📹 **视频上传**: 支持多种视频格式
19
- - 🤖 **Pose Estimation**: 人体姿态关键点检测
20
- - ⏱️ **时序下采样**: 可配置的帧率降采样
21
- - 🧠 **AI分析**: 集成ChatGPT进行智能分析
22
- - 🐍 **Python处理**: 后续数据处理和分析
23
- - 📊 **结果展示**: 多维度结果可视化
 
24
 
25
- ## 处理流程
26
 
27
- 1. **视频上传**上传待分析的视频文件
28
- 2. **Pose Estimation** → 提取人体关键点数据
29
- 3. **时序下采样**根据设定比率降采样
30
- 4. **ChatGPT分析**AI智能分析姿态特征
31
- 5. **Python处理**运行自定义分析程序
32
- 6. **结果生成**输出综合分析报告
 
33
 
34
- ## 使用方法
35
 
36
- 1. 上传视频文件
37
- 2. 设置下采样率 (1-10)
38
- 3. 输入ChatGPT分析提示词
39
- 4. 点击"开始处理"
40
- 5. 在不同标签页查看结果
41
 
42
- ## 技术栈
43
 
44
- - **前端**: Gradio Web UI
45
- - **视频处理**: OpenCV
46
- - **数据处理**: NumPy
47
- - **AI集成**: ChatGPT API (待集成)
48
- - **部署平台**: Hugging Face Spaces
49
 
50
- ## 开发说明
 
 
 
 
51
 
52
- 当前版本使用模拟数据进行演示。实际部署时需要:
 
 
 
53
 
54
- 1. 集成真实的Pose Estimation模型 (如MediaPipe, OpenPose等)
55
- 2. 配置ChatGPT API密钥
56
- 3. 实现具体的Python分析程序
57
- 4. 添加错误处理和日志记录
58
 
59
- ## 部署到Hugging Face
 
 
60
 
61
- 1. 创建新的Space
62
- 2. 选择Gradio SDK
63
- 3. 上传所有文件
64
- 4. 配置环境变量 (如API密钥)
65
- 5. 启动应用
66
 
67
- ## 许可证
68
 
69
- MIT License
 
 
 
1
  ---
2
+ title: AI Sports Coaching
3
  emoji: 🎬
4
  colorFrom: blue
5
  colorTo: purple
 
9
  pinned: false
10
  ---
11
 
12
+ # 🎬 AI Sports Coaching System
13
 
14
+ A video-based pose estimation and AI analysis platform powered by Vision-Language Models (VLMs). Users can upload videos, perform pose detection, temporal downsampling, and get intelligent feedback.
15
 
16
+ ## Features
17
 
18
+ - 📹 **Video Upload**: Support for multiple video formats
19
+ - 🤖 **Pose Estimation**: Human keypoint detection
20
+ - ⏱️ **Temporal Downsampling**: Configurable frame rate reduction
21
+ - 🧠 **AI Analysis**: Integrate LLMs/VLMs for intelligent insights
22
+ - 🐍 **Python Processing**: Custom data processing and analysis pipelines
23
+ - 📊 **Result Visualization**: Multi-dimensional result display
24
+ - ⭐ **Scoring Mechanism**: User feedback scoring for outputs
25
 
26
+ ## Workflow
27
 
28
+ 1. **Video Upload** Upload one or more videos for analysis
29
+ 2. **Pose Estimation** → Extract human keypoint data
30
+ 3. **Temporal Downsampling** Reduce frame rate according to settings
31
+ 4. **AI Analysis** Use VLM/LLM to analyze pose features
32
+ 5. **Python Processing** Run custom analysis scripts
33
+ 6. **Result Generation** Produce a comprehensive analysis report
34
+ 7. **Scoring** → User rates output quality (e.g., 1–5)
35
 
36
+ ## Usage
37
 
38
+ 1. Upload video file(s)
39
+ 2. Configure downsampling rate (e.g., 110)
40
+ 3. Click “Start Processing”
41
+ 4. View results in different tabs (pose visualization, analysis report, charts, etc.)
42
+ 5. Provide feedback score for the outputs
43
 
44
+ ## TODO
45
 
46
+ - **Accelerate Pose Estimation**
47
+ - Optimize model inference (e.g., model pruning/quantization, GPU/CPU parallelism)
48
+ - Batch processing for multiple videos or frames
49
+ - Investigate lightweight architectures or delegate to hardware accelerators
 
50
 
51
+ - **Local Deployment of VLMs**
52
+ - Documentation for downloading and setting up VLM weights locally
53
+ - Instructions for environment configuration (dependencies, hardware requirements)
54
+ - Offline inference capabilities and fallback strategies
55
+ - Security considerations for storing API keys or model files
56
 
57
+ - **Support Multiple Video Formats**
58
+ - Automatic compatibility check and conversion (e.g., mp4, avi, mov, webm)
59
+ - Integrate ffmpeg (or similar) for on-the-fly format handling
60
+ - Graceful fallback or user guidance when format is unsupported
61
 
62
+ - **Extend Scoring & Feedback Loop**
63
+ - Store user scores along with video metadata
64
+ - Use scores to fine-tune or adjust analysis parameters over time
 
65
 
66
+ - **Support Different Language**
67
+ - Use different language prompts for different language
68
+ - Update prompts for stable language-control
69
 
 
 
 
 
 
70
 
 
71
 
72
+ ## License
73
+
74
+ MIT License
WHAM ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 2b54f7797391c94876848b905ed875b154c4a295
app.py CHANGED
@@ -2,288 +2,560 @@ import gradio as gr
2
  import numpy as np
3
  import cv2
4
  import json
5
- import time
6
- from typing import Tuple, Dict, Any
7
- import tempfile
8
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class PoseEstimationApp:
11
- def __init__(self):
12
  self.processing_steps = [
13
- "视频上传完成",
14
- "开始Pose Estimation...",
15
- "执行时序下采样...",
16
- "调用ChatGPT分析...",
17
- "运行Python程序...",
18
- "生成最终结果"
 
 
19
  ]
20
-
21
- def extract_video_frames(self, video_path: str, downsample_rate: int = 2):
22
- """提取视频帧并进行下采样"""
23
- cap = cv2.VideoCapture(video_path)
24
- frames = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  frame_count = 0
26
-
27
  while True:
28
  ret, frame = cap.read()
29
  if not ret:
30
  break
31
-
32
- # 根据下采样率选择帧
33
  if frame_count % downsample_rate == 0:
34
- frames.append(frame)
35
-
36
  frame_count += 1
37
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  cap.release()
39
- return frames, frame_count
40
-
41
- def mock_pose_estimation(self, frames):
42
- """模拟pose estimation处理"""
43
- # 这里应该是真实的pose estimation代码
44
- # 目前返回模拟数据
45
- pose_data = []
46
- for i, frame in enumerate(frames):
47
- # 模拟关键点数据 (17个关键点,每个点有x,y,confidence)
48
- keypoints = np.random.rand(17, 3).tolist()
49
- pose_data.append({
50
- "frame_id": i,
51
- "keypoints": keypoints,
52
- "timestamp": i * 0.033 # 假设30fps
53
- })
54
- return pose_data
55
-
56
- def mock_chatgpt_analysis(self, pose_data, custom_prompt):
57
- """模拟ChatGPT分析"""
58
- # 这里应该调用实际的ChatGPT API
59
- analysis = f"""
60
- 基于您的提示词: "{custom_prompt}"
61
-
62
- 姿态分析结果:
63
- - 检测到 {len(pose_data)} 个有效帧
64
- - 主要动作模式:步行/站立
65
- - 姿态稳定性:良好
66
- - 运动幅度:中等
67
- - 异常检测:未发现明显异常
68
-
69
- 建议:保持当前运动模式,注意关节角度的稳定性。
70
  """
71
- return analysis
72
-
73
- def mock_python_processing(self, chatgpt_result):
74
- """模拟Python程序处理"""
75
- # 这里应该运行实际的Python分析程序
76
- python_output = f"""
77
- 执行Python分析程序...
78
-
79
- 处理ChatGPT结果: {len(chatgpt_result)} 字符
80
- 计算统计指标...
81
- 生成可视化图表...
82
-
83
- 结果摘要:
84
- - 处理状态: 成功
85
- - 分析维度: 多维度姿态分析
86
- - 输出格式: JSON + 可视化
87
- - 置信度: 0.87
88
  """
89
- return python_output
90
-
91
- def generate_final_result(self, pose_data, chatgpt_analysis, python_output):
92
- """生成最终结果"""
93
- final_result = {
94
- "summary": {
95
- "total_frames": len(pose_data),
96
- "processing_time": "模拟处理",
97
- "confidence_score": 0.87
98
- },
99
- "pose_analysis": "姿态数据已提取",
100
- "ai_insights": chatgpt_analysis,
101
- "technical_analysis": python_output,
102
- "recommendations": [
103
- "继续保持良好的运动姿态",
104
- "注意关节角度的协调性",
105
- "建议增加运动的多样性"
106
- ]
107
- }
108
- return json.dumps(final_result, indent=2, ensure_ascii=False)
109
-
110
- def process_video(self, video_file, downsample_rate, custom_prompt, progress=gr.Progress()):
111
- """主要的视频处理函数"""
112
  if video_file is None:
113
- return "请先上传视频文件", "", "", ""
114
-
115
  try:
116
- # 步骤1: 视频上传完成
117
- progress(0.1, desc="视频上传完成")
118
- time.sleep(1)
119
-
120
- # 步骤2: 提取帧和Pose Estimation
121
- progress(0.3, desc="执行Pose Estimation...")
122
- frames, total_frames = self.extract_video_frames(video_file, downsample_rate)
123
- pose_data = self.mock_pose_estimation(frames)
124
- time.sleep(2)
125
-
126
- # 步骤3: 时序下采样
127
- progress(0.5, desc="时序下采样完成")
128
- time.sleep(1)
129
-
130
- # 步骤4: ChatGPT分析
131
- progress(0.7, desc="ChatGPT分析中...")
132
- if not custom_prompt.strip():
133
- custom_prompt = "请分析这个视频中的姿态特征"
134
- chatgpt_result = self.mock_chatgpt_analysis(pose_data, custom_prompt)
135
- time.sleep(2)
136
-
137
- # 步骤5: Python程序处理
138
- progress(0.9, desc="运行Python分析程序...")
139
- python_result = self.mock_python_processing(chatgpt_result)
140
- time.sleep(1)
141
-
142
- # 步骤6: 生成最终结果
143
- progress(1.0, desc="生成最终结果")
144
- final_result = self.generate_final_result(pose_data, chatgpt_result, python_result)
145
-
146
- # 格式化pose数据用于显示
147
- pose_summary = f"提取了 {len(pose_data)} 帧数据\n"
148
- pose_summary += f"原始帧数: {total_frames}\n"
149
- pose_summary += f"下采样率: {downsample_rate}\n"
150
- pose_summary += f"有效关键点数: {len(pose_data[0]['keypoints']) if pose_data else 0}"
151
-
152
- return pose_summary, chatgpt_result, python_result, final_result
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  except Exception as e:
155
- return f"处理出错: {str(e)}", "", "", ""
 
 
156
 
157
- # 创建应用实例
158
  app = PoseEstimationApp()
159
 
160
- # 创建Gradio界面
161
  def create_interface():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  with gr.Blocks(
163
  theme=gr.themes.Soft(),
164
- title="视频Pose Estimation处理平台",
165
  css="""
166
- .gradio-container {
167
- max-width: 1200px !important;
168
- }
169
- .tab-nav {
170
- background: linear-gradient(90deg, #667eea, #764ba2) !important;
171
- }
172
  """
173
  ) as demo:
174
-
175
- gr.Markdown(
176
- """
177
- # 🎬 视频Pose Estimation处理平台
178
-
179
- 上传视频文件,进行姿态估计分析,获得AI驱动的智能洞察
180
-
181
- **功能流程**: 视频上传 Pose Estimation → 时序下采样 → ChatGPT分析 → Python处理 → 结果生成
182
- """
183
- )
184
-
 
185
  with gr.Row():
186
  with gr.Column(scale=1):
187
- gr.Markdown("## 📤 输入设置")
188
-
189
- video_input = gr.Video(
190
- label="上传视频文件",
191
- sources=["upload"],
192
- height=300
193
- )
194
-
195
  with gr.Row():
196
- downsample_rate = gr.Slider(
197
- minimum=1,
198
- maximum=10,
199
- value=2,
200
- step=1,
201
- label="时序下采样率",
202
- info="每N帧取1帧"
203
- )
204
-
205
- custom_prompt = gr.Textbox(
206
- label="ChatGPT提示词",
207
- placeholder="请输入分析姿态数据的提示词...",
208
- lines=3,
209
- value="请详细分析这个视频中的人体姿态特征,包括运动模式、稳定性和任何异常情况。"
210
- )
211
-
212
  with gr.Row():
213
- process_btn = gr.Button("🚀 开始处理", variant="primary", size="lg")
214
- clear_btn = gr.Button("🔄 清除", variant="secondary")
215
-
216
- with gr.Column(scale=2):
217
- gr.Markdown("## 📊 处理结果")
218
-
219
- with gr.Tabs() as tabs:
220
- with gr.TabItem("姿态数据"):
221
- pose_output = gr.Textbox(
222
- label="Pose Estimation结果",
223
- lines=8,
224
- max_lines=15,
225
- placeholder="姿态数据将在此显示..."
226
- )
227
-
228
- with gr.TabItem("AI分析"):
229
- ai_output = gr.Textbox(
230
- label="ChatGPT分析结果",
231
- lines=8,
232
- max_lines=15,
233
- placeholder="AI分析结果将在此显示..."
234
- )
235
-
236
- with gr.TabItem("Python输出"):
237
- python_output = gr.Textbox(
238
- label="Python程序输出",
239
- lines=8,
240
- max_lines=15,
241
- placeholder="Python处理结果将在此显示..."
242
- )
243
-
244
- with gr.TabItem("最终结果"):
245
- final_output = gr.Textbox(
246
- label="综合分析结果",
247
- lines=8,
248
- max_lines=15,
249
- placeholder="最终结果将在此显示..."
250
- )
251
-
252
- # 事件绑定
253
- process_btn.click(
254
- fn=app.process_video,
255
- inputs=[video_input, downsample_rate, custom_prompt],
256
- outputs=[pose_output, ai_output, python_output, final_output],
257
- show_progress=True
258
- )
259
-
260
- clear_btn.click(
261
- fn=lambda: (None, 2, "", "", "", "", ""),
262
- outputs=[video_input, downsample_rate, custom_prompt, pose_output, ai_output, python_output, final_output]
263
- )
264
-
265
- # 示例
266
- gr.Markdown(
267
- """
268
- ## 💡 使用说明
269
-
270
- 1. **上传视频**: 支持常见视频格式 (MP4, AVI, MOV等)
271
- 2. **设置参数**: 调整下���样率和自定义提示词
272
- 3. **开始处理**: 点击处理按钮,等待分析完成
273
- 4. **查看结果**: 在不同标签页中查看详细结果
274
-
275
- **注意**: 当前版本使用模拟数据,实际部署时需要集成真实的API和算法。
276
- """
277
- )
278
-
279
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
- # 创建并启动应用
282
  if __name__ == "__main__":
283
  demo = create_interface()
284
- demo.launch(
285
- server_name="0.0.0.0",
286
- server_port=7860,
287
- share=True,
288
- show_error=True
289
- )
 
2
  import numpy as np
3
  import cv2
4
  import json
5
+ import subprocess
 
 
6
  import os
7
+ from typing import Tuple, List, Dict, Any
8
+ import tempfile
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
10
+ import torch
11
+ from PIL import Image
12
+ import datetime
13
+ import uuid
14
+ from PIL import Image
15
+ import io, base64
16
+
17
+ # Use HuggingFace remote inference
18
+ try:
19
+ from huggingface_hub import InferenceClient
20
+ except ImportError:
21
+ InferenceClient = None
22
 
23
  class PoseEstimationApp:
24
+ def __init__(self, model_name: str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", use_remote: bool = True, max_history_turns: int = 50):
25
  self.processing_steps = [
26
+ "Video upload completed",
27
+ "Starting video downsampling...",
28
+ "Executing Pose Estimation...",
29
+ "Running Stage1 prompt...",
30
+ "Running Stage2 prompt...",
31
+ "Running Evaluator...",
32
+ "Running Stage3 prompt...",
33
+ "Generating final result"
34
  ]
35
+ self.use_remote = use_remote
36
+ self.model_name = model_name
37
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ if not use_remote:
39
+ raise RuntimeError("Remote inference only supported, please set use_remote=True and provide HF_TOKEN environment variable.")
40
+ if InferenceClient is None:
41
+ raise RuntimeError("huggingface_hub not installed, please install to use remote inference.")
42
+ token = os.getenv("HF_TOKEN")
43
+ if not token:
44
+ raise RuntimeError("HF_TOKEN environment variable not set, please set the access token in deployment environment.")
45
+ try:
46
+ self.client = InferenceClient(model=model_name, token=token)
47
+ except Exception as e:
48
+ raise RuntimeError(f"Failed to initialize remote inference client: {e}")
49
+
50
+ # Conversation history management
51
+ # Use a list to store several rounds of conversation, each item is a dict containing 'role' ('user' or 'assistant') and 'content'
52
+ self.conversation_history: List[Dict[str, str]] = []
53
+ # Keep the most recent number of rounds (user+assistant), truncate when exceeded
54
+ self.max_history_turns = max_history_turns
55
+
56
+ def reset_history(self):
57
+ """
58
+ Clear conversation history, call when starting a new multi-turn conversation scenario.
59
+ """
60
+ self.conversation_history = []
61
+
62
+ def add_user_message(self, message: str):
63
+ self.conversation_history.append({"role": "user", "content": message})
64
+ # If exceeding maximum rounds (here a round refers to user+assistant), remove earliest rounds
65
+ # Calculate current entries, if len > 2 * max_history_turns, truncate the earliest two entries
66
+ max_items = 2 * self.max_history_turns
67
+ if len(self.conversation_history) > max_items:
68
+ # Discard the earliest two entries
69
+ self.conversation_history = self.conversation_history[-max_items:]
70
+
71
+ def add_assistant_message(self, message: str):
72
+ self.conversation_history.append({"role": "assistant", "content": message})
73
+ # Similarly truncate history
74
+ max_items = 2 * self.max_history_turns
75
+ if len(self.conversation_history) > max_items:
76
+ self.conversation_history = self.conversation_history[-max_items:]
77
+
78
+ def build_prompt_with_history(self, new_user_input: str) -> str:
79
+ """
80
+ Concatenate history rounds with current user input into a prompt string.
81
+ Example:
82
+ User: ...
83
+ Assistant: ...
84
+ User: new_user_input
85
+ Assistant:
86
+ """
87
+ prompt_parts = []
88
+ for turn in self.conversation_history:
89
+ if turn["role"] == "user":
90
+ prompt_parts.append(f"User: {turn['content']}")
91
+ else:
92
+ prompt_parts.append(f"Assistant: {turn['content']}")
93
+ # Add new user input, model reply will be generated at the end
94
+ prompt_parts.append(f"User: {new_user_input}")
95
+ prompt_parts.append("Assistant:") # Guide model generation
96
+ full_prompt = "\n".join(prompt_parts)
97
+ return full_prompt
98
+
99
+ def image_to_datauri(self, img: Image.Image, max_size=640, jpeg_quality=70):
100
+ # 先按最长边缩放到 max_size
101
+ w, h = img.size
102
+ scale = max_size / max(w, h)
103
+ if scale < 1.0:
104
+ new_w, new_h = int(w*scale), int(h*scale)
105
+ img = img.resize((new_w, new_h), Image.BILINEAR)
106
+ # 转 JPEG
107
+ buffered = io.BytesIO()
108
+ img.save(buffered, format="JPEG", quality=jpeg_quality)
109
+ img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
110
+ return f"data:image/jpeg;base64,{img_b64}"
111
+
112
+
113
+
114
+ def query_llm_multimodal(self, text: str, indexed_images: list, use_history: bool = True, max_tokens: int = 1024):
115
+ messages = []
116
+ if use_history:
117
+ for turn in self.conversation_history:
118
+ messages.append({"role": turn['role'], "content": turn['content']})
119
+ # 先处理文字部分(如果有)
120
+ if text:
121
+ messages.append({"role": "user", "content": [{"type": "text", "text": text}]})
122
+ self.add_user_message(text)
123
+ # 逐帧添加图片输入
124
+ for clip_idx, (video_idx, img) in enumerate(indexed_images):
125
+ uri = self.image_to_datauri(img, max_size=512, jpeg_quality=60)
126
+ # 只对最关键的少量帧调用,或在调用前筛选
127
+ msg_content = [
128
+ {"type":"image_url","image_url": {"url": uri}},
129
+ {"type":"text","text":f"The {clip_idx}th image is the {video_idx+1}th frame in the video, please analyze and summarize the content. Use the original frame index ({video_idx+1}) for reminder."}
130
+ ]
131
+ messages.append({"role":"user","content":msg_content})
132
+ self.add_user_message(f"[IMAGE frame {video_idx}]")
133
+ # 如帧过多,可在此处 break,或只处理前 K 帧
134
+ try:
135
+ response = self.client.chat.completions.create(messages=messages, max_tokens=max_tokens)
136
+ reply = response.choices[0].message.content
137
+ except Exception as e:
138
+ raise RuntimeError(f"Multimodal inference error: {e}")
139
+ self.add_assistant_message(reply)
140
+ return reply
141
+
142
+ def query_llm(self, prompt: str, max_length: int = 2048, use_history: bool = True) -> str:
143
+ if use_history:
144
+ messages = []
145
+ for turn in self.conversation_history:
146
+ messages.append({"role": turn['role'], "content": turn['content']})
147
+ messages.append({"role": "user", "content": prompt})
148
+ self.add_user_message(prompt)
149
+ else:
150
+ messages = [{"role": "user", "content": prompt}]
151
+ try:
152
+ response = self.client.chat.completions.create(messages=messages, max_tokens=max_length)
153
+ reply = response.choices[0].message.content
154
+ except Exception as e:
155
+ raise RuntimeError(f"Remote inference error: {e}")
156
+ self.add_assistant_message(reply)
157
+ return reply
158
+
159
+ # Other methods remain unchanged...
160
+ def downsample_video(self, input_path: str, output_path: str, downsample_rate: int) -> Tuple[str, int]:
161
+ cap = cv2.VideoCapture(input_path)
162
+ if not cap.isOpened():
163
+ raise RuntimeError(f"Cannot open video file {input_path}")
164
+ fps = cap.get(cv2.CAP_PROP_FPS)
165
+ new_fps = max(1, int(fps / downsample_rate))
166
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
167
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
168
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
169
+ out = cv2.VideoWriter(output_path, fourcc, new_fps, (width, height))
170
  frame_count = 0
171
+ frames = []
172
  while True:
173
  ret, frame = cap.read()
174
  if not ret:
175
  break
 
 
176
  if frame_count % downsample_rate == 0:
177
+ out.write(frame)
178
+ frames.append((frame_count, Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))))
179
  frame_count += 1
180
+ cap.release()
181
+ out.release()
182
+ return output_path, new_fps, frames
183
+
184
+ def run_pose_estimation(self, tmp_dir, video_path: str) -> str:
185
+ base_name = os.path.splitext(os.path.basename(video_path))[0]
186
+ out_dir = os.path.join(tmp_dir, "output")
187
+ os.makedirs(out_dir, exist_ok=True)
188
+ cmd = [
189
+ "python", "WHAM/demo.py",
190
+ "--video", video_path,
191
+ "--save_pkl",
192
+ "--output_pth", out_dir
193
+ ]
194
+ out_dir = os.path.join(out_dir, base_name)
195
+ result = subprocess.run(cmd, capture_output=True, text=True)
196
+ if result.returncode != 0:
197
+ raise RuntimeError(f"Pose Estimation failed: {result.stderr}")
198
+ result_path = os.path.join(out_dir, "wham_output.pkl")
199
+ if not os.path.exists(result_path):
200
+ raise FileNotFoundError(f"Result file not found: {result_path}")
201
+ return result_path
202
+
203
+ def load_pose_data(self, pth_path: str):
204
+ try:
205
+ data = torch.load(pth_path, map_location="cpu")
206
+ return data
207
+ except Exception as e:
208
+ raise RuntimeError(f"Failed to load Pose data: {e}")
209
+
210
+ def extract_frames(self, video_path: str, frame_skip: int = 1) -> List[Tuple[int, Image.Image]]:
211
+ """
212
+ Eagerly read and return all frames as a list of (frame_index, PIL.Image).
213
+ """
214
+ frames = []
215
+ cap = cv2.VideoCapture(video_path)
216
+ if not cap.isOpened():
217
+ return frames
218
+ idx = 0
219
+ while True:
220
+ ret, frame = cap.read()
221
+ if not ret:
222
+ break
223
+ if idx % frame_skip == 0:
224
+ img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
225
+ frames.append((idx, img))
226
+ idx += 1
227
  cap.release()
228
+ return frames
229
+
230
+
231
+ def process_video(self, video_file, downsample_rate, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  """
233
+ 修改:process_video 返回 (result_text, downsampled_video_path, downsample_rate) 三元组
234
+ 以便界面显示视频并存储。
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  if video_file is None:
237
+ return "Please upload a video file first", None, None
 
238
  try:
239
+ self.reset_history()
240
+
241
+ progress(0.1, desc=self.processing_steps[0])
242
+ tmp_dir = tempfile.mkdtemp()
243
+ if hasattr(video_file, 'name'):
244
+ orig_ext = os.path.splitext(video_file.name)[1] # e.g. ".avi"
245
+ else:
246
+ orig_ext = os.path.splitext(video_file)[1]
247
+ input_path = os.path.join(tmp_dir, f"input{orig_ext}")
248
+ os.replace(video_file, input_path)
249
+
250
+
251
+ progress(0.2, desc=self.processing_steps[1])
252
+ downsampled_tmp = os.path.join(tmp_dir, "downsample.mp4")
253
+ downsampled_path, new_fps, frames = self.downsample_video(input_path, downsampled_tmp, downsample_rate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
+ print(type(frames), len(frames))
256
+ frr = self.extract_frames(downsampled_path)
257
+ print(type(frr), len(frr))
258
+
259
+ results_dir = "results"
260
+ os.makedirs(results_dir, exist_ok=True)
261
+ unique_name = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + "_" + str(uuid.uuid4())[:8] + ".mp4"
262
+ persistent_path = os.path.join(results_dir, unique_name)
263
+ # 复制文件
264
+ import shutil
265
+ shutil.copyfile(downsampled_path, persistent_path)
266
+
267
+
268
+ progress(0.3, desc=self.processing_steps[2])
269
+ #pth_path = self.run_pose_estimation(tmp_dir, input_path)
270
+ pth_path = "C:\\Users\\fqh13\\AppData\\Local\\Temp\\tmpdy5roav4\\output\\input\\wham_output.pkl"
271
+ with open(pth_path, "rb") as f:
272
+ import joblib
273
+ pkl_file = joblib.load(f)
274
+ subjs = len(pkl_file.keys())
275
+ if subjs < 1:
276
+ return "Failed to detect characters from the video, please update a new video with higher frame rate and .", None, None
277
+
278
+
279
+ #pth_path = "wham_output.pth"
280
+
281
+ # Stage1
282
+ progress(0.4, desc=self.processing_steps[3])
283
+ stage1_path = os.path.join("prompts", "stage1.txt")
284
+ if not os.path.exists(stage1_path):
285
+ raise RuntimeError("Missing prompts/stage1.txt prompt file")
286
+ with open(stage1_path, 'r', encoding='utf-8') as f:
287
+ prompt1 = f.read()
288
+ prompt1_1 = prompt1.split("[IMAGEFLAG]")[0].strip()
289
+ prompt1_2 = prompt1.split("[IMAGEFLAG]")[1].strip()
290
+ out_stage1_1 = self.query_llm(prompt1_1, use_history=True)
291
+ out_images = self.query_llm_multimodal(text="", indexed_images=frames, use_history=True)
292
+ out_stage1_part2 = self.query_llm(prompt1_2, use_history=True)
293
+ print(out_images)
294
+ print("-------")
295
+ print(out_stage1_part2)
296
+ # Stage2
297
+ progress(0.5, desc=self.processing_steps[4])
298
+ stage2_path = os.path.join("prompts", "stage2.txt")
299
+ if not os.path.exists(stage2_path):
300
+ raise RuntimeError("Missing prompts/stage2.txt prompt file")
301
+ with open(stage2_path, 'r', encoding='utf-8') as f:
302
+ prompt2 = f.read()
303
+ prompt2 = prompt2.replace("[FRAMERATE]", str(new_fps))
304
+ max_retries = 3
305
+ out_stage2 = ""
306
+ temp_json_path = os.path.join(tmp_dir, "temp_json.json")
307
+ for attempt in range(max_retries):
308
+ out_stage2 = self.query_llm(prompt2, use_history=True)
309
+ try:
310
+ parsed = json.loads(out_stage2)
311
+ with open(temp_json_path, 'w', encoding='utf-8') as f:
312
+ json.dump(parsed, f, ensure_ascii=False, indent=2)
313
+ break
314
+ except json.JSONDecodeError:
315
+ prompt2 = "The previous output was not valid JSON. Please output only valid JSON without any extra content." + "\n" + out_stage2
316
+ if attempt == max_retries - 1:
317
+ with open(temp_json_path, 'w', encoding='utf-8') as f:
318
+ f.write(out_stage2)
319
+
320
+ # Evaluator
321
+ progress(0.6, desc=self.processing_steps[5])
322
+ evaluator_cmd = ["python", "estimator.py", pth_path, temp_json_path]
323
+ result = subprocess.run(evaluator_cmd, capture_output=True, text=True)
324
+ if result.returncode != 0:
325
+ raise RuntimeError(f"Evaluator error: {result.stderr}")
326
+ output_txt_path = os.path.join(tmp_dir, "temp_json_output.txt")
327
+ with open(output_txt_path, 'r', encoding='utf-8') as f:
328
+ evaluator_output = f.read()
329
+
330
+ # Stage3
331
+ progress(0.7, desc=self.processing_steps[6])
332
+ stage3_path = os.path.join("prompts", "stage3.txt")
333
+ if not os.path.exists(stage3_path):
334
+ raise RuntimeError("Missing prompts/stage3.txt prompt file")
335
+ with open(stage3_path, 'r', encoding='utf-8') as f:
336
+ prompt3 = f.read()
337
+ prompt3 = prompt3.replace("[RESULTS]", evaluator_output)
338
+ out_stage3 = self.query_llm(prompt3, use_history=True)
339
+
340
+ progress(1.0, desc=self.processing_steps[7])
341
+ # 返回最终文本、持久化保存的视频路径、下采样率
342
+ return out_stage3, persistent_path, downsample_rate
343
  except Exception as e:
344
+ # 出错返回三个值,其中视频路径和下采样率为 None
345
+ return "Processing error: " + str(e), None, None
346
+
347
 
 
348
  app = PoseEstimationApp()
349
 
 
350
  def create_interface():
351
+ # 预定义两种语言下的文本
352
+ texts = {
353
+ "en": {
354
+ "title_md": "# 🎬 Video Pose Estimation Processing Platform",
355
+ "description_md": "Upload a video to downsample and perform pose estimation, combine multimodal LLM analysis to generate intelligent insights",
356
+ "input_settings": "## 📤 Input Settings",
357
+ "video_label": "Upload video file",
358
+ "downsample_label": "Temporal downsampling rate",
359
+ "downsample_info": "Take 1 frame every N frames and reduce frame rate. Higher rate runs faster, lower rate yields more accurate results.",
360
+ "process_btn": "🚀 Start Processing",
361
+ "clear_btn": "🔄 Clear",
362
+ "results_md": "## 📊 Processing Results",
363
+ "final_tab": "Final Result",
364
+ "final_label": "Final Comprehensive Result",
365
+ "rating_label": "Please rate the result (1–5):",
366
+ "submit_rating_btn": "Submit Rating",
367
+ "thankyou_msg": "Thank you for your feedback!",
368
+ "instructions_md": """
369
+ ## 💡 Instructions
370
+ 1. After uploading a video, the system will generate downsample.mp4 based on the downsampling rate.
371
+ 2. Run WHAM/demo.py for Pose Estimation; results are saved in output/<video_name>/wham_output.pth.
372
+ 3. The system will automatically read prompts/stage1.txt, stage2.txt, stage3.txt; user custom prompts are not accepted.
373
+ 4. Stage1: prompts/stage1.txt can include [POSE_SUMMARY] placeholder, auto-replaced with pose summary.
374
+ 5. Stage2: prompts/stage2.txt can include [FRAMERATE] and [STAGE1_RESULT] placeholders, auto-replaced.
375
+ 6. Prompts will be forced to output JSON format for Evaluator use.
376
+ 7. After Evaluator runs, it generates output.txt; content is automatically passed to Stage3.
377
+ 8. Deployment requires HF_TOKEN environment variable set for HuggingFace access token; code uses it automatically.
378
+ 9. Ensure project root contains prompts/stage1.txt, stage2.txt, stage3.txt and WHAM/demo.py, evaluator.py.
379
+ """
380
+ },
381
+ "zh": {
382
+ "title_md": "# 🎬 视频姿态估计处理平台",
383
+ "description_md": "上传视频进行降采样和姿态估计,结合多模态 LLM 分析生成智能化见解",
384
+ "input_settings": "## 📤 输入设置",
385
+ "video_label": "上传视频文件",
386
+ "downsample_label": "时间降采样率",
387
+ "downsample_info": "每隔 N 帧取 1 帧并降低帧率。更高的采样率速度更快,但精度可能下降;更低采样率更准确。",
388
+ "process_btn": "🚀 开始处理",
389
+ "clear_btn": "🔄 清除",
390
+ "results_md": "## 📊 处理结果",
391
+ "final_tab": "最终结果",
392
+ "final_label": "最终综合结果",
393
+ "rating_label": "请对结果进行评分 (1–5):",
394
+ "submit_rating_btn": "提交评分",
395
+ "thankyou_msg": "感谢您的反馈!",
396
+ "instructions_md": """
397
+ ## 💡 使用说明
398
+ 1. 上传视频后,系统会根据降采样率生成 downsample.mp4。
399
+ 2. 运行 WHAM/demo.py 进行姿态估计;结果保存在 output/<video_name>/wham_output.pth。
400
+ 3. 系统会自动读取 prompts/stage1.txt、stage2.txt、stage3.txt;不接受用户自定义提示。
401
+ 4. Stage1: prompts/stage1.txt 可包含 [POSE_SUMMARY] 占位符,将被自动替换。
402
+ 5. Stage2: prompts/stage2.txt 可包含 [FRAMERATE] 和 [STAGE1_RESULT] 占位符,将被自动替换。
403
+ 6. 提示将被强制输出 JSON 格式以供 Evaluator 使用。
404
+ 7. Evaluator 运行后会生成 output.txt;内容会自动传递到 Stage3。
405
+ 8. 部署需要设置 HF_TOKEN 环境变量以获得 HuggingFace 访问令牌;代码会自动使用。
406
+ 9. 确保项目根目录下包含 prompts/stage1.txt、stage2.txt、stage3.txt 以及 WHAM/demo.py、evaluator.py。
407
+ """
408
+ }
409
+ }
410
+
411
  with gr.Blocks(
412
  theme=gr.themes.Soft(),
413
+ title="Video Pose Estimation Processing Platform",
414
  css="""
415
+ .gradio-container { max-width: 1200px !important; }
416
+ .tab-nav { background: linear-gradient(90deg, #667eea, #764ba2) !important; }
 
 
 
 
417
  """
418
  ) as demo:
419
+ # 语言状态
420
+ lang_state = gr.State("en")
421
+ # 隐藏状态:存储最近处理的视频路径和下采样率
422
+ last_video_path = gr.State(None)
423
+ last_downsample_rate = gr.State(None)
424
+
425
+ # 语言切换按钮
426
+ lang_btn = gr.Button("中文") # 初始语言 en,所以按钮文字为“中文”
427
+ # 头部 Markdown
428
+ header_md = gr.Markdown(texts["en"]["title_md"])
429
+ desc_md = gr.Markdown(texts["en"]["description_md"])
430
+
431
  with gr.Row():
432
  with gr.Column(scale=1):
433
+ input_md = gr.Markdown(texts["en"]["input_settings"])
434
+ video_input = gr.Video(label=texts["en"]["video_label"], sources=["upload"], height=300)
 
 
 
 
 
 
435
  with gr.Row():
436
+ downsample_rate = gr.Slider(minimum=1, maximum=30, value=10, step=1,
437
+ label=texts["en"]["downsample_label"],
438
+ info=texts["en"]["downsample_info"])
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  with gr.Row():
440
+ process_btn = gr.Button(texts["en"]["process_btn"], variant="primary", size="lg")
441
+ clear_btn = gr.Button(texts["en"]["clear_btn"], variant="secondary")
442
+
443
+ with gr.Column(scale=2):
444
+ results_md = gr.Markdown(texts["en"]["results_md"])
445
+ with gr.Tabs() as tabs:
446
+ with gr.TabItem(texts["en"]["final_tab"]):
447
+ final_output = gr.Textbox(label=texts["en"]["final_label"], lines=12, max_lines=20)
448
+ # 新增:显示降采样视频预览
449
+ downsampled_video_output = gr.Video(label="Downsampled Video", interactive=False)
450
+ # 新增:评分滑块和按钮
451
+ rating_slider = gr.Slider(minimum=1, maximum=5, step=1,
452
+ label=texts["en"]["rating_label"])
453
+ submit_rating_btn = gr.Button(value=texts["en"]["submit_rating_btn"])
454
+ # 用于显示提交后的感谢信息
455
+ thankyou_text = gr.Markdown("") # 初始为空
456
+
457
+ # 语言切换回调
458
+ def toggle_language(current_lang):
459
+ # current_lang: "en" 或 "zh",返回新的 current_lang 以及一系列组件更新
460
+ new_lang = "zh" if current_lang == "en" else "en"
461
+ t = texts[new_lang]
462
+ # 更新各个组件文本
463
+ updates = {
464
+ lang_state: new_lang,
465
+ header_md: gr.update(value=t["title_md"]),
466
+ desc_md: gr.update(value=t["description_md"]),
467
+ input_md: gr.update(value=t["input_settings"]),
468
+ video_input: gr.update(label=t["video_label"]),
469
+ downsample_rate: gr.update(label=t["downsample_label"], info=t["downsample_info"]),
470
+ process_btn: gr.update(value=t["process_btn"]),
471
+ clear_btn: gr.update(value=t["clear_btn"]),
472
+ results_md: gr.update(value=t["results_md"]),
473
+ final_output: gr.update(label=t["final_label"]),
474
+ rating_slider: gr.update(label=t["rating_label"]),
475
+ submit_rating_btn: gr.update(value=t["submit_rating_btn"]),
476
+ thankyou_text: gr.update(value="") # 切换语言时清空感谢信息
477
+ }
478
+ # 语言切换按钮文字也需更新:若当前是英文,则按钮显示“中文”,反之显示“English”
479
+ btn_text = "English" if new_lang == "zh" else "中文"
480
+ updates[lang_btn] = gr.update(value=btn_text)
481
+ return updates
482
+
483
+ lang_btn.click(fn=toggle_language,
484
+ inputs=[lang_state],
485
+ outputs=[lang_state,
486
+ header_md, desc_md,
487
+ input_md, video_input, downsample_rate, process_btn, clear_btn,
488
+ results_md, final_output, rating_slider, submit_rating_btn, thankyou_text,
489
+ lang_btn])
490
+
491
+ # 处理视频的回调:process_video 返回 (result_text, video_path, downsample_rate)
492
+ def on_process(video, rate):
493
+ result_text, video_path, dr = app.process_video(video, rate)
494
+ # 更新状态
495
+ # 如果成功,video_path 不为 None
496
+ return result_text, video_path, dr, gr.update(value=None), gr.update(value=None), gr.update(value="")
497
+
498
+ # 注意:outputs 顺序对应 on_process 返回值
499
+ # outputs: final_output (文本), downsampled_video_output (视频), last_video_path (state), last_downsample_rate (state), rating_slider (复位), thankyou_text (清空)
500
+ process_btn.click(fn=on_process,
501
+ inputs=[video_input, downsample_rate],
502
+ outputs=[final_output, downsampled_video_output,
503
+ last_video_path, last_downsample_rate,
504
+ rating_slider, thankyou_text])
505
+
506
+ # 清除按钮:重置所有
507
+ def on_clear():
508
+ return None, 10, None, None, gr.update(value=10), gr.update(value=""), "中文" if lang_state.value=="en" else "English"
509
+ # 返回顺序:video_input, downsample_rate, last_video_path, last_downsample_rate, rating_slider, thankyou_text, lang_btn
510
+ clear_btn.click(fn=on_clear,
511
+ outputs=[video_input, downsample_rate,
512
+ last_video_path, last_downsample_rate,
513
+ rating_slider, thankyou_text, lang_btn])
514
+
515
+ # 提交评分回调:读取 last_video_path, last_downsample_rate, rating_slider.value
516
+ def save_rating(video_path, dr, rating, current_lang):
517
+ # video_path 可能为 None,需检查
518
+ if (video_path is None) or (dr is None):
519
+ # 没有有效的视频,忽略
520
+ msg = texts[current_lang]["thankyou_msg"]
521
+ return msg
522
+ # 确保 results_dir 中存在 ratings.json
523
+ ratings_file = "results/ratings.json"
524
+ os.makedirs("results", exist_ok=True)
525
+ record = {
526
+ "timestamp": datetime.datetime.now().isoformat(),
527
+ "video_path": video_path,
528
+ "downsample_rate": dr,
529
+ "rating": int(rating)
530
+ }
531
+ # 如果文件存在,先读,然后 append;否则新建
532
+ try:
533
+ if os.path.exists(ratings_file):
534
+ with open(ratings_file, "r", encoding="utf-8") as f:
535
+ data = json.load(f)
536
+ if not isinstance(data, list):
537
+ data = []
538
+ else:
539
+ data = []
540
+ except Exception:
541
+ data = []
542
+ data.append(record)
543
+ # 写回
544
+ with open(ratings_file, "w", encoding="utf-8") as f:
545
+ json.dump(data, f, ensure_ascii=False, indent=2)
546
+ # 返回感谢信息
547
+ return texts[current_lang]["thankyou_msg"]
548
+
549
+ # 绑定评分提交按钮
550
+ submit_rating_btn.click(fn=save_rating,
551
+ inputs=[last_video_path, last_downsample_rate, rating_slider, lang_state],
552
+ outputs=[thankyou_text])
553
+
554
+ # 底部说明
555
+ instructions_md = gr.Markdown(texts["en"]["instructions_md"])
556
+ # 当切换语言时,上面 toggle_language 已更新 instructions_md
557
+ return demo
558
 
 
559
  if __name__ == "__main__":
560
  demo = create_interface()
561
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True)
 
 
 
 
 
estimator.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from smplx.lbs import batch_rodrigues
5
+
6
+ import json
7
+ from typing import Dict
8
+ import numpy as np
9
+ import joblib
10
+ import sys
11
+ import argparse
12
+
13
+ def get_rotated_axes(global_orient):
14
+ """
15
+ 输入:
16
+ global_orient: [T, 3] numpy array (axis-angle)
17
+ 输出:
18
+ rotated_axes: dict of [T, 3] numpy arrays for X, Y, Z
19
+ """
20
+ R = batch_rodrigues(torch.tensor(global_orient).float()) # [T, 3, 3]
21
+
22
+ # 局部单位坐标轴
23
+ x_local = torch.tensor([1.0, 0.0, 0.0]) # X轴:右→左
24
+ y_local = torch.tensor([0.0, 1.0, 0.0]) # Y轴:下→上
25
+ z_local = torch.tensor([0.0, 0.0, 1.0]) # Z轴:后→前
26
+
27
+ # 应用旋转
28
+ x_world = torch.matmul(R, x_local) # [T, 3]
29
+ y_world = torch.matmul(R, y_local)
30
+ z_world = torch.matmul(R, z_local)
31
+
32
+ return {
33
+ 'x': x_world.numpy(),
34
+ 'y': y_world.numpy(),
35
+ 'z': z_world.numpy()
36
+ }
37
+
38
+
39
+
40
+ class BaseEvaluator:
41
+ def __init__(self):
42
+ pass
43
+
44
+ def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor:
45
+ raise NotImplementedError
46
+
47
+
48
+ class ThreeJointAngleEvaluator(BaseEvaluator):
49
+ def __init__(self, joint_indices: Tuple[int, int, int], threshold: float, greater_than: bool = True):
50
+ super().__init__()
51
+ self.a, self.b, self.c = joint_indices
52
+ self.threshold = threshold
53
+ self.greater_than = greater_than
54
+
55
+ def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor:
56
+ ba = F.normalize(joints[:, self.a] - joints[:, self.b], dim=-1)
57
+ bc = F.normalize(joints[:, self.c] - joints[:, self.b], dim=-1)
58
+ cos_angle = (ba * bc).sum(dim=-1).clamp(-1.0, 1.0)
59
+ angles = torch.acos(cos_angle) * 180.0 / torch.pi
60
+ return angles > self.threshold if self.greater_than else angles < self.threshold
61
+
62
+
63
+ class VectorAngleEvaluator(BaseEvaluator):
64
+ def __init__(self, pair1: Tuple[int, int], pair2: Tuple[int, int], threshold: float, less_than=True):
65
+ super().__init__()
66
+ self.a1, self.a2 = pair1
67
+ self.b1, self.b2 = pair2
68
+ self.threshold = threshold
69
+ self.less_than = less_than
70
+
71
+ def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor:
72
+ v1 = F.normalize(joints[:, self.a1] - joints[:, self.a2], dim=-1)
73
+ v2 = F.normalize(joints[:, self.b1] - joints[:, self.b2], dim=-1)
74
+ angle = torch.acos((v1 * v2).sum(dim=-1).clamp(-1.0, 1.0)) * 180.0 / torch.pi
75
+ return angle < self.threshold if self.less_than else angle > self.threshold
76
+
77
+
78
+ class SingleAxisComparisonEvaluator(BaseEvaluator):
79
+ def __init__(self, joint_a: int, joint_b: int, axis: str, greater_than=True):
80
+ super().__init__()
81
+ self.joint_a = joint_a
82
+ self.joint_b = joint_b
83
+ self.axis = axis
84
+ self.greater_than = greater_than
85
+
86
+ def evaluate(self, joints: torch.Tensor, global_orient, **kwargs) -> torch.Tensor:
87
+ T = joints.shape[0]
88
+ assert self.axis in ["x", "y", "z"]
89
+ rotated_axes = get_rotated_axes(global_orient)
90
+ assert rotated_axes[self.axis].shape == (T, 3)
91
+
92
+ axis_tensor = torch.tensor(rotated_axes[self.axis], dtype=joints.dtype, device=joints.device) # [T, 3]
93
+
94
+ vec_a = joints[:, self.joint_a, :] # [T, 3]
95
+ vec_b = joints[:, self.joint_b, :] # [T, 3]
96
+
97
+ # 投影到当前帧坐标轴方向
98
+ a_proj = torch.sum(vec_a * axis_tensor, dim=1) # [T]
99
+ b_proj = torch.sum(vec_b * axis_tensor, dim=1) # [T]
100
+
101
+ return a_proj > b_proj if self.greater_than else a_proj < b_proj
102
+
103
+
104
+ class JointDistanceEvaluator(BaseEvaluator):
105
+ def __init__(self, joint_a: int, joint_b: int, threshold: float, greater_than=True):
106
+ super().__init__()
107
+ self.joint_a = joint_a
108
+ self.joint_b = joint_b
109
+ self.threshold = threshold
110
+ self.greater_than = greater_than
111
+
112
+ def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor:
113
+ dist = torch.norm(joints[:, self.joint_a] - joints[:, self.joint_b], dim=-1)
114
+ return dist > self.threshold if self.greater_than else dist < self.threshold
115
+
116
+
117
+ class RelativeOffsetDirectionEvaluator(BaseEvaluator):
118
+ def __init__(self, joint_a: int, joint_b: int, axis: str, threshold: float, greater_than=True):
119
+ super().__init__()
120
+ assert axis in ["x", "y", "z"]
121
+ self.joint_a = joint_a
122
+ self.joint_b = joint_b
123
+ self.axis = axis
124
+ self.threshold = threshold
125
+ self.greater_than = greater_than
126
+
127
+ def evaluate(self, joints: torch.Tensor, global_orient, **kwargs) -> torch.Tensor:
128
+ T = joints.shape[0]
129
+ rotated_axes = get_rotated_axes(global_orient)
130
+ assert self.axis in rotated_axes
131
+ assert rotated_axes[self.axis].shape == (T, 3)
132
+
133
+ axis_tensor = torch.tensor(rotated_axes[self.axis], dtype=joints.dtype, device=joints.device) # [T, 3]
134
+
135
+ offset_vec = joints[:, self.joint_a, :] - joints[:, self.joint_b, :] # [T, 3]
136
+ projection = torch.sum(offset_vec * axis_tensor, dim=1) # [T]
137
+
138
+ return projection > self.threshold if self.greater_than else projection < self.threshold
139
+
140
+
141
+ class VelocityThresholdEvaluator(BaseEvaluator):
142
+ def __init__(self, joint: int, axis: str, threshold: float, greater_than=True):
143
+ super().__init__()
144
+ assert axis in ["x", "y", "z"]
145
+ self.joint = joint
146
+ self.axis = axis
147
+ self.threshold = threshold
148
+ self.greater_than = greater_than
149
+
150
+ def evaluate(self, joints: torch.Tensor, global_orient, dt: float = 1.0, **kwargs) -> torch.Tensor:
151
+ T = joints.shape[0]
152
+ rotated_axes = get_rotated_axes(global_orient)
153
+ assert self.axis in rotated_axes
154
+ assert rotated_axes[self.axis].shape == (T, 3)
155
+
156
+ axis_tensor = torch.tensor(rotated_axes[self.axis], dtype=joints.dtype, device=joints.device) # [T, 3]
157
+
158
+ # 对关节位置沿当前坐标轴进行投影
159
+ joint_pos = joints[:, self.joint, :] # [T, 3]
160
+ projection = torch.sum(joint_pos * axis_tensor, dim=1) # [T]
161
+
162
+ # 计算速度(时间差分)
163
+ velocity = (projection[1:] - projection[:-1]) / dt # [T-1]
164
+
165
+ # 比较阈值
166
+ result = velocity > self.threshold if self.greater_than else velocity < self.threshold
167
+
168
+ # 补齐长度
169
+ return result
170
+
171
+
172
+ class AccelerationThresholdEvaluator(BaseEvaluator):
173
+ def __init__(self, joint: int, axis: str, threshold: float, greater_than=True):
174
+ super().__init__()
175
+ assert axis in ["x", "y", "z"]
176
+ self.joint = joint
177
+ self.axis = axis
178
+ self.threshold = threshold
179
+ self.greater_than = greater_than
180
+
181
+ def evaluate(self, joints: torch.Tensor, global_orient, dt: float = 1.0, **kwargs) -> torch.Tensor:
182
+ T = joints.shape[0]
183
+ rotated_axes = get_rotated_axes(global_orient)
184
+ assert self.axis in rotated_axes
185
+ assert rotated_axes[self.axis].shape == (T, 3)
186
+
187
+ axis_tensor = torch.tensor(rotated_axes[self.axis], dtype=joints.dtype, device=joints.device) # [T, 3]
188
+ joint_pos = joints[:, self.joint, :] # [T, 3]
189
+ projection = torch.sum(joint_pos * axis_tensor, dim=1) # [T]
190
+
191
+ velocity = (projection[1:] - projection[:-1]) / dt # [T-1]
192
+ acceleration = (velocity[1:] - velocity[:-1]) / dt # [T-2]
193
+
194
+ result = acceleration > self.threshold if self.greater_than else acceleration < self.threshold
195
+ return result # shape: [T-2]
196
+
197
+
198
+ class AngleRangeEvaluator(BaseEvaluator):
199
+ def __init__(self, joint_indices: Tuple[int, int, int], threshold: float, greater_than=True):
200
+ super().__init__()
201
+ self.a, self.b, self.c = joint_indices
202
+ self.threshold = threshold
203
+ self.greater_than = greater_than
204
+
205
+ def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor:
206
+ ba = F.normalize(joints[:, self.a] - joints[:, self.b], dim=-1)
207
+ bc = F.normalize(joints[:, self.c] - joints[:, self.b], dim=-1)
208
+ cos_angle = (ba * bc).sum(dim=-1).clamp(-1.0, 1.0)
209
+ angles = torch.acos(cos_angle) * 180.0 / torch.pi
210
+ motion_range = angles.max() - angles.min()
211
+ return torch.tensor([motion_range > self.threshold]) if self.greater_than else torch.tensor([motion_range < self.threshold])
212
+
213
+
214
+ class AngleChangeEvaluator(BaseEvaluator):
215
+ def __init__(self, joint_indices: Tuple[int, int, int], frame1: int, frame2: int, threshold: float, greater_than=True):
216
+ super().__init__()
217
+ self.a, self.b, self.c = joint_indices
218
+ self.frame1 = frame1
219
+ self.frame2 = frame2
220
+ self.threshold = threshold
221
+ self.greater_than = greater_than
222
+
223
+ def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor:
224
+ def compute_angle(frame_idx):
225
+ ba = F.normalize(joints[frame_idx, self.a] - joints[frame_idx, self.b], dim=-1)
226
+ bc = F.normalize(joints[frame_idx, self.c] - joints[frame_idx, self.b], dim=-1)
227
+ cos_angle = (ba * bc).sum().clamp(-1.0, 1.0)
228
+ return torch.acos(cos_angle) * 180.0 / torch.pi
229
+
230
+ angle_diff = torch.abs(compute_angle(self.frame2) - compute_angle(self.frame1))
231
+ return torch.tensor([angle_diff > self.threshold]) if self.greater_than else torch.tensor([angle_diff < self.threshold])
232
+
233
+
234
+
235
+ # Joint name to index mapping (for SMPL 24-joint model, common names)
236
+ SMPL_JOINT_NAMES = {
237
+ "pelvis": 0,
238
+ "left_hip": 1,
239
+ "right_hip": 2,
240
+ "spine1": 3,
241
+ "left_knee": 4,
242
+ "right_knee": 5,
243
+ "spine2": 6,
244
+ "left_ankle": 7,
245
+ "right_ankle": 8,
246
+ "spine3": 9,
247
+ "left_foot": 10,
248
+ "right_foot": 11,
249
+ "neck": 12,
250
+ "left_collar": 13,
251
+ "right_collar": 14,
252
+ "head": 15,
253
+ "left_shoulder": 16,
254
+ "right_shoulder": 17,
255
+ "left_elbow": 18,
256
+ "right_elbow": 19,
257
+ "left_wrist": 20,
258
+ "right_wrist": 21,
259
+ "left_hand": 22,
260
+ "right_hand": 23,
261
+ }
262
+
263
+ # Mapping from JSON "type" to class constructor
264
+ EVALUATOR_CLASSES = {
265
+ "ThreeJointAngle": ThreeJointAngleEvaluator,
266
+ "VectorAngle": VectorAngleEvaluator,
267
+ "SingleAxisComparison": SingleAxisComparisonEvaluator,
268
+ "JointDistance": JointDistanceEvaluator,
269
+ "RelativeOffsetDirection": RelativeOffsetDirectionEvaluator,
270
+ "VelocityThreshold": VelocityThresholdEvaluator,
271
+ "AccelerationThreshold": AccelerationThresholdEvaluator,
272
+ "AngleRange": AngleRangeEvaluator,
273
+ "AngleChange": AngleChangeEvaluator,
274
+ # PositionRange can reuse RelativeOffset with a max-min wrapper
275
+ }
276
+
277
+
278
+ def get_joint_index(name: str) -> int:
279
+ if name not in SMPL_JOINT_NAMES:
280
+ raise ValueError(f"Unknown joint name: {name}")
281
+ return SMPL_JOINT_NAMES[name]
282
+
283
+
284
+ def build_evaluator_from_json(json_data: Dict) -> BaseEvaluator:
285
+ etype = json_data["type"]
286
+
287
+
288
+ if etype == "ThreeJointAngle":
289
+ a = get_joint_index(json_data["joint_a"])
290
+ b = get_joint_index(json_data["joint_b"])
291
+ c = get_joint_index(json_data["joint_c"])
292
+ return ThreeJointAngleEvaluator((a, b, c), json_data["threshold"], json_data.get("greater_than", True))
293
+
294
+ elif etype == "VectorAngle":
295
+ a1 = get_joint_index(json_data["joint_a1"])
296
+ a2 = get_joint_index(json_data["joint_a2"])
297
+ b1 = get_joint_index(json_data["joint_b1"])
298
+ b2 = get_joint_index(json_data["joint_b2"])
299
+ return VectorAngleEvaluator((a1, a2), (b1, b2), json_data["threshold"], json_data.get("less_than", True))
300
+
301
+ elif etype == "SingleAxisComparison":
302
+ a = get_joint_index(json_data["joint_a"])
303
+ b = get_joint_index(json_data["joint_b"])
304
+ return SingleAxisComparisonEvaluator(a, b, json_data["axis"], json_data.get("greater_than", True))
305
+
306
+ elif etype == "JointDistance":
307
+ a = get_joint_index(json_data["joint_a"])
308
+ b = get_joint_index(json_data["joint_b"])
309
+ return JointDistanceEvaluator(a, b, json_data["threshold"], json_data.get("greater_than", True))
310
+
311
+ elif etype == "RelativeOffsetDirection":
312
+ a = get_joint_index(json_data["joint_a"])
313
+ b = get_joint_index(json_data["joint_b"])
314
+ return RelativeOffsetDirectionEvaluator(a, b, json_data["axis"], json_data["threshold"], json_data.get("greater_than", True))
315
+
316
+ elif etype == "VelocityThreshold":
317
+ j = get_joint_index(json_data["joint"])
318
+ return VelocityThresholdEvaluator(j, json_data["axis"], json_data["threshold"], json_data.get("greater_than", True))
319
+
320
+ elif etype == "AccelerationThreshold":
321
+ j = get_joint_index(json_data["joint"])
322
+ return AccelerationThresholdEvaluator(j, json_data["axis"], json_data["threshold"], json_data.get("greater_than", True))
323
+
324
+ elif etype == "AngleRange":
325
+ a = get_joint_index(json_data["joint_a"])
326
+ b = get_joint_index(json_data["joint_b"])
327
+ c = get_joint_index(json_data["joint_c"])
328
+ return AngleRangeEvaluator((a, b, c), json_data["threshold"], json_data.get("greater_than", True))
329
+
330
+ elif etype == "AngleChange":
331
+ a = get_joint_index(json_data["joint_a"])
332
+ b = get_joint_index(json_data["joint_b"])
333
+ c = get_joint_index(json_data["joint_c"])
334
+ return AngleChangeEvaluator((a, b, c), json_data["frame1"], json_data["frame2"], json_data["threshold"], json_data.get("greater_than", True))
335
+
336
+ else:
337
+ raise ValueError(f"Unknown evaluator type: {etype}")
338
+
339
+
340
+
341
+ # Main function: load motion tensor and json, return evaluation result
342
+ def evaluate_motion_from_json(
343
+ json_path: str,
344
+ motion_tensor: torch.Tensor,
345
+ global_orient_tensor: torch.Tensor = None
346
+ ) -> Dict[str, List[bool]]:
347
+ with open(json_path, "r") as f:
348
+ configs = json.load(f)
349
+
350
+ results = {}
351
+ for idx, cfg in enumerate(configs):
352
+ try:
353
+ evalr = build_evaluator_from_json(cfg)
354
+ except:
355
+ continue
356
+ name = cfg.get("name", f"eval_{idx}")
357
+
358
+ # 切片
359
+ if "start_frame" in cfg and "end_frame" in cfg:
360
+ s, e = cfg["start_frame"], cfg["end_frame"]
361
+ seg = motion_tensor[s - 1 : e] # [T_seg, J, 3]
362
+ orient_seg = None
363
+ if global_orient_tensor is not None:
364
+ orient_seg = global_orient_tensor[s - 1 : e]
365
+ elif "frame" in cfg:
366
+ # 兼容旧版
367
+ f0 = cfg["frame"]
368
+ seg = motion_tensor[f0 - 1 : f0]
369
+ orient_seg = None
370
+ if global_orient_tensor is not None:
371
+ orient_seg = global_orient_tensor[f0 - 1 : f0]
372
+ else:
373
+ seg = motion_tensor
374
+ orient_seg = None
375
+ if global_orient_tensor is not None:
376
+ orient_seg = global_orient_tensor
377
+
378
+ # 调用
379
+ if orient_seg is not None:
380
+ out = evalr.evaluate(seg, global_orient=orient_seg)
381
+ else:
382
+ out = evalr.evaluate(seg)
383
+ results[name] = out.tolist()
384
+
385
+ # 保存
386
+ out_path = json_path.replace(".json", "_output.txt")
387
+ with open(out_path, "w") as f:
388
+ json.dump(results, f)
389
+
390
+ return results
391
+
392
+ if __name__ == "__main__":
393
+ pkl_file_path = sys.argv[1]
394
+ json_file_path = sys.argv[2]
395
+
396
+
397
+ with open(pkl_file_path, 'rb') as f:
398
+ pose = joblib.load(f)
399
+
400
+ global_orient = pose[0]['pose_world'][:, :3]
401
+ global_orient_tensor = torch.from_numpy(global_orient)
402
+
403
+ pose = pose[0]['joint'].reshape(-1, 45, 3)[:, :24, :]
404
+ pose = torch.from_numpy(pose)
405
+
406
+ evaluate_motion_from_json(json_file_path, pose, global_orient)
407
+
408
+
prompts/stage1.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a sports motion analyst and coach, helping to evaluate technical correctness of an athletic motion based on 10 sampled keyframes from a video (e.g., of a tennis backhand). I will provide you with several frame images or descriptions. When I upload all the frames, I'll tell you. Before I uploads all the frames, please wait and you can analyze and summarize the content of each picture to facilitate subsequent understanding.
2
+
3
+ [IMAGEFLAG]
4
+
5
+ All the frames are uploaded, please analyze the frames based on my further instructions.
6
+ ---
7
+
8
+ ### 🥇 Step 1: Phase Assignment
9
+
10
+ Please assign each of the frames to one of several **action phases** (e.g., "ready position", "backswing", "forward swing", "contact", "follow through"). Label each frame like:
11
+ - Frame 5: [phase]
12
+ - Frame 10: [phase]
13
+ ...
14
+ - Frame 45: [phase]
15
+
16
+ After phase labeling, determine the **segment boundaries**: for every point where the phase changes (e.g., Frame 3→4), treat it as the **midpoint of a transition segment**. Your job is to design evaluation rules for each segment.
17
+
18
+
prompts/stage2.txt ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 🧠 Step 2: Generate Evaluation Rules
2
+
3
+ For each phase, please output a list of JSON rules that check whether this phase was performed correctly.
4
+
5
+ The unit of distance is meter and the frame rate is [FRAMERATE]. Each rule must use one of the 10 allowed evaluator types (defined below), and use **SMPL joint names**. Output the rules as a JSON list, following this format:
6
+
7
+ ---
8
+
9
+ **Coordinate System (negative → positive):**
10
+ - **X axis**: right → left
11
+ - **Y axis**: down → up
12
+ - **Z axis**: back → front
13
+
14
+ ### 🎯 Available evaluator types:
15
+
16
+ 1. ThreeJointAngle
17
+ → Checks angle(A, B, C)
18
+ Fields: phase, start_frame, end_frame, joint_a, joint_b, joint_c, threshold, greater_than
19
+
20
+ 2. VectorAngle
21
+ → Angle between two joint vectors
22
+ Fields: phase, start_frame, end_frame, joint_a1, joint_a2, joint_b1, joint_b2, threshold, less_than
23
+
24
+ 3. SingleAxisComparison
25
+ → Compares one joint's value vs another's along x/y/z
26
+ Fields: phase, start_frame, end_frame, joint_a, joint_b, axis, greater_than
27
+
28
+ 4. JointDistance
29
+ → Distance between two joints
30
+ Fields: phase, start_frame, end_frame, joint_a, joint_b, threshold, greater_than
31
+
32
+ 5. RelativeOffsetDirection
33
+ → Whether joint A is in front/above/etc. of joint B
34
+ Fields: phase, start_frame, end_frame, joint_a, joint_b, axis, threshold, greater_than
35
+
36
+ 6. VelocityThreshold
37
+ → Movement speed along a joint's axis
38
+ Fields: phase, start_frame, end_frame, joint, axis, threshold, greater_than
39
+
40
+ 7. AccelerationThreshold
41
+ → Change in joint velocity
42
+ Fields: phase, start_frame, end_frame, joint, axis, threshold, greater_than
43
+
44
+ 8. AngleRange
45
+ → Angle variation over time
46
+ Fields: phase, start_frame, end_frame, joint_a, joint_b, joint_c, threshold, greater_than
47
+
48
+ 9. AngleChange
49
+ → Angle difference between two specific frames
50
+ Fields: phase, start_frame, end_frame, joint_a, joint_b, joint_c, frame1, frame2, threshold, greater_than
51
+
52
+ 10. PositionRange
53
+ → Spatial displacement of a joint over time
54
+ Fields: phase, start_frame, end_frame, joint, axis, threshold, greater_than
55
+
56
+ Note that, for the axis, you need to select a axis from [x, y, x]. For all the distance, the unit is meter.
57
+ ---
58
+
59
+ ### ✅ SMPL joint names:
60
+
61
+ pelvis, left_hip, right_hip, spine1, left_knee, right_knee, spine2, left_ankle, right_ankle,
62
+ spine3, left_foot, right_foot, neck, left_collar, right_collar, head,
63
+ left_shoulder, right_shoulder, left_elbow, right_elbow, left_wrist, right_wrist, left_hand, right_hand
64
+
65
+ ---
66
+
67
+ ### ✅ Output format:
68
+
69
+ You MUST output only a JSON list, with no explanatory text, no markdown, no code fences. Output must start with [ and end with ].
70
+ Example:
71
+ [
72
+ {
73
+ "type": "ThreeJointAngle",
74
+ "name": "elbow_straightens",
75
+ "start_frame": "10",
76
+ "end_frame": "20",
77
+ "joint_a": "left_shoulder",
78
+ "joint_b": "left_elbow",
79
+ "joint_c": "left_wrist",
80
+ "threshold": 160,
81
+ "greater_than": true
82
+ },
83
+ ...
84
+ ]
85
+
prompts/stage3.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 🧩 Step 3: Interpret binary detection results
2
+ After you've generated evaluation rules, I will run them through a system and return a dict of boolean results, where each key is a rule name and each value is a list of True/False results per frame.
3
+
4
+ Your job:
5
+
6
+ Analyze which rules failed (False in any frame).
7
+
8
+ Provide concise, practical feedback about the motion technique (e.g., "Elbow should be more extended at contact", "Backswing was too short", etc.) Use clear language suitable for an athlete or coach.
9
+
10
+
11
+ The boolean results:
12
+ [RESULTS]
requirements.txt CHANGED
@@ -1,6 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  gradio>=4.0.0
2
- opencv-python-headless>=4.8.0
3
- numpy>=1.24.0
4
- Pillow>=9.5.0
5
  requests>=2.31.0
6
- python-multipart>=0.0.6
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python 基础库
2
+ numpy==1.24.3
3
+ yacs
4
+ joblib
5
+ scikit-image
6
+ opencv-python
7
+ imageio[ffmpeg]
8
+ matplotlib
9
+ tensorboard
10
+ smplx
11
+ progress
12
+ einops
13
+ munkres
14
+ xtcocotools>=1.8
15
+ loguru
16
+ tqdm
17
+ ultralytics
18
+ gdown==4.6.0
19
+ smplx
20
  gradio>=4.0.0
21
+ python-multipart>=0.0.6
 
 
22
  requests>=2.31.0
23
+ Pillow>=9.5.0
24
+ transformers>=4.28.0
25
+ accelerate
26
+ torch==2.7.0
27
+ torchvision==0.15.2
28
+ torchaudio
29
+ chumpy @ git+https://github.com/mattloper/chumpy
30
+ mmcv==1.3.9
31
+ timm==0.4.9
32
+ setuptools==59.5.0
33
+
34
+ -e ./third-party/ViTPose
35
+ torch-scatter @ https://data.pyg.org/whl/torch-2.7.0+cpu.html
smplx/smpl/SMPL_NEUTRAL.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4924f235e63f7c5d5b690acedf736419c2edb846a2d69fc0956169615fa75688
3
+ size 247186228
tracking_results.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aac30cfd9715393f0d600a72c27687ff3866c9a7969cc0d3b8dfc93a84642d82
3
+ size 66