Text_To_Sql_Converter_HF / model_utils.py
Manju080's picture
Fix the Optimization
6416f7d
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
import logging
import os
import gc
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TextToSQLModel:
"""Text-to-SQL model wrapper for deployment"""
def __init__(self, model_dir="./final-model", base_model="Salesforce/codet5-base"):
self.model_dir = model_dir
self.base_model = base_model
self.max_length = 128
self.model = None
self.tokenizer = None
self._load_model()
def _load_model(self):
"""Load the trained model and tokenizer with optimizations for HF Spaces"""
try:
# Check if model directory exists
if not os.path.exists(self.model_dir):
raise FileNotFoundError(f"Model directory {self.model_dir} not found")
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_dir,
trust_remote_code=True,
use_fast=True
)
logger.info("Loading base model...")
# Use lower precision and CPU if needed for memory optimization
device = "cpu" # Force CPU for HF Spaces stability
torch_dtype = torch.float32 # Use float32 for better compatibility
base_model = AutoModelForSeq2SeqLM.from_pretrained(
self.base_model,
torch_dtype=torch_dtype,
device_map=device,
trust_remote_code=True,
low_cpu_mem_usage=True
)
logger.info("Loading PEFT model...")
self.model = PeftModel.from_pretrained(
base_model,
self.model_dir,
torch_dtype=torch_dtype,
device_map=device
)
# Move to CPU and set to eval mode
self.model = self.model.to(device)
self.model.eval()
# Clear cache to free memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
# Clean up on error
self.model = None
self.tokenizer = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
raise
def predict(self, question: str, table_headers: list) -> str:
"""
Generate SQL query for a given question and table headers
Args:
question (str): Natural language question
table_headers (list): List of table column names
Returns:
str: Generated SQL query
"""
try:
if self.model is None or self.tokenizer is None:
raise RuntimeError("Model not properly loaded")
# Format input text
table_headers_str = ", ".join(table_headers)
input_text = f"### Table columns:\n{table_headers_str}\n### Question:\n{question}\n### SQL:"
# Tokenize input
inputs = self.tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length
)
# Generate prediction with memory optimization
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=self.max_length,
num_beams=1, # Use greedy decoding for speed
do_sample=False,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode prediction
sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up
del inputs, outputs
if torch.cuda.is_available():
torch.cuda.empty_cache()
return sql_query.strip()
except Exception as e:
logger.error(f"Error generating SQL: {str(e)}")
raise
def batch_predict(self, queries: list) -> list:
"""
Generate SQL queries for multiple questions
Args:
queries (list): List of dicts with 'question' and 'table_headers' keys
Returns:
list: List of generated SQL queries
"""
results = []
for query in queries:
try:
sql = self.predict(query['question'], query['table_headers'])
results.append({
'question': query['question'],
'table_headers': query['table_headers'],
'sql': sql,
'status': 'success'
})
except Exception as e:
logger.error(f"Error in batch prediction for query '{query['question']}': {str(e)}")
results.append({
'question': query['question'],
'table_headers': query['table_headers'],
'sql': None,
'status': 'error',
'error': str(e)
})
return results
def health_check(self) -> bool:
"""Check if model is loaded and ready"""
return (self.model is not None and
self.tokenizer is not None and
hasattr(self.model, 'generate'))
# Global model instance
_model_instance = None
def get_model():
"""Get or create global model instance"""
global _model_instance
if _model_instance is None:
_model_instance = TextToSQLModel()
return _model_instance