yuki-imajuku commited on
Commit
eb4186d
·
0 Parent(s):

initial commit

Browse files
Files changed (4) hide show
  1. .gitattributes +35 -0
  2. README.md +14 -0
  3. app.py +218 -0
  4. requirements.txt +4 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MangaLMM Demo
3
+ emoji: 📚
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.30.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: The official demo of MangaLMM
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install FlashAttention
2
+ import subprocess
3
+ subprocess.run(
4
+ "pip install flash-attn --no-build-isolation",
5
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
6
+ shell=True,
7
+ )
8
+
9
+ import base64
10
+ from collections import Counter
11
+ from io import BytesIO
12
+ import re
13
+
14
+ from PIL import Image, ImageDraw
15
+ import gradio as gr
16
+ import spaces
17
+ import torch
18
+ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
19
+ from qwen_vl_utils import process_vision_info, smart_resize
20
+
21
+
22
+ repo_id = "hal-utokyo/MangaLMM"
23
+ processor = Qwen2_5_VLProcessor.from_pretrained(repo_id)
24
+
25
+
26
+ def pil2base64(image: Image.Image) -> str:
27
+ buffered = BytesIO()
28
+ image.save(buffered, format="PNG")
29
+ return base64.b64encode(buffered.getvalue()).decode()
30
+
31
+
32
+ def bbox2d_to_quad(bbox_2d):
33
+ xmin, ymin, xmax, ymax = bbox_2d
34
+ return [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax]
35
+
36
+
37
+ def normalize_repeated_symbols(text):
38
+ text = re.sub(r'([~\~\〜\-\ー]+)', lambda m: m.group(1)[0], text)
39
+ text = re.sub(r'[~~〜]', '~', text)
40
+ text = re.sub(r'[-ー]', '-', text)
41
+ return text
42
+
43
+
44
+ def normalize_punctuation(text):
45
+ conversion_map = {
46
+ "!": "!",
47
+ "?": "?",
48
+ "…": "..."
49
+ }
50
+ text = re.sub("|".join(map(re.escape, conversion_map.keys())), lambda m: conversion_map[m.group()], text)
51
+ text = re.sub(r'[・・.]', '・', text)
52
+ return text
53
+
54
+
55
+ def restore_chouon(text):
56
+ # hirakana + katakana + kanji
57
+ # jp_range = r"ぁ-んァ-ン一-龯㐀-䶵" # \u3400-\u4DBF = r"㐀-䶵"
58
+ # Extended Unicode version: covers Hiragana, Katakana, and a wide range of Kanji (including Extension A)
59
+ jp_range = r"\u3040-\u309F\u30A0-\u30FF\u3400-\u4DBF\u4E00-\u9FFF"
60
+ pattern = rf"(?<=[{jp_range}])-(?=[{jp_range}])"
61
+ return re.sub(pattern, "ー", text)
62
+
63
+
64
+ def process_text(text: str) -> str:
65
+ text = re.sub(r"[\s\u3000]+", "", text)
66
+ text = normalize_repeated_symbols(text)
67
+ text = normalize_punctuation(text)
68
+ text = restore_chouon(text)
69
+ return text
70
+
71
+
72
+ def parse_ocr_text(text: str) -> list[list]:
73
+ if not text.strip():
74
+ return []
75
+ # handle escape
76
+ text = text.replace('\\"', '"')
77
+ # find \n\t{ ... } blocks
78
+ blocks = re.findall(r"\n\t\{.*?\}", text, re.DOTALL)
79
+ # extract OCR text and bounding box
80
+ ocrs = []
81
+ for block in blocks:
82
+ block = block.strip() # remove \n\t
83
+ bbox_match = re.search(r'"bbox_2d"\s*:\s*\[([^\]]+)\]', block, flags=re.DOTALL)
84
+ text_match = re.search(
85
+ r'"text_content"\s*:\s*"([^"]*)"', block, flags=re.DOTALL
86
+ )
87
+
88
+ if bbox_match and text_match:
89
+ try:
90
+ bbox_list = [int(x.strip()) for x in bbox_match.group(1).split(",")]
91
+ content = process_text(text_match.group(1))
92
+ quad = bbox2d_to_quad(bbox_list)
93
+ ocrs.append([content, quad])
94
+ except:
95
+ continue
96
+ # remove duplicates (sometimes the model generates the same text multiple times)
97
+ counter = Counter([ocr[0] for ocr in ocrs])
98
+ ocrs = [ocr for ocr in ocrs if counter[ocr[0]] < 10]
99
+ return ocrs
100
+
101
+
102
+ @spaces.GPU
103
+ @torch.inference_mode()
104
+ def inference_fn(
105
+ image: Image.Image | None,
106
+ text: str | None,
107
+ # progress=gr.Progress(track_tqdm=True),
108
+ ) -> tuple[str, str, Image.Image | None]:
109
+ if image is None:
110
+ gr.Warning("Please upload an image!", duration=10)
111
+ return "Please upload an image!", "Please upload an image!", None
112
+ if image.width * image.height > 2116800:
113
+ gr.Warning("The image size is too large! We resize it to smaller size.", duration=10)
114
+ resized_height, resized_width = smart_resize(
115
+ height=image.height,
116
+ width=image.width,
117
+ factor=processor.image_processor.patch_size * processor.image_processor.merge_size,
118
+ min_pixels=processor.image_processor.min_pixels,
119
+ max_pixels=processor.image_processor.max_pixels,
120
+ )
121
+ image = image.resize((resized_width, resized_height), resample=Image.Resampling.BICUBIC)
122
+ if text is None or text.strip() == "":
123
+ # OCR
124
+ text = "Please perform OCR on this image and output the recognized Japanese text along with its position (grounding)."
125
+
126
+ device = "cuda" if torch.cuda.is_available() else "cpu"
127
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
128
+ repo_id,
129
+ torch_dtype=torch.bfloat16,
130
+ attn_implementation="flash_attention_2",
131
+ device_map=device,
132
+ )
133
+
134
+ base64_image = pil2base64(image)
135
+ messages = [
136
+ {"role": "user", "content": [
137
+ {"type": "image", "image": f"data:image;base64,{base64_image}"},
138
+ {"type": "text", "text": text},
139
+ ]},
140
+ ]
141
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
142
+ image_inputs, video_inputs = process_vision_info(messages)
143
+ inputs = processor(
144
+ text=[text],
145
+ images=image_inputs,
146
+ videos=video_inputs,
147
+ padding=True,
148
+ return_tensors="pt",
149
+ )
150
+ inputs = inputs.to(model.device)
151
+
152
+ generated_ids = model.generate(**inputs, max_new_tokens=4096)
153
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
154
+ raw_output = processor.batch_decode(
155
+ generated_ids_trimmed,
156
+ skip_special_tokens=True,
157
+ clean_up_tokenization_spaces=False,
158
+ )[0]
159
+ result_image = image_inputs[0].copy()
160
+
161
+ ocrs = parse_ocr_text(raw_output)
162
+ if not ocrs:
163
+ return raw_output, "OCR feature was not performed.", result_image
164
+
165
+ draw = ImageDraw.Draw(result_image)
166
+ ocr_texts = []
167
+ for ocr_text, quad in ocrs:
168
+ ocr_texts.append(f'{ocr_text} ({quad[0]}, {quad[1]}, {quad[4]}, {quad[5]})')
169
+ for i in range(4):
170
+ start_point = quad[i*2:i*2+2]
171
+ end_point = quad[i*2+2:i*2+4] if i < 3 else quad[:2]
172
+ draw.line(start_point + end_point, fill="red", width=4)
173
+ draw.polygon(quad, outline="red", width=4)
174
+ # draw.text((quad[0], quad[1]), ocr_text, fill="red")
175
+ ocr_texts_str = "\n".join(ocr_texts)
176
+ return "No question was entered.", ocr_texts_str, result_image
177
+
178
+
179
+ with gr.Blocks() as demo:
180
+ gr.Markdown("""# MangaLMM Official Demo
181
+
182
+ ![GitHub Repo](https://img.shields.io/badge/repo-manga109%2FMangaLMM-9E95B7?logo=refinedgithub)
183
+
184
+ We propose MangaVQA and MangaLMM, which are a benchmark and a specialized LMM for multimodal manga understanding.
185
+
186
+ This demo uses our [MangaLMM model](https://huggingface.co/hal-utokyo/MangaLMM) to perform OCR on an image of manga panels and answer a question about the image.
187
+
188
+ Please ensure that the image contains fewer than 2116800 pixels. (e.g. 1600x1200, 1920x1080, etc.) If more, we resize it to smaller size.
189
+
190
+ *Note: This model is for research purposes only and may return incorrect results. Please use it at your own risk.*
191
+ """)
192
+ with gr.Row():
193
+ with gr.Column():
194
+ input_button = gr.Button(value="Submit")
195
+ input_text = gr.Textbox(
196
+ label="Input Text", lines=5, max_lines=5,
197
+ placeholder="Please enter a question about your image.\nEmpty text will perform OCR.",
198
+ )
199
+ input_image = gr.Image(label="Input Image", image_mode="RGB", type="pil")
200
+ with gr.Column():
201
+ vqa_text = gr.Textbox(label="VQA Result", lines=2, max_lines=2)
202
+ ocr_text = gr.Textbox(label="OCR Result", lines=3, max_lines=3)
203
+ ocr_image = gr.Image(label="OCR Result", type="pil", show_label=False)
204
+
205
+ input_button.click(
206
+ fn=inference_fn,
207
+ inputs=[input_image, input_text],
208
+ outputs=[vqa_text, ocr_text, ocr_image],
209
+ )
210
+ ocr_examples = gr.Examples(
211
+ examples=[],
212
+ fn=inference_fn,
213
+ inputs=[input_image, input_text],
214
+ outputs=[vqa_text, ocr_text, ocr_image],
215
+ cache_examples=False,
216
+ )
217
+
218
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ accelerate==1.7.0
2
+ qwen-vl-utils==0.0.11
3
+ torchvision==0.20.1 --extra-index-url https://download.pytorch.org/whl/cu121
4
+ transformers @ git+https://github.com/huggingface/transformers@6b550462139655d488d4c663086a63e98713c6b9