SaoYear commited on
Commit
fe17ce1
·
1 Parent(s): d49a4f8

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ .gradio
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+ import tempfile
4
+ import soundfile as sf
5
+ import gradio as gr
6
+ import librosa as lb
7
+ import yaml
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from model.cleanmel import CleanMel
11
+ from model.vocos.pretrained import Vocos
12
+ from model.stft import InputSTFT, TargetMel
13
+
14
+ DEVICE = torch.device("cuda:5")
15
+
16
+ def read_audio(file_path):
17
+ audio, sample_rate = sf.read(file_path)
18
+ if audio.ndim > 1:
19
+ audio = audio[:, 0]
20
+ if sample_rate != 16000:
21
+ audio = lb.resample(audio, orig_sr=sample_rate, target_sr=16000)
22
+ sample_rate = 16000
23
+
24
+ return torch.tensor(audio).float().squeeze().unsqueeze(0)
25
+
26
+ def stft(audio):
27
+ transform = InputSTFT(
28
+ n_fft=512,
29
+ n_win=512,
30
+ n_hop=128,
31
+ normalize=False,
32
+ center=True,
33
+ onesided=True,
34
+ online=False
35
+ ).eval().to(DEVICE)
36
+ return transform(audio)
37
+
38
+ def mel_transform(audio, X_norm):
39
+ transform = TargetMel(
40
+ sample_rate=16000,
41
+ n_fft=512,
42
+ n_win=512,
43
+ n_hop=128,
44
+ n_mels=80,
45
+ f_min=0,
46
+ f_max=8000,
47
+ power=2,
48
+ center=True,
49
+ normalize=False,
50
+ onesided=True,
51
+ mel_norm="slaney",
52
+ mel_scale="slaney",
53
+ librosa_mel=True,
54
+ online=False
55
+ ).eval().to(DEVICE)
56
+ return transform(audio, X_norm)
57
+
58
+ def load_cleanmel(model_name):
59
+ model_config = f"./configs/cleanmel_offline.yaml"
60
+ model_config = yaml.safe_load(open(model_config, "r"))["model"]["arch"]["init_args"]
61
+ cleanmel = CleanMel(**model_config)
62
+ cleanmel.load_state_dict(torch.load(f"./ckpts/CleanMel/{model_name}.ckpt"))
63
+ return cleanmel.eval()
64
+
65
+ def load_vocos():
66
+ vocos = Vocos.from_hparams(config_path="./configs/vocos_offline.yaml")
67
+ vocos = Vocos.from_pretrained(None, model_path=f"./ckpts/Vocos/vocos_offline.pt", model=vocos)
68
+ return vocos.eval()
69
+
70
+ def get_mrm_pred(Y_hat, x, X_norm):
71
+ X_noisy = mel_transform(x, X_norm)
72
+ Y_hat = Y_hat.squeeze()
73
+ Y_hat = torch.square(Y_hat * (torch.sqrt(X_noisy) + 1e-10))
74
+ return Y_hat
75
+
76
+ def safe_log(x):
77
+ return torch.log(torch.clip(x, min=1e-5))
78
+
79
+ def output(y_hat, logMel_hat):
80
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
81
+ sf.write(tmp_file.name, y_hat.squeeze().cpu().numpy(), 16000)
82
+ with tempfile.NamedTemporaryFile(suffix='.npy', delete=False) as tmp_logmel_np_file:
83
+ np.save(tmp_logmel_np_file.name, logMel_hat.squeeze().cpu().numpy())
84
+ logMel_img = logMel_hat.squeeze().cpu().numpy()[::-1, :]
85
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_logmel_img:
86
+ # give a plt figure size according to the logMel shape
87
+ plt.figure(figsize=(logMel_img.shape[1] / 100, logMel_img.shape[0] / 50))
88
+ plt.clf()
89
+ plt.imshow(logMel_img, vmin=-11, cmap="jet")
90
+ plt.tight_layout()
91
+ plt.ylabel("Mel bands")
92
+ plt.xlabel("Time (second)")
93
+ plt.yticks([0, 80], [80, 0])
94
+ dur = y_hat.shape[-1] / 16000
95
+ xticks = [int(x) for x in np.linspace(0, logMel_img.shape[-1], 11)]
96
+ xticks_str = ["{:.1f}".format(x) for x in np.linspace(0, dur, 11)]
97
+ plt.xticks(xticks, xticks_str)
98
+ plt.savefig(tmp_logmel_img.name)
99
+
100
+ return tmp_file.name, tmp_logmel_img.name, tmp_logmel_np_file.name
101
+
102
+ @spaces.GPU
103
+ @torch.inference_mode()
104
+ def enhance_cleanmel_L_mask(audio_path):
105
+ model = load_cleanmel("offline_CleanMel_L_mask").to(DEVICE)
106
+ vocos = load_vocos().to(DEVICE)
107
+ x = read_audio(audio_path).to(DEVICE)
108
+ X, X_norm = stft(x)
109
+ Y_hat = model(X, inference=True)
110
+ MRM_hat = torch.sigmoid(Y_hat)
111
+ Y_hat = get_mrm_pred(MRM_hat, x, X_norm)
112
+ logMel_hat = safe_log(Y_hat)
113
+ y_hat = vocos(logMel_hat, X_norm).clamp(min=-1, max=1)
114
+ return output(y_hat, logMel_hat)
115
+
116
+ @spaces.GPU
117
+ @torch.inference_mode()
118
+ def enhance_cleanmel_L_map(audio_path):
119
+ model = load_cleanmel("offline_CleanMel_L_map").to(DEVICE)
120
+ vocos = load_vocos().to(DEVICE)
121
+ x = read_audio(audio_path).to(DEVICE)
122
+ X, X_norm = stft(x)
123
+ logMel_hat = model(X, inference=True)
124
+ y_hat = vocos(logMel_hat, X_norm).clamp(min=-1, max=1)
125
+ return output(y_hat, logMel_hat)
126
+
127
+ def reset_everything():
128
+ """Reset all components to initial state"""
129
+ return None, None, None
130
+
131
+ if __name__ == "__main__":
132
+ demo = gr.Blocks()
133
+ with gr.Blocks(title="CleanMel Demo") as demo:
134
+ gr.Markdown("## CleanMel Demo")
135
+ gr.Markdown("This demo showcases the CleanMel model for speech enhancement.")
136
+
137
+ with gr.Row():
138
+ audio_input = gr.Audio(label="Input Audio", type="filepath", sources="upload")
139
+ with gr.Column():
140
+ enhance_button_map = gr.Button("Enhance Audio (offline CleanMel_L_map)")
141
+ enhance_button_mask = gr.Button("Enhance Audio (offline CleanMel_L_mask)")
142
+ clear_btn = gr.Button(
143
+ "🗑️ Clear All",
144
+ variant="secondary",
145
+ size="lg"
146
+ )
147
+
148
+ output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
149
+ output_mel = gr.Image(label="Output LogMel Spectrogram", type="filepath", visible=True)
150
+ output_np = gr.File(label="Enhanced LogMel Spec. (.npy)", type="filepath")
151
+
152
+ enhance_button_map.click(
153
+ enhance_cleanmel_L_map,
154
+ inputs=audio_input,
155
+ outputs=[output_audio, output_mel, output_np]
156
+ )
157
+
158
+ enhance_button_mask.click(
159
+ enhance_cleanmel_L_mask,
160
+ inputs=audio_input,
161
+ outputs=[output_audio, output_mel, output_np]
162
+ )
163
+ clear_btn.click(
164
+ fn=reset_everything,
165
+ outputs=[output_audio, output_mel, output_np]
166
+ )
167
+
168
+ demo.launch(debug=False, share=True)
ckpts/CleanMel/offline_CleanMel_L_map.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd0a5118f0c57c91c521564f8275f3a731e10a5afdba9859a3b997067e1eefdd
3
+ size 29065251
ckpts/CleanMel/offline_CleanMel_L_mask.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42ecba65266fe11d1beb97138c1e83146a57da06daf075a367e902e151f73afa
3
+ size 29065251
ckpts/Vocos/vocos_offline.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f55c3f8cd69ab92e97578d5a99f7207ece6e2c08ccde801a8402e9fa77b72d14
3
+ size 223334266
configs/cleanmel_offline.yaml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed_everything: 2
2
+
3
+ trainer:
4
+ gradient_clip_val: 10
5
+ gradient_clip_algorithm: norm
6
+ devices: null
7
+ accelerator: gpu
8
+ strategy: ddp_find_unused_parameters_false
9
+ sync_batchnorm: false
10
+ precision: 32
11
+ num_sanity_val_steps: 3
12
+ deterministic: true
13
+ max_epochs: 100
14
+ log_every_n_steps: 40
15
+
16
+ model:
17
+ arch:
18
+ class_path: model.arch.cleanmel.CleanMel
19
+ init_args:
20
+ dim_input: 2
21
+ dim_output: 1
22
+ n_layers: 16
23
+ dim_hidden: 144
24
+ layer_linear_freq: 1
25
+ f_kernel_size: 5
26
+ f_conv_groups: 8
27
+ n_freqs: 257
28
+ n_mels: 80
29
+ mamba_state: 16
30
+ mamba_conv_kernel: 4
31
+ online: false
32
+ sr: 16000
33
+ n_fft: 512
34
+ input_stft:
35
+ class_path: model.io.stft.InputSTFT
36
+ init_args:
37
+ n_fft: 512
38
+ n_win: 512
39
+ n_hop: 128
40
+ center: true
41
+ normalize: false
42
+ onesided: true
43
+ online: false
44
+ target_stft:
45
+ class_path: model.io.stft.TargetMel
46
+ init_args:
47
+ sample_rate: 16000
48
+ n_fft: 512
49
+ n_win: 512
50
+ n_hop: 128
51
+ n_mels: 80
52
+ f_min: 0
53
+ f_max: 8000
54
+ power: 2
55
+ center: true
56
+ normalize: false
57
+ onesided: true
58
+ mel_norm: "slaney"
59
+ mel_scale: "slaney"
60
+ librosa_mel: true
61
+ online: false
62
+
63
+ optimizer: [AdamW, { lr: 0.001, weight_decay: 0.001}]
64
+ lr_scheduler: [ExponentialLR, { gamma: 0.99 }]
65
+ exp_name: exp
66
+ metrics: [DNSMOS]
67
+ log_eps: 1e-5
configs/vocos_offline.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ feature_extractor:
2
+ class_path: model.vocos.feature_extractors.MelSpectrogramFeatures
3
+ init_args:
4
+ sample_rate: 16000
5
+ n_fft: 512
6
+ n_win: 512
7
+ n_hop: 128
8
+ n_mels: 80
9
+ f_min: 0
10
+ f_max: 8000
11
+ power: 2
12
+ center: true
13
+ normalize: false
14
+ onesided: true
15
+ mel_norm: slaney
16
+ mel_scale: slaney
17
+ librosa_mel: true
18
+ clip_val: 0.00001
19
+ backbone:
20
+ class_path: model.vocos.models.VocosBackbone
21
+ init_args:
22
+ input_channels: 80
23
+ dim: 512
24
+ intermediate_dim: 1536
25
+ num_layers: 8
26
+ layer_scale_init_value: null
27
+ adanorm_num_embeddings: null
28
+ head:
29
+ class_path: model.vocos.heads.ISTFTHead
30
+ init_args:
31
+ dim: 512
32
+ n_fft: 512
33
+ hop_length: 128
34
+ padding: center
35
+ sample_rate: 16000
36
+ initial_learning_rate: 0.0005
37
+ num_warmup_steps: 0
38
+ mel_loss_coeff: 45.0
39
+ mrd_loss_coeff: 0.1
40
+ pretrain_mel_steps: 0
41
+ decay_mel_coeff: false
42
+ evaluate_utmos: true
43
+ evaluate_pesq: true
44
+ evaluate_periodicty: true
model/cleanmel.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import pytorch_lightning
8
+ import librosa
9
+
10
+ from torch import Tensor
11
+ from torch.nn import Parameter, init
12
+ from torch.nn.common_types import _size_1_t
13
+
14
+ from mamba_ssm import Mamba
15
+ from mamba_ssm.utils.generation import InferenceParams
16
+
17
+ class LinearGroup(nn.Module):
18
+
19
+ def __init__(self, in_features: int, out_features: int, num_groups: int, bias: bool = True) -> None:
20
+ super(LinearGroup, self).__init__()
21
+ self.in_features = in_features
22
+ self.out_features = out_features
23
+ self.num_groups = num_groups
24
+ self.weight = Parameter(torch.empty((num_groups, out_features, in_features)))
25
+ if bias:
26
+ self.bias = Parameter(torch.empty(num_groups, out_features))
27
+ else:
28
+ self.register_parameter('bias', None)
29
+ self.reset_parameters()
30
+
31
+ def reset_parameters(self) -> None:
32
+ # same as linear
33
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
34
+ if self.bias is not None:
35
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
36
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
37
+ init.uniform_(self.bias, -bound, bound)
38
+
39
+ def forward(self, x: Tensor) -> Tensor:
40
+ """shape [..., group, feature]"""
41
+ x = torch.einsum("...gh,gkh->...gk", x, self.weight)
42
+ if self.bias is not None:
43
+ x = x + self.bias
44
+ return x
45
+
46
+ def extra_repr(self) -> str:
47
+ return f"{self.in_features}, {self.out_features}, num_groups={self.num_groups}, bias={True if self.bias is not None else False}"
48
+
49
+ class LayerNorm(nn.LayerNorm):
50
+
51
+ def __init__(self, seq_last: bool, **kwargs) -> None:
52
+ """
53
+ Arg s:
54
+ seq_last (bool): whether the sequence dim is the last dim
55
+ """
56
+ super().__init__(**kwargs)
57
+ self.seq_last = seq_last
58
+
59
+ def forward(self, input: Tensor) -> Tensor:
60
+ if self.seq_last:
61
+ input = input.transpose(-1, 1) # [B, H, Seq] -> [B, Seq, H], or [B,H,w,h] -> [B,h,w,H]
62
+ o = super().forward(input)
63
+ if self.seq_last:
64
+ o = o.transpose(-1, 1)
65
+ return o
66
+
67
+ class CausalConv1d(nn.Conv1d):
68
+
69
+ def __init__(
70
+ self,
71
+ in_channels: int,
72
+ out_channels: int,
73
+ kernel_size: _size_1_t,
74
+ stride: _size_1_t = 1,
75
+ padding: _size_1_t | str = 0,
76
+ dilation: _size_1_t = 1,
77
+ groups: int = 1,
78
+ bias: bool = True,
79
+ padding_mode: str = 'zeros',
80
+ device=None,
81
+ dtype=None,
82
+ look_ahead: int = 0,
83
+ ) -> None:
84
+ super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
85
+ self.look_ahead = look_ahead
86
+ assert look_ahead <= self.kernel_size[0] - 1, (look_ahead, self.kernel_size)
87
+
88
+ def forward(self, x: Tensor, state: Dict[int, Any] = None) -> Tensor:
89
+ # x [B,H,T]
90
+ B, H, T = x.shape
91
+ if state is None or id(self) not in state:
92
+ x = F.pad(x, pad=(self.kernel_size[0] - 1 - self.look_ahead, self.look_ahead))
93
+ else:
94
+ x = torch.concat([state[id(self)], x], dim=-1)
95
+ if state is not None:
96
+ state[id(self)] = x[..., -self.kernel_size + 1:]
97
+ x = super().forward(x)
98
+ return x
99
+
100
+ class CleanMelLayer(nn.Module):
101
+
102
+ def __init__(
103
+ self,
104
+ dim_hidden: int,
105
+ dim_squeeze: int,
106
+ n_freqs: int,
107
+ dropout: Tuple[float, float, float] = (0, 0, 0),
108
+ f_kernel_size: int = 5,
109
+ f_conv_groups: int = 8,
110
+ padding: str = 'zeros',
111
+ full: nn.Module = None,
112
+ mamba_state: int = None,
113
+ mamba_conv_kernel: int = None,
114
+ online: bool = False,
115
+ ) -> None:
116
+ super().__init__()
117
+ self.online = online
118
+ # cross-band block
119
+ # frequency-convolutional module
120
+ self.fconv1 = nn.ModuleList([
121
+ LayerNorm(seq_last=True, normalized_shape=dim_hidden),
122
+ nn.Conv1d(in_channels=dim_hidden, out_channels=dim_hidden, kernel_size=f_kernel_size, groups=f_conv_groups, padding='same', padding_mode=padding),
123
+ nn.PReLU(dim_hidden),
124
+ ])
125
+ # full-band linear module
126
+ self.norm_full = LayerNorm(seq_last=False, normalized_shape=dim_hidden)
127
+ self.full_share = False if full == None else True
128
+ self.squeeze = nn.Sequential(nn.Conv1d(in_channels=dim_hidden, out_channels=dim_squeeze, kernel_size=1), nn.SiLU())
129
+ self.dropout_full = nn.Dropout2d(dropout[2]) if dropout[2] > 0 else None
130
+ self.full = LinearGroup(n_freqs, n_freqs, num_groups=dim_squeeze) if full == None else full
131
+ self.unsqueeze = nn.Sequential(nn.Conv1d(in_channels=dim_squeeze, out_channels=dim_hidden, kernel_size=1), nn.SiLU())
132
+ # frequency-convolutional module
133
+ self.fconv2 = nn.ModuleList([
134
+ LayerNorm(seq_last=True, normalized_shape=dim_hidden),
135
+ nn.Conv1d(in_channels=dim_hidden, out_channels=dim_hidden, kernel_size=f_kernel_size, groups=f_conv_groups, padding='same', padding_mode=padding),
136
+ nn.PReLU(dim_hidden),
137
+ ])
138
+
139
+ # narrow-band block
140
+ self.norm_mamba = LayerNorm(seq_last=False, normalized_shape=dim_hidden)
141
+ if online:
142
+ self.mamba = Mamba(d_model=dim_hidden, d_state=mamba_state, d_conv=mamba_conv_kernel, layer_idx=0)
143
+ else:
144
+ self.mamba = nn.ModuleList([
145
+ Mamba(d_model=dim_hidden, d_state=mamba_state, d_conv=mamba_conv_kernel, layer_idx=0),
146
+ Mamba(d_model=dim_hidden, d_state=mamba_state, d_conv=mamba_conv_kernel, layer_idx=1),
147
+ ])
148
+
149
+ self.dropout_mamba = nn.Dropout(dropout[0])
150
+
151
+ def forward(self, x: Tensor, inference: bool = False) -> Tensor:
152
+ x = x + self._fconv(self.fconv1, x)
153
+ x = x + self._full(x)
154
+ x = x + self._fconv(self.fconv2, x)
155
+ if self.online:
156
+ x = x + self._mamba(x, self.mamba, self.norm_mamba, self.dropout_mamba, inference)
157
+ else:
158
+ x_fw = x + self._mamba(x, self.mamba[0], self.norm_mamba, self.dropout_mamba, inference)
159
+ x_bw = x.flip(dims=[2]) + self._mamba(x.flip(dims=[2]), self.mamba[1], self.norm_mamba, self.dropout_mamba, inference)
160
+ x = (x_fw + x_bw.flip(dims=[2])) / 2
161
+ return x
162
+
163
+ def _mamba(self, x: Tensor, mamba: Mamba, norm: nn.Module, dropout: nn.Module, inference: bool = False):
164
+ B, F, T, H = x.shape
165
+ x = norm(x)
166
+ x = x.reshape(B * F, T, H)
167
+ if inference:
168
+ inference_params = InferenceParams(T, B * F)
169
+ xs = []
170
+ for i in range(T):
171
+ inference_params.seqlen_offset = i
172
+ xi = mamba.forward(x[:, [i], :], inference_params)
173
+ xs.append(xi)
174
+ x = torch.concat(xs, dim=1)
175
+ else:
176
+ x = mamba.forward(x)
177
+ x = x.reshape(B, F, T, H)
178
+ return dropout(x)
179
+
180
+ def _fconv(self, ml: nn.ModuleList, x: Tensor) -> Tensor:
181
+ B, F, T, H = x.shape
182
+ x = x.permute(0, 2, 3, 1) # [B,T,H,F]
183
+ x = x.reshape(B * T, H, F)
184
+ for m in ml:
185
+ x = m(x)
186
+ x = x.reshape(B, T, H, F)
187
+ x = x.permute(0, 3, 1, 2) # [B,F,T,H]
188
+ return x
189
+
190
+ def _full(self, x: Tensor) -> Tensor:
191
+ B, F, T, H = x.shape
192
+ x = self.norm_full(x)
193
+ x = x.permute(0, 2, 3, 1) # [B,T,H,F]
194
+ x = x.reshape(B * T, H, F)
195
+ x = self.squeeze(x) # [B*T,H',F]
196
+ if self.dropout_full:
197
+ x = x.reshape(B, T, -1, F)
198
+ x = x.transpose(1, 3) # [B,F,H',T]
199
+ x = self.dropout_full(x) # dropout some frequencies in one utterance
200
+ x = x.transpose(1, 3) # [B,T,H',F]
201
+ x = x.reshape(B * T, -1, F)
202
+ x = self.full(x) # [B*T,H',F]
203
+ x = self.unsqueeze(x) # [B*T,H,F]
204
+ x = x.reshape(B, T, H, F)
205
+ x = x.permute(0, 3, 1, 2) # [B,F,T,H]
206
+ return x
207
+
208
+ def extra_repr(self) -> str:
209
+ return f"full_share={self.full_share}"
210
+
211
+
212
+ class CleanMel(nn.Module):
213
+
214
+ def __init__(
215
+ self,
216
+ dim_input: int, # the input dim for each time-frequency point
217
+ dim_output: int, # the output dim for each time-frequency point
218
+ n_layers: int,
219
+ n_freqs: int,
220
+ n_mels: int = 80,
221
+ layer_linear_freq: int = 1,
222
+ encoder_kernel_size: int = 5,
223
+ dim_hidden: int = 192,
224
+ dropout: Tuple[float, float, float] = (0, 0, 0),
225
+ f_kernel_size: int = 5,
226
+ f_conv_groups: int = 8,
227
+ padding: str = 'zeros',
228
+ mamba_state: int = 16,
229
+ mamba_conv_kernel: int = 4,
230
+ online: bool = True,
231
+ sr: int = 16000,
232
+ n_fft: int = 512,
233
+ ):
234
+ super().__init__()
235
+ self.layer_linear_freq = layer_linear_freq
236
+ self.online = online
237
+ # encoder
238
+ self.encoder = CausalConv1d(in_channels=dim_input, out_channels=dim_hidden, kernel_size=encoder_kernel_size, look_ahead=0)
239
+ # cleanmel layers
240
+ full = None
241
+ layers = []
242
+ for l in range(n_layers):
243
+ layer = CleanMelLayer(
244
+ dim_hidden=dim_hidden,
245
+ dim_squeeze=8 if l < layer_linear_freq else dim_hidden,
246
+ n_freqs=n_freqs if l < layer_linear_freq else n_mels,
247
+ dropout=dropout,
248
+ f_kernel_size=f_kernel_size,
249
+ f_conv_groups=f_conv_groups,
250
+ padding=padding,
251
+ full=full if l > layer_linear_freq else None,
252
+ online=online,
253
+ mamba_conv_kernel=mamba_conv_kernel,
254
+ mamba_state=mamba_state,
255
+ )
256
+ if hasattr(layer, 'full'):
257
+ full = layer.full
258
+ layers.append(layer)
259
+ self.layers = nn.ModuleList(layers)
260
+ # Mel filterbank
261
+ linear2mel = librosa.filters.mel(**{"sr": sr, "n_fft": n_fft, "n_mels": n_mels})
262
+ self.register_buffer("linear2mel", torch.nn.Parameter(torch.tensor(linear2mel.T, dtype=torch.float32)))
263
+ # decoder
264
+ self.decoder = nn.Linear(in_features=dim_hidden, out_features=dim_output)
265
+
266
+ def forward(self, x: Tensor, inference: bool = False) -> Tensor:
267
+ # x: [Batch, Freq, Time, Feature]
268
+ B, F, T, H0 = x.shape
269
+ x = self.encoder(x.reshape(B * F, T, H0).permute(0, 2, 1)).permute(0, 2, 1)
270
+
271
+ H = x.shape[2]
272
+ x = x.reshape(B, F, T, H)
273
+ # First Cross-Narrow band block in Linear Frequency
274
+ for i in range(self.layer_linear_freq):
275
+ m = self.layers[i]
276
+ x = m(x, inference).contiguous()
277
+
278
+ # Mel-filterbank
279
+ x = torch.einsum("bfth,fm->bmth", x, self.linear2mel)
280
+
281
+ for i in range(self.layer_linear_freq, len(self.layers)):
282
+ m = self.layers[i]
283
+ x = m(x, inference).contiguous()
284
+
285
+ y = self.decoder(x).squeeze(-1)
286
+ return y.contiguous()
287
+
288
+ if __name__ == '__main__':
289
+ # a quick demo here for the CleanMel model
290
+ # input: wavs
291
+ # output: enhanced log-mel spectrogram
292
+ pytorch_lightning.seed_everything(1234)
293
+ import soundfile as sf
294
+ import matplotlib.pyplot as plt
295
+ import numpy as np
296
+ from model.io.stft import InputSTFT
297
+ from model.io.stft import TargetMel
298
+ from torch.utils.flop_counter import FlopCounterMode
299
+
300
+ online=False
301
+ # Define input STFT and target Mel
302
+ stft = InputSTFT(
303
+ n_fft=512,
304
+ n_win=512,
305
+ n_hop=128,
306
+ center=True,
307
+ normalize=False,
308
+ onesided=True,
309
+ online=online).to("cuda")
310
+
311
+ target_mel = TargetMel(
312
+ sample_rate=16000,
313
+ n_fft=512,
314
+ n_win=512,
315
+ n_hop=128,
316
+ n_mels=80,
317
+ f_min=0,
318
+ f_max=8000,
319
+ power=2,
320
+ center=True,
321
+ normalize=False,
322
+ onesided=True,
323
+ mel_norm="slaney",
324
+ mel_scale="slaney",
325
+ librosa_mel=True,
326
+ online=online).to("cuda")
327
+
328
+ def customize_soxnorm(wav, gain=-3, factor=None):
329
+ wav = np.clip(wav, a_max=1, a_min=-1)
330
+ if factor is None:
331
+ linear_gain = 10 ** (gain / 20)
332
+ factor = linear_gain / np.abs(wav).max()
333
+ wav = wav * factor
334
+ return wav, factor
335
+ else:
336
+ wav = wav * factor
337
+ return wav, None
338
+
339
+ # Noisy file path
340
+ wav = "./src/demos/noisy_CHIME-real_F05_442C020S_STR_REAL.wav"
341
+ wavname = wav.split("/")[-1].split(".")[0]
342
+
343
+ print(f"Processing {wav}")
344
+ noisy, fs = sf.read(wav)
345
+ dur = len(noisy) / fs
346
+ noisy, factor = customize_soxnorm(noisy, gain=-3)
347
+ noisy = torch.tensor(noisy).unsqueeze(0).float().to("cuda")
348
+ # vocos norm
349
+ x = stft(noisy)
350
+ # Load the model
351
+ hidden=96
352
+ depth=8
353
+ model = CleanMel(
354
+ dim_input=2,
355
+ dim_output=1,
356
+ n_layers=depth,
357
+ dim_hidden=hidden,
358
+ layer_linear_freq=1,
359
+ f_kernel_size=5,
360
+ f_conv_groups=8,
361
+ n_freqs=257,
362
+ mamba_state=16,
363
+ mamba_conv_kernel=4,
364
+ online=online,
365
+ sr=16000,
366
+ n_fft=512
367
+ ).to("cuda")
368
+
369
+ # Load the pretrained model
370
+ state_dict = torch.load("./pretrained/CleanMel_S_L1.ckpt")
371
+ model.load_state_dict(state_dict)
372
+
373
+ model.eval()
374
+ with FlopCounterMode(model, display=False) as fcm:
375
+ y_hat = model(x, inference=False)
376
+ flops_forward_eval = fcm.get_total_flops()
377
+ params_eval = sum(param.numel() for param in model.parameters())
378
+ print(f"flops_forward={flops_forward_eval/1e9 / dur:.2f}G")
379
+ print(f"params={params_eval/1e6:.2f} M")
380
+
381
+ # y_hat is the enhanced log-mel spectrogram
382
+ y_hat = y_hat[0].cpu().detach().numpy()
383
+
384
+ # sanity check
385
+ if wavname == "noisy_CHIME-real_F05_442C020S_STR_REAL":
386
+ assert np.allclose(y_hat, np.load("./src/inference/check_CHIME-real_F05_442C020S_STR_REAL.npy"), atol=1e-5)
387
+
388
+ # plot the enhanced mel spectrogram
389
+ noisy_mel = target_mel(noisy)
390
+ noisy_mel = torch.log(noisy_mel.clamp(min=1e-5))[0].cpu().detach().numpy()
391
+ vmax = math.log(1e2)
392
+ vmin = math.log(1e-5)
393
+ plt.figure(figsize=(8, 4))
394
+ plt.subplot(2, 1, 1)
395
+ plt.imshow(noisy_mel, aspect='auto', origin='lower', cmap='jet', vmax=vmax, vmin=vmin)
396
+ plt.colorbar()
397
+ plt.subplot(2, 1, 2)
398
+ plt.imshow(y_hat, aspect='auto', origin='lower', cmap='jet', vmax=vmax, vmin=vmin)
399
+ plt.colorbar()
400
+ plt.tight_layout()
401
+ plt.savefig(f"./src/inference/{wavname}.png")
model/stft.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import librosa
3
+ import torch.nn as nn
4
+ import random
5
+ from torch import Tensor
6
+ from typing import Optional
7
+ from torchaudio.transforms import Spectrogram
8
+ from torchaudio.transforms import Spectrogram, MelScale
9
+
10
+ def soxnorm(wav: torch.Tensor, gain, factor=None):
11
+ """sox norm, used in Vocos codes;
12
+ """
13
+ wav = torch.clip(wav, max=1, min=-1).float()
14
+ if factor is None:
15
+ linear_gain = 10 ** (gain / 20)
16
+ factor = linear_gain / torch.abs(wav).max().item()
17
+ wav = wav * factor
18
+ else:
19
+ # for clean speech, normed by the noisy factor
20
+ wav = wav * factor
21
+ assert torch.all(wav.abs() <= 1), f"out wavform is not in [-1, 1], {wav.abs().max()}"
22
+ return wav, factor
23
+
24
+
25
+ class InputSTFT(nn.Module):
26
+ """
27
+ The STFT of the input signal of CleanMel (STFT coefficients);
28
+ In online mode, the recursive normalization is used.
29
+ """
30
+ def __init__(
31
+ self,
32
+ n_fft: int,
33
+ n_win: int,
34
+ n_hop: int,
35
+ center: bool,
36
+ normalize: bool,
37
+ onesided: bool,
38
+ online: bool = False):
39
+ super().__init__()
40
+
41
+ self.online = online
42
+ self.stft=Spectrogram(
43
+ n_fft=n_fft,
44
+ win_length=n_win,
45
+ hop_length=n_hop,
46
+ normalized=normalize,
47
+ center=center,
48
+ onesided=onesided,
49
+ power=None
50
+ )
51
+
52
+ def forward(self, x):
53
+ if self.online:
54
+ # recursive normalization
55
+ x = self.stft(x)
56
+ x_mag = x.abs()
57
+ x_norm = recursive_normalization(x_mag)
58
+ x = x / x_norm.clamp(min=1e-8)
59
+ x = torch.view_as_real(x)
60
+ else:
61
+ # vocos dBFS normalization
62
+ x, x_norm = soxnorm(x, random.randint(-6, -1) if self.training else -3)
63
+ x = self.stft(x)
64
+ x = torch.view_as_real(x)
65
+ return x, x_norm
66
+
67
+
68
+ class LibrosaMelScale(nn.Module):
69
+ r"""Pytorch implementation of librosa mel scale to align with common ESPNet ASR models;
70
+ You might need to define .
71
+ """
72
+ def __init__(self, n_mels, sample_rate, f_min, f_max, n_stft, norm=None, mel_scale="slaney"):
73
+ super(LibrosaMelScale, self).__init__()
74
+
75
+ _mel_options = dict(
76
+ sr=sample_rate,
77
+ n_fft=(n_stft - 1) * 2,
78
+ n_mels=n_mels,
79
+ fmin=f_min,
80
+ fmax=f_max if f_max is not None else float(sample_rate // 2),
81
+ htk=mel_scale=="htk",
82
+ norm=norm
83
+ )
84
+
85
+ fb = torch.from_numpy(librosa.filters.mel(**_mel_options).T).float()
86
+ self.register_buffer("fb", fb)
87
+
88
+ def forward(self, specgram):
89
+ mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)
90
+ return mel_specgram
91
+
92
+
93
+ class TargetMel(nn.Module):
94
+ """
95
+ This class generates the enhancement TARGET mel spectrogram;
96
+ """
97
+ def __init__(
98
+ self,
99
+ sample_rate: int,
100
+ n_fft: int,
101
+ n_win: int,
102
+ n_hop: int,
103
+ n_mels: int,
104
+ f_min: int,
105
+ f_max: int,
106
+ power: int,
107
+ center: bool,
108
+ normalize: bool,
109
+ onesided: bool,
110
+ mel_norm: str | None,
111
+ mel_scale: str,
112
+ librosa_mel: bool = True,
113
+ online: bool = False,
114
+ ):
115
+ super().__init__()
116
+ # This implementation vs torchaudio.transforms.MelSpectrogram: Add librosa melscale
117
+ # librosa melscale is numerically different from the torchaudio melscale (x_diff > 1e-5)
118
+
119
+ self.sample_rate = sample_rate
120
+ self.n_fft = n_fft
121
+ self.online = online
122
+ self.stft = Spectrogram(
123
+ n_fft=n_fft,
124
+ win_length=n_win,
125
+ hop_length=n_hop,
126
+ power=None if online else power,
127
+ normalized=normalize,
128
+ center=center,
129
+ onesided=onesided,
130
+ )
131
+ mel_method = LibrosaMelScale if librosa_mel else MelScale
132
+ self.mel_scale = mel_method(
133
+ n_mels=n_mels,
134
+ sample_rate=sample_rate,
135
+ f_min=f_min,
136
+ f_max=f_max,
137
+ n_stft=n_fft // 2 + 1,
138
+ norm=mel_norm,
139
+ mel_scale=mel_scale,
140
+ )
141
+
142
+ def forward(self, x: Tensor, x_norm=None):
143
+ if self.online:
144
+ # apply recursive normalization to target waveform
145
+ spectrogram = self.stft(x)
146
+ spectrogram = spectrogram / (x_norm + 1e-8)
147
+ spectrogram = spectrogram.abs().pow(2) # to power spectrogram
148
+ else:
149
+ # apply vocos dBFS normalization to target waveform
150
+ x, _ = soxnorm(x, None, x_norm)
151
+ spectrogram = self.stft(x)
152
+ # mel spectrogram
153
+ mel_specgram = self.mel_scale(spectrogram)
154
+ return mel_specgram
model/vocos/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
model/vocos/dataset.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+ from pytorch_lightning.utilities.types import EVAL_DATALOADERS
5
+ import torch
6
+ import torchaudio
7
+ import warnings
8
+ from pytorch_lightning import LightningDataModule
9
+ from torch.utils.data import Dataset, DataLoader
10
+
11
+ torch.set_num_threads(1)
12
+
13
+
14
+ @dataclass
15
+ class DataConfig:
16
+ filelist_path: str
17
+ sampling_rate: int
18
+ num_samples: int
19
+ batch_size: int
20
+ num_workers: int
21
+
22
+
23
+ class VocosDataModule(LightningDataModule):
24
+ def __init__(self, train_params: DataConfig, val_params: DataConfig):
25
+ super().__init__()
26
+ self.train_config = train_params
27
+ self.val_config = val_params
28
+
29
+ def _get_dataloder(self, cfg: DataConfig, train: bool):
30
+ dataset = VocosDataset(cfg, train=train)
31
+ dataloader = DataLoader(
32
+ dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True,
33
+ )
34
+ return dataloader
35
+
36
+ def train_dataloader(self) -> DataLoader:
37
+ return self._get_dataloder(self.train_config, train=True)
38
+
39
+ def val_dataloader(self) -> DataLoader:
40
+ return self._get_dataloder(self.val_config, train=False)
41
+
42
+ def test_dataloader(self) -> DataLoader:
43
+ return self.val_dataloader()
44
+
45
+ class VocosDataset(Dataset):
46
+ def __init__(self, cfg: DataConfig, train: bool):
47
+ with open(cfg.filelist_path) as f:
48
+ self.filelist = f.read().splitlines()
49
+ self.sampling_rate = cfg.sampling_rate
50
+ self.num_samples = cfg.num_samples
51
+ self.train = train
52
+
53
+ def __len__(self) -> int:
54
+ return len(self.filelist)
55
+
56
+ def customize_soxnorm(self, wav, gain=-3, factor=None):
57
+ wav = np.clip(wav, a_max=1, a_min=-1)
58
+ if factor is None:
59
+ linear_gain = 10 ** (gain / 20)
60
+ wav = wav / np.abs(wav).max() * linear_gain
61
+ return wav, linear_gain / np.abs(wav).max()
62
+ else:
63
+ wav = wav * factor
64
+ return wav, None
65
+
66
+ def __getitem__(self, index: int) -> torch.Tensor:
67
+ audio_path = self.filelist[index]
68
+ try:
69
+ y, sr = torchaudio.load(audio_path)
70
+ except:
71
+ warnings.warn(f"Error loading {audio_path}")
72
+ return self.__getitem__(np.random.randint(len(self.filelist)))
73
+ if y.size(-1) == 0:
74
+ return self.__getitem__(np.random.randint(len(self.filelist)))
75
+ if y.size(0) > 1:
76
+ # mix to mono
77
+ y = y.mean(dim=0, keepdim=True)
78
+ gain = np.random.uniform(-1, -6) if self.train else -3
79
+ y, _ = self.customize_soxnorm(y, gain)
80
+ if sr != self.sampling_rate:
81
+ y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate)
82
+ if y.size(-1) < self.num_samples:
83
+ pad_length = self.num_samples - y.size(-1)
84
+ padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
85
+ y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
86
+ elif self.train:
87
+ start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
88
+ y = y[:, start : start + self.num_samples]
89
+ else:
90
+ # During validation, take always the first segment for determinism
91
+ y = y[:, : self.num_samples]
92
+
93
+ return y[0]
model/vocos/discriminators.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import nn
6
+ from torch.nn import Conv2d
7
+ from torch.nn.utils import weight_norm
8
+ from torchaudio.transforms import Spectrogram
9
+
10
+
11
+ class MultiPeriodDiscriminator(nn.Module):
12
+ """
13
+ Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
14
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
15
+
16
+ Args:
17
+ periods (tuple[int]): Tuple of periods for each discriminator.
18
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
19
+ Defaults to None.
20
+ """
21
+
22
+ def __init__(self, periods: Tuple[int, ...] = (2, 3, 5, 7, 11), num_embeddings: Optional[int] = None):
23
+ super().__init__()
24
+ self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods])
25
+
26
+ def forward(
27
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None
28
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
29
+ y_d_rs = []
30
+ y_d_gs = []
31
+ fmap_rs = []
32
+ fmap_gs = []
33
+ for d in self.discriminators:
34
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
35
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
36
+ y_d_rs.append(y_d_r)
37
+ fmap_rs.append(fmap_r)
38
+ y_d_gs.append(y_d_g)
39
+ fmap_gs.append(fmap_g)
40
+
41
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
42
+
43
+
44
+ class DiscriminatorP(nn.Module):
45
+ def __init__(
46
+ self,
47
+ period: int,
48
+ in_channels: int = 1,
49
+ kernel_size: int = 5,
50
+ stride: int = 3,
51
+ lrelu_slope: float = 0.1,
52
+ num_embeddings: Optional[int] = None,
53
+ ):
54
+ super().__init__()
55
+ self.period = period
56
+ self.convs = nn.ModuleList(
57
+ [
58
+ weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
59
+ weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
60
+ weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
61
+ weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
62
+ weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))),
63
+ ]
64
+ )
65
+ if num_embeddings is not None:
66
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024)
67
+ torch.nn.init.zeros_(self.emb.weight)
68
+
69
+ self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
70
+ self.lrelu_slope = lrelu_slope
71
+
72
+ def forward(
73
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
74
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
75
+ x = x.unsqueeze(1)
76
+ fmap = []
77
+ # 1d to 2d
78
+ b, c, t = x.shape
79
+ if t % self.period != 0: # pad first
80
+ n_pad = self.period - (t % self.period)
81
+ x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
82
+ t = t + n_pad
83
+ x = x.view(b, c, t // self.period, self.period)
84
+
85
+ for i, l in enumerate(self.convs):
86
+ x = l(x)
87
+ x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
88
+ if i > 0:
89
+ fmap.append(x)
90
+ if cond_embedding_id is not None:
91
+ emb = self.emb(cond_embedding_id)
92
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
93
+ else:
94
+ h = 0
95
+ x = self.conv_post(x)
96
+ fmap.append(x)
97
+ x += h
98
+ x = torch.flatten(x, 1, -1)
99
+
100
+ return x, fmap
101
+
102
+
103
+ class MultiResolutionDiscriminator(nn.Module):
104
+ def __init__(
105
+ self,
106
+ fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
107
+ num_embeddings: Optional[int] = None,
108
+ ):
109
+ """
110
+ Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
111
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
112
+
113
+ Args:
114
+ fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
115
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
116
+ Defaults to None.
117
+ """
118
+
119
+ super().__init__()
120
+ self.discriminators = nn.ModuleList(
121
+ [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
122
+ )
123
+
124
+ def forward(
125
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
126
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
127
+ y_d_rs = []
128
+ y_d_gs = []
129
+ fmap_rs = []
130
+ fmap_gs = []
131
+
132
+ for d in self.discriminators:
133
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
134
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
135
+ y_d_rs.append(y_d_r)
136
+ fmap_rs.append(fmap_r)
137
+ y_d_gs.append(y_d_g)
138
+ fmap_gs.append(fmap_g)
139
+
140
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
141
+
142
+
143
+ class DiscriminatorR(nn.Module):
144
+ def __init__(
145
+ self,
146
+ window_length: int,
147
+ num_embeddings: Optional[int] = None,
148
+ channels: int = 32,
149
+ hop_factor: float = 0.25,
150
+ bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
151
+ ):
152
+ super().__init__()
153
+ self.window_length = window_length
154
+ self.hop_factor = hop_factor
155
+ self.spec_fn = Spectrogram(
156
+ n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
157
+ )
158
+ n_fft = window_length // 2 + 1
159
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
160
+ self.bands = bands
161
+ convs = lambda: nn.ModuleList(
162
+ [
163
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
164
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
165
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
166
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
167
+ weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
168
+ ]
169
+ )
170
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
171
+
172
+ if num_embeddings is not None:
173
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
174
+ torch.nn.init.zeros_(self.emb.weight)
175
+
176
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
177
+
178
+ def spectrogram(self, x):
179
+ # Remove DC offset
180
+ x = x - x.mean(dim=-1, keepdims=True)
181
+ # Peak normalize the volume of input audio
182
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
183
+ x = self.spec_fn(x)
184
+ x = torch.view_as_real(x)
185
+ x = rearrange(x, "b f t c -> b c t f")
186
+ # Split into bands
187
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
188
+ return x_bands
189
+
190
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
191
+ x_bands = self.spectrogram(x)
192
+ fmap = []
193
+ x = []
194
+ for band, stack in zip(x_bands, self.band_convs):
195
+ for i, layer in enumerate(stack):
196
+ band = layer(band)
197
+ band = torch.nn.functional.leaky_relu(band, 0.1)
198
+ if i > 0:
199
+ fmap.append(band)
200
+ x.append(band)
201
+ x = torch.cat(x, dim=-1)
202
+ if cond_embedding_id is not None:
203
+ emb = self.emb(cond_embedding_id)
204
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
205
+ else:
206
+ h = 0
207
+ x = self.conv_post(x)
208
+ fmap.append(x)
209
+ x += h
210
+
211
+ return x, fmap
model/vocos/experiment.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ import torchaudio
7
+ import transformers
8
+ import torch.nn as nn
9
+
10
+ from model.vocos.offline.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
11
+ from model.vocos.offline.feature_extractors import FeatureExtractor
12
+ from model.vocos.offline.heads import FourierHead
13
+ from model.vocos.offline.helpers import plot_spectrogram_to_numpy
14
+ from model.vocos.offline.loss import DiscriminatorLoss, GeneratorLoss, FeatureMatchingLoss, MelSpecReconstructionLoss
15
+ # from models.vocos.offline.models import Backbone
16
+ from model.vocos.offline.modules import safe_log
17
+
18
+
19
+ class VocosExp(pl.LightningModule):
20
+ # noinspection PyUnusedLocal
21
+ def __init__(
22
+ self,
23
+ feature_extractor: FeatureExtractor,
24
+ backbone: nn.Module,
25
+ head: nn.Module,
26
+ sample_rate: int,
27
+ initial_learning_rate: float,
28
+ num_warmup_steps: int = 0,
29
+ mel_loss_coeff: float = 45,
30
+ mrd_loss_coeff: float = 1.0,
31
+ pretrain_mel_steps: int = 0,
32
+ decay_mel_coeff: bool = False,
33
+ evaluate_utmos: bool = False,
34
+ evaluate_pesq: bool = False,
35
+ evaluate_periodicty: bool = False,
36
+ ):
37
+ """
38
+ Args:
39
+ feature_extractor (FeatureExtractor): An instance of FeatureExtractor to extract features from audio signals.
40
+ backbone (Backbone): An instance of Backbone model.
41
+ head (FourierHead): An instance of Fourier head to generate spectral coefficients and reconstruct a waveform.
42
+ sample_rate (int): Sampling rate of the audio signals.
43
+ initial_learning_rate (float): Initial learning rate for the optimizer.
44
+ num_warmup_steps (int): Number of steps for the warmup phase of learning rate scheduler. Default is 0.
45
+ mel_loss_coeff (float, optional): Coefficient for Mel-spectrogram loss in the loss function. Default is 45.
46
+ mrd_loss_coeff (float, optional): Coefficient for Multi Resolution Discriminator loss. Default is 1.0.
47
+ pretrain_mel_steps (int, optional): Number of steps to pre-train the model without the GAN objective. Default is 0.
48
+ decay_mel_coeff (bool, optional): If True, the Mel-spectrogram loss coefficient is decayed during training. Default is False.
49
+ evaluate_utmos (bool, optional): If True, UTMOS scores are computed for each validation run.
50
+ evaluate_pesq (bool, optional): If True, PESQ scores are computed for each validation run.
51
+ evaluate_periodicty (bool, optional): If True, periodicity scores are computed for each validation run.
52
+ """
53
+ super().__init__()
54
+ self.save_hyperparameters(ignore=["feature_extractor", "backbone", "head"])
55
+ self.feature_extractor = feature_extractor
56
+ self.backbone = backbone
57
+ self.head = head
58
+ self.sample_rate = sample_rate
59
+ self.initial_learning_rate = initial_learning_rate
60
+ self.num_warmup_steps = num_warmup_steps
61
+ self.mel_loss_coeff = mel_loss_coeff
62
+ self.mrd_loss_coeff = mrd_loss_coeff
63
+ self.pretrain_mel_steps = pretrain_mel_steps
64
+ self.decay_mel_coeff = decay_mel_coeff
65
+ self.evaluate_utmos = evaluate_utmos
66
+ self.evaluate_pesq = evaluate_pesq
67
+ self.evaluate_periodicty = evaluate_periodicty
68
+
69
+ self.multiperioddisc = MultiPeriodDiscriminator()
70
+ self.multiresddisc = MultiResolutionDiscriminator()
71
+
72
+ self.disc_loss = DiscriminatorLoss()
73
+ self.gen_loss = GeneratorLoss()
74
+ self.feat_matching_loss = FeatureMatchingLoss()
75
+ self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate)
76
+
77
+ self.train_discriminator = False
78
+ self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff
79
+ self.temp_cache=None
80
+ self.temp_grad=None
81
+
82
+ def configure_optimizers(self):
83
+ disc_params = [
84
+ {"params": self.multiperioddisc.parameters()},
85
+ {"params": self.multiresddisc.parameters()},
86
+ ]
87
+ gen_params = [
88
+ {"params": self.feature_extractor.parameters()},
89
+ {"params": self.backbone.parameters()},
90
+ {"params": self.head.parameters()},
91
+ ]
92
+
93
+ opt_disc = torch.optim.AdamW(disc_params, lr=self.initial_learning_rate, betas=(0.8, 0.9))
94
+ opt_gen = torch.optim.AdamW(gen_params, lr=self.initial_learning_rate, betas=(0.8, 0.9))
95
+
96
+ max_steps = self.trainer.max_steps // 2 # Max steps per optimizer
97
+ scheduler_disc = transformers.get_cosine_schedule_with_warmup(
98
+ opt_disc, num_warmup_steps=self.num_warmup_steps, num_training_steps=max_steps,
99
+ )
100
+ scheduler_gen = transformers.get_cosine_schedule_with_warmup(
101
+ opt_gen, num_warmup_steps=self.num_warmup_steps, num_training_steps=max_steps,
102
+ )
103
+
104
+ return (
105
+ [opt_disc, opt_gen],
106
+ [{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}],
107
+ )
108
+
109
+ def forward(self, audio_input, **kwargs):
110
+ features = self.feature_extractor(audio_input, **kwargs)
111
+ x = self.backbone(features, **kwargs)
112
+ audio_output = self.head(x)
113
+ return audio_output
114
+
115
+ def training_step(self, batch, batch_idx, optimizer_idx, **kwargs):
116
+ audio_input = batch
117
+ # train discriminator
118
+ if optimizer_idx == 0 and self.train_discriminator:
119
+ with torch.no_grad():
120
+ audio_hat = self(audio_input, **kwargs)
121
+ real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
122
+ real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
123
+ loss_mp, loss_mp_real, _ = self.disc_loss(
124
+ disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp
125
+ )
126
+ loss_mrd, loss_mrd_real, _ = self.disc_loss(
127
+ disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd
128
+ )
129
+ loss_mp /= len(loss_mp_real)
130
+ loss_mrd /= len(loss_mrd_real)
131
+ loss = loss_mp + self.mrd_loss_coeff * loss_mrd
132
+
133
+ self.log("discriminator/total", loss, prog_bar=True)
134
+ self.log("discriminator/multi_period_loss", loss_mp)
135
+ self.log("discriminator/multi_res_loss", loss_mrd)
136
+ return loss
137
+
138
+ # train generator
139
+ if optimizer_idx == 1:
140
+ audio_hat = self(audio_input, **kwargs)
141
+ if self.train_discriminator:
142
+ _, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc(
143
+ y=audio_input, y_hat=audio_hat, **kwargs,
144
+ )
145
+ _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc(
146
+ y=audio_input, y_hat=audio_hat, **kwargs,
147
+ )
148
+ loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp)
149
+ loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd)
150
+ loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp)
151
+ loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd)
152
+ loss_fm_mp = self.feat_matching_loss(fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp) / len(fmap_rs_mp)
153
+ loss_fm_mrd = self.feat_matching_loss(fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd) / len(fmap_rs_mrd)
154
+
155
+ self.log("generator/multi_period_loss", loss_gen_mp)
156
+ self.log("generator/multi_res_loss", loss_gen_mrd)
157
+ self.log("generator/feature_matching_mp", loss_fm_mp)
158
+ self.log("generator/feature_matching_mrd", loss_fm_mrd)
159
+ else:
160
+ loss_gen_mp = loss_gen_mrd = loss_fm_mp = loss_fm_mrd = 0
161
+
162
+ mel_loss = self.melspec_loss(audio_hat, audio_input)
163
+ loss = (
164
+ loss_gen_mp
165
+ + self.mrd_loss_coeff * loss_gen_mrd
166
+ + loss_fm_mp
167
+ + self.mrd_loss_coeff * loss_fm_mrd
168
+ + self.mel_loss_coeff * mel_loss
169
+ )
170
+
171
+ self.log("generator/total_loss", loss, prog_bar=True)
172
+ self.log("mel_loss_coeff", self.mel_loss_coeff)
173
+ self.log("generator/mel_loss", mel_loss)
174
+
175
+ if self.global_step % 1000 == 0 and self.global_rank == 0:
176
+ self.logger.experiment.add_audio(
177
+ "train/audio_in", audio_input[0].data.cpu(), self.global_step, self.sample_rate
178
+ )
179
+ self.logger.experiment.add_audio(
180
+ "train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.sample_rate
181
+ )
182
+ with torch.no_grad():
183
+ mel = safe_log(self.melspec_loss.mel_spec(audio_input[0]))
184
+ mel_hat = safe_log(self.melspec_loss.mel_spec(audio_hat[0]))
185
+ self.logger.experiment.add_image(
186
+ "train/mel_target",
187
+ plot_spectrogram_to_numpy(mel.data.cpu().numpy()),
188
+ self.global_step,
189
+ dataformats="HWC",
190
+ )
191
+ self.logger.experiment.add_image(
192
+ "train/mel_pred",
193
+ plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
194
+ self.global_step,
195
+ dataformats="HWC",
196
+ )
197
+ return loss
198
+
199
+ def on_validation_epoch_start(self):
200
+ if self.evaluate_utmos:
201
+ from model.vocos.metrics.UTMOS import UTMOSScore
202
+ # if not hasattr(self, "utmos_model"):
203
+ self.utmos_model = UTMOSScore(device=self.device)
204
+
205
+ def validation_step(self, batch, batch_idx, **kwargs):
206
+ audio_input = batch
207
+ audio_hat = self(audio_input, **kwargs)
208
+
209
+ audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.sample_rate, new_freq=16000)
210
+ audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.sample_rate, new_freq=16000)
211
+
212
+ if self.evaluate_periodicty:
213
+ from model.vocos.metrics.periodicity import calculate_periodicity_metrics
214
+
215
+ periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz)
216
+ else:
217
+ periodicity_loss = pitch_loss = f1_score = 0
218
+
219
+ if self.evaluate_utmos:
220
+ utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean()
221
+ else:
222
+ utmos_score = torch.zeros(1, device=self.device)
223
+
224
+ if self.evaluate_pesq:
225
+ from pesq import pesq
226
+
227
+ pesq_score = 0
228
+ for ref, deg in zip(audio_16_khz.cpu().numpy(), audio_hat_16khz.cpu().numpy()):
229
+ pesq_score += pesq(16000, ref, deg, "wb", on_error=1)
230
+ pesq_score /= len(audio_16_khz)
231
+ pesq_score = torch.tensor(pesq_score)
232
+ else:
233
+ pesq_score = torch.zeros(1, device=self.device)
234
+
235
+ mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1))
236
+ total_loss = mel_loss + (5 - utmos_score) + (5 - pesq_score)
237
+
238
+ return {
239
+ "val_loss": total_loss,
240
+ "mel_loss": mel_loss,
241
+ "utmos_score": utmos_score,
242
+ "pesq_score": pesq_score,
243
+ "periodicity_loss": periodicity_loss,
244
+ "pitch_loss": pitch_loss,
245
+ "f1_score": f1_score,
246
+ "audio_input": audio_input[0],
247
+ "audio_pred": audio_hat[0],
248
+ }
249
+
250
+ def validation_epoch_end(self, outputs):
251
+ if self.global_rank == 0:
252
+ for i, output in enumerate(outputs):
253
+ *_, audio_in, audio_pred = output.values()
254
+ self.logger.experiment.add_audio(
255
+ f"val_in_{i}", audio_in.data.cpu().numpy(), self.global_step, self.sample_rate
256
+ )
257
+ self.logger.experiment.add_audio(
258
+ f"val_pred_{i}", audio_pred.data.cpu().numpy(), self.global_step, self.sample_rate
259
+ )
260
+ mel_target = safe_log(self.melspec_loss.mel_spec(audio_in))
261
+ mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred))
262
+ self.logger.experiment.add_image(
263
+ f"val_mel_target_{i}",
264
+ plot_spectrogram_to_numpy(mel_target.data.cpu().numpy()),
265
+ self.global_step,
266
+ dataformats="HWC",
267
+ )
268
+ self.logger.experiment.add_image(
269
+ f"val_mel_hat_{i}",
270
+ plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
271
+ self.global_step,
272
+ dataformats="HWC",
273
+ )
274
+ avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
275
+ mel_loss = torch.stack([x["mel_loss"] for x in outputs]).mean()
276
+ utmos_score = torch.stack([x["utmos_score"] for x in outputs]).mean()
277
+ pesq_score = torch.stack([x["pesq_score"] for x in outputs]).mean()
278
+ periodicity_loss = np.array([x["periodicity_loss"] for x in outputs]).mean()
279
+ pitch_loss = np.array([x["pitch_loss"] for x in outputs]).mean()
280
+ f1_score = np.array([x["f1_score"] for x in outputs]).mean()
281
+
282
+ self.log("val_loss", avg_loss, sync_dist=True)
283
+ self.log("val/mel_loss", mel_loss, sync_dist=True)
284
+ self.log("val/utmos_score", utmos_score, sync_dist=True)
285
+ self.log("val/pesq_score", pesq_score, sync_dist=True)
286
+ self.log("val/periodicity_loss", periodicity_loss, sync_dist=True)
287
+ self.log("val/pitch_loss", pitch_loss, sync_dist=True)
288
+ self.log("val/f1_score", f1_score, sync_dist=True)
289
+
290
+ return {
291
+ "avg_loss": avg_loss,
292
+ "mel_loss": mel_loss,
293
+ "utmos_score": utmos_score,
294
+ "pesq_score": pesq_score,
295
+ "periodicity_loss": periodicity_loss,
296
+ "pitch_loss": pitch_loss,
297
+ "f1_score": f1_score,
298
+ }
299
+ def on_test_epoch_start(self):
300
+ self.on_validation_epoch_start()
301
+
302
+ def test_step(self, *args, **kwargs):
303
+ return self.validation_step(*args, **kwargs)
304
+
305
+ def test_epoch_end(self, outputs):
306
+ results = self.validation_epoch_end(outputs)
307
+ print(results)
308
+ @property
309
+ def global_step(self):
310
+ """
311
+ Override global_step so that it returns the total number of batches processed
312
+ """
313
+ return self.trainer.fit_loop.epoch_loop.total_batch_idx
314
+
315
+ def on_train_batch_start(self, *args):
316
+ if self.global_step >= self.pretrain_mel_steps:
317
+ self.train_discriminator = True
318
+ else:
319
+ self.train_discriminator = False
320
+
321
+ def on_train_batch_end(self, *args):
322
+ def mel_loss_coeff_decay(current_step, num_cycles=0.5):
323
+ max_steps = self.trainer.max_steps // 2
324
+ if current_step < self.num_warmup_steps:
325
+ return 1.0
326
+ progress = float(current_step - self.num_warmup_steps) / float(
327
+ max(1, max_steps - self.num_warmup_steps)
328
+ )
329
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
330
+
331
+ if self.decay_mel_coeff:
332
+ self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1)
333
+
334
+
335
+ class VocosEncodecExp(VocosExp):
336
+ """
337
+ VocosEncodecExp is a subclass of VocosExp that overrides the parent experiment to function as a conditional GAN.
338
+ It manages an additional `bandwidth_id` attribute, which denotes a learnable embedding corresponding to
339
+ a specific bandwidth value of EnCodec. During training, a random bandwidth_id is generated for each step,
340
+ while during validation, a fixed bandwidth_id is used.
341
+ """
342
+
343
+ def __init__(
344
+ self,
345
+ feature_extractor: FeatureExtractor,
346
+ backbone: pl.LightningModule,
347
+ head: pl.LightningModule,
348
+ sample_rate: int,
349
+ initial_learning_rate: float,
350
+ num_warmup_steps: int,
351
+ mel_loss_coeff: float = 45,
352
+ mrd_loss_coeff: float = 1.0,
353
+ pretrain_mel_steps: int = 0,
354
+ decay_mel_coeff: bool = False,
355
+ evaluate_utmos: bool = False,
356
+ evaluate_pesq: bool = False,
357
+ evaluate_periodicty: bool = False,
358
+ ):
359
+ super().__init__(
360
+ feature_extractor,
361
+ backbone,
362
+ head,
363
+ sample_rate,
364
+ initial_learning_rate,
365
+ num_warmup_steps,
366
+ mel_loss_coeff,
367
+ mrd_loss_coeff,
368
+ pretrain_mel_steps,
369
+ decay_mel_coeff,
370
+ evaluate_utmos,
371
+ evaluate_pesq,
372
+ evaluate_periodicty,
373
+ )
374
+ # Override with conditional discriminators
375
+ self.multiperioddisc = MultiPeriodDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
376
+ self.multiresddisc = MultiResolutionDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
377
+
378
+ def training_step(self, *args):
379
+ bandwidth_id = torch.randint(low=0, high=len(self.feature_extractor.bandwidths), size=(1,), device=self.device,)
380
+ output = super().training_step(*args, bandwidth_id=bandwidth_id)
381
+ return output
382
+
383
+ def validation_step(self, *args):
384
+ bandwidth_id = torch.tensor([0], device=self.device)
385
+ output = super().validation_step(*args, bandwidth_id=bandwidth_id)
386
+ return output
387
+
388
+ def validation_epoch_end(self, outputs):
389
+ if self.global_rank == 0:
390
+ *_, audio_in, _ = outputs[0].values()
391
+ # Resynthesis with encodec for reference
392
+ self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0])
393
+ encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :])
394
+ self.logger.experiment.add_audio(
395
+ "encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.sample_rate,
396
+ )
397
+
398
+ super().validation_epoch_end(outputs)
model/vocos/feature_extractors.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ import librosa
5
+ from encodec import EncodecModel
6
+ from torch import nn
7
+ from torch import Tensor
8
+ from typing import Optional
9
+ from torchaudio.transforms import Spectrogram, MelScale
10
+ from model.vocos.modules import safe_log
11
+
12
+
13
+ class FeatureExtractor(nn.Module):
14
+ """Base class for feature extractors."""
15
+
16
+ def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
17
+ """
18
+ Extract features from the given audio.
19
+
20
+ Args:
21
+ audio (Tensor): Input audio waveform.
22
+
23
+ Returns:
24
+ Tensor: Extracted features of shape (B, C, L), where B is the batch size,
25
+ C denotes output features, and L is the sequence length.
26
+ """
27
+ raise NotImplementedError("Subclasses must implement the forward method.")
28
+
29
+
30
+ class LibrosaMelScale(nn.Module):
31
+ r"""This MelScale has a create_fb_matrix function that can be used to create a filterbank matrix.
32
+ same as previous torchaudio version
33
+ """
34
+ __constants__ = ["n_mels", "sample_rate", "f_min", "f_max"]
35
+
36
+ def __init__(
37
+ self,
38
+ n_mels: int = 128,
39
+ sample_rate: int = 16000,
40
+ f_min: float = 0.0,
41
+ f_max: Optional[float] = None,
42
+ n_stft: int = 201,
43
+ norm: Optional[str] = None,
44
+ mel_scale: str = "htk",
45
+ ) -> None:
46
+ super(LibrosaMelScale, self).__init__()
47
+ self.n_mels = n_mels
48
+ self.sample_rate = sample_rate
49
+ self.f_max = f_max if f_max is not None else float(sample_rate // 2)
50
+ self.f_min = f_min
51
+ self.norm = norm
52
+ self.mel_scale = mel_scale
53
+
54
+ if f_min > self.f_max:
55
+ raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max))
56
+ _mel_options = dict(
57
+ sr=sample_rate,
58
+ n_fft=(n_stft - 1) * 2,
59
+ n_mels=n_mels,
60
+ fmin=f_min,
61
+ fmax=f_max,
62
+ htk=mel_scale=="htk",
63
+ norm=norm
64
+ )
65
+ fb = torch.from_numpy(librosa.filters.mel(**_mel_options).T).float()
66
+ self.register_buffer("fb", fb)
67
+
68
+ def forward(self, specgram):
69
+ mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)
70
+ return mel_specgram
71
+
72
+
73
+ class MelSpectrogramFeatures(FeatureExtractor):
74
+ def __init__(
75
+ self,
76
+ sample_rate: int,
77
+ n_fft: int,
78
+ n_win: int,
79
+ n_hop: int,
80
+ n_mels: int,
81
+ f_min: int,
82
+ f_max: int,
83
+ power: int,
84
+ center: bool,
85
+ normalize: bool,
86
+ onesided: bool,
87
+ mel_norm: str | None,
88
+ mel_scale: str,
89
+ librosa_mel: bool = True,
90
+ clip_val: float = 1e-7
91
+ ):
92
+ super().__init__()
93
+ # This implementation vs torchaudio.transforms.MelSpectrogram: Add librosa melscale
94
+ # librosa melscale is numerically different from the torchaudio melscale (x_diff > 1e-5)
95
+ self.n_fft = n_fft
96
+ self.spectrogram = Spectrogram(
97
+ n_fft=n_fft,
98
+ win_length=n_win,
99
+ hop_length=n_hop,
100
+ power=power,
101
+ normalized=normalize,
102
+ center=center,
103
+ onesided=onesided,
104
+ )
105
+ mel_method = LibrosaMelScale if librosa_mel else MelScale
106
+ self.mel_scale = mel_method(
107
+ n_mels=n_mels,
108
+ sample_rate=sample_rate,
109
+ f_min=f_min,
110
+ f_max=f_max,
111
+ n_stft=n_fft // 2 + 1,
112
+ norm=mel_norm,
113
+ mel_scale=mel_scale,
114
+ )
115
+ self.clip_val = clip_val
116
+
117
+ def forward(self, x: Tensor) -> Tensor:
118
+ # Compute Spectrogram
119
+ specgram = self.spectrogram(x)
120
+ mel_specgram = self.mel_scale(specgram)
121
+ return safe_log(mel_specgram, self.clip_val)
122
+
123
+ class EncodecFeatures(FeatureExtractor):
124
+ def __init__(
125
+ self,
126
+ encodec_model: str = "encodec_24khz",
127
+ bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0],
128
+ train_codebooks: bool = False,
129
+ ):
130
+ super().__init__()
131
+ if encodec_model == "encodec_24khz":
132
+ encodec = EncodecModel.encodec_model_24khz
133
+ elif encodec_model == "encodec_48khz":
134
+ encodec = EncodecModel.encodec_model_48khz
135
+ else:
136
+ raise ValueError(
137
+ f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz' and 'encodec_48khz'."
138
+ )
139
+ self.encodec = encodec(pretrained=True)
140
+ for param in self.encodec.parameters():
141
+ param.requires_grad = False
142
+ self.num_q = self.encodec.quantizer.get_num_quantizers_for_bandwidth(
143
+ self.encodec.frame_rate, bandwidth=max(bandwidths)
144
+ )
145
+ codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0)
146
+ self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks)
147
+ self.bandwidths = bandwidths
148
+
149
+ @torch.no_grad()
150
+ def get_encodec_codes(self, audio):
151
+ audio = audio.unsqueeze(1)
152
+ emb = self.encodec.encoder(audio)
153
+ codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
154
+ return codes
155
+
156
+ def forward(self, audio: torch.Tensor, **kwargs):
157
+ bandwidth_id = kwargs.get("bandwidth_id")
158
+ if bandwidth_id is None:
159
+ raise ValueError("The 'bandwidth_id' argument is required")
160
+ self.encodec.eval() # Force eval mode as Pytorch Lightning automatically sets child modules to training mode
161
+ self.encodec.set_target_bandwidth(self.bandwidths[bandwidth_id])
162
+ codes = self.get_encodec_codes(audio)
163
+ # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights`
164
+ # with offsets given by the number of bins, and finally summed in a vectorized operation.
165
+ offsets = torch.arange(
166
+ 0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device
167
+ )
168
+ embeddings_idxs = codes + offsets.view(-1, 1, 1)
169
+ features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0)
170
+ return features.transpose(1, 2)
model/vocos/heads.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
6
+
7
+ from model.vocos.spectral_ops import IMDCT, ISTFT
8
+ from model.vocos.modules import symexp
9
+
10
+
11
+ class FourierHead(nn.Module):
12
+ """Base class for inverse fourier modules."""
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Args:
17
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
18
+ L is the sequence length, and H denotes the model dimension.
19
+
20
+ Returns:
21
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
22
+ """
23
+ raise NotImplementedError("Subclasses must implement the forward method.")
24
+
25
+
26
+ class ISTFTHead(FourierHead):
27
+ """
28
+ ISTFT Head module for predicting STFT complex coefficients.
29
+
30
+ Args:
31
+ dim (int): Hidden dimension of the model.
32
+ n_fft (int): Size of Fourier transform.
33
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
34
+ the resolution of the input features.
35
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
36
+ """
37
+
38
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
39
+ super().__init__()
40
+ out_dim = n_fft + 2
41
+ self.out = torch.nn.Linear(dim, out_dim)
42
+ self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
43
+
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ """
46
+ Forward pass of the ISTFTHead module.
47
+
48
+ Args:
49
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
50
+ L is the sequence length, and H denotes the model dimension.
51
+
52
+ Returns:
53
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
54
+ """
55
+ x = self.out(x).transpose(1, 2)
56
+ mag, p = x.chunk(2, dim=1)
57
+ mag = torch.exp(mag)
58
+ mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
59
+ # wrapping happens here. These two lines produce real and imaginary value
60
+ x = torch.cos(p)
61
+ y = torch.sin(p)
62
+ # recalculating phase here does not produce anything new
63
+ # only costs time
64
+ # phase = torch.atan2(y, x)
65
+ # S = mag * torch.exp(phase * 1j)
66
+ # better directly produce the complex value
67
+ S = mag * (x + 1j * y)
68
+ audio = self.istft(S)
69
+ return audio
70
+
71
+
72
+ class IMDCTSymExpHead(FourierHead):
73
+ """
74
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
75
+
76
+ Args:
77
+ dim (int): Hidden dimension of the model.
78
+ mdct_frame_len (int): Length of the MDCT frame.
79
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
80
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
81
+ based on perceptual scaling. Defaults to None.
82
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ dim: int,
88
+ mdct_frame_len: int,
89
+ padding: str = "same",
90
+ sample_rate: Optional[int] = None,
91
+ clip_audio: bool = False,
92
+ ):
93
+ super().__init__()
94
+ out_dim = mdct_frame_len // 2
95
+ self.out = nn.Linear(dim, out_dim)
96
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
97
+ self.clip_audio = clip_audio
98
+
99
+ if sample_rate is not None:
100
+ # optionally init the last layer following mel-scale
101
+ m_max = _hz_to_mel(sample_rate // 2)
102
+ m_pts = torch.linspace(0, m_max, out_dim)
103
+ f_pts = _mel_to_hz(m_pts)
104
+ scale = 1 - (f_pts / f_pts.max())
105
+
106
+ with torch.no_grad():
107
+ self.out.weight.mul_(scale.view(-1, 1))
108
+
109
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
110
+ """
111
+ Forward pass of the IMDCTSymExpHead module.
112
+
113
+ Args:
114
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
115
+ L is the sequence length, and H denotes the model dimension.
116
+
117
+ Returns:
118
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
119
+ """
120
+ x = self.out(x)
121
+ x = symexp(x)
122
+ x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes
123
+ audio = self.imdct(x)
124
+ if self.clip_audio:
125
+ audio = torch.clip(x, min=-1.0, max=1.0)
126
+
127
+ return audio
128
+
129
+
130
+ class IMDCTCosHead(FourierHead):
131
+ """
132
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
133
+
134
+ Args:
135
+ dim (int): Hidden dimension of the model.
136
+ mdct_frame_len (int): Length of the MDCT frame.
137
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
138
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
139
+ """
140
+
141
+ def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False):
142
+ super().__init__()
143
+ self.clip_audio = clip_audio
144
+ self.out = nn.Linear(dim, mdct_frame_len)
145
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
146
+
147
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
148
+ """
149
+ Forward pass of the IMDCTCosHead module.
150
+
151
+ Args:
152
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
153
+ L is the sequence length, and H denotes the model dimension.
154
+
155
+ Returns:
156
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
157
+ """
158
+ x = self.out(x)
159
+ m, p = x.chunk(2, dim=2)
160
+ m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes
161
+ audio = self.imdct(m * torch.cos(p))
162
+ if self.clip_audio:
163
+ audio = torch.clip(x, min=-1.0, max=1.0)
164
+ return audio
model/vocos/helpers.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import numpy as np
3
+ import torch
4
+ from matplotlib import pyplot as plt
5
+ from pytorch_lightning import Callback
6
+
7
+ matplotlib.use("Agg")
8
+
9
+
10
+ def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray:
11
+ """
12
+ Save a matplotlib figure to a numpy array.
13
+
14
+ Args:
15
+ fig (Figure): Matplotlib figure object.
16
+
17
+ Returns:
18
+ ndarray: Numpy array representing the figure.
19
+ """
20
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
21
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
22
+ return data
23
+
24
+
25
+ def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray:
26
+ """
27
+ Plot a spectrogram and convert it to a numpy array.
28
+
29
+ Args:
30
+ spectrogram (ndarray): Spectrogram data.
31
+
32
+ Returns:
33
+ ndarray: Numpy array representing the plotted spectrogram.
34
+ """
35
+ spectrogram = spectrogram.astype(np.float32)
36
+ fig, ax = plt.subplots(figsize=(12, 3))
37
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
38
+ plt.colorbar(im, ax=ax)
39
+ plt.xlabel("Frames")
40
+ plt.ylabel("Channels")
41
+ plt.tight_layout()
42
+
43
+ fig.canvas.draw()
44
+ data = save_figure_to_numpy(fig)
45
+ plt.close()
46
+ return data
47
+
48
+
49
+ class GradNormCallback(Callback):
50
+ """
51
+ Callback to log the gradient norm.
52
+ """
53
+
54
+ def on_after_backward(self, trainer, model):
55
+ model.log("grad_norm", gradient_norm(model))
56
+
57
+
58
+ def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor:
59
+ """
60
+ Compute the gradient norm.
61
+
62
+ Args:
63
+ model (Module): PyTorch model.
64
+ norm_type (float, optional): Type of the norm. Defaults to 2.0.
65
+
66
+ Returns:
67
+ Tensor: Gradient norm.
68
+ """
69
+ grads = [p.grad for p in model.parameters() if p.grad is not None]
70
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type)
71
+ return total_norm
model/vocos/loss.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ import torchaudio
5
+ from torch import nn
6
+
7
+ from model.vocos.offline.modules import safe_log
8
+
9
+
10
+ class MelSpecReconstructionLoss(nn.Module):
11
+ """
12
+ L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
13
+ """
14
+
15
+ def __init__(
16
+ self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100,
17
+ ):
18
+ super().__init__()
19
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
20
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1,
21
+ )
22
+
23
+ def forward(self, y_hat, y) -> torch.Tensor:
24
+ """
25
+ Args:
26
+ y_hat (Tensor): Predicted audio waveform.
27
+ y (Tensor): Ground truth audio waveform.
28
+
29
+ Returns:
30
+ Tensor: L1 loss between the mel-scaled magnitude spectrograms.
31
+ """
32
+ mel_hat = safe_log(self.mel_spec(y_hat))
33
+ mel = safe_log(self.mel_spec(y))
34
+
35
+ loss = torch.nn.functional.l1_loss(mel, mel_hat)
36
+
37
+ return loss
38
+
39
+
40
+ class GeneratorLoss(nn.Module):
41
+ """
42
+ Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
43
+ """
44
+
45
+ def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
46
+ """
47
+ Args:
48
+ disc_outputs (List[Tensor]): List of discriminator outputs.
49
+
50
+ Returns:
51
+ Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
52
+ the sub-discriminators
53
+ """
54
+ loss = torch.zeros(1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype)
55
+ gen_losses = []
56
+ for dg in disc_outputs:
57
+ l = torch.mean(torch.clamp(1 - dg, min=0))
58
+ gen_losses.append(l)
59
+ loss += l
60
+
61
+ return loss, gen_losses
62
+
63
+
64
+ class DiscriminatorLoss(nn.Module):
65
+ """
66
+ Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
67
+ """
68
+
69
+ def forward(
70
+ self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
71
+ ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
72
+ """
73
+ Args:
74
+ disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
75
+ disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
76
+
77
+ Returns:
78
+ Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
79
+ the sub-discriminators for real outputs, and a list of
80
+ loss values for generated outputs.
81
+ """
82
+ loss = torch.zeros(1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype)
83
+ r_losses = []
84
+ g_losses = []
85
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
86
+ r_loss = torch.mean(torch.clamp(1 - dr, min=0))
87
+ g_loss = torch.mean(torch.clamp(1 + dg, min=0))
88
+ loss += r_loss + g_loss
89
+ r_losses.append(r_loss)
90
+ g_losses.append(g_loss)
91
+
92
+ return loss, r_losses, g_losses
93
+
94
+
95
+ class FeatureMatchingLoss(nn.Module):
96
+ """
97
+ Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
98
+ """
99
+
100
+ def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
101
+ """
102
+ Args:
103
+ fmap_r (List[List[Tensor]]): List of feature maps from real samples.
104
+ fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
105
+
106
+ Returns:
107
+ Tensor: The calculated feature matching loss.
108
+ """
109
+ loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype)
110
+ for dr, dg in zip(fmap_r, fmap_g):
111
+ for rl, gl in zip(dr, dg):
112
+ loss += torch.mean(torch.abs(rl - gl))
113
+
114
+ return loss
model/vocos/models.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.utils import weight_norm
6
+
7
+ from model.vocos.modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm
8
+
9
+
10
+ class Backbone(nn.Module):
11
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
12
+
13
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
14
+ """
15
+ Args:
16
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
17
+ C denotes output features, and L is the sequence length.
18
+
19
+ Returns:
20
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
21
+ and H denotes the model dimension.
22
+ """
23
+ raise NotImplementedError("Subclasses must implement the forward method.")
24
+
25
+
26
+ class VocosBackbone(Backbone):
27
+ """
28
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
29
+
30
+ Args:
31
+ input_channels (int): Number of input features channels.
32
+ dim (int): Hidden dimension of the model.
33
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
34
+ num_layers (int): Number of ConvNeXtBlock layers.
35
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
36
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
37
+ None means non-conditional model. Defaults to None.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ input_channels: int,
43
+ dim: int,
44
+ intermediate_dim: int,
45
+ num_layers: int,
46
+ layer_scale_init_value: Optional[float] = None,
47
+ adanorm_num_embeddings: Optional[int] = None,
48
+ ):
49
+ super().__init__()
50
+ self.input_channels = input_channels
51
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
52
+ self.adanorm = adanorm_num_embeddings is not None
53
+ if adanorm_num_embeddings:
54
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
55
+ else:
56
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
57
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
58
+ self.convnext = nn.ModuleList(
59
+ [
60
+ ConvNeXtBlock(
61
+ dim=dim,
62
+ intermediate_dim=intermediate_dim,
63
+ layer_scale_init_value=layer_scale_init_value,
64
+ adanorm_num_embeddings=adanorm_num_embeddings,
65
+ )
66
+ for _ in range(num_layers)
67
+ ]
68
+ )
69
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
70
+ self.apply(self._init_weights)
71
+
72
+ def _init_weights(self, m):
73
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
74
+ nn.init.trunc_normal_(m.weight, std=0.02)
75
+ nn.init.constant_(m.bias, 0)
76
+
77
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
78
+ bandwidth_id = kwargs.get('bandwidth_id', None)
79
+ x = self.embed(x)
80
+ if self.adanorm:
81
+ assert bandwidth_id is not None
82
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
83
+ else:
84
+ x = self.norm(x.transpose(1, 2))
85
+ x = x.transpose(1, 2)
86
+ for conv_block in self.convnext:
87
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
88
+ x = self.final_layer_norm(x.transpose(1, 2))
89
+ return x
90
+
91
+
92
+ class VocosResNetBackbone(Backbone):
93
+ """
94
+ Vocos backbone module built with ResBlocks.
95
+
96
+ Args:
97
+ input_channels (int): Number of input features channels.
98
+ dim (int): Hidden dimension of the model.
99
+ num_blocks (int): Number of ResBlock1 blocks.
100
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
101
+ """
102
+
103
+ def __init__(
104
+ self, input_channels, dim, num_blocks, layer_scale_init_value=None,
105
+ ):
106
+ super().__init__()
107
+ self.input_channels = input_channels
108
+ self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1))
109
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
110
+ self.resnet = nn.Sequential(
111
+ *[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)]
112
+ )
113
+
114
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
115
+ x = self.embed(x)
116
+ x = self.resnet(x)
117
+ x = x.transpose(1, 2)
118
+ return x
model/vocos/modules.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.utils import weight_norm, remove_weight_norm
6
+
7
+
8
+ class ConvNeXtBlock(nn.Module):
9
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
10
+
11
+ Args:
12
+ dim (int): Number of input channels.
13
+ intermediate_dim (int): Dimensionality of the intermediate layer.
14
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
15
+ Defaults to None.
16
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
17
+ None means non-conditional LayerNorm. Defaults to None.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ dim: int,
23
+ intermediate_dim: int,
24
+ layer_scale_init_value: float,
25
+ adanorm_num_embeddings: Optional[int] = None,
26
+ ):
27
+ super().__init__()
28
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
29
+ self.adanorm = adanorm_num_embeddings is not None
30
+ if adanorm_num_embeddings:
31
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
32
+ else:
33
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
34
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
35
+ self.act = nn.GELU()
36
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
37
+ self.gamma = (
38
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
39
+ if layer_scale_init_value > 0
40
+ else None
41
+ )
42
+
43
+ def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
44
+ residual = x
45
+ x = self.dwconv(x)
46
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
47
+ if self.adanorm:
48
+ assert cond_embedding_id is not None
49
+ x = self.norm(x, cond_embedding_id)
50
+ else:
51
+ x = self.norm(x)
52
+ x = self.pwconv1(x)
53
+ x = self.act(x)
54
+ x = self.pwconv2(x)
55
+ if self.gamma is not None:
56
+ x = self.gamma * x
57
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
58
+
59
+ x = residual + x
60
+ return x
61
+
62
+
63
+ class AdaLayerNorm(nn.Module):
64
+ """
65
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
66
+
67
+ Args:
68
+ num_embeddings (int): Number of embeddings.
69
+ embedding_dim (int): Dimension of the embeddings.
70
+ """
71
+
72
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
73
+ super().__init__()
74
+ self.eps = eps
75
+ self.dim = embedding_dim
76
+ self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
77
+ self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
78
+ torch.nn.init.ones_(self.scale.weight)
79
+ torch.nn.init.zeros_(self.shift.weight)
80
+
81
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
82
+ scale = self.scale(cond_embedding_id)
83
+ shift = self.shift(cond_embedding_id)
84
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
85
+ x = x * scale + shift
86
+ return x
87
+
88
+
89
+ class ResBlock1(nn.Module):
90
+ """
91
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
92
+ but without upsampling layers.
93
+
94
+ Args:
95
+ dim (int): Number of input channels.
96
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
97
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
98
+ Defaults to (1, 3, 5).
99
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
100
+ Defaults to 0.1.
101
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
102
+ Defaults to None.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ dim: int,
108
+ kernel_size: int = 3,
109
+ dilation: Tuple[int, int, int] = (1, 3, 5),
110
+ lrelu_slope: float = 0.1,
111
+ layer_scale_init_value: Optional[float] = None,
112
+ ):
113
+ super().__init__()
114
+ self.lrelu_slope = lrelu_slope
115
+ self.convs1 = nn.ModuleList(
116
+ [
117
+ weight_norm(
118
+ nn.Conv1d(
119
+ dim,
120
+ dim,
121
+ kernel_size,
122
+ 1,
123
+ dilation=dilation[0],
124
+ padding=self.get_padding(kernel_size, dilation[0]),
125
+ )
126
+ ),
127
+ weight_norm(
128
+ nn.Conv1d(
129
+ dim,
130
+ dim,
131
+ kernel_size,
132
+ 1,
133
+ dilation=dilation[1],
134
+ padding=self.get_padding(kernel_size, dilation[1]),
135
+ )
136
+ ),
137
+ weight_norm(
138
+ nn.Conv1d(
139
+ dim,
140
+ dim,
141
+ kernel_size,
142
+ 1,
143
+ dilation=dilation[2],
144
+ padding=self.get_padding(kernel_size, dilation[2]),
145
+ )
146
+ ),
147
+ ]
148
+ )
149
+
150
+ self.convs2 = nn.ModuleList(
151
+ [
152
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
153
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
154
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
155
+ ]
156
+ )
157
+
158
+ self.gamma = nn.ParameterList(
159
+ [
160
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
161
+ if layer_scale_init_value is not None
162
+ else None,
163
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
164
+ if layer_scale_init_value is not None
165
+ else None,
166
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
167
+ if layer_scale_init_value is not None
168
+ else None,
169
+ ]
170
+ )
171
+
172
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
173
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
174
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
175
+ xt = c1(xt)
176
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
177
+ xt = c2(xt)
178
+ if gamma is not None:
179
+ xt = gamma * xt
180
+ x = xt + x
181
+ return x
182
+
183
+ def remove_weight_norm(self):
184
+ for l in self.convs1:
185
+ remove_weight_norm(l)
186
+ for l in self.convs2:
187
+ remove_weight_norm(l)
188
+
189
+ @staticmethod
190
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
191
+ return int((kernel_size * dilation - dilation) / 2)
192
+
193
+
194
+ def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
195
+ """
196
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
197
+
198
+ Args:
199
+ x (Tensor): Input tensor.
200
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
201
+
202
+ Returns:
203
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
204
+ """
205
+ return torch.log(torch.clip(x, min=clip_val))
206
+
207
+
208
+ def symlog(x: torch.Tensor) -> torch.Tensor:
209
+ return torch.sign(x) * torch.log1p(x.abs())
210
+
211
+
212
+ def symexp(x: torch.Tensor) -> torch.Tensor:
213
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
model/vocos/pretrained.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Tuple, Union, Optional
4
+
5
+ import torch
6
+ import yaml
7
+ from huggingface_hub import hf_hub_download
8
+ from torch import nn
9
+ from model.vocos.feature_extractors import FeatureExtractor, EncodecFeatures
10
+ from model.vocos.heads import FourierHead
11
+ from model.vocos.models import Backbone
12
+
13
+
14
+ def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
15
+ """Instantiates a class with the given args and init.
16
+
17
+ Args:
18
+ args: Positional arguments required for instantiation.
19
+ init: Dict of the form {"class_path":...,"init_args":...}.
20
+
21
+ Returns:
22
+ The instantiated class object.
23
+ """
24
+ kwargs = init.get("init_args", {})
25
+ if not isinstance(args, tuple):
26
+ args = (args,)
27
+ class_module, class_name = init["class_path"].rsplit(".", 1)
28
+ module = __import__(class_module, fromlist=[class_name])
29
+ args_class = getattr(module, class_name)
30
+ return args_class(*args, **kwargs)
31
+
32
+
33
+ class Vocos(nn.Module):
34
+ """
35
+ The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
36
+ This class is primarily designed for inference, with support for loading from pretrained
37
+ model checkpoints. It consists of three main components: a feature extractor,
38
+ a backbone, and a head.
39
+ """
40
+
41
+ def __init__(
42
+ self, feature_extractor: nn.Module, backbone: Backbone, head: FourierHead,
43
+ ):
44
+ super().__init__()
45
+ self.feature_extractor = feature_extractor
46
+ self.backbone = backbone
47
+ self.head = head
48
+
49
+ @classmethod
50
+ def from_hparams(cls, config_path: str) -> "Vocos":
51
+ """
52
+ Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
53
+ """
54
+ with open(config_path, "r") as f:
55
+ config = yaml.safe_load(f)
56
+ feature_extractor = instantiate_class(args=(), init=config["feature_extractor"])
57
+ backbone = instantiate_class(args=(), init=config["backbone"])
58
+ head = instantiate_class(args=(), init=config["head"])
59
+ model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
60
+ return model
61
+
62
+ @classmethod
63
+ def from_pretrained(self, config_path: str, model_path: str, model: nn.Module=None) -> "Vocos":
64
+ """
65
+ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
66
+ """
67
+ if model is None:
68
+ model = self.from_hparams(config_path)
69
+ state_dict = torch.load(model_path, map_location="cpu")
70
+ prefixes = ("backbone", "feature_extractor", "head")
71
+ state_dict = {
72
+ key: value
73
+ for key, value in state_dict.items()
74
+ if any(key.startswith(prefix) for prefix in prefixes)
75
+ }
76
+ if isinstance(model.feature_extractor, EncodecFeatures):
77
+ encodec_parameters = {
78
+ "feature_extractor.encodec." + key: value
79
+ for key, value in model.feature_extractor.encodec.state_dict().items()
80
+ }
81
+ state_dict.update(encodec_parameters)
82
+ model.load_state_dict(state_dict)
83
+ model.eval()
84
+ return model
85
+
86
+ @torch.inference_mode()
87
+ def forward(self, features_input: torch.Tensor, X_norm, **kwargs: Any) -> torch.Tensor:
88
+ """
89
+ Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
90
+ which is then passed through the backbone and the head to reconstruct the audio output.
91
+
92
+ Args:
93
+ audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
94
+ where B is the batch size and L is the waveform length.
95
+
96
+
97
+ Returns:
98
+ Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
99
+ """
100
+ audio_output = self.decode(features_input, **kwargs)
101
+ return audio_output / X_norm
102
+
103
+ @torch.inference_mode()
104
+ def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
105
+ """
106
+ Method to decode audio waveform from already calculated features. The features input is passed through
107
+ the backbone and the head to reconstruct the audio output.
108
+
109
+ Args:
110
+ features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
111
+ C denotes the feature dimension, and L is the sequence length.
112
+
113
+ Returns:
114
+ Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
115
+ """
116
+ x = self.backbone(features_input, **kwargs)
117
+ audio_output = self.head(x)
118
+ return audio_output
119
+
120
+ @torch.inference_mode()
121
+ def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor:
122
+ """
123
+ Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
124
+ codebook weights.
125
+
126
+ Args:
127
+ codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
128
+ where K is the number of codebooks, B is the batch size and L is the sequence length.
129
+
130
+ Returns:
131
+ Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
132
+ and L is the sequence length.
133
+ """
134
+ assert isinstance(
135
+ self.feature_extractor, EncodecFeatures
136
+ ), "Feature extractor should be an instance of EncodecFeatures"
137
+
138
+ if codes.dim() == 2:
139
+ codes = codes.unsqueeze(1)
140
+
141
+ n_bins = self.feature_extractor.encodec.quantizer.bins
142
+ offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device)
143
+ embeddings_idxs = codes + offsets.view(-1, 1, 1)
144
+ features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
145
+ features = features.transpose(1, 2)
146
+
147
+ return features
148
+
149
+
150
+ if __name__ == "__main__":
151
+ model = Vocos.from_pretrained(
152
+ "/nvmework3/shaonian/MelSpatialNet/MelSpatialNet/models/vocos/pretrained/pretrained_rec_normed.yaml",
153
+ "/nvmework3/shaonian/MelSpatialNet/MelSpatialNet/models/vocos/pretrained/vocos_hop128_clip1e-5_rts.ckpt").to("meta")
154
+ x = torch.randn(1, 80, 501)
155
+ x = x.to('meta')
156
+ from torch.utils.flop_counter import FlopCounterMode # requires torch>=2.1.0
157
+ with FlopCounterMode(model, display=False) as fcm:
158
+ y = model.decode(x)
159
+ flops_forward_eval = fcm.get_total_flops()
160
+
161
+ params_eval = sum(param.numel() for param in model.parameters())
162
+ print(f"flops_forward={flops_forward_eval/4e9:.2f}G, params={params_eval/1e6:.2f} M")
model/vocos/spectral_ops.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy
3
+ import torch
4
+ from torch import nn, view_as_real, view_as_complex
5
+
6
+
7
+ class ISTFT(nn.Module):
8
+ """
9
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
10
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
11
+ See issue: https://github.com/pytorch/pytorch/issues/62323
12
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
13
+ The NOLA constraint is met as we trim padded samples anyway.
14
+
15
+ Args:
16
+ n_fft (int): Size of Fourier transform.
17
+ hop_length (int): The distance between neighboring sliding window frames.
18
+ win_length (int): The size of window frame and STFT filter.
19
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
20
+ """
21
+
22
+ def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
23
+ super().__init__()
24
+ if padding not in ["center", "same"]:
25
+ raise ValueError("Padding must be 'center' or 'same'.")
26
+ self.padding = padding
27
+ self.n_fft = n_fft
28
+ self.hop_length = hop_length
29
+ self.win_length = win_length
30
+ window = torch.hann_window(win_length)
31
+ self.register_buffer("window", window)
32
+
33
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
34
+ """
35
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
36
+
37
+ Args:
38
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
39
+ N is the number of frequency bins, and T is the number of time frames.
40
+
41
+ Returns:
42
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
43
+ """
44
+ if self.padding == "center":
45
+ # Fallback to pytorch native implementation
46
+ return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
47
+ elif self.padding == "same":
48
+ pad = (self.win_length - self.hop_length) // 2
49
+ else:
50
+ raise ValueError("Padding must be 'center' or 'same'.")
51
+
52
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
53
+ B, N, T = spec.shape
54
+
55
+ # Inverse FFT
56
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
57
+ ifft = ifft * self.window[None, :, None]
58
+
59
+ # Overlap and Add
60
+ output_size = (T - 1) * self.hop_length + self.win_length
61
+ y = torch.nn.functional.fold(
62
+ ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
63
+ )[:, 0, 0, pad:-pad]
64
+
65
+ # Window envelope
66
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
67
+ window_envelope = torch.nn.functional.fold(
68
+ window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
69
+ ).squeeze()[pad:-pad]
70
+
71
+ # Normalize
72
+ assert (window_envelope > 1e-11).all()
73
+ y = y / window_envelope
74
+
75
+ return y
76
+
77
+
78
+ class MDCT(nn.Module):
79
+ """
80
+ Modified Discrete Cosine Transform (MDCT) module.
81
+
82
+ Args:
83
+ frame_len (int): Length of the MDCT frame.
84
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
85
+ """
86
+
87
+ def __init__(self, frame_len: int, padding: str = "same"):
88
+ super().__init__()
89
+ if padding not in ["center", "same"]:
90
+ raise ValueError("Padding must be 'center' or 'same'.")
91
+ self.padding = padding
92
+ self.frame_len = frame_len
93
+ N = frame_len // 2
94
+ n0 = (N + 1) / 2
95
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
96
+ self.register_buffer("window", window)
97
+
98
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
99
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
100
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
101
+ # https://github.com/pytorch/pytorch/issues/71613
102
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
103
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
104
+
105
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
106
+ """
107
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
108
+
109
+ Args:
110
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
111
+ and T is the length of the audio.
112
+
113
+ Returns:
114
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
115
+ and N is the number of frequency bins.
116
+ """
117
+ if self.padding == "center":
118
+ audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2))
119
+ elif self.padding == "same":
120
+ # hop_length is 1/2 frame_len
121
+ audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4))
122
+ else:
123
+ raise ValueError("Padding must be 'center' or 'same'.")
124
+
125
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
126
+ N = self.frame_len // 2
127
+ x = x * self.window.expand(x.shape)
128
+ X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N]
129
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
130
+ return torch.real(res) * np.sqrt(2)
131
+
132
+
133
+ class IMDCT(nn.Module):
134
+ """
135
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
136
+
137
+ Args:
138
+ frame_len (int): Length of the MDCT frame.
139
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
140
+ """
141
+
142
+ def __init__(self, frame_len: int, padding: str = "same"):
143
+ super().__init__()
144
+ if padding not in ["center", "same"]:
145
+ raise ValueError("Padding must be 'center' or 'same'.")
146
+ self.padding = padding
147
+ self.frame_len = frame_len
148
+ N = frame_len // 2
149
+ n0 = (N + 1) / 2
150
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
151
+ self.register_buffer("window", window)
152
+
153
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
154
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
155
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
156
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
157
+
158
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
159
+ """
160
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
161
+
162
+ Args:
163
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
164
+ L is the number of frames, and N is the number of frequency bins.
165
+
166
+ Returns:
167
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
168
+ """
169
+ B, L, N = X.shape
170
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
171
+ Y[..., :N] = X
172
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
173
+ y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1)
174
+ y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2)
175
+ result = y * self.window.expand(y.shape)
176
+ output_size = (1, (L + 1) * N)
177
+ audio = torch.nn.functional.fold(
178
+ result.transpose(1, 2),
179
+ output_size=output_size,
180
+ kernel_size=(1, self.frame_len),
181
+ stride=(1, self.frame_len // 2),
182
+ )[:, 0, 0, :]
183
+
184
+ if self.padding == "center":
185
+ pad = self.frame_len // 2
186
+ elif self.padding == "same":
187
+ pad = self.frame_len // 4
188
+ else:
189
+ raise ValueError("Padding must be 'center' or 'same'.")
190
+
191
+ audio = audio[:, pad:-pad]
192
+ return audio