Harold-lkk commited on
Commit
0f4b503
·
1 Parent(s): 579d11f

update readme

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. chat_gradio_demo.py → app.py +77 -30
README.md CHANGED
@@ -17,7 +17,7 @@ The **joint training** of visual and language instructions effectively improves
17
  To install the package in an existing environment, run
18
 
19
  ```bash
20
- git clone xxxxxx.git
21
  pip install -r requirements.txt
22
  pip install -e. -v
23
  ```
 
17
  To install the package in an existing environment, run
18
 
19
  ```bash
20
+ git clone https://github.com/open-mmlab/Multimodal-GPT.git
21
  pip install -r requirements.txt
22
  pip install -e. -v
23
  ```
chat_gradio_demo.py → app.py RENAMED
@@ -12,12 +12,16 @@ Prompt_Tutorial = "Model Inputs = {Prompt}({seperator}Image:\n<image> if image u
12
 
13
 
14
  class Inferencer:
 
15
  def __init__(self, finetune_path, llama_path, open_flamingo_path):
16
  ckpt = torch.load(finetune_path, map_location="cpu")
17
  if "model_state_dict" in ckpt:
18
  state_dict = ckpt["model_state_dict"]
19
  # remove the "module." prefix
20
- state_dict = {k[7:]: v for k, v in state_dict.items() if k.startswith("module.")}
 
 
 
21
  else:
22
  state_dict = ckpt
23
  tuning_config = ckpt.get("tuning_config")
@@ -35,16 +39,21 @@ class Inferencer:
35
  tuning_config=tuning_config,
36
  )
37
  model.load_state_dict(state_dict, strict=False)
 
38
  model = model.to("cuda")
 
39
  tokenizer.padding_side = "left"
40
  tokenizer.add_eos_token = False
41
  self.model = model
42
  self.image_processor = image_processor
43
  self.tokenizer = tokenizer
44
 
45
- def __call__(self, prompt, imgpaths, max_new_token, num_beams, temperature, top_k, top_p, do_sample):
 
46
  if len(imgpaths) > 1:
47
- raise gr.Error("Current only support one image, please clear gallery and upload one image")
 
 
48
  lang_x = self.tokenizer([prompt], return_tensors="pt")
49
  if len(imgpaths) == 0 or imgpaths is None:
50
  for layer in self.model.lang_encoder._get_decoder_layers():
@@ -65,7 +74,7 @@ class Inferencer:
65
  images = (Image.open(fp) for fp in imgpaths)
66
  vision_x = [self.image_processor(im).unsqueeze(0) for im in images]
67
  vision_x = torch.cat(vision_x, dim=0)
68
- vision_x = vision_x.unsqueeze(1).unsqueeze(0)
69
 
70
  output_ids = self.model.generate(
71
  vision_x=vision_x.cuda(),
@@ -78,13 +87,15 @@ class Inferencer:
78
  top_p=top_p,
79
  do_sample=do_sample,
80
  )[0]
81
- generated_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
 
82
  # print(generated_text)
83
  result = generated_text.split(response_split)[-1].strip()
84
  return result
85
 
86
 
87
  class PromptGenerator:
 
