SportsCoaching / estimator.py
qihfang's picture
Fix AngleChangeEvaluator
c9b39f1
raw
history blame
15.4 kB
from typing import Tuple, List
import torch
import torch.nn.functional as F
from smplx.lbs import batch_rodrigues
import json
from typing import Dict
import numpy as np
import joblib
import sys
import argparse
def get_rotated_axes(global_orient):
"""
输入:
global_orient: [T, 3] numpy array (axis-angle)
输出:
rotated_axes: dict of [T, 3] numpy arrays for X, Y, Z
"""
R = batch_rodrigues(torch.tensor(global_orient).float()) # [T, 3, 3]
# 局部单位坐标轴
x_local = torch.tensor([1.0, 0.0, 0.0]) # X轴:右→左
y_local = torch.tensor([0.0, 1.0, 0.0]) # Y轴:下→上
z_local = torch.tensor([0.0, 0.0, 1.0]) # Z轴:后→前
# 应用旋转
x_world = torch.matmul(R, x_local) # [T, 3]
y_world = torch.matmul(R, y_local)
z_world = torch.matmul(R, z_local)
return {
'x': x_world.numpy(),
'y': y_world.numpy(),
'z': z_world.numpy()
}
class BaseEvaluator:
def __init__(self):
pass
def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor:
raise NotImplementedError
class ThreeJointAngleEvaluator(BaseEvaluator):
def __init__(self, joint_indices: Tuple[int, int, int], threshold: float, greater_than: bool = True):
super().__init__()
self.a, self.b, self.c = joint_indices
self.threshold = threshold
self.greater_than = greater_than
def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor:
ba = F.normalize(joints[:, self.a] - joints[:, self.b], dim=-1)
bc = F.normalize(joints[:, self.c] - joints[:, self.b], dim=-1)
cos_angle = (ba * bc).sum(dim=-1).clamp(-1.0, 1.0)
angles = torch.acos(cos_angle) * 180.0 / torch.pi
return angles > self.threshold if self.greater_than else angles < self.threshold
class VectorAngleEvaluator(BaseEvaluator):
def __init__(self, pair1: Tuple[int, int], pair2: Tuple[int, int], threshold: float, less_than=True):
super().__init__()
self.a1, self.a2 = pair1
self.b1, self.b2 = pair2
self.threshold = threshold
self.less_than = less_than
def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor:
v1 = F.normalize(joints[:, self.a1] - joints[:, self.a2], dim=-1)
v2 = F.normalize(joints[:, self.b1] - joints[:, self.b2], dim=-1)
angle = torch.acos((v1 * v2).sum(dim=-1).clamp(-1.0, 1.0)) * 180.0 / torch.pi
return angle < self.threshold if self.less_than else angle > self.threshold
class SingleAxisComparisonEvaluator(BaseEvaluator):
def __init__(self, joint_a: int, joint_b: int, axis: str, greater_than=True):
super().__init__()
self.joint_a = joint_a
self.joint_b = joint_b
self.axis = axis
self.greater_than = greater_than
def evaluate(self, joints: torch.Tensor, global_orient, **kwargs) -> torch.Tensor:
T = joints.shape[0]
assert self.axis in ["x", "y", "z"]
rotated_axes = get_rotated_axes(global_orient)
assert rotated_axes[self.axis].shape == (T, 3)
axis_tensor = torch.tensor(rotated_axes[self.axis], dtype=joints.dtype, device=joints.device) # [T, 3]
vec_a = joints[:, self.joint_a, :] # [T, 3]
vec_b = joints[:, self.joint_b, :] # [T, 3]
# 投影到当前帧坐标轴方向
a_proj = torch.sum(vec_a * axis_tensor, dim=1) # [T]
b_proj = torch.sum(vec_b * axis_tensor, dim=1) # [T]
return a_proj > b_proj if self.greater_than else a_proj < b_proj
class JointDistanceEvaluator(BaseEvaluator):
def __init__(self, joint_a: int, joint_b: int, threshold: float, greater_than=True):
super().__init__()
self.joint_a = joint_a
self.joint_b = joint_b
self.threshold = threshold
self.greater_than = greater_than
def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor:
dist = torch.norm(joints[:, self.joint_a] - joints[:, self.joint_b], dim=-1)
return dist > self.threshold if self.greater_than else dist < self.threshold
class RelativeOffsetDirectionEvaluator(BaseEvaluator):
def __init__(self, joint_a: int, joint_b: int, axis: str, threshold: float, greater_than=True):
super().__init__()
assert axis in ["x", "y", "z"]
self.joint_a = joint_a
self.joint_b = joint_b
self.axis = axis
self.threshold = threshold
self.greater_than = greater_than
def evaluate(self, joints: torch.Tensor, global_orient, **kwargs) -> torch.Tensor:
T = joints.shape[0]
rotated_axes = get_rotated_axes(global_orient)
assert self.axis in rotated_axes
assert rotated_axes[self.axis].shape == (T, 3)
axis_tensor = torch.tensor(rotated_axes[self.axis], dtype=joints.dtype, device=joints.device) # [T, 3]
offset_vec = joints[:, self.joint_a, :] - joints[:, self.joint_b, :] # [T, 3]
projection = torch.sum(offset_vec * axis_tensor, dim=1) # [T]
return projection > self.threshold if self.greater_than else projection < self.threshold
class VelocityThresholdEvaluator(BaseEvaluator):
def __init__(self, joint: int, axis: str, threshold: float, greater_than=True):
super().__init__()
assert axis in ["x", "y", "z"]
self.joint = joint
self.axis = axis
self.threshold = threshold
self.greater_than = greater_than
def evaluate(self, joints: torch.Tensor, global_orient, dt: float = 1.0, **kwargs) -> torch.Tensor:
T = joints.shape[0]
rotated_axes = get_rotated_axes(global_orient)
assert self.axis in rotated_axes
assert rotated_axes[self.axis].shape == (T, 3)
axis_tensor = torch.tensor(rotated_axes[self.axis], dtype=joints.dtype, device=joints.device) # [T, 3]
# 对关节位置沿当前坐标轴进行投影
joint_pos = joints[:, self.joint, :] # [T, 3]
projection = torch.sum(joint_pos * axis_tensor, dim=1) # [T]
# 计算速度(时间差分)
velocity = (projection[1:] - projection[:-1]) / dt # [T-1]
# 比较阈值
result = velocity > self.threshold if self.greater_than else velocity < self.threshold
# 补齐长度
return result
class AccelerationThresholdEvaluator(BaseEvaluator):
def __init__(self, joint: int, axis: str, threshold: float, greater_than=True):
super().__init__()
assert axis in ["x", "y", "z"]
self.joint = joint
self.axis = axis
self.threshold = threshold
self.greater_than = greater_than
def evaluate(self, joints: torch.Tensor, global_orient, dt: float = 1.0, **kwargs) -> torch.Tensor:
T = joints.shape[0]
rotated_axes = get_rotated_axes(global_orient)
assert self.axis in rotated_axes
assert rotated_axes[self.axis].shape == (T, 3)
axis_tensor = torch.tensor(rotated_axes[self.axis], dtype=joints.dtype, device=joints.device) # [T, 3]
joint_pos = joints[:, self.joint, :] # [T, 3]
projection = torch.sum(joint_pos * axis_tensor, dim=1) # [T]
velocity = (projection[1:] - projection[:-1]) / dt # [T-1]
acceleration = (velocity[1:] - velocity[:-1]) / dt # [T-2]
result = acceleration > self.threshold if self.greater_than else acceleration < self.threshold
return result # shape: [T-2]
class AngleRangeEvaluator(BaseEvaluator):
def __init__(self, joint_indices: Tuple[int, int, int], threshold: float, greater_than=True):
super().__init__()
self.a, self.b, self.c = joint_indices
self.threshold = threshold
self.greater_than = greater_than
def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor:
ba = F.normalize(joints[:, self.a] - joints[:, self.b], dim=-1)
bc = F.normalize(joints[:, self.c] - joints[:, self.b], dim=-1)
cos_angle = (ba * bc).sum(dim=-1).clamp(-1.0, 1.0)
angles = torch.acos(cos_angle) * 180.0 / torch.pi
motion_range = angles.max() - angles.min()
return torch.tensor([motion_range > self.threshold]) if self.greater_than else torch.tensor([motion_range < self.threshold])
class AngleChangeEvaluator(BaseEvaluator):
def __init__(self, joint_indices: Tuple[int, int, int], frame1: int, frame2: int, threshold: float, greater_than=True):
super().__init__()
self.a, self.b, self.c = joint_indices
self.frame1 = frame1
self.frame2 = frame2
self.threshold = threshold
self.greater_than = greater_than
def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor:
def compute_angle(frame_idx):
ba = F.normalize(joints[frame_idx, self.a] - joints[frame_idx, self.b], dim=-1)
bc = F.normalize(joints[frame_idx, self.c] - joints[frame_idx, self.b], dim=-1)
cos_angle = (ba * bc).sum().clamp(-1.0, 1.0)
return torch.acos(cos_angle) * 180.0 / torch.pi
angle_diff = torch.abs(compute_angle(-1) - compute_angle(0))
return torch.tensor([angle_diff > self.threshold]) if self.greater_than else torch.tensor([angle_diff < self.threshold])
# Joint name to index mapping (for SMPL 24-joint model, common names)
SMPL_JOINT_NAMES = {
"pelvis": 0,
"left_hip": 1,
"right_hip": 2,
"spine1": 3,
"left_knee": 4,
"right_knee": 5,
"spine2": 6,
"left_ankle": 7,
"right_ankle": 8,
"spine3": 9,
"left_foot": 10,
"right_foot": 11,
"neck": 12,
"left_collar": 13,
"right_collar": 14,
"head": 15,
"left_shoulder": 16,
"right_shoulder": 17,
"left_elbow": 18,
"right_elbow": 19,
"left_wrist": 20,
"right_wrist": 21,
"left_hand": 22,
"right_hand": 23,
}
# Mapping from JSON "type" to class constructor
EVALUATOR_CLASSES = {
"ThreeJointAngle": ThreeJointAngleEvaluator,
"VectorAngle": VectorAngleEvaluator,
"SingleAxisComparison": SingleAxisComparisonEvaluator,
"JointDistance": JointDistanceEvaluator,
"RelativeOffsetDirection": RelativeOffsetDirectionEvaluator,
"VelocityThreshold": VelocityThresholdEvaluator,
"AccelerationThreshold": AccelerationThresholdEvaluator,
"AngleRange": AngleRangeEvaluator,
"AngleChange": AngleChangeEvaluator,
# PositionRange can reuse RelativeOffset with a max-min wrapper
}
def get_joint_index(name: str) -> int:
if name not in SMPL_JOINT_NAMES:
raise ValueError(f"Unknown joint name: {name}")
return SMPL_JOINT_NAMES[name]
def build_evaluator_from_json(json_data: Dict) -> BaseEvaluator:
etype = json_data["type"]
if etype == "ThreeJointAngle":
a = get_joint_index(json_data["joint_a"])
b = get_joint_index(json_data["joint_b"])
c = get_joint_index(json_data["joint_c"])
return ThreeJointAngleEvaluator((a, b, c), json_data["threshold"], json_data.get("greater_than", True))
elif etype == "VectorAngle":
a1 = get_joint_index(json_data["joint_a1"])
a2 = get_joint_index(json_data["joint_a2"])
b1 = get_joint_index(json_data["joint_b1"])
b2 = get_joint_index(json_data["joint_b2"])
return VectorAngleEvaluator((a1, a2), (b1, b2), json_data["threshold"], json_data.get("less_than", True))
elif etype == "SingleAxisComparison":
a = get_joint_index(json_data["joint_a"])
b = get_joint_index(json_data["joint_b"])
return SingleAxisComparisonEvaluator(a, b, json_data["axis"], json_data.get("greater_than", True))
elif etype == "JointDistance":
a = get_joint_index(json_data["joint_a"])
b = get_joint_index(json_data["joint_b"])
return JointDistanceEvaluator(a, b, json_data["threshold"], json_data.get("greater_than", True))
elif etype == "RelativeOffsetDirection":
a = get_joint_index(json_data["joint_a"])
b = get_joint_index(json_data["joint_b"])
return RelativeOffsetDirectionEvaluator(a, b, json_data["axis"], json_data["threshold"], json_data.get("greater_than", True))
elif etype == "VelocityThreshold":
j = get_joint_index(json_data["joint"])
return VelocityThresholdEvaluator(j, json_data["axis"], json_data["threshold"], json_data.get("greater_than", True))
elif etype == "AccelerationThreshold":
j = get_joint_index(json_data["joint"])
return AccelerationThresholdEvaluator(j, json_data["axis"], json_data["threshold"], json_data.get("greater_than", True))
elif etype == "AngleRange":
a = get_joint_index(json_data["joint_a"])
b = get_joint_index(json_data["joint_b"])
c = get_joint_index(json_data["joint_c"])
return AngleRangeEvaluator((a, b, c), json_data["threshold"], json_data.get("greater_than", True))
elif etype == "AngleChange":
a = get_joint_index(json_data["joint_a"])
b = get_joint_index(json_data["joint_b"])
c = get_joint_index(json_data["joint_c"])
return AngleChangeEvaluator((a, b, c), json_data["frame1"], json_data["frame2"], json_data["threshold"], json_data.get("greater_than", True))
else:
raise ValueError(f"Unknown evaluator type: {etype}")
# Main function: load motion tensor and json, return evaluation result
def evaluate_motion_from_json(
json_path: str,
motion_tensor: torch.Tensor,
global_orient_tensor: torch.Tensor = None
) -> Dict[str, List[bool]]:
with open(json_path, "r") as f:
configs = json.load(f)
results = {}
for idx, cfg in enumerate(configs):
try:
evalr = build_evaluator_from_json(cfg)
except:
continue
name = cfg.get("name", f"eval_{idx}")
# 切片
if "start_frame" in cfg and "end_frame" in cfg:
s, e = cfg["start_frame"], cfg["end_frame"]
seg = motion_tensor[s - 1 : e] # [T_seg, J, 3]
orient_seg = None
if global_orient_tensor is not None:
orient_seg = global_orient_tensor[s - 1 : e]
elif "frame" in cfg:
# 兼容旧版
f0 = cfg["frame"]
seg = motion_tensor[f0 - 1 : f0]
orient_seg = None
if global_orient_tensor is not None:
orient_seg = global_orient_tensor[f0 - 1 : f0]
else:
seg = motion_tensor
orient_seg = None
if global_orient_tensor is not None:
orient_seg = global_orient_tensor
# 调用
if orient_seg is not None:
out = evalr.evaluate(seg, global_orient=orient_seg)
else:
out = evalr.evaluate(seg)
results[name] = out.tolist()
# 保存
out_path = json_path.replace(".json", "_output.txt")
with open(out_path, "w") as f:
json.dump(results, f)
return results
if __name__ == "__main__":
pkl_file_path = sys.argv[1]
json_file_path = sys.argv[2]
with open(pkl_file_path, 'rb') as f:
pose = joblib.load(f)
global_orient = pose[0]['pose_world'][:, :3]
global_orient_tensor = torch.from_numpy(global_orient)
pose = pose[0]['joint'].reshape(-1, 45, 3)[:, :24, :]
pose = torch.from_numpy(pose)
evaluate_motion_from_json(json_file_path, pose, global_orient)