Dionyssos commited on
Commit
e0f0baf
·
1 Parent(s): 95ad439

display dawn / teacher

Browse files
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
- title: Emotional Attributes
3
- emoji: 📚
4
- colorFrom: indigo
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.41.0
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-4.0
11
  short_description: Perceive speech Arousal / Dominance / Valence
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Wav2Vec2 / Wav2small
3
+ emoji: 🎵
4
+ colorFrom: blue
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.25.2
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-4.0
11
  short_description: Perceive speech Arousal / Dominance / Valence
12
  ---
13
 
14
+ A space for [Dawn](https://huggingface.co/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim) and [wav2small](https://huggingface.co/dkounadis/wav2small). Follows this [paper](https://arxiv.org/abs/2408.13920).
app.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch.nn as nn
3
+ import audresample
4
+ import matplotlib.pyplot as plt
5
+ from matplotlib import colors as mcolors
6
+ import torch
7
+ import librosa
8
+ import numpy as np
9
+ import types
10
+ from transformers import AutoModelForAudioClassification
11
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (Wav2Vec2Model,
12
+ Wav2Vec2PreTrainedModel)
13
+
14
+
15
+ plt.style.use('seaborn-v0_8-whitegrid')
16
+
17
+
18
+
19
+
20
+ def _prenorm(x, attention_mask=None):
21
+ '''mean/var'''
22
+ if attention_mask is not None:
23
+ N = attention_mask.sum(1, keepdim=True) # 0=ignored 1=valid
24
+ x -= x.sum(1, keepdim=True) / N
25
+ var = (x * x).sum(1, keepdim=True) / N
26
+
27
+ else:
28
+ x -= x.mean(1, keepdim=True) # mean is an onnx operator reducemean saves some ops compared to casting integer N to float and the div
29
+ var = (x * x).mean(1, keepdim=True)
30
+ return x / torch.sqrt(var + 1e-7)
31
+
32
+
33
+
34
+
35
+ class ADV(nn.Module):
36
+
37
+ def __init__(self, config):
38
+
39
+ super().__init__()
40
+
41
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
42
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
43
+
44
+ def forward(self, x):
45
+
46
+ x = self.dense(x)
47
+ x = torch.tanh(x)
48
+
49
+ return self.out_proj(x)
50
+
51
+
52
+ class Dawn(Wav2Vec2PreTrainedModel):
53
+ r"""https://arxiv.org/abs/2203.07378"""
54
+
55
+ def __init__(self, config):
56
+
57
+ super().__init__(config)
58
+
59
+ self.wav2vec2 = Wav2Vec2Model(config)
60
+ self.classifier = ADV(config)
61
+
62
+ def forward(self, x):
63
+ x -= x.mean(1, keepdim=True)
64
+ variance = (x * x).mean(1, keepdim=True) + 1e-7
65
+ x = self.wav2vec2(x / variance.sqrt())
66
+ return self.classifier(x.last_hidden_state.mean(1))
67
+
68
+
69
+ def _forward(self, x):
70
+ '''x: (batch, audio-samples-16KHz)'''
71
+ x = (x + self.config.mean) / self.config.std # sgn
72
+ x = self.ssl_model(x, attention_mask=None).last_hidden_state
73
+ # pool
74
+ h = self.pool_model.sap_linear(x).tanh()
75
+ w = torch.matmul(h, self.pool_model.attention).softmax(1)
76
+ mu = (x * w).sum(1)
77
+ x = torch.cat(
78
+ [
79
+ mu,
80
+ ((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt()
81
+ ], 1)
82
+ return self.ser_model(x)
83
+
84
+
85
+ # WavLM
86
+ device = 'cpu'
87
+ base = AutoModelForAudioClassification.from_pretrained(
88
+ '3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
89
+ trust_remote_code=True).to(device).eval()
90
+ base.forward = types.MethodType(_forward, base)
91
+
92
+ # Wav2Vec2
93
+
94
+ dawn = Dawn.from_pretrained(
95
+ 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
96
+ ).to(device).eval()
97
+
98
+
99
+ def wav2small(x):
100
+ return .5 * dawn(x) + .5 * base(x)
101
+
102
+
103
+ fig_error, ax = plt.subplots(figsize=(8, 6))
104
+
105
+ # Set the text to display
106
+ error_message = "Error: No .wav or Mic. audio provided."
107
+
108
+ # Add the text to the plot. We'll place it in the center of the plot
109
+ ax.text(0.5, 0.5, error_message,
110
+ ha='center',
111
+ va='center',
112
+ fontsize=24,
113
+ color='gray',
114
+ fontweight='bold',
115
+ transform=ax.transAxes)
116
+
117
+ # Hide the axis ticks and labels for a cleaner look
118
+ ax.set_xticks([])
119
+ ax.set_yticks([])
120
+ ax.set_xticklabels([])
121
+ ax.set_yticklabels([])
122
+
123
+ # Optional: Add a border around the text to make it stand out more
124
+ ax.set_frame_on(True)
125
+ ax.spines['top'].set_visible(False)
126
+ ax.spines['right'].set_visible(False)
127
+ ax.spines['bottom'].set_visible(False)
128
+ ax.spines['left'].set_visible(False)
129
+
130
+
131
+
132
+
133
+
134
+ def process_audio(audio_filepath):
135
+ if audio_filepath is None:
136
+ return fig_error
137
+
138
+ # Load the audio file
139
+ waveform, sample_rate = librosa.load(audio_filepath)
140
+
141
+ # Ensure audio is mono: if stereo, take the mean across channels
142
+
143
+ # Resample audio to 16kHz if necessary
144
+ if sample_rate != 16000:
145
+ resampled_waveform_np = audresample.resample(waveform, sample_rate, 16000)
146
+ x = torch.from_numpy(resampled_waveform_np)
147
+ x = x[:, :64000] # 4s
148
+ with torch.no_grad():
149
+ logits_dawn = dawn(x).cpu().numpy()[0, :]
150
+ logits_wavlm = base(x).cpu().numpy()[0, :]
151
+
152
+ logits_wav2small = .5 * logits_dawn + .5 * logits_wavlm
153
+
154
+ # left_bars_data = np.array([0.75, 0.5, 0.9])
155
+ # right_bars_data = np.array([0.3, 0.8, 0.65])
156
+ left_bars_data = logits_dawn.clip(0, 1)
157
+ right_bars_data = logits_wav2small.clip(0, 1)
158
+
159
+
160
+ bar_labels = ['\nArousal', '\nDominance', '\nValence']
161
+ y_pos = np.arange(len(bar_labels))
162
+
163
+ # Define the base colormaps for each category to ensure a different color per row
164
+ # Using Greys for Dominance as requested
165
+ category_colormaps = [plt.cm.Blues, plt.cm.Greys, plt.cm.Oranges]
166
+
167
+ # Define color shades for left and right for each category
168
+ left_filled_colors = []
169
+ right_filled_colors = []
170
+ background_colors = []
171
+
172
+ for i, cmap in enumerate(category_colormaps):
173
+ # Pick a darker shade for the left filled bar
174
+ left_filled_colors.append(cmap(0.74)) # 0.7
175
+ # Pick a slightly lighter shade for the right filled bar
176
+ right_filled_colors.append(cmap(0.64)) # 0.5
177
+ # Pick a very light shade for the transparent background bar
178
+ background_colors.append(cmap(0.1))
179
+
180
+ # Set up the figure and axes
181
+ fig, ax = plt.subplots(figsize=(10, 6))
182
+
183
+ # Plot the background bars with transparency
184
+ for i in range(len(bar_labels)):
185
+ # Left background bar (transparent, light shade of category color)
186
+ ax.barh(y_pos[i], -1, color=background_colors[i], alpha=0.3, height=0.6)
187
+ # Right background bar (transparent, light shade of category color)
188
+ ax.barh(y_pos[i], 1, color=background_colors[i], alpha=0.3, height=0.6)
189
+
190
+ # Plot the filled bars for the left and right side
191
+ for i in range(len(bar_labels)):
192
+ # Left filled bar (opaque, darker shade of category color)
193
+ ax.barh(y_pos[i], -left_bars_data[i], color=left_filled_colors[i], alpha=1, height=0.6)
194
+ # Right filled bar (opaque, lighter shade of category color)
195
+ ax.barh(y_pos[i], right_bars_data[i], color=right_filled_colors[i], alpha=1, height=0.6)
196
+
197
+ # Add a central axis divider
198
+ ax.axvline(0, color='black', linewidth=0.8, linestyle='--')
199
+
200
+ # Set x-axis limits and y-axis ticks
201
+ ax.set_xlim(-1, 1)
202
+ ax.set_yticks(y_pos)
203
+ ax.set_yticklabels(bar_labels, fontsize=12)
204
+
205
+
206
+ def abs_tick_formatter(x, pos):
207
+ return f'{int(abs(x) * 100)}%'
208
+ ax.xaxis.set_major_formatter(plt.FuncFormatter(abs_tick_formatter))
209
+
210
+ # Add a clean title and labels
211
+ ax.set_title('', fontsize=16, pad=20)
212
+ ax.set_xlabel('Outputs of Wav2Vev2 Outputs of Wav2Small Teacher', fontsize=12)
213
+
214
+ # Remove the top and right spines for a cleaner look
215
+ ax.spines['top'].set_visible(False)
216
+ ax.spines['right'].set_visible(False)
217
+ ax.spines['left'].set_visible(False)
218
+
219
+ # Add annotations to the filled bars for clarity
220
+ for i in range(len(bar_labels)):
221
+ # Left annotation (uses left_filled_colors for text color)
222
+ ax.text(-left_bars_data[i] - 0.05, y_pos[i], f'{int(left_bars_data[i] * 100)}%',
223
+ va='center', ha='right', color=left_filled_colors[i], fontweight='bold')
224
+ # Right annotation (uses right_filled_colors for text color)
225
+ ax.text(right_bars_data[i] + 0.05, y_pos[i], f'{int(right_bars_data[i] * 100)}%',
226
+ va='center', ha='left', color=right_filled_colors[i], fontweight='bold')
227
+
228
+
229
+ return fig
230
+
231
+
232
+
233
+
234
+
235
+
236
+
237
+ iface = gr.Interface(
238
+ fn=process_audio,
239
+ inputs=gr.Audio(
240
+ sources=["microphone", "upload"],
241
+ type="filepath", # Input type is file path
242
+ label=''
243
+ ),
244
+ outputs=[
245
+ gr.Plot(label="Arousal / Dominance / Valence Plots"),
246
+ ],
247
+ title='',
248
+ description='',
249
+ flagging_mode="never", # save audio and .csv in the machine ?
250
+ examples=[
251
+ "female-46-neutral.wav",
252
+ "female-20-happy.wav",
253
+ "male-60-angry.wav",
254
+ "male-27-sad.wav",
255
+ ],
256
+ css="footer {visibility: hidden}"
257
+ )
258
+
259
+ with gr.Blocks() as demo:
260
+
261
+ # https://discuss.huggingface.co/t/how-to-get-the-microphone-streaming-input-file-when-using-blocks/37204/3
262
+ with gr.Tab(label="Arousal / Dominance / Valence"):
263
+ iface.render()
264
+ with gr.Tab(label="CCC"):
265
+ gr.Markdown('''<table style="width:500px"><tr><th colspan=5 >CCC MSP Podcast v1.7</th></tr>
266
+ <tr> <td> </td><td>Arousal</td> <td>Dominance</td> <td>Valence</td> <td> Associated Paper </td> </tr>
267
+ <tr> <td> <a href="https://huggingface.co/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim">Wav2Vec2</a></td><td>0.744</td><td>0.655</td><td> 0.638 </td><td> <a href="https://arxiv.org/abs/2203.07378">arXiv</a> </td> </tr>
268
+ <tr> <td> <a href="https://huggingface.co/dkounadis/wav2small">Wav2Small Teacher</a></td><td> 0.762 </td> <td> 0.684 </td><td> 0.676 </td><td> <a href="https://arxiv.org/abs/2408.13920">arXiv</a> </td> </tr>
269
+ </table>
270
+ ''')
271
+
272
+ if __name__ == "__main__":
273
+ demo.launch(share=False)
female-20-happy.wav ADDED
Binary file (51 kB). View file
 
female-46-neutral.wav ADDED
Binary file (37.6 kB). View file
 
male-27-sad.wav ADDED
Binary file (50.4 kB). View file
 
male-60-angry.wav ADDED
Binary file (60.5 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ audresample
2
+ matplotlib
3
+ torch
4
+ transformers
5
+ librosa