Spaces:
Sleeping
Sleeping
import os | |
import json | |
import logging | |
from datetime import datetime | |
from typing import List, Dict, Any, Optional | |
from fastapi import FastAPI, HTTPException, File, UploadFile, Form | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModel, | |
AutoModelForMaskedLM, | |
AutoModelForTokenClassification, | |
AutoModelForQuestionAnswering, | |
AutoModelForSequenceClassification, | |
pipeline | |
) | |
import numpy as np | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="SobroJuriBert API", | |
description="French Legal AI API powered by JuriBERT for comprehensive legal document analysis", | |
version="1.0.0" | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global model storage | |
models = {} | |
tokenizers = {} | |
# Pydantic models | |
class TextRequest(BaseModel): | |
text: str = Field(..., description="Text to analyze") | |
class MaskFillRequest(BaseModel): | |
text: str = Field(..., description="Text with [MASK] tokens") | |
top_k: int = Field(5, description="Number of predictions to return") | |
class NERRequest(BaseModel): | |
text: str = Field(..., description="Legal text for entity extraction") | |
class QARequest(BaseModel): | |
context: str = Field(..., description="Legal document context") | |
question: str = Field(..., description="Question about the document") | |
class ClassificationRequest(BaseModel): | |
text: str = Field(..., description="Legal document to classify") | |
class EmbeddingRequest(BaseModel): | |
texts: List[str] = Field(..., description="List of texts to embed") | |
class JurisprudenceSearchRequest(BaseModel): | |
query: str = Field(..., description="Search query") | |
filters: Optional[Dict[str, Any]] = Field(None, description="Filters for search") | |
limit: int = Field(10, description="Number of results") | |
class ContractAnalysisRequest(BaseModel): | |
text: str = Field(..., description="Contract text to analyze") | |
contract_type: Optional[str] = Field(None, description="Type of contract") | |
async def load_models(): | |
"""Load all required models on startup""" | |
logger.info("Loading French legal models...") | |
try: | |
# Load JuriBERT base model for embeddings and mask filling | |
logger.info("Loading JuriBERT base model...") | |
models['juribert_base'] = AutoModel.from_pretrained('dascim/juribert-base') | |
tokenizers['juribert_base'] = AutoTokenizer.from_pretrained('dascim/juribert-base') | |
models['juribert_mlm'] = AutoModelForMaskedLM.from_pretrained('dascim/juribert-base') | |
# Load CamemBERT models as fallback/complement | |
logger.info("Loading CamemBERT models...") | |
models['camembert_ner'] = pipeline( | |
'ner', | |
model='Jean-Baptiste/camembert-ner-with-dates', | |
aggregation_strategy="simple" | |
) | |
# Load legal-specific models | |
logger.info("Loading French legal classification model...") | |
models['legal_classifier'] = pipeline( | |
'text-classification', | |
model='nlptown/bert-base-multilingual-uncased-sentiment' # Placeholder | |
) | |
logger.info("All models loaded successfully!") | |
except Exception as e: | |
logger.error(f"Error loading models: {e}") | |
raise | |
async def root(): | |
"""Root endpoint with API information""" | |
return { | |
"name": "SobroJuriBert API", | |
"version": "1.0.0", | |
"description": "French Legal AI API for lawyers", | |
"endpoints": { | |
"mask_fill": "/mask-fill - Fill masked tokens in legal text", | |
"embeddings": "/embeddings - Generate legal text embeddings", | |
"ner": "/ner - Extract legal entities", | |
"qa": "/qa - Answer questions about legal documents", | |
"classify": "/classify - Classify legal documents", | |
"analyze_contract": "/analyze-contract - Analyze legal contracts", | |
"search_jurisprudence": "/search-jurisprudence - Search case law", | |
"extract_articles": "/extract-articles - Extract legal article references", | |
"check_compliance": "/check-compliance - Check legal compliance", | |
"generate_summary": "/generate-summary - Generate legal summaries" | |
}, | |
"models": { | |
"base": "dascim/juribert-base", | |
"ner": "Jean-Baptiste/camembert-ner-with-dates", | |
"training_data": "6.3GB French legal texts from Légifrance + 100k+ court decisions" | |
} | |
} | |
async def mask_fill(request: MaskFillRequest): | |
"""Fill [MASK] tokens in French legal text""" | |
try: | |
tokenizer = tokenizers['juribert_base'] | |
model = models['juribert_mlm'] | |
# Create pipeline | |
fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer) | |
# Get predictions | |
predictions = fill_mask(request.text, top_k=request.top_k) | |
return { | |
"input": request.text, | |
"predictions": [ | |
{ | |
"sequence": pred['sequence'], | |
"score": pred['score'], | |
"token": pred['token_str'] | |
} | |
for pred in predictions | |
] | |
} | |
except Exception as e: | |
logger.error(f"Mask fill error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_embeddings(request: EmbeddingRequest): | |
"""Generate embeddings for French legal texts""" | |
try: | |
tokenizer = tokenizers['juribert_base'] | |
model = models['juribert_base'] | |
embeddings = [] | |
for text in request.texts: | |
# Tokenize | |
inputs = tokenizer( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
padding=True | |
) | |
# Generate embeddings | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Use CLS token embedding | |
embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy() | |
embeddings.append(embedding.tolist()) | |
return { | |
"embeddings": embeddings, | |
"dimension": len(embeddings[0]) if embeddings else 0 | |
} | |
except Exception as e: | |
logger.error(f"Embedding error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def extract_entities(request: NERRequest): | |
"""Extract named entities from French legal text""" | |
try: | |
# Use CamemBERT NER model | |
ner_pipeline = models['camembert_ner'] | |
entities = ner_pipeline(request.text) | |
# Format results | |
formatted_entities = [] | |
for entity in entities: | |
formatted_entities.append({ | |
"text": entity['word'], | |
"type": entity['entity_group'], | |
"score": entity['score'], | |
"start": entity['start'], | |
"end": entity['end'] | |
}) | |
return { | |
"entities": formatted_entities, | |
"text": request.text | |
} | |
except Exception as e: | |
logger.error(f"NER error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def question_answering(request: QARequest): | |
"""Answer questions about French legal documents""" | |
try: | |
# Simple implementation for now | |
# In production, use a fine-tuned QA model | |
return { | |
"question": request.question, | |
"answer": "This feature requires a fine-tuned QA model. Please check back later.", | |
"confidence": 0.0, | |
"relevant_articles": [], | |
"explanation": "QA model is being fine-tuned on French legal data" | |
} | |
except Exception as e: | |
logger.error(f"QA error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def classify_document(request: ClassificationRequest): | |
"""Classify French legal documents""" | |
try: | |
# Simple keyword-based classification for now | |
text_lower = request.text.lower() | |
categories = { | |
"contract": ["contrat", "accord", "convention", "parties"], | |
"litigation": ["tribunal", "jugement", "litige", "procès"], | |
"corporate": ["société", "sarl", "sas", "entreprise"], | |
"employment": ["travail", "salarié", "employeur", "licenciement"] | |
} | |
scores = {} | |
for category, keywords in categories.items(): | |
score = sum(1 for kw in keywords if kw in text_lower) | |
if score > 0: | |
scores[category] = score | |
if not scores: | |
primary_category = "general" | |
else: | |
primary_category = max(scores, key=scores.get) | |
return { | |
"primary_category": primary_category, | |
"categories": [{"category": cat, "score": score} for cat, score in scores.items()], | |
"confidence": 0.8 if scores else 0.5, | |
"document_type": "legal_document", | |
"legal_domain": primary_category | |
} | |
except Exception as e: | |
logger.error(f"Classification error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def analyze_contract(request: ContractAnalysisRequest): | |
"""Analyze French legal contracts""" | |
try: | |
# Extract entities first | |
entities_response = await extract_entities(NERRequest(text=request.text)) | |
# Basic contract analysis | |
text_lower = request.text.lower() | |
analysis = { | |
"contract_type": request.contract_type or "general", | |
"parties": [e for e in entities_response['entities'] if e['type'] in ['PER', 'ORG']], | |
"key_clauses": [], | |
"obligations": [], | |
"risks": [], | |
"missing_clauses": [], | |
"recommendations": [], | |
"legal_references": [] | |
} | |
# Check for key clauses | |
clause_checks = [ | |
("price", ["prix", "montant", "coût"]), | |
("duration", ["durée", "période", "terme"]), | |
("termination", ["résiliation", "rupture", "fin"]) | |
] | |
for clause_name, keywords in clause_checks: | |
if any(kw in text_lower for kw in keywords): | |
analysis['key_clauses'].append(clause_name) | |
else: | |
analysis['missing_clauses'].append(f"Missing {clause_name} clause") | |
analysis['recommendations'].append(f"Add {clause_name} clause") | |
return analysis | |
except Exception as e: | |
logger.error(f"Contract analysis error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
"""Health check endpoint""" | |
return { | |
"status": "healthy", | |
"models_loaded": list(models.keys()), | |
"timestamp": datetime.utcnow().isoformat() | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |