Htr-testing / app.py
Emeritus-21's picture
Update app.py
a1d226e verified
import os
import time
from threading import Thread
import gradio as gr
import spaces
from PIL import Image
import torch
from transformers import (
AutoProcessor,
AutoModelForImageTextToText,
Qwen2_5_VLForConditionalGeneration,
TextIteratorStreamer,
)
# ---------------------------
# Models
# ---------------------------
MODEL_PATHS = {
"Model 1 (Qwen2.5-VL-7B-Abliterated)": (
"prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it",
Qwen2_5_VLForConditionalGeneration,
),
"Model 2 (Nanonets-OCR-s)": (
"nanonets/Nanonets-OCR-s",
Qwen2_5_VLForConditionalGeneration,
),
"Model 3 (Finetuned HTR)": (
"Emeritus-21/Finetuned-full-HTR-model",
AutoModelForImageTextToText,
),
}
MAX_NEW_TOKENS_DEFAULT = 512
device = "cuda" if torch.cuda.is_available() else "cpu"
# ---------------------------
# Preload models at startup
# ---------------------------
_loaded_processors = {}
_loaded_models = {}
print("๐Ÿš€ Preloading models into GPU/CPU memory...")
for name, (repo_id, cls) in MODEL_PATHS.items():
try:
print(f"Loading {name} ...")
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
model = cls.from_pretrained(
repo_id,
trust_remote_code=True,
torch_dtype=torch.float16
).to(device).eval()
_loaded_processors[name] = processor
_loaded_models[name] = model
print(f"โœ… {name} ready.")
except Exception as e:
print(f"โš ๏ธ Failed to load {name}: {e}")
# ---------------------------
# Warmup (GPU)
# ---------------------------
@spaces.GPU
def warmup():
try:
default_model_choice = list(MODEL_PATHS.keys())[0]
processor = _loaded_processors[default_model_choice]
model = _loaded_models[default_model_choice]
messages = [{"role": "user", "content": [{"type": "text", "text": "Warmup."}]}]
chat_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[chat_prompt], images=None, return_tensors="pt").to(device)
with torch.inference_mode():
_ = model.generate(**inputs, max_new_tokens=1)
return f"GPU warm and {default_model_choice} ready."
except Exception as e:
return f"Warmup skipped: {e}"
# ---------------------------
# OCR Function (RAW ONLY)
# ---------------------------
@spaces.GPU
def ocr_image(image: Image.Image, model_choice: str, query: str = None,
max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0):
if image is None:
yield "Please upload an image."
return
if model_choice not in _loaded_models:
yield f"Invalid model: {model_choice}"
return
processor = _loaded_processors[model_choice]
model = _loaded_models[model_choice]
if query and query.strip():
prompt = query.strip()
else:
prompt = (
"You are a professional Handwritten OCR system.\n"
"TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
"- Preserve original structure and line breaks.\n"
"- Keep spacing, bullet points, numbering, and indentation.\n"
"- Render tables as Markdown tables if present.\n"
"- Do NOT autocorrect spelling or grammar.\n"
"- Do NOT merge lines.\n"
"Return RAW transcription only."
)
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}]
chat_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[chat_prompt], images=[image], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
new_text = new_text.replace("<|im_end|>", "")
buffer += new_text
time.sleep(0.01)
yield buffer
# ---------------------------
# Gradio Interface
# ---------------------------
with gr.Blocks() as demo:
gr.Markdown("## wilson Handwritten OCR ")
model_choice = gr.Radio(
choices=list(MODEL_PATHS.keys()),
value=list(MODEL_PATHS.keys())[0],
label="Select OCR Model"
)
with gr.Tab("๐Ÿ–ผ Image Inference"):
query_input = gr.Textbox(label="Custom Prompt (optional)", placeholder="Leave empty for RAW structured output")
image_input = gr.Image(type="pil", label="Upload Handwritten Image")
with gr.Accordion("โš™๏ธ Advanced Options", open=False):
max_new_tokens = gr.Slider(1, 2048, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens")
temperature = gr.Slider(0.1, 2.0, value=0.1, step=0.05, label="Temperature")
top_p = gr.Slider(0.05, 1.0, value=1.0, step=0.05, label="Top-p (nucleus)")
top_k = gr.Slider(0, 1000, value=0, step=1, label="Top-k")
repetition_penalty = gr.Slider(0.8, 2.0, value=1.0, step=0.05, label="Repetition penalty")
with gr.Row():
extract_btn = gr.Button("๐Ÿ“ค Extract RAW Text", variant="primary")
clear_btn = gr.Button("๐Ÿงน Clear")
raw_output = gr.Textbox(label="๐Ÿ“œ RAW Structured Output (exact as written)", lines=18, show_copy_button=True)
extract_btn.click(
fn=ocr_image,
inputs=[image_input, model_choice, query_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=[raw_output],
api_name="ocr_image" # <--- THIS IS THE CRUCIAL FIX
)
clear_btn.click(
fn=lambda: ("", None, ""),
outputs=[raw_output, image_input, query_input]
)
if __name__ == "__main__":
demo.queue(max_size=50).launch(share=True, ssr_mode=False, show_error=True)