import os import json import logging from datetime import datetime from typing import List, Dict, Any, Optional from fastapi import FastAPI, HTTPException, File, UploadFile, Form from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import torch from transformers import ( AutoTokenizer, AutoModel, AutoModelForMaskedLM, AutoModelForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, pipeline ) import numpy as np # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="SobroJuriBert API", description="French Legal AI API powered by JuriBERT for comprehensive legal document analysis", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global model storage models = {} tokenizers = {} # Pydantic models class TextRequest(BaseModel): text: str = Field(..., description="Text to analyze") class MaskFillRequest(BaseModel): text: str = Field(..., description="Text with [MASK] tokens") top_k: int = Field(5, description="Number of predictions to return") class NERRequest(BaseModel): text: str = Field(..., description="Legal text for entity extraction") class QARequest(BaseModel): context: str = Field(..., description="Legal document context") question: str = Field(..., description="Question about the document") class ClassificationRequest(BaseModel): text: str = Field(..., description="Legal document to classify") class EmbeddingRequest(BaseModel): texts: List[str] = Field(..., description="List of texts to embed") class JurisprudenceSearchRequest(BaseModel): query: str = Field(..., description="Search query") filters: Optional[Dict[str, Any]] = Field(None, description="Filters for search") limit: int = Field(10, description="Number of results") class ContractAnalysisRequest(BaseModel): text: str = Field(..., description="Contract text to analyze") contract_type: Optional[str] = Field(None, description="Type of contract") @app.on_event("startup") async def load_models(): """Load all required models on startup""" logger.info("Loading French legal models...") try: # Load JuriBERT base model for embeddings and mask filling logger.info("Loading JuriBERT base model...") models['juribert_base'] = AutoModel.from_pretrained('dascim/juribert-base', cache_dir="/app/.cache/huggingface") tokenizers['juribert_base'] = AutoTokenizer.from_pretrained('dascim/juribert-base', cache_dir="/app/.cache/huggingface") models['juribert_mlm'] = AutoModelForMaskedLM.from_pretrained('dascim/juribert-base', cache_dir="/app/.cache/huggingface") # Load CamemBERT models as fallback/complement logger.info("Loading CamemBERT models...") models['camembert_ner'] = pipeline( 'ner', model='Jean-Baptiste/camembert-ner-with-dates', aggregation_strategy="simple", model_kwargs={"cache_dir": "/app/.cache/huggingface"} ) logger.info("Models loaded successfully!") except Exception as e: logger.error(f"Error loading models: {e}") # Don't crash completely, allow basic endpoints to work logger.warning("Running in limited mode without all models") @app.get("/") async def root(): """Root endpoint with API information""" return { "name": "SobroJuriBert API", "version": "1.0.0", "description": "French Legal AI API for lawyers", "endpoints": { "mask_fill": "/mask-fill - Fill masked tokens in legal text", "embeddings": "/embeddings - Generate legal text embeddings", "ner": "/ner - Extract legal entities", "qa": "/qa - Answer questions about legal documents", "classify": "/classify - Classify legal documents", "analyze_contract": "/analyze-contract - Analyze legal contracts", "search_jurisprudence": "/search-jurisprudence - Search case law", "extract_articles": "/extract-articles - Extract legal article references", "check_compliance": "/check-compliance - Check legal compliance", "generate_summary": "/generate-summary - Generate legal summaries" }, "models": { "base": "dascim/juribert-base", "ner": "Jean-Baptiste/camembert-ner-with-dates", "training_data": "6.3GB French legal texts from Légifrance + 100k+ court decisions" } } @app.post("/mask-fill") async def mask_fill(request: MaskFillRequest): """Fill [MASK] tokens in French legal text""" try: tokenizer = tokenizers['juribert_base'] model = models['juribert_mlm'] # Create pipeline fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer) # Get predictions predictions = fill_mask(request.text, top_k=request.top_k) return { "input": request.text, "predictions": [ { "sequence": pred['sequence'], "score": pred['score'], "token": pred['token_str'] } for pred in predictions ] } except Exception as e: logger.error(f"Mask fill error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/embeddings") async def generate_embeddings(request: EmbeddingRequest): """Generate embeddings for French legal texts""" try: tokenizer = tokenizers['juribert_base'] model = models['juribert_base'] embeddings = [] for text in request.texts: # Tokenize inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=512, padding=True ) # Generate embeddings with torch.no_grad(): outputs = model(**inputs) # Use CLS token embedding embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy() embeddings.append(embedding.tolist()) return { "embeddings": embeddings, "dimension": len(embeddings[0]) if embeddings else 0 } except Exception as e: logger.error(f"Embedding error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/ner") async def extract_entities(request: NERRequest): """Extract named entities from French legal text""" try: # Use CamemBERT NER model ner_pipeline = models['camembert_ner'] entities = ner_pipeline(request.text) # Format results formatted_entities = [] for entity in entities: formatted_entities.append({ "text": entity['word'], "type": entity['entity_group'], "score": entity['score'], "start": entity['start'], "end": entity['end'] }) return { "entities": formatted_entities, "text": request.text } except Exception as e: logger.error(f"NER error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/qa") async def question_answering(request: QARequest): """Answer questions about French legal documents""" try: # Simple implementation for now # In production, use a fine-tuned QA model return { "question": request.question, "answer": "This feature requires a fine-tuned QA model. Please check back later.", "confidence": 0.0, "relevant_articles": [], "explanation": "QA model is being fine-tuned on French legal data" } except Exception as e: logger.error(f"QA error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/classify") async def classify_document(request: ClassificationRequest): """Classify French legal documents""" try: # Simple keyword-based classification for now text_lower = request.text.lower() categories = { "contract": ["contrat", "accord", "convention", "parties"], "litigation": ["tribunal", "jugement", "litige", "procès"], "corporate": ["société", "sarl", "sas", "entreprise"], "employment": ["travail", "salarié", "employeur", "licenciement"] } scores = {} for category, keywords in categories.items(): score = sum(1 for kw in keywords if kw in text_lower) if score > 0: scores[category] = score if not scores: primary_category = "general" else: primary_category = max(scores, key=scores.get) return { "primary_category": primary_category, "categories": [{"category": cat, "score": score} for cat, score in scores.items()], "confidence": 0.8 if scores else 0.5, "document_type": "legal_document", "legal_domain": primary_category } except Exception as e: logger.error(f"Classification error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/analyze-contract") async def analyze_contract(request: ContractAnalysisRequest): """Analyze French legal contracts""" try: # Extract entities first entities_response = await extract_entities(NERRequest(text=request.text)) # Basic contract analysis text_lower = request.text.lower() analysis = { "contract_type": request.contract_type or "general", "parties": [e for e in entities_response['entities'] if e['type'] in ['PER', 'ORG']], "key_clauses": [], "obligations": [], "risks": [], "missing_clauses": [], "recommendations": [], "legal_references": [] } # Check for key clauses clause_checks = [ ("price", ["prix", "montant", "coût"]), ("duration", ["durée", "période", "terme"]), ("termination", ["résiliation", "rupture", "fin"]) ] for clause_name, keywords in clause_checks: if any(kw in text_lower for kw in keywords): analysis['key_clauses'].append(clause_name) else: analysis['missing_clauses'].append(f"Missing {clause_name} clause") analysis['recommendations'].append(f"Add {clause_name} clause") return analysis except Exception as e: logger.error(f"Contract analysis error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "models_loaded": list(models.keys()), "timestamp": datetime.utcnow().isoformat() } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)