Spaces:
Sleeping
Sleeping
File size: 4,623 Bytes
4786618 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import logging
from datetime import datetime
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import torch
from transformers import AutoTokenizer, AutoModel, pipeline
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="SobroJuriBert API",
description="French Legal AI API powered by JuriBERT",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model storage
models = {}
tokenizers = {}
# Pydantic models
class TextRequest(BaseModel):
text: str = Field(..., description="Text to analyze")
class NERRequest(BaseModel):
text: str = Field(..., description="Legal text for entity extraction")
class ClassificationRequest(BaseModel):
text: str = Field(..., description="Legal document to classify")
@app.on_event("startup")
async def load_models():
"""Load models on startup"""
logger.info("Starting SobroJuriBert API...")
logger.info("Models will be loaded on demand to save memory")
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"name": "SobroJuriBert API",
"version": "1.0.0",
"description": "French Legal AI API for lawyers",
"status": "operational",
"endpoints": {
"ner": "/ner - Extract legal entities",
"classify": "/classify - Classify legal documents",
"health": "/health - Health check"
}
}
@app.post("/ner")
async def extract_entities(request: NERRequest):
"""Extract named entities from French legal text"""
try:
# Simple entity extraction
import re
entities = []
# Extract dates
dates = re.findall(r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', request.text)
for date in dates:
entities.append({"text": date, "type": "DATE"})
# Extract organizations
orgs = re.findall(r'(?:SARL|SAS|SA|EURL)\s+[\w\s]+', request.text)
for org in orgs:
entities.append({"text": org.strip(), "type": "ORG"})
# Extract courts
courts = re.findall(r'(?:Tribunal|Cour)\s+[\w\s]+?(?=\s|,|\.)', request.text)
for court in courts:
entities.append({"text": court.strip(), "type": "COURT"})
return {
"entities": entities,
"text": request.text,
"message": "Basic entity extraction (full NER model loading on demand)"
}
except Exception as e:
logger.error(f"NER error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/classify")
async def classify_document(request: ClassificationRequest):
"""Classify French legal documents"""
try:
# Simple keyword-based classification
text_lower = request.text.lower()
categories = {
"contract": ["contrat", "accord", "convention", "parties"],
"litigation": ["tribunal", "jugement", "litige", "procès"],
"corporate": ["société", "sarl", "sas", "entreprise"],
"employment": ["travail", "salarié", "employeur", "licenciement"]
}
scores = {}
for category, keywords in categories.items():
score = sum(1 for kw in keywords if kw in text_lower)
if score > 0:
scores[category] = score
if not scores:
primary_category = "general"
else:
primary_category = max(scores, key=scores.get)
return {
"primary_category": primary_category,
"categories": [{"category": cat, "score": score} for cat, score in scores.items()],
"confidence": 0.8 if scores else 0.5,
"document_type": "legal_document"
}
except Exception as e:
logger.error(f"Classification error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"version": "1.0.0",
"message": "SobroJuriBert API is running"
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |