geolocation-from-speech-demo commited on
Commit
57b83ef
·
verified ·
1 Parent(s): 68975a7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import lhotse
4
+ import numpy as np
5
+ import os
6
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2ForPreTraining
7
+ import gradio as gr
8
+ import geoviews as gv
9
+ import geoviews.tile_sources as gts
10
+ import uuid
11
+ import gdown
12
+ import math
13
+ import torch.nn as nn
14
+
15
+
16
+ device = torch.device("cpu")
17
+
18
+ class AttentionPool(nn.Module):
19
+ def __init__(self, att, query_embed):
20
+ super(AttentionPool, self).__init__()
21
+ self.query_embed = query_embed
22
+ self.att = att
23
+
24
+ def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
25
+ # Create mask
26
+ max_seq_length = x_lens.max().item()
27
+
28
+ # Step 2: Create a binary mask
29
+ mask = torch.arange(max_seq_length)[None, :].to(x.device) >= x_lens[:, None]
30
+
31
+ # Step 3: Expand the mask to match the shape required by MultiheadAttention
32
+ # The mask should have shape (batch_size, 1, 1, max_seq_length)
33
+ x, w = self.att(
34
+ self.query_embed.unsqueeze(0).unsqueeze(1).repeat(x.size(0), 1, 1),
35
+ x,
36
+ x,
37
+ key_padding_mask=mask
38
+ )
39
+ x = x.squeeze(1)
40
+ return x, w
41
+
42
+
43
+ class AveragePool(nn.Module):
44
+ def __init__(self):
45
+ super(AveragePool, self).__init__()
46
+
47
+ def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
48
+ # Create mask
49
+ max_seq_length = x_lens.max().item()
50
+ # Step 2: Create a binary mask
51
+ mask = torch.arange(max_seq_length)[None, :].to(x.device) >= x_lens[:, None]
52
+ x[mask] = torch.nan
53
+ return x.nanmean(dim=1), None
54
+
55
+
56
+ class Wav2Vec2Model(nn.Module):
57
+ def __init__(self,
58
+ modelpath='facebook/mms-300m',
59
+ freeze_feat_extractor=True,
60
+ pooling_loc=0,
61
+ pooling_type='att',
62
+ ):
63
+ super(Wav2Vec2Model, self).__init__()
64
+ try:
65
+ self.encoder = Wav2Vec2ForCTC.from_pretrained(modelpath).wav2vec2
66
+ except:
67
+ self.encoder = Wav2Vec2ForPreTraining.from_pretrained(modelpath).wav2vec2
68
+
69
+ if freeze_feat_extractor:
70
+ self.encoder.feature_extractor._freeze_parameters()
71
+ self.freeze_feat_extractor = freeze_feat_extractor
72
+ self.odim = self._get_output_dim()
73
+
74
+ self.frozen = False
75
+ if pooling_type == 'att':
76
+ assert pooling_loc == 0
77
+ self.att = nn.MultiheadAttention(self.odim, 1, batch_first=True)
78
+ self.loc_embed = nn.Parameter(
79
+ torch.FloatTensor(self.odim).uniform_(-1, 1)
80
+ )
81
+ self.pooling = AttentionPool(self.att, self.loc_embed)
82
+ elif pooling_type == 'avg':
83
+ self.pooling = AveragePool()
84
+ self.pooling_type = pooling_type
85
+ # pooling loc is on 0: embeddings 1: unnormalized coords, 2: normalized coords
86
+ self.pooling_loc = pooling_loc
87
+ self.linear_out = nn.Linear(self.odim, 3)
88
+
89
+ def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
90
+ x = self.encoder(
91
+ x.squeeze(-1), output_hidden_states=False
92
+ )[0]
93
+
94
+ for width, stride in [(10, 5), (3, 2), (3, 2), (3, 2), (3, 2), (2, 2), (2, 2)]:
95
+ x_lens = torch.floor((x_lens - width) / stride + 1)
96
+ if self.pooling_loc == 0:
97
+ x, w = self.pooling(x, x_lens)
98
+ x = self.linear_out(x)
99
+ x = x.div(x.norm(dim=1).unsqueeze(-1))
100
+ elif self.pooling_loc == 1:
101
+ x = self.linear_out(x)
102
+ x, w = self.pooling(x, x_lens)
103
+ x = x.div(x.norm(dim=1).unsqueeze(-1))
104
+ elif self.pooling_loc == 2:
105
+ x = self.linear_out(x)
106
+ x = x.div(x.norm(dim=1).unsqueeze(-1))
107
+ x = self.pooling(x, x_lens)
108
+ x = x.div(x.norm(dim=1).unsqueeze(-1))
109
+ return x, w
110
+
111
+ def freeze_encoder(self):
112
+ for p in self.encoder.encoder.parameters():
113
+ if p.requires_grad:
114
+ p.requires_grad = False
115
+ self.frozen = True
116
+
117
+ def unfreeze_encoder(self):
118
+ for i, p in enumerate(self.encoder.encoder.parameters()):
119
+ p.requires_grad = True
120
+ if self.freeze_feat_extractor:
121
+ self.encoder.feature_extractor._freeze_parameters()
122
+ self.frozen = False
123
+
124
+ def _get_output_dim(self):
125
+ x = torch.rand(1, 400)
126
+ return self.encoder(x).last_hidden_state.size(-1)
127
+
128
+
129
+ # download model checkpoint
130
+ # bad way to do this probably but oh well
131
+ if 'checkpoint.pt' not in os.listdir():
132
+ checkpoint_url = "https://drive.google.com/uc?id=162jJ_YC4MGEfXBWvAK-kXnZcXX3v1smr"
133
+ output = "checkpoint.pt"
134
+ gdown.download(checkpoint_url, output, quiet=False)
135
+
136
+ model = Wav2Vec2Model()
137
+ model.to(device)
138
+
139
+ # load model checkpoint
140
+ for f in os.listdir():
141
+ if '.pt' in f and 'checkpoint' in f:
142
+ checkpoint = torch.load(f, map_location=f'cpu')
143
+ model.load_state_dict(checkpoint)
144
+ model.eval()
145
+ print(f'Loaded state dict {f}')
146
+
147
+ def predict(audio_path):
148
+ # get raw audio data
149
+ try:
150
+ a = lhotse.Recording.from_file(audio_path)
151
+ except:
152
+ return (None, "Please wait a bit until the audio file has uploaded, then try again")
153
+ a = a.resample(16000)
154
+ a = lhotse.cut.MultiCut(recording = a, start=0, duration=10, id="temp", channel=a.to_dict()['sources'][0]['channels']).to_mono(mono_downmix = True) # if multi channel, convert to single channel
155
+ cuts = lhotse.CutSet(cuts={"cut":a})
156
+
157
+ audio_data, audio_lens = lhotse.dataset.collation.collate_audio(cuts)
158
+
159
+ # pass through model
160
+ x, _ = model.forward(audio_data, audio_lens)
161
+ print(x)
162
+
163
+ pred_lon = torch.atan2(x[:, 0], x[:, 1]).unsqueeze(-1)
164
+ pred_lat = torch.asin(x[:, 2]).unsqueeze(-1)
165
+ x_polar = torch.cat((pred_lat, pred_lon), dim=1).to(device)
166
+ coords = x_polar.mul(180. / math.pi).cpu().detach().numpy()
167
+ print(coords)
168
+
169
+
170
+ coords = [[-lon, math.degrees(math.asin(math.sin(math.radians(lat))))] if lat > 90 else [lon, lat] for lat, lon in coords][0] # wraparound fix (lat > 90)
171
+
172
+ # create plot
173
+ guesses = gv.Points([coords]).opts(
174
+ size=8, cmap='Spectral_r', color='blue', fill_alpha=1
175
+ )
176
+ plot = (gts.OSM * guesses).options(
177
+ gv.opts.Points(width=800, height=400, xlim=(-180*110000, 180*110000), ylim=(-90*140000, 90*140000), xaxis=None, yaxis=None)
178
+ )
179
+ filename = f"{str(uuid.uuid4())}.png"
180
+ gv.save(plot, filename=filename, fmt='png')
181
+ coords = [round(i, 2) for i in coords]
182
+ coords = [coords[1], coords[0]]
183
+ print(filename, coords)
184
+ return (filename, str(coords)[1:-1])
185
+
186
+ gradio_app = gr.Interface(
187
+ predict,
188
+ inputs=gr.Audio(label="Record Audio (10 seconds)", type="filepath", min_length=10.0),
189
+ outputs=[gr.Image(type="filepath", label="Map of Prediction"), gr.Textbox(placeholder="Latitude, Longitude", label="Prediction (Latitude, Longitude)")],
190
+ title="Speech Geolocation Demo",
191
+ )
192
+
193
+ if __name__ == "__main__":
194
+ gradio_app.launch()