VED-AGI-1 commited on
Commit
b7a949a
·
verified ·
1 Parent(s): 64972fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -18
app.py CHANGED
@@ -24,6 +24,7 @@ except Exception:
24
  from transformers import AutoTokenizer, AutoModelForCausalLM
25
  from huggingface_hub import login, HfApi
26
 
 
27
  # -------------------
28
  # Configuration
29
  # -------------------
@@ -37,8 +38,9 @@ HF_TOKEN = (
37
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
38
  USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
39
 
 
40
  # -------------------
41
- # Helpers (for page header / connection card only)
42
  # -------------------
43
  def local_now_str(user_tz: str | None) -> tuple[str, str]:
44
  """Returns (label, formatted_time). Falls back to UTC if tz missing/invalid."""
@@ -54,7 +56,9 @@ def local_now_str(user_tz: str | None) -> tuple[str, str]:
54
  label = "UTC"
55
  return label, dt.strftime("%Y-%m-%d %H:%M:%S")
56
 
 
57
  def header(processing_time=None, user_tz: str | None = None):
 
58
  tz_label, now_str = local_now_str(user_tz)
59
  s = (
60
  f"Current Date and Time ({tz_label} - YYYY-MM-DD HH:MM:SS formatted): {now_str}\n"
@@ -64,6 +68,7 @@ def header(processing_time=None, user_tz: str | None = None):
64
  s += f"Processing Time: {processing_time:.2f} seconds\n"
65
  return s
66
 
 
67
  def pick_dtype_and_map():
68
  if torch.cuda.is_available():
69
  return torch.float16, "auto"
@@ -71,6 +76,7 @@ def pick_dtype_and_map():
71
  return torch.float16, {"": "mps"}
72
  return torch.float32, "cpu" # CPU path (likely too big for R7B)
73
 
 
74
  def is_identity_query(message: str, history) -> bool:
75
  """Detects identity questions in current message or most recent user turn."""
76
  patterns = [
@@ -85,17 +91,23 @@ def is_identity_query(message: str, history) -> bool:
85
  r"\byour\s+name\b",
86
  r"\bwho\s+am\s+i\s+chatting\s+with\b",
87
  ]
 
88
  def hit(text: str | None) -> bool:
89
  t = (text or "").strip().lower()
90
  return any(re.search(p, t) for p in patterns)
 
91
  if hit(message):
92
  return True
 
93
  if history:
 
94
  last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) and history[-1] else None
95
  if hit(last_user):
96
  return True
 
97
  return False
98
 
 
99
  # -------------------
100
  # Cohere Hosted Path
101
  # -------------------
@@ -103,6 +115,7 @@ _co_client = None
103
  if USE_HOSTED_COHERE:
104
  _co_client = cohere.Client(api_key=COHERE_API_KEY)
105
 
 
106
  def _cohere_parse(resp):
107
  # v5+ responses.create
108
  if hasattr(resp, "output_text") and resp.output_text:
@@ -116,6 +129,7 @@ def _cohere_parse(resp):
116
  return resp.text.strip()
117
  return "Sorry, I couldn't parse the response from Cohere."
118
 
 
119
  def cohere_chat(message, history):
120
  try:
121
  # Prefer modern API
@@ -143,6 +157,7 @@ def cohere_chat(message, history):
143
  except Exception as e:
144
  return f"Error calling Cohere API: {e}"
145
 
 
146
  # -------------------
147
  # Local HF Path
148
  # -------------------
@@ -153,20 +168,31 @@ def load_local_model():
153
  "HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is not set. "
154
  "Either set it, or provide COHERE_API_KEY to use Cohere's hosted API."
155
  )
 
156
  login(token=HF_TOKEN, add_to_git_credential=False)
 
157
  dtype, device_map = pick_dtype_and_map()
158
  tok = AutoTokenizer.from_pretrained(
159
- MODEL_ID, token=HF_TOKEN, use_fast=True, model_max_length=4096,
160
- padding_side="left", trust_remote_code=True,
 
 
 
 
161
  )
162
  mdl = AutoModelForCausalLM.from_pretrained(
163
- MODEL_ID, token=HF_TOKEN, device_map=device_map, low_cpu_mem_usage=True,
164
- torch_dtype=dtype, trust_remote_code=True,
 
 
 
 
165
  )
166
  if mdl.config.eos_token_id is None and tok.eos_token_id is not None:
167
  mdl.config.eos_token_id = tok.eos_token_id
168
  return mdl, tok
169
 
 
170
  def build_inputs(tokenizer, message, history):
