freddyaboulton HF Staff commited on
Commit
bdb5512
·
verified ·
1 Parent(s): 872589f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -35
app.py CHANGED
@@ -7,7 +7,6 @@ from fastrtc import (
7
  ReplyOnPause,
8
  Stream,
9
  WebRTCError,
10
- audio_to_float32,
11
  get_current_context,
12
  get_hf_turn_credentials,
13
  get_hf_turn_credentials_async,
@@ -15,33 +14,6 @@ from fastrtc import (
15
  get_tts_model,
16
  )
17
  from huggingface_hub import InferenceClient
18
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
19
- import spaces
20
-
21
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
22
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
23
-
24
- model_id = "openai/whisper-large-v3-turbo"
25
-
26
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
27
- model_id,
28
- torch_dtype=torch_dtype,
29
- low_cpu_mem_usage=True,
30
- use_safetensors=True,
31
- )
32
- model.to(device)
33
-
34
- processor = AutoProcessor.from_pretrained(model_id)
35
-
36
- pipe = pipeline(
37
- "automatic-speech-recognition",
38
- model=model,
39
- tokenizer=processor.tokenizer,
40
- feature_extractor=processor.feature_extractor,
41
- torch_dtype=torch_dtype,
42
- device=device,
43
- )
44
-
45
 
46
  load_dotenv()
47
 
@@ -50,7 +22,7 @@ tts_model = get_tts_model()
50
 
51
  conversations: dict[str, list[dict[str, str]]] = {}
52
 
53
- @spaces.GPU
54
  def response(
55
  audio: tuple[int, np.ndarray],
56
  hf_token: str | None,
@@ -60,12 +32,6 @@ def response(
60
 
61
  llm_client = InferenceClient(provider="auto", token=hf_token)
62
 
63
- result = pipe(
64
- {"array": audio_to_float32(audio[1]).squeeze(), "sampling_rate": audio[0]},
65
- generate_kwargs={"language": "en"},
66
- )
67
- transcription = result["text"]
68
-
69
  context = get_current_context()
70
  if context.webrtc_id not in conversations:
71
  conversations[context.webrtc_id] = [
@@ -81,6 +47,8 @@ def response(
81
 
82
  messages = conversations[context.webrtc_id]
83
 
 
 
84
  messages.append({"role": "user", "content": transcription})
85
 
86
  output = llm_client.chat.completions.create( # type: ignore
 
7
  ReplyOnPause,
8
  Stream,
9
  WebRTCError,
 
10
  get_current_context,
11
  get_hf_turn_credentials,
12
  get_hf_turn_credentials_async,
 
14
  get_tts_model,
15
  )
16
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  load_dotenv()
19
 
 
22
 
23
  conversations: dict[str, list[dict[str, str]]] = {}
24
 
25
+
26
  def response(
27
  audio: tuple[int, np.ndarray],
28
  hf_token: str | None,
 
32
 
33
  llm_client = InferenceClient(provider="auto", token=hf_token)
34
 
 
 
 
 
 
 
35
  context = get_current_context()
36
  if context.webrtc_id not in conversations:
37
  conversations[context.webrtc_id] = [
 
47
 
48
  messages = conversations[context.webrtc_id]
49
 
50
+ transcription = stt_model.stt(audio)
51
+
52
  messages.append({"role": "user", "content": transcription})
53
 
54
  output = llm_client.chat.completions.create( # type: ignore