tuanhqv123 commited on
Commit
d5238da
·
verified ·
1 Parent(s): e8bbc64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -85
app.py CHANGED
@@ -4,6 +4,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
  import time
6
  import asyncio
 
 
7
  from typing import Dict, Any, Optional
8
  import logging
9
  import traceback
@@ -78,6 +80,49 @@ def load_model_on_demand(model_key: str):
78
  models[model_key] = model
79
  logger.info(f"{model_name} loaded successfully!")
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  @app.on_event("startup")
82
  async def load_models():
83
  """Load default model"""
@@ -99,7 +144,7 @@ def health_check():
99
  "available_models": list(MODEL_CONFIGS.keys()),
100
  "loaded_models": list(models.keys()),
101
  "version": "1.0.0",
102
- "message": "Qwen3 API Service - OpenAI Compatible"
103
  }
104
 
105
  @app.get("/models")
@@ -114,19 +159,20 @@ def list_models():
114
 
115
  @app.post("/v1/chat/completions")
116
  async def chat_completions(request: Dict[str, Any]):
117
- """OpenAI-compatible chat completions endpoint - FIXED AttributeError"""
118
  try:
119
  logger.info("=== CHAT COMPLETIONS REQUEST START ===")
120
- logger.info(f"Request payload: {request}")
121
 
122
  # Parse request parameters
123
  model_name = request.get("model", "qwen3-1.7b")
124
  messages = request.get("messages", [])
125
  temperature = request.get("temperature", 0.7)
126
  max_tokens = request.get("max_tokens", 200)
 
127
 
128
  logger.info(f"Model: {model_name}, Temperature: {temperature}, Max tokens: {max_tokens}")
129
- logger.info(f"Messages: {messages}")
130
 
131
  # Validate input
132
  if not messages:
@@ -151,6 +197,12 @@ async def chat_completions(request: Dict[str, Any]):
151
  model = models[model_key]
152
  logger.info(f"Got tokenizer and model for {model_key}")
153
 
 
 
 
 
 
 
154
  # Format messages - FORCE DISABLE thinking mode
155
  logger.info("Formatting messages with apply_chat_template...")
156
  try:
@@ -161,28 +213,28 @@ async def chat_completions(request: Dict[str, Any]):
161
  enable_thinking=False # CRITICAL: Force disable thinking
162
  )
163
 
164
- # REMOVE thinking tags if present
165
- if "<think>" in text:
166
  logger.warning("Found thinking tags in formatted text, removing...")
167
- text = text.replace("<think>\n\n</think>\n\n", "")
168
- text = text.replace("<think></think>", "")
169
- # Remove any remaining thinking content
170
- import re
171
  text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
 
 
172
 
173
- logger.info(f"Formatted text (first 200 chars): {text[:200]}...")
174
 
175
  except Exception as e:
176
  logger.error(f"Error in apply_chat_template: {str(e)}")
177
  # Fallback to simple format WITHOUT thinking
178
  text = ""
179
  for msg in messages:
180
- if msg["role"] == "user":
 
 
181
  text += f"<|im_start|>user\n{msg['content']}<|im_end|>\n"
182
  elif msg["role"] == "assistant":
183
  text += f"<|im_start|>assistant\n{msg['content']}<|im_end|>\n"
184
  text += "<|im_start|>assistant\n" # NO thinking tags
185
- logger.info(f"Using fallback formatting: {text}")
186
 
187
  # Tokenize input
188
  logger.info("Tokenizing input...")
@@ -204,7 +256,7 @@ async def chat_completions(request: Dict[str, Any]):
204
  with torch.no_grad():
205
  generated_ids = model.generate(
206
  **model_inputs,
207
- max_new_tokens=min(max_tokens, 100),
208
  temperature=temperature,
209
  do_sample=True if temperature > 0 else False,
210
  pad_token_id=tokenizer.eos_token_id,
@@ -251,7 +303,7 @@ async def chat_completions(request: Dict[str, Any]):
251
  "model": model_key
252
  }
253
 
254
- # FIXED: Extract response - handle both tensor and dict cases
255
  logger.info("Extracting response...")
256
  try:
257
  # Get input length correctly
@@ -260,33 +312,34 @@ async def chat_completions(request: Dict[str, Any]):
260
  elif isinstance(model_inputs, dict) and 'input_ids' in model_inputs:
261
  input_length = model_inputs['input_ids'].shape[1]
262
  else:
263
- logger.error("Cannot find input_ids in model_inputs")
264
  input_length = 0
265
 
266
  # Extract output tokens
267
- if torch.is_tensor(generated_ids):
268
- output_ids = generated_ids[0][input_length:].tolist()
269
- else:
270
- output_ids = generated_ids[0][input_length:].tolist()
271
-
272
  response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
