SaoYear commited on
Commit
8521c95
·
1 Parent(s): 0c51cc6

+Small models

Browse files
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import subprocess
2
  import sys
3
  import shlex
 
 
4
  #update the system
5
  subprocess.check_call(["apt-get", "update"])
6
  subprocess.check_call([sys.executable,"-m","pip","install",
@@ -72,7 +74,10 @@ def mel_transform(audio, X_norm):
72
  return transform(audio, X_norm)
73
 
74
  def load_cleanmel(model_name):
75
- model_config = f"./configs/cleanmel_offline.yaml"
 
 
 
76
  model_config = yaml.safe_load(open(model_config, "r"))["model"]["arch"]["init_args"]
77
  cleanmel = CleanMel(**model_config)
78
  cleanmel.load_state_dict(torch.load(f"./ckpts/CleanMel/{model_name}.ckpt", map_location=DEVICE))
@@ -129,6 +134,20 @@ def enhance_cleanmel_L_mask(audio_path):
129
  y_hat = vocos(logMel_hat, X_norm).clamp(min=-1, max=1)
130
  return output(y_hat, logMel_hat)
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  @spaces.GPU
133
  @torch.inference_mode()
134
  def enhance_cleanmel_L_map(audio_path):
@@ -140,6 +159,17 @@ def enhance_cleanmel_L_map(audio_path):
140
  y_hat = vocos(logMel_hat, X_norm).clamp(min=-1, max=1)
141
  return output(y_hat, logMel_hat)
142
 
 
 
 
 
 
 
 
 
 
 
 
143
  def reset_everything():
144
  """Reset all components to initial state"""
145
  return None, None, None
@@ -153,8 +183,12 @@ with gr.Blocks(title="CleanMel Demo") as demo:
153
  with gr.Row():
154
  audio_input = gr.Audio(label="Input Audio", type="filepath", sources="upload")
155
  with gr.Column():
156
- enhance_button_map = gr.Button("Enhance Audio (offline CleanMel_L_map)")
157
- enhance_button_mask = gr.Button("Enhance Audio (offline CleanMel_L_mask)")
 
 
 
 
158
  clear_btn = gr.Button(
159
  "🗑️ Clear All",
160
  variant="secondary",
@@ -165,17 +199,30 @@ with gr.Blocks(title="CleanMel Demo") as demo:
165
  output_mel = gr.Image(label="Output LogMel Spectrogram", type="filepath", visible=True)
166
  output_np = gr.File(label="Enhanced LogMel Spec. (.npy)", type="filepath")
167
 
168
- enhance_button_map.click(
169
  enhance_cleanmel_L_map,
170
  inputs=audio_input,
171
  outputs=[output_audio, output_mel, output_np]
172
  )
173
 
174
- enhance_button_mask.click(
175
  enhance_cleanmel_L_mask,
176
  inputs=audio_input,
177
  outputs=[output_audio, output_mel, output_np]
178
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  clear_btn.click(
180
  fn=reset_everything,
181
  outputs=[output_audio, output_mel, output_np]
 
1
  import subprocess
2
  import sys
3
  import shlex
4
+
5
+ from OtherMethods.Demucs.denoiser import enhance
6
  #update the system
7
  subprocess.check_call(["apt-get", "update"])
8
  subprocess.check_call([sys.executable,"-m","pip","install",
 
74
  return transform(audio, X_norm)
75
 
76
  def load_cleanmel(model_name):
77
+ if "S" in model_name:
78
+ model_config = f"./configs/cleanmel_offline_S.yaml"
79
+ else:
80
+ model_config = f"./configs/cleanmel_offline_L.yaml"
81
  model_config = yaml.safe_load(open(model_config, "r"))["model"]["arch"]["init_args"]
82
  cleanmel = CleanMel(**model_config)
83
  cleanmel.load_state_dict(torch.load(f"./ckpts/CleanMel/{model_name}.ckpt", map_location=DEVICE))
 
134
  y_hat = vocos(logMel_hat, X_norm).clamp(min=-1, max=1)
135
  return output(y_hat, logMel_hat)
136
 
137
+ @spaces.GPU
138
+ @torch.inference_mode()
139
+ def enhance_cleanmel_S_mask(audio_path):
140
+ model = load_cleanmel("offline_CleanMel_S_mask").to(DEVICE)
141
+ vocos = load_vocos().to(DEVICE)
142
+ x = read_audio(audio_path).to(DEVICE)
143
+ X, X_norm = stft(x)
144
+ Y_hat = model(X, inference=True)
145
+ MRM_hat = torch.sigmoid(Y_hat)
146
+ Y_hat = get_mrm_pred(MRM_hat, x, X_norm)
147
+ logMel_hat = safe_log(Y_hat)
148
+ y_hat = vocos(logMel_hat, X_norm).clamp(min=-1, max=1)
149
+ return output(y_hat, logMel_hat)
150
+
151
  @spaces.GPU
152
  @torch.inference_mode()
153
  def enhance_cleanmel_L_map(audio_path):
 
159
  y_hat = vocos(logMel_hat, X_norm).clamp(min=-1, max=1)
160
  return output(y_hat, logMel_hat)
161
 
162
+ @spaces.GPU
163
+ @torch.inference_mode()
164
+ def enhance_cleanmel_S_map(audio_path):
165
+ model = load_cleanmel("offline_CleanMel_S_map").to(DEVICE)
166
+ vocos = load_vocos().to(DEVICE)
167
+ x = read_audio(audio_path).to(DEVICE)
168
+ X, X_norm = stft(x)
169
+ logMel_hat = model(X, inference=True)
170
+ y_hat = vocos(logMel_hat, X_norm).clamp(min=-1, max=1)
171
+ return output(y_hat, logMel_hat)
172
+
173
  def reset_everything():
174
  """Reset all components to initial state"""
175
  return None, None, None
 
183
  with gr.Row():
184
  audio_input = gr.Audio(label="Input Audio", type="filepath", sources="upload")
185
  with gr.Column():
186
+
187
+ enhance_button_map_S = gr.Button("Enhance Audio (offline CleanMel_S_map), 4 mins for 10-second audio")
188
+ enhance_button_mask_S = gr.Button("Enhance Audio (offline CleanMel_S_mask), 4 mins for 10-second audio")
189
+
190
+ enhance_button_map_L = gr.Button("Enhance Audio (offline CleanMel_L_map), 10 mins for 10-second audio")
191
+ enhance_button_mask_L = gr.Button("Enhance Audio (offline CleanMel_L_mask), 10 mins for 10-second audio")
192
  clear_btn = gr.Button(
193
  "🗑️ Clear All",
194
  variant="secondary",
 
199
  output_mel = gr.Image(label="Output LogMel Spectrogram", type="filepath", visible=True)
200
  output_np = gr.File(label="Enhanced LogMel Spec. (.npy)", type="filepath")
201
 
202
+ enhance_button_map_L.click(
203
  enhance_cleanmel_L_map,
204
  inputs=audio_input,
205
  outputs=[output_audio, output_mel, output_np]
206
  )
207
 
208
+ enhance_button_mask_L.click(
209
  enhance_cleanmel_L_mask,
210
  inputs=audio_input,
211
  outputs=[output_audio, output_mel, output_np]
212
  )
213
+
214
+ enhance_button_map_S.click(
215
+ enhance_cleanmel_S_map,
216
+ inputs=audio_input,
217
+ outputs=[output_audio, output_mel, output_np]
218
+ )
219
+
220
+ enhance_button_mask_S.click(
221
+ enhance_cleanmel_S_mask,
222
+ inputs=audio_input,
223
+ outputs=[output_audio, output_mel, output_np]
224
+ )
225
+
226
  clear_btn.click(
227
  fn=reset_everything,
228
  outputs=[output_audio, output_mel, output_np]
configs/{cleanmel_offline.yaml → cleanmel_offline_L.yaml} RENAMED
File without changes
configs/cleanmel_offline_S.yaml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed_everything: 2
2
+
3
+ trainer:
4
+ gradient_clip_val: 10
5
+ gradient_clip_algorithm: norm
6
+ devices: null
7
+ accelerator: gpu
8
+ strategy: ddp_find_unused_parameters_false
9
+ sync_batchnorm: false
10
+ precision: 32
11
+ num_sanity_val_steps: 3
12
+ deterministic: true
13
+ max_epochs: 100
14
+ log_every_n_steps: 40
15
+
16
+ model:
17
+ arch:
18
+ class_path: model.arch.cleanmel.CleanMel
19
+ init_args:
20
+ dim_input: 2
21
+ dim_output: 1
22
+ n_layers: 8
23
+ dim_hidden: 96
24
+ layer_linear_freq: 1
25
+ f_kernel_size: 5
26
+ f_conv_groups: 8
27
+ n_freqs: 257
28
+ n_mels: 80
29
+ mamba_state: 16
30
+ mamba_conv_kernel: 4
31
+ online: false
32
+ sr: 16000
33
+ n_fft: 512
34
+ input_stft:
35
+ class_path: model.io.stft.InputSTFT
36
+ init_args:
37
+ n_fft: 512
38
+ n_win: 512
39
+ n_hop: 128
40
+ center: true
41
+ normalize: false
42
+ onesided: true
43
+ online: false
44
+ target_stft:
45
+ class_path: model.io.stft.TargetMel
46
+ init_args:
47
+ sample_rate: 16000
48
+ n_fft: 512
49
+ n_win: 512
50
+ n_hop: 128
51
+ n_mels: 80
52
+ f_min: 0
53
+ f_max: 8000
54
+ power: 2
55
+ center: true
56
+ normalize: false
57
+ onesided: true
58
+ mel_norm: "slaney"
59
+ mel_scale: "slaney"
60
+ librosa_mel: true
61
+ online: false
62
+
63
+ optimizer: [AdamW, { lr: 0.001, weight_decay: 0.001}]
64
+ lr_scheduler: [ExponentialLR, { gamma: 0.99 }]
65
+ exp_name: exp
66
+ metrics: [DNSMOS]
67
+ log_eps: 1e-5