Spaces:
Sleeping
Sleeping
# This file contains the endpoint implementations | |
# In production, merge this with main.py | |
async def mask_fill(request: MaskFillRequest): | |
"""Fill [MASK] tokens in French legal text""" | |
try: | |
tokenizer = tokenizers['juribert_base'] | |
model = models['juribert_mlm'] | |
# Create pipeline | |
fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer) | |
# Get predictions | |
predictions = fill_mask(request.text, top_k=request.top_k) | |
return { | |
"input": request.text, | |
"predictions": [ | |
{ | |
"sequence": pred['sequence'], | |
"score": pred['score'], | |
"token": pred['token_str'] | |
} | |
for pred in predictions | |
] | |
} | |
except Exception as e: | |
logger.error(f"Mask fill error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_embeddings(request: EmbeddingRequest): | |
"""Generate embeddings for French legal texts""" | |
try: | |
tokenizer = tokenizers['juribert_base'] | |
model = models['juribert_base'] | |
embeddings = [] | |
for text in request.texts: | |
# Tokenize | |
inputs = tokenizer( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
padding=True | |
) | |
# Generate embeddings | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Use CLS token embedding | |
embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy() | |
embeddings.append(embedding.tolist()) | |
return { | |
"embeddings": embeddings, | |
"dimension": len(embeddings[0]) if embeddings else 0 | |
} | |
except Exception as e: | |
logger.error(f"Embedding error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def extract_entities(request: NERRequest): | |
"""Extract named entities from French legal text""" | |
try: | |
# Use CamemBERT NER model | |
ner_pipeline = models['camembert_ner'] | |
entities = ner_pipeline(request.text) | |
# Format results | |
formatted_entities = [] | |
for entity in entities: | |
formatted_entities.append({ | |
"text": entity['word'], | |
"type": entity['entity_group'], | |
"score": entity['score'], | |
"start": entity['start'], | |
"end": entity['end'] | |
}) | |
return { | |
"entities": formatted_entities, | |
"text": request.text | |
} | |
except Exception as e: | |
logger.error(f"NER error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def question_answering(request: QARequest): | |
"""Answer questions about French legal documents""" | |
try: | |
# Simple implementation for now | |
# In production, use a fine-tuned QA model | |
return { | |
"question": request.question, | |
"answer": "This feature requires a fine-tuned QA model. Please check back later.", | |
"confidence": 0.0, | |
"relevant_articles": [], | |
"explanation": "QA model is being fine-tuned on French legal data" | |
} | |
except Exception as e: | |
logger.error(f"QA error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def classify_document(request: ClassificationRequest): | |
"""Classify French legal documents""" | |
try: | |
# Simple keyword-based classification for now | |
text_lower = request.text.lower() | |
categories = { | |
"contract": ["contrat", "accord", "convention", "parties"], | |
"litigation": ["tribunal", "jugement", "litige", "procès"], | |
"corporate": ["société", "sarl", "sas", "entreprise"], | |
"employment": ["travail", "salarié", "employeur", "licenciement"] | |
} | |
scores = {} | |
for category, keywords in categories.items(): | |
score = sum(1 for kw in keywords if kw in text_lower) | |
if score > 0: | |
scores[category] = score | |
if not scores: | |
primary_category = "general" | |
else: | |
primary_category = max(scores, key=scores.get) | |
return { | |
"primary_category": primary_category, | |
"categories": [{"category": cat, "score": score} for cat, score in scores.items()], | |
"confidence": 0.8 if scores else 0.5, | |
"document_type": "legal_document", | |
"legal_domain": primary_category | |
} | |
except Exception as e: | |
logger.error(f"Classification error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
"""Health check endpoint""" | |
return { | |
"status": "healthy", | |
"models_loaded": list(models.keys()), | |
"timestamp": datetime.utcnow().isoformat() | |
} | |
# Add this to main.py when deploying |