# app.py for Hugging Face Space # Make sure to add 'gradio', 'transformers', and 'torch' (or 'tensorflow'/'flax') # to your requirements.txt file in the Hugging Face Space repository. # gated model # Set Hugging Face token if needed (for gated models, though Llama 3.1 might not require it after initial access grant) from huggingface_hub import login # app.py for Hugging Face Space # Make sure to add 'gradio', 'transformers', 'torch' (or 'tensorflow'/'flax'), # and 'huggingface_hub' to your requirements.txt file in the Hugging Face Space repository. import gradio as gr import torch # Or tensorflow/flax depending on backend from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import hf_hub_download # Import hub download function import json # Import json library import os # Import os library for path joining # --- hf lpgin --- hf_token = os.getenv("HF_TOKEN") login(token=hf_token) # --- Configuration --- MODEL_NAME = "google/txgemma-2b-predict" PROMPT_FILENAME = "tdc_prompts.json" MODEL_CACHE = "model_cache" # Optional: define a cache directory MAX_EXAMPLES = 100 # Limit the number of examples loaded from the JSON EXAMPLE_SMILES = "C1=CC=CC=C1" # Default SMILES for examples (Benzene) # --- Load Model, Tokenizer, and Prompts --- print(f"Loading model: {MODEL_NAME}...") tdc_prompts_data = None # Initialize as None examples_list = [] # Initialize empty list for examples try: # Check if GPU is available and use it, otherwise use CPU device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=MODEL_CACHE) print("Tokenizer loaded.") # Load the model model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, cache_dir=MODEL_CACHE, device_map="auto" # Automatically distribute model across available devices (GPU/CPU) ) print("Model loaded.") # Download and load the prompts JSON file print(f"Downloading {PROMPT_FILENAME}...") prompts_file_path = hf_hub_download( repo_id=MODEL_NAME, filename=PROMPT_FILENAME, cache_dir=MODEL_CACHE, # force_download=True, # Uncomment to force redownload if needed ) print(f"{PROMPT_FILENAME} downloaded to: {prompts_file_path}") # Load the JSON data with open(prompts_file_path, 'r') as f: tdc_prompts_data = json.load(f) print(f"Loaded prompts data from {PROMPT_FILENAME}.") # --- Prepare examples for Gradio --- # Updated logic: Parse the dictionary format from tdc_prompts.json # The JSON is expected to be a dictionary where values are prompt templates. if isinstance(tdc_prompts_data, dict): print(f"Processing {len(tdc_prompts_data)} prompts from dictionary...") count = 0 for prompt_template in tdc_prompts_data.values(): if count >= MAX_EXAMPLES: break if isinstance(prompt_template, str): # Replace the placeholder with the example SMILES string example_prompt = prompt_template.replace("{Drug SMILES}", EXAMPLE_SMILES) # Add to examples list with default parameters examples_list.append([example_prompt, 100, 0.7]) # Default max_tokens=100, temp=0.7 count += 1 else: print(f"Warning: Skipping non-string value in prompts dictionary: {prompt_template}") print(f"Prepared {len(examples_list)} examples for Gradio.") else: print(f"Warning: Expected {PROMPT_FILENAME} to contain a dictionary, but found {type(tdc_prompts_data)}. Cannot load examples.") # examples_list remains empty except Exception as e: print(f"Error loading model, tokenizer, or prompts: {e}") # Ensure examples_list is empty on error during setup examples_list = [] raise gr.Error(f"Failed during setup. Check logs for details. Error: {e}") # --- Prediction Function --- def predict(prompt, max_new_tokens=100, temperature=0.7): """ Generates text based on the input prompt using the loaded model. Args: prompt (str): The input text prompt. max_new_tokens (int): The maximum number of new tokens to generate. temperature (float): Controls the randomness of the generation. Lower is more deterministic. Returns: str: The generated text. """ print(f"Received prompt: {prompt}") print(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}") try: # Prepare the input for the model inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Move inputs to the model's device # Generate text with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=int(max_new_tokens), # Ensure it's an integer temperature=float(temperature), # Ensure it's a float do_sample=True if float(temperature) > 0 else False, # Only sample if temp > 0 pad_token_id=tokenizer.eos_token_id # Set pad token id ) # Decode the generated tokens generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) print(f"Generated text (raw): {generated_text}") # Remove the prompt from the beginning of the generated text if generated_text.startswith(prompt): prompt_length = len(prompt) result_text = generated_text[prompt_length:].lstrip() else: # Handle cases where the model might slightly alter the prompt start # This is a basic check; more robust checks might be needed common_prefix = os.path.commonprefix([prompt, generated_text]) # Check if a significant portion of the prompt is at the start # Use a threshold relative to prompt length, e.g., 80% if len(prompt) > 0 and len(common_prefix) / len(prompt) > 0.8: result_text = generated_text[len(common_prefix):].lstrip() else: result_text = generated_text # Assume prompt is not included or significantly altered print(f"Generated text (processed): {result_text}") return result_text except Exception as e: print(f"Error during prediction: {e}") return f"An error occurred during generation: {e}" # --- Gradio Interface --- print("Creating Gradio interface...") with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( f""" # 🤖 TXGemma-2B-Predict Text Generation Enter a prompt below or select an example, and the model ({MODEL_NAME}) will generate text based on it. Adjust the parameters for different results. Examples loaded from `{PROMPT_FILENAME}`. Example prompts use the SMILES string `{EXAMPLE_SMILES}` (Benzene) as a placeholder. """ ) with gr.Row(): with gr.Column(scale=2): prompt_input = gr.Textbox( label="Your Prompt", placeholder="Enter your text prompt here, potentially including a specific Drug SMILES string...", lines=5 ) with gr.Row(): max_tokens_slider = gr.Slider( minimum=10, maximum=500, # Adjust max limit if needed value=100, step=10, label="Max New Tokens", info="Maximum number of tokens to generate after the prompt." ) temperature_slider = gr.Slider( minimum=0.0, # Allow deterministic generation maximum=1.5, value=0.7, step=0.05, # Finer control for temperature label="Temperature", info="Controls randomness (0=deterministic, >0=random)." ) submit_button = gr.Button("Generate Text", variant="primary") with gr.Column(scale=3): output_text = gr.Textbox( label="Generated Text", lines=10, interactive=False # Output is not editable by user ) # --- Connect Components --- submit_button.click( fn=predict, inputs=[prompt_input, max_tokens_slider, temperature_slider], outputs=output_text, api_name="predict" # Name for API endpoint if needed ) # Use the loaded examples if available if examples_list: gr.Examples( examples=examples_list, # Ensure inputs match the order expected by the 'predict' function and the structure of examples_list inputs=[prompt_input, max_tokens_slider, temperature_slider], outputs=output_text, fn=predict, # The function to run when an example is clicked cache_examples=False # Caching might be slow/problematic for LLMs ) else: gr.Markdown("_(Could not load examples from JSON file or file format was incorrect.)_") # --- Launch the App --- print("Launching Gradio app...") # queue() enables handling multiple users concurrently # Set share=True if you need a public link, otherwise False or omit demo.queue().launch(debug=True) # Set debug=False for production