import spaces import json import math import os import traceback from io import BytesIO from typing import Any, Dict, List, Optional, Tuple import re import fitz # PyMuPDF import gradio as gr import requests import torch from huggingface_hub import snapshot_download from PIL import Image, ImageDraw, ImageFont from qwen_vl_utils import process_vision_info from transformers import AutoModelForCausalLM, AutoProcessor # Constants MIN_PIXELS = 3136 MAX_PIXELS = 11289600 IMAGE_FACTOR = 28 # Prompts prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox. 1. Bbox format: [x1, y1, x2, y2] 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. 3. Text Extraction & Formatting Rules: - Picture: For the 'Picture' category, the text field should be omitted. - Formula: Format its text as LaTeX. - Table: Format its text as HTML. - All Others (Text, Title, etc.): Format their text as Markdown. 4. Constraints: - The output text must be the original text from the image, with no translation. - All layout elements must be sorted according to human reading order. 5. Final Output: The entire output must be a single JSON object. """ # Utility functions def round_by_factor(number: int, factor: int) -> int: """Returns the closest integer to 'number' that is divisible by 'factor'.""" return round(number / factor) * factor def smart_resize( height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 11289600, ): """Rescales the image so that the following conditions are met: 1. Both dimensions (height and width) are divisible by 'factor'. 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. 3. The aspect ratio of the image is maintained as closely as possible. """ if max(height, width) / min(height, width) > 200: raise ValueError( f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" ) h_bar = max(factor, round_by_factor(height, factor)) w_bar = max(factor, round_by_factor(width, factor)) if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = round_by_factor(height / beta, factor) w_bar = round_by_factor(width / beta, factor) elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = round_by_factor(height * beta, factor) w_bar = round_by_factor(width * beta, factor) return h_bar, w_bar def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None): """Fetch and process an image""" if isinstance(image_input, str): if image_input.startswith(("http://", "https://")): response = requests.get(image_input) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_input).convert('RGB') elif isinstance(image_input, Image.Image): image = image_input.convert('RGB') else: raise ValueError(f"Invalid image input type: {type(image_input)}") if min_pixels is not None or max_pixels is not None: min_pixels = min_pixels or MIN_PIXELS max_pixels = max_pixels or MAX_PIXELS height, width = smart_resize( image.height, image.width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels ) image = image.resize((width, height), Image.LANCZOS) return image def load_images_from_pdf(pdf_path: str) -> List[Image.Image]: """Load images from PDF file""" images = [] try: pdf_document = fitz.open(pdf_path) for page_num in range(len(pdf_document)): page = pdf_document.load_page(page_num) # Convert page to image mat = fitz.Matrix(2.0, 2.0) # Increase resolution pix = page.get_pixmap(matrix=mat) img_data = pix.tobytes("ppm") image = Image.open(BytesIO(img_data)).convert('RGB') images.append(image) pdf_document.close() except Exception as e: print(f"Error loading PDF: {e}") return [] return images def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image: """Draw layout bounding boxes on image""" img_copy = image.copy() draw = ImageDraw.Draw(img_copy) # Colors for different categories colors = { 'Caption': '#FF6B6B', 'Footnote': '#4ECDC4', 'Formula': '#45B7D1', 'List-item': '#96CEB4', 'Page-footer': '#FFEAA7', 'Page-header': '#DDA0DD', 'Picture': '#FFD93D', 'Section-header': '#6C5CE7', 'Table': '#FD79A8', 'Text': '#74B9FF', 'Title': '#E17055' } try: # Load a font try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12) except Exception: font = ImageFont.load_default() for item in layout_data: if 'bbox' in item and 'category' in item: bbox = item['bbox'] category = item['category'] color = colors.get(category, '#000000') # Draw rectangle draw.rectangle(bbox, outline=color, width=2) # Draw label label = category label_bbox = draw.textbbox((0, 0), label, font=font) label_width = label_bbox[2] - label_bbox[0] label_height = label_bbox[3] - label_bbox[1] # Position label above the box label_x = bbox[0] label_y = max(0, bbox[1] - label_height - 2) # Draw background for label draw.rectangle( [label_x, label_y, label_x + label_width + 4, label_y + label_height + 2], fill=color ) # Draw text draw.text((label_x + 2, label_y + 1), label, fill='white', font=font) except Exception as e: print(f"Error drawing layout: {e}") return img_copy def is_arabic_text(text: str) -> bool: """Check if text in headers and paragraphs contains mostly Arabic characters""" if not text: return False # Extract text from headers and paragraphs only # Match markdown headers (# ## ###) and regular paragraph text header_pattern = r'^#{1,6}\s+(.+)$' paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$' content_text = [] for line in text.split('\n'): line = line.strip() if not line: continue # Check for headers header_match = re.match(header_pattern, line, re.MULTILINE) if header_match: content_text.append(header_match.group(1)) continue # Check for paragraph text (exclude lists, tables, code blocks, images) if re.match(paragraph_pattern, line, re.MULTILINE): content_text.append(line) if not content_text: return False # Join all content text and check for Arabic characters combined_text = ' '.join(content_text) # Arabic Unicode ranges arabic_chars = 0 total_chars = 0 for char in combined_text: if char.isalpha(): total_chars += 1 # Arabic script ranges if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'): arabic_chars += 1 if total_chars == 0: return False # Consider text as Arabic if more than 50% of alphabetic characters are Arabic return (arabic_chars / total_chars) > 0.5 def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str: """Convert layout JSON to markdown format""" import base64 from io import BytesIO markdown_lines = [] try: # Sort items by reading order (top to bottom, left to right) sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0])) for item in sorted_items: category = item.get('category', '') text = item.get(text_key, '') bbox = item.get('bbox', []) if category == 'Picture': # Extract image region and embed it if bbox and len(bbox) == 4: try: # Extract the image region x1, y1, x2, y2 = bbox # Ensure coordinates are within image bounds x1, y1 = max(0, int(x1)), max(0, int(y1)) x2, y2 = min(image.width, int(x2)), min(image.height, int(y2)) if x2 > x1 and y2 > y1: cropped_img = image.crop((x1, y1, x2, y2)) # Convert to base64 for embedding buffer = BytesIO() cropped_img.save(buffer, format='PNG') img_data = base64.b64encode(buffer.getvalue()).decode() # Add as markdown image markdown_lines.append(f"\n") else: markdown_lines.append("\n") except Exception as e: print(f"Error processing image region: {e}") markdown_lines.append("\n") else: markdown_lines.append("\n") elif not text: continue elif category == 'Title': markdown_lines.append(f"# {text}\n") elif category == 'Section-header': markdown_lines.append(f"## {text}\n") elif category == 'Text': markdown_lines.append(f"{text}\n") elif category == 'List-item': markdown_lines.append(f"- {text}\n") elif category == 'Table': # If text is already HTML, keep it as is if text.strip().startswith('<'): markdown_lines.append(f"{text}\n") else: markdown_lines.append(f"**Table:** {text}\n") elif category == 'Formula': # If text is LaTeX, format it properly if text.strip().startswith('$') or '\\' in text: markdown_lines.append(f"$$\n{text}\n$$\n") else: markdown_lines.append(f"**Formula:** {text}\n") elif category == 'Caption': markdown_lines.append(f"*{text}*\n") elif category == 'Footnote': markdown_lines.append(f"^{text}^\n") elif category in ['Page-header', 'Page-footer']: # Skip headers and footers in main content continue else: markdown_lines.append(f"{text}\n") markdown_lines.append("") # Add spacing except Exception as e: print(f"Error converting to markdown: {e}") return str(layout_data) return "\n".join(markdown_lines) # Initialize model and processor at script level model_id = "rednote-hilab/dots.ocr" model_path = "./models/dots-ocr-local" snapshot_download( repo_id=model_id, local_dir=model_path, local_dir_use_symlinks=False, # Recommended to set to False to avoid symlink issues ) model = AutoModelForCausalLM.from_pretrained( model_path, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) processor = AutoProcessor.from_pretrained( model_path, trust_remote_code=True ) # Global state variables device = "cuda" if torch.cuda.is_available() else "cpu" # PDF handling state pdf_cache = { "images": [], "current_page": 0, "total_pages": 0, "file_type": None, "is_parsed": False, "results": [] } @spaces.GPU() def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str: """Run inference on an image with the given prompt""" try: if model is None or processor is None: raise RuntimeError("Model not loaded. Please check model initialization.") # Prepare messages in the expected format messages = [ { "role": "user", "content": [ { "type": "image", "image": image }, {"type": "text", "text": prompt} ] } ] # Apply chat template text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Process vision information image_inputs, video_inputs = process_vision_info(messages) # Prepare inputs inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) # Move to device inputs = inputs.to(device) # Generate output with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.1 ) # Decode output generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return output_text[0] if output_text else "" except Exception as e: print(f"Error during inference: {e}") traceback.print_exc() return f"Error during inference: {str(e)}" def process_image( image: Image.Image, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None ) -> Dict[str, Any]: """Process a single image with the specified prompt mode""" try: # Resize image if needed if min_pixels is not None or max_pixels is not None: image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels) # Run inference with the default prompt raw_output = inference(image, prompt) # Process results based on prompt mode result = { 'original_image': image, 'raw_output': raw_output, 'processed_image': image, 'layout_result': None, 'markdown_content': None } # Try to parse JSON and create visualizations (since we're doing layout analysis) try: # Try to parse JSON output layout_data = json.loads(raw_output) result['layout_result'] = layout_data # Create visualization with bounding boxes try: processed_image = draw_layout_on_image(image, layout_data) result['processed_image'] = processed_image except Exception as e: print(f"Error drawing layout: {e}") result['processed_image'] = image # Generate markdown from layout data try: markdown_content = layoutjson2md(image, layout_data, text_key='text') result['markdown_content'] = markdown_content except Exception as e: print(f"Error generating markdown: {e}") result['markdown_content'] = raw_output except json.JSONDecodeError: print("Failed to parse JSON output, using raw output") result['markdown_content'] = raw_output return result except Exception as e: print(f"Error processing image: {e}") traceback.print_exc() return { 'original_image': image, 'raw_output': f"Error processing image: {str(e)}", 'processed_image': image, 'layout_result': None, 'markdown_content': f"Error processing image: {str(e)}" } def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]: """Load file for preview (supports PDF and images)""" global pdf_cache if not file_path or not os.path.exists(file_path): return None, "No file selected" file_ext = os.path.splitext(file_path)[1].lower() try: if file_ext == '.pdf': # Load PDF pages images = load_images_from_pdf(file_path) if not images: return None, "Failed to load PDF" pdf_cache.update({ "images": images, "current_page": 0, "total_pages": len(images), "file_type": "pdf", "is_parsed": False, "results": [] }) return images[0], f"Page 1 / {len(images)}" elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: # Load single image image = Image.open(file_path).convert('RGB') pdf_cache.update({ "images": [image], "current_page": 0, "total_pages": 1, "file_type": "image", "is_parsed": False, "results": [] }) return image, "Page 1 / 1" else: return None, f"Unsupported file format: {file_ext}" except Exception as e: print(f"Error loading file: {e}") return None, f"Error loading file: {str(e)}" def turn_page(direction: str) -> Tuple[Optional[Image.Image], str, Any, Optional[Image.Image], Optional[Dict]]: """Navigate through PDF pages and update all relevant outputs.""" global pdf_cache if not pdf_cache["images"]: return None, '
A state-of-the-art image/pdf-to-markdown vision language model for intelligent document processing