|
import subprocess |
|
import sys |
|
import shlex |
|
|
|
subprocess.check_call(["apt-get", "update"]) |
|
subprocess.check_call([sys.executable,"-m","pip","install", |
|
"torch==2.2.0", |
|
"torchaudio==2.2.0"]) |
|
subprocess.check_call([sys.executable,"-m","pip","install", |
|
"einops", "encodec"]) |
|
def install_mamba(): |
|
subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) |
|
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v1.2.0.post1/mamba_ssm-1.2.0.post1+cu122torch2.2cxx11abiTRUE-cp310-cp310-linux_x86_64.whl")) |
|
|
|
install_mamba() |
|
|
|
import torch |
|
import spaces |
|
import tempfile |
|
import soundfile as sf |
|
import gradio as gr |
|
import librosa as lb |
|
import yaml |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from pydub import AudioSegment |
|
from model.cleanmel import CleanMel |
|
from model.vocos.pretrained import Vocos |
|
from model.stft import InputSTFT, TargetMel |
|
|
|
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
def read_audio(file_path): |
|
assert file_path.endswith(('.wav', '.flac')), "Unsupported audio format. Please upload a .wav, .flac file." |
|
audio, sample_rate = sf.read(file_path) |
|
|
|
if audio.ndim > 1: |
|
|
|
audio = audio[:, np.argmax(np.abs(audio).mean(axis=0))] |
|
if sample_rate != 16000: |
|
audio = lb.resample(audio, orig_sr=sample_rate, target_sr=16000) |
|
sample_rate = 16000 |
|
|
|
return torch.tensor(audio).float().squeeze().unsqueeze(0) |
|
|
|
def stft(audio): |
|
transform = InputSTFT( |
|
n_fft=512, |
|
n_win=512, |
|
n_hop=128, |
|
normalize=False, |
|
center=True, |
|
onesided=True, |
|
online=False |
|
).eval().to(DEVICE) |
|
return transform(audio) |
|
|
|
def mel_transform(audio, X_norm): |
|
transform = TargetMel( |
|
sample_rate=16000, |
|
n_fft=512, |
|
n_win=512, |
|
n_hop=128, |
|
n_mels=80, |
|
f_min=0, |
|
f_max=8000, |
|
power=2, |
|
center=True, |
|
normalize=False, |
|
onesided=True, |
|
mel_norm="slaney", |
|
mel_scale="slaney", |
|
librosa_mel=True, |
|
online=False |
|
).eval().to(DEVICE) |
|
return transform(audio, X_norm) |
|
|
|
def load_cleanmel(model_name): |
|
if "S" in model_name: |
|
model_config = f"./configs/cleanmel_offline_S.yaml" |
|
else: |
|
model_config = f"./configs/cleanmel_offline_L.yaml" |
|
model_config = yaml.safe_load(open(model_config, "r"))["model"]["arch"]["init_args"] |
|
cleanmel = CleanMel(**model_config) |
|
cleanmel.load_state_dict(torch.load(f"./ckpts/CleanMel/{model_name}.ckpt", map_location=DEVICE)) |
|
return cleanmel.eval() |
|
|
|
def load_vocos(): |
|
vocos = Vocos.from_hparams(config_path="./configs/vocos_offline.yaml") |
|
vocos = Vocos.from_pretrained(None, model_path=f"./ckpts/Vocos/vocos_offline.pt", model=vocos) |
|
return vocos.eval() |
|
|
|
def get_mrm_pred(Y_hat, x, X_norm): |
|
X_noisy = mel_transform(x, X_norm) |
|
Y_hat = Y_hat.squeeze() |
|
Y_hat = torch.square(Y_hat * (torch.sqrt(X_noisy) + 1e-10)) |
|
return Y_hat |
|
|
|
def safe_log(x): |
|
return torch.log(torch.clip(x, min=1e-5)) |
|
|
|
def output(y_hat, logMel_hat): |
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: |
|
sf.write(tmp_file.name, y_hat.squeeze().cpu().numpy(), 16000) |
|
with tempfile.NamedTemporaryFile(suffix='.npy', delete=False) as tmp_logmel_np_file: |
|
np.save(tmp_logmel_np_file.name, logMel_hat.squeeze().cpu().numpy()) |
|
logMel_img = logMel_hat.squeeze().cpu().numpy()[::-1, :] |
|
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_logmel_img: |
|
|
|
plt.figure(figsize=(logMel_img.shape[1] / 100, logMel_img.shape[0] / 50)) |
|
plt.clf() |
|
plt.imshow(logMel_img, vmin=-11, cmap="jet") |
|
plt.tight_layout() |
|
plt.ylabel("Mel bands") |
|
plt.xlabel("Time (second)") |
|
plt.yticks([0, 80], [80, 0]) |
|
dur = y_hat.shape[-1] / 16000 |
|
xticks = [int(x) for x in np.linspace(0, logMel_img.shape[-1], 11)] |
|
xticks_str = ["{:.1f}".format(x) for x in np.linspace(0, dur, 11)] |
|
plt.xticks(xticks, xticks_str) |
|
plt.savefig(tmp_logmel_img.name) |
|
|
|
return tmp_file.name, tmp_logmel_img.name, tmp_logmel_np_file.name |
|
|
|
@spaces.GPU |
|
@torch.inference_mode() |
|
def enhance_cleanmel_L_mask(audio_path): |
|
model = load_cleanmel("offline_CleanMel_L_mask").to(DEVICE) |
|
vocos = load_vocos().to(DEVICE) |
|
x = read_audio(audio_path).to(DEVICE) |
|
X, X_norm = stft(x) |
|
Y_hat = model(X, inference=True) |
|
MRM_hat = torch.sigmoid(Y_hat) |
|
Y_hat = get_mrm_pred(MRM_hat, x, X_norm) |
|
logMel_hat = safe_log(Y_hat) |
|
y_hat = vocos(logMel_hat, X_norm).clamp(min=-1, max=1) |
|
return output(y_hat, logMel_hat) |
|
|
|
@spaces.GPU |
|
@torch.inference_mode() |
|
def enhance_cleanmel_S_mask(audio_path): |
|
model = load_cleanmel("offline_CleanMel_S_mask").to(DEVICE) |
|
vocos = load_vocos().to(DEVICE) |
|
x = read_audio(audio_path).to(DEVICE) |
|
X, X_norm = stft(x) |
|
Y_hat = model(X, inference=True) |
|
MRM_hat = torch.sigmoid(Y_hat) |
|
Y_hat = get_mrm_pred(MRM_hat, x, X_norm) |
|
logMel_hat = safe_log(Y_hat) |
|
y_hat = vocos(logMel_hat, X_norm).clamp(min=-1, max=1) |
|
return output(y_hat, logMel_hat) |
|
|
|
@spaces.GPU |
|
@torch.inference_mode() |
|
def enhance_cleanmel_L_map(audio_path): |
|
model = load_cleanmel("offline_CleanMel_L_map").to(DEVICE) |
|
vocos = load_vocos().to(DEVICE) |
|
x = read_audio(audio_path).to(DEVICE) |
|
X, X_norm = stft(x) |
|
logMel_hat = model(X, inference=True) |
|
y_hat = vocos(logMel_hat, X_norm).clamp(min=-1, max=1) |
|
return output(y_hat, logMel_hat) |
|
|
|
@spaces.GPU |
|
@torch.inference_mode() |
|
def enhance_cleanmel_S_map(audio_path): |
|
model = load_cleanmel("offline_CleanMel_S_map").to(DEVICE) |
|
vocos = load_vocos().to(DEVICE) |
|
x = read_audio(audio_path).to(DEVICE) |
|
X, X_norm = stft(x) |
|
logMel_hat = model(X, inference=True) |
|
y_hat = vocos(logMel_hat, X_norm).clamp(min=-1, max=1) |
|
return output(y_hat, logMel_hat) |
|
|
|
def reset_everything(): |
|
"""Reset all components to initial state""" |
|
return None, None, None |
|
|
|
|
|
demo = gr.Blocks() |
|
with gr.Blocks(title="CleanMel Demo") as demo: |
|
gr.Markdown("## CleanMel Demo") |
|
gr.Markdown("This demo showcases the CleanMel model for speech enhancement. <br> \ |
|
Only **.wav** and **.flac** files are supported. <br> \ |
|
--- <br> \ |
|
The model is running on CPU. Please be patient and wait for the result. <br> \ |
|
Inference time reference: <br> \ |
|
- CleanMel_L: **10 mins** for **10-second** audio <br> \ |
|
- CleanMel_S: **4 mins** for **10-second** audio <br> ") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
audio_input = gr.Audio(label="Input Audio", type="filepath", sources="upload") |
|
audio_input_record = gr.Audio(label="Input Audio (Record)", type="filepath", sources="microphone") |
|
with gr.Row(): |
|
with gr.Column(): |
|
enhance_button_map_S = gr.Button("Enhance File (offline CleanMel_S_map)") |
|
enhance_button_mask_S = gr.Button("Enhance File (offline CleanMel_S_mask)") |
|
enhance_button_map_L = gr.Button("Enhance File (offline CleanMel_L_map)") |
|
enhance_button_mask_L = gr.Button("Enhance File (offline CleanMel_L_mask)") |
|
|
|
with gr.Column(): |
|
enhance_button_map_Sr = gr.Button("Enhance Recorded Audio (offline CleanMel_S_map)") |
|
enhance_button_mask_Sr = gr.Button("Enhance Recorded Audio (offline CleanMel_S_mask)") |
|
enhance_button_map_Lr = gr.Button("Enhance Recorded Audio (offline CleanMel_L_map)") |
|
enhance_button_mask_Lr = gr.Button("Enhance Recorded Audio (offline CleanMel_L_mask)") |
|
with gr.Row(): |
|
clear_btn = gr.Button( |
|
"🗑️ Clear All", |
|
variant="secondary", |
|
size="lg" |
|
) |
|
output_audio = gr.Audio(label="Enhanced Audio", type="filepath") |
|
output_mel = gr.Image(label="Output LogMel Spectrogram", type="filepath", visible=True) |
|
output_np = gr.File(label="Enhanced LogMel Spec. (.npy)", type="filepath") |
|
enhance_button_map_L.click( |
|
enhance_cleanmel_L_map, |
|
inputs=audio_input, |
|
outputs=[output_audio, output_mel, output_np] |
|
) |
|
|
|
enhance_button_mask_L.click( |
|
enhance_cleanmel_L_mask, |
|
inputs=audio_input, |
|
outputs=[output_audio, output_mel, output_np] |
|
) |
|
|
|
enhance_button_map_S.click( |
|
enhance_cleanmel_S_map, |
|
inputs=audio_input, |
|
outputs=[output_audio, output_mel, output_np] |
|
) |
|
|
|
enhance_button_mask_S.click( |
|
enhance_cleanmel_S_mask, |
|
inputs=audio_input, |
|
outputs=[output_audio, output_mel, output_np] |
|
) |
|
enhance_button_map_Lr.click( |
|
enhance_cleanmel_L_map, |
|
inputs=audio_input_record, |
|
outputs=[output_audio, output_mel, output_np] |
|
) |
|
|
|
enhance_button_mask_Lr.click( |
|
enhance_cleanmel_L_mask, |
|
inputs=audio_input_record, |
|
outputs=[output_audio, output_mel, output_np] |
|
) |
|
|
|
enhance_button_map_Sr.click( |
|
enhance_cleanmel_S_map, |
|
inputs=audio_input_record, |
|
outputs=[output_audio, output_mel, output_np] |
|
) |
|
|
|
enhance_button_mask_Sr.click( |
|
enhance_cleanmel_S_mask, |
|
inputs=audio_input_record, |
|
outputs=[output_audio, output_mel, output_np] |
|
) |
|
|
|
clear_btn.click( |
|
fn=reset_everything, |
|
outputs=[output_audio, output_mel, output_np] |
|
) |
|
|
|
demo.launch(debug=False) |