import os import time from threading import Thread import gradio as gr import spaces from PIL import Image import torch from transformers import ( AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer, ) # --------------------------- # Models # --------------------------- MODEL_PATHS = { "Model 1 (Qwen2.5-VL-7B-Abliterated)": ( "prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", Qwen2_5_VLForConditionalGeneration, ), "Model 2 (Nanonets-OCR-s)": ( "nanonets/Nanonets-OCR-s", Qwen2_5_VLForConditionalGeneration, ), "Model 3 (Finetuned HTR)": ( "Emeritus-21/Finetuned-full-HTR-model", AutoModelForImageTextToText, ), } MAX_NEW_TOKENS_DEFAULT = 512 device = "cuda" if torch.cuda.is_available() else "cpu" # --------------------------- # Preload models at startup # --------------------------- _loaded_processors = {} _loaded_models = {} print("๐Ÿš€ Preloading models into GPU/CPU memory...") for name, (repo_id, cls) in MODEL_PATHS.items(): try: print(f"Loading {name} ...") processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True) model = cls.from_pretrained( repo_id, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() _loaded_processors[name] = processor _loaded_models[name] = model print(f"โœ… {name} ready.") except Exception as e: print(f"โš ๏ธ Failed to load {name}: {e}") # --------------------------- # Warmup (GPU) # --------------------------- @spaces.GPU def warmup(): try: default_model_choice = list(MODEL_PATHS.keys())[0] processor = _loaded_processors[default_model_choice] model = _loaded_models[default_model_choice] messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}] chat_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device) with torch.inference_mode(): _ = model.generate(**inputs, max_new_tokens=1) return f"GPU warm and {default_model_choice} ready." except Exception as e: return f"Warmup skipped: {e}" # --------------------------- # OCR Function (RAW ONLY) # --------------------------- @spaces.GPU def ocr_image(image: Image.Image, model_choice: str, query: str = None, max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT, temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0): if image is None: yield "Please upload an image." return if model_choice not in _loaded_models: yield f"Invalid model: {model_choice}" return processor = _loaded_processors[model_choice] model = _loaded_models[model_choice] if query and query.strip(): prompt = query.strip() else: prompt = ( "You are a professional Handwritten OCR system.\n" "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n" "- Preserve original structure and line breaks.\n" "- Keep spacing, bullet points, numbering, and indentation.\n" "- Render tables as Markdown tables if present.\n" "- Do NOT autocorrect spelling or grammar.\n" "- Do NOT merge lines.\n" "Return RAW transcription only." ) messages = [{ "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt} ] }] chat_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(text=[chat_prompt], images=[image], return_tensors="pt").to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=False, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: new_text = new_text.replace("<|im_end|>", "") buffer += new_text time.sleep(0.01) yield buffer # --------------------------- # Gradio Interface # --------------------------- with gr.Blocks() as demo: gr.Markdown("## wilson Handwritten OCR ") model_choice = gr.Radio( choices=list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="Select OCR Model" ) with gr.Tab("๐Ÿ–ผ Image Inference"): query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output") image_input = gr.Image(type="pil", label="Upload Handwritten Image") with gr.Accordion("โš™๏ธ Advanced Options", open=False): max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens") temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature") top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)") top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k") repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty") with gr.Row(): extract_btn = gr.Button("๐Ÿ“ค Extract RAW Text", variant="primary") clear_btn = gr.Button("๐Ÿงน Clear") raw_output = gr.Textbox(label="๐Ÿ“œ RAW Structured Output (exact as written)", lines=18, show_copy_button=True) extract_btn.click( fn=ocr_image, inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[raw_output], api_name="ocr_image" # <--- THIS IS THE CRUCIAL FIX ) clear_btn.click( fn=lambda: ("", None, ""), outputs=[raw_output, image_input, query_input] ) if __name__ == "__main__": demo.queue(max_size=50).launch(share=True, ssr_mode=False, show_error=True)