88
  def __init__(
89
  self,
90
  prompt_template=TEMPLATE,
@@ -106,7 +117,7 @@ class PromptGenerator:
106
  def get_images(self):
107
  img_list = list()
108
  if self.buffer_size > 0:
109
- all_history = self.all_history[-2 * (self.buffer_size + 1) :]
110
  elif self.buffer_size == 0:
111
  all_history = self.all_history[-2:]
112
  else:
@@ -125,7 +136,7 @@ class PromptGenerator:
125
  prompt_template = self.prompt_template.format(**format_dict)
126
  ret = prompt_template
127
  if self.buffer_size > 0:
128
- all_history = self.all_history[-2 * (self.buffer_size + 1) :]
129
  elif self.buffer_size == 0:
130
  all_history = self.all_history[-2:]
131
  else:
@@ -134,9 +145,11 @@ class PromptGenerator:
134
  have_image = False
135
  for role, message in all_history[::-1]:
136
  if message:
137
- if type(message) is tuple and message[1] is not None and not have_image:
 
138
  message, _ = message
139
- context.append(self.sep + "Image:\n<image>" + self.sep + role + ":\n" + message)
 
140
  else:
141
  context.append(self.sep + role + ":\n" + message)
142
  else:
@@ -162,7 +175,8 @@ def to_gradio_chatbot(prompt_generator):
162
  max_hw, min_hw = max(image.size), min(image.size)
163
  aspect_ratio = max_hw / min_hw
164
  max_len, min_len = 800, 400
165
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
 
166
  longest_edge = int(shortest_edge * aspect_ratio)
167
  H, W = image.size
168
  if H > W:
@@ -211,10 +225,12 @@ def bot(
211
  inputs = state.get_prompt()
212
  image_paths = state.get_images()[-1:]
213
 
214
- inference_results = inferencer(inputs, image_paths, max_new_token, num_beams, temperature, top_k, top_p, do_sample)
 
 
215
  state.all_history[-1][-1] = inference_results
216
-
217
- return state, to_gradio_chatbot(state), "", None, inputs
218
 
219
 
220
  def clear(state):
@@ -222,41 +238,57 @@ def clear(state):
222
  return state, to_gradio_chatbot(state), "", None, ""
223
 
224
 
 
 
 
 
 
225
  def build_conversation_demo():
226
- with gr.Blocks() as demo:
 
 
227
  state = gr.State(PromptGenerator())
228
  with gr.Row():
229
  with gr.Column(scale=3):
 
 
230
  imagebox = gr.Image(type="filepath")
231
  # TODO config parameters
232
  with gr.Accordion(
233
- "Parameters",
234
- open=True,
235
  ):
236
- max_new_token_bar = gr.Slider(0, 1024, 512, label="max_new_token", step=1)
237
- num_beams_bar = gr.Slider(0.0, 10, 3, label="num_beams", step=1)
238
- temperature_bar = gr.Slider(0.0, 1.0, 1.0, label="temperature", step=0.01)
 
 
 
239
  topk_bar = gr.Slider(0, 100, 20, label="top_k", step=1)
240
  topp_bar = gr.Slider(0, 1.0, 1.0, label="top_p", step=0.01)
241
  do_sample = gr.Checkbox(True, label="do_sample")
242
  with gr.Accordion(
243
- "Prompt",
244
- open=False,
245
  ):
246
- with gr.Accordion("Click to hide the tutorial", open=False):
 
247
  gr.Markdown(Prompt_Tutorial)
248
  with gr.Row():
249
  ai_prefix = gr.Text("Response", label="AI Prefix")
250
- user_prefix = gr.Text("Instruction", label="User Prefix")
 
251
  seperator = gr.Text("\n\n### ", label="Seperator")
252
- history_buffer = gr.Slider(-1, 10, -1, label="History buffer", step=1)
 
253
  prompt = gr.Text(TEMPLATE, label="Prompt")
254
  model_inputs = gr.Textbox(label="Actual inputs for Model")
255
 
256
  with gr.Column(scale=6):
257
  with gr.Row():
258
  with gr.Column():
259
- chatbot = gr.Chatbot(elem_id="chatbot").style(height=750)
 
260
  with gr.Row():
261
  with gr.Column(scale=8):
262
  textbox = gr.Textbox(
@@ -268,7 +300,10 @@ def build_conversation_demo():
268
  cur_dir = os.path.dirname(os.path.abspath(__file__))
269
  gr.Examples(
270
  examples=[
271
- [f"{cur_dir}/docs/images/demo_image.jpg", "What is in this image?"],
 
 
 
272
  ],
273
  inputs=[imagebox, textbox],
274
  )
@@ -290,7 +325,10 @@ def build_conversation_demo():
290
  topp_bar,
291
  do_sample,
292
  ],
293
- [state, chatbot, textbox, imagebox, model_inputs],
 
 
 
294
  )
295
  submit_btn.click(
296
  bot,
@@ -310,9 +348,13 @@ def build_conversation_demo():
310
  topp_bar,
311
  do_sample,
312
  ],
313
- [state, chatbot, textbox, imagebox, model_inputs],
 
 
 
314
  )
315
- clear_btn.click(clear, [state], [state, chatbot, textbox, imagebox, model_inputs])
 
316
  return demo
317
 
318
 
@@ -320,7 +362,12 @@ if __name__ == "__main__":
320
  llama_path = "checkpoints/llama-7b_hf"
321
  open_flamingo_path = "checkpoints/OpenFlamingo-9B/checkpoint.pt"
322
  finetune_path = "checkpoints/mmgpt-lora-v0-release.pt"
