import os import logging from datetime import datetime from typing import List, Dict, Any, Optional from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import torch from transformers import AutoTokenizer, AutoModel, pipeline import numpy as np # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="SobroJuriBert API", description="French Legal AI API powered by JuriBERT", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global model storage models = {} tokenizers = {} # Pydantic models class TextRequest(BaseModel): text: str = Field(..., description="Text to analyze") class NERRequest(BaseModel): text: str = Field(..., description="Legal text for entity extraction") class ClassificationRequest(BaseModel): text: str = Field(..., description="Legal document to classify") @app.on_event("startup") async def load_models(): """Load models on startup""" logger.info("Starting SobroJuriBert API...") logger.info("Models will be loaded on demand to save memory") @app.get("/") async def root(): """Root endpoint with API information""" return { "name": "SobroJuriBert API", "version": "1.0.0", "description": "French Legal AI API for lawyers", "status": "operational", "endpoints": { "ner": "/ner - Extract legal entities", "classify": "/classify - Classify legal documents", "health": "/health - Health check" } } @app.post("/ner") async def extract_entities(request: NERRequest): """Extract named entities from French legal text""" try: # Simple entity extraction import re entities = [] # Extract dates dates = re.findall(r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', request.text) for date in dates: entities.append({"text": date, "type": "DATE"}) # Extract organizations orgs = re.findall(r'(?:SARL|SAS|SA|EURL)\s+[\w\s]+', request.text) for org in orgs: entities.append({"text": org.strip(), "type": "ORG"}) # Extract courts courts = re.findall(r'(?:Tribunal|Cour)\s+[\w\s]+?(?=\s|,|\.)', request.text) for court in courts: entities.append({"text": court.strip(), "type": "COURT"}) return { "entities": entities, "text": request.text, "message": "Basic entity extraction (full NER model loading on demand)" } except Exception as e: logger.error(f"NER error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/classify") async def classify_document(request: ClassificationRequest): """Classify French legal documents""" try: # Simple keyword-based classification text_lower = request.text.lower() categories = { "contract": ["contrat", "accord", "convention", "parties"], "litigation": ["tribunal", "jugement", "litige", "procès"], "corporate": ["société", "sarl", "sas", "entreprise"], "employment": ["travail", "salarié", "employeur", "licenciement"] } scores = {} for category, keywords in categories.items(): score = sum(1 for kw in keywords if kw in text_lower) if score > 0: scores[category] = score if not scores: primary_category = "general" else: primary_category = max(scores, key=scores.get) return { "primary_category": primary_category, "categories": [{"category": cat, "score": score} for cat, score in scores.items()], "confidence": 0.8 if scores else 0.5, "document_type": "legal_document" } except Exception as e: logger.error(f"Classification error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "timestamp": datetime.utcnow().isoformat(), "version": "1.0.0", "message": "SobroJuriBert API is running" } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)