Manju080 commited on
Commit
6416f7d
·
1 Parent(s): e161cb9

Fix the Optimization

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +63 -23
  3. model_utils.py +65 -8
  4. requirements.txt +8 -6
  5. startup_test.py +136 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🗄️
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.0.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.35.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -5,9 +5,9 @@ from pydantic import BaseModel
5
  from typing import List, Optional
6
  import uvicorn
7
  import logging
8
- from model_utils import get_model
9
  import time
10
  import os
 
11
  from contextlib import asynccontextmanager
12
 
13
  # Configure logging
@@ -16,18 +16,44 @@ logger = logging.getLogger(__name__)
16
 
17
  # Global model instance
18
  model = None
 
 
19
 
20
  @asynccontextmanager
21
  async def lifespan(app: FastAPI):
22
  # Startup
23
- global model
24
  logger.info("Starting Text-to-SQL API...")
 
 
 
 
 
25
  try:
26
- model = get_model()
27
- logger.info("Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  except Exception as e:
29
- logger.error(f"Failed to load model: {str(e)}")
30
- raise
 
 
 
31
  yield
32
  # Shutdown
33
  logger.info("Shutting down Text-to-SQL API...")
@@ -62,12 +88,10 @@ class BatchResponse(BaseModel):
62
  class HealthResponse(BaseModel):
63
  status: str
64
  model_loaded: bool
 
 
65
  timestamp: float
66
 
67
-
68
-
69
-
70
-
71
  @app.get("/", response_class=HTMLResponse)
72
  async def root():
73
  """Serve the main HTML interface"""
@@ -111,8 +135,14 @@ async def predict_sql(request: SQLRequest):
111
  Returns:
112
  SQLResponse with generated SQL query
113
  """
 
 
 
 
 
114
  if model is None:
115
- raise HTTPException(status_code=503, detail="Model not loaded")
 
116
 
117
  start_time = time.time()
118
 
@@ -142,8 +172,14 @@ async def batch_predict(request: BatchRequest):
142
  Returns:
143
  BatchResponse with generated SQL queries
144
  """
 
 
 
 
 
145
  if model is None:
146
- raise HTTPException(status_code=503, detail="Model not loaded")
 
147
 
148
  start_time = time.time()
149
 
@@ -197,17 +233,28 @@ async def health_check():
197
  Returns:
198
  HealthResponse with service status
199
  """
 
 
200
  model_loaded = model is not None and model.health_check()
201
 
 
 
 
 
 
 
 
202
  return HealthResponse(
203
- status="healthy" if model_loaded else "unhealthy",
204
  model_loaded=model_loaded,
 
 
205
  timestamp=time.time()
206
  )
207
 
208
  @app.get("/example")
209
  async def get_example():
210
- """Get example request format"""
211
  return {
212
  "example_request": {
213
  "question": "How many employees are older than 30?",
@@ -217,16 +264,9 @@ async def get_example():
217
  "question": "How many employees are older than 30?",
218
  "table_headers": ["id", "name", "age", "department", "salary"],
219
  "sql_query": "SELECT COUNT(*) FROM table WHERE age > 30",
220
- "processing_time": 0.123
221
  }
222
  }
223
 
224
  if __name__ == "__main__":
225
- # Run the application
226
- uvicorn.run(
227
- "app:app",
228
- host="0.0.0.0",
229
- port=8000,
230
- reload=False,
231
- log_level="info"
232
- )
 
5
  from typing import List, Optional
6
  import uvicorn
7
  import logging
 
8
  import time
9
  import os
10
+ import asyncio
11
  from contextlib import asynccontextmanager
12
 
13
  # Configure logging
 
16
 
17
  # Global model instance
18
  model = None
19
+ model_loading = False
20
+ model_load_error = None
21
 
22
  @asynccontextmanager
23
  async def lifespan(app: FastAPI):
24
  # Startup
25
+ global model, model_loading, model_load_error
26
  logger.info("Starting Text-to-SQL API...")
27
+
28
+ # Start model loading in background
29
+ model_loading = True
30
+ model_load_error = None
31
+
32
  try:
33
+ # Import here to avoid startup delays
34
+ from model_utils import get_model
35
+
36
+ # Set a timeout for model loading (5 minutes)
37
+ try:
38
+ # Run model loading in a thread to avoid blocking
39
+ import concurrent.futures
40
+ with concurrent.futures.ThreadPoolExecutor() as executor:
41
+ future = executor.submit(get_model)
42
+ model = future.result(timeout=300) # 5 minute timeout
43
+ logger.info("Model loaded successfully!")
44
+ except concurrent.futures.TimeoutError:
45
+ logger.error("Model loading timed out after 5 minutes")
46
+ model_load_error = "Model loading timed out"
47
+ except Exception as e:
48
+ logger.error(f"Failed to load model: {str(e)}")
49
+ model_load_error = str(e)
50
+
51
  except Exception as e:
52
+ logger.error(f"Failed to import model_utils: {str(e)}")
53
+ model_load_error = f"Import error: {str(e)}"
54
+ finally:
55
+ model_loading = False
56
+
57
  yield
58
  # Shutdown
59
  logger.info("Shutting down Text-to-SQL API...")
 
88
  class HealthResponse(BaseModel):
89
  status: str
90
  model_loaded: bool
91
+ model_loading: bool
92
+ model_error: Optional[str] = None
93
  timestamp: float
94
 
 
 
 
 
95
  @app.get("/", response_class=HTMLResponse)
96
  async def root():
97
  """Serve the main HTML interface"""
 
135
  Returns:
136
  SQLResponse with generated SQL query
137
  """
138
+ global model, model_loading, model_load_error
139
+
140
+ if model_loading:
141
+ raise HTTPException(status_code=503, detail="Model is still loading, please try again in a few minutes")
142
+
143
  if model is None:
144
+ error_msg = model_load_error or "Model not loaded"
145
+ raise HTTPException(status_code=503, detail=f"Model not available: {error_msg}")
146
 
147
  start_time = time.time()
148
 
 
172
  Returns:
173
  BatchResponse with generated SQL queries
174
  """
175
+ global model, model_loading, model_load_error
176
+
177
+ if model_loading:
178
+ raise HTTPException(status_code=503, detail="Model is still loading, please try again in a few minutes")
179
+
180
  if model is None:
181
+ error_msg = model_load_error or "Model not loaded"
182
+ raise HTTPException(status_code=503, detail=f"Model not available: {error_msg}")
183
 
184
  start_time = time.time()
185
 
 
233
  Returns:
234
  HealthResponse with service status
235
  """
236
+ global model, model_loading, model_load_error
237
+
238
  model_loaded = model is not None and model.health_check()
239
 
240
+ if model_loaded:
241
+ status = "healthy"
242
+ elif model_loading:
243
+ status = "loading"
244
+ else:
245
+ status = "unhealthy"
246
+
247
  return HealthResponse(
248
+ status=status,
249
  model_loaded=model_loaded,
250
+ model_loading=model_loading,
251
+ model_error=model_load_error,
252
  timestamp=time.time()
253
  )
254
 
255
  @app.get("/example")
256
  async def get_example():
257
+ """Get example usage"""
258
  return {
259
  "example_request": {
260
  "question": "How many employees are older than 30?",
 
264
  "question": "How many employees are older than 30?",
265
  "table_headers": ["id", "name", "age", "department", "salary"],
266
  "sql_query": "SELECT COUNT(*) FROM table WHERE age > 30",
267
+ "processing_time": 0.5
268
  }
269
  }
270
 
271
  if __name__ == "__main__":
272
+ uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
model_utils.py CHANGED
@@ -2,6 +2,8 @@ import torch
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from peft import PeftModel
4
  import logging
 
 
5
 
6
  # Configure logging
7
  logging.basicConfig(level=logging.INFO)
@@ -19,22 +21,59 @@ class TextToSQLModel:
19
  self._load_model()
20
 
21
  def _load_model(self):
22
- """Load the trained model and tokenizer"""
23
  try:
 
 
 
 
24
  logger.info("Loading tokenizer...")
25
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
 
 
 
 
26
 
27
  logger.info("Loading base model...")
28
- base_model = AutoModelForSeq2SeqLM.from_pretrained(self.base_model)
 
 
 
 
 
 
 
 
 
 
29
 
30
  logger.info("Loading PEFT model...")
31
- self.model = PeftModel.from_pretrained(base_model, self.model_dir)
 
 
 
 
 
 
 
 
32
  self.model.eval()
33
 
 
 
 
 
 
34
  logger.info("Model loaded successfully!")
35
 
36
  except Exception as e:
37
  logger.error(f"Error loading model: {str(e)}")
 
 
 
 
 
 
38
  raise
39
 
40
  def predict(self, question: str, table_headers: list) -> str:
@@ -49,6 +88,9 @@ class TextToSQLModel:
49
  str: Generated SQL query
50
  """
51
  try:
 
 
 
52
  # Format input text
53
  table_headers_str = ", ".join(table_headers)
54
  input_text = f"### Table columns:\n{table_headers_str}\n### Question:\n{question}\n### SQL:"
@@ -62,14 +104,26 @@ class TextToSQLModel:
62
  max_length=self.max_length
63
  )
64
 
65
- # Generate prediction
66
  with torch.no_grad():
67
- outputs = self.model.generate(**inputs, max_length=self.max_length)
 
 
 
 
 
 
 
68
 
69
  # Decode prediction
70
  sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
71
 
72
- return sql_query
 
 
 
 
 
73
 
74
  except Exception as e:
75
  logger.error(f"Error generating SQL: {str(e)}")
@@ -96,6 +150,7 @@ class TextToSQLModel:
96
  'status': 'success'
97
  })
98
  except Exception as e:
 
99
  results.append({
100
  'question': query['question'],
101
  'table_headers': query['table_headers'],
@@ -108,7 +163,9 @@ class TextToSQLModel:
108
 
109
  def health_check(self) -> bool:
110
  """Check if model is loaded and ready"""
111
- return self.model is not None and self.tokenizer is not None
 
 
112
 
113
  # Global model instance
114
  _model_instance = None
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from peft import PeftModel
4
  import logging
5
+ import os
6
+ import gc
7
 
8
  # Configure logging
9
  logging.basicConfig(level=logging.INFO)
 
21
  self._load_model()
22
 
23
  def _load_model(self):
24
+ """Load the trained model and tokenizer with optimizations for HF Spaces"""
25
  try:
26
+ # Check if model directory exists
27
+ if not os.path.exists(self.model_dir):
28
+ raise FileNotFoundError(f"Model directory {self.model_dir} not found")
29
+
30
  logger.info("Loading tokenizer...")
31
+ self.tokenizer = AutoTokenizer.from_pretrained(
32
+ self.model_dir,
33
+ trust_remote_code=True,
34
+ use_fast=True
35
+ )
36
 
37
  logger.info("Loading base model...")
38
+ # Use lower precision and CPU if needed for memory optimization
39
+ device = "cpu" # Force CPU for HF Spaces stability
40
+ torch_dtype = torch.float32 # Use float32 for better compatibility
41
+
42
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(
43
+ self.base_model,
44
+ torch_dtype=torch_dtype,
45
+ device_map=device,
46
+ trust_remote_code=True,
47
+ low_cpu_mem_usage=True
48
+ )
49
 
50
  logger.info("Loading PEFT model...")
51
+ self.model = PeftModel.from_pretrained(
52
+ base_model,
53
+ self.model_dir,
54
+ torch_dtype=torch_dtype,
55
+ device_map=device
56
+ )
57
+
58
+ # Move to CPU and set to eval mode
59
+ self.model = self.model.to(device)
60
  self.model.eval()
61
 
62
+ # Clear cache to free memory
63
+ if torch.cuda.is_available():
64
+ torch.cuda.empty_cache()
65
+ gc.collect()
66
+
67
  logger.info("Model loaded successfully!")
68
 
69
  except Exception as e:
70
  logger.error(f"Error loading model: {str(e)}")
71
+ # Clean up on error
72
+ self.model = None
73
+ self.tokenizer = None
74
+ if torch.cuda.is_available():
75
+ torch.cuda.empty_cache()
76
+ gc.collect()
77
  raise
78
 
79
  def predict(self, question: str, table_headers: list) -> str:
 
88
  str: Generated SQL query
89
  """
90
  try:
91
+ if self.model is None or self.tokenizer is None:
92
+ raise RuntimeError("Model not properly loaded")
93
+
94
  # Format input text
95
  table_headers_str = ", ".join(table_headers)
96
  input_text = f"### Table columns:\n{table_headers_str}\n### Question:\n{question}\n### SQL:"
 
104
  max_length=self.max_length
105
  )
