Spaces:
Build error
Build error
Commit
·
0f4b503
1
Parent(s):
579d11f
update readme
Browse files- README.md +1 -1
- chat_gradio_demo.py → app.py +77 -30
README.md
CHANGED
@@ -17,7 +17,7 @@ The **joint training** of visual and language instructions effectively improves
|
|
17 |
To install the package in an existing environment, run
|
18 |
|
19 |
```bash
|
20 |
-
git clone
|
21 |
pip install -r requirements.txt
|
22 |
pip install -e. -v
|
23 |
```
|
|
|
17 |
To install the package in an existing environment, run
|
18 |
|
19 |
```bash
|
20 |
+
git clone https://github.com/open-mmlab/Multimodal-GPT.git
|
21 |
pip install -r requirements.txt
|
22 |
pip install -e. -v
|
23 |
```
|
chat_gradio_demo.py → app.py
RENAMED
@@ -12,12 +12,16 @@ Prompt_Tutorial = "Model Inputs = {Prompt}({seperator}Image:\n<image> if image u
|
|
12 |
|
13 |
|
14 |
class Inferencer:
|
|
|
15 |
def __init__(self, finetune_path, llama_path, open_flamingo_path):
|
16 |
ckpt = torch.load(finetune_path, map_location="cpu")
|
17 |
if "model_state_dict" in ckpt:
|
18 |
state_dict = ckpt["model_state_dict"]
|
19 |
# remove the "module." prefix
|
20 |
-
state_dict = {
|
|
|
|
|
|
|
21 |
else:
|
22 |
state_dict = ckpt
|
23 |
tuning_config = ckpt.get("tuning_config")
|
@@ -35,16 +39,21 @@ class Inferencer:
|
|
35 |
tuning_config=tuning_config,
|
36 |
)
|
37 |
model.load_state_dict(state_dict, strict=False)
|
|
|
38 |
model = model.to("cuda")
|
|
|
39 |
tokenizer.padding_side = "left"
|
40 |
tokenizer.add_eos_token = False
|
41 |
self.model = model
|
42 |
self.image_processor = image_processor
|
43 |
self.tokenizer = tokenizer
|
44 |
|
45 |
-
def __call__(self, prompt, imgpaths, max_new_token, num_beams, temperature,
|
|
|
46 |
if len(imgpaths) > 1:
|
47 |
-
raise gr.Error(
|
|
|
|
|
48 |
lang_x = self.tokenizer([prompt], return_tensors="pt")
|
49 |
if len(imgpaths) == 0 or imgpaths is None:
|
50 |
for layer in self.model.lang_encoder._get_decoder_layers():
|
@@ -65,7 +74,7 @@ class Inferencer:
|
|
65 |
images = (Image.open(fp) for fp in imgpaths)
|
66 |
vision_x = [self.image_processor(im).unsqueeze(0) for im in images]
|
67 |
vision_x = torch.cat(vision_x, dim=0)
|
68 |
-
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
|
69 |
|
70 |
output_ids = self.model.generate(
|
71 |
vision_x=vision_x.cuda(),
|
@@ -78,13 +87,15 @@ class Inferencer:
|
|
78 |
top_p=top_p,
|
79 |
do_sample=do_sample,
|
80 |
)[0]
|
81 |
-
generated_text = self.tokenizer.decode(
|
|
|
82 |
# print(generated_text)
|
83 |
result = generated_text.split(response_split)[-1].strip()
|
84 |
return result
|
85 |
|
86 |
|
87 |
class PromptGenerator:
|
|
|
88 |
def __init__(
|
89 |
self,
|
90 |
prompt_template=TEMPLATE,
|
@@ -106,7 +117,7 @@ class PromptGenerator:
|
|
106 |
def get_images(self):
|
107 |
img_list = list()
|
108 |
if self.buffer_size > 0:
|
109 |
-
all_history = self.all_history[-2 * (self.buffer_size + 1)
|
110 |
elif self.buffer_size == 0:
|
111 |
all_history = self.all_history[-2:]
|
112 |
else:
|
@@ -125,7 +136,7 @@ class PromptGenerator:
|
|
125 |
prompt_template = self.prompt_template.format(**format_dict)
|
126 |
ret = prompt_template
|
127 |
if self.buffer_size > 0:
|
128 |
-
all_history = self.all_history[-2 * (self.buffer_size + 1)
|
129 |
elif self.buffer_size == 0:
|
130 |
all_history = self.all_history[-2:]
|
131 |
else:
|
@@ -134,9 +145,11 @@ class PromptGenerator:
|
|
134 |
have_image = False
|
135 |
for role, message in all_history[::-1]:
|
136 |
if message:
|
137 |
-
if type(message) is tuple and message[
|
|
|
138 |
message, _ = message
|
139 |
-
context.append(self.sep + "Image:\n<image>" + self.sep +
|
|
|
140 |
else:
|
141 |
context.append(self.sep + role + ":\n" + message)
|
142 |
else:
|
@@ -162,7 +175,8 @@ def to_gradio_chatbot(prompt_generator):
|
|
162 |
max_hw, min_hw = max(image.size), min(image.size)
|
163 |
aspect_ratio = max_hw / min_hw
|
164 |
max_len, min_len = 800, 400
|
165 |
-
shortest_edge = int(
|
|
|
166 |
longest_edge = int(shortest_edge * aspect_ratio)
|
167 |
H, W = image.size
|
168 |
if H > W:
|
@@ -211,10 +225,12 @@ def bot(
|
|
211 |
inputs = state.get_prompt()
|
212 |
image_paths = state.get_images()[-1:]
|
213 |
|
214 |
-
inference_results = inferencer(inputs, image_paths, max_new_token,
|
|
|
|
|
215 |
state.all_history[-1][-1] = inference_results
|
216 |
-
|
217 |
-
return state, to_gradio_chatbot(state), "", None, inputs
|
218 |
|
219 |
|
220 |
def clear(state):
|
@@ -222,41 +238,57 @@ def clear(state):
|
|
222 |
return state, to_gradio_chatbot(state), "", None, ""
|
223 |
|
224 |
|
|
|
|
|
|
|
|
|
|
|
225 |
def build_conversation_demo():
|
226 |
-
with gr.Blocks() as demo:
|
|
|
|
|
227 |
state = gr.State(PromptGenerator())
|
228 |
with gr.Row():
|
229 |
with gr.Column(scale=3):
|
|
|
|
|
230 |
imagebox = gr.Image(type="filepath")
|
231 |
# TODO config parameters
|
232 |
with gr.Accordion(
|
233 |
-
|
234 |
-
|
235 |
):
|
236 |
-
max_new_token_bar = gr.Slider(
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
239 |
topk_bar = gr.Slider(0, 100, 20, label="top_k", step=1)
|
240 |
topp_bar = gr.Slider(0, 1.0, 1.0, label="top_p", step=0.01)
|
241 |
do_sample = gr.Checkbox(True, label="do_sample")
|
242 |
with gr.Accordion(
|
243 |
-
|
244 |
-
|
245 |
):
|
246 |
-
with gr.Accordion(
|
|
|
247 |
gr.Markdown(Prompt_Tutorial)
|
248 |
with gr.Row():
|
249 |
ai_prefix = gr.Text("Response", label="AI Prefix")
|
250 |
-
user_prefix = gr.Text(
|
|
|
251 |
seperator = gr.Text("\n\n### ", label="Seperator")
|
252 |
-
history_buffer = gr.Slider(
|
|
|
253 |
prompt = gr.Text(TEMPLATE, label="Prompt")
|
254 |
model_inputs = gr.Textbox(label="Actual inputs for Model")
|
255 |
|
256 |
with gr.Column(scale=6):
|
257 |
with gr.Row():
|
258 |
with gr.Column():
|
259 |
-
chatbot = gr.Chatbot(elem_id="chatbot").style(
|
|
|
260 |
with gr.Row():
|
261 |
with gr.Column(scale=8):
|
262 |
textbox = gr.Textbox(
|
@@ -268,7 +300,10 @@ def build_conversation_demo():
|
|
268 |
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
269 |
gr.Examples(
|
270 |
examples=[
|
271 |
-
[
|
|
|
|
|
|
|
272 |
],
|
273 |
inputs=[imagebox, textbox],
|
274 |
)
|
@@ -290,7 +325,10 @@ def build_conversation_demo():
|
|
290 |
topp_bar,
|
291 |
do_sample,
|
292 |
],
|
293 |
-
[
|
|
|
|
|
|
|
294 |
)
|
295 |
submit_btn.click(
|
296 |
bot,
|
@@ -310,9 +348,13 @@ def build_conversation_demo():
|
|
310 |
topp_bar,
|
311 |
do_sample,
|
312 |
],
|
313 |
-
[
|
|
|
|
|
|
|
314 |
)
|
315 |
-
clear_btn.click(clear, [state],
|
|
|
316 |
return demo
|
317 |
|
318 |
|
@@ -320,7 +362,12 @@ if __name__ == "__main__":
|
|
320 |
llama_path = "checkpoints/llama-7b_hf"
|
321 |
open_flamingo_path = "checkpoints/OpenFlamingo-9B/checkpoint.pt"
|
322 |
finetune_path = "checkpoints/mmgpt-lora-v0-release.pt"
|
323 |
-
|
|
|
|
|
|
|
|
|
|
|
324 |
demo = build_conversation_demo()
|
325 |
demo.queue(concurrency_count=3)
|
326 |
IP = "0.0.0.0"
|
|
|
12 |
|
13 |
|
14 |
class Inferencer:
|
15 |
+
|
16 |
def __init__(self, finetune_path, llama_path, open_flamingo_path):
|
17 |
ckpt = torch.load(finetune_path, map_location="cpu")
|
18 |
if "model_state_dict" in ckpt:
|
19 |
state_dict = ckpt["model_state_dict"]
|
20 |
# remove the "module." prefix
|
21 |
+
state_dict = {
|
22 |
+
k[7:]: v
|
23 |
+
for k, v in state_dict.items() if k.startswith("module.")
|
24 |
+
}
|
25 |
else:
|
26 |
state_dict = ckpt
|
27 |
tuning_config = ckpt.get("tuning_config")
|
|
|
39 |
tuning_config=tuning_config,
|
40 |
)
|
41 |
model.load_state_dict(state_dict, strict=False)
|
42 |
+
model.half()
|
43 |
model = model.to("cuda")
|
44 |
+
model.eval()
|
45 |
tokenizer.padding_side = "left"
|
46 |
tokenizer.add_eos_token = False
|
47 |
self.model = model
|
48 |
self.image_processor = image_processor
|
49 |
self.tokenizer = tokenizer
|
50 |
|
51 |
+
def __call__(self, prompt, imgpaths, max_new_token, num_beams, temperature,
|
52 |
+
top_k, top_p, do_sample):
|
53 |
if len(imgpaths) > 1:
|
54 |
+
raise gr.Error(
|
55 |
+
"Current only support one image, please clear gallery and upload one image"
|
56 |
+
)
|
57 |
lang_x = self.tokenizer([prompt], return_tensors="pt")
|
58 |
if len(imgpaths) == 0 or imgpaths is None:
|
59 |
for layer in self.model.lang_encoder._get_decoder_layers():
|
|
|
74 |
images = (Image.open(fp) for fp in imgpaths)
|
75 |
vision_x = [self.image_processor(im).unsqueeze(0) for im in images]
|
76 |
vision_x = torch.cat(vision_x, dim=0)
|
77 |
+
vision_x = vision_x.unsqueeze(1).unsqueeze(0).half()
|
78 |
|
79 |
output_ids = self.model.generate(
|
80 |
vision_x=vision_x.cuda(),
|
|
|
87 |
top_p=top_p,
|
88 |
do_sample=do_sample,
|
89 |
)[0]
|
90 |
+
generated_text = self.tokenizer.decode(
|
91 |
+
output_ids, skip_special_tokens=True)
|
92 |
# print(generated_text)
|
93 |
result = generated_text.split(response_split)[-1].strip()
|
94 |
return result
|
95 |
|
96 |
|
97 |
class PromptGenerator:
|
98 |
+
|
99 |
def __init__(
|
100 |
self,
|
101 |
prompt_template=TEMPLATE,
|
|
|
117 |
def get_images(self):
|
118 |
img_list = list()
|
119 |
if self.buffer_size > 0:
|
120 |
+
all_history = self.all_history[-2 * (self.buffer_size + 1):]
|
121 |
elif self.buffer_size == 0:
|
122 |
all_history = self.all_history[-2:]
|
123 |
else:
|
|
|
136 |
prompt_template = self.prompt_template.format(**format_dict)
|
137 |
ret = prompt_template
|
138 |
if self.buffer_size > 0:
|
139 |
+
all_history = self.all_history[-2 * (self.buffer_size + 1):]
|
140 |
elif self.buffer_size == 0:
|
141 |
all_history = self.all_history[-2:]
|
142 |
else:
|
|
|
145 |
have_image = False
|
146 |
for role, message in all_history[::-1]:
|
147 |
if message:
|
148 |
+
if type(message) is tuple and message[
|
149 |
+
1] is not None and not have_image:
|
150 |
message, _ = message
|
151 |
+
context.append(self.sep + "Image:\n<image>" + self.sep +
|
152 |
+
role + ":\n" + message)
|
153 |
else:
|
154 |
context.append(self.sep + role + ":\n" + message)
|
155 |
else:
|
|
|
175 |
max_hw, min_hw = max(image.size), min(image.size)
|
176 |
aspect_ratio = max_hw / min_hw
|
177 |
max_len, min_len = 800, 400
|
178 |
+
shortest_edge = int(
|
179 |
+
min(max_len / aspect_ratio, min_len, min_hw))
|
180 |
longest_edge = int(shortest_edge * aspect_ratio)
|
181 |
H, W = image.size
|
182 |
if H > W:
|
|
|
225 |
inputs = state.get_prompt()
|
226 |
image_paths = state.get_images()[-1:]
|
227 |
|
228 |
+
inference_results = inferencer(inputs, image_paths, max_new_token,
|
229 |
+
num_beams, temperature, top_k, top_p,
|
230 |
+
do_sample)
|
231 |
state.all_history[-1][-1] = inference_results
|
232 |
+
memory_allocated = str(torch.cuda.memory_allocated() / 1024**3) + 'GB'
|
233 |
+
return state, to_gradio_chatbot(state), "", None, inputs, memory_allocated
|
234 |
|
235 |
|
236 |
def clear(state):
|
|
|
238 |
return state, to_gradio_chatbot(state), "", None, ""
|
239 |
|
240 |
|
241 |
+
title_markdown = ("""
|
242 |
+
# 🤖 Multi-modal GPT
|
243 |
+
[[Project]](https://github.com/open-mmlab/Multimodal-GPT.git)""")
|
244 |
+
|
245 |
+
|
246 |
def build_conversation_demo():
|
247 |
+
with gr.Blocks(title="Multi-modal GPT") as demo:
|
248 |
+
gr.Markdown(title_markdown)
|
249 |
+
|
250 |
state = gr.State(PromptGenerator())
|
251 |
with gr.Row():
|
252 |
with gr.Column(scale=3):
|
253 |
+
memory_allocated = gr.Textbox(
|
254 |
+
value=init_memory, label="Memory")
|
255 |
imagebox = gr.Image(type="filepath")
|
256 |
# TODO config parameters
|
257 |
with gr.Accordion(
|
258 |
+
"Parameters",
|
259 |
+
open=True,
|
260 |
):
|
261 |
+
max_new_token_bar = gr.Slider(
|
262 |
+
0, 1024, 512, label="max_new_token", step=1)
|
263 |
+
num_beams_bar = gr.Slider(
|
264 |
+
0.0, 10, 3, label="num_beams", step=1)
|
265 |
+
temperature_bar = gr.Slider(
|
266 |
+
0.0, 1.0, 1.0, label="temperature", step=0.01)
|
267 |
topk_bar = gr.Slider(0, 100, 20, label="top_k", step=1)
|
268 |
topp_bar = gr.Slider(0, 1.0, 1.0, label="top_p", step=0.01)
|
269 |
do_sample = gr.Checkbox(True, label="do_sample")
|
270 |
with gr.Accordion(
|
271 |
+
"Prompt",
|
272 |
+
open=False,
|
273 |
):
|
274 |
+
with gr.Accordion(
|
275 |
+
"Click to hide the tutorial", open=False):
|
276 |
gr.Markdown(Prompt_Tutorial)
|
277 |
with gr.Row():
|
278 |
ai_prefix = gr.Text("Response", label="AI Prefix")
|
279 |
+
user_prefix = gr.Text(
|
280 |
+
"Instruction", label="User Prefix")
|
281 |
seperator = gr.Text("\n\n### ", label="Seperator")
|
282 |
+
history_buffer = gr.Slider(
|
283 |
+
-1, 10, -1, label="History buffer", step=1)
|
284 |
prompt = gr.Text(TEMPLATE, label="Prompt")
|
285 |
model_inputs = gr.Textbox(label="Actual inputs for Model")
|
286 |
|
287 |
with gr.Column(scale=6):
|
288 |
with gr.Row():
|
289 |
with gr.Column():
|
290 |
+
chatbot = gr.Chatbot(elem_id="chatbot").style(
|
291 |
+
height=750)
|
292 |
with gr.Row():
|
293 |
with gr.Column(scale=8):
|
294 |
textbox = gr.Textbox(
|
|
|
300 |
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
301 |
gr.Examples(
|
302 |
examples=[
|
303 |
+
[
|
304 |
+
f"{cur_dir}/docs/images/demo_image.jpg",
|
305 |
+
"What is in this image?"
|
306 |
+
],
|
307 |
],
|
308 |
inputs=[imagebox, textbox],
|
309 |
)
|
|
|
325 |
topp_bar,
|
326 |
do_sample,
|
327 |
],
|
328 |
+
[
|
329 |
+
state, chatbot, textbox, imagebox, model_inputs,
|
330 |
+
memory_allocated
|
331 |
+
],
|
332 |
)
|
333 |
submit_btn.click(
|
334 |
bot,
|
|
|
348 |
topp_bar,
|
349 |
do_sample,
|
350 |
],
|
351 |
+
[
|
352 |
+
state, chatbot, textbox, imagebox, model_inputs,
|
353 |
+
memory_allocated
|
354 |
+
],
|
355 |
)
|
356 |
+
clear_btn.click(clear, [state],
|
357 |
+
[state, chatbot, textbox, imagebox, model_inputs])
|
358 |
return demo
|
359 |
|
360 |
|
|
|
362 |
llama_path = "checkpoints/llama-7b_hf"
|
363 |
open_flamingo_path = "checkpoints/OpenFlamingo-9B/checkpoint.pt"
|
364 |
finetune_path = "checkpoints/mmgpt-lora-v0-release.pt"
|
365 |
+
|
366 |
+
inferencer = Inferencer(
|
367 |
+
llama_path=llama_path,
|
368 |
+
open_flamingo_path=open_flamingo_path,
|
369 |
+
finetune_path=finetune_path)
|
370 |
+
init_memory = str(torch.cuda.memory_allocated() / 1024**3) + 'GB'
|
371 |
demo = build_conversation_demo()
|
372 |
demo.queue(concurrency_count=3)
|
373 |
IP = "0.0.0.0"
|