Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from peft import PeftModel | |
import logging | |
import os | |
import gc | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class TextToSQLModel: | |
"""Text-to-SQL model wrapper for deployment""" | |
def __init__(self, model_dir="./final-model", base_model="Salesforce/codet5-base"): | |
self.model_dir = model_dir | |
self.base_model = base_model | |
self.max_length = 128 | |
self.model = None | |
self.tokenizer = None | |
self._load_model() | |
def _load_model(self): | |
"""Load the trained model and tokenizer with optimizations for HF Spaces""" | |
try: | |
# Check if model directory exists | |
if not os.path.exists(self.model_dir): | |
raise FileNotFoundError(f"Model directory {self.model_dir} not found") | |
logger.info("Loading tokenizer...") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_dir, | |
trust_remote_code=True, | |
use_fast=True | |
) | |
logger.info("Loading base model...") | |
# Use lower precision and CPU if needed for memory optimization | |
device = "cpu" # Force CPU for HF Spaces stability | |
torch_dtype = torch.float32 # Use float32 for better compatibility | |
base_model = AutoModelForSeq2SeqLM.from_pretrained( | |
self.base_model, | |
torch_dtype=torch_dtype, | |
device_map=device, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
) | |
logger.info("Loading PEFT model...") | |
self.model = PeftModel.from_pretrained( | |
base_model, | |
self.model_dir, | |
torch_dtype=torch_dtype, | |
device_map=device | |
) | |
# Move to CPU and set to eval mode | |
self.model = self.model.to(device) | |
self.model.eval() | |
# Clear cache to free memory | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
logger.info("Model loaded successfully!") | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
# Clean up on error | |
self.model = None | |
self.tokenizer = None | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
raise | |
def predict(self, question: str, table_headers: list) -> str: | |
""" | |
Generate SQL query for a given question and table headers | |
Args: | |
question (str): Natural language question | |
table_headers (list): List of table column names | |
Returns: | |
str: Generated SQL query | |
""" | |
try: | |
if self.model is None or self.tokenizer is None: | |
raise RuntimeError("Model not properly loaded") | |
# Format input text | |
table_headers_str = ", ".join(table_headers) | |
input_text = f"### Table columns:\n{table_headers_str}\n### Question:\n{question}\n### SQL:" | |
# Tokenize input | |
inputs = self.tokenizer( | |
input_text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=self.max_length | |
) | |
# Generate prediction with memory optimization | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_length=self.max_length, | |
num_beams=1, # Use greedy decoding for speed | |
do_sample=False, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
# Decode prediction | |
sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Clean up | |
del inputs, outputs | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return sql_query.strip() | |
except Exception as e: | |
logger.error(f"Error generating SQL: {str(e)}") | |
raise | |
def batch_predict(self, queries: list) -> list: | |
""" | |
Generate SQL queries for multiple questions | |
Args: | |
queries (list): List of dicts with 'question' and 'table_headers' keys | |
Returns: | |
list: List of generated SQL queries | |
""" | |
results = [] | |
for query in queries: | |
try: | |
sql = self.predict(query['question'], query['table_headers']) | |
results.append({ | |
'question': query['question'], | |
'table_headers': query['table_headers'], | |
'sql': sql, | |
'status': 'success' | |
}) | |
except Exception as e: | |
logger.error(f"Error in batch prediction for query '{query['question']}': {str(e)}") | |
results.append({ | |
'question': query['question'], | |
'table_headers': query['table_headers'], | |
'sql': None, | |
'status': 'error', | |
'error': str(e) | |
}) | |
return results | |
def health_check(self) -> bool: | |
"""Check if model is loaded and ready""" | |
return (self.model is not None and | |
self.tokenizer is not None and | |
hasattr(self.model, 'generate')) | |
# Global model instance | |
_model_instance = None | |
def get_model(): | |
"""Get or create global model instance""" | |
global _model_instance | |
if _model_instance is None: | |
_model_instance = TextToSQLModel() | |
return _model_instance |