cwadayi commited on
Commit
40ad967
·
verified ·
1 Parent(s): 04b3116

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -38
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  from urllib.parse import quote
 
5
 
6
  # --- 模型配置 (CPU TinyLlama) ---
7
  MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
@@ -10,25 +11,7 @@ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
  print("模型與分詞器載入完成。")
12
 
13
- # --- 核心功能函數 ---
14
-
15
- def llm_generate(prompt, max_new_tokens=256, temperature=0.7, top_p=0.9):
16
- """基礎的語言模型文字生成函數"""
17
- chat = [{"role": "user", "content": prompt}]
18
- formatted_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
19
- inputs = tokenizer(formatted_prompt, return_tensors="pt")
20
- with torch.no_grad():
21
- outputs = model.generate(
22
- **inputs,
23
- max_new_tokens=int(max_new_tokens),
24
- do_sample=True,
25
- temperature=float(temperature),
26
- top_p=float(top_p),
27
- eos_token_id=tokenizer.eos_token_id
28
- )
29
- response = tokenizer.decode(outputs[0, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
30
- return response.strip()
31
-
32
  def get_map_link(location_query):
33
  """根據地點名稱生成 Google 地圖的 URL"""
34
  if not location_query:
@@ -37,45 +20,84 @@ def get_map_link(location_query):
37
  map_url = base_url + quote(location_query)
38
  return f"點擊這裡查看 **{location_query}** 的地圖:\n[在 Google 地圖中開啟]({map_url})"
39
 
 
40
  def unified_processor(query):
41
  """
42
- 統一處理函數:接收使用者輸入,同時生成地點描述和地圖連結。
 
 
 
43
  """
44
  if not query:
45
- return "請輸入一個地點或一段描述。", ""
 
46
 
47
- print(f"正在處理查詢:'{query}'")
 
48
 
49
- # 步驟 1: 讓模型生成關於這個地點的描述
50
- description_prompt = f"請用繁體中文,生動地介紹一下「{query}」這個地方的特色、歷史或是有趣的景點。"
51
- generated_description = llm_generate(description_prompt)
52
 
53
- # 步驟 2: 產生該地點的地圖連結
54
- map_link_markdown = get_map_link(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- # 步驟 3: 讓模型推薦附近的景點
57
  recommendation_prompt = f"我在「{query}」這個地方,請用條列的方式,推薦3個附近的必去景點或必吃美食。"
58
- generated_recommendations = llm_generate(recommendation_prompt)
 
 
 
 
 
 
59
 
60
- # 組合最終的文字輸出
61
  final_text_output = (
62
- f"### 關於「{query}」\n"
63
- f"{generated_description}\n\n"
64
  f"### 附近推薦\n"
65
- f"{generated_recommendations}"
66
  )
67
 
68
- return final_text_output, map_link_markdown
 
 
 
69
 
70
- # --- Gradio 介面 (整合版) ---
 
71
  with gr.Blocks(theme=gr.themes.Default()) as demo:
72
  gr.Markdown(
73
  """
74
  # 🗺️ AI 智慧導遊 ✨
75
- 輸入一個地點,AI 將為您生成生動的介紹、推薦附近景點,並附上地圖!
76
  """
77
  )
78
 
 
 
79
  with gr.Row():
80
  query_input = gr.Textbox(
81
  label="請輸入地點名稱或描述",
@@ -90,10 +112,11 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
90
  with gr.Column(scale=1):
91
  map_output = gr.Markdown(label="地圖連結")
92
 
 
93
  process_button.click(
94
  fn=unified_processor,
95
  inputs=query_input,
96
- outputs=[text_output, map_output]
97
  )
98
 
99
  gr.Examples(
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
  import torch
4
  from urllib.parse import quote
5
+ from threading import Thread
6
 
7
  # --- 模型配置 (CPU TinyLlama) ---
8
  MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
  print("模型與分詞器載入完成。")
13
 
14
+ # --- 地圖連結函數 (保持不變) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def get_map_link(location_query):
16
  """根據地點名稱生成 Google 地圖的 URL"""
17
  if not location_query:
 
20
  map_url = base_url + quote(location_query)
21
  return f"點擊這裡查看 **{location_query}** 的地圖:\n[在 Google 地圖中開啟]({map_url})"
22
 
23
+ # --- 核心處理函數 (修改為支援 Streaming) ---
24
  def unified_processor(query):
25
  """
26
+ 統一處理函數:
27
+ 1. 以串流方式生成地點描述。
28
+ 2. 生成完畢後,一次性生成推薦內容。
29
+ 3. 最後顯示地圖。
30
  """
31
  if not query:
32
+ yield "請輸入一個地點或一段描述。", "", "狀態:待機中"
33
+ return
34
 
35
+ # --- 階段一:串流生成地點描述 ---
36
+ yield "", "", f"狀態:正在為「{query}」生成介紹..."
37
 
38
+ # 設置 streamer
39
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
40
 
41
+ # 準備生成參數
42
+ prompt = f"請用繁體中文,生動地介紹一下「{query}」這個地方的特色、歷史或是有趣的景點。"
43
+ chat = [{"role": "user", "content": prompt}]
44
+ formatted_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
45
+ inputs = tokenizer(formatted_prompt, return_tensors="pt")
46
+
47
+ generation_kwargs = dict(
48
+ inputs,
49
+ streamer=streamer,
50
+ max_new_tokens=256,
51
+ do_sample=True,
52
+ temperature=0.7,
53
+ top_p=0.9
54
+ )
55
+
56
+ # 使用多執行緒來運行 blocking 的 generate 方法
57
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
58
+ thread.start()
59
+
60
+ # 即時從 streamer 讀取並更新輸出
61
+ description_output = "### 關於「" + query + "」\n"
62
+ for new_text in streamer:
63
+ description_output += new_text
64
+ yield description_output, "", f"狀態:正在為「{query}」生成介紹..."
65
+
66
+ # --- 階段二:生成推薦內容 ---
67
+ yield description_output, "", f"狀態:正在為「{query}」生成附近推薦..."
68
 
 
69
  recommendation_prompt = f"我在「{query}」這個地方,請用條列的方式,推薦3個附近的必去景點或必吃美食。"
70
+ chat_reco = [{"role": "user", "content": recommendation_prompt}]
71
+ formatted_reco_prompt = tokenizer.apply_chat_template(chat_reco, tokenize=False, add_generation_prompt=True)
72
+ inputs_reco = tokenizer(formatted_reco_prompt, return_tensors="pt")
73
+
74
+ with torch.no_grad():
75
+ outputs_reco = model.generate(**inputs_reco, max_new_tokens=150)
76
+ recommendations = tokenizer.decode(outputs_reco[0, inputs_reco["input_ids"].shape[-1]:], skip_special_tokens=True)
77
 
 
78
  final_text_output = (
79
+ f"{description_output}\n\n"
 
80
  f"### 附近推薦\n"
81
+ f"{recommendations.strip()}"
82
  )
83
 
84
+ # --- 階段三:生成地圖 ---
85
+ map_link_markdown = get_map_link(query)
86
+
87
+ yield final_text_output, map_link_markdown, "狀態:導覽完成!"
88
 
89
+
90
+ # --- Gradio 介面 ---
91
  with gr.Blocks(theme=gr.themes.Default()) as demo:
92
  gr.Markdown(
93
  """
94
  # 🗺️ AI 智慧導遊 ✨
95
+ 輸入一個地點,AI 將為您即時生成生動的介紹、推薦附近景點,並附上地圖!
96
  """
97
  )
98
 
99
+ status_display = gr.Markdown("狀態:待機中")
100
+
101
  with gr.Row():
102
  query_input = gr.Textbox(
103
  label="請輸入地點名稱或描述",
 
112
  with gr.Column(scale=1):
113
  map_output = gr.Markdown(label="地圖連結")
114
 
115
+ # Gradio 會自動處理 yield 函數的串流輸出
116
  process_button.click(
117
  fn=unified_processor,
118
  inputs=query_input,
119
+ outputs=[text_output, map_output, status_display]
120
  )
121
 
122
  gr.Examples(