SobroJuriBert / main.py
Sobro API
Initial SobroJuriBert deployment with JuriBERT integration
c914f37
raw
history blame
11.6 kB
import os
import json
import logging
from datetime import datetime
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import torch
from transformers import (
AutoTokenizer,
AutoModel,
AutoModelForMaskedLM,
AutoModelForTokenClassification,
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
pipeline
)
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="SobroJuriBert API",
description="French Legal AI API powered by JuriBERT for comprehensive legal document analysis",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model storage
models = {}
tokenizers = {}
# Pydantic models
class TextRequest(BaseModel):
text: str = Field(..., description="Text to analyze")
class MaskFillRequest(BaseModel):
text: str = Field(..., description="Text with [MASK] tokens")
top_k: int = Field(5, description="Number of predictions to return")
class NERRequest(BaseModel):
text: str = Field(..., description="Legal text for entity extraction")
class QARequest(BaseModel):
context: str = Field(..., description="Legal document context")
question: str = Field(..., description="Question about the document")
class ClassificationRequest(BaseModel):
text: str = Field(..., description="Legal document to classify")
class EmbeddingRequest(BaseModel):
texts: List[str] = Field(..., description="List of texts to embed")
class JurisprudenceSearchRequest(BaseModel):
query: str = Field(..., description="Search query")
filters: Optional[Dict[str, Any]] = Field(None, description="Filters for search")
limit: int = Field(10, description="Number of results")
class ContractAnalysisRequest(BaseModel):
text: str = Field(..., description="Contract text to analyze")
contract_type: Optional[str] = Field(None, description="Type of contract")
@app.on_event("startup")
async def load_models():
"""Load all required models on startup"""
logger.info("Loading French legal models...")
try:
# Load JuriBERT base model for embeddings and mask filling
logger.info("Loading JuriBERT base model...")
models['juribert_base'] = AutoModel.from_pretrained('dascim/juribert-base')
tokenizers['juribert_base'] = AutoTokenizer.from_pretrained('dascim/juribert-base')
models['juribert_mlm'] = AutoModelForMaskedLM.from_pretrained('dascim/juribert-base')
# Load CamemBERT models as fallback/complement
logger.info("Loading CamemBERT models...")
models['camembert_ner'] = pipeline(
'ner',
model='Jean-Baptiste/camembert-ner-with-dates',
aggregation_strategy="simple"
)
# Load legal-specific models
logger.info("Loading French legal classification model...")
models['legal_classifier'] = pipeline(
'text-classification',
model='nlptown/bert-base-multilingual-uncased-sentiment' # Placeholder
)
logger.info("All models loaded successfully!")
except Exception as e:
logger.error(f"Error loading models: {e}")
raise
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"name": "SobroJuriBert API",
"version": "1.0.0",
"description": "French Legal AI API for lawyers",
"endpoints": {
"mask_fill": "/mask-fill - Fill masked tokens in legal text",
"embeddings": "/embeddings - Generate legal text embeddings",
"ner": "/ner - Extract legal entities",
"qa": "/qa - Answer questions about legal documents",
"classify": "/classify - Classify legal documents",
"analyze_contract": "/analyze-contract - Analyze legal contracts",
"search_jurisprudence": "/search-jurisprudence - Search case law",
"extract_articles": "/extract-articles - Extract legal article references",
"check_compliance": "/check-compliance - Check legal compliance",
"generate_summary": "/generate-summary - Generate legal summaries"
},
"models": {
"base": "dascim/juribert-base",
"ner": "Jean-Baptiste/camembert-ner-with-dates",
"training_data": "6.3GB French legal texts from Légifrance + 100k+ court decisions"
}
}
@app.post("/mask-fill")
async def mask_fill(request: MaskFillRequest):
"""Fill [MASK] tokens in French legal text"""
try:
tokenizer = tokenizers['juribert_base']
model = models['juribert_mlm']
# Create pipeline
fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)
# Get predictions
predictions = fill_mask(request.text, top_k=request.top_k)
return {
"input": request.text,
"predictions": [
{
"sequence": pred['sequence'],
"score": pred['score'],
"token": pred['token_str']
}
for pred in predictions
]
}
except Exception as e:
logger.error(f"Mask fill error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/embeddings")
async def generate_embeddings(request: EmbeddingRequest):
"""Generate embeddings for French legal texts"""
try:
tokenizer = tokenizers['juribert_base']
model = models['juribert_base']
embeddings = []
for text in request.texts:
# Tokenize
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# Generate embeddings
with torch.no_grad():
outputs = model(**inputs)
# Use CLS token embedding
embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
embeddings.append(embedding.tolist())
return {
"embeddings": embeddings,
"dimension": len(embeddings[0]) if embeddings else 0
}
except Exception as e:
logger.error(f"Embedding error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/ner")
async def extract_entities(request: NERRequest):
"""Extract named entities from French legal text"""
try:
# Use CamemBERT NER model
ner_pipeline = models['camembert_ner']
entities = ner_pipeline(request.text)
# Format results
formatted_entities = []
for entity in entities:
formatted_entities.append({
"text": entity['word'],
"type": entity['entity_group'],
"score": entity['score'],
"start": entity['start'],
"end": entity['end']
})
return {
"entities": formatted_entities,
"text": request.text
}
except Exception as e:
logger.error(f"NER error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/qa")
async def question_answering(request: QARequest):
"""Answer questions about French legal documents"""
try:
# Simple implementation for now
# In production, use a fine-tuned QA model
return {
"question": request.question,
"answer": "This feature requires a fine-tuned QA model. Please check back later.",
"confidence": 0.0,
"relevant_articles": [],
"explanation": "QA model is being fine-tuned on French legal data"
}
except Exception as e:
logger.error(f"QA error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/classify")
async def classify_document(request: ClassificationRequest):
"""Classify French legal documents"""
try:
# Simple keyword-based classification for now
text_lower = request.text.lower()
categories = {
"contract": ["contrat", "accord", "convention", "parties"],
"litigation": ["tribunal", "jugement", "litige", "procès"],
"corporate": ["société", "sarl", "sas", "entreprise"],
"employment": ["travail", "salarié", "employeur", "licenciement"]
}
scores = {}
for category, keywords in categories.items():
score = sum(1 for kw in keywords if kw in text_lower)
if score > 0:
scores[category] = score
if not scores:
primary_category = "general"
else:
primary_category = max(scores, key=scores.get)
return {
"primary_category": primary_category,
"categories": [{"category": cat, "score": score} for cat, score in scores.items()],
"confidence": 0.8 if scores else 0.5,
"document_type": "legal_document",
"legal_domain": primary_category
}
except Exception as e:
logger.error(f"Classification error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/analyze-contract")
async def analyze_contract(request: ContractAnalysisRequest):
"""Analyze French legal contracts"""
try:
# Extract entities first
entities_response = await extract_entities(NERRequest(text=request.text))
# Basic contract analysis
text_lower = request.text.lower()
analysis = {
"contract_type": request.contract_type or "general",
"parties": [e for e in entities_response['entities'] if e['type'] in ['PER', 'ORG']],
"key_clauses": [],
"obligations": [],
"risks": [],
"missing_clauses": [],
"recommendations": [],
"legal_references": []
}
# Check for key clauses
clause_checks = [
("price", ["prix", "montant", "coût"]),
("duration", ["durée", "période", "terme"]),
("termination", ["résiliation", "rupture", "fin"])
]
for clause_name, keywords in clause_checks:
if any(kw in text_lower for kw in keywords):
analysis['key_clauses'].append(clause_name)
else:
analysis['missing_clauses'].append(f"Missing {clause_name} clause")
analysis['recommendations'].append(f"Add {clause_name} clause")
return analysis
except Exception as e:
logger.error(f"Contract analysis error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"models_loaded": list(models.keys()),
"timestamp": datetime.utcnow().isoformat()
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)