nvidia-nemo-asr / handler.py
alvarobartt's picture
alvarobartt HF Staff
feat(parakeet): pin `torch` and fix formatting
631b8a3 verified
raw
history blame
4.13 kB
import asyncio
import zlib
from functools import partial
from io import BytesIO
import torch
from hfendpoints import EndpointConfig, Handler, __version__
from hfendpoints.openai import Context, run
from hfendpoints.openai.audio import (
AutomaticSpeechRecognitionEndpoint,
Segment,
SegmentBuilder,
TranscriptionRequest,
TranscriptionResponse,
TranscriptionResponseKind,
VerboseTranscription,
)
from librosa import get_duration
from librosa import load as load_audio
from loguru import logger
from nemo.collections.asr.models import ASRModel
def compression_ratio(text: str) -> float:
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))
def get_segment(idx: int, segment, tokenizer, request: TranscriptionRequest) -> Segment:
return (
SegmentBuilder()
.id(idx)
.start(segment["start"])
.end(segment["end"])
.text(segment["segment"])
.tokens(tokenizer.text_to_ids(segment["segment"]))
.temperature(request.temperature)
.compression_ratio(compression_ratio(segment["segment"]))
.build()
)
class NemoAsrHandler(Handler):
__slots__ = ("_model",)
def __init__(self, config: EndpointConfig):
logger.info(config.repository)
self._model = ASRModel.from_pretrained(model_name=str(config.repository)).eval()
self._model = self._model.to(torch.bfloat16)
async def __call__(
self, request: TranscriptionRequest, ctx: Context
) -> TranscriptionResponse:
with logger.contextualize(request_id=ctx.request_id):
with memoryview(request) as audio:
(waveform, sampling) = load_audio(BytesIO(audio), sr=16000, mono=True)
logger.debug(
f"Successfully decoded {len(waveform)} bytes PCM audio chunk"
)
# Do we need to compute the timestamps?
needs_timestamps = (
request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
)
transcribe_f = partial(
self._model.transcribe, timestamps=needs_timestamps, verbose=False
)
outputs = await asyncio.get_running_loop().run_in_executor(
None, transcribe_f, (waveform,)
)
output = outputs[0]
text = output.text
match request.response_kind:
case TranscriptionResponseKind.VERBOSE_JSON:
segment_timestamps = output.timestamp["segment"]
segments = [
get_segment(idx, stamp, self._model.tokenizer, request)
for (idx, stamp) in enumerate(segment_timestamps)
]
logger.info(f"Segment: {segment_timestamps[0]}")
return TranscriptionResponse.verbose(
VerboseTranscription(
text=text,
duration=get_duration(y=waveform, sr=sampling),
language=request.language,
segments=segments,
# word=None
)
)
case TranscriptionResponseKind.JSON:
return TranscriptionResponse.json(text)
case TranscriptionResponseKind.TEXT:
return TranscriptionResponse.text(text)
# Theoretically, we can't end up there as Rust validates the enum value beforehand
raise RuntimeError(f"unknown response_kind: {request.response_kind}")
def entrypoint():
config = EndpointConfig.from_env()
handler = NemoAsrHandler(config)
endpoint = AutomaticSpeechRecognitionEndpoint(handler)
logger.info(f"[Hugging Face Endpoint v{__version__}] Serving: {config.model_id}")
run(endpoint, config.interface, config.port)
if __name__ == "__main__":
entrypoint()