# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Feature extractor class for MERaLiON-SpeechEncoder, modified from original WhisperFeatureExtractor
"""
import itertools
import os
from shutil import copyfile
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from transformers import is_torch_available, AutoFeatureExtractor, AutoTokenizer
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils import TensorType, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.model"}
class ModifiedWhisperFeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs a modified Whisper feature extractor.
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
most of the main methods. Users should refer to this superclass for more information regarding those methods.
This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time
Fourier Transform` which should match pytorch's `torch.stft` equivalent.
Differences from WhisperFeatureExtractor:
- mel_filter_bank
- norm: "slaney" -> None
- mel_scale: "slaney" -> "htk"
- still uses log scaling and clamp but removes additional min-max/mean normalization
Args:
feature_size (`int`, *optional*, defaults to 80):
The feature dimension of the extracted features.
sampling_rate (`int`, *optional*, defaults to 16000):
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
hop_length (`int`, *optional*, defaults to 160):
Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
chunk_length (`int`, *optional*, defaults to 30):
The maximum number of chunks of `sampling_rate` samples used to trim and pad longer or shorter audio
sequences.
n_fft (`int`, *optional*, defaults to 400):
Size of the Fourier transform.
padding_value (`float`, *optional*, defaults to 0.0):
Padding value used to pad the audio. Should correspond to silences.
"""
model_input_names = ["input_values"]
def __init__(
self,
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=120,
n_fft=400,
padding_value=0.0,
return_attention_mask=True, # pad inputs to max length with silence token (zero) and no attention mask
**kwargs,
):
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
return_attention_mask=return_attention_mask,
**kwargs,
)
self.n_fft = n_fft
self.hop_length = hop_length
self.chunk_length = chunk_length
self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length
self.sampling_rate = sampling_rate
self.mel_filters = mel_filter_bank(
num_frequency_bins=1 + n_fft // 2,
num_mel_filters=feature_size,
min_frequency=0.0,
max_frequency=8000.0,
sampling_rate=sampling_rate,
norm=None,
mel_scale="htk",
)
def _np_extract_fbank_features(self, waveform_batch: np.array, device: str) -> np.ndarray:
"""
Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
implementation with 1e-5 tolerance.
"""
if device != "cpu":
raise ValueError(
f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator "
"devices requires torch, which is not installed. Either set `device='cpu'`, or "
"install torch according to the official instructions: https://pytorch.org/get-started/locally/"
)
log_spec_batch = []
for waveform in waveform_batch:
log_spec = spectrogram(
waveform,
window_function(self.n_fft, "hann"),
frame_length=self.n_fft,
hop_length=self.hop_length,
power=2.0,
mel_filters=self.mel_filters,
log_mel="log10",
)
log_spec = log_spec[:, :-1]
log_spec_batch.append(log_spec)
log_spec_batch = np.array(log_spec_batch)
return log_spec_batch
def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu") -> np.ndarray:
"""
Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching,
yielding results similar to cpu computing with 1e-5 tolerance.
"""
waveform = torch.from_numpy(waveform).type(torch.float32)
window = torch.hann_window(self.n_fft)
if device != "cpu":
waveform = waveform.to(device)
window = window.to(device)
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
if device != "cpu":
mel_filters = mel_filters.to(device)
mel_spec = mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
if device != "cpu":
log_spec = log_spec.detach().cpu()
return log_spec.numpy()
@staticmethod
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
def zero_mean_unit_var_norm(
input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
) -> List[np.ndarray]:
"""
Every array in the list is normalized to have zero mean and unit variance
"""
if attention_mask is not None:
attention_mask = np.array(attention_mask, np.int32)
normed_input_values = []
for vector, length in zip(input_values, attention_mask.sum(-1)):
normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
if length < normed_slice.shape[0]:
normed_slice[length:] = padding_value
normed_input_values.append(normed_slice)
else:
normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
return normed_input_values
def __call__(
self,
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
truncation: bool = True,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: Optional[bool] = True,
padding: Optional[Union[bool, str]] = True,
max_length: Optional[int] = None,
sampling_rate: Optional[int] = None,
do_normalize: Optional[bool] = None,
device: Optional[str] = "cpu",
return_token_timestamps: Optional[bool] = None,
**kwargs,
) -> BatchFeature:
"""
Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for
the STFT computation if available, otherwise a slower NumPy based one.
Args:
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
stereo, i.e. single float per timestep.
truncation (`bool`, *optional*, default to `True`):
Activates truncation to cut input sequences longer than *max_length* to *max_length*.
pad_to_multiple_of (`int`, *optional*, defaults to None):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
return_attention_mask (`bool`, *optional*):
Whether to return the attention mask. If left to the default, will return the attention mask according
to the specific feature_extractor's default.
[What are attention masks?](../glossary#attention-mask)
For Whisper models, `attention_mask` should always be passed for batched inference, to avoid subtle
bugs.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors instead of list of python integers. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return Numpy `np.ndarray` objects.
sampling_rate (`int`, *optional*):
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
`sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
pipeline.
padding_value (`float`, *optional*, defaults to 0.0):
The value that is used to fill the padding values / vectors.
do_normalize (`bool`, *optional*, defaults to `False`):
Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
improve the performance of the model.
device (`str`, *optional*, defaults to `'cpu'`):
Specifies the device for computation of the log-mel spectrogram of audio signals in the
`_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
return_token_timestamps (`bool`, *optional*, defaults to `None`):
Whether or not to return the number of frames of the input raw_speech.
These num_frames can be used by the model to compute word level timestamps.
"""
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
f" was sampled with {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
"Failing to do so can result in silent errors that might be hard to debug."
)
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
if is_batched_numpy and len(raw_speech.shape) > 2:
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
)
if is_batched:
raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float32)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
raw_speech = raw_speech.astype(np.float32)
# always return batch
if not is_batched:
raw_speech = [np.asarray([raw_speech]).T]
batched_speech = BatchFeature({"input_values": raw_speech})
# convert into correct format for padding
padded_inputs = self.pad( #whisper pads first then transform, while we do the reverse
batched_speech,
padding=padding,
max_length=max_length if max_length else self.n_samples,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask or do_normalize,
)
# zero-mean and unit-variance normalization
if do_normalize:
padded_inputs["input_values"] = self.zero_mean_unit_var_norm(
padded_inputs["input_values"],
attention_mask=padded_inputs["attention_mask"],
padding_value=self.padding_value,
)
padded_inputs["input_values"] = np.stack(padded_inputs["input_values"], axis=0)
# make sure list is in array format
input_values = padded_inputs.get("input_values").transpose(2, 0, 1)
extract_fbank_features = (
self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
)
input_values = extract_fbank_features(input_values[0], device)
if isinstance(input_values[0], List):
padded_inputs["input_values"] = [np.asarray(feature, dtype=np.float32) for feature in input_values]
else:
padded_inputs["input_values"] = input_values
if return_attention_mask:
# rescale from sample (48000) to feature (3000)
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
if return_token_timestamps is not None:
padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech]
if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
return padded_inputs
class MeralionBestRqConformerTokenizer(PreTrainedTokenizer):
"""
Constructs a MeralionBestRqConformer tokenizer. Based on `SentencePiece`.
Args:
vocab_file (`str`):
Path to the vocabulary file.
unk_token (`str`, *optional*, defaults to ""):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token.
pad_token (`str`, *optional*, defaults to ""):
The token used for padding, for example when batching sequences of different lengths.
bos_token (`str`, *optional*, defaults to ""):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str`, *optional*, defaults to ""):
The end of sequence token.
**kwargs
Additional keyword arguments passed along to
[`PreTrainedTokenizer.__init__`](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__init__).
"""
vocab_files_names = VOCAB_FILES_NAMES
def __init__(
self,
vocab_file,
unk_token="",
pad_token="",
bos_token="",
eos_token="",
blank_token="",
**kwargs
):
import sentencepiece as spm
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
super().__init__(
unk_token=unk_token,
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
blank_token=blank_token,
**kwargs,
)
self.blank_token_id = self.sp_model.piece_to_id(blank_token)
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
"""
Retrieves sequence of 0s and 1s specifying if corresponding token ID is a special token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [0] * len(token_ids_0)
return ([0] * len(token_ids_0)) + ([0] * len(token_ids_1))
def _tokenize(self, text: str) -> List[str]:
"""
Converts a string in a sequence of tokens (string), using the `sp_model` tokenizer.
"""
# SentencePiece doesn't like empty strings
if not text:
return []
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token: str) -> int:
"""
Converts a token (str) in an id (integer) using the vocab.
"""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index: int) -> str:
"""
Converts an id (integer) in a token (str) using the vocab.
"""
return self.sp_model.id_to_piece(index)
def convert_tokens_to_string(self, tokens: list[str]) -> str:
return self.sp_model.decode(tokens)
def decode(
self,
token_ids: Union[List[int], np.ndarray, torch.Tensor],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
group_tokens: bool = True,
**kwargs,
) -> str:
"""
Converts a sequence of ids in a string, using the tokenizer and vocabulary with CTC decoding logic.
"""
if isinstance(token_ids, (np.ndarray, torch.Tensor)):
token_ids = token_ids.tolist()
# CTC decoding
if group_tokens:
token_ids = [token_id for token_id, _ in itertools.groupby(token_ids)]
# Remove blank tokens
token_ids = [token_id for token_id in token_ids if token_id != self.blank_token_id]
return super().decode(
token_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
def batch_decode(
self,
sequences: Union[List[int], List[List[int]], np.ndarray, torch.Tensor],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: Optional[bool] = None,
**kwargs,
) -> List[str]:
"""
Convert a list of lists of token ids into a list of strings by calling decode.
"""
batch_decoded = [
self.decode(
seq,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
for seq in sequences
]
return batch_decoded
def get_vocab(self) -> Dict[str, int]:
"""
Returns the vocabulary as a dictionary of token to index.
"""
vocab = {self.sp_model.IdToPiece(i): i for i in range(self.sp_model.GetPieceSize())}
vocab.update(self.added_tokens_encoder)
return vocab
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
"""
if not os.path.isdir(save_directory):
os.makedirs(save_directory)
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
copyfile(self.vocab_file, vocab_file)
return (vocab_file,)
@property
def vocab_size(self) -> int:
return self.sp_model.get_piece_size()
class MeralionBestRqConformerProcessor(ProcessorMixin):
r"""
Constructs a Wav2Vec2 like processor which wraps a ModifiedWhisperFeatureExtractor feature extractor and a
MeralionBestRqConformerTokenizer sentencepiece tokenizer into a single
processor.
[`MeralionBestRqConformerProcessor`] offers all the functionalities of [`ModifiedWhisperFeatureExtractor`] and
[`MeralionBestRqConformerTokenizer`].
See the docstring of [`~MeralionBestRqConformerProcessor.__call__`] and [`~MeralionBestRqConformerProcessor.decode`]
for more information.
Args:
feature_extractor (`ModifiedWhisperFeatureExtractor`):
An instance of [`ModifiedWhisperFeatureExtractor`]. The feature extractor is a required input.
tokenizer ([`MeralionBestRqConformerTokenizer`]):
An instance of [`MeralionBestRqConformerTokenizer`]. The tokenizer is a required input.
"""
feature_extractor_class = "ModifiedWhisperFeatureExtractor"
tokenizer_class = "MeralionBestRqConformerTokenizer"
def __init__(self, feature_extractor, tokenizer):
self.feature_extractor = feature_extractor
self.tokenizer = tokenizer
self.chat_template = None
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
feature_extractor = ModifiedWhisperFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
tokenizer = MeralionBestRqConformerTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
def __call__(
self,
audio: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
*args,
**kwds,
):
return self.feature_extractor(audio, *args, **kwds)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to MeralionBestRqConformerTokenizer's [`~MeralionBestRqConformerTokenizer.batch_decode`].
Please refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to MeralionBestRqConformerTokenizer's [`~MeralionBestRqConformerTokenizer.decode`].
Please refer to the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)