SobroJuriBert / main_endpoints.py
Sobro API
Initial SobroJuriBert deployment with JuriBERT integration
c914f37
raw
history blame
5.31 kB
# This file contains the endpoint implementations
# In production, merge this with main.py
@app.post("/mask-fill")
async def mask_fill(request: MaskFillRequest):
"""Fill [MASK] tokens in French legal text"""
try:
tokenizer = tokenizers['juribert_base']
model = models['juribert_mlm']
# Create pipeline
fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)
# Get predictions
predictions = fill_mask(request.text, top_k=request.top_k)
return {
"input": request.text,
"predictions": [
{
"sequence": pred['sequence'],
"score": pred['score'],
"token": pred['token_str']
}
for pred in predictions
]
}
except Exception as e:
logger.error(f"Mask fill error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/embeddings")
async def generate_embeddings(request: EmbeddingRequest):
"""Generate embeddings for French legal texts"""
try:
tokenizer = tokenizers['juribert_base']
model = models['juribert_base']
embeddings = []
for text in request.texts:
# Tokenize
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# Generate embeddings
with torch.no_grad():
outputs = model(**inputs)
# Use CLS token embedding
embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
embeddings.append(embedding.tolist())
return {
"embeddings": embeddings,
"dimension": len(embeddings[0]) if embeddings else 0
}
except Exception as e:
logger.error(f"Embedding error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/ner")
async def extract_entities(request: NERRequest):
"""Extract named entities from French legal text"""
try:
# Use CamemBERT NER model
ner_pipeline = models['camembert_ner']
entities = ner_pipeline(request.text)
# Format results
formatted_entities = []
for entity in entities:
formatted_entities.append({
"text": entity['word'],
"type": entity['entity_group'],
"score": entity['score'],
"start": entity['start'],
"end": entity['end']
})
return {
"entities": formatted_entities,
"text": request.text
}
except Exception as e:
logger.error(f"NER error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/qa")
async def question_answering(request: QARequest):
"""Answer questions about French legal documents"""
try:
# Simple implementation for now
# In production, use a fine-tuned QA model
return {
"question": request.question,
"answer": "This feature requires a fine-tuned QA model. Please check back later.",
"confidence": 0.0,
"relevant_articles": [],
"explanation": "QA model is being fine-tuned on French legal data"
}
except Exception as e:
logger.error(f"QA error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/classify")
async def classify_document(request: ClassificationRequest):
"""Classify French legal documents"""
try:
# Simple keyword-based classification for now
text_lower = request.text.lower()
categories = {
"contract": ["contrat", "accord", "convention", "parties"],
"litigation": ["tribunal", "jugement", "litige", "procès"],
"corporate": ["société", "sarl", "sas", "entreprise"],
"employment": ["travail", "salarié", "employeur", "licenciement"]
}
scores = {}
for category, keywords in categories.items():
score = sum(1 for kw in keywords if kw in text_lower)
if score > 0:
scores[category] = score
if not scores:
primary_category = "general"
else:
primary_category = max(scores, key=scores.get)
return {
"primary_category": primary_category,
"categories": [{"category": cat, "score": score} for cat, score in scores.items()],
"confidence": 0.8 if scores else 0.5,
"document_type": "legal_document",
"legal_domain": primary_category
}
except Exception as e:
logger.error(f"Classification error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"models_loaded": list(models.keys()),
"timestamp": datetime.utcnow().isoformat()
}
# Add this to main.py when deploying