tuanhqv123 commited on
Commit
ad57d9c
·
verified ·
1 Parent(s): 010fa18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +277 -40
app.py CHANGED
@@ -1,11 +1,14 @@
1
- from fastapi import FastAPI
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import json
 
 
 
5
 
6
- app = FastAPI()
7
 
8
- # Load models
9
  models = {}
10
  tokenizers = {}
11
 
@@ -14,85 +17,319 @@ MODEL_CONFIGS = {
14
  "qwen3-4b": "Qwen/Qwen3-4B"
15
  }
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  @app.on_event("startup")
18
  async def load_models():
19
- for model_key, model_name in MODEL_CONFIGS.items():
20
- print(f"Loading {model_name}...")
21
- tokenizers[model_key] = AutoTokenizer.from_pretrained(model_name, resume_download=True, timeout=300)
22
- models[model_key] = AutoModelForCausalLM.from_pretrained(
23
- model_name,
24
- torch_dtype="auto",
25
- device_map="auto",
26
- resume_download=True, timeout=300
27
- )
28
- print("All models loaded!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  @app.post("/v1/chat/completions")
31
- def chat_completions(request: dict):
 
 
 
 
32
  try:
 
33
  model_name = request.get("model", "qwen3-1.7b")
34
  messages = request.get("messages", [])
35
  temperature = request.get("temperature", 0.7)
36
  max_tokens = request.get("max_tokens", 1024)
37
 
38
- # Chọn model
39
- if "4b" in model_name.lower() or "4" in model_name:
 
 
 
 
40
  model_key = "qwen3-4b"
41
  else:
42
  model_key = "qwen3-1.7b"
43
 
 
 
 
44
  if model_key not in models:
45
- return {"error": f"Model {model_key} not loaded"}
46
 
 
47
  tokenizer = tokenizers[model_key]
48
  model = models[model_key]
49
 
50
- # Format messages cho Qwen3 - QUAN TRỌNG: dùng apply_chat_template
 
51
  text = tokenizer.apply_chat_template(
52
  messages,
53
  tokenize=False,
54
  add_generation_prompt=True,
55
- enable_thinking=False # Tắt thinking mode để response nhanh hơn
56
  )
57
 
58
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
59
 
60
- # Generate với temperature
61
- generated_ids = model.generate(
62
- **model_inputs,
63
- max_new_tokens=max_tokens,
64
- temperature=temperature,
65
- do_sample=True if temperature > 0 else False,
66
- pad_token_id=tokenizer.eos_token_id
67
- )
 
 
 
 
 
 
 
 
 
68
 
69
  # Extract response - chỉ lấy phần mới generate
70
- output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
 
71
  response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
72
 
73
- # Format response theo OpenAI API để tương thích với AiService
 
 
 
 
74
  return {
75
  "choices": [{
76
  "message": {
77
  "content": response,
78
  "role": "assistant"
79
- }
 
 
80
  }],
81
- "model": model_key
 
 
 
 
 
 
 
82
  }
83
 
 
 
84
  except Exception as e:
85
- print(f"Error: {str(e)}")
 
86
  return {
87
  "choices": [{
88
  "message": {
89
- "content": f"Error processing request: {str(e)}",
90
  "role": "assistant"
91
- }
 
 
92
  }],
93
- "error": str(e)
 
 
 
 
 
94
  }
95
 
96
- @app.get("/")
97
- def health_check():
98
- return {"status": "API is running", "models": list(MODEL_CONFIGS.keys())}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import json
5
+ import time
6
+ from typing import Dict, Any, Optional
7
+ import os
8
 
9
+ app = FastAPI(title="Qwen3 API", description="API for Qwen3 models", version="1.0.0")
10
 
11
+ # Global variables để lưu models
12
  models = {}
13
  tokenizers = {}
14
 
 
17
  "qwen3-4b": "Qwen/Qwen3-4B"
18
  }
19
 
