Spaces:
Sleeping
Sleeping
import os | |
import re | |
from functools import lru_cache | |
import gradio as gr | |
import torch | |
# ------------------- | |
# Writable caches for HF + Gradio (fixes PermissionError in Spaces) | |
# ------------------- | |
os.environ.setdefault("HF_HOME", "/data/.cache/huggingface") | |
os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.cache/huggingface/transformers") | |
os.environ.setdefault("HF_HUB_CACHE", "/data/.cache/huggingface/hub") | |
os.environ.setdefault("GRADIO_TEMP_DIR", "/data/gradio") | |
os.environ.setdefault("GRADIO_CACHE_DIR", "/data/gradio") | |
for p in [ | |
"/data/.cache/huggingface/transformers", | |
"/data/.cache/huggingface/hub", | |
"/data/gradio", | |
]: | |
try: | |
os.makedirs(p, exist_ok=True) | |
except Exception: | |
pass | |
# Timezone (Python 3.9+) | |
try: | |
from zoneinfo import ZoneInfo | |
except Exception: | |
ZoneInfo = None | |
# Cohere SDK (hosted path) | |
try: | |
import cohere | |
_HAS_COHERE = True | |
except Exception: | |
_HAS_COHERE = False | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from huggingface_hub import login, HfApi | |
# ------------------- | |
# Config | |
# ------------------- | |
MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024") | |
HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN") | |
COHERE_API_KEY = os.getenv("COHERE_API_KEY") | |
USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE) | |
# ------------------- | |
# Helpers | |
# ------------------- | |
def pick_dtype_and_map(): | |
if torch.cuda.is_available(): | |
return torch.float16, "auto" | |
if torch.backends.mps.is_available(): | |
return torch.float16, {"": "mps"} | |
return torch.float32, "cpu" | |
def is_identity_query(message, history): | |
patterns = [ | |
r"\bwho\s+are\s+you\b", r"\bwhat\s+are\s+you\b", | |
r"\bwhat\s+is\s+your\s+name\b", r"\bwho\s+is\s+this\b", | |
r"\bidentify\s+yourself\b", r"\btell\s+me\s+about\s+yourself\b", | |
r"\bdescribe\s+yourself\b", r"\band\s+you\s*\?\b", | |
r"\byour\s+name\b", r"\bwho\s+am\s+i\s+chatting\s+with\b" | |
] | |
def match(t): | |
return any(re.search(p, (t or "").strip().lower()) for p in patterns) | |
if match(message): | |
return True | |
if history: | |
last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) else None | |
if match(last_user): | |
return True | |
return False | |
def _history_to_prompt(message, history): | |
"""Build a simple text prompt for the stable cohere.chat API.""" | |
parts = [] | |
for u, a in (history or []): | |
if u: | |
parts.append(f"User: {u}") | |
if a: | |
parts.append(f"Assistant: {a}") | |
parts.append(f"User: {message}") | |
parts.append("Assistant:") | |
return "\n".join(parts) | |
# ------------------- | |
# Cohere Hosted | |
# ------------------- | |
_co_client = None | |
if USE_HOSTED_COHERE: | |
_co_client = cohere.Client(api_key=COHERE_API_KEY) | |
def cohere_chat(message, history): | |
try: | |
prompt = _history_to_prompt(message, history) | |
resp = _co_client.chat( | |
model="command-r7b-12-2024", | |
message=prompt, | |
temperature=0.3, | |
max_tokens=350, | |
) | |
if hasattr(resp, "text") and resp.text: | |
return resp.text.strip() | |
if hasattr(resp, "reply") and resp.reply: | |
return resp.reply.strip() | |
if hasattr(resp, "generations") and resp.generations: | |
return resp.generations[0].text.strip() | |
return "Sorry, I couldn't parse the response from Cohere." | |
except Exception as e: | |
return f"Error calling Cohere API: {e}" | |
# ------------------- | |
# Local HF Model | |
# ------------------- | |
def load_local_model(): | |
if not HF_TOKEN: | |
raise RuntimeError("HUGGINGFACE_HUB_TOKEN is not set.") | |
login(token=HF_TOKEN, add_to_git_credential=False) | |
dtype, device_map = pick_dtype_and_map() | |
tok = AutoTokenizer.from_pretrained( | |
MODEL_ID, | |
token=HF_TOKEN, | |
use_fast=True, | |
model_max_length=4096, | |
padding_side="left", | |
trust_remote_code=True, | |
) | |
mdl = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
token=HF_TOKEN, | |
device_map=device_map, | |
low_cpu_mem_usage=True, | |
torch_dtype=dtype, | |
trust_remote_code=True, | |
) | |
if mdl.config.eos_token_id is None and tok.eos_token_id is not None: | |
mdl.config.eos_token_id = tok.eos_token_id | |
return mdl, tok | |
def build_inputs(tokenizer, message, history): | |
msgs = [] | |
for u, a in (history or []): | |
msgs.append({"role": "user", "content": u}) | |
msgs.append({"role": "assistant", "content": a}) | |
msgs.append({"role": "user", "content": message}) | |
return tokenizer.apply_chat_template( | |
msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt" | |
) | |
def local_generate(model, tokenizer, input_ids, max_new_tokens=350): | |
input_ids = input_ids.to(model.device) | |
with torch.no_grad(): | |
out = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=0.3, | |
top_p=0.9, | |
repetition_penalty=1.15, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
gen_only = out[0, input_ids.shape[-1]:] | |
return tokenizer.decode(gen_only, skip_special_tokens=True).strip() | |
# ------------------- | |
# Chat Function | |
# ------------------- | |
def chat_fn(message, history, user_tz): | |
try: | |
if is_identity_query(message, history): | |
return "I am ClarityOps, your strategic decision making AI partner." | |
if USE_HOSTED_COHERE: | |
return cohere_chat(message, history) | |
model, tokenizer = load_local_model() | |
inputs = build_inputs(tokenizer, message, history) | |
return local_generate(model, tokenizer, inputs, max_new_tokens=350) | |
except Exception as e: | |
return f"Error: {e}" | |
# ------------------- | |
# Theme & CSS | |
# ------------------- | |
theme = gr.themes.Soft( | |
primary_hue="teal", | |
neutral_hue="slate", | |
radius_size=gr.themes.sizes.radius_lg, | |
) | |
custom_css = """ | |
:root { | |
--brand-bg: #e6f7f8; /* soft medical teal */ | |
--brand-accent: #0d9488; /* teal-600 */ | |
--brand-text: #0f172a; | |
--brand-text-light: #ffffff; | |
} | |
/* Page background */ | |
.gradio-container { | |
background: var(--brand-bg); | |
} | |
/* Title */ | |
h1 { | |
color: var(--brand-text); | |
font-weight: 700; | |
font-size: 28px !important; | |
} | |
/* Try to hide the default Chatbot label via CSS for multiple Gradio builds */ | |
.chatbot header, | |
.chatbot .label, | |
.chatbot .label-wrap, | |
.chatbot .top, | |
.chatbot .header, | |
.chatbot > .wrap > header { | |
display: none !important; | |
} | |
/* Both bot and user bubbles teal with white text */ | |
.message.user, .message.bot { | |
background: var(--brand-accent) !important; | |
color: var(--brand-text-light) !important; | |
border-radius: 12px !important; | |
padding: 8px 12px !important; | |
} | |
/* Inputs a bit softer */ | |
textarea, input, .gr-input { | |
border-radius: 12px !important; | |
} | |
""" | |
# ------------------- | |
# UI | |
# ------------------- | |
with gr.Blocks(theme=theme, css=custom_css) as demo: | |
# Hidden box to carry timezone (still useful for future features) | |
tz_box = gr.Textbox(visible=False) | |
demo.load(lambda tz: tz, inputs=[tz_box], outputs=[tz_box], | |
js="() => Intl.DateTimeFormat().resolvedOptions().timeZone") | |
# Extra JS hard-removal of the Chatbot label to cover all DOM variants | |
hide_label_sink = gr.HTML(visible=False) | |
demo.load( | |
fn=lambda: "", | |
inputs=None, | |
outputs=hide_label_sink, | |
js=""" | |
() => { | |
const sel = [ | |
'.chatbot header', | |
'.chatbot .label', | |
'.chatbot .label-wrap', | |
'.chatbot .top', | |
'.chatbot .header', | |
'.chatbot > .wrap > header' | |
]; | |
sel.forEach(s => document.querySelectorAll(s).forEach(el => el.style.display = 'none')); | |
return ""; | |
} | |
""" | |
) | |
# Updated title | |
gr.Markdown("# ClarityOps Augmented Decision AI") | |
gr.ChatInterface( | |
fn=chat_fn, | |
type="messages", | |
additional_inputs=[tz_box], | |
chatbot=gr.Chatbot(label="", show_label=False, type="messages"), # aligned type + no label | |
examples=[ | |
["What are the symptoms of hypertension?", ""], | |
["What are common drug interactions with aspirin?", ""], | |
["What are the warning signs of diabetes?", ""], | |
], | |
cache_examples=False, # prevent permission error in Spaces | |
) | |
if __name__ == "__main__": | |
demo.launch() | |