|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForImageTextToText, AutoProcessor, TextIteratorStreamer |
|
from peft import PeftModel |
|
from transformers.image_utils import load_image |
|
from threading import Thread |
|
import time |
|
import html |
|
|
|
|
|
def progress_bar_html(label: str) -> str: |
|
""" |
|
Returns an HTML snippet for a thin progress bar with a label. |
|
The progress bar is styled as a dark animated bar. |
|
""" |
|
return f''' |
|
<div style="display: flex; align-items: center;"> |
|
<span style="margin-right: 10px; font-size: 14px;">{label}</span> |
|
<div style="width: 110px; height: 5px; background-color: #9370DB; border-radius: 2px; overflow: hidden;"> |
|
<div style="width: 100%; height: 100%; background-color: #4B0082; animation: loading 1.5s linear infinite;"></div> |
|
</div> |
|
</div> |
|
<style> |
|
@keyframes loading {{ |
|
0% {{ transform: translateX(-100%); }} |
|
100% {{ transform: translateX(100%); }} |
|
}} |
|
</style> |
|
''' |
|
|
|
|
|
model_name = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" |
|
adapter_name = "xco2/smolvlm2-500M-illustration-description" |
|
|
|
model = AutoModelForImageTextToText.from_pretrained( |
|
model_name, |
|
) |
|
model = PeftModel.from_pretrained(model, adapter_name) |
|
|
|
processor = AutoProcessor.from_pretrained(model_name) |
|
|
|
model = model.merge_and_unload().to(torch.float16).eval() |
|
|
|
print(f"Successfully load the model: {model}") |
|
|
|
|
|
def model_inference(input_dict, history): |
|
text = input_dict["text"] |
|
files = input_dict["files"] |
|
|
|
if len(files) > 1: |
|
images = [load_image(image) for image in files] |
|
elif len(files) == 1: |
|
images = [load_image(files[0])] |
|
else: |
|
images = [] |
|
|
|
if text == "" and not images: |
|
gr.Error("Please input a query and optionally image(s).") |
|
return |
|
if text == "" and images: |
|
gr.Error("Please input a text query along with the image(s).") |
|
return |
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
*[{"type": "image", "image": image} for image in images], |
|
{"type": "text", "text": text}, |
|
], |
|
} |
|
] |
|
inputs = processor.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt", |
|
).to(model.device, dtype=model.dtype) |
|
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) |
|
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048) |
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
buffer = "" |
|
yield progress_bar_html("Processing...") |
|
for new_text in streamer: |
|
escaped_new_text = html.escape(new_text) |
|
buffer += escaped_new_text |
|
|
|
time.sleep(0.01) |
|
yield buffer |
|
|
|
|
|
examples = [ |
|
[{"text": "Write a descriptive caption for this image in a formal tone.", "files": ["example_images/阿能_129888755.jpg"]}], |
|
[{"text": "What are the characters wearing?", "files": ["example_images/阿能_129888755.jpg"]}], |
|
] |
|
|
|
demo = gr.ChatInterface( |
|
fn=model_inference, |
|
description="# **Smolvlm2-500M-illustration-description** \n (running on CPU)", |
|
examples=examples, |
|
fill_height=True, |
|
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"]), |
|
stop_btn="Stop Generation", |
|
multimodal=True, |
|
cache_examples=False, |
|
) |
|
|
|
demo.launch(debug=True) |
|
|