106
 
107
+ # Generate prediction with memory optimization
108
  with torch.no_grad():
109
+ outputs = self.model.generate(
110
+ **inputs,
111
+ max_length=self.max_length,
112
+ num_beams=1, # Use greedy decoding for speed
113
+ do_sample=False,
114
+ pad_token_id=self.tokenizer.pad_token_id,
115
+ eos_token_id=self.tokenizer.eos_token_id
116
+ )
117
 
118
  # Decode prediction
119
  sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
120
 
121
+ # Clean up
122
+ del inputs, outputs
123
+ if torch.cuda.is_available():
124
+ torch.cuda.empty_cache()
125
+
126
+ return sql_query.strip()
127
 
128
  except Exception as e:
129
  logger.error(f"Error generating SQL: {str(e)}")
 
150
  'status': 'success'
151
  })
152
  except Exception as e:
153
+ logger.error(f"Error in batch prediction for query '{query['question']}': {str(e)}")
154
  results.append({
155
  'question': query['question'],
156
  'table_headers': query['table_headers'],
 
163
 
164
  def health_check(self) -> bool:
165
  """Check if model is loaded and ready"""
166
+ return (self.model is not None and
167
+ self.tokenizer is not None and
168
+ hasattr(self.model, 'generate'))
169
 
170
  # Global model instance
171
  _model_instance = None
requirements.txt CHANGED
@@ -1,8 +1,10 @@
1
  fastapi==0.104.1
2
  uvicorn[standard]==0.24.0
3
- torch>=2.0.0
4
- transformers>=4.35.0
5
- peft>=0.6.0
6
- accelerate>=0.24.0
7
- pydantic>=2.0.0
8
- python-multipart>=0.0.6
 
 
 
1
  fastapi==0.104.1
2
  uvicorn[standard]==0.24.0
3
+ torch==2.1.0
4
+ transformers==4.35.0
5
+ peft==0.6.0
6
+ accelerate==0.24.0
7
+ pydantic==2.5.0
8
+ python-multipart==0.0.6
9
+ tokenizers==0.15.0
10
+ safetensors==0.4.0
startup_test.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Startup test script for Hugging Face Spaces deployment
4
+ This script helps debug model loading issues
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import time
10
+ import logging
11
+ import traceback
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
+ logger = logging.getLogger(__name__)
16
+
17
+ def test_imports():
18
+ """Test if all required packages can be imported"""
19
+ logger.info("Testing imports...")
20
+
21
+ try:
22
+ import torch
23
+ logger.info(f"PyTorch version: {torch.__version__}")
24
+ except ImportError as e:
25
+ logger.error(f"Failed to import torch: {e}")
26
+ return False
27
+
28
+ try:
29
+ import transformers
30
+ logger.info(f"Transformers version: {transformers.__version__}")
31
+ except ImportError as e:
32
+ logger.error(f"Failed to import transformers: {e}")
33
+ return False
34
+
35
+ try:
36
+ import peft
37
+ logger.info(f"PEFT version: {peft.__version__}")
38
+ except ImportError as e:
39
+ logger.error(f"Failed to import peft: {e}")
40
+ return False
41
+
42
+ try:
43
+ import fastapi
44
+ logger.info(f"FastAPI version: {fastapi.__version__}")
45
+ except ImportError as e:
46
+ logger.error(f"Failed to import fastapi: {e}")
47
+ return False
48
+
49
+ return True
50
+
51
+ def test_model_files():
52
+ """Test if model files exist"""
53
+ logger.info("Testing model files...")
54
+
55
+ model_dir = "./final-model"
56
+ required_files = [
57
+ "adapter_config.json",
58
+ "adapter_model.safetensors",
59
+ "tokenizer.json",
60
+ "tokenizer_config.json",
61
+ "vocab.json"
62
+ ]
63
+
64
+ if not os.path.exists(model_dir):
65
+ logger.error(f"Model directory {model_dir} does not exist")
66
+ return False
67
+
68
+ missing_files = []
69
+ for file in required_files:
70
+ file_path = os.path.join(model_dir, file)
71
+ if not os.path.exists(file_path):
72
+ missing_files.append(file)
73
+ else:
74
+ size = os.path.getsize(file_path)
75
+ logger.info(f"✓ {file} exists ({size} bytes)")
76
+
77
+ if missing_files:
78
+ logger.error(f"Missing required files: {missing_files}")
79
+ return False
80
+
81
+ return True
82
+
83
+ def test_model_loading():
84
+ """Test model loading with timeout"""
85
+ logger.info("Testing model loading...")
86
+
87
+ try:
88
+ from model_utils import get_model
89
+
90
+ start_time = time.time()
91
+ model = get_model()
92
+ load_time = time.time() - start_time
93
+
94
+ logger.info(f"Model loaded successfully in {load_time:.2f} seconds")
95
+
96
+ # Test a simple prediction
97
+ test_question = "How many records are there?"
98
+ test_headers = ["id", "name", "age"]
99
+
100
+ start_time = time.time()
101
+ result = model.predict(test_question, test_headers)
102
+ predict_time = time.time() - start_time
103
+
104
+ logger.info(f"Test prediction successful in {predict_time:.2f} seconds")
105
+ logger.info(f"Generated SQL: {result}")
106
+
107
+ return True
108
+
109
+ except Exception as e:
110
+ logger.error(f"Model loading failed: {e}")
111
+ logger.error(traceback.format_exc())
112
+ return False
113
+
114
+ def main():
115
+ """Run all tests"""
116
+ logger.info("Starting Hugging Face Spaces deployment tests...")
117
+
118
+ # Test 1: Imports
119
+ if not test_imports():
120
+ logger.error("Import test failed")
121
+ sys.exit(1)
122
+
123
+ # Test 2: Model files
124
+ if not test_model_files():
125
+ logger.error("Model files test failed")
126
+ sys.exit(1)
127
+
128
+ # Test 3: Model loading
129
+ if not test_model_loading():
130
+ logger.error("Model loading test failed")
131
+ sys.exit(1)
132
+
133
+ logger.info("All tests passed! Ready for deployment.")
134
+
135
+ if __name__ == "__main__":
136
+ main()