171
  msgs = []
172
  for u, a in (history or []):
@@ -177,6 +203,7 @@ def build_inputs(tokenizer, message, history):
177
  msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt"
178
  )
179
 
 
180
  def local_generate(model, tokenizer, input_ids, max_new_tokens=350):
181
  input_ids = input_ids.to(model.device)
182
  with torch.no_grad():
@@ -194,6 +221,7 @@ def local_generate(model, tokenizer, input_ids, max_new_tokens=350):
194
  text = tokenizer.decode(gen_only, skip_special_tokens=True)
195
  return text.strip()
196
 
 
197
  # -------------------
198
  # Chat callback (no header/meta in chat replies)
199
  # -------------------
@@ -218,6 +246,7 @@ def chat_fn(message, history, user_tz):
218
  except Exception as e:
219
  return f"Error during chat: {e}"
220
 
 
221
  # -------------------
222
  # Connection check (keeps header/meta)
223
  # -------------------
@@ -243,22 +272,23 @@ def check_connection(user_tz=None):
243
  except Exception as e:
244
  return f"{header(user_tz=user_tz)}Connection Status: ❌ Error\nDetails: {e}"
245
 
 
246
  # -------------------
247
  # UI
248
  # -------------------
249
  with gr.Blocks(theme=gr.themes.Default()) as demo:
250
- # Capture browser timezone via JS and store in state
251
  user_tz_state = gr.State("")
252
- # On load, capture browser timezone via JS and store in user_tz_state
253
- demo.load(
254
- fn=lambda tz: tz, # echo the JS value back to Gradio
255
- inputs=None,
256
- outputs=[user_tz_state], # outputs must be a LIST
257
- js="() => Intl.DateTimeFormat().resolvedOptions().timeZone"
258
- )
259
 
 
 
 
 
 
 
 
260
 
261
- gr.Markdown(f"# Medical Decision Support AI\n{header(user_tz=None)}")
262
 
263
  with gr.Row():
264
  btn = gr.Button("Check Connection Status")
