Spaces:
Sleeping
Sleeping
from typing import Tuple, List | |
import torch | |
import torch.nn.functional as F | |
from smplx.lbs import batch_rodrigues | |
import json | |
from typing import Dict | |
import numpy as np | |
import joblib | |
import sys | |
import argparse | |
def get_rotated_axes(global_orient): | |
""" | |
输入: | |
global_orient: [T, 3] numpy array (axis-angle) | |
输出: | |
rotated_axes: dict of [T, 3] numpy arrays for X, Y, Z | |
""" | |
R = batch_rodrigues(torch.tensor(global_orient).float()) # [T, 3, 3] | |
# 局部单位坐标轴 | |
x_local = torch.tensor([1.0, 0.0, 0.0]) # X轴:右→左 | |
y_local = torch.tensor([0.0, 1.0, 0.0]) # Y轴:下→上 | |
z_local = torch.tensor([0.0, 0.0, 1.0]) # Z轴:后→前 | |
# 应用旋转 | |
x_world = torch.matmul(R, x_local) # [T, 3] | |
y_world = torch.matmul(R, y_local) | |
z_world = torch.matmul(R, z_local) | |
return { | |
'x': x_world.numpy(), | |
'y': y_world.numpy(), | |
'z': z_world.numpy() | |
} | |
class BaseEvaluator: | |
def __init__(self): | |
pass | |
def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor: | |
raise NotImplementedError | |
class ThreeJointAngleEvaluator(BaseEvaluator): | |
def __init__(self, joint_indices: Tuple[int, int, int], threshold: float, greater_than: bool = True): | |
super().__init__() | |
self.a, self.b, self.c = joint_indices | |
self.threshold = threshold | |
self.greater_than = greater_than | |
def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor: | |
ba = F.normalize(joints[:, self.a] - joints[:, self.b], dim=-1) | |
bc = F.normalize(joints[:, self.c] - joints[:, self.b], dim=-1) | |
cos_angle = (ba * bc).sum(dim=-1).clamp(-1.0, 1.0) | |
angles = torch.acos(cos_angle) * 180.0 / torch.pi | |
return angles > self.threshold if self.greater_than else angles < self.threshold | |
class VectorAngleEvaluator(BaseEvaluator): | |
def __init__(self, pair1: Tuple[int, int], pair2: Tuple[int, int], threshold: float, less_than=True): | |
super().__init__() | |
self.a1, self.a2 = pair1 | |
self.b1, self.b2 = pair2 | |
self.threshold = threshold | |
self.less_than = less_than | |
def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor: | |
v1 = F.normalize(joints[:, self.a1] - joints[:, self.a2], dim=-1) | |
v2 = F.normalize(joints[:, self.b1] - joints[:, self.b2], dim=-1) | |
angle = torch.acos((v1 * v2).sum(dim=-1).clamp(-1.0, 1.0)) * 180.0 / torch.pi | |
return angle < self.threshold if self.less_than else angle > self.threshold | |
class SingleAxisComparisonEvaluator(BaseEvaluator): | |
def __init__(self, joint_a: int, joint_b: int, axis: str, greater_than=True): | |
super().__init__() | |
self.joint_a = joint_a | |
self.joint_b = joint_b | |
self.axis = axis | |
self.greater_than = greater_than | |
def evaluate(self, joints: torch.Tensor, global_orient, **kwargs) -> torch.Tensor: | |
T = joints.shape[0] | |
assert self.axis in ["x", "y", "z"] | |
rotated_axes = get_rotated_axes(global_orient) | |
assert rotated_axes[self.axis].shape == (T, 3) | |
axis_tensor = torch.tensor(rotated_axes[self.axis], dtype=joints.dtype, device=joints.device) # [T, 3] | |
vec_a = joints[:, self.joint_a, :] # [T, 3] | |
vec_b = joints[:, self.joint_b, :] # [T, 3] | |
# 投影到当前帧坐标轴方向 | |
a_proj = torch.sum(vec_a * axis_tensor, dim=1) # [T] | |
b_proj = torch.sum(vec_b * axis_tensor, dim=1) # [T] | |
return a_proj > b_proj if self.greater_than else a_proj < b_proj | |
class JointDistanceEvaluator(BaseEvaluator): | |
def __init__(self, joint_a: int, joint_b: int, threshold: float, greater_than=True): | |
super().__init__() | |
self.joint_a = joint_a | |
self.joint_b = joint_b | |
self.threshold = threshold | |
self.greater_than = greater_than | |
def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor: | |
dist = torch.norm(joints[:, self.joint_a] - joints[:, self.joint_b], dim=-1) | |
return dist > self.threshold if self.greater_than else dist < self.threshold | |
class RelativeOffsetDirectionEvaluator(BaseEvaluator): | |
def __init__(self, joint_a: int, joint_b: int, axis: str, threshold: float, greater_than=True): | |
super().__init__() | |
assert axis in ["x", "y", "z"] | |
self.joint_a = joint_a | |
self.joint_b = joint_b | |
self.axis = axis | |
self.threshold = threshold | |
self.greater_than = greater_than | |
def evaluate(self, joints: torch.Tensor, global_orient, **kwargs) -> torch.Tensor: | |
T = joints.shape[0] | |
rotated_axes = get_rotated_axes(global_orient) | |
assert self.axis in rotated_axes | |
assert rotated_axes[self.axis].shape == (T, 3) | |
axis_tensor = torch.tensor(rotated_axes[self.axis], dtype=joints.dtype, device=joints.device) # [T, 3] | |
offset_vec = joints[:, self.joint_a, :] - joints[:, self.joint_b, :] # [T, 3] | |
projection = torch.sum(offset_vec * axis_tensor, dim=1) # [T] | |
return projection > self.threshold if self.greater_than else projection < self.threshold | |
class VelocityThresholdEvaluator(BaseEvaluator): | |
def __init__(self, joint: int, axis: str, threshold: float, greater_than=True): | |
super().__init__() | |
assert axis in ["x", "y", "z"] | |
self.joint = joint | |
self.axis = axis | |
self.threshold = threshold | |
self.greater_than = greater_than | |
def evaluate(self, joints: torch.Tensor, global_orient, dt: float = 1.0, **kwargs) -> torch.Tensor: | |
T = joints.shape[0] | |
rotated_axes = get_rotated_axes(global_orient) | |
assert self.axis in rotated_axes | |
assert rotated_axes[self.axis].shape == (T, 3) | |
axis_tensor = torch.tensor(rotated_axes[self.axis], dtype=joints.dtype, device=joints.device) # [T, 3] | |
# 对关节位置沿当前坐标轴进行投影 | |
joint_pos = joints[:, self.joint, :] # [T, 3] | |
projection = torch.sum(joint_pos * axis_tensor, dim=1) # [T] | |
# 计算速度(时间差分) | |
velocity = (projection[1:] - projection[:-1]) / dt # [T-1] | |
# 比较阈值 | |
result = velocity > self.threshold if self.greater_than else velocity < self.threshold | |
# 补齐长度 | |
return result | |
class AccelerationThresholdEvaluator(BaseEvaluator): | |
def __init__(self, joint: int, axis: str, threshold: float, greater_than=True): | |
super().__init__() | |
assert axis in ["x", "y", "z"] | |
self.joint = joint | |
self.axis = axis | |
self.threshold = threshold | |
self.greater_than = greater_than | |
def evaluate(self, joints: torch.Tensor, global_orient, dt: float = 1.0, **kwargs) -> torch.Tensor: | |
T = joints.shape[0] | |
rotated_axes = get_rotated_axes(global_orient) | |
assert self.axis in rotated_axes | |
assert rotated_axes[self.axis].shape == (T, 3) | |
axis_tensor = torch.tensor(rotated_axes[self.axis], dtype=joints.dtype, device=joints.device) # [T, 3] | |
joint_pos = joints[:, self.joint, :] # [T, 3] | |
projection = torch.sum(joint_pos * axis_tensor, dim=1) # [T] | |
velocity = (projection[1:] - projection[:-1]) / dt # [T-1] | |
acceleration = (velocity[1:] - velocity[:-1]) / dt # [T-2] | |
result = acceleration > self.threshold if self.greater_than else acceleration < self.threshold | |
return result # shape: [T-2] | |
class AngleRangeEvaluator(BaseEvaluator): | |
def __init__(self, joint_indices: Tuple[int, int, int], threshold: float, greater_than=True): | |
super().__init__() | |
self.a, self.b, self.c = joint_indices | |
self.threshold = threshold | |
self.greater_than = greater_than | |
def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor: | |
ba = F.normalize(joints[:, self.a] - joints[:, self.b], dim=-1) | |
bc = F.normalize(joints[:, self.c] - joints[:, self.b], dim=-1) | |
cos_angle = (ba * bc).sum(dim=-1).clamp(-1.0, 1.0) | |
angles = torch.acos(cos_angle) * 180.0 / torch.pi | |
motion_range = angles.max() - angles.min() | |
return torch.tensor([motion_range > self.threshold]) if self.greater_than else torch.tensor([motion_range < self.threshold]) | |
class AngleChangeEvaluator(BaseEvaluator): | |
def __init__(self, joint_indices: Tuple[int, int, int], frame1: int, frame2: int, threshold: float, greater_than=True): | |
super().__init__() | |
self.a, self.b, self.c = joint_indices | |
self.frame1 = frame1 | |
self.frame2 = frame2 | |
self.threshold = threshold | |
self.greater_than = greater_than | |
def evaluate(self, joints: torch.Tensor, **kwargs) -> torch.Tensor: | |
def compute_angle(frame_idx): | |
ba = F.normalize(joints[frame_idx, self.a] - joints[frame_idx, self.b], dim=-1) | |
bc = F.normalize(joints[frame_idx, self.c] - joints[frame_idx, self.b], dim=-1) | |
cos_angle = (ba * bc).sum().clamp(-1.0, 1.0) | |
return torch.acos(cos_angle) * 180.0 / torch.pi | |
angle_diff = torch.abs(compute_angle(-1) - compute_angle(0)) | |
return torch.tensor([angle_diff > self.threshold]) if self.greater_than else torch.tensor([angle_diff < self.threshold]) | |
# Joint name to index mapping (for SMPL 24-joint model, common names) | |
SMPL_JOINT_NAMES = { | |
"pelvis": 0, | |
"left_hip": 1, | |
"right_hip": 2, | |
"spine1": 3, | |
"left_knee": 4, | |
"right_knee": 5, | |
"spine2": 6, | |
"left_ankle": 7, | |
"right_ankle": 8, | |
"spine3": 9, | |
"left_foot": 10, | |
"right_foot": 11, | |
"neck": 12, | |
"left_collar": 13, | |
"right_collar": 14, | |
"head": 15, | |
"left_shoulder": 16, | |
"right_shoulder": 17, | |
"left_elbow": 18, | |
"right_elbow": 19, | |
"left_wrist": 20, | |
"right_wrist": 21, | |
"left_hand": 22, | |
"right_hand": 23, | |
} | |
# Mapping from JSON "type" to class constructor | |
EVALUATOR_CLASSES = { | |
"ThreeJointAngle": ThreeJointAngleEvaluator, | |
"VectorAngle": VectorAngleEvaluator, | |
"SingleAxisComparison": SingleAxisComparisonEvaluator, | |
"JointDistance": JointDistanceEvaluator, | |
"RelativeOffsetDirection": RelativeOffsetDirectionEvaluator, | |
"VelocityThreshold": VelocityThresholdEvaluator, | |
"AccelerationThreshold": AccelerationThresholdEvaluator, | |
"AngleRange": AngleRangeEvaluator, | |
"AngleChange": AngleChangeEvaluator, | |
# PositionRange can reuse RelativeOffset with a max-min wrapper | |
} | |
def get_joint_index(name: str) -> int: | |
if name not in SMPL_JOINT_NAMES: | |
raise ValueError(f"Unknown joint name: {name}") | |
return SMPL_JOINT_NAMES[name] | |
def build_evaluator_from_json(json_data: Dict) -> BaseEvaluator: | |
etype = json_data["type"] | |
if etype == "ThreeJointAngle": | |
a = get_joint_index(json_data["joint_a"]) | |
b = get_joint_index(json_data["joint_b"]) | |
c = get_joint_index(json_data["joint_c"]) | |
return ThreeJointAngleEvaluator((a, b, c), json_data["threshold"], json_data.get("greater_than", True)) | |
elif etype == "VectorAngle": | |
a1 = get_joint_index(json_data["joint_a1"]) | |
a2 = get_joint_index(json_data["joint_a2"]) | |
b1 = get_joint_index(json_data["joint_b1"]) | |
b2 = get_joint_index(json_data["joint_b2"]) | |
return VectorAngleEvaluator((a1, a2), (b1, b2), json_data["threshold"], json_data.get("less_than", True)) | |
elif etype == "SingleAxisComparison": | |
a = get_joint_index(json_data["joint_a"]) | |
b = get_joint_index(json_data["joint_b"]) | |
return SingleAxisComparisonEvaluator(a, b, json_data["axis"], json_data.get("greater_than", True)) | |
elif etype == "JointDistance": | |
a = get_joint_index(json_data["joint_a"]) | |
b = get_joint_index(json_data["joint_b"]) | |
return JointDistanceEvaluator(a, b, json_data["threshold"], json_data.get("greater_than", True)) | |
elif etype == "RelativeOffsetDirection": | |
a = get_joint_index(json_data["joint_a"]) | |
b = get_joint_index(json_data["joint_b"]) | |
return RelativeOffsetDirectionEvaluator(a, b, json_data["axis"], json_data["threshold"], json_data.get("greater_than", True)) | |
elif etype == "VelocityThreshold": | |
j = get_joint_index(json_data["joint"]) | |
return VelocityThresholdEvaluator(j, json_data["axis"], json_data["threshold"], json_data.get("greater_than", True)) | |
elif etype == "AccelerationThreshold": | |
j = get_joint_index(json_data["joint"]) | |
return AccelerationThresholdEvaluator(j, json_data["axis"], json_data["threshold"], json_data.get("greater_than", True)) | |
elif etype == "AngleRange": | |
a = get_joint_index(json_data["joint_a"]) | |
b = get_joint_index(json_data["joint_b"]) | |
c = get_joint_index(json_data["joint_c"]) | |
return AngleRangeEvaluator((a, b, c), json_data["threshold"], json_data.get("greater_than", True)) | |
elif etype == "AngleChange": | |
a = get_joint_index(json_data["joint_a"]) | |
b = get_joint_index(json_data["joint_b"]) | |
c = get_joint_index(json_data["joint_c"]) | |
return AngleChangeEvaluator((a, b, c), json_data["frame1"], json_data["frame2"], json_data["threshold"], json_data.get("greater_than", True)) | |
else: | |
raise ValueError(f"Unknown evaluator type: {etype}") | |
# Main function: load motion tensor and json, return evaluation result | |
def evaluate_motion_from_json( | |
json_path: str, | |
motion_tensor: torch.Tensor, | |
global_orient_tensor: torch.Tensor = None | |
) -> Dict[str, List[bool]]: | |
with open(json_path, "r") as f: | |
configs = json.load(f) | |
results = {} | |
for idx, cfg in enumerate(configs): | |
try: | |
evalr = build_evaluator_from_json(cfg) | |
except: | |
continue | |
name = cfg.get("name", f"eval_{idx}") | |
# 切片 | |
if "start_frame" in cfg and "end_frame" in cfg: | |
s, e = cfg["start_frame"], cfg["end_frame"] | |
seg = motion_tensor[s - 1 : e] # [T_seg, J, 3] | |
orient_seg = None | |
if global_orient_tensor is not None: | |
orient_seg = global_orient_tensor[s - 1 : e] | |
elif "frame" in cfg: | |
# 兼容旧版 | |
f0 = cfg["frame"] | |
seg = motion_tensor[f0 - 1 : f0] | |
orient_seg = None | |
if global_orient_tensor is not None: | |
orient_seg = global_orient_tensor[f0 - 1 : f0] | |
else: | |
seg = motion_tensor | |
orient_seg = None | |
if global_orient_tensor is not None: | |
orient_seg = global_orient_tensor | |
# 调用 | |
if orient_seg is not None: | |
out = evalr.evaluate(seg, global_orient=orient_seg) | |
else: | |
out = evalr.evaluate(seg) | |
results[name] = out.tolist() | |
# 保存 | |
out_path = json_path.replace(".json", "_output.txt") | |
with open(out_path, "w") as f: | |
json.dump(results, f) | |
return results | |
if __name__ == "__main__": | |
pkl_file_path = sys.argv[1] | |
json_file_path = sys.argv[2] | |
with open(pkl_file_path, 'rb') as f: | |
pose = joblib.load(f) | |
global_orient = pose[0]['pose_world'][:, :3] | |
global_orient_tensor = torch.from_numpy(global_orient) | |
pose = pose[0]['joint'].reshape(-1, 45, 3)[:, :24, :] | |
pose = torch.from_numpy(pose) | |
evaluate_motion_from_json(json_file_path, pose, global_orient) | |