File size: 4,132 Bytes
7ff080c
 
 
 
 
a51f009
7bff197
7ff080c
7bff197
 
 
 
 
 
 
 
 
 
 
7ff080c
 
 
 
 
 
 
 
 
 
7bff197
 
 
 
 
 
 
 
 
7ff080c
7bff197
7ff080c
 
 
 
 
 
 
 
a51f009
7ff080c
7bff197
 
 
7ff080c
 
 
 
 
 
 
 
7bff197
 
 
 
 
 
7ff080c
 
7bff197
7ff080c
 
 
 
 
 
 
7bff197
7ff080c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bff197
7ff080c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import asyncio
import zlib
from functools import partial
from io import BytesIO

import torch
from hfendpoints import EndpointConfig, Handler, __version__
from hfendpoints.openai import Context, run
from hfendpoints.openai.audio import (
    AutomaticSpeechRecognitionEndpoint,
    Segment,
    SegmentBuilder,
    TranscriptionRequest,
    TranscriptionResponse,
    TranscriptionResponseKind,
    VerboseTranscription,
)
from librosa import get_duration
from librosa import load as load_audio
from loguru import logger
from nemo.collections.asr.models import ASRModel


def compression_ratio(text: str) -> float:
    text_bytes = text.encode("utf-8")
    return len(text_bytes) / len(zlib.compress(text_bytes))


def get_segment(idx: int, segment, tokenizer, request: TranscriptionRequest) -> Segment:
    return (
        SegmentBuilder()
        .id(idx)
        .start(segment["start"])
        .end(segment["end"])
        .text(segment["segment"])
        .tokens(tokenizer.text_to_ids(segment["segment"]))
        .temperature(request.temperature)
        .compression_ratio(compression_ratio(segment["segment"]))
        .build()
    )


class NemoAsrHandler(Handler):
    __slots__ = ("_model",)

    def __init__(self, config: EndpointConfig):
        logger.info(config.repository)
        self._model = ASRModel.from_pretrained(model_name=str(config.repository)).eval()
        self._model = self._model.to(torch.bfloat16)

    async def __call__(
        self, request: TranscriptionRequest, ctx: Context
    ) -> TranscriptionResponse:
        with logger.contextualize(request_id=ctx.request_id):
            with memoryview(request) as audio:
                (waveform, sampling) = load_audio(BytesIO(audio), sr=16000, mono=True)
                logger.debug(
                    f"Successfully decoded {len(waveform)} bytes PCM audio chunk"
                )

                # Do we need to compute the timestamps?
                needs_timestamps = (
                    request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
                )
                transcribe_f = partial(
                    self._model.transcribe, timestamps=needs_timestamps, verbose=False
                )

                outputs = await asyncio.get_running_loop().run_in_executor(
                    None, transcribe_f, (waveform,)
                )

                output = outputs[0]
                text = output.text

                match request.response_kind:
                    case TranscriptionResponseKind.VERBOSE_JSON:
                        segment_timestamps = output.timestamp["segment"]
                        segments = [
                            get_segment(idx, stamp, self._model.tokenizer, request)
                            for (idx, stamp) in enumerate(segment_timestamps)
                        ]

                        logger.info(f"Segment: {segment_timestamps[0]}")

                        return TranscriptionResponse.verbose(
                            VerboseTranscription(
                                text=text,
                                duration=get_duration(y=waveform, sr=sampling),
                                language=request.language,
                                segments=segments,
                                # word=None
                            )
                        )
                    case TranscriptionResponseKind.JSON:
                        return TranscriptionResponse.json(text)

                    case TranscriptionResponseKind.TEXT:
                        return TranscriptionResponse.text(text)

                # Theoretically, we can't end up there as Rust validates the enum value beforehand
                raise RuntimeError(f"unknown response_kind: {request.response_kind}")


def entrypoint():
    config = EndpointConfig.from_env()
    handler = NemoAsrHandler(config)
    endpoint = AutomaticSpeechRecognitionEndpoint(handler)

    logger.info(f"[Hugging Face Endpoint v{__version__}] Serving: {config.model_id}")
    run(endpoint, config.interface, config.port)


if __name__ == "__main__":
    entrypoint()