323
- inferencer = Inferencer(llama_path=llama_path, open_flamingo_path=open_flamingo_path, finetune_path=finetune_path)
 
 
 
 
 
324
  demo = build_conversation_demo()
325
  demo.queue(concurrency_count=3)
326
  IP = "0.0.0.0"
 
12
 
13
 
14
  class Inferencer:
15
+
16
  def __init__(self, finetune_path, llama_path, open_flamingo_path):
17
  ckpt = torch.load(finetune_path, map_location="cpu")
18
  if "model_state_dict" in ckpt:
19
  state_dict = ckpt["model_state_dict"]
20
  # remove the "module." prefix
21
+ state_dict = {
22
+ k[7:]: v
23
+ for k, v in state_dict.items() if k.startswith("module.")
24
+ }
25
  else:
26
  state_dict = ckpt
27
  tuning_config = ckpt.get("tuning_config")
 
39
  tuning_config=tuning_config,
40
  )
41
  model.load_state_dict(state_dict, strict=False)
42
+ model.half()
43
  model = model.to("cuda")
44
+ model.eval()
45
  tokenizer.padding_side = "left"
46
  tokenizer.add_eos_token = False
47
  self.model = model
48
  self.image_processor = image_processor
49
  self.tokenizer = tokenizer
50
 
51
+ def __call__(self, prompt, imgpaths, max_new_token, num_beams, temperature,
52
+ top_k, top_p, do_sample):
53
  if len(imgpaths) > 1:
54
+ raise gr.Error(
55
+ "Current only support one image, please clear gallery and upload one image"
56
+ )
57
  lang_x = self.tokenizer([prompt], return_tensors="pt")
58
  if len(imgpaths) == 0 or imgpaths is None:
59
  for layer in self.model.lang_encoder._get_decoder_layers():
 
74
  images = (Image.open(fp) for fp in imgpaths)
75
  vision_x = [self.image_processor(im).unsqueeze(0) for im in images]
76
  vision_x = torch.cat(vision_x, dim=0)
77
+ vision_x = vision_x.unsqueeze(1).unsqueeze(0).half()
78
 
79
  output_ids = self.model.generate(
80
  vision_x=vision_x.cuda(),
 
87
  top_p=top_p,
88
  do_sample=do_sample,
89
  )[0]
90
+ generated_text = self.tokenizer.decode(
91
+ output_ids, skip_special_tokens=True)
92
  # print(generated_text)
93
  result = generated_text.split(response_split)[-1].strip()
94
  return result
95
 
96
 
97
  class PromptGenerator:
