SobroJuriBert / main_simple.py
Sobro Inc
Fix permission errors and use simplified version
4786618
raw
history blame
4.62 kB
import os
import logging
from datetime import datetime
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import torch
from transformers import AutoTokenizer, AutoModel, pipeline
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="SobroJuriBert API",
description="French Legal AI API powered by JuriBERT",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model storage
models = {}
tokenizers = {}
# Pydantic models
class TextRequest(BaseModel):
text: str = Field(..., description="Text to analyze")
class NERRequest(BaseModel):
text: str = Field(..., description="Legal text for entity extraction")
class ClassificationRequest(BaseModel):
text: str = Field(..., description="Legal document to classify")
@app.on_event("startup")
async def load_models():
"""Load models on startup"""
logger.info("Starting SobroJuriBert API...")
logger.info("Models will be loaded on demand to save memory")
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"name": "SobroJuriBert API",
"version": "1.0.0",
"description": "French Legal AI API for lawyers",
"status": "operational",
"endpoints": {
"ner": "/ner - Extract legal entities",
"classify": "/classify - Classify legal documents",
"health": "/health - Health check"
}
}
@app.post("/ner")
async def extract_entities(request: NERRequest):
"""Extract named entities from French legal text"""
try:
# Simple entity extraction
import re
entities = []
# Extract dates
dates = re.findall(r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', request.text)
for date in dates:
entities.append({"text": date, "type": "DATE"})
# Extract organizations
orgs = re.findall(r'(?:SARL|SAS|SA|EURL)\s+[\w\s]+', request.text)
for org in orgs:
entities.append({"text": org.strip(), "type": "ORG"})
# Extract courts
courts = re.findall(r'(?:Tribunal|Cour)\s+[\w\s]+?(?=\s|,|\.)', request.text)
for court in courts:
entities.append({"text": court.strip(), "type": "COURT"})
return {
"entities": entities,
"text": request.text,
"message": "Basic entity extraction (full NER model loading on demand)"
}
except Exception as e:
logger.error(f"NER error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/classify")
async def classify_document(request: ClassificationRequest):
"""Classify French legal documents"""
try:
# Simple keyword-based classification
text_lower = request.text.lower()
categories = {
"contract": ["contrat", "accord", "convention", "parties"],
"litigation": ["tribunal", "jugement", "litige", "procès"],
"corporate": ["société", "sarl", "sas", "entreprise"],
"employment": ["travail", "salarié", "employeur", "licenciement"]
}
scores = {}
for category, keywords in categories.items():
score = sum(1 for kw in keywords if kw in text_lower)
if score > 0:
scores[category] = score
if not scores:
primary_category = "general"
else:
primary_category = max(scores, key=scores.get)
return {
"primary_category": primary_category,
"categories": [{"category": cat, "score": score} for cat, score in scores.items()],
"confidence": 0.8 if scores else 0.5,
"document_type": "legal_document"
}
except Exception as e:
logger.error(f"Classification error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"version": "1.0.0",
"message": "SobroJuriBert API is running"
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)