273
- logger.info(f"Generated response: {response}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
  except Exception as e:
276
  logger.error(f"Error extracting response: {str(e)}")
277
- # Fallback: decode entire generated sequence
278
- try:
279
- if torch.is_tensor(generated_ids):
280
- response = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
281
- else:
282
- response = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
283
- # Remove the original prompt from response
284
- if text in response:
285
- response = response.replace(text, "").strip()
286
- logger.info(f"Fallback response: {response}")
287
- except Exception as e2:
288
- logger.error(f"Fallback extraction also failed: {str(e2)}")
289
- response = "Error extracting response"
290
 
291
  # Clean up response
292
  if not response:
@@ -333,56 +386,6 @@ async def chat_completions(request: Dict[str, Any]):
333
  "model": "qwen3-1.7b"
334
  }
335
 
336
- @app.post("/generate")
337
- async def simple_generate(request: Dict[str, Any]):
338
- """Simple generate endpoint for testing"""
339
- try:
340
- text = request.get("text", "")
341
- model_name = request.get("model", "qwen3-1.7b")
342
- max_tokens = request.get("max_tokens", 50)
343
- temperature = request.get("temperature", 0.7)
344
-
345
- if not text:
346
- raise HTTPException(status_code=400, detail="Text cannot be empty")
347
-
348
- # Determine model key
349
- if "4b" in model_name.lower():
350
- model_key = "qwen3-4b"
351
- else:
352
- model_key = "qwen3-1.7b"
353
-
354
- # Load model if needed
355
- if model_key not in models:
356
- load_model_on_demand(model_key)
357
-
358
- tokenizer = tokenizers[model_key]
359
- model = models[model_key]
360
-
361
- # Simple generation
362
- inputs = tokenizer(text, return_tensors="pt")
363
- if hasattr(model, 'device'):
364
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
365
-
366
- with torch.no_grad():
367
- outputs = model.generate(
368
- **inputs,
369
- max_new_tokens=max_tokens,
370
- temperature=temperature,
371
- do_sample=True if temperature > 0 else False,
372
- pad_token_id=tokenizer.eos_token_id
373
- )
374
-
375
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
376
-
377
- return {
378
- "generated_text": response,
379
- "model": model_key,
380
- "input_text": text
381
- }
382
-
383
- except Exception as e:
384
- raise HTTPException(status_code=500, detail=str(e))
385
-
386
  @app.get("/health")
387
  def health():
388
  """Simple health check"""
 
4
  import torch
5
  import time
6
  import asyncio
7
+ import json
8
+ import re
9
  from typing import Dict, Any, Optional
10
  import logging
11
  import traceback
 
80
  models[model_key] = model
81
  logger.info(f"{model_name} loaded successfully!")
82
 
83
+ def extract_json_from_response(text: str) -> str:
84
+ """Extract JSON from response text"""
85
+ # Remove thinking tags completely
86
+ text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
87
+ text = text.strip()
88
+
89
+ # Try to find JSON object
90
+ json_match = re.search(r'\{[^{}]*\}', text)
91
+ if json_match:
92
+ return json_match.group(0)
93
+
94
+ # If no JSON found, return the cleaned text
95
+ return text
96
+
97
+ def format_structured_prompt(messages: list, json_schema: dict) -> str:
98
+ """Format messages with JSON schema instructions"""
99
+ # Extract schema properties for clear instructions
100
+ schema_info = json_schema.get('schema', {})
101
+ properties = schema_info.get('properties', {})
102
+ required = schema_info.get('required', [])
103
+
104
+ # Create clear JSON format instructions
105
+ json_instructions = f"""
106
+ You must respond with a valid JSON object only. No explanations, no markdown, no additional text.
107
+
108
+ Required JSON format:
109
+ {json.dumps(schema_info, indent=2)}
110
+
111
+ Example response format: {{"type": "examschedule"}}
112
+ """
113
+
114
+ # Build the conversation
115
+ formatted_messages = []
116
+ for msg in messages:
117
+ if msg["role"] == "system":
118
+ # Append JSON instructions to system message
119
+ content = msg["content"] + "\n" + json_instructions
120
+ formatted_messages.append({"role": "system", "content": content})
121
+ else:
122
+ formatted_messages.append(msg)
123
+
124
+ return formatted_messages
125
+
126
  @app.on_event("startup")
127
  async def load_models():
128
  """Load default model"""
 
144
  "available_models": list(MODEL_CONFIGS.keys()),
145
  "loaded_models": list(models.keys()),
146
  "version": "1.0.0",
147
+ "message": "Qwen3 API Service - OpenAI Compatible with Structured Output"
148
  }
149
 
150
  @app.get("/models")
 