20
+ def download_model_safely(model_name: str, max_retries: int = 3):
21
+ """Download model với retry logic và error handling"""
22
+ for attempt in range(max_retries):
23
+ try:
24
+ print(f"Downloading {model_name} (attempt {attempt + 1}/{max_retries})...")
25
+
26
+ # Download tokenizer với các parameters tối ưu
27
+ tokenizer = AutoTokenizer.from_pretrained(
28
+ model_name,
29
+ resume_download=True,
30
+ timeout=600,
31
+ trust_remote_code=True,
32
+ cache_dir=None # Sử dụng cache mặc định
33
+ )
34
+
35
+ # Download model với cấu hình tối ưu cho free tier
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ model_name,
38
+ torch_dtype="auto",
39
+ device_map="auto",
40
+ resume_download=True,
41
+ timeout=600,
42
+ trust_remote_code=True,
43
+ cache_dir=None,
44
+ low_cpu_mem_usage=True # Tối ưu memory usage
45
+ )
46
+
47
+ print(f"Successfully loaded {model_name}")
48
+ return tokenizer, model
49
+
50
+ except Exception as e:
51
+ print(f"Download failed (attempt {attempt + 1}): {str(e)}")
52
+ if attempt == max_retries - 1:
53
+ raise e
54
+ time.sleep(30) # Wait before retry
55
+
56
+ def load_model_on_demand(model_key: str):
57
+ """Load model khi cần thiết với memory management"""
58
+ if model_key not in models:
59
+ if model_key not in MODEL_CONFIGS:
60
+ raise ValueError(f"Unknown model key: {model_key}")
61
+
62
+ model_name = MODEL_CONFIGS[model_key]
63
+ print(f"Loading {model_name} on demand...")
64
+
65
+ # Memory management: chỉ giữ 1 model trong memory do giới hạn free tier
66
+ if len(models) >= 1:
67
+ oldest_model = list(models.keys())[0]
68
+ print(f"Unloading {oldest_model} to free memory...")
69
+ del models[oldest_model]
70
+ del tokenizers[oldest_model]
71
+ if torch.cuda.is_available():
72
+ torch.cuda.empty_cache()
73
+
74
+ tokenizer, model = download_model_safely(model_name)
75
+ tokenizers[model_key] = tokenizer
76
+ models[model_key] = model
77
+ print(f"{model_name} loaded successfully!")
78
+
79
  @app.on_event("startup")
80
  async def load_models():
81
+ """Load model mặc định khi startup"""
82
+ try:
83
+ print("Loading default model: Qwen3-1.7B...")
84
+ tokenizer, model = download_model_safely("Qwen/Qwen3-1.7B")
85
+ tokenizers["qwen3-1.7b"] = tokenizer
86
+ models["qwen3-1.7b"] = model
87
+ print("Default model loaded successfully!")
88
+ except Exception as e:
89
+ print(f"Failed to load default model: {str(e)}")
90
+ print("Server will continue running, models will be loaded on demand")
91
+
92
+ @app.get("/")
93
+ def health_check():
94
+ """Health check endpoint"""
95
+ return {
96
+ "status": "API is running",
97
+ "available_models": list(MODEL_CONFIGS.keys()),
98
+ "loaded_models": list(models.keys()),
99
+ "version": "1.0.0",
100
+ "message": "Qwen3 API Service"
101
+ }
102
+
103
+ @app.get("/models")
104
+ def list_models():
105
+ """List available models"""
106
+ return {
107
+ "available_models": MODEL_CONFIGS,
108
+ "loaded_models": list(models.keys()),
109
+ "total_available": len(MODEL_CONFIGS),
110
+ "total_loaded": len(models)
111
+ }
112
 
113
  @app.post("/v1/chat/completions")
114
+ def chat_completions(request: Dict[str, Any]):
115
+ """
116
+ OpenAI-compatible chat completions endpoint
117
+ Tương thích hoàn toàn với code AiService hiện tại
118
+ """
119
  try:
120
+ # Parse request parameters
121
  model_name = request.get("model", "qwen3-1.7b")
122
  messages = request.get("messages", [])
123
  temperature = request.get("temperature", 0.7)
124
  max_tokens = request.get("max_tokens", 1024)
125
 
126
+ # Validate input
127
+ if not messages:
128
+ raise HTTPException(status_code=400, detail="Messages cannot be empty")
129
+
130
+ # Determine model key từ model name - tương thích với agents.py
131
+ if "4b" in model_name.lower() or "4" in model_name.lower():
132
  model_key = "qwen3-4b"
133
  else:
134
  model_key = "qwen3-1.7b"
135
 
136
+ print(f"Using model: {model_key} for request")
137
+
138
+ # Load model nếu chưa có
139
  if model_key not in models:
140
+ load_model_on_demand(model_key)
141
 
142
+ # Get model và tokenizer
143
  tokenizer = tokenizers[model_key]
144
  model = models[model_key]
145
 
146
+ # Format messages cho Qwen3 using apply_chat_template
147
+ # Đây là phần quan trọng để tương thích với Qwen3
148
  text = tokenizer.apply_chat_template(
149
  messages,
150
  tokenize=False,
151
  add_generation_prompt=True,
152
+ enable_thinking=False # Tắt thinking mode để response đơn giản và nhanh
153
  )
154
 
155
+ # Tokenize input
156
+ model_inputs = tokenizer([text], return_tensors="pt")
157
 
158
+ # Move to device if available
159
+ if torch.cuda.is_available():
160
+ model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
161
+
162
+ # Generate response với các parameters tối ưu
163
+ with torch.no_grad():
164
+ generated_ids = model.generate(
165
+ **model_inputs,
166
+ max_new_tokens=min(max_tokens, 2048), # Limit max tokens để tránh timeout
167
+ temperature=temperature,
168
+ do_sample=True if temperature > 0 else False,
169
+ pad_token_id=tokenizer.eos_token_id,
170
+ eos_token_id=tokenizer.eos_token_id,
171
+ repetition_penalty=1.1, # Tránh lặp lại
172
+ top_p=0.9 if temperature > 0 else None,
173
+ use_cache=True # Tăng tốc generation
174
+ )
175
 
