import gradio as gr from PIL import Image import os import json import time import torch from transformers import MllamaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig import spaces ckpt = "unsloth/Llama-3.2-11B-Vision-Instruct" device = "cuda" if torch.cuda.is_available() else "cpu" # Chargement du modèle et processeur bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 if device=="cuda" else torch.float32, ) model = MllamaForConditionalGeneration.from_pretrained( ckpt, device_map="auto", # répartit le modèle automatiquement sur les GPU quantization_config=bnb_config, ) processor = AutoProcessor.from_pretrained(ckpt) SAVE_DIR = "corrections" os.makedirs(SAVE_DIR, exist_ok=True) @spaces.GPU def ocr_on_image(image): prompt1 = ( "Output ONLY the raw text exactly as it appears in the image. Do not add anything.\n\n" "The image may contain both handwritten and printed text in French and/or English, including punctuation and underscores.\n\n" "Your task: Transcribe all visible text exactly, preserving:\n" "- All characters, accents, punctuation, spacing, and line breaks.\n" "- The original reading order and layout, including tables and forms if present.\n\n" "Rules:\n" "- Do NOT add any explanations, summaries, comments, or extra text.\n" "- Do NOT duplicate any content.\n" "- Do NOT indicate blank space.\n" "- Do NOT separate handwritten and printed text.\n" "- Do NOT confuse '.' (a period) with '|' (a border).\n\n" "Only extract the text that is actually visible in the image, and nothing else." ) prompt2 =( "Extract all visible text from the image, including both handwritten and printed content." "Do not translate the text — preserve the original language exactly as it appears." "Return only the extracted text, with no explanation, no formatting, and no additions." ) messages = [{"role": "user", "content": [{"type": "text", "text": prompt2}, {"type": "image"}]}] texts = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=texts, images=[image], return_tensors="pt").to(device) outputs = model.generate(**inputs, max_new_tokens=250) result = processor.decode(outputs[0], skip_special_tokens=True) # Nettoyage simple if "assistant" in result.lower(): result = result[result.lower().find("assistant") + len("assistant"):].strip() result = result.replace("user", "").replace(prompt, "").strip() return result def batch_ocr(images): if not images: return [], "Aucune image uploadée." results = [] status_text = f"Traitement de {len(images)} image(s)...\n" for i, img_file in enumerate(images): try: pil_img = Image.open(img_file.name).convert("RGB") text = ocr_on_image(pil_img) results.append({ "image": pil_img, "filepath": img_file.name, "ocr_text": text, "corrected_text": text }) status_text += f"Image {i+1}: ✓ Texte extrait\n" except Exception as e: status_text += f"Image {i+1}: ❌ Erreur: {str(e)}\n" return results, status_text def save_all_corrections(data_list, *corrections): if not data_list: return "Aucune donnée à sauvegarder.", None # Mettre à jour les corrections avec les textes modifiés for i, correction in enumerate(corrections): if i < len(data_list) and correction.strip(): data_list[i]["corrected_text"] = correction timestamp = int(time.time()) # Créer un fichier JSON consolidé téléchargeable consolidated_data = [] for i, data in enumerate(data_list): # Sauvegarder l'image localement pour l'affichage img_path = f"{SAVE_DIR}/image_{timestamp}_{i}.png" os.makedirs(SAVE_DIR, exist_ok=True) data["image"].save(img_path) entry = { "image_id": f"image_{i+1}", "original_filename": data["filepath"], "ocr_text": data["ocr_text"], "corrected_text": data["corrected_text"], "timestamp": timestamp } consolidated_data.append(entry) # Créer le fichier de téléchargement download_path = f"corrections_{timestamp}.json" with open(download_path, "w", encoding="utf-8") as f: json.dump(consolidated_data, f, ensure_ascii=False, indent=2) status_msg = f"✅ {len(consolidated_data)} correction(s) préparée(s) pour téléchargement." return status_msg, download_path # Interface Gradio simplifiée with gr.Blocks(title="OCR avec Llama Vision", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🔍 OCR Multi-Images avec Correction Manuelle") gr.Markdown("Uploadez vos images et extrayez le texte automatiquement, puis corrigez si nécessaire.") with gr.Row(): uploaded = gr.Files( file_types=[".png", ".jpg", ".jpeg", ".tif"], label="📁 Uploader plusieurs images", file_count="multiple" ) btn_ocr = gr.Button("🚀 Extraire le texte OCR", variant="primary", size="lg") status = gr.Textbox(label="📊 Status", lines=3, visible=False) # Conteneurs pour les résultats (fixes, pas dynamiques) results_data = gr.State([]) with gr.Column(visible=False) as results_section: gr.Markdown("## 📝 Résultats OCR - Vous pouvez modifier le texte ci-dessous") # Interface fixe pour jusqu'à 5 images (ajustez selon vos besoins) image_components = [] text_components = [] for i in range(5): # Maximum 5 images with gr.Row(visible=False) as row: with gr.Column(scale=1): img_comp = gr.Image(label=f"Image {i+1}", height=300) image_components.append((row, img_comp)) with gr.Column(scale=2): txt_comp = gr.Textbox( label=f"Texte extrait - Image {i+1}", lines=10, placeholder="Le texte extrait apparaîtra ici..." ) text_components.append(txt_comp) btn_save = gr.Button("💾 Sauvegarder toutes les corrections", variant="secondary", size="lg") save_status = gr.Textbox(label="💾 Status de sauvegarde", visible=False) download_file = gr.File(label="📥 Télécharger les corrections", visible=False) def process_images(images): if not images: return ( gr.update(visible=True, value="❌ Aucune image uploadée."), gr.update(visible=False), gr.update(visible=False), [], *[gr.update(visible=False) for _ in range(5)], *[gr.update(value="") for _ in range(5)] ) results, status_text = batch_ocr(images) # Mise à jour des composants d'image et de texte image_updates = [] text_updates = [] for i in range(5): if i < len(results): # Montrer l'image et le texte image_updates.append(gr.update(visible=True)) image_updates.append(gr.update(value=results[i]["image"])) text_updates.append(gr.update(value=results[i]["ocr_text"])) else: # Cacher les composants non utilisés image_updates.append(gr.update(visible=False)) image_updates.append(gr.update(value=None)) text_updates.append(gr.update(value="")) return ( gr.update(visible=True, value=status_text), gr.update(visible=True), gr.update(visible=True), results, *image_updates, *text_updates ) # Préparer les outputs pour le clic image_outputs = [] for row, img in image_components: image_outputs.extend([row, img]) btn_ocr.click( process_images, inputs=[uploaded], outputs=[ status, results_section, save_status, results_data, *image_outputs, *text_components ] ) btn_save.click( save_all_corrections, inputs=[results_data] + text_components, outputs=save_status ) if __name__ == "__main__": demo.launch(debug=True)