darknoon commited on
Commit
7dbba05
·
1 Parent(s): 85c421c

fixing conversation formatting

Browse files
Files changed (1) hide show
  1. app.py +58 -29
app.py CHANGED
@@ -10,54 +10,82 @@ import requests
10
  model_path = "facebook/chameleon-7b"
11
  # model = ChameleonForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
12
  # processor = ChameleonProcessor.from_pretrained(model_path)
13
- model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto", use_auth_token=True)
14
  model.eval()
15
- processor = ChameleonProcessor.from_pretrained(model_path, use_auth_token=True)
16
  tokenizer = processor.tokenizer
17
 
18
- def load_example_image():
19
- global image
20
- if not image:
21
- image = Image.open(requests.get("https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True).raw)
22
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  @spaces.GPU(duration=90)
25
  def respond(
26
  message,
27
- history: list[tuple[str, str]],
28
  system_message,
29
  max_tokens,
30
  temperature,
31
  top_p,
32
  ):
33
- # messages = [{"role": "system", "content": system_message}]
34
-
35
- # for val in history:
36
- # if val[0]:
37
- # messages.append({"role": "user", "content": val[0]})
38
- # if val[1]:
39
- # messages.append({"role": "assistant", "content": val[1]})
40
-
41
- # messages.append({"role": "user", "content": message})
42
-
43
  response = ""
44
 
45
- prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
46
- image = load_example_image()
 
 
 
 
 
47
 
48
- inputs = processor(prompt, images=[image], return_tensors="pt").to(model.device, dtype=torch.bfloat16)
49
 
50
  streamer = TextIteratorStreamer(tokenizer)
51
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
52
 
53
- # launch generation in the background
54
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
55
- thread.start()
 
56
 
57
- partial_message = ""
58
- for new_token in streamer:
59
- partial_message += new_token
60
- yield partial_message
 
 
61
 
62
 
63
  """
@@ -65,6 +93,7 @@ For information on how to customize the ChatInterface, peruse the gradio docs: h
65
  """
66
  demo = gr.ChatInterface(
67
  respond,
 
68
  additional_inputs=[
69
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
70
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
@@ -81,4 +110,4 @@ demo = gr.ChatInterface(
81
 
82
 
83
  if __name__ == "__main__":
84
- demo.launch()
 
10
  model_path = "facebook/chameleon-7b"
11
  # model = ChameleonForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
12
  # processor = ChameleonProcessor.from_pretrained(model_path)
13
+ model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
14
  model.eval()
15
+ processor = ChameleonProcessor.from_pretrained(model_path)
16
  tokenizer = processor.tokenizer
17
 
18
+ multimodal_file = tuple[str, str]
19
+ multimodal_message = list[str | multimodal_file] | multimodal_file
20
+ # todo: verify this type with gr.ChatInterface
21
+ message_t = str | multimodal_message
22
+ history_t = list[tuple[str, str] | list[tuple[multimodal_message, multimodal_message]]]
23
+
24
+ def history_to_prompt(
25
+ message,
26
+ history: history_t,
27
+ eot_id = "<reserved08706>",
28
+ image_placeholder = "<image>"
29
+ ):
30
+
31
+ prompt = ""
32
+ images = []
33
+ for turn in history + (message, None):
34
+ print("turn:", turn)
35
+ # turn should be a tuple of user message and assistant message
36
+ for message in turn:
37
+ if isinstance(message, str):
38
+ prompt += user_message
39
+ prompt += eot_id
40
+ if isinstance(message, list):
41
+ for item in message:
42
+ if isinstance(item, str):
43
+ prompt += item
44
+ elif isinstance(item, tuple):
45
+ image_path, alt = item
46
+ prompt += image_placeholder
47
+ image = Image.open(requests.get(image_path, stream=True).raw)
48
+ images.append(image)
49
+ else:
50
+ prompt += f"(unhandled message type: {message})"
51
+ prompt += eot_id
52
+ return prompt, images
53
 
54
  @spaces.GPU(duration=90)
55
  def respond(
56
  message,
57
+ history: history_t,
58
  system_message,
59
  max_tokens,
60
  temperature,
61
  top_p,
62
  ):
 
 
 
 
 
 
 
 
 
 
63
  response = ""
64
 
65
+ print(f"message: {message}\nhistory:\n\n{history}\n")
66
+ prompt, images = history_to_prompt(message, history)
67
+ print(f"prompt:\n\n{prompt}\n")
68
+
69
+ # prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
70
+ # image = Image.open(requests.get("https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True).raw)
71
+ # images = [image]
72
 
73
+ inputs = processor(prompt, images=images, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
74
 
75
  streamer = TextIteratorStreamer(tokenizer)
76
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
77
 
78
+ try:
79
+ # launch generation in the background
80
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
81
+ thread.start()
82
 
83
+ partial_message = ""
84
+ for new_token in streamer:
85
+ partial_message += new_token
86
+ yield partial_message
87
+ except e:
88
+ return f"Error: {e}"
89
 
90
 
91
  """
 
93
  """
94
  demo = gr.ChatInterface(
95
  respond,
96
+ multimodal=True,
97
  additional_inputs=[
98
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
99
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
 
110
 
111
 
112
  if __name__ == "__main__":
113
+ demo.launch(debug=True)