176
  # Extract response - chỉ lấy phần mới generate
177
+ input_length = model_inputs.input_ids.shape[1]
178
+ output_ids = generated_ids[0][input_length:].tolist()
179
  response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
180
 
181
+ # Clean up response
182
+ if not response:
183
+ response = "I apologize, but I couldn't generate a proper response. Please try again."
184
+
185
+ # Format response theo OpenAI API để tương thích hoàn toàn với AiService
186
  return {
187
  "choices": [{
188
  "message": {
189
  "content": response,
190
  "role": "assistant"
191
+ },
192
+ "finish_reason": "stop",
193
+ "index": 0
194
  }],
195
+ "model": model_key,
196
+ "usage": {
197
+ "prompt_tokens": input_length,
198
+ "completion_tokens": len(output_ids),
199
+ "total_tokens": input_length + len(output_ids)
200
+ },
201
+ "object": "chat.completion",
202
+ "created": int(time.time())
203
  }
204
 
205
+ except HTTPException:
206
+ raise
207
  except Exception as e:
208
+ print(f"Error in chat_completions: {str(e)}")
209
+ # Return error trong format tương thích với OpenAI API
210
  return {
211
  "choices": [{
212
  "message": {
213
+ "content": f"I encountered an error while processing your request: {str(e)}",
214
  "role": "assistant"
215
+ },
216
+ "finish_reason": "error",
217
+ "index": 0
218
  }],
219
+ "error": {
220
+ "message": str(e),
221
+ "type": "internal_error",
222
+ "code": "processing_error"
223
+ },
224
+ "model": "qwen3-1.7b"
225
  }
226
 
227
+ @app.post("/generate")
228
+ def simple_generate(request: Dict[str, Any]):
229
+ """
230
+ Simple generate endpoint cho testing đơn giản
231
+ """
232
+ try:
233
+ text = request.get("text", "")
234
+ model_name = request.get("model", "qwen3-1.7b")
235
+ max_tokens = request.get("max_tokens", 100)
236
+ temperature = request.get("temperature", 0.7)
237
+
238
+ if not text:
239
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
240
+
241
+ # Determine model key
242
+ if "4b" in model_name.lower():
243
+ model_key = "qwen3-4b"
244
+ else:
245
+ model_key = "qwen3-1.7b"
246
+
247
+ # Load model nếu cần
248
+ if model_key not in models:
249
+ load_model_on_demand(model_key)
250
+
251
+ tokenizer = tokenizers[model_key]
252
+ model = models[model_key]
253
+
254
+ # Simple generation
255
+ inputs = tokenizer(text, return_tensors="pt")
256
+ if torch.cuda.is_available():
257
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
258
+
259
+ with torch.no_grad():
260
+ outputs = model.generate(
261
+ **inputs,
262
+ max_new_tokens=max_tokens,
263
+ temperature=temperature,
264
+ do_sample=True if temperature > 0 else False,
265
+ pad_token_id=tokenizer.eos_token_id
266
+ )
267
+
268
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
269
+
270
+ return {
271
+ "generated_text": response,
272
+ "model": model_key,
273
+ "input_text": text
274
+ }
275
+
276
+ except Exception as e:
277
+ raise HTTPException(status_code=500, detail=str(e))
278
+
279
+ @app.get("/health")
280
+ def health():
281
+ """Simple health check"""
282
+ return {
283
+ "status": "healthy",
284
+ "timestamp": int(time.time()),
285
+ "models_loaded": len(models)
286
+ }
287
+
288
+ @app.get("/status")
289
+ def status():
290
+ """Detailed status information"""
291
+ return {
292
+ "service": "Qwen3 API",
293
+ "status": "running",
294
+ "models": {
295
+ "available": MODEL_CONFIGS,
296
+ "loaded": list(models.keys()),
297
+ "memory_usage": {
298
+ "total_models": len(models),
299
+ "cuda_available": torch.cuda.is_available(),
300
+ "cuda_memory": torch.cuda.get_device_properties(0).total_memory if torch.cuda.is_available() else None
301
+ }
302
+ },
303
+ "endpoints": [
304
+ "/v1/chat/completions",
305
+ "/generate",
306
+ "/models",
307
+ "/health",
308
+ "/status"
309
+ ]
310
+ }
311
+
312
+ # Error handlers
313
+ @app.exception_handler(404)
314
+ async def not_found_handler(request, exc):
315
+ return {
316
+ "error": {
317
+ "message": "Endpoint not found",
318
+ "type": "not_found_error",
319
+ "code": 404
320
+ }
321
+ }
322
+
323
+ @app.exception_handler(500)
324
+ async def internal_error_handler(request, exc):
325
+ return {
326
+ "error": {
327
+ "message": "Internal server error",
328
+ "type": "internal_server_error",
329
+ "code": 500
330
+ }
331
+ }
332
+
333
+ if __name__ == "__main__":
334
+ import uvicorn
335
+ uvicorn.run(app, host="0.0.0.0", port=7860)