@@ -273,7 +303,7 @@ demo.load(
273
  chat = gr.ChatInterface(
274
  fn=chat_fn,
275
  type="messages",
276
- additional_inputs=[user_tz_state], # pass timezone into chat_fn
277
  description="A medical decision support system that provides healthcare-related information and decision making support.",
278
  examples=[
279
  ["What are the symptoms of hypertension?", ""],
@@ -283,12 +313,10 @@ demo.load(
283
  cache_examples=False,
284
  )
285
 
 
286
  btn.click(fn=check_connection, inputs=user_tz_state, outputs=status)
287
 
288
  if __name__ == "__main__":
289
  demo.launch()
290
 
291
 
292
-
293
-
294
-
 
24
  from transformers import AutoTokenizer, AutoModelForCausalLM
25
  from huggingface_hub import login, HfApi
26
 
27
+
28
  # -------------------
29
  # Configuration
30
  # -------------------
 
38
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
39
  USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
40
 
41
+
42
  # -------------------
43
+ # Helpers (used for the connection card only)
44
  # -------------------
45
  def local_now_str(user_tz: str | None) -> tuple[str, str]:
46
  """Returns (label, formatted_time). Falls back to UTC if tz missing/invalid."""
 
56
  label = "UTC"
57
  return label, dt.strftime("%Y-%m-%d %H:%M:%S")
58
 
59
+
60
  def header(processing_time=None, user_tz: str | None = None):
61
+ """Only used in the connection status panel (not in chat replies)."""
62
  tz_label, now_str = local_now_str(user_tz)
63
  s = (
64
  f"Current Date and Time ({tz_label} - YYYY-MM-DD HH:MM:SS formatted): {now_str}\n"
 
68
  s += f"Processing Time: {processing_time:.2f} seconds\n"
69
  return s
70
 
71
+
72
  def pick_dtype_and_map():
73
  if torch.cuda.is_available():
74
  return torch.float16, "auto"
 
76
  return torch.float16, {"": "mps"}
77
  return torch.float32, "cpu" # CPU path (likely too big for R7B)
78
 
79
+
80
  def is_identity_query(message: str, history) -> bool:
81
  """Detects identity questions in current message or most recent user turn."""
82
  patterns = [
 
91
  r"\byour\s+name\b",
92
  r"\bwho\s+am\s+i\s+chatting\s+with\b",
93
  ]
94
+
95
  def hit(text: str | None) -> bool:
96
  t = (text or "").strip().lower()
97
  return any(re.search(p, t) for p in patterns)
98
+
99
  if hit(message):
100
  return True
101
+
102
  if history:
103
+ # Gradio history: List[Tuple[user, assistant]]
104
  last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) and history[-1] else None
105
  if hit(last_user):
106
  return True
107
+
108
  return False
109
 
110
+
111
  # -------------------
112
  # Cohere Hosted Path
113
  # -------------------
 
115
  if USE_HOSTED_COHERE:
116
  _co_client = cohere.Client(api_key=COHERE_API_KEY)
117
 
118
+
119
  def _cohere_parse(resp):
120
  # v5+ responses.create
121
  if hasattr(resp, "output_text") and resp.output_text:
 
129
  return resp.text.strip()
130
  return "Sorry, I couldn't parse the response from Cohere."
131
 
132
+
133
  def cohere_chat(message, history):
134
  try:
135
  # Prefer modern API
 
157
  except Exception as e:
158
  return f"Error calling Cohere API: {e}"
159
 
160
+
161
  # -------------------
162
  # Local HF Path
163
  # -------------------
 
168
  "HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is not set. "
169
  "Either set it, or provide COHERE_API_KEY to use Cohere's hosted API."
170
  )
171
+
172
  login(token=HF_TOKEN, add_to_git_credential=False)
173
+
174
  dtype, device_map = pick_dtype_and_map()
175
  tok = AutoTokenizer.from_pretrained(
176
+ MODEL_ID,
177
+ token=HF_TOKEN,
178
+ use_fast=True,
179
+ model_max_length=4096,
180
+ padding_side="left",
181
+ trust_remote_code=True,
182
  )
183
  mdl = AutoModelForCausalLM.from_pretrained(
184
+ MODEL_ID,
185
+ token=HF_TOKEN,
186
+ device_map=device_map,
187
+ low_cpu_mem_usage=True,
188
+ torch_dtype=dtype,
189
+ trust_remote_code=True,
190
  )
191
  if mdl.config.eos_token_id is None and tok.eos_token_id is not None:
192
  mdl.config.eos_token_id = tok.eos_token_id
193
  return mdl, tok
194
 
195
+
196
  def build_inputs(tokenizer, message, history):
197
  msgs = []
198
  for u, a in (history or []):
 
203
  msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt"
204
  )
205
 
206
+
207
  def local_generate(model, tokenizer, input_ids, max_new_tokens=350):
208
  input_ids = input_ids.to(model.device)
209
  with torch.no_grad():
 
221
  text = tokenizer.decode(gen_only, skip_special_tokens=True)
222
  return text.strip()
223
 
224
+
225
  # -------------------
226
  # Chat callback (no header/meta in chat replies)
227
  # -------------------
 
246
  except Exception as e:
247
  return f"Error during chat: {e}"
248
 
249
+
250
  # -------------------
251
  # Connection check (keeps header/meta)
252
  # -------------------
 
272
  except Exception as e:
273
  return f"{header(user_tz=user_tz)}Connection Status: ❌ Error\nDetails: {e}"
274
 
275
+
276
  # -------------------
277
  # UI
278
  # -------------------
279
  with gr.Blocks(theme=gr.themes.Default()) as demo:
280
+ # Hold browser timezone (e.g., "America/Toronto")
281
  user_tz_state = gr.State("")
 
 
 
 
 
 
 
282
 
283
+ # On load, capture browser timezone via JS and store in user_tz_state
284
+ demo.load(
285
+ fn=lambda tz: tz, # echo the JS value
286
+ inputs=None,
287
+ outputs=[user_tz_state], # outputs must be a LIST
288
+ js="() => Intl.DateTimeFormat().resolvedOptions().timeZone"
289
+ )
290
 
291
+ gr.Markdown("# Medical Decision Support AI")
292
 
293
  with gr.Row():
294
  btn = gr.Button("Check Connection Status")
 
303
  chat = gr.ChatInterface(
304
  fn=chat_fn,
305
  type="messages",
306
+ additional_inputs=[user_tz_state], # pass timezone into chat_fn (for future use)
307
  description="A medical decision support system that provides healthcare-related information and decision making support.",
308
  examples=[
309
  ["What are the symptoms of hypertension?", ""],
 
313
  cache_examples=False,
314
  )
315
 
316
+ # Wire timezone into the connection check as well
317
  btn.click(fn=check_connection, inputs=user_tz_state, outputs=status)
318
 
319
  if __name__ == "__main__":
320
  demo.launch()
321
 
322