98
+
99
  def __init__(
100
  self,
101
  prompt_template=TEMPLATE,
 
117
  def get_images(self):
118
  img_list = list()
119
  if self.buffer_size > 0:
120
+ all_history = self.all_history[-2 * (self.buffer_size + 1):]
121
  elif self.buffer_size == 0:
122
  all_history = self.all_history[-2:]
123
  else:
 
136
  prompt_template = self.prompt_template.format(**format_dict)
137
  ret = prompt_template
138
  if self.buffer_size > 0:
139
+ all_history = self.all_history[-2 * (self.buffer_size + 1):]
140
  elif self.buffer_size == 0:
141
  all_history = self.all_history[-2:]
142
  else:
 
145
  have_image = False
146
  for role, message in all_history[::-1]:
147
  if message:
148
+ if type(message) is tuple and message[
149
+ 1] is not None and not have_image:
150
  message, _ = message
151
+ context.append(self.sep + "Image:\n<image>" + self.sep +
152
+ role + ":\n" + message)
153
  else:
154
  context.append(self.sep + role + ":\n" + message)
155
  else:
 
175
  max_hw, min_hw = max(image.size), min(image.size)
176
  aspect_ratio = max_hw / min_hw
177
  max_len, min_len = 800, 400
178
+ shortest_edge = int(
179
+ min(max_len / aspect_ratio, min_len, min_hw))
180
  longest_edge = int(shortest_edge * aspect_ratio)
181
  H, W = image.size
182
  if H > W:
 
225
  inputs = state.get_prompt()
226
  image_paths = state.get_images()[-1:]
227
 
228
+ inference_results = inferencer(inputs, image_paths, max_new_token,
229
+ num_beams, temperature, top_k, top_p,
230
+ do_sample)
231
  state.all_history[-1][-1] = inference_results
232
+ memory_allocated = str(torch.cuda.memory_allocated() / 1024**3) + 'GB'
233
+ return state, to_gradio_chatbot(state), "", None, inputs, memory_allocated
234
 
235
 
236
  def clear(state):
 
238
  return state, to_gradio_chatbot(state), "", None, ""
239
 
240
 
241
+ title_markdown = ("""
242
+ # 🤖 Multi-modal GPT
243
+ [[Project]](https://github.com/open-mmlab/Multimodal-GPT.git)""")
244
+
245
+
246
  def build_conversation_demo():
247
+ with gr.Blocks(title="Multi-modal GPT") as demo:
248
+ gr.Markdown(title_markdown)
249
+
250
  state = gr.State(PromptGenerator())
251
  with gr.Row():
252
  with gr.Column(scale=3):
253
+ memory_allocated = gr.Textbox(
254
+ value=init_memory, label="Memory")
255
  imagebox = gr.Image(type="filepath")
256
  # TODO config parameters
257
  with gr.Accordion(
258
+ "Parameters",
259
+ open=True,
260
  ):
261
+ max_new_token_bar = gr.Slider(
262
+ 0, 1024, 512, label="max_new_token", step=1)
263
+ num_beams_bar = gr.Slider(
264
+ 0.0, 10, 3, label="num_beams", step=1)
265
+ temperature_bar = gr.Slider(
266
+ 0.0, 1.0, 1.0, label="temperature", step=0.01)
267
  topk_bar = gr.Slider(0, 100, 20, label="top_k", step=1)
268
  topp_bar = gr.Slider(0, 1.0, 1.0, label="top_p", step=0.01)
269
  do_sample = gr.Checkbox(True, label="do_sample")
270
  with gr.Accordion(
271
+ "Prompt",
272
+ open=False,
273
  ):
274
+ with gr.Accordion(
275
+ "Click to hide the tutorial", open=False):
276
  gr.Markdown(Prompt_Tutorial)
277
  with gr.Row():
278
  ai_prefix = gr.Text("Response", label="AI Prefix")
279
+ user_prefix = gr.Text(
280
+ "Instruction", label="User Prefix")
281
  seperator = gr.Text("\n\n### ", label="Seperator")
282
+ history_buffer = gr.Slider(
283
+ -1, 10, -1, label="History buffer", step=1)
284
  prompt = gr.Text(TEMPLATE, label="Prompt")
285
  model_inputs = gr.Textbox(label="Actual inputs for Model")
286
 
287
  with gr.Column(scale=6):
288
  with gr.Row():
289
  with gr.Column():
290
+ chatbot = gr.Chatbot(elem_id="chatbot").style(
291
+ height=750)
292
  with gr.Row():
293
  with gr.Column(scale=8):
294
  textbox = gr.Textbox(
 
300
  cur_dir = os.path.dirname(os.path.abspath(__file__))
301
  gr.Examples(
302
  examples=[
303
+ [
304
+ f"{cur_dir}/docs/images/demo_image.jpg",
305
+ "What is in this image?"
306
+ ],
307
  ],
308
  inputs=[imagebox, textbox],
309
  )
 
325
  topp_bar,
326
  do_sample,
327
  ],
328
+ [
329
+ state, chatbot, textbox, imagebox, model_inputs,
330
+ memory_allocated
331
+ ],
332
  )
333
  submit_btn.click(
334
  bot,
 
348
  topp_bar,
349
  do_sample,
350
  ],
351
+ [
352
+ state, chatbot, textbox, imagebox, model_inputs,
353
+ memory_allocated
354
+ ],
355
  )
356
+ clear_btn.click(clear, [state],
357
+ [state, chatbot, textbox, imagebox, model_inputs])
358
  return demo
359
 
360
 
 
362
  llama_path = "checkpoints/llama-7b_hf"
363
  open_flamingo_path = "checkpoints/OpenFlamingo-9B/checkpoint.pt"
364
  finetune_path = "checkpoints/mmgpt-lora-v0-release.pt"
365
+
366
+ inferencer = Inferencer(
367
+ llama_path=llama_path,
368
+ open_flamingo_path=open_flamingo_path,
369
+ finetune_path=finetune_path)
370
+ init_memory = str(torch.cuda.memory_allocated() / 1024**3) + 'GB'
371
  demo = build_conversation_demo()
372
  demo.queue(concurrency_count=3)
373
  IP = "0.0.0.0"