Transformer / app.py
cwadayi's picture
Update app.py
40ad967 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from urllib.parse import quote
from threading import Thread
# --- 模型配置 (CPU TinyLlama) ---
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
print("正在載入 TinyLlama 模型與分詞器,請稍候...")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print("模型與分詞器載入完成。")
# --- 地圖連結函數 (保持不變) ---
def get_map_link(location_query):
"""根據地點名稱生成 Google 地圖的 URL"""
if not location_query:
return ""
base_url = "https://www.google.com/maps/search/?api=1&query="
map_url = base_url + quote(location_query)
return f"點擊這裡查看 **{location_query}** 的地圖:\n[在 Google 地圖中開啟]({map_url})"
# --- 核心處理函數 (修改為支援 Streaming) ---
def unified_processor(query):
"""
統一處理函數:
1. 以串流方式生成地點描述。
2. 生成完畢後,一次性生成推薦內容。
3. 最後顯示地圖。
"""
if not query:
yield "請輸入一個地點或一段描述。", "", "狀態:待機中"
return
# --- 階段一:串流生成地點描述 ---
yield "", "", f"狀態:正在為「{query}」生成介紹..."
# 設置 streamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 準備生成參數
prompt = f"請用繁體中文,生動地介紹一下「{query}」這個地方的特色、歷史或是有趣的景點。"
chat = [{"role": "user", "content": prompt}]
formatted_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(formatted_prompt, return_tensors="pt")
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9
)
# 使用多執行緒來運行 blocking 的 generate 方法
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# 即時從 streamer 讀取並更新輸出
description_output = "### 關於「" + query + "」\n"
for new_text in streamer:
description_output += new_text
yield description_output, "", f"狀態:正在為「{query}」生成介紹..."
# --- 階段二:生成推薦內容 ---
yield description_output, "", f"狀態:正在為「{query}」生成附近推薦..."
recommendation_prompt = f"我在「{query}」這個地方,請用條列的方式,推薦3個附近的必去景點或必吃美食。"
chat_reco = [{"role": "user", "content": recommendation_prompt}]
formatted_reco_prompt = tokenizer.apply_chat_template(chat_reco, tokenize=False, add_generation_prompt=True)
inputs_reco = tokenizer(formatted_reco_prompt, return_tensors="pt")
with torch.no_grad():
outputs_reco = model.generate(**inputs_reco, max_new_tokens=150)
recommendations = tokenizer.decode(outputs_reco[0, inputs_reco["input_ids"].shape[-1]:], skip_special_tokens=True)
final_text_output = (
f"{description_output}\n\n"
f"### 附近推薦\n"
f"{recommendations.strip()}"
)
# --- 階段三:生成地圖 ---
map_link_markdown = get_map_link(query)
yield final_text_output, map_link_markdown, "狀態:導覽完成!"
# --- Gradio 介面 ---
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown(
"""
# 🗺️ AI 智慧導遊 ✨
輸入一個地點,AI 將為您即時生成生動的介紹、推薦附近景點,並附上地圖!
"""
)
status_display = gr.Markdown("狀態:待機中")
with gr.Row():
query_input = gr.Textbox(
label="請輸入地點名稱或描述",
placeholder="例如:九份老街、有著巨大玻璃金字塔的博物館..."
)
process_button = gr.Button("開始導覽 ✨", variant="primary")
with gr.Row():
with gr.Column(scale=2):
text_output = gr.Markdown(label="AI 導覽介紹")
with gr.Column(scale=1):
map_output = gr.Markdown(label="地圖連結")
# Gradio 會自動處理 yield 函數的串流輸出
process_button.click(
fn=unified_processor,
inputs=query_input,
outputs=[text_output, map_output, status_display]
)
gr.Examples(
examples=[
"東京迪士尼樂園",
"羅浮宮",
"台灣的阿里山",
"一個可以看到極光的玻璃屋飯店"
],
inputs=query_input,
label="試試看這些例子"
)
# --- 啟動 Gradio 應用 ---
demo.launch()