Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from ..utils.load_model import load_model | |
def make_beta(n_timestep, cosine_s=8e-3): | |
timesteps = ( | |
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s | |
) | |
alphas = timesteps / (1 + cosine_s) * np.pi / 2 | |
alphas = torch.cos(alphas).pow(2) | |
alphas = alphas / alphas[0] | |
betas = 1 - alphas[1:] / alphas[:-1] | |
betas = np.clip(betas, a_min=0, a_max=0.999) | |
return betas.numpy() | |
class LMDM: | |
def __init__(self, model_path, device="cuda", **kwargs): | |
kwargs["module_name"] = "LMDM" | |
self.model, self.model_type = load_model(model_path, device=device, **kwargs) | |
self.device = device | |
self.motion_feat_dim = kwargs.get("motion_feat_dim", 265) | |
self.audio_feat_dim = kwargs.get("audio_feat_dim", 1024+35) | |
self.seq_frames = kwargs.get("seq_frames", 80) | |
if self.model_type == "pytorch": | |
pass | |
else: | |
self._init_np() | |
def setup(self, sampling_timesteps): | |
if self.model_type == "pytorch": | |
self.model.setup(sampling_timesteps) | |
else: | |
self._setup_np(sampling_timesteps) | |
def _init_np(self): | |
self.sampling_timesteps = None | |
self.n_timestep = 1000 | |
betas = torch.Tensor(make_beta(n_timestep=self.n_timestep)) | |
alphas = 1.0 - betas | |
self.alphas_cumprod = torch.cumprod(alphas, axis=0).cpu().numpy() | |
def _setup_np(self, sampling_timesteps=50): | |
if self.sampling_timesteps == sampling_timesteps: | |
return | |
self.sampling_timesteps = sampling_timesteps | |
total_timesteps = self.n_timestep | |
eta = 1 | |
shape = (1, self.seq_frames, self.motion_feat_dim) | |
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps | |
times = list(reversed(times.int().tolist())) | |
self.time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] | |
self.time_cond_list = [] | |
self.alpha_next_sqrt_list = [] | |
self.sigma_list = [] | |
self.c_list = [] | |
self.noise_list = [] | |
for time, time_next in self.time_pairs: | |
time_cond = np.full((1,), time, dtype=np.int64) | |
self.time_cond_list.append(time_cond) | |
if time_next < 0: | |
continue | |
alpha = self.alphas_cumprod[time] | |
alpha_next = self.alphas_cumprod[time_next] | |
sigma = eta * np.sqrt((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)) | |
c = np.sqrt(1 - alpha_next - sigma ** 2) | |
noise = np.random.randn(*shape).astype(np.float32) | |
self.alpha_next_sqrt_list.append(np.sqrt(alpha_next)) | |
self.sigma_list.append(sigma) | |
self.c_list.append(c) | |
self.noise_list.append(noise) | |
def _one_step(self, x, cond_frame, cond, time_cond): | |
if self.model_type == "onnx": | |
pred = self.model.run(None, {"x": x, "cond_frame": cond_frame, "cond": cond, "time_cond": time_cond}) | |
pred_noise, x_start = pred[0], pred[1] | |
elif self.model_type == "tensorrt": | |
self.model.setup({"x": x, "cond_frame": cond_frame, "cond": cond, "time_cond": time_cond}) | |
self.model.infer() | |
pred_noise, x_start = self.model.buffer["pred_noise"][0], self.model.buffer["x_start"][0] | |
elif self.model_type == "pytorch": | |
with torch.no_grad(): | |
pred_noise, x_start = self.model(x, cond_frame, cond, time_cond) | |
else: | |
raise ValueError(f"Unsupported model type: {self.model_type}") | |
return pred_noise, x_start | |
def _call_np(self, kp_cond, aud_cond, sampling_timesteps): | |
self._setup_np(sampling_timesteps) | |
cond_frame = kp_cond | |
cond = aud_cond | |
x = np.random.randn(1, self.seq_frames, self.motion_feat_dim).astype(np.float32) | |
x_start = None | |
i = 0 | |
for _, time_next in self.time_pairs: | |
time_cond = self.time_cond_list[i] | |
pred_noise, x_start = self._one_step(x, cond_frame, cond, time_cond) | |
if time_next < 0: | |
x = x_start | |
continue | |
alpha_next_sqrt = self.alpha_next_sqrt_list[i] | |
c = self.c_list[i] | |
sigma = self.sigma_list[i] | |
noise = self.noise_list[i] | |
x = x_start * alpha_next_sqrt + c * pred_noise + sigma * noise | |
i += 1 | |
return x | |
def __call__(self, kp_cond, aud_cond, sampling_timesteps): | |
if self.model_type == "pytorch": | |
pred_kp_seq = self.model.ddim_sample( | |
torch.from_numpy(kp_cond).to(self.device), | |
torch.from_numpy(aud_cond).to(self.device), | |
sampling_timesteps, | |
).cpu().numpy() | |
else: | |
pred_kp_seq = self._call_np(kp_cond, aud_cond, sampling_timesteps) | |
return pred_kp_seq | |