oKen38461's picture
初回コミットに基づくファイルの追加
ac7cda5
import numpy as np
import torch
from ..utils.load_model import load_model
def make_beta(n_timestep, cosine_s=8e-3):
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
return betas.numpy()
class LMDM:
def __init__(self, model_path, device="cuda", **kwargs):
kwargs["module_name"] = "LMDM"
self.model, self.model_type = load_model(model_path, device=device, **kwargs)
self.device = device
self.motion_feat_dim = kwargs.get("motion_feat_dim", 265)
self.audio_feat_dim = kwargs.get("audio_feat_dim", 1024+35)
self.seq_frames = kwargs.get("seq_frames", 80)
if self.model_type == "pytorch":
pass
else:
self._init_np()
def setup(self, sampling_timesteps):
if self.model_type == "pytorch":
self.model.setup(sampling_timesteps)
else:
self._setup_np(sampling_timesteps)
def _init_np(self):
self.sampling_timesteps = None
self.n_timestep = 1000
betas = torch.Tensor(make_beta(n_timestep=self.n_timestep))
alphas = 1.0 - betas
self.alphas_cumprod = torch.cumprod(alphas, axis=0).cpu().numpy()
def _setup_np(self, sampling_timesteps=50):
if self.sampling_timesteps == sampling_timesteps:
return
self.sampling_timesteps = sampling_timesteps
total_timesteps = self.n_timestep
eta = 1
shape = (1, self.seq_frames, self.motion_feat_dim)
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = list(reversed(times.int().tolist()))
self.time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
self.time_cond_list = []
self.alpha_next_sqrt_list = []
self.sigma_list = []
self.c_list = []
self.noise_list = []
for time, time_next in self.time_pairs:
time_cond = np.full((1,), time, dtype=np.int64)
self.time_cond_list.append(time_cond)
if time_next < 0:
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = eta * np.sqrt((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha))
c = np.sqrt(1 - alpha_next - sigma ** 2)
noise = np.random.randn(*shape).astype(np.float32)
self.alpha_next_sqrt_list.append(np.sqrt(alpha_next))
self.sigma_list.append(sigma)
self.c_list.append(c)
self.noise_list.append(noise)
def _one_step(self, x, cond_frame, cond, time_cond):
if self.model_type == "onnx":
pred = self.model.run(None, {"x": x, "cond_frame": cond_frame, "cond": cond, "time_cond": time_cond})
pred_noise, x_start = pred[0], pred[1]
elif self.model_type == "tensorrt":
self.model.setup({"x": x, "cond_frame": cond_frame, "cond": cond, "time_cond": time_cond})
self.model.infer()
pred_noise, x_start = self.model.buffer["pred_noise"][0], self.model.buffer["x_start"][0]
elif self.model_type == "pytorch":
with torch.no_grad():
pred_noise, x_start = self.model(x, cond_frame, cond, time_cond)
else:
raise ValueError(f"Unsupported model type: {self.model_type}")
return pred_noise, x_start
def _call_np(self, kp_cond, aud_cond, sampling_timesteps):
self._setup_np(sampling_timesteps)
cond_frame = kp_cond
cond = aud_cond
x = np.random.randn(1, self.seq_frames, self.motion_feat_dim).astype(np.float32)
x_start = None
i = 0
for _, time_next in self.time_pairs:
time_cond = self.time_cond_list[i]
pred_noise, x_start = self._one_step(x, cond_frame, cond, time_cond)
if time_next < 0:
x = x_start
continue
alpha_next_sqrt = self.alpha_next_sqrt_list[i]
c = self.c_list[i]
sigma = self.sigma_list[i]
noise = self.noise_list[i]
x = x_start * alpha_next_sqrt + c * pred_noise + sigma * noise
i += 1
return x
def __call__(self, kp_cond, aud_cond, sampling_timesteps):
if self.model_type == "pytorch":
pred_kp_seq = self.model.ddim_sample(
torch.from_numpy(kp_cond).to(self.device),
torch.from_numpy(aud_cond).to(self.device),
sampling_timesteps,
).cpu().numpy()
else:
pred_kp_seq = self._call_np(kp_cond, aud_cond, sampling_timesteps)
return pred_kp_seq