SaoYear
commited on
Commit
·
fe17ce1
1
Parent(s):
d49a4f8
first commit
Browse files- .gitignore +2 -0
- app.py +168 -0
- ckpts/CleanMel/offline_CleanMel_L_map.ckpt +3 -0
- ckpts/CleanMel/offline_CleanMel_L_mask.ckpt +3 -0
- ckpts/Vocos/vocos_offline.pt +3 -0
- configs/cleanmel_offline.yaml +67 -0
- configs/vocos_offline.yaml +44 -0
- model/cleanmel.py +401 -0
- model/stft.py +154 -0
- model/vocos/__init__.py +1 -0
- model/vocos/dataset.py +93 -0
- model/vocos/discriminators.py +211 -0
- model/vocos/experiment.py +398 -0
- model/vocos/feature_extractors.py +170 -0
- model/vocos/heads.py +164 -0
- model/vocos/helpers.py +71 -0
- model/vocos/loss.py +114 -0
- model/vocos/models.py +118 -0
- model/vocos/modules.py +213 -0
- model/vocos/pretrained.py +162 -0
- model/vocos/spectral_ops.py +192 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.gradio
|
app.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import spaces
|
3 |
+
import tempfile
|
4 |
+
import soundfile as sf
|
5 |
+
import gradio as gr
|
6 |
+
import librosa as lb
|
7 |
+
import yaml
|
8 |
+
import numpy as np
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from model.cleanmel import CleanMel
|
11 |
+
from model.vocos.pretrained import Vocos
|
12 |
+
from model.stft import InputSTFT, TargetMel
|
13 |
+
|
14 |
+
DEVICE = torch.device("cuda:5")
|
15 |
+
|
16 |
+
def read_audio(file_path):
|
17 |
+
audio, sample_rate = sf.read(file_path)
|
18 |
+
if audio.ndim > 1:
|
19 |
+
audio = audio[:, 0]
|
20 |
+
if sample_rate != 16000:
|
21 |
+
audio = lb.resample(audio, orig_sr=sample_rate, target_sr=16000)
|
22 |
+
sample_rate = 16000
|
23 |
+
|
24 |
+
return torch.tensor(audio).float().squeeze().unsqueeze(0)
|
25 |
+
|
26 |
+
def stft(audio):
|
27 |
+
transform = InputSTFT(
|
28 |
+
n_fft=512,
|
29 |
+
n_win=512,
|
30 |
+
n_hop=128,
|
31 |
+
normalize=False,
|
32 |
+
center=True,
|
33 |
+
onesided=True,
|
34 |
+
online=False
|
35 |
+
).eval().to(DEVICE)
|
36 |
+
return transform(audio)
|
37 |
+
|
38 |
+
def mel_transform(audio, X_norm):
|
39 |
+
transform = TargetMel(
|
40 |
+
sample_rate=16000,
|
41 |
+
n_fft=512,
|
42 |
+
n_win=512,
|
43 |
+
n_hop=128,
|
44 |
+
n_mels=80,
|
45 |
+
f_min=0,
|
46 |
+
f_max=8000,
|
47 |
+
power=2,
|
48 |
+
center=True,
|
49 |
+
normalize=False,
|
50 |
+
onesided=True,
|
51 |
+
mel_norm="slaney",
|
52 |
+
mel_scale="slaney",
|
53 |
+
librosa_mel=True,
|
54 |
+
online=False
|
55 |
+
).eval().to(DEVICE)
|
56 |
+
return transform(audio, X_norm)
|
57 |
+
|
58 |
+
def load_cleanmel(model_name):
|
59 |
+
model_config = f"./configs/cleanmel_offline.yaml"
|
60 |
+
model_config = yaml.safe_load(open(model_config, "r"))["model"]["arch"]["init_args"]
|
61 |
+
cleanmel = CleanMel(**model_config)
|
62 |
+
cleanmel.load_state_dict(torch.load(f"./ckpts/CleanMel/{model_name}.ckpt"))
|
63 |
+
return cleanmel.eval()
|
64 |
+
|
65 |
+
def load_vocos():
|
66 |
+
vocos = Vocos.from_hparams(config_path="./configs/vocos_offline.yaml")
|
67 |
+
vocos = Vocos.from_pretrained(None, model_path=f"./ckpts/Vocos/vocos_offline.pt", model=vocos)
|
68 |
+
return vocos.eval()
|
69 |
+
|
70 |
+
def get_mrm_pred(Y_hat, x, X_norm):
|
71 |
+
X_noisy = mel_transform(x, X_norm)
|
72 |
+
Y_hat = Y_hat.squeeze()
|
73 |
+
Y_hat = torch.square(Y_hat * (torch.sqrt(X_noisy) + 1e-10))
|
74 |
+
return Y_hat
|
75 |
+
|
76 |
+
def safe_log(x):
|
77 |
+
return torch.log(torch.clip(x, min=1e-5))
|
78 |
+
|
79 |
+
def output(y_hat, logMel_hat):
|
80 |
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
|
81 |
+
sf.write(tmp_file.name, y_hat.squeeze().cpu().numpy(), 16000)
|
82 |
+
with tempfile.NamedTemporaryFile(suffix='.npy', delete=False) as tmp_logmel_np_file:
|
83 |
+
np.save(tmp_logmel_np_file.name, logMel_hat.squeeze().cpu().numpy())
|
84 |
+
logMel_img = logMel_hat.squeeze().cpu().numpy()[::-1, :]
|
85 |
+
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_logmel_img:
|
86 |
+
# give a plt figure size according to the logMel shape
|
87 |
+
plt.figure(figsize=(logMel_img.shape[1] / 100, logMel_img.shape[0] / 50))
|
88 |
+
plt.clf()
|
89 |
+
plt.imshow(logMel_img, vmin=-11, cmap="jet")
|
90 |
+
plt.tight_layout()
|
91 |
+
plt.ylabel("Mel bands")
|
92 |
+
plt.xlabel("Time (second)")
|
93 |
+
plt.yticks([0, 80], [80, 0])
|
94 |
+
dur = y_hat.shape[-1] / 16000
|
95 |
+
xticks = [int(x) for x in np.linspace(0, logMel_img.shape[-1], 11)]
|
96 |
+
xticks_str = ["{:.1f}".format(x) for x in np.linspace(0, dur, 11)]
|
97 |
+
plt.xticks(xticks, xticks_str)
|
98 |
+
plt.savefig(tmp_logmel_img.name)
|
99 |
+
|
100 |
+
return tmp_file.name, tmp_logmel_img.name, tmp_logmel_np_file.name
|
101 |
+
|
102 |
+
@spaces.GPU
|
103 |
+
@torch.inference_mode()
|
104 |
+
def enhance_cleanmel_L_mask(audio_path):
|
105 |
+
model = load_cleanmel("offline_CleanMel_L_mask").to(DEVICE)
|
106 |
+
vocos = load_vocos().to(DEVICE)
|
107 |
+
x = read_audio(audio_path).to(DEVICE)
|
108 |
+
X, X_norm = stft(x)
|
109 |
+
Y_hat = model(X, inference=True)
|
110 |
+
MRM_hat = torch.sigmoid(Y_hat)
|
111 |
+
Y_hat = get_mrm_pred(MRM_hat, x, X_norm)
|
112 |
+
logMel_hat = safe_log(Y_hat)
|
113 |
+
y_hat = vocos(logMel_hat, X_norm).clamp(min=-1, max=1)
|
114 |
+
return output(y_hat, logMel_hat)
|
115 |
+
|
116 |
+
@spaces.GPU
|
117 |
+
@torch.inference_mode()
|
118 |
+
def enhance_cleanmel_L_map(audio_path):
|
119 |
+
model = load_cleanmel("offline_CleanMel_L_map").to(DEVICE)
|
120 |
+
vocos = load_vocos().to(DEVICE)
|
121 |
+
x = read_audio(audio_path).to(DEVICE)
|
122 |
+
X, X_norm = stft(x)
|
123 |
+
logMel_hat = model(X, inference=True)
|
124 |
+
y_hat = vocos(logMel_hat, X_norm).clamp(min=-1, max=1)
|
125 |
+
return output(y_hat, logMel_hat)
|
126 |
+
|
127 |
+
def reset_everything():
|
128 |
+
"""Reset all components to initial state"""
|
129 |
+
return None, None, None
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
demo = gr.Blocks()
|
133 |
+
with gr.Blocks(title="CleanMel Demo") as demo:
|
134 |
+
gr.Markdown("## CleanMel Demo")
|
135 |
+
gr.Markdown("This demo showcases the CleanMel model for speech enhancement.")
|
136 |
+
|
137 |
+
with gr.Row():
|
138 |
+
audio_input = gr.Audio(label="Input Audio", type="filepath", sources="upload")
|
139 |
+
with gr.Column():
|
140 |
+
enhance_button_map = gr.Button("Enhance Audio (offline CleanMel_L_map)")
|
141 |
+
enhance_button_mask = gr.Button("Enhance Audio (offline CleanMel_L_mask)")
|
142 |
+
clear_btn = gr.Button(
|
143 |
+
"🗑️ Clear All",
|
144 |
+
variant="secondary",
|
145 |
+
size="lg"
|
146 |
+
)
|
147 |
+
|
148 |
+
output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
|
149 |
+
output_mel = gr.Image(label="Output LogMel Spectrogram", type="filepath", visible=True)
|
150 |
+
output_np = gr.File(label="Enhanced LogMel Spec. (.npy)", type="filepath")
|
151 |
+
|
152 |
+
enhance_button_map.click(
|
153 |
+
enhance_cleanmel_L_map,
|
154 |
+
inputs=audio_input,
|
155 |
+
outputs=[output_audio, output_mel, output_np]
|
156 |
+
)
|
157 |
+
|
158 |
+
enhance_button_mask.click(
|
159 |
+
enhance_cleanmel_L_mask,
|
160 |
+
inputs=audio_input,
|
161 |
+
outputs=[output_audio, output_mel, output_np]
|
162 |
+
)
|
163 |
+
clear_btn.click(
|
164 |
+
fn=reset_everything,
|
165 |
+
outputs=[output_audio, output_mel, output_np]
|
166 |
+
)
|
167 |
+
|
168 |
+
demo.launch(debug=False, share=True)
|
ckpts/CleanMel/offline_CleanMel_L_map.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd0a5118f0c57c91c521564f8275f3a731e10a5afdba9859a3b997067e1eefdd
|
3 |
+
size 29065251
|
ckpts/CleanMel/offline_CleanMel_L_mask.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:42ecba65266fe11d1beb97138c1e83146a57da06daf075a367e902e151f73afa
|
3 |
+
size 29065251
|
ckpts/Vocos/vocos_offline.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f55c3f8cd69ab92e97578d5a99f7207ece6e2c08ccde801a8402e9fa77b72d14
|
3 |
+
size 223334266
|
configs/cleanmel_offline.yaml
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seed_everything: 2
|
2 |
+
|
3 |
+
trainer:
|
4 |
+
gradient_clip_val: 10
|
5 |
+
gradient_clip_algorithm: norm
|
6 |
+
devices: null
|
7 |
+
accelerator: gpu
|
8 |
+
strategy: ddp_find_unused_parameters_false
|
9 |
+
sync_batchnorm: false
|
10 |
+
precision: 32
|
11 |
+
num_sanity_val_steps: 3
|
12 |
+
deterministic: true
|
13 |
+
max_epochs: 100
|
14 |
+
log_every_n_steps: 40
|
15 |
+
|
16 |
+
model:
|
17 |
+
arch:
|
18 |
+
class_path: model.arch.cleanmel.CleanMel
|
19 |
+
init_args:
|
20 |
+
dim_input: 2
|
21 |
+
dim_output: 1
|
22 |
+
n_layers: 16
|
23 |
+
dim_hidden: 144
|
24 |
+
layer_linear_freq: 1
|
25 |
+
f_kernel_size: 5
|
26 |
+
f_conv_groups: 8
|
27 |
+
n_freqs: 257
|
28 |
+
n_mels: 80
|
29 |
+
mamba_state: 16
|
30 |
+
mamba_conv_kernel: 4
|
31 |
+
online: false
|
32 |
+
sr: 16000
|
33 |
+
n_fft: 512
|
34 |
+
input_stft:
|
35 |
+
class_path: model.io.stft.InputSTFT
|
36 |
+
init_args:
|
37 |
+
n_fft: 512
|
38 |
+
n_win: 512
|
39 |
+
n_hop: 128
|
40 |
+
center: true
|
41 |
+
normalize: false
|
42 |
+
onesided: true
|
43 |
+
online: false
|
44 |
+
target_stft:
|
45 |
+
class_path: model.io.stft.TargetMel
|
46 |
+
init_args:
|
47 |
+
sample_rate: 16000
|
48 |
+
n_fft: 512
|
49 |
+
n_win: 512
|
50 |
+
n_hop: 128
|
51 |
+
n_mels: 80
|
52 |
+
f_min: 0
|
53 |
+
f_max: 8000
|
54 |
+
power: 2
|
55 |
+
center: true
|
56 |
+
normalize: false
|
57 |
+
onesided: true
|
58 |
+
mel_norm: "slaney"
|
59 |
+
mel_scale: "slaney"
|
60 |
+
librosa_mel: true
|
61 |
+
online: false
|
62 |
+
|
63 |
+
optimizer: [AdamW, { lr: 0.001, weight_decay: 0.001}]
|
64 |
+
lr_scheduler: [ExponentialLR, { gamma: 0.99 }]
|
65 |
+
exp_name: exp
|
66 |
+
metrics: [DNSMOS]
|
67 |
+
log_eps: 1e-5
|
configs/vocos_offline.yaml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
feature_extractor:
|
2 |
+
class_path: model.vocos.feature_extractors.MelSpectrogramFeatures
|
3 |
+
init_args:
|
4 |
+
sample_rate: 16000
|
5 |
+
n_fft: 512
|
6 |
+
n_win: 512
|
7 |
+
n_hop: 128
|
8 |
+
n_mels: 80
|
9 |
+
f_min: 0
|
10 |
+
f_max: 8000
|
11 |
+
power: 2
|
12 |
+
center: true
|
13 |
+
normalize: false
|
14 |
+
onesided: true
|
15 |
+
mel_norm: slaney
|
16 |
+
mel_scale: slaney
|
17 |
+
librosa_mel: true
|
18 |
+
clip_val: 0.00001
|
19 |
+
backbone:
|
20 |
+
class_path: model.vocos.models.VocosBackbone
|
21 |
+
init_args:
|
22 |
+
input_channels: 80
|
23 |
+
dim: 512
|
24 |
+
intermediate_dim: 1536
|
25 |
+
num_layers: 8
|
26 |
+
layer_scale_init_value: null
|
27 |
+
adanorm_num_embeddings: null
|
28 |
+
head:
|
29 |
+
class_path: model.vocos.heads.ISTFTHead
|
30 |
+
init_args:
|
31 |
+
dim: 512
|
32 |
+
n_fft: 512
|
33 |
+
hop_length: 128
|
34 |
+
padding: center
|
35 |
+
sample_rate: 16000
|
36 |
+
initial_learning_rate: 0.0005
|
37 |
+
num_warmup_steps: 0
|
38 |
+
mel_loss_coeff: 45.0
|
39 |
+
mrd_loss_coeff: 0.1
|
40 |
+
pretrain_mel_steps: 0
|
41 |
+
decay_mel_coeff: false
|
42 |
+
evaluate_utmos: true
|
43 |
+
evaluate_pesq: true
|
44 |
+
evaluate_periodicty: true
|
model/cleanmel.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import pytorch_lightning
|
8 |
+
import librosa
|
9 |
+
|
10 |
+
from torch import Tensor
|
11 |
+
from torch.nn import Parameter, init
|
12 |
+
from torch.nn.common_types import _size_1_t
|
13 |
+
|
14 |
+
from mamba_ssm import Mamba
|
15 |
+
from mamba_ssm.utils.generation import InferenceParams
|
16 |
+
|
17 |
+
class LinearGroup(nn.Module):
|
18 |
+
|
19 |
+
def __init__(self, in_features: int, out_features: int, num_groups: int, bias: bool = True) -> None:
|
20 |
+
super(LinearGroup, self).__init__()
|
21 |
+
self.in_features = in_features
|
22 |
+
self.out_features = out_features
|
23 |
+
self.num_groups = num_groups
|
24 |
+
self.weight = Parameter(torch.empty((num_groups, out_features, in_features)))
|
25 |
+
if bias:
|
26 |
+
self.bias = Parameter(torch.empty(num_groups, out_features))
|
27 |
+
else:
|
28 |
+
self.register_parameter('bias', None)
|
29 |
+
self.reset_parameters()
|
30 |
+
|
31 |
+
def reset_parameters(self) -> None:
|
32 |
+
# same as linear
|
33 |
+
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
34 |
+
if self.bias is not None:
|
35 |
+
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
36 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
37 |
+
init.uniform_(self.bias, -bound, bound)
|
38 |
+
|
39 |
+
def forward(self, x: Tensor) -> Tensor:
|
40 |
+
"""shape [..., group, feature]"""
|
41 |
+
x = torch.einsum("...gh,gkh->...gk", x, self.weight)
|
42 |
+
if self.bias is not None:
|
43 |
+
x = x + self.bias
|
44 |
+
return x
|
45 |
+
|
46 |
+
def extra_repr(self) -> str:
|
47 |
+
return f"{self.in_features}, {self.out_features}, num_groups={self.num_groups}, bias={True if self.bias is not None else False}"
|
48 |
+
|
49 |
+
class LayerNorm(nn.LayerNorm):
|
50 |
+
|
51 |
+
def __init__(self, seq_last: bool, **kwargs) -> None:
|
52 |
+
"""
|
53 |
+
Arg s:
|
54 |
+
seq_last (bool): whether the sequence dim is the last dim
|
55 |
+
"""
|
56 |
+
super().__init__(**kwargs)
|
57 |
+
self.seq_last = seq_last
|
58 |
+
|
59 |
+
def forward(self, input: Tensor) -> Tensor:
|
60 |
+
if self.seq_last:
|
61 |
+
input = input.transpose(-1, 1) # [B, H, Seq] -> [B, Seq, H], or [B,H,w,h] -> [B,h,w,H]
|
62 |
+
o = super().forward(input)
|
63 |
+
if self.seq_last:
|
64 |
+
o = o.transpose(-1, 1)
|
65 |
+
return o
|
66 |
+
|
67 |
+
class CausalConv1d(nn.Conv1d):
|
68 |
+
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
in_channels: int,
|
72 |
+
out_channels: int,
|
73 |
+
kernel_size: _size_1_t,
|
74 |
+
stride: _size_1_t = 1,
|
75 |
+
padding: _size_1_t | str = 0,
|
76 |
+
dilation: _size_1_t = 1,
|
77 |
+
groups: int = 1,
|
78 |
+
bias: bool = True,
|
79 |
+
padding_mode: str = 'zeros',
|
80 |
+
device=None,
|
81 |
+
dtype=None,
|
82 |
+
look_ahead: int = 0,
|
83 |
+
) -> None:
|
84 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
|
85 |
+
self.look_ahead = look_ahead
|
86 |
+
assert look_ahead <= self.kernel_size[0] - 1, (look_ahead, self.kernel_size)
|
87 |
+
|
88 |
+
def forward(self, x: Tensor, state: Dict[int, Any] = None) -> Tensor:
|
89 |
+
# x [B,H,T]
|
90 |
+
B, H, T = x.shape
|
91 |
+
if state is None or id(self) not in state:
|
92 |
+
x = F.pad(x, pad=(self.kernel_size[0] - 1 - self.look_ahead, self.look_ahead))
|
93 |
+
else:
|
94 |
+
x = torch.concat([state[id(self)], x], dim=-1)
|
95 |
+
if state is not None:
|
96 |
+
state[id(self)] = x[..., -self.kernel_size + 1:]
|
97 |
+
x = super().forward(x)
|
98 |
+
return x
|
99 |
+
|
100 |
+
class CleanMelLayer(nn.Module):
|
101 |
+
|
102 |
+
def __init__(
|
103 |
+
self,
|
104 |
+
dim_hidden: int,
|
105 |
+
dim_squeeze: int,
|
106 |
+
n_freqs: int,
|
107 |
+
dropout: Tuple[float, float, float] = (0, 0, 0),
|
108 |
+
f_kernel_size: int = 5,
|
109 |
+
f_conv_groups: int = 8,
|
110 |
+
padding: str = 'zeros',
|
111 |
+
full: nn.Module = None,
|
112 |
+
mamba_state: int = None,
|
113 |
+
mamba_conv_kernel: int = None,
|
114 |
+
online: bool = False,
|
115 |
+
) -> None:
|
116 |
+
super().__init__()
|
117 |
+
self.online = online
|
118 |
+
# cross-band block
|
119 |
+
# frequency-convolutional module
|
120 |
+
self.fconv1 = nn.ModuleList([
|
121 |
+
LayerNorm(seq_last=True, normalized_shape=dim_hidden),
|
122 |
+
nn.Conv1d(in_channels=dim_hidden, out_channels=dim_hidden, kernel_size=f_kernel_size, groups=f_conv_groups, padding='same', padding_mode=padding),
|
123 |
+
nn.PReLU(dim_hidden),
|
124 |
+
])
|
125 |
+
# full-band linear module
|
126 |
+
self.norm_full = LayerNorm(seq_last=False, normalized_shape=dim_hidden)
|
127 |
+
self.full_share = False if full == None else True
|
128 |
+
self.squeeze = nn.Sequential(nn.Conv1d(in_channels=dim_hidden, out_channels=dim_squeeze, kernel_size=1), nn.SiLU())
|
129 |
+
self.dropout_full = nn.Dropout2d(dropout[2]) if dropout[2] > 0 else None
|
130 |
+
self.full = LinearGroup(n_freqs, n_freqs, num_groups=dim_squeeze) if full == None else full
|
131 |
+
self.unsqueeze = nn.Sequential(nn.Conv1d(in_channels=dim_squeeze, out_channels=dim_hidden, kernel_size=1), nn.SiLU())
|
132 |
+
# frequency-convolutional module
|
133 |
+
self.fconv2 = nn.ModuleList([
|
134 |
+
LayerNorm(seq_last=True, normalized_shape=dim_hidden),
|
135 |
+
nn.Conv1d(in_channels=dim_hidden, out_channels=dim_hidden, kernel_size=f_kernel_size, groups=f_conv_groups, padding='same', padding_mode=padding),
|
136 |
+
nn.PReLU(dim_hidden),
|
137 |
+
])
|
138 |
+
|
139 |
+
# narrow-band block
|
140 |
+
self.norm_mamba = LayerNorm(seq_last=False, normalized_shape=dim_hidden)
|
141 |
+
if online:
|
142 |
+
self.mamba = Mamba(d_model=dim_hidden, d_state=mamba_state, d_conv=mamba_conv_kernel, layer_idx=0)
|
143 |
+
else:
|
144 |
+
self.mamba = nn.ModuleList([
|
145 |
+
Mamba(d_model=dim_hidden, d_state=mamba_state, d_conv=mamba_conv_kernel, layer_idx=0),
|
146 |
+
Mamba(d_model=dim_hidden, d_state=mamba_state, d_conv=mamba_conv_kernel, layer_idx=1),
|
147 |
+
])
|
148 |
+
|
149 |
+
self.dropout_mamba = nn.Dropout(dropout[0])
|
150 |
+
|
151 |
+
def forward(self, x: Tensor, inference: bool = False) -> Tensor:
|
152 |
+
x = x + self._fconv(self.fconv1, x)
|
153 |
+
x = x + self._full(x)
|
154 |
+
x = x + self._fconv(self.fconv2, x)
|
155 |
+
if self.online:
|
156 |
+
x = x + self._mamba(x, self.mamba, self.norm_mamba, self.dropout_mamba, inference)
|
157 |
+
else:
|
158 |
+
x_fw = x + self._mamba(x, self.mamba[0], self.norm_mamba, self.dropout_mamba, inference)
|
159 |
+
x_bw = x.flip(dims=[2]) + self._mamba(x.flip(dims=[2]), self.mamba[1], self.norm_mamba, self.dropout_mamba, inference)
|
160 |
+
x = (x_fw + x_bw.flip(dims=[2])) / 2
|
161 |
+
return x
|
162 |
+
|
163 |
+
def _mamba(self, x: Tensor, mamba: Mamba, norm: nn.Module, dropout: nn.Module, inference: bool = False):
|
164 |
+
B, F, T, H = x.shape
|
165 |
+
x = norm(x)
|
166 |
+
x = x.reshape(B * F, T, H)
|
167 |
+
if inference:
|
168 |
+
inference_params = InferenceParams(T, B * F)
|
169 |
+
xs = []
|
170 |
+
for i in range(T):
|
171 |
+
inference_params.seqlen_offset = i
|
172 |
+
xi = mamba.forward(x[:, [i], :], inference_params)
|
173 |
+
xs.append(xi)
|
174 |
+
x = torch.concat(xs, dim=1)
|
175 |
+
else:
|
176 |
+
x = mamba.forward(x)
|
177 |
+
x = x.reshape(B, F, T, H)
|
178 |
+
return dropout(x)
|
179 |
+
|
180 |
+
def _fconv(self, ml: nn.ModuleList, x: Tensor) -> Tensor:
|
181 |
+
B, F, T, H = x.shape
|
182 |
+
x = x.permute(0, 2, 3, 1) # [B,T,H,F]
|
183 |
+
x = x.reshape(B * T, H, F)
|
184 |
+
for m in ml:
|
185 |
+
x = m(x)
|
186 |
+
x = x.reshape(B, T, H, F)
|
187 |
+
x = x.permute(0, 3, 1, 2) # [B,F,T,H]
|
188 |
+
return x
|
189 |
+
|
190 |
+
def _full(self, x: Tensor) -> Tensor:
|
191 |
+
B, F, T, H = x.shape
|
192 |
+
x = self.norm_full(x)
|
193 |
+
x = x.permute(0, 2, 3, 1) # [B,T,H,F]
|
194 |
+
x = x.reshape(B * T, H, F)
|
195 |
+
x = self.squeeze(x) # [B*T,H',F]
|
196 |
+
if self.dropout_full:
|
197 |
+
x = x.reshape(B, T, -1, F)
|
198 |
+
x = x.transpose(1, 3) # [B,F,H',T]
|
199 |
+
x = self.dropout_full(x) # dropout some frequencies in one utterance
|
200 |
+
x = x.transpose(1, 3) # [B,T,H',F]
|
201 |
+
x = x.reshape(B * T, -1, F)
|
202 |
+
x = self.full(x) # [B*T,H',F]
|
203 |
+
x = self.unsqueeze(x) # [B*T,H,F]
|
204 |
+
x = x.reshape(B, T, H, F)
|
205 |
+
x = x.permute(0, 3, 1, 2) # [B,F,T,H]
|
206 |
+
return x
|
207 |
+
|
208 |
+
def extra_repr(self) -> str:
|
209 |
+
return f"full_share={self.full_share}"
|
210 |
+
|
211 |
+
|
212 |
+
class CleanMel(nn.Module):
|
213 |
+
|
214 |
+
def __init__(
|
215 |
+
self,
|
216 |
+
dim_input: int, # the input dim for each time-frequency point
|
217 |
+
dim_output: int, # the output dim for each time-frequency point
|
218 |
+
n_layers: int,
|
219 |
+
n_freqs: int,
|
220 |
+
n_mels: int = 80,
|
221 |
+
layer_linear_freq: int = 1,
|
222 |
+
encoder_kernel_size: int = 5,
|
223 |
+
dim_hidden: int = 192,
|
224 |
+
dropout: Tuple[float, float, float] = (0, 0, 0),
|
225 |
+
f_kernel_size: int = 5,
|
226 |
+
f_conv_groups: int = 8,
|
227 |
+
padding: str = 'zeros',
|
228 |
+
mamba_state: int = 16,
|
229 |
+
mamba_conv_kernel: int = 4,
|
230 |
+
online: bool = True,
|
231 |
+
sr: int = 16000,
|
232 |
+
n_fft: int = 512,
|
233 |
+
):
|
234 |
+
super().__init__()
|
235 |
+
self.layer_linear_freq = layer_linear_freq
|
236 |
+
self.online = online
|
237 |
+
# encoder
|
238 |
+
self.encoder = CausalConv1d(in_channels=dim_input, out_channels=dim_hidden, kernel_size=encoder_kernel_size, look_ahead=0)
|
239 |
+
# cleanmel layers
|
240 |
+
full = None
|
241 |
+
layers = []
|
242 |
+
for l in range(n_layers):
|
243 |
+
layer = CleanMelLayer(
|
244 |
+
dim_hidden=dim_hidden,
|
245 |
+
dim_squeeze=8 if l < layer_linear_freq else dim_hidden,
|
246 |
+
n_freqs=n_freqs if l < layer_linear_freq else n_mels,
|
247 |
+
dropout=dropout,
|
248 |
+
f_kernel_size=f_kernel_size,
|
249 |
+
f_conv_groups=f_conv_groups,
|
250 |
+
padding=padding,
|
251 |
+
full=full if l > layer_linear_freq else None,
|
252 |
+
online=online,
|
253 |
+
mamba_conv_kernel=mamba_conv_kernel,
|
254 |
+
mamba_state=mamba_state,
|
255 |
+
)
|
256 |
+
if hasattr(layer, 'full'):
|
257 |
+
full = layer.full
|
258 |
+
layers.append(layer)
|
259 |
+
self.layers = nn.ModuleList(layers)
|
260 |
+
# Mel filterbank
|
261 |
+
linear2mel = librosa.filters.mel(**{"sr": sr, "n_fft": n_fft, "n_mels": n_mels})
|
262 |
+
self.register_buffer("linear2mel", torch.nn.Parameter(torch.tensor(linear2mel.T, dtype=torch.float32)))
|
263 |
+
# decoder
|
264 |
+
self.decoder = nn.Linear(in_features=dim_hidden, out_features=dim_output)
|
265 |
+
|
266 |
+
def forward(self, x: Tensor, inference: bool = False) -> Tensor:
|
267 |
+
# x: [Batch, Freq, Time, Feature]
|
268 |
+
B, F, T, H0 = x.shape
|
269 |
+
x = self.encoder(x.reshape(B * F, T, H0).permute(0, 2, 1)).permute(0, 2, 1)
|
270 |
+
|
271 |
+
H = x.shape[2]
|
272 |
+
x = x.reshape(B, F, T, H)
|
273 |
+
# First Cross-Narrow band block in Linear Frequency
|
274 |
+
for i in range(self.layer_linear_freq):
|
275 |
+
m = self.layers[i]
|
276 |
+
x = m(x, inference).contiguous()
|
277 |
+
|
278 |
+
# Mel-filterbank
|
279 |
+
x = torch.einsum("bfth,fm->bmth", x, self.linear2mel)
|
280 |
+
|
281 |
+
for i in range(self.layer_linear_freq, len(self.layers)):
|
282 |
+
m = self.layers[i]
|
283 |
+
x = m(x, inference).contiguous()
|
284 |
+
|
285 |
+
y = self.decoder(x).squeeze(-1)
|
286 |
+
return y.contiguous()
|
287 |
+
|
288 |
+
if __name__ == '__main__':
|
289 |
+
# a quick demo here for the CleanMel model
|
290 |
+
# input: wavs
|
291 |
+
# output: enhanced log-mel spectrogram
|
292 |
+
pytorch_lightning.seed_everything(1234)
|
293 |
+
import soundfile as sf
|
294 |
+
import matplotlib.pyplot as plt
|
295 |
+
import numpy as np
|
296 |
+
from model.io.stft import InputSTFT
|
297 |
+
from model.io.stft import TargetMel
|
298 |
+
from torch.utils.flop_counter import FlopCounterMode
|
299 |
+
|
300 |
+
online=False
|
301 |
+
# Define input STFT and target Mel
|
302 |
+
stft = InputSTFT(
|
303 |
+
n_fft=512,
|
304 |
+
n_win=512,
|
305 |
+
n_hop=128,
|
306 |
+
center=True,
|
307 |
+
normalize=False,
|
308 |
+
onesided=True,
|
309 |
+
online=online).to("cuda")
|
310 |
+
|
311 |
+
target_mel = TargetMel(
|
312 |
+
sample_rate=16000,
|
313 |
+
n_fft=512,
|
314 |
+
n_win=512,
|
315 |
+
n_hop=128,
|
316 |
+
n_mels=80,
|
317 |
+
f_min=0,
|
318 |
+
f_max=8000,
|
319 |
+
power=2,
|
320 |
+
center=True,
|
321 |
+
normalize=False,
|
322 |
+
onesided=True,
|
323 |
+
mel_norm="slaney",
|
324 |
+
mel_scale="slaney",
|
325 |
+
librosa_mel=True,
|
326 |
+
online=online).to("cuda")
|
327 |
+
|
328 |
+
def customize_soxnorm(wav, gain=-3, factor=None):
|
329 |
+
wav = np.clip(wav, a_max=1, a_min=-1)
|
330 |
+
if factor is None:
|
331 |
+
linear_gain = 10 ** (gain / 20)
|
332 |
+
factor = linear_gain / np.abs(wav).max()
|
333 |
+
wav = wav * factor
|
334 |
+
return wav, factor
|
335 |
+
else:
|
336 |
+
wav = wav * factor
|
337 |
+
return wav, None
|
338 |
+
|
339 |
+
# Noisy file path
|
340 |
+
wav = "./src/demos/noisy_CHIME-real_F05_442C020S_STR_REAL.wav"
|
341 |
+
wavname = wav.split("/")[-1].split(".")[0]
|
342 |
+
|
343 |
+
print(f"Processing {wav}")
|
344 |
+
noisy, fs = sf.read(wav)
|
345 |
+
dur = len(noisy) / fs
|
346 |
+
noisy, factor = customize_soxnorm(noisy, gain=-3)
|
347 |
+
noisy = torch.tensor(noisy).unsqueeze(0).float().to("cuda")
|
348 |
+
# vocos norm
|
349 |
+
x = stft(noisy)
|
350 |
+
# Load the model
|
351 |
+
hidden=96
|
352 |
+
depth=8
|
353 |
+
model = CleanMel(
|
354 |
+
dim_input=2,
|
355 |
+
dim_output=1,
|
356 |
+
n_layers=depth,
|
357 |
+
dim_hidden=hidden,
|
358 |
+
layer_linear_freq=1,
|
359 |
+
f_kernel_size=5,
|
360 |
+
f_conv_groups=8,
|
361 |
+
n_freqs=257,
|
362 |
+
mamba_state=16,
|
363 |
+
mamba_conv_kernel=4,
|
364 |
+
online=online,
|
365 |
+
sr=16000,
|
366 |
+
n_fft=512
|
367 |
+
).to("cuda")
|
368 |
+
|
369 |
+
# Load the pretrained model
|
370 |
+
state_dict = torch.load("./pretrained/CleanMel_S_L1.ckpt")
|
371 |
+
model.load_state_dict(state_dict)
|
372 |
+
|
373 |
+
model.eval()
|
374 |
+
with FlopCounterMode(model, display=False) as fcm:
|
375 |
+
y_hat = model(x, inference=False)
|
376 |
+
flops_forward_eval = fcm.get_total_flops()
|
377 |
+
params_eval = sum(param.numel() for param in model.parameters())
|
378 |
+
print(f"flops_forward={flops_forward_eval/1e9 / dur:.2f}G")
|
379 |
+
print(f"params={params_eval/1e6:.2f} M")
|
380 |
+
|
381 |
+
# y_hat is the enhanced log-mel spectrogram
|
382 |
+
y_hat = y_hat[0].cpu().detach().numpy()
|
383 |
+
|
384 |
+
# sanity check
|
385 |
+
if wavname == "noisy_CHIME-real_F05_442C020S_STR_REAL":
|
386 |
+
assert np.allclose(y_hat, np.load("./src/inference/check_CHIME-real_F05_442C020S_STR_REAL.npy"), atol=1e-5)
|
387 |
+
|
388 |
+
# plot the enhanced mel spectrogram
|
389 |
+
noisy_mel = target_mel(noisy)
|
390 |
+
noisy_mel = torch.log(noisy_mel.clamp(min=1e-5))[0].cpu().detach().numpy()
|
391 |
+
vmax = math.log(1e2)
|
392 |
+
vmin = math.log(1e-5)
|
393 |
+
plt.figure(figsize=(8, 4))
|
394 |
+
plt.subplot(2, 1, 1)
|
395 |
+
plt.imshow(noisy_mel, aspect='auto', origin='lower', cmap='jet', vmax=vmax, vmin=vmin)
|
396 |
+
plt.colorbar()
|
397 |
+
plt.subplot(2, 1, 2)
|
398 |
+
plt.imshow(y_hat, aspect='auto', origin='lower', cmap='jet', vmax=vmax, vmin=vmin)
|
399 |
+
plt.colorbar()
|
400 |
+
plt.tight_layout()
|
401 |
+
plt.savefig(f"./src/inference/{wavname}.png")
|
model/stft.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import librosa
|
3 |
+
import torch.nn as nn
|
4 |
+
import random
|
5 |
+
from torch import Tensor
|
6 |
+
from typing import Optional
|
7 |
+
from torchaudio.transforms import Spectrogram
|
8 |
+
from torchaudio.transforms import Spectrogram, MelScale
|
9 |
+
|
10 |
+
def soxnorm(wav: torch.Tensor, gain, factor=None):
|
11 |
+
"""sox norm, used in Vocos codes;
|
12 |
+
"""
|
13 |
+
wav = torch.clip(wav, max=1, min=-1).float()
|
14 |
+
if factor is None:
|
15 |
+
linear_gain = 10 ** (gain / 20)
|
16 |
+
factor = linear_gain / torch.abs(wav).max().item()
|
17 |
+
wav = wav * factor
|
18 |
+
else:
|
19 |
+
# for clean speech, normed by the noisy factor
|
20 |
+
wav = wav * factor
|
21 |
+
assert torch.all(wav.abs() <= 1), f"out wavform is not in [-1, 1], {wav.abs().max()}"
|
22 |
+
return wav, factor
|
23 |
+
|
24 |
+
|
25 |
+
class InputSTFT(nn.Module):
|
26 |
+
"""
|
27 |
+
The STFT of the input signal of CleanMel (STFT coefficients);
|
28 |
+
In online mode, the recursive normalization is used.
|
29 |
+
"""
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
n_fft: int,
|
33 |
+
n_win: int,
|
34 |
+
n_hop: int,
|
35 |
+
center: bool,
|
36 |
+
normalize: bool,
|
37 |
+
onesided: bool,
|
38 |
+
online: bool = False):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.online = online
|
42 |
+
self.stft=Spectrogram(
|
43 |
+
n_fft=n_fft,
|
44 |
+
win_length=n_win,
|
45 |
+
hop_length=n_hop,
|
46 |
+
normalized=normalize,
|
47 |
+
center=center,
|
48 |
+
onesided=onesided,
|
49 |
+
power=None
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
if self.online:
|
54 |
+
# recursive normalization
|
55 |
+
x = self.stft(x)
|
56 |
+
x_mag = x.abs()
|
57 |
+
x_norm = recursive_normalization(x_mag)
|
58 |
+
x = x / x_norm.clamp(min=1e-8)
|
59 |
+
x = torch.view_as_real(x)
|
60 |
+
else:
|
61 |
+
# vocos dBFS normalization
|
62 |
+
x, x_norm = soxnorm(x, random.randint(-6, -1) if self.training else -3)
|
63 |
+
x = self.stft(x)
|
64 |
+
x = torch.view_as_real(x)
|
65 |
+
return x, x_norm
|
66 |
+
|
67 |
+
|
68 |
+
class LibrosaMelScale(nn.Module):
|
69 |
+
r"""Pytorch implementation of librosa mel scale to align with common ESPNet ASR models;
|
70 |
+
You might need to define .
|
71 |
+
"""
|
72 |
+
def __init__(self, n_mels, sample_rate, f_min, f_max, n_stft, norm=None, mel_scale="slaney"):
|
73 |
+
super(LibrosaMelScale, self).__init__()
|
74 |
+
|
75 |
+
_mel_options = dict(
|
76 |
+
sr=sample_rate,
|
77 |
+
n_fft=(n_stft - 1) * 2,
|
78 |
+
n_mels=n_mels,
|
79 |
+
fmin=f_min,
|
80 |
+
fmax=f_max if f_max is not None else float(sample_rate // 2),
|
81 |
+
htk=mel_scale=="htk",
|
82 |
+
norm=norm
|
83 |
+
)
|
84 |
+
|
85 |
+
fb = torch.from_numpy(librosa.filters.mel(**_mel_options).T).float()
|
86 |
+
self.register_buffer("fb", fb)
|
87 |
+
|
88 |
+
def forward(self, specgram):
|
89 |
+
mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)
|
90 |
+
return mel_specgram
|
91 |
+
|
92 |
+
|
93 |
+
class TargetMel(nn.Module):
|
94 |
+
"""
|
95 |
+
This class generates the enhancement TARGET mel spectrogram;
|
96 |
+
"""
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
sample_rate: int,
|
100 |
+
n_fft: int,
|
101 |
+
n_win: int,
|
102 |
+
n_hop: int,
|
103 |
+
n_mels: int,
|
104 |
+
f_min: int,
|
105 |
+
f_max: int,
|
106 |
+
power: int,
|
107 |
+
center: bool,
|
108 |
+
normalize: bool,
|
109 |
+
onesided: bool,
|
110 |
+
mel_norm: str | None,
|
111 |
+
mel_scale: str,
|
112 |
+
librosa_mel: bool = True,
|
113 |
+
online: bool = False,
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
# This implementation vs torchaudio.transforms.MelSpectrogram: Add librosa melscale
|
117 |
+
# librosa melscale is numerically different from the torchaudio melscale (x_diff > 1e-5)
|
118 |
+
|
119 |
+
self.sample_rate = sample_rate
|
120 |
+
self.n_fft = n_fft
|
121 |
+
self.online = online
|
122 |
+
self.stft = Spectrogram(
|
123 |
+
n_fft=n_fft,
|
124 |
+
win_length=n_win,
|
125 |
+
hop_length=n_hop,
|
126 |
+
power=None if online else power,
|
127 |
+
normalized=normalize,
|
128 |
+
center=center,
|
129 |
+
onesided=onesided,
|
130 |
+
)
|
131 |
+
mel_method = LibrosaMelScale if librosa_mel else MelScale
|
132 |
+
self.mel_scale = mel_method(
|
133 |
+
n_mels=n_mels,
|
134 |
+
sample_rate=sample_rate,
|
135 |
+
f_min=f_min,
|
136 |
+
f_max=f_max,
|
137 |
+
n_stft=n_fft // 2 + 1,
|
138 |
+
norm=mel_norm,
|
139 |
+
mel_scale=mel_scale,
|
140 |
+
)
|
141 |
+
|
142 |
+
def forward(self, x: Tensor, x_norm=None):
|
143 |
+
if self.online:
|
144 |
+
# apply recursive normalization to target waveform
|
145 |
+
spectrogram = self.stft(x)
|
146 |
+
spectrogram = spectrogram / (x_norm + 1e-8)
|
147 |
+
spectrogram = spectrogram.abs().pow(2) # to power spectrogram
|
148 |
+
else:
|
149 |
+
# apply vocos dBFS normalization to target waveform
|
150 |
+
x, _ = soxnorm(x, None, x_norm)
|
151 |
+
spectrogram = self.stft(x)
|
152 |
+
# mel spectrogram
|
153 |
+
mel_specgram = self.mel_scale(spectrogram)
|
154 |
+
return mel_specgram
|
model/vocos/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
model/vocos/dataset.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from pytorch_lightning.utilities.types import EVAL_DATALOADERS
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
import warnings
|
8 |
+
from pytorch_lightning import LightningDataModule
|
9 |
+
from torch.utils.data import Dataset, DataLoader
|
10 |
+
|
11 |
+
torch.set_num_threads(1)
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class DataConfig:
|
16 |
+
filelist_path: str
|
17 |
+
sampling_rate: int
|
18 |
+
num_samples: int
|
19 |
+
batch_size: int
|
20 |
+
num_workers: int
|
21 |
+
|
22 |
+
|
23 |
+
class VocosDataModule(LightningDataModule):
|
24 |
+
def __init__(self, train_params: DataConfig, val_params: DataConfig):
|
25 |
+
super().__init__()
|
26 |
+
self.train_config = train_params
|
27 |
+
self.val_config = val_params
|
28 |
+
|
29 |
+
def _get_dataloder(self, cfg: DataConfig, train: bool):
|
30 |
+
dataset = VocosDataset(cfg, train=train)
|
31 |
+
dataloader = DataLoader(
|
32 |
+
dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True,
|
33 |
+
)
|
34 |
+
return dataloader
|
35 |
+
|
36 |
+
def train_dataloader(self) -> DataLoader:
|
37 |
+
return self._get_dataloder(self.train_config, train=True)
|
38 |
+
|
39 |
+
def val_dataloader(self) -> DataLoader:
|
40 |
+
return self._get_dataloder(self.val_config, train=False)
|
41 |
+
|
42 |
+
def test_dataloader(self) -> DataLoader:
|
43 |
+
return self.val_dataloader()
|
44 |
+
|
45 |
+
class VocosDataset(Dataset):
|
46 |
+
def __init__(self, cfg: DataConfig, train: bool):
|
47 |
+
with open(cfg.filelist_path) as f:
|
48 |
+
self.filelist = f.read().splitlines()
|
49 |
+
self.sampling_rate = cfg.sampling_rate
|
50 |
+
self.num_samples = cfg.num_samples
|
51 |
+
self.train = train
|
52 |
+
|
53 |
+
def __len__(self) -> int:
|
54 |
+
return len(self.filelist)
|
55 |
+
|
56 |
+
def customize_soxnorm(self, wav, gain=-3, factor=None):
|
57 |
+
wav = np.clip(wav, a_max=1, a_min=-1)
|
58 |
+
if factor is None:
|
59 |
+
linear_gain = 10 ** (gain / 20)
|
60 |
+
wav = wav / np.abs(wav).max() * linear_gain
|
61 |
+
return wav, linear_gain / np.abs(wav).max()
|
62 |
+
else:
|
63 |
+
wav = wav * factor
|
64 |
+
return wav, None
|
65 |
+
|
66 |
+
def __getitem__(self, index: int) -> torch.Tensor:
|
67 |
+
audio_path = self.filelist[index]
|
68 |
+
try:
|
69 |
+
y, sr = torchaudio.load(audio_path)
|
70 |
+
except:
|
71 |
+
warnings.warn(f"Error loading {audio_path}")
|
72 |
+
return self.__getitem__(np.random.randint(len(self.filelist)))
|
73 |
+
if y.size(-1) == 0:
|
74 |
+
return self.__getitem__(np.random.randint(len(self.filelist)))
|
75 |
+
if y.size(0) > 1:
|
76 |
+
# mix to mono
|
77 |
+
y = y.mean(dim=0, keepdim=True)
|
78 |
+
gain = np.random.uniform(-1, -6) if self.train else -3
|
79 |
+
y, _ = self.customize_soxnorm(y, gain)
|
80 |
+
if sr != self.sampling_rate:
|
81 |
+
y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate)
|
82 |
+
if y.size(-1) < self.num_samples:
|
83 |
+
pad_length = self.num_samples - y.size(-1)
|
84 |
+
padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
|
85 |
+
y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
|
86 |
+
elif self.train:
|
87 |
+
start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
|
88 |
+
y = y[:, start : start + self.num_samples]
|
89 |
+
else:
|
90 |
+
# During validation, take always the first segment for determinism
|
91 |
+
y = y[:, : self.num_samples]
|
92 |
+
|
93 |
+
return y[0]
|
model/vocos/discriminators.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from einops import rearrange
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import Conv2d
|
7 |
+
from torch.nn.utils import weight_norm
|
8 |
+
from torchaudio.transforms import Spectrogram
|
9 |
+
|
10 |
+
|
11 |
+
class MultiPeriodDiscriminator(nn.Module):
|
12 |
+
"""
|
13 |
+
Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
|
14 |
+
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
periods (tuple[int]): Tuple of periods for each discriminator.
|
18 |
+
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
19 |
+
Defaults to None.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, periods: Tuple[int, ...] = (2, 3, 5, 7, 11), num_embeddings: Optional[int] = None):
|
23 |
+
super().__init__()
|
24 |
+
self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods])
|
25 |
+
|
26 |
+
def forward(
|
27 |
+
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None
|
28 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
29 |
+
y_d_rs = []
|
30 |
+
y_d_gs = []
|
31 |
+
fmap_rs = []
|
32 |
+
fmap_gs = []
|
33 |
+
for d in self.discriminators:
|
34 |
+
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
35 |
+
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
36 |
+
y_d_rs.append(y_d_r)
|
37 |
+
fmap_rs.append(fmap_r)
|
38 |
+
y_d_gs.append(y_d_g)
|
39 |
+
fmap_gs.append(fmap_g)
|
40 |
+
|
41 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
42 |
+
|
43 |
+
|
44 |
+
class DiscriminatorP(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
period: int,
|
48 |
+
in_channels: int = 1,
|
49 |
+
kernel_size: int = 5,
|
50 |
+
stride: int = 3,
|
51 |
+
lrelu_slope: float = 0.1,
|
52 |
+
num_embeddings: Optional[int] = None,
|
53 |
+
):
|
54 |
+
super().__init__()
|
55 |
+
self.period = period
|
56 |
+
self.convs = nn.ModuleList(
|
57 |
+
[
|
58 |
+
weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
59 |
+
weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
60 |
+
weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
61 |
+
weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
62 |
+
weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))),
|
63 |
+
]
|
64 |
+
)
|
65 |
+
if num_embeddings is not None:
|
66 |
+
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024)
|
67 |
+
torch.nn.init.zeros_(self.emb.weight)
|
68 |
+
|
69 |
+
self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
70 |
+
self.lrelu_slope = lrelu_slope
|
71 |
+
|
72 |
+
def forward(
|
73 |
+
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
74 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
75 |
+
x = x.unsqueeze(1)
|
76 |
+
fmap = []
|
77 |
+
# 1d to 2d
|
78 |
+
b, c, t = x.shape
|
79 |
+
if t % self.period != 0: # pad first
|
80 |
+
n_pad = self.period - (t % self.period)
|
81 |
+
x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
|
82 |
+
t = t + n_pad
|
83 |
+
x = x.view(b, c, t // self.period, self.period)
|
84 |
+
|
85 |
+
for i, l in enumerate(self.convs):
|
86 |
+
x = l(x)
|
87 |
+
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
|
88 |
+
if i > 0:
|
89 |
+
fmap.append(x)
|
90 |
+
if cond_embedding_id is not None:
|
91 |
+
emb = self.emb(cond_embedding_id)
|
92 |
+
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
93 |
+
else:
|
94 |
+
h = 0
|
95 |
+
x = self.conv_post(x)
|
96 |
+
fmap.append(x)
|
97 |
+
x += h
|
98 |
+
x = torch.flatten(x, 1, -1)
|
99 |
+
|
100 |
+
return x, fmap
|
101 |
+
|
102 |
+
|
103 |
+
class MultiResolutionDiscriminator(nn.Module):
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
|
107 |
+
num_embeddings: Optional[int] = None,
|
108 |
+
):
|
109 |
+
"""
|
110 |
+
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
|
111 |
+
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
|
115 |
+
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
116 |
+
Defaults to None.
|
117 |
+
"""
|
118 |
+
|
119 |
+
super().__init__()
|
120 |
+
self.discriminators = nn.ModuleList(
|
121 |
+
[DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
|
122 |
+
)
|
123 |
+
|
124 |
+
def forward(
|
125 |
+
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
126 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
127 |
+
y_d_rs = []
|
128 |
+
y_d_gs = []
|
129 |
+
fmap_rs = []
|
130 |
+
fmap_gs = []
|
131 |
+
|
132 |
+
for d in self.discriminators:
|
133 |
+
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
134 |
+
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
135 |
+
y_d_rs.append(y_d_r)
|
136 |
+
fmap_rs.append(fmap_r)
|
137 |
+
y_d_gs.append(y_d_g)
|
138 |
+
fmap_gs.append(fmap_g)
|
139 |
+
|
140 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
141 |
+
|
142 |
+
|
143 |
+
class DiscriminatorR(nn.Module):
|
144 |
+
def __init__(
|
145 |
+
self,
|
146 |
+
window_length: int,
|
147 |
+
num_embeddings: Optional[int] = None,
|
148 |
+
channels: int = 32,
|
149 |
+
hop_factor: float = 0.25,
|
150 |
+
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
|
151 |
+
):
|
152 |
+
super().__init__()
|
153 |
+
self.window_length = window_length
|
154 |
+
self.hop_factor = hop_factor
|
155 |
+
self.spec_fn = Spectrogram(
|
156 |
+
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
|
157 |
+
)
|
158 |
+
n_fft = window_length // 2 + 1
|
159 |
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
160 |
+
self.bands = bands
|
161 |
+
convs = lambda: nn.ModuleList(
|
162 |
+
[
|
163 |
+
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
164 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
165 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
166 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
167 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
168 |
+
]
|
169 |
+
)
|
170 |
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
171 |
+
|
172 |
+
if num_embeddings is not None:
|
173 |
+
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
|
174 |
+
torch.nn.init.zeros_(self.emb.weight)
|
175 |
+
|
176 |
+
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
177 |
+
|
178 |
+
def spectrogram(self, x):
|
179 |
+
# Remove DC offset
|
180 |
+
x = x - x.mean(dim=-1, keepdims=True)
|
181 |
+
# Peak normalize the volume of input audio
|
182 |
+
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
183 |
+
x = self.spec_fn(x)
|
184 |
+
x = torch.view_as_real(x)
|
185 |
+
x = rearrange(x, "b f t c -> b c t f")
|
186 |
+
# Split into bands
|
187 |
+
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
188 |
+
return x_bands
|
189 |
+
|
190 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
191 |
+
x_bands = self.spectrogram(x)
|
192 |
+
fmap = []
|
193 |
+
x = []
|
194 |
+
for band, stack in zip(x_bands, self.band_convs):
|
195 |
+
for i, layer in enumerate(stack):
|
196 |
+
band = layer(band)
|
197 |
+
band = torch.nn.functional.leaky_relu(band, 0.1)
|
198 |
+
if i > 0:
|
199 |
+
fmap.append(band)
|
200 |
+
x.append(band)
|
201 |
+
x = torch.cat(x, dim=-1)
|
202 |
+
if cond_embedding_id is not None:
|
203 |
+
emb = self.emb(cond_embedding_id)
|
204 |
+
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
205 |
+
else:
|
206 |
+
h = 0
|
207 |
+
x = self.conv_post(x)
|
208 |
+
fmap.append(x)
|
209 |
+
x += h
|
210 |
+
|
211 |
+
return x, fmap
|
model/vocos/experiment.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
import transformers
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from model.vocos.offline.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
|
11 |
+
from model.vocos.offline.feature_extractors import FeatureExtractor
|
12 |
+
from model.vocos.offline.heads import FourierHead
|
13 |
+
from model.vocos.offline.helpers import plot_spectrogram_to_numpy
|
14 |
+
from model.vocos.offline.loss import DiscriminatorLoss, GeneratorLoss, FeatureMatchingLoss, MelSpecReconstructionLoss
|
15 |
+
# from models.vocos.offline.models import Backbone
|
16 |
+
from model.vocos.offline.modules import safe_log
|
17 |
+
|
18 |
+
|
19 |
+
class VocosExp(pl.LightningModule):
|
20 |
+
# noinspection PyUnusedLocal
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
feature_extractor: FeatureExtractor,
|
24 |
+
backbone: nn.Module,
|
25 |
+
head: nn.Module,
|
26 |
+
sample_rate: int,
|
27 |
+
initial_learning_rate: float,
|
28 |
+
num_warmup_steps: int = 0,
|
29 |
+
mel_loss_coeff: float = 45,
|
30 |
+
mrd_loss_coeff: float = 1.0,
|
31 |
+
pretrain_mel_steps: int = 0,
|
32 |
+
decay_mel_coeff: bool = False,
|
33 |
+
evaluate_utmos: bool = False,
|
34 |
+
evaluate_pesq: bool = False,
|
35 |
+
evaluate_periodicty: bool = False,
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
feature_extractor (FeatureExtractor): An instance of FeatureExtractor to extract features from audio signals.
|
40 |
+
backbone (Backbone): An instance of Backbone model.
|
41 |
+
head (FourierHead): An instance of Fourier head to generate spectral coefficients and reconstruct a waveform.
|
42 |
+
sample_rate (int): Sampling rate of the audio signals.
|
43 |
+
initial_learning_rate (float): Initial learning rate for the optimizer.
|
44 |
+
num_warmup_steps (int): Number of steps for the warmup phase of learning rate scheduler. Default is 0.
|
45 |
+
mel_loss_coeff (float, optional): Coefficient for Mel-spectrogram loss in the loss function. Default is 45.
|
46 |
+
mrd_loss_coeff (float, optional): Coefficient for Multi Resolution Discriminator loss. Default is 1.0.
|
47 |
+
pretrain_mel_steps (int, optional): Number of steps to pre-train the model without the GAN objective. Default is 0.
|
48 |
+
decay_mel_coeff (bool, optional): If True, the Mel-spectrogram loss coefficient is decayed during training. Default is False.
|
49 |
+
evaluate_utmos (bool, optional): If True, UTMOS scores are computed for each validation run.
|
50 |
+
evaluate_pesq (bool, optional): If True, PESQ scores are computed for each validation run.
|
51 |
+
evaluate_periodicty (bool, optional): If True, periodicity scores are computed for each validation run.
|
52 |
+
"""
|
53 |
+
super().__init__()
|
54 |
+
self.save_hyperparameters(ignore=["feature_extractor", "backbone", "head"])
|
55 |
+
self.feature_extractor = feature_extractor
|
56 |
+
self.backbone = backbone
|
57 |
+
self.head = head
|
58 |
+
self.sample_rate = sample_rate
|
59 |
+
self.initial_learning_rate = initial_learning_rate
|
60 |
+
self.num_warmup_steps = num_warmup_steps
|
61 |
+
self.mel_loss_coeff = mel_loss_coeff
|
62 |
+
self.mrd_loss_coeff = mrd_loss_coeff
|
63 |
+
self.pretrain_mel_steps = pretrain_mel_steps
|
64 |
+
self.decay_mel_coeff = decay_mel_coeff
|
65 |
+
self.evaluate_utmos = evaluate_utmos
|
66 |
+
self.evaluate_pesq = evaluate_pesq
|
67 |
+
self.evaluate_periodicty = evaluate_periodicty
|
68 |
+
|
69 |
+
self.multiperioddisc = MultiPeriodDiscriminator()
|
70 |
+
self.multiresddisc = MultiResolutionDiscriminator()
|
71 |
+
|
72 |
+
self.disc_loss = DiscriminatorLoss()
|
73 |
+
self.gen_loss = GeneratorLoss()
|
74 |
+
self.feat_matching_loss = FeatureMatchingLoss()
|
75 |
+
self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate)
|
76 |
+
|
77 |
+
self.train_discriminator = False
|
78 |
+
self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff
|
79 |
+
self.temp_cache=None
|
80 |
+
self.temp_grad=None
|
81 |
+
|
82 |
+
def configure_optimizers(self):
|
83 |
+
disc_params = [
|
84 |
+
{"params": self.multiperioddisc.parameters()},
|
85 |
+
{"params": self.multiresddisc.parameters()},
|
86 |
+
]
|
87 |
+
gen_params = [
|
88 |
+
{"params": self.feature_extractor.parameters()},
|
89 |
+
{"params": self.backbone.parameters()},
|
90 |
+
{"params": self.head.parameters()},
|
91 |
+
]
|
92 |
+
|
93 |
+
opt_disc = torch.optim.AdamW(disc_params, lr=self.initial_learning_rate, betas=(0.8, 0.9))
|
94 |
+
opt_gen = torch.optim.AdamW(gen_params, lr=self.initial_learning_rate, betas=(0.8, 0.9))
|
95 |
+
|
96 |
+
max_steps = self.trainer.max_steps // 2 # Max steps per optimizer
|
97 |
+
scheduler_disc = transformers.get_cosine_schedule_with_warmup(
|
98 |
+
opt_disc, num_warmup_steps=self.num_warmup_steps, num_training_steps=max_steps,
|
99 |
+
)
|
100 |
+
scheduler_gen = transformers.get_cosine_schedule_with_warmup(
|
101 |
+
opt_gen, num_warmup_steps=self.num_warmup_steps, num_training_steps=max_steps,
|
102 |
+
)
|
103 |
+
|
104 |
+
return (
|
105 |
+
[opt_disc, opt_gen],
|
106 |
+
[{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}],
|
107 |
+
)
|
108 |
+
|
109 |
+
def forward(self, audio_input, **kwargs):
|
110 |
+
features = self.feature_extractor(audio_input, **kwargs)
|
111 |
+
x = self.backbone(features, **kwargs)
|
112 |
+
audio_output = self.head(x)
|
113 |
+
return audio_output
|
114 |
+
|
115 |
+
def training_step(self, batch, batch_idx, optimizer_idx, **kwargs):
|
116 |
+
audio_input = batch
|
117 |
+
# train discriminator
|
118 |
+
if optimizer_idx == 0 and self.train_discriminator:
|
119 |
+
with torch.no_grad():
|
120 |
+
audio_hat = self(audio_input, **kwargs)
|
121 |
+
real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
|
122 |
+
real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
|
123 |
+
loss_mp, loss_mp_real, _ = self.disc_loss(
|
124 |
+
disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp
|
125 |
+
)
|
126 |
+
loss_mrd, loss_mrd_real, _ = self.disc_loss(
|
127 |
+
disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd
|
128 |
+
)
|
129 |
+
loss_mp /= len(loss_mp_real)
|
130 |
+
loss_mrd /= len(loss_mrd_real)
|
131 |
+
loss = loss_mp + self.mrd_loss_coeff * loss_mrd
|
132 |
+
|
133 |
+
self.log("discriminator/total", loss, prog_bar=True)
|
134 |
+
self.log("discriminator/multi_period_loss", loss_mp)
|
135 |
+
self.log("discriminator/multi_res_loss", loss_mrd)
|
136 |
+
return loss
|
137 |
+
|
138 |
+
# train generator
|
139 |
+
if optimizer_idx == 1:
|
140 |
+
audio_hat = self(audio_input, **kwargs)
|
141 |
+
if self.train_discriminator:
|
142 |
+
_, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc(
|
143 |
+
y=audio_input, y_hat=audio_hat, **kwargs,
|
144 |
+
)
|
145 |
+
_, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc(
|
146 |
+
y=audio_input, y_hat=audio_hat, **kwargs,
|
147 |
+
)
|
148 |
+
loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp)
|
149 |
+
loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd)
|
150 |
+
loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp)
|
151 |
+
loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd)
|
152 |
+
loss_fm_mp = self.feat_matching_loss(fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp) / len(fmap_rs_mp)
|
153 |
+
loss_fm_mrd = self.feat_matching_loss(fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd) / len(fmap_rs_mrd)
|
154 |
+
|
155 |
+
self.log("generator/multi_period_loss", loss_gen_mp)
|
156 |
+
self.log("generator/multi_res_loss", loss_gen_mrd)
|
157 |
+
self.log("generator/feature_matching_mp", loss_fm_mp)
|
158 |
+
self.log("generator/feature_matching_mrd", loss_fm_mrd)
|
159 |
+
else:
|
160 |
+
loss_gen_mp = loss_gen_mrd = loss_fm_mp = loss_fm_mrd = 0
|
161 |
+
|
162 |
+
mel_loss = self.melspec_loss(audio_hat, audio_input)
|
163 |
+
loss = (
|
164 |
+
loss_gen_mp
|
165 |
+
+ self.mrd_loss_coeff * loss_gen_mrd
|
166 |
+
+ loss_fm_mp
|
167 |
+
+ self.mrd_loss_coeff * loss_fm_mrd
|
168 |
+
+ self.mel_loss_coeff * mel_loss
|
169 |
+
)
|
170 |
+
|
171 |
+
self.log("generator/total_loss", loss, prog_bar=True)
|
172 |
+
self.log("mel_loss_coeff", self.mel_loss_coeff)
|
173 |
+
self.log("generator/mel_loss", mel_loss)
|
174 |
+
|
175 |
+
if self.global_step % 1000 == 0 and self.global_rank == 0:
|
176 |
+
self.logger.experiment.add_audio(
|
177 |
+
"train/audio_in", audio_input[0].data.cpu(), self.global_step, self.sample_rate
|
178 |
+
)
|
179 |
+
self.logger.experiment.add_audio(
|
180 |
+
"train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.sample_rate
|
181 |
+
)
|
182 |
+
with torch.no_grad():
|
183 |
+
mel = safe_log(self.melspec_loss.mel_spec(audio_input[0]))
|
184 |
+
mel_hat = safe_log(self.melspec_loss.mel_spec(audio_hat[0]))
|
185 |
+
self.logger.experiment.add_image(
|
186 |
+
"train/mel_target",
|
187 |
+
plot_spectrogram_to_numpy(mel.data.cpu().numpy()),
|
188 |
+
self.global_step,
|
189 |
+
dataformats="HWC",
|
190 |
+
)
|
191 |
+
self.logger.experiment.add_image(
|
192 |
+
"train/mel_pred",
|
193 |
+
plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
|
194 |
+
self.global_step,
|
195 |
+
dataformats="HWC",
|
196 |
+
)
|
197 |
+
return loss
|
198 |
+
|
199 |
+
def on_validation_epoch_start(self):
|
200 |
+
if self.evaluate_utmos:
|
201 |
+
from model.vocos.metrics.UTMOS import UTMOSScore
|
202 |
+
# if not hasattr(self, "utmos_model"):
|
203 |
+
self.utmos_model = UTMOSScore(device=self.device)
|
204 |
+
|
205 |
+
def validation_step(self, batch, batch_idx, **kwargs):
|
206 |
+
audio_input = batch
|
207 |
+
audio_hat = self(audio_input, **kwargs)
|
208 |
+
|
209 |
+
audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.sample_rate, new_freq=16000)
|
210 |
+
audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.sample_rate, new_freq=16000)
|
211 |
+
|
212 |
+
if self.evaluate_periodicty:
|
213 |
+
from model.vocos.metrics.periodicity import calculate_periodicity_metrics
|
214 |
+
|
215 |
+
periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz)
|
216 |
+
else:
|
217 |
+
periodicity_loss = pitch_loss = f1_score = 0
|
218 |
+
|
219 |
+
if self.evaluate_utmos:
|
220 |
+
utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean()
|
221 |
+
else:
|
222 |
+
utmos_score = torch.zeros(1, device=self.device)
|
223 |
+
|
224 |
+
if self.evaluate_pesq:
|
225 |
+
from pesq import pesq
|
226 |
+
|
227 |
+
pesq_score = 0
|
228 |
+
for ref, deg in zip(audio_16_khz.cpu().numpy(), audio_hat_16khz.cpu().numpy()):
|
229 |
+
pesq_score += pesq(16000, ref, deg, "wb", on_error=1)
|
230 |
+
pesq_score /= len(audio_16_khz)
|
231 |
+
pesq_score = torch.tensor(pesq_score)
|
232 |
+
else:
|
233 |
+
pesq_score = torch.zeros(1, device=self.device)
|
234 |
+
|
235 |
+
mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1))
|
236 |
+
total_loss = mel_loss + (5 - utmos_score) + (5 - pesq_score)
|
237 |
+
|
238 |
+
return {
|
239 |
+
"val_loss": total_loss,
|
240 |
+
"mel_loss": mel_loss,
|
241 |
+
"utmos_score": utmos_score,
|
242 |
+
"pesq_score": pesq_score,
|
243 |
+
"periodicity_loss": periodicity_loss,
|
244 |
+
"pitch_loss": pitch_loss,
|
245 |
+
"f1_score": f1_score,
|
246 |
+
"audio_input": audio_input[0],
|
247 |
+
"audio_pred": audio_hat[0],
|
248 |
+
}
|
249 |
+
|
250 |
+
def validation_epoch_end(self, outputs):
|
251 |
+
if self.global_rank == 0:
|
252 |
+
for i, output in enumerate(outputs):
|
253 |
+
*_, audio_in, audio_pred = output.values()
|
254 |
+
self.logger.experiment.add_audio(
|
255 |
+
f"val_in_{i}", audio_in.data.cpu().numpy(), self.global_step, self.sample_rate
|
256 |
+
)
|
257 |
+
self.logger.experiment.add_audio(
|
258 |
+
f"val_pred_{i}", audio_pred.data.cpu().numpy(), self.global_step, self.sample_rate
|
259 |
+
)
|
260 |
+
mel_target = safe_log(self.melspec_loss.mel_spec(audio_in))
|
261 |
+
mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred))
|
262 |
+
self.logger.experiment.add_image(
|
263 |
+
f"val_mel_target_{i}",
|
264 |
+
plot_spectrogram_to_numpy(mel_target.data.cpu().numpy()),
|
265 |
+
self.global_step,
|
266 |
+
dataformats="HWC",
|
267 |
+
)
|
268 |
+
self.logger.experiment.add_image(
|
269 |
+
f"val_mel_hat_{i}",
|
270 |
+
plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
|
271 |
+
self.global_step,
|
272 |
+
dataformats="HWC",
|
273 |
+
)
|
274 |
+
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
|
275 |
+
mel_loss = torch.stack([x["mel_loss"] for x in outputs]).mean()
|
276 |
+
utmos_score = torch.stack([x["utmos_score"] for x in outputs]).mean()
|
277 |
+
pesq_score = torch.stack([x["pesq_score"] for x in outputs]).mean()
|
278 |
+
periodicity_loss = np.array([x["periodicity_loss"] for x in outputs]).mean()
|
279 |
+
pitch_loss = np.array([x["pitch_loss"] for x in outputs]).mean()
|
280 |
+
f1_score = np.array([x["f1_score"] for x in outputs]).mean()
|
281 |
+
|
282 |
+
self.log("val_loss", avg_loss, sync_dist=True)
|
283 |
+
self.log("val/mel_loss", mel_loss, sync_dist=True)
|
284 |
+
self.log("val/utmos_score", utmos_score, sync_dist=True)
|
285 |
+
self.log("val/pesq_score", pesq_score, sync_dist=True)
|
286 |
+
self.log("val/periodicity_loss", periodicity_loss, sync_dist=True)
|
287 |
+
self.log("val/pitch_loss", pitch_loss, sync_dist=True)
|
288 |
+
self.log("val/f1_score", f1_score, sync_dist=True)
|
289 |
+
|
290 |
+
return {
|
291 |
+
"avg_loss": avg_loss,
|
292 |
+
"mel_loss": mel_loss,
|
293 |
+
"utmos_score": utmos_score,
|
294 |
+
"pesq_score": pesq_score,
|
295 |
+
"periodicity_loss": periodicity_loss,
|
296 |
+
"pitch_loss": pitch_loss,
|
297 |
+
"f1_score": f1_score,
|
298 |
+
}
|
299 |
+
def on_test_epoch_start(self):
|
300 |
+
self.on_validation_epoch_start()
|
301 |
+
|
302 |
+
def test_step(self, *args, **kwargs):
|
303 |
+
return self.validation_step(*args, **kwargs)
|
304 |
+
|
305 |
+
def test_epoch_end(self, outputs):
|
306 |
+
results = self.validation_epoch_end(outputs)
|
307 |
+
print(results)
|
308 |
+
@property
|
309 |
+
def global_step(self):
|
310 |
+
"""
|
311 |
+
Override global_step so that it returns the total number of batches processed
|
312 |
+
"""
|
313 |
+
return self.trainer.fit_loop.epoch_loop.total_batch_idx
|
314 |
+
|
315 |
+
def on_train_batch_start(self, *args):
|
316 |
+
if self.global_step >= self.pretrain_mel_steps:
|
317 |
+
self.train_discriminator = True
|
318 |
+
else:
|
319 |
+
self.train_discriminator = False
|
320 |
+
|
321 |
+
def on_train_batch_end(self, *args):
|
322 |
+
def mel_loss_coeff_decay(current_step, num_cycles=0.5):
|
323 |
+
max_steps = self.trainer.max_steps // 2
|
324 |
+
if current_step < self.num_warmup_steps:
|
325 |
+
return 1.0
|
326 |
+
progress = float(current_step - self.num_warmup_steps) / float(
|
327 |
+
max(1, max_steps - self.num_warmup_steps)
|
328 |
+
)
|
329 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
330 |
+
|
331 |
+
if self.decay_mel_coeff:
|
332 |
+
self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1)
|
333 |
+
|
334 |
+
|
335 |
+
class VocosEncodecExp(VocosExp):
|
336 |
+
"""
|
337 |
+
VocosEncodecExp is a subclass of VocosExp that overrides the parent experiment to function as a conditional GAN.
|
338 |
+
It manages an additional `bandwidth_id` attribute, which denotes a learnable embedding corresponding to
|
339 |
+
a specific bandwidth value of EnCodec. During training, a random bandwidth_id is generated for each step,
|
340 |
+
while during validation, a fixed bandwidth_id is used.
|
341 |
+
"""
|
342 |
+
|
343 |
+
def __init__(
|
344 |
+
self,
|
345 |
+
feature_extractor: FeatureExtractor,
|
346 |
+
backbone: pl.LightningModule,
|
347 |
+
head: pl.LightningModule,
|
348 |
+
sample_rate: int,
|
349 |
+
initial_learning_rate: float,
|
350 |
+
num_warmup_steps: int,
|
351 |
+
mel_loss_coeff: float = 45,
|
352 |
+
mrd_loss_coeff: float = 1.0,
|
353 |
+
pretrain_mel_steps: int = 0,
|
354 |
+
decay_mel_coeff: bool = False,
|
355 |
+
evaluate_utmos: bool = False,
|
356 |
+
evaluate_pesq: bool = False,
|
357 |
+
evaluate_periodicty: bool = False,
|
358 |
+
):
|
359 |
+
super().__init__(
|
360 |
+
feature_extractor,
|
361 |
+
backbone,
|
362 |
+
head,
|
363 |
+
sample_rate,
|
364 |
+
initial_learning_rate,
|
365 |
+
num_warmup_steps,
|
366 |
+
mel_loss_coeff,
|
367 |
+
mrd_loss_coeff,
|
368 |
+
pretrain_mel_steps,
|
369 |
+
decay_mel_coeff,
|
370 |
+
evaluate_utmos,
|
371 |
+
evaluate_pesq,
|
372 |
+
evaluate_periodicty,
|
373 |
+
)
|
374 |
+
# Override with conditional discriminators
|
375 |
+
self.multiperioddisc = MultiPeriodDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
|
376 |
+
self.multiresddisc = MultiResolutionDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
|
377 |
+
|
378 |
+
def training_step(self, *args):
|
379 |
+
bandwidth_id = torch.randint(low=0, high=len(self.feature_extractor.bandwidths), size=(1,), device=self.device,)
|
380 |
+
output = super().training_step(*args, bandwidth_id=bandwidth_id)
|
381 |
+
return output
|
382 |
+
|
383 |
+
def validation_step(self, *args):
|
384 |
+
bandwidth_id = torch.tensor([0], device=self.device)
|
385 |
+
output = super().validation_step(*args, bandwidth_id=bandwidth_id)
|
386 |
+
return output
|
387 |
+
|
388 |
+
def validation_epoch_end(self, outputs):
|
389 |
+
if self.global_rank == 0:
|
390 |
+
*_, audio_in, _ = outputs[0].values()
|
391 |
+
# Resynthesis with encodec for reference
|
392 |
+
self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0])
|
393 |
+
encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :])
|
394 |
+
self.logger.experiment.add_audio(
|
395 |
+
"encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.sample_rate,
|
396 |
+
)
|
397 |
+
|
398 |
+
super().validation_epoch_end(outputs)
|
model/vocos/feature_extractors.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import librosa
|
5 |
+
from encodec import EncodecModel
|
6 |
+
from torch import nn
|
7 |
+
from torch import Tensor
|
8 |
+
from typing import Optional
|
9 |
+
from torchaudio.transforms import Spectrogram, MelScale
|
10 |
+
from model.vocos.modules import safe_log
|
11 |
+
|
12 |
+
|
13 |
+
class FeatureExtractor(nn.Module):
|
14 |
+
"""Base class for feature extractors."""
|
15 |
+
|
16 |
+
def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
|
17 |
+
"""
|
18 |
+
Extract features from the given audio.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
audio (Tensor): Input audio waveform.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
Tensor: Extracted features of shape (B, C, L), where B is the batch size,
|
25 |
+
C denotes output features, and L is the sequence length.
|
26 |
+
"""
|
27 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
28 |
+
|
29 |
+
|
30 |
+
class LibrosaMelScale(nn.Module):
|
31 |
+
r"""This MelScale has a create_fb_matrix function that can be used to create a filterbank matrix.
|
32 |
+
same as previous torchaudio version
|
33 |
+
"""
|
34 |
+
__constants__ = ["n_mels", "sample_rate", "f_min", "f_max"]
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
n_mels: int = 128,
|
39 |
+
sample_rate: int = 16000,
|
40 |
+
f_min: float = 0.0,
|
41 |
+
f_max: Optional[float] = None,
|
42 |
+
n_stft: int = 201,
|
43 |
+
norm: Optional[str] = None,
|
44 |
+
mel_scale: str = "htk",
|
45 |
+
) -> None:
|
46 |
+
super(LibrosaMelScale, self).__init__()
|
47 |
+
self.n_mels = n_mels
|
48 |
+
self.sample_rate = sample_rate
|
49 |
+
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
|
50 |
+
self.f_min = f_min
|
51 |
+
self.norm = norm
|
52 |
+
self.mel_scale = mel_scale
|
53 |
+
|
54 |
+
if f_min > self.f_max:
|
55 |
+
raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max))
|
56 |
+
_mel_options = dict(
|
57 |
+
sr=sample_rate,
|
58 |
+
n_fft=(n_stft - 1) * 2,
|
59 |
+
n_mels=n_mels,
|
60 |
+
fmin=f_min,
|
61 |
+
fmax=f_max,
|
62 |
+
htk=mel_scale=="htk",
|
63 |
+
norm=norm
|
64 |
+
)
|
65 |
+
fb = torch.from_numpy(librosa.filters.mel(**_mel_options).T).float()
|
66 |
+
self.register_buffer("fb", fb)
|
67 |
+
|
68 |
+
def forward(self, specgram):
|
69 |
+
mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)
|
70 |
+
return mel_specgram
|
71 |
+
|
72 |
+
|
73 |
+
class MelSpectrogramFeatures(FeatureExtractor):
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
sample_rate: int,
|
77 |
+
n_fft: int,
|
78 |
+
n_win: int,
|
79 |
+
n_hop: int,
|
80 |
+
n_mels: int,
|
81 |
+
f_min: int,
|
82 |
+
f_max: int,
|
83 |
+
power: int,
|
84 |
+
center: bool,
|
85 |
+
normalize: bool,
|
86 |
+
onesided: bool,
|
87 |
+
mel_norm: str | None,
|
88 |
+
mel_scale: str,
|
89 |
+
librosa_mel: bool = True,
|
90 |
+
clip_val: float = 1e-7
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
# This implementation vs torchaudio.transforms.MelSpectrogram: Add librosa melscale
|
94 |
+
# librosa melscale is numerically different from the torchaudio melscale (x_diff > 1e-5)
|
95 |
+
self.n_fft = n_fft
|
96 |
+
self.spectrogram = Spectrogram(
|
97 |
+
n_fft=n_fft,
|
98 |
+
win_length=n_win,
|
99 |
+
hop_length=n_hop,
|
100 |
+
power=power,
|
101 |
+
normalized=normalize,
|
102 |
+
center=center,
|
103 |
+
onesided=onesided,
|
104 |
+
)
|
105 |
+
mel_method = LibrosaMelScale if librosa_mel else MelScale
|
106 |
+
self.mel_scale = mel_method(
|
107 |
+
n_mels=n_mels,
|
108 |
+
sample_rate=sample_rate,
|
109 |
+
f_min=f_min,
|
110 |
+
f_max=f_max,
|
111 |
+
n_stft=n_fft // 2 + 1,
|
112 |
+
norm=mel_norm,
|
113 |
+
mel_scale=mel_scale,
|
114 |
+
)
|
115 |
+
self.clip_val = clip_val
|
116 |
+
|
117 |
+
def forward(self, x: Tensor) -> Tensor:
|
118 |
+
# Compute Spectrogram
|
119 |
+
specgram = self.spectrogram(x)
|
120 |
+
mel_specgram = self.mel_scale(specgram)
|
121 |
+
return safe_log(mel_specgram, self.clip_val)
|
122 |
+
|
123 |
+
class EncodecFeatures(FeatureExtractor):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
encodec_model: str = "encodec_24khz",
|
127 |
+
bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0],
|
128 |
+
train_codebooks: bool = False,
|
129 |
+
):
|
130 |
+
super().__init__()
|
131 |
+
if encodec_model == "encodec_24khz":
|
132 |
+
encodec = EncodecModel.encodec_model_24khz
|
133 |
+
elif encodec_model == "encodec_48khz":
|
134 |
+
encodec = EncodecModel.encodec_model_48khz
|
135 |
+
else:
|
136 |
+
raise ValueError(
|
137 |
+
f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz' and 'encodec_48khz'."
|
138 |
+
)
|
139 |
+
self.encodec = encodec(pretrained=True)
|
140 |
+
for param in self.encodec.parameters():
|
141 |
+
param.requires_grad = False
|
142 |
+
self.num_q = self.encodec.quantizer.get_num_quantizers_for_bandwidth(
|
143 |
+
self.encodec.frame_rate, bandwidth=max(bandwidths)
|
144 |
+
)
|
145 |
+
codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0)
|
146 |
+
self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks)
|
147 |
+
self.bandwidths = bandwidths
|
148 |
+
|
149 |
+
@torch.no_grad()
|
150 |
+
def get_encodec_codes(self, audio):
|
151 |
+
audio = audio.unsqueeze(1)
|
152 |
+
emb = self.encodec.encoder(audio)
|
153 |
+
codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
|
154 |
+
return codes
|
155 |
+
|
156 |
+
def forward(self, audio: torch.Tensor, **kwargs):
|
157 |
+
bandwidth_id = kwargs.get("bandwidth_id")
|
158 |
+
if bandwidth_id is None:
|
159 |
+
raise ValueError("The 'bandwidth_id' argument is required")
|
160 |
+
self.encodec.eval() # Force eval mode as Pytorch Lightning automatically sets child modules to training mode
|
161 |
+
self.encodec.set_target_bandwidth(self.bandwidths[bandwidth_id])
|
162 |
+
codes = self.get_encodec_codes(audio)
|
163 |
+
# Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights`
|
164 |
+
# with offsets given by the number of bins, and finally summed in a vectorized operation.
|
165 |
+
offsets = torch.arange(
|
166 |
+
0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device
|
167 |
+
)
|
168 |
+
embeddings_idxs = codes + offsets.view(-1, 1, 1)
|
169 |
+
features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0)
|
170 |
+
return features.transpose(1, 2)
|
model/vocos/heads.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
|
6 |
+
|
7 |
+
from model.vocos.spectral_ops import IMDCT, ISTFT
|
8 |
+
from model.vocos.modules import symexp
|
9 |
+
|
10 |
+
|
11 |
+
class FourierHead(nn.Module):
|
12 |
+
"""Base class for inverse fourier modules."""
|
13 |
+
|
14 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
15 |
+
"""
|
16 |
+
Args:
|
17 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
18 |
+
L is the sequence length, and H denotes the model dimension.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
22 |
+
"""
|
23 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
24 |
+
|
25 |
+
|
26 |
+
class ISTFTHead(FourierHead):
|
27 |
+
"""
|
28 |
+
ISTFT Head module for predicting STFT complex coefficients.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
dim (int): Hidden dimension of the model.
|
32 |
+
n_fft (int): Size of Fourier transform.
|
33 |
+
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
34 |
+
the resolution of the input features.
|
35 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
39 |
+
super().__init__()
|
40 |
+
out_dim = n_fft + 2
|
41 |
+
self.out = torch.nn.Linear(dim, out_dim)
|
42 |
+
self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
|
43 |
+
|
44 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
45 |
+
"""
|
46 |
+
Forward pass of the ISTFTHead module.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
50 |
+
L is the sequence length, and H denotes the model dimension.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
54 |
+
"""
|
55 |
+
x = self.out(x).transpose(1, 2)
|
56 |
+
mag, p = x.chunk(2, dim=1)
|
57 |
+
mag = torch.exp(mag)
|
58 |
+
mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
|
59 |
+
# wrapping happens here. These two lines produce real and imaginary value
|
60 |
+
x = torch.cos(p)
|
61 |
+
y = torch.sin(p)
|
62 |
+
# recalculating phase here does not produce anything new
|
63 |
+
# only costs time
|
64 |
+
# phase = torch.atan2(y, x)
|
65 |
+
# S = mag * torch.exp(phase * 1j)
|
66 |
+
# better directly produce the complex value
|
67 |
+
S = mag * (x + 1j * y)
|
68 |
+
audio = self.istft(S)
|
69 |
+
return audio
|
70 |
+
|
71 |
+
|
72 |
+
class IMDCTSymExpHead(FourierHead):
|
73 |
+
"""
|
74 |
+
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
|
75 |
+
|
76 |
+
Args:
|
77 |
+
dim (int): Hidden dimension of the model.
|
78 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
79 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
80 |
+
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
|
81 |
+
based on perceptual scaling. Defaults to None.
|
82 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
dim: int,
|
88 |
+
mdct_frame_len: int,
|
89 |
+
padding: str = "same",
|
90 |
+
sample_rate: Optional[int] = None,
|
91 |
+
clip_audio: bool = False,
|
92 |
+
):
|
93 |
+
super().__init__()
|
94 |
+
out_dim = mdct_frame_len // 2
|
95 |
+
self.out = nn.Linear(dim, out_dim)
|
96 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
97 |
+
self.clip_audio = clip_audio
|
98 |
+
|
99 |
+
if sample_rate is not None:
|
100 |
+
# optionally init the last layer following mel-scale
|
101 |
+
m_max = _hz_to_mel(sample_rate // 2)
|
102 |
+
m_pts = torch.linspace(0, m_max, out_dim)
|
103 |
+
f_pts = _mel_to_hz(m_pts)
|
104 |
+
scale = 1 - (f_pts / f_pts.max())
|
105 |
+
|
106 |
+
with torch.no_grad():
|
107 |
+
self.out.weight.mul_(scale.view(-1, 1))
|
108 |
+
|
109 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
110 |
+
"""
|
111 |
+
Forward pass of the IMDCTSymExpHead module.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
115 |
+
L is the sequence length, and H denotes the model dimension.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
119 |
+
"""
|
120 |
+
x = self.out(x)
|
121 |
+
x = symexp(x)
|
122 |
+
x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes
|
123 |
+
audio = self.imdct(x)
|
124 |
+
if self.clip_audio:
|
125 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
126 |
+
|
127 |
+
return audio
|
128 |
+
|
129 |
+
|
130 |
+
class IMDCTCosHead(FourierHead):
|
131 |
+
"""
|
132 |
+
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
|
133 |
+
|
134 |
+
Args:
|
135 |
+
dim (int): Hidden dimension of the model.
|
136 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
137 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
138 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
139 |
+
"""
|
140 |
+
|
141 |
+
def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False):
|
142 |
+
super().__init__()
|
143 |
+
self.clip_audio = clip_audio
|
144 |
+
self.out = nn.Linear(dim, mdct_frame_len)
|
145 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
146 |
+
|
147 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
148 |
+
"""
|
149 |
+
Forward pass of the IMDCTCosHead module.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
153 |
+
L is the sequence length, and H denotes the model dimension.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
157 |
+
"""
|
158 |
+
x = self.out(x)
|
159 |
+
m, p = x.chunk(2, dim=2)
|
160 |
+
m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes
|
161 |
+
audio = self.imdct(m * torch.cos(p))
|
162 |
+
if self.clip_audio:
|
163 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
164 |
+
return audio
|
model/vocos/helpers.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from matplotlib import pyplot as plt
|
5 |
+
from pytorch_lightning import Callback
|
6 |
+
|
7 |
+
matplotlib.use("Agg")
|
8 |
+
|
9 |
+
|
10 |
+
def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray:
|
11 |
+
"""
|
12 |
+
Save a matplotlib figure to a numpy array.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
fig (Figure): Matplotlib figure object.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
ndarray: Numpy array representing the figure.
|
19 |
+
"""
|
20 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
21 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
22 |
+
return data
|
23 |
+
|
24 |
+
|
25 |
+
def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray:
|
26 |
+
"""
|
27 |
+
Plot a spectrogram and convert it to a numpy array.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
spectrogram (ndarray): Spectrogram data.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
ndarray: Numpy array representing the plotted spectrogram.
|
34 |
+
"""
|
35 |
+
spectrogram = spectrogram.astype(np.float32)
|
36 |
+
fig, ax = plt.subplots(figsize=(12, 3))
|
37 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
38 |
+
plt.colorbar(im, ax=ax)
|
39 |
+
plt.xlabel("Frames")
|
40 |
+
plt.ylabel("Channels")
|
41 |
+
plt.tight_layout()
|
42 |
+
|
43 |
+
fig.canvas.draw()
|
44 |
+
data = save_figure_to_numpy(fig)
|
45 |
+
plt.close()
|
46 |
+
return data
|
47 |
+
|
48 |
+
|
49 |
+
class GradNormCallback(Callback):
|
50 |
+
"""
|
51 |
+
Callback to log the gradient norm.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def on_after_backward(self, trainer, model):
|
55 |
+
model.log("grad_norm", gradient_norm(model))
|
56 |
+
|
57 |
+
|
58 |
+
def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor:
|
59 |
+
"""
|
60 |
+
Compute the gradient norm.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
model (Module): PyTorch model.
|
64 |
+
norm_type (float, optional): Type of the norm. Defaults to 2.0.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Tensor: Gradient norm.
|
68 |
+
"""
|
69 |
+
grads = [p.grad for p in model.parameters() if p.grad is not None]
|
70 |
+
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type)
|
71 |
+
return total_norm
|
model/vocos/loss.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from model.vocos.offline.modules import safe_log
|
8 |
+
|
9 |
+
|
10 |
+
class MelSpecReconstructionLoss(nn.Module):
|
11 |
+
"""
|
12 |
+
L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100,
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
20 |
+
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1,
|
21 |
+
)
|
22 |
+
|
23 |
+
def forward(self, y_hat, y) -> torch.Tensor:
|
24 |
+
"""
|
25 |
+
Args:
|
26 |
+
y_hat (Tensor): Predicted audio waveform.
|
27 |
+
y (Tensor): Ground truth audio waveform.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
Tensor: L1 loss between the mel-scaled magnitude spectrograms.
|
31 |
+
"""
|
32 |
+
mel_hat = safe_log(self.mel_spec(y_hat))
|
33 |
+
mel = safe_log(self.mel_spec(y))
|
34 |
+
|
35 |
+
loss = torch.nn.functional.l1_loss(mel, mel_hat)
|
36 |
+
|
37 |
+
return loss
|
38 |
+
|
39 |
+
|
40 |
+
class GeneratorLoss(nn.Module):
|
41 |
+
"""
|
42 |
+
Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
46 |
+
"""
|
47 |
+
Args:
|
48 |
+
disc_outputs (List[Tensor]): List of discriminator outputs.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
|
52 |
+
the sub-discriminators
|
53 |
+
"""
|
54 |
+
loss = torch.zeros(1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype)
|
55 |
+
gen_losses = []
|
56 |
+
for dg in disc_outputs:
|
57 |
+
l = torch.mean(torch.clamp(1 - dg, min=0))
|
58 |
+
gen_losses.append(l)
|
59 |
+
loss += l
|
60 |
+
|
61 |
+
return loss, gen_losses
|
62 |
+
|
63 |
+
|
64 |
+
class DiscriminatorLoss(nn.Module):
|
65 |
+
"""
|
66 |
+
Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
|
67 |
+
"""
|
68 |
+
|
69 |
+
def forward(
|
70 |
+
self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
|
71 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
72 |
+
"""
|
73 |
+
Args:
|
74 |
+
disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
|
75 |
+
disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
|
79 |
+
the sub-discriminators for real outputs, and a list of
|
80 |
+
loss values for generated outputs.
|
81 |
+
"""
|
82 |
+
loss = torch.zeros(1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype)
|
83 |
+
r_losses = []
|
84 |
+
g_losses = []
|
85 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
86 |
+
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
|
87 |
+
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
|
88 |
+
loss += r_loss + g_loss
|
89 |
+
r_losses.append(r_loss)
|
90 |
+
g_losses.append(g_loss)
|
91 |
+
|
92 |
+
return loss, r_losses, g_losses
|
93 |
+
|
94 |
+
|
95 |
+
class FeatureMatchingLoss(nn.Module):
|
96 |
+
"""
|
97 |
+
Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
|
98 |
+
"""
|
99 |
+
|
100 |
+
def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
|
101 |
+
"""
|
102 |
+
Args:
|
103 |
+
fmap_r (List[List[Tensor]]): List of feature maps from real samples.
|
104 |
+
fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
Tensor: The calculated feature matching loss.
|
108 |
+
"""
|
109 |
+
loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype)
|
110 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
111 |
+
for rl, gl in zip(dr, dg):
|
112 |
+
loss += torch.mean(torch.abs(rl - gl))
|
113 |
+
|
114 |
+
return loss
|
model/vocos/models.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn.utils import weight_norm
|
6 |
+
|
7 |
+
from model.vocos.modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm
|
8 |
+
|
9 |
+
|
10 |
+
class Backbone(nn.Module):
|
11 |
+
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
12 |
+
|
13 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
14 |
+
"""
|
15 |
+
Args:
|
16 |
+
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
17 |
+
C denotes output features, and L is the sequence length.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
21 |
+
and H denotes the model dimension.
|
22 |
+
"""
|
23 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
24 |
+
|
25 |
+
|
26 |
+
class VocosBackbone(Backbone):
|
27 |
+
"""
|
28 |
+
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
29 |
+
|
30 |
+
Args:
|
31 |
+
input_channels (int): Number of input features channels.
|
32 |
+
dim (int): Hidden dimension of the model.
|
33 |
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
34 |
+
num_layers (int): Number of ConvNeXtBlock layers.
|
35 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
36 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
37 |
+
None means non-conditional model. Defaults to None.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
input_channels: int,
|
43 |
+
dim: int,
|
44 |
+
intermediate_dim: int,
|
45 |
+
num_layers: int,
|
46 |
+
layer_scale_init_value: Optional[float] = None,
|
47 |
+
adanorm_num_embeddings: Optional[int] = None,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
self.input_channels = input_channels
|
51 |
+
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
52 |
+
self.adanorm = adanorm_num_embeddings is not None
|
53 |
+
if adanorm_num_embeddings:
|
54 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
55 |
+
else:
|
56 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
57 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
58 |
+
self.convnext = nn.ModuleList(
|
59 |
+
[
|
60 |
+
ConvNeXtBlock(
|
61 |
+
dim=dim,
|
62 |
+
intermediate_dim=intermediate_dim,
|
63 |
+
layer_scale_init_value=layer_scale_init_value,
|
64 |
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
65 |
+
)
|
66 |
+
for _ in range(num_layers)
|
67 |
+
]
|
68 |
+
)
|
69 |
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
70 |
+
self.apply(self._init_weights)
|
71 |
+
|
72 |
+
def _init_weights(self, m):
|
73 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
74 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
75 |
+
nn.init.constant_(m.bias, 0)
|
76 |
+
|
77 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
78 |
+
bandwidth_id = kwargs.get('bandwidth_id', None)
|
79 |
+
x = self.embed(x)
|
80 |
+
if self.adanorm:
|
81 |
+
assert bandwidth_id is not None
|
82 |
+
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
|
83 |
+
else:
|
84 |
+
x = self.norm(x.transpose(1, 2))
|
85 |
+
x = x.transpose(1, 2)
|
86 |
+
for conv_block in self.convnext:
|
87 |
+
x = conv_block(x, cond_embedding_id=bandwidth_id)
|
88 |
+
x = self.final_layer_norm(x.transpose(1, 2))
|
89 |
+
return x
|
90 |
+
|
91 |
+
|
92 |
+
class VocosResNetBackbone(Backbone):
|
93 |
+
"""
|
94 |
+
Vocos backbone module built with ResBlocks.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
input_channels (int): Number of input features channels.
|
98 |
+
dim (int): Hidden dimension of the model.
|
99 |
+
num_blocks (int): Number of ResBlock1 blocks.
|
100 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(
|
104 |
+
self, input_channels, dim, num_blocks, layer_scale_init_value=None,
|
105 |
+
):
|
106 |
+
super().__init__()
|
107 |
+
self.input_channels = input_channels
|
108 |
+
self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1))
|
109 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
110 |
+
self.resnet = nn.Sequential(
|
111 |
+
*[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)]
|
112 |
+
)
|
113 |
+
|
114 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
115 |
+
x = self.embed(x)
|
116 |
+
x = self.resnet(x)
|
117 |
+
x = x.transpose(1, 2)
|
118 |
+
return x
|
model/vocos/modules.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
6 |
+
|
7 |
+
|
8 |
+
class ConvNeXtBlock(nn.Module):
|
9 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
dim (int): Number of input channels.
|
13 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
14 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
15 |
+
Defaults to None.
|
16 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
17 |
+
None means non-conditional LayerNorm. Defaults to None.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
dim: int,
|
23 |
+
intermediate_dim: int,
|
24 |
+
layer_scale_init_value: float,
|
25 |
+
adanorm_num_embeddings: Optional[int] = None,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
29 |
+
self.adanorm = adanorm_num_embeddings is not None
|
30 |
+
if adanorm_num_embeddings:
|
31 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
32 |
+
else:
|
33 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
34 |
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
35 |
+
self.act = nn.GELU()
|
36 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
37 |
+
self.gamma = (
|
38 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
39 |
+
if layer_scale_init_value > 0
|
40 |
+
else None
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
44 |
+
residual = x
|
45 |
+
x = self.dwconv(x)
|
46 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
47 |
+
if self.adanorm:
|
48 |
+
assert cond_embedding_id is not None
|
49 |
+
x = self.norm(x, cond_embedding_id)
|
50 |
+
else:
|
51 |
+
x = self.norm(x)
|
52 |
+
x = self.pwconv1(x)
|
53 |
+
x = self.act(x)
|
54 |
+
x = self.pwconv2(x)
|
55 |
+
if self.gamma is not None:
|
56 |
+
x = self.gamma * x
|
57 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
58 |
+
|
59 |
+
x = residual + x
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class AdaLayerNorm(nn.Module):
|
64 |
+
"""
|
65 |
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
66 |
+
|
67 |
+
Args:
|
68 |
+
num_embeddings (int): Number of embeddings.
|
69 |
+
embedding_dim (int): Dimension of the embeddings.
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
73 |
+
super().__init__()
|
74 |
+
self.eps = eps
|
75 |
+
self.dim = embedding_dim
|
76 |
+
self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
77 |
+
self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
78 |
+
torch.nn.init.ones_(self.scale.weight)
|
79 |
+
torch.nn.init.zeros_(self.shift.weight)
|
80 |
+
|
81 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
82 |
+
scale = self.scale(cond_embedding_id)
|
83 |
+
shift = self.shift(cond_embedding_id)
|
84 |
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
85 |
+
x = x * scale + shift
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class ResBlock1(nn.Module):
|
90 |
+
"""
|
91 |
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
92 |
+
but without upsampling layers.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
dim (int): Number of input channels.
|
96 |
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
97 |
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
98 |
+
Defaults to (1, 3, 5).
|
99 |
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
100 |
+
Defaults to 0.1.
|
101 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
102 |
+
Defaults to None.
|
103 |
+
"""
|
104 |
+
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
dim: int,
|
108 |
+
kernel_size: int = 3,
|
109 |
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
110 |
+
lrelu_slope: float = 0.1,
|
111 |
+
layer_scale_init_value: Optional[float] = None,
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
self.lrelu_slope = lrelu_slope
|
115 |
+
self.convs1 = nn.ModuleList(
|
116 |
+
[
|
117 |
+
weight_norm(
|
118 |
+
nn.Conv1d(
|
119 |
+
dim,
|
120 |
+
dim,
|
121 |
+
kernel_size,
|
122 |
+
1,
|
123 |
+
dilation=dilation[0],
|
124 |
+
padding=self.get_padding(kernel_size, dilation[0]),
|
125 |
+
)
|
126 |
+
),
|
127 |
+
weight_norm(
|
128 |
+
nn.Conv1d(
|
129 |
+
dim,
|
130 |
+
dim,
|
131 |
+
kernel_size,
|
132 |
+
1,
|
133 |
+
dilation=dilation[1],
|
134 |
+
padding=self.get_padding(kernel_size, dilation[1]),
|
135 |
+
)
|
136 |
+
),
|
137 |
+
weight_norm(
|
138 |
+
nn.Conv1d(
|
139 |
+
dim,
|
140 |
+
dim,
|
141 |
+
kernel_size,
|
142 |
+
1,
|
143 |
+
dilation=dilation[2],
|
144 |
+
padding=self.get_padding(kernel_size, dilation[2]),
|
145 |
+
)
|
146 |
+
),
|
147 |
+
]
|
148 |
+
)
|
149 |
+
|
150 |
+
self.convs2 = nn.ModuleList(
|
151 |
+
[
|
152 |
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
153 |
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
154 |
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
155 |
+
]
|
156 |
+
)
|
157 |
+
|
158 |
+
self.gamma = nn.ParameterList(
|
159 |
+
[
|
160 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
161 |
+
if layer_scale_init_value is not None
|
162 |
+
else None,
|
163 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
164 |
+
if layer_scale_init_value is not None
|
165 |
+
else None,
|
166 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
167 |
+
if layer_scale_init_value is not None
|
168 |
+
else None,
|
169 |
+
]
|
170 |
+
)
|
171 |
+
|
172 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
173 |
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
174 |
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
175 |
+
xt = c1(xt)
|
176 |
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
177 |
+
xt = c2(xt)
|
178 |
+
if gamma is not None:
|
179 |
+
xt = gamma * xt
|
180 |
+
x = xt + x
|
181 |
+
return x
|
182 |
+
|
183 |
+
def remove_weight_norm(self):
|
184 |
+
for l in self.convs1:
|
185 |
+
remove_weight_norm(l)
|
186 |
+
for l in self.convs2:
|
187 |
+
remove_weight_norm(l)
|
188 |
+
|
189 |
+
@staticmethod
|
190 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
191 |
+
return int((kernel_size * dilation - dilation) / 2)
|
192 |
+
|
193 |
+
|
194 |
+
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
195 |
+
"""
|
196 |
+
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
x (Tensor): Input tensor.
|
200 |
+
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
204 |
+
"""
|
205 |
+
return torch.log(torch.clip(x, min=clip_val))
|
206 |
+
|
207 |
+
|
208 |
+
def symlog(x: torch.Tensor) -> torch.Tensor:
|
209 |
+
return torch.sign(x) * torch.log1p(x.abs())
|
210 |
+
|
211 |
+
|
212 |
+
def symexp(x: torch.Tensor) -> torch.Tensor:
|
213 |
+
return torch.sign(x) * (torch.exp(x.abs()) - 1)
|
model/vocos/pretrained.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Any, Dict, Tuple, Union, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import yaml
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from torch import nn
|
9 |
+
from model.vocos.feature_extractors import FeatureExtractor, EncodecFeatures
|
10 |
+
from model.vocos.heads import FourierHead
|
11 |
+
from model.vocos.models import Backbone
|
12 |
+
|
13 |
+
|
14 |
+
def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
|
15 |
+
"""Instantiates a class with the given args and init.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
args: Positional arguments required for instantiation.
|
19 |
+
init: Dict of the form {"class_path":...,"init_args":...}.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
The instantiated class object.
|
23 |
+
"""
|
24 |
+
kwargs = init.get("init_args", {})
|
25 |
+
if not isinstance(args, tuple):
|
26 |
+
args = (args,)
|
27 |
+
class_module, class_name = init["class_path"].rsplit(".", 1)
|
28 |
+
module = __import__(class_module, fromlist=[class_name])
|
29 |
+
args_class = getattr(module, class_name)
|
30 |
+
return args_class(*args, **kwargs)
|
31 |
+
|
32 |
+
|
33 |
+
class Vocos(nn.Module):
|
34 |
+
"""
|
35 |
+
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
|
36 |
+
This class is primarily designed for inference, with support for loading from pretrained
|
37 |
+
model checkpoints. It consists of three main components: a feature extractor,
|
38 |
+
a backbone, and a head.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self, feature_extractor: nn.Module, backbone: Backbone, head: FourierHead,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
self.feature_extractor = feature_extractor
|
46 |
+
self.backbone = backbone
|
47 |
+
self.head = head
|
48 |
+
|
49 |
+
@classmethod
|
50 |
+
def from_hparams(cls, config_path: str) -> "Vocos":
|
51 |
+
"""
|
52 |
+
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
|
53 |
+
"""
|
54 |
+
with open(config_path, "r") as f:
|
55 |
+
config = yaml.safe_load(f)
|
56 |
+
feature_extractor = instantiate_class(args=(), init=config["feature_extractor"])
|
57 |
+
backbone = instantiate_class(args=(), init=config["backbone"])
|
58 |
+
head = instantiate_class(args=(), init=config["head"])
|
59 |
+
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
|
60 |
+
return model
|
61 |
+
|
62 |
+
@classmethod
|
63 |
+
def from_pretrained(self, config_path: str, model_path: str, model: nn.Module=None) -> "Vocos":
|
64 |
+
"""
|
65 |
+
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
|
66 |
+
"""
|
67 |
+
if model is None:
|
68 |
+
model = self.from_hparams(config_path)
|
69 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
70 |
+
prefixes = ("backbone", "feature_extractor", "head")
|
71 |
+
state_dict = {
|
72 |
+
key: value
|
73 |
+
for key, value in state_dict.items()
|
74 |
+
if any(key.startswith(prefix) for prefix in prefixes)
|
75 |
+
}
|
76 |
+
if isinstance(model.feature_extractor, EncodecFeatures):
|
77 |
+
encodec_parameters = {
|
78 |
+
"feature_extractor.encodec." + key: value
|
79 |
+
for key, value in model.feature_extractor.encodec.state_dict().items()
|
80 |
+
}
|
81 |
+
state_dict.update(encodec_parameters)
|
82 |
+
model.load_state_dict(state_dict)
|
83 |
+
model.eval()
|
84 |
+
return model
|
85 |
+
|
86 |
+
@torch.inference_mode()
|
87 |
+
def forward(self, features_input: torch.Tensor, X_norm, **kwargs: Any) -> torch.Tensor:
|
88 |
+
"""
|
89 |
+
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
|
90 |
+
which is then passed through the backbone and the head to reconstruct the audio output.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
|
94 |
+
where B is the batch size and L is the waveform length.
|
95 |
+
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
|
99 |
+
"""
|
100 |
+
audio_output = self.decode(features_input, **kwargs)
|
101 |
+
return audio_output / X_norm
|
102 |
+
|
103 |
+
@torch.inference_mode()
|
104 |
+
def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
105 |
+
"""
|
106 |
+
Method to decode audio waveform from already calculated features. The features input is passed through
|
107 |
+
the backbone and the head to reconstruct the audio output.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
|
111 |
+
C denotes the feature dimension, and L is the sequence length.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
|
115 |
+
"""
|
116 |
+
x = self.backbone(features_input, **kwargs)
|
117 |
+
audio_output = self.head(x)
|
118 |
+
return audio_output
|
119 |
+
|
120 |
+
@torch.inference_mode()
|
121 |
+
def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor:
|
122 |
+
"""
|
123 |
+
Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
|
124 |
+
codebook weights.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
|
128 |
+
where K is the number of codebooks, B is the batch size and L is the sequence length.
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
|
132 |
+
and L is the sequence length.
|
133 |
+
"""
|
134 |
+
assert isinstance(
|
135 |
+
self.feature_extractor, EncodecFeatures
|
136 |
+
), "Feature extractor should be an instance of EncodecFeatures"
|
137 |
+
|
138 |
+
if codes.dim() == 2:
|
139 |
+
codes = codes.unsqueeze(1)
|
140 |
+
|
141 |
+
n_bins = self.feature_extractor.encodec.quantizer.bins
|
142 |
+
offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device)
|
143 |
+
embeddings_idxs = codes + offsets.view(-1, 1, 1)
|
144 |
+
features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
|
145 |
+
features = features.transpose(1, 2)
|
146 |
+
|
147 |
+
return features
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
model = Vocos.from_pretrained(
|
152 |
+
"/nvmework3/shaonian/MelSpatialNet/MelSpatialNet/models/vocos/pretrained/pretrained_rec_normed.yaml",
|
153 |
+
"/nvmework3/shaonian/MelSpatialNet/MelSpatialNet/models/vocos/pretrained/vocos_hop128_clip1e-5_rts.ckpt").to("meta")
|
154 |
+
x = torch.randn(1, 80, 501)
|
155 |
+
x = x.to('meta')
|
156 |
+
from torch.utils.flop_counter import FlopCounterMode # requires torch>=2.1.0
|
157 |
+
with FlopCounterMode(model, display=False) as fcm:
|
158 |
+
y = model.decode(x)
|
159 |
+
flops_forward_eval = fcm.get_total_flops()
|
160 |
+
|
161 |
+
params_eval = sum(param.numel() for param in model.parameters())
|
162 |
+
print(f"flops_forward={flops_forward_eval/4e9:.2f}G, params={params_eval/1e6:.2f} M")
|
model/vocos/spectral_ops.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy
|
3 |
+
import torch
|
4 |
+
from torch import nn, view_as_real, view_as_complex
|
5 |
+
|
6 |
+
|
7 |
+
class ISTFT(nn.Module):
|
8 |
+
"""
|
9 |
+
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
10 |
+
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
11 |
+
See issue: https://github.com/pytorch/pytorch/issues/62323
|
12 |
+
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
13 |
+
The NOLA constraint is met as we trim padded samples anyway.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
n_fft (int): Size of Fourier transform.
|
17 |
+
hop_length (int): The distance between neighboring sliding window frames.
|
18 |
+
win_length (int): The size of window frame and STFT filter.
|
19 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
|
23 |
+
super().__init__()
|
24 |
+
if padding not in ["center", "same"]:
|
25 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
26 |
+
self.padding = padding
|
27 |
+
self.n_fft = n_fft
|
28 |
+
self.hop_length = hop_length
|
29 |
+
self.win_length = win_length
|
30 |
+
window = torch.hann_window(win_length)
|
31 |
+
self.register_buffer("window", window)
|
32 |
+
|
33 |
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
34 |
+
"""
|
35 |
+
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
39 |
+
N is the number of frequency bins, and T is the number of time frames.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
43 |
+
"""
|
44 |
+
if self.padding == "center":
|
45 |
+
# Fallback to pytorch native implementation
|
46 |
+
return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
|
47 |
+
elif self.padding == "same":
|
48 |
+
pad = (self.win_length - self.hop_length) // 2
|
49 |
+
else:
|
50 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
51 |
+
|
52 |
+
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
53 |
+
B, N, T = spec.shape
|
54 |
+
|
55 |
+
# Inverse FFT
|
56 |
+
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
57 |
+
ifft = ifft * self.window[None, :, None]
|
58 |
+
|
59 |
+
# Overlap and Add
|
60 |
+
output_size = (T - 1) * self.hop_length + self.win_length
|
61 |
+
y = torch.nn.functional.fold(
|
62 |
+
ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
|
63 |
+
)[:, 0, 0, pad:-pad]
|
64 |
+
|
65 |
+
# Window envelope
|
66 |
+
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
67 |
+
window_envelope = torch.nn.functional.fold(
|
68 |
+
window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
|
69 |
+
).squeeze()[pad:-pad]
|
70 |
+
|
71 |
+
# Normalize
|
72 |
+
assert (window_envelope > 1e-11).all()
|
73 |
+
y = y / window_envelope
|
74 |
+
|
75 |
+
return y
|
76 |
+
|
77 |
+
|
78 |
+
class MDCT(nn.Module):
|
79 |
+
"""
|
80 |
+
Modified Discrete Cosine Transform (MDCT) module.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
frame_len (int): Length of the MDCT frame.
|
84 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
88 |
+
super().__init__()
|
89 |
+
if padding not in ["center", "same"]:
|
90 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
91 |
+
self.padding = padding
|
92 |
+
self.frame_len = frame_len
|
93 |
+
N = frame_len // 2
|
94 |
+
n0 = (N + 1) / 2
|
95 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
96 |
+
self.register_buffer("window", window)
|
97 |
+
|
98 |
+
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
|
99 |
+
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
|
100 |
+
# view_as_real: NCCL Backend does not support ComplexFloat data type
|
101 |
+
# https://github.com/pytorch/pytorch/issues/71613
|
102 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
103 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
104 |
+
|
105 |
+
def forward(self, audio: torch.Tensor) -> torch.Tensor:
|
106 |
+
"""
|
107 |
+
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
|
111 |
+
and T is the length of the audio.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
|
115 |
+
and N is the number of frequency bins.
|
116 |
+
"""
|
117 |
+
if self.padding == "center":
|
118 |
+
audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2))
|
119 |
+
elif self.padding == "same":
|
120 |
+
# hop_length is 1/2 frame_len
|
121 |
+
audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4))
|
122 |
+
else:
|
123 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
124 |
+
|
125 |
+
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
|
126 |
+
N = self.frame_len // 2
|
127 |
+
x = x * self.window.expand(x.shape)
|
128 |
+
X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N]
|
129 |
+
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
|
130 |
+
return torch.real(res) * np.sqrt(2)
|
131 |
+
|
132 |
+
|
133 |
+
class IMDCT(nn.Module):
|
134 |
+
"""
|
135 |
+
Inverse Modified Discrete Cosine Transform (IMDCT) module.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
frame_len (int): Length of the MDCT frame.
|
139 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
143 |
+
super().__init__()
|
144 |
+
if padding not in ["center", "same"]:
|
145 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
146 |
+
self.padding = padding
|
147 |
+
self.frame_len = frame_len
|
148 |
+
N = frame_len // 2
|
149 |
+
n0 = (N + 1) / 2
|
150 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
151 |
+
self.register_buffer("window", window)
|
152 |
+
|
153 |
+
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
|
154 |
+
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
|
155 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
156 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
157 |
+
|
158 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
159 |
+
"""
|
160 |
+
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
|
164 |
+
L is the number of frames, and N is the number of frequency bins.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
|
168 |
+
"""
|
169 |
+
B, L, N = X.shape
|
170 |
+
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
|
171 |
+
Y[..., :N] = X
|
172 |
+
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
|
173 |
+
y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1)
|
174 |
+
y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2)
|
175 |
+
result = y * self.window.expand(y.shape)
|
176 |
+
output_size = (1, (L + 1) * N)
|
177 |
+
audio = torch.nn.functional.fold(
|
178 |
+
result.transpose(1, 2),
|
179 |
+
output_size=output_size,
|
180 |
+
kernel_size=(1, self.frame_len),
|
181 |
+
stride=(1, self.frame_len // 2),
|
182 |
+
)[:, 0, 0, :]
|
183 |
+
|
184 |
+
if self.padding == "center":
|
185 |
+
pad = self.frame_len // 2
|
186 |
+
elif self.padding == "same":
|
187 |
+
pad = self.frame_len // 4
|
188 |
+
else:
|
189 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
190 |
+
|
191 |
+
audio = audio[:, pad:-pad]
|
192 |
+
return audio
|