Spaces:
Sleeping
Sleeping
import os | |
import logging | |
from datetime import datetime | |
from typing import List, Dict, Any, Optional | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field | |
import torch | |
from transformers import AutoTokenizer, AutoModel, pipeline | |
import numpy as np | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="SobroJuriBert API", | |
description="French Legal AI API powered by JuriBERT", | |
version="1.0.0" | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global model storage | |
models = {} | |
tokenizers = {} | |
# Pydantic models | |
class TextRequest(BaseModel): | |
text: str = Field(..., description="Text to analyze") | |
class NERRequest(BaseModel): | |
text: str = Field(..., description="Legal text for entity extraction") | |
class ClassificationRequest(BaseModel): | |
text: str = Field(..., description="Legal document to classify") | |
async def load_models(): | |
"""Load models on startup""" | |
logger.info("Starting SobroJuriBert API...") | |
logger.info("Models will be loaded on demand to save memory") | |
async def root(): | |
"""Root endpoint with API information""" | |
return { | |
"name": "SobroJuriBert API", | |
"version": "1.0.0", | |
"description": "French Legal AI API for lawyers", | |
"status": "operational", | |
"endpoints": { | |
"ner": "/ner - Extract legal entities", | |
"classify": "/classify - Classify legal documents", | |
"health": "/health - Health check" | |
} | |
} | |
async def extract_entities(request: NERRequest): | |
"""Extract named entities from French legal text""" | |
try: | |
# Simple entity extraction | |
import re | |
entities = [] | |
# Extract dates | |
dates = re.findall(r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', request.text) | |
for date in dates: | |
entities.append({"text": date, "type": "DATE"}) | |
# Extract organizations | |
orgs = re.findall(r'(?:SARL|SAS|SA|EURL)\s+[\w\s]+', request.text) | |
for org in orgs: | |
entities.append({"text": org.strip(), "type": "ORG"}) | |
# Extract courts | |
courts = re.findall(r'(?:Tribunal|Cour)\s+[\w\s]+?(?=\s|,|\.)', request.text) | |
for court in courts: | |
entities.append({"text": court.strip(), "type": "COURT"}) | |
return { | |
"entities": entities, | |
"text": request.text, | |
"message": "Basic entity extraction (full NER model loading on demand)" | |
} | |
except Exception as e: | |
logger.error(f"NER error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def classify_document(request: ClassificationRequest): | |
"""Classify French legal documents""" | |
try: | |
# Simple keyword-based classification | |
text_lower = request.text.lower() | |
categories = { | |
"contract": ["contrat", "accord", "convention", "parties"], | |
"litigation": ["tribunal", "jugement", "litige", "procès"], | |
"corporate": ["société", "sarl", "sas", "entreprise"], | |
"employment": ["travail", "salarié", "employeur", "licenciement"] | |
} | |
scores = {} | |
for category, keywords in categories.items(): | |
score = sum(1 for kw in keywords if kw in text_lower) | |
if score > 0: | |
scores[category] = score | |
if not scores: | |
primary_category = "general" | |
else: | |
primary_category = max(scores, key=scores.get) | |
return { | |
"primary_category": primary_category, | |
"categories": [{"category": cat, "score": score} for cat, score in scores.items()], | |
"confidence": 0.8 if scores else 0.5, | |
"document_type": "legal_document" | |
} | |
except Exception as e: | |
logger.error(f"Classification error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
"""Health check endpoint""" | |
return { | |
"status": "healthy", | |
"timestamp": datetime.utcnow().isoformat(), | |
"version": "1.0.0", | |
"message": "SobroJuriBert API is running" | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |