from fastapi import FastAPI from pydantic import BaseModel from huggingface_hub import InferenceClient import os from dotenv import load_dotenv # Load environment variables from .env file, which HF Spaces uses for secrets load_dotenv() app = FastAPI() # Get the token from the environment (Hugging Face secrets) # Use a fallback for local testing if you want token = os.getenv("HF_TOKEN") # Initialize the client with your token client = InferenceClient( "google/gemma-3n-E4B-it", token=token ) class Item(BaseModel): prompt: str history: list = [] system_prompt: str = "You are a helpful AI assistant." temperature: float = 0.7 max_new_tokens: int = 1024 top_p: float = 0.95 repetition_penalty: float = 1.0 # CORRECT Llama-3 prompt formatting function def format_prompt(message, history, system_prompt): messages = [ {"role": "system", "content": system_prompt} ] for user_prompt, bot_response in history: messages.append({"role": "user", "content": user_prompt}) messages.append({"role": "assistant", "content": bot_response}) messages.append({"role": "user", "content": message}) # This is the official Llama-3 chat template return client.chat_completion(messages, stream=False, max_tokens=1).prompt # HACK: Use the client to build the prompt def generate(item: Item): temperature = item.temperature if temperature < 1e-2: temperature = 1e-2 generate_kwargs = dict( temperature=temperature, max_new_tokens=item.max_new_tokens, top_p=item.top_p, repetition_penalty=item.repetition_penalty, do_sample=True, ) # Use the apply_chat_template method to get a correctly formatted string formatted_prompt = client.apply_chat_template( [ {"role": "system", "content": item.system_prompt}, *sum([ [{"role": "user", "content": user}, {"role": "assistant", "content": assistant}] for user, assistant in item.history ], []), {"role": "user", "content": item.prompt} ], tokenize=False ) stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) output = "" for response in stream: output += response.token.text return output @app.post("/generate/") async def generate_text(item: Item): response_text = generate(item) return {"response": response_text} # Optional: Add a root endpoint to show the app is alive @app.get("/") def read_root(): return {"Status": "API is running. Use the /generate/ endpoint to get a response."}