talkingAvater_bgk / core /models /stitch_network.py
oKen38461's picture
初回コミットに基づくファイルの追加
ac7cda5
import numpy as np
import torch
from ..utils.load_model import load_model
class StitchNetwork:
def __init__(self, model_path, device="cuda"):
kwargs = {
"module_name": "StitchingNetwork",
}
self.model, self.model_type = load_model(model_path, device=device, **kwargs)
self.device = device
def __call__(self, kp_source, kp_driving):
if self.model_type == "onnx":
pred = self.model.run(None, {"kp_source": kp_source, "kp_driving": kp_driving})[0]
elif self.model_type == "tensorrt":
self.model.setup({"kp_source": kp_source, "kp_driving": kp_driving})
self.model.infer()
pred = self.model.buffer["out"][0].copy()
elif self.model_type == 'pytorch':
with torch.no_grad():
pred = self.model(
torch.from_numpy(kp_source).to(self.device),
torch.from_numpy(kp_driving).to(self.device)
).cpu().numpy()
else:
raise ValueError(f"Unsupported model type: {self.model_type}")
return pred