Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
from PIL import Image | |
import torch | |
from transformers import ( | |
AutoProcessor, | |
AutoModelForImageTextToText, | |
Qwen2_5_VLForConditionalGeneration, | |
TextIteratorStreamer, | |
) | |
# --------------------------- | |
# Models | |
# --------------------------- | |
MODEL_PATHS = { | |
"Model 1 (Qwen2.5-VL-7B-Abliterated)": ( | |
"prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it", | |
Qwen2_5_VLForConditionalGeneration, | |
), | |
"Model 2 (Nanonets-OCR-s)": ( | |
"nanonets/Nanonets-OCR-s", | |
Qwen2_5_VLForConditionalGeneration, | |
), | |
"Model 3 (Finetuned HTR)": ( | |
"Emeritus-21/Finetuned-full-HTR-model", | |
AutoModelForImageTextToText, | |
), | |
} | |
MAX_NEW_TOKENS_DEFAULT = 512 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# --------------------------- | |
# Preload models at startup | |
# --------------------------- | |
_loaded_processors = {} | |
_loaded_models = {} | |
print("๐ Preloading models into GPU/CPU memory...") | |
for name, (repo_id, cls) in MODEL_PATHS.items(): | |
try: | |
print(f"Loading {name} ...") | |
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True) | |
model = cls.from_pretrained( | |
repo_id, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 | |
).to(device).eval() | |
_loaded_processors[name] = processor | |
_loaded_models[name] = model | |
print(f"โ {name} ready.") | |
except Exception as e: | |
print(f"โ ๏ธ Failed to load {name}: {e}") | |
# --------------------------- | |
# Warmup (GPU) | |
# --------------------------- | |
def warmup(): | |
try: | |
default_model_choice = list(MODEL_PATHS.keys())[0] | |
processor = _loaded_processors[default_model_choice] | |
model = _loaded_models[default_model_choice] | |
messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}] | |
chat_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device) | |
with torch.inference_mode(): | |
_ = model.generate(**inputs, max_new_tokens=1) | |
return f"GPU warm and {default_model_choice} ready." | |
except Exception as e: | |
return f"Warmup skipped: {e}" | |
# --------------------------- | |
# OCR Function (RAW ONLY) | |
# --------------------------- | |
def ocr_image(image: Image.Image, model_choice: str, query: str = None, | |
max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT, | |
temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0): | |
if image is None: | |
yield "Please upload an image." | |
return | |
if model_choice not in _loaded_models: | |
yield f"Invalid model: {model_choice}" | |
return | |
processor = _loaded_processors[model_choice] | |
model = _loaded_models[model_choice] | |
if query and query.strip(): | |
prompt = query.strip() | |
else: | |
prompt = ( | |
"You are a professional Handwritten OCR system.\n" | |
"TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n" | |
"- Preserve original structure and line breaks.\n" | |
"- Keep spacing, bullet points, numbering, and indentation.\n" | |
"- Render tables as Markdown tables if present.\n" | |
"- Do NOT autocorrect spelling or grammar.\n" | |
"- Do NOT merge lines.\n" | |
"Return RAW transcription only." | |
) | |
messages = [{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": prompt} | |
] | |
}] | |
chat_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = processor(text=[chat_prompt], images=[image], return_tensors="pt").to(device) | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict( | |
**inputs, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=False, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty | |
) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
new_text = new_text.replace("<|im_end|>", "") | |
buffer += new_text | |
time.sleep(0.01) | |
yield buffer | |
# --------------------------- | |
# Gradio Interface | |
# --------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("## wilson Handwritten OCR ") | |
model_choice = gr.Radio( | |
choices=list(MODEL_PATHS.keys()), | |
value=list(MODEL_PATHS.keys())[0], | |
label="Select OCR Model" | |
) | |
with gr.Tab("๐ผ Image Inference"): | |
query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output") | |
image_input = gr.Image(type="pil", label="Upload Handwritten Image") | |
with gr.Accordion("โ๏ธ Advanced Options", open=False): | |
max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens") | |
temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature") | |
top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)") | |
top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k") | |
repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty") | |
with gr.Row(): | |
extract_btn = gr.Button("๐ค Extract RAW Text", variant="primary") | |
clear_btn = gr.Button("๐งน Clear") | |
raw_output = gr.Textbox(label="๐ RAW Structured Output (exact as written)", lines=18, show_copy_button=True) | |
extract_btn.click( | |
fn=ocr_image, | |
inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty], | |
outputs=[raw_output], | |
api_name="ocr_image" # <--- THIS IS THE CRUCIAL FIX | |
) | |
clear_btn.click( | |
fn=lambda: ("", None, ""), | |
outputs=[raw_output, image_input, query_input] | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=50).launch(share=True, ssr_mode=False, show_error=True) |