Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
import torch | |
from urllib.parse import quote | |
from threading import Thread | |
# --- 模型配置 (CPU TinyLlama) --- | |
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
print("正在載入 TinyLlama 模型與分詞器,請稍候...") | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
print("模型與分詞器載入完成。") | |
# --- 地圖連結函數 (保持不變) --- | |
def get_map_link(location_query): | |
"""根據地點名稱生成 Google 地圖的 URL""" | |
if not location_query: | |
return "" | |
base_url = "https://www.google.com/maps/search/?api=1&query=" | |
map_url = base_url + quote(location_query) | |
return f"點擊這裡查看 **{location_query}** 的地圖:\n[在 Google 地圖中開啟]({map_url})" | |
# --- 核心處理函數 (修改為支援 Streaming) --- | |
def unified_processor(query): | |
""" | |
統一處理函數: | |
1. 以串流方式生成地點描述。 | |
2. 生成完畢後,一次性生成推薦內容。 | |
3. 最後顯示地圖。 | |
""" | |
if not query: | |
yield "請輸入一個地點或一段描述。", "", "狀態:待機中" | |
return | |
# --- 階段一:串流生成地點描述 --- | |
yield "", "", f"狀態:正在為「{query}」生成介紹..." | |
# 設置 streamer | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# 準備生成參數 | |
prompt = f"請用繁體中文,生動地介紹一下「{query}」這個地方的特色、歷史或是有趣的景點。" | |
chat = [{"role": "user", "content": prompt}] | |
formatted_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
inputs = tokenizer(formatted_prompt, return_tensors="pt") | |
generation_kwargs = dict( | |
inputs, | |
streamer=streamer, | |
max_new_tokens=256, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
# 使用多執行緒來運行 blocking 的 generate 方法 | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# 即時從 streamer 讀取並更新輸出 | |
description_output = "### 關於「" + query + "」\n" | |
for new_text in streamer: | |
description_output += new_text | |
yield description_output, "", f"狀態:正在為「{query}」生成介紹..." | |
# --- 階段二:生成推薦內容 --- | |
yield description_output, "", f"狀態:正在為「{query}」生成附近推薦..." | |
recommendation_prompt = f"我在「{query}」這個地方,請用條列的方式,推薦3個附近的必去景點或必吃美食。" | |
chat_reco = [{"role": "user", "content": recommendation_prompt}] | |
formatted_reco_prompt = tokenizer.apply_chat_template(chat_reco, tokenize=False, add_generation_prompt=True) | |
inputs_reco = tokenizer(formatted_reco_prompt, return_tensors="pt") | |
with torch.no_grad(): | |
outputs_reco = model.generate(**inputs_reco, max_new_tokens=150) | |
recommendations = tokenizer.decode(outputs_reco[0, inputs_reco["input_ids"].shape[-1]:], skip_special_tokens=True) | |
final_text_output = ( | |
f"{description_output}\n\n" | |
f"### 附近推薦\n" | |
f"{recommendations.strip()}" | |
) | |
# --- 階段三:生成地圖 --- | |
map_link_markdown = get_map_link(query) | |
yield final_text_output, map_link_markdown, "狀態:導覽完成!" | |
# --- Gradio 介面 --- | |
with gr.Blocks(theme=gr.themes.Default()) as demo: | |
gr.Markdown( | |
""" | |
# 🗺️ AI 智慧導遊 ✨ | |
輸入一個地點,AI 將為您即時生成生動的介紹、推薦附近景點,並附上地圖! | |
""" | |
) | |
status_display = gr.Markdown("狀態:待機中") | |
with gr.Row(): | |
query_input = gr.Textbox( | |
label="請輸入地點名稱或描述", | |
placeholder="例如:九份老街、有著巨大玻璃金字塔的博物館..." | |
) | |
process_button = gr.Button("開始導覽 ✨", variant="primary") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
text_output = gr.Markdown(label="AI 導覽介紹") | |
with gr.Column(scale=1): | |
map_output = gr.Markdown(label="地圖連結") | |
# Gradio 會自動處理 yield 函數的串流輸出 | |
process_button.click( | |
fn=unified_processor, | |
inputs=query_input, | |
outputs=[text_output, map_output, status_display] | |
) | |
gr.Examples( | |
examples=[ | |
"東京迪士尼樂園", | |
"羅浮宮", | |
"台灣的阿里山", | |
"一個可以看到極光的玻璃屋飯店" | |
], | |
inputs=query_input, | |
label="試試看這些例子" | |
) | |
# --- 啟動 Gradio 應用 --- | |
demo.launch() | |