feat(parakeet): pin `torch` and fix formatting

#1
by alvarobartt HF Staff - opened
Files changed (3) hide show
  1. Dockerfile +4 -4
  2. handler.py +34 -25
  3. requirements.txt +7 -4
Dockerfile CHANGED
@@ -1,4 +1,5 @@
1
- ARG SDK_VERSION=latest
 
2
  FROM huggingface/hfendpoints-sdk:${SDK_VERSION} AS sdk
3
 
4
  FROM nvcr.io/nvidia/nemo:25.04
@@ -7,8 +8,7 @@ RUN --mount=type=bind,from=sdk,source=/opt/hfendpoints/dist,target=/usr/local/en
7
  python3 -m pip install -r /tmp/requirements.txt && \
8
  python3 -m pip install /usr/local/endpoints/dist/*.whl
9
 
10
-
11
- COPY handler.py /usr/local/endpoint/
12
 
13
  # Disable TQDM
14
  ENV TQDM_DISABLE=1
@@ -20,4 +20,4 @@ ENV PORT=80
20
  EXPOSE 80
21
 
22
  ENTRYPOINT ["python3"]
23
- CMD ["/usr/local/endpoint/handler.py"]
 
1
+ # ARG SDK_VERSION=6751aaa
2
+ ARG SDK_VERSION=v0.2.0
3
  FROM huggingface/hfendpoints-sdk:${SDK_VERSION} AS sdk
4
 
5
  FROM nvcr.io/nvidia/nemo:25.04
 
8
  python3 -m pip install -r /tmp/requirements.txt && \
9
  python3 -m pip install /usr/local/endpoints/dist/*.whl
10
 
11
+ COPY handler.py /usr/local/endpoint/handler.py
 
12
 
13
  # Disable TQDM
14
  ENV TQDM_DISABLE=1
 
20
  EXPOSE 80
21
 
22
  ENTRYPOINT ["python3"]
23
+ CMD ["/usr/local/endpoint/handler.py"]
handler.py CHANGED
@@ -4,35 +4,40 @@ from functools import partial
4
  from io import BytesIO
5
 
6
  import torch
 
7
  from hfendpoints.openai import Context, run
8
- from hfendpoints.openai.audio import AutomaticSpeechRecognitionEndpoint, SegmentBuilder, Segment, \
9
- TranscriptionRequest, TranscriptionResponse, TranscriptionResponseKind, VerboseTranscription
10
- from librosa import load as load_audio, get_duration
 
 
 
 
 
 
 
 
11
  from loguru import logger
12
  from nemo.collections.asr.models import ASRModel
13
 
14
- from hfendpoints import EndpointConfig, Handler, __version__
15
-
16
 
17
  def compression_ratio(text: str) -> float:
18
- """
19
- :param text:
20
- :return:
21
- """
22
  text_bytes = text.encode("utf-8")
23
  return len(text_bytes) / len(zlib.compress(text_bytes))
24
 
25
 
26
  def get_segment(idx: int, segment, tokenizer, request: TranscriptionRequest) -> Segment:
27
- return SegmentBuilder() \
28
- .id(idx) \
29
- .start(segment['start']) \
30
- .end(segment['end']) \
31
- .text(segment['segment']) \
32
- .tokens(tokenizer.text_to_ids(segment['segment'])) \
33
- .temperature(request.temperature) \
34
- .compression_ratio(compression_ratio(segment['segment'])) \
 
35
  .build()
 
36
 
37
 
38
  class NemoAsrHandler(Handler):
@@ -43,7 +48,9 @@ class NemoAsrHandler(Handler):
43
  self._model = ASRModel.from_pretrained(model_name=str(config.repository)).eval()
44
  self._model = self._model.to(torch.bfloat16)
45
 
46
- async def __call__(self, request: TranscriptionRequest, ctx: Context) -> TranscriptionResponse:
 
 
47
  with logger.contextualize(request_id=ctx.request_id):
48
  with memoryview(request) as audio:
49
  (waveform, sampling) = load_audio(BytesIO(audio), sr=16000, mono=True)
@@ -52,13 +59,15 @@ class NemoAsrHandler(Handler):
52
  )
53
 
54
  # Do we need to compute the timestamps?
55
- needs_timestamps = request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
56
- transcribe_f = partial(self._model.transcribe, timestamps=needs_timestamps, verbose=False)
 
 
 
 
57
 
58
  outputs = await asyncio.get_running_loop().run_in_executor(
59
- None,
60
- transcribe_f,
61
- (waveform,)
62
  )
63
 
64
  output = outputs[0]
@@ -66,7 +75,7 @@ class NemoAsrHandler(Handler):
66
 
67
  match request.response_kind:
68
  case TranscriptionResponseKind.VERBOSE_JSON:
69
- segment_timestamps = output.timestamp['segment']
70
  segments = [
71
  get_segment(idx, stamp, self._model.tokenizer, request)
72
  for (idx, stamp) in enumerate(segment_timestamps)
@@ -102,5 +111,5 @@ def entrypoint():
102
  run(endpoint, config.interface, config.port)
103
 
104
 
105
- if __name__ == '__main__':
106
  entrypoint()
 
4
  from io import BytesIO
5
 
6
  import torch
7
+ from hfendpoints import EndpointConfig, Handler, __version__
8
  from hfendpoints.openai import Context, run
9
+ from hfendpoints.openai.audio import (
10
+ AutomaticSpeechRecognitionEndpoint,
11
+ Segment,
12
+ SegmentBuilder,
13
+ TranscriptionRequest,
14
+ TranscriptionResponse,
15
+ TranscriptionResponseKind,
16
+ VerboseTranscription,
17
+ )
18
+ from librosa import get_duration
19
+ from librosa import load as load_audio
20
  from loguru import logger
21
  from nemo.collections.asr.models import ASRModel
22
 
 
 
23
 
24
  def compression_ratio(text: str) -> float:
 
 
 
 
25
  text_bytes = text.encode("utf-8")
26
  return len(text_bytes) / len(zlib.compress(text_bytes))
27
 
28
 
29
  def get_segment(idx: int, segment, tokenizer, request: TranscriptionRequest) -> Segment:
30
+ return (
31
+ SegmentBuilder()
32
+ .id(idx)
33
+ .start(segment["start"])
34
+ .end(segment["end"])
35
+ .text(segment["segment"])
36
+ .tokens(tokenizer.text_to_ids(segment["segment"]))
37
+ .temperature(request.temperature)
38
+ .compression_ratio(compression_ratio(segment["segment"]))
39
  .build()
40
+ )
41
 
42
 
43
  class NemoAsrHandler(Handler):
 
48
  self._model = ASRModel.from_pretrained(model_name=str(config.repository)).eval()
49
  self._model = self._model.to(torch.bfloat16)
50
 
51
+ async def __call__(
52
+ self, request: TranscriptionRequest, ctx: Context
53
+ ) -> TranscriptionResponse:
54
  with logger.contextualize(request_id=ctx.request_id):
55
  with memoryview(request) as audio:
56
  (waveform, sampling) = load_audio(BytesIO(audio), sr=16000, mono=True)
 
59
  )
60
 
61
  # Do we need to compute the timestamps?
62
+ needs_timestamps = (
63
+ request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
64
+ )
65
+ transcribe_f = partial(
66
+ self._model.transcribe, timestamps=needs_timestamps, verbose=False
67
+ )
68
 
69
  outputs = await asyncio.get_running_loop().run_in_executor(
70
+ None, transcribe_f, (waveform,)
 
 
71
  )
72
 
73
  output = outputs[0]
 
75
 
76
  match request.response_kind:
77
  case TranscriptionResponseKind.VERBOSE_JSON:
78
+ segment_timestamps = output.timestamp["segment"]
79
  segments = [
80
  get_segment(idx, stamp, self._model.tokenizer, request)
81
  for (idx, stamp) in enumerate(segment_timestamps)
 
111
  run(endpoint, config.interface, config.port)
112
 
113
 
114
+ if __name__ == "__main__":
115
  entrypoint()
requirements.txt CHANGED
@@ -1,5 +1,8 @@
1
- huggingface_hub [hf_xet]
2
- librosa >= 0.11.0
3
- nemo_toolkit [asr] >= 2.3.0
 
 
 
4
  numpy
5
- tqdm
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+ torch>=2.6.0,<2.7.0
3
+ torchvision
4
+ huggingface_hub[hf_xet]
5
+ librosa>=0.11.0
6
+ nemo_toolkit[asr]>=2.3.0
7
  numpy
8
+ tqdm