159
 
160
  @app.post("/v1/chat/completions")
161
  async def chat_completions(request: Dict[str, Any]):
162
+ """OpenAI-compatible chat completions endpoint với Structured Output support"""
163
  try:
164
  logger.info("=== CHAT COMPLETIONS REQUEST START ===")
165
+ logger.info(f"Request payload: {json.dumps(request, ensure_ascii=False, indent=2)}")
166
 
167
  # Parse request parameters
168
  model_name = request.get("model", "qwen3-1.7b")
169
  messages = request.get("messages", [])
170
  temperature = request.get("temperature", 0.7)
171
  max_tokens = request.get("max_tokens", 200)
172
+ response_format = request.get("response_format", None)
173
 
174
  logger.info(f"Model: {model_name}, Temperature: {temperature}, Max tokens: {max_tokens}")
175
+ logger.info(f"Response format: {response_format}")
176
 
177
  # Validate input
178
  if not messages:
 
197
  model = models[model_key]
198
  logger.info(f"Got tokenizer and model for {model_key}")
199
 
200
+ # Handle structured output
201
+ if response_format and response_format.get("type") == "json_schema":
202
+ json_schema = response_format.get("json_schema", {})
203
+ logger.info("Structured output requested, formatting messages with JSON schema")
204
+ messages = format_structured_prompt(messages, json_schema)
205
+
206
  # Format messages - FORCE DISABLE thinking mode
207
  logger.info("Formatting messages with apply_chat_template...")
208
  try:
 
213
  enable_thinking=False # CRITICAL: Force disable thinking
214
  )
215
 
216
+ # AGGRESSIVE thinking mode removal
217
+ if "<think>" in text or "think>" in text:
218
  logger.warning("Found thinking tags in formatted text, removing...")
 
 
 
 
219
  text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
220
+ text = re.sub(r'<think>\s*</think>', '', text)
221
+ text = text.replace("<think>", "").replace("</think>", "")
222
 
223
+ logger.info(f"Formatted text (first 300 chars): {text[:300]}...")
224
 
225
  except Exception as e:
226
  logger.error(f"Error in apply_chat_template: {str(e)}")
227
  # Fallback to simple format WITHOUT thinking
228
  text = ""
229
  for msg in messages:
230
+ if msg["role"] == "system":
231
+ text += f"<|im_start|>system\n{msg['content']}<|im_end|>\n"
232
+ elif msg["role"] == "user":
233
  text += f"<|im_start|>user\n{msg['content']}<|im_end|>\n"
234
  elif msg["role"] == "assistant":
235
  text += f"<|im_start|>assistant\n{msg['content']}<|im_end|>\n"
236
  text += "<|im_start|>assistant\n" # NO thinking tags
237
+ logger.info(f"Using fallback formatting")
238
 
239
  # Tokenize input
240
  logger.info("Tokenizing input...")
 
256
  with torch.no_grad():
257
  generated_ids = model.generate(
258
  **model_inputs,
259
+ max_new_tokens=min(max_tokens, 200),
260
  temperature=temperature,
261
  do_sample=True if temperature > 0 else False,
262
  pad_token_id=tokenizer.eos_token_id,
 
303
  "model": model_key
304
  }
305
 
306
+ # Extract response
307
  logger.info("Extracting response...")
308
  try:
309
  # Get input length correctly
 
312
  elif isinstance(model_inputs, dict) and 'input_ids' in model_inputs:
313
  input_length = model_inputs['input_ids'].shape[1]
314
  else:
 
315
  input_length = 0
316
 
317
  # Extract output tokens
318
+ output_ids = generated_ids[0][input_length:].tolist()
 
 
 
 
319
  response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
320
+
321
+ # Handle structured output
322
+ if response_format and response_format.get("type") == "json_schema":
323
+ response = extract_json_from_response(response)
324
+ logger.info(f"Extracted JSON response: {response}")
325
+
326
+ # Validate JSON
327
+ try:
328
+ json.loads(response)
329
+ except json.JSONDecodeError:
330
+ logger.warning("Generated response is not valid JSON, attempting to fix...")
331
+ # Try to extract just the JSON part
332
+ json_match = re.search(r'\{.*\}', response)
333
+ if json_match:
334
+ response = json_match.group(0)
335
+ else:
336
+ response = '{"type": "other"}' # Fallback
337
+
338
+ logger.info(f"Final response: {response}")
339
 
340
  except Exception as e:
341
  logger.error(f"Error extracting response: {str(e)}")
342
+ response = "Error extracting response"
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
  # Clean up response
345
  if not response:
 
386
  "model": "qwen3-1.7b"
387
  }
388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  @app.get("/health")
390
  def health():
391
  """Simple health check"""