Spaces:
Sleeping
Sleeping
import os | |
import re | |
import logging | |
from datetime import datetime | |
from typing import List, Dict, Any, Optional | |
from fastapi import FastAPI, HTTPException, File, UploadFile | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModel, | |
AutoModelForMaskedLM, | |
pipeline | |
) | |
import numpy as np | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="SobroJuriBert API - Full Version", | |
description="French Legal AI API powered by JuriBERT with complete functionality", | |
version="2.0.0" | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global model storage | |
models = {} | |
tokenizers = {} | |
models_loaded = False | |
# Pydantic models | |
class TextRequest(BaseModel): | |
text: str = Field(..., description="Text to analyze") | |
class MaskFillRequest(BaseModel): | |
text: str = Field(..., description="Text with [MASK] tokens") | |
top_k: int = Field(5, description="Number of predictions to return") | |
class NERRequest(BaseModel): | |
text: str = Field(..., description="Legal text for entity extraction") | |
class QARequest(BaseModel): | |
context: str = Field(..., description="Legal document context") | |
question: str = Field(..., description="Question about the document") | |
class ClassificationRequest(BaseModel): | |
text: str = Field(..., description="Legal document to classify") | |
class EmbeddingRequest(BaseModel): | |
texts: List[str] = Field(..., description="List of texts to embed") | |
async def load_models_on_demand(): | |
"""Load models on first request""" | |
global models_loaded | |
if models_loaded: | |
return | |
logger.info("Loading JuriBERT models on demand...") | |
try: | |
# Load JuriBERT for embeddings and mask filling | |
models['juribert_base'] = AutoModel.from_pretrained( | |
'dascim/juribert-base', | |
cache_dir="/app/.cache/huggingface" | |
) | |
tokenizers['juribert_base'] = AutoTokenizer.from_pretrained( | |
'dascim/juribert-base', | |
cache_dir="/app/.cache/huggingface" | |
) | |
models['juribert_mlm'] = AutoModelForMaskedLM.from_pretrained( | |
'dascim/juribert-base', | |
cache_dir="/app/.cache/huggingface" | |
) | |
models_loaded = True | |
logger.info("JuriBERT models loaded successfully!") | |
except Exception as e: | |
logger.error(f"Error loading models: {e}") | |
raise HTTPException(status_code=503, detail="Models could not be loaded") | |
async def root(): | |
"""Root endpoint with API information""" | |
return { | |
"name": "SobroJuriBert API - Full Version", | |
"version": "2.0.0", | |
"description": "Complete French Legal AI API", | |
"status": "operational", | |
"endpoints": { | |
"mask_fill": "/mask-fill - Fill masked tokens in legal text", | |
"embeddings": "/embeddings - Generate legal text embeddings", | |
"ner": "/ner - Extract legal entities (enhanced)", | |
"qa": "/qa - Answer questions about legal documents", | |
"classify": "/classify - Classify legal documents", | |
"health": "/health - Health check" | |
}, | |
"models": { | |
"base": "dascim/juribert-base", | |
"status": "loaded" if models_loaded else "on-demand" | |
} | |
} | |
async def mask_fill(request: MaskFillRequest): | |
"""Fill [MASK] tokens in French legal text using JuriBERT""" | |
await load_models_on_demand() | |
try: | |
tokenizer = tokenizers['juribert_base'] | |
model = models['juribert_mlm'] | |
# Create pipeline | |
fill_mask = pipeline( | |
'fill-mask', | |
model=model, | |
tokenizer=tokenizer, | |
device=-1 # CPU | |
) | |
# Get predictions | |
predictions = fill_mask(request.text, top_k=request.top_k) | |
return { | |
"input": request.text, | |
"predictions": [ | |
{ | |
"sequence": pred['sequence'], | |
"score": float(pred['score']), | |
"token": pred['token_str'] | |
} | |
for pred in predictions | |
] | |
} | |
except Exception as e: | |
logger.error(f"Mask fill error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_embeddings(request: EmbeddingRequest): | |
"""Generate embeddings for French legal texts using JuriBERT""" | |
await load_models_on_demand() | |
try: | |
tokenizer = tokenizers['juribert_base'] | |
model = models['juribert_base'] | |
embeddings = [] | |
for text in request.texts: | |
# Tokenize | |
inputs = tokenizer( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
padding=True | |
) | |
# Generate embeddings | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Use mean pooling | |
attention_mask = inputs['attention_mask'] | |
token_embeddings = outputs.last_hidden_state | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
embeddings.append(embedding.squeeze().numpy().tolist()) | |
return { | |
"embeddings": embeddings, | |
"dimension": len(embeddings[0]) if embeddings else 0, | |
"model": "juribert-base" | |
} | |
except Exception as e: | |
logger.error(f"Embedding error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
def extract_enhanced_entities(text: str) -> List[Dict[str, Any]]: | |
"""Enhanced entity extraction for French legal text""" | |
entities = [] | |
# Extract persons (PER) | |
person_patterns = [ | |
r'\b(?:M\.|Mme|Mlle|Me|Dr|Prof\.?)\s+[A-Z][a-zÀ-ÿ]+(?:\s+[A-Z][a-zÀ-ÿ]+)*', | |
r'\b[A-Z][a-zÀ-ÿ]+\s+[A-Z][A-Z]+\b', # Jean DUPONT | |
] | |
for pattern in person_patterns: | |
for match in re.finditer(pattern, text): | |
entities.append({ | |
"text": match.group(), | |
"type": "PER", | |
"start": match.start(), | |
"end": match.end() | |
}) | |
# Extract money amounts (MONEY) | |
money_patterns = [ | |
r'\b\d{1,3}(?:\s?\d{3})*(?:[,\.]\d{2})?\s?(?:€|EUR|euros?)\b', | |
r'\b(?:€|EUR)\s?\d{1,3}(?:\s?\d{3})*(?:[,\.]\d{2})?\b', | |
] | |
for pattern in money_patterns: | |
for match in re.finditer(pattern, text, re.IGNORECASE): | |
entities.append({ | |
"text": match.group(), | |
"type": "MONEY", | |
"start": match.start(), | |
"end": match.end() | |
}) | |
# Extract legal references (LEGAL_REF) | |
legal_patterns = [ | |
r'article\s+(?:L\.?)?\d+(?:-\d+)?(?:\s+(?:alinéa|al\.)\s+\d+)?', | |
r'articles?\s+\d+\s+(?:à|et)\s+\d+', | |
r'(?:loi|décret|ordonnance)\s+n°\s*\d{4}-\d+', | |
r'directive\s+\d{4}/\d+/[A-Z]+', | |
] | |
for pattern in legal_patterns: | |
for match in re.finditer(pattern, text, re.IGNORECASE): | |
entities.append({ | |
"text": match.group(), | |
"type": "LEGAL_REF", | |
"start": match.start(), | |
"end": match.end() | |
}) | |
# Extract dates (DATE) | |
date_patterns = [ | |
r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', | |
r'\b\d{1,2}\s+(?:janvier|février|mars|avril|mai|juin|juillet|août|septembre|octobre|novembre|décembre)\s+\d{4}\b', | |
] | |
for pattern in date_patterns: | |
for match in re.finditer(pattern, text, re.IGNORECASE): | |
entities.append({ | |
"text": match.group(), | |
"type": "DATE", | |
"start": match.start(), | |
"end": match.end() | |
}) | |
# Extract organizations (ORG) | |
org_patterns = [ | |
r'\b(?:SARL|SAS|SA|EURL|SCI|SASU|SNC)\s+[A-Z][A-Za-zÀ-ÿ\s&\'-]+', | |
r'\b(?:Société|Entreprise|Compagnie|Association)\s+[A-Z][A-Za-zÀ-ÿ\s&\'-]+', | |
] | |
for pattern in org_patterns: | |
for match in re.finditer(pattern, text): | |
entities.append({ | |
"text": match.group(), | |
"type": "ORG", | |
"start": match.start(), | |
"end": match.end() | |
}) | |
# Extract courts (COURT) | |
court_patterns = [ | |
r'(?:Cour|Tribunal|Conseil)\s+(?:de\s+)?[A-Za-zÀ-ÿ\s\'-]+?(?=\s|,|\.)', | |
] | |
for pattern in court_patterns: | |
for match in re.finditer(pattern, text, re.IGNORECASE): | |
entities.append({ | |
"text": match.group().strip(), | |
"type": "COURT", | |
"start": match.start(), | |
"end": match.end() | |
}) | |
# Remove duplicates and sort by position | |
seen = set() | |
unique_entities = [] | |
for ent in sorted(entities, key=lambda x: x['start']): | |
key = (ent['text'], ent['type'], ent['start']) | |
if key not in seen: | |
seen.add(key) | |
unique_entities.append(ent) | |
return unique_entities | |
async def extract_entities(request: NERRequest): | |
"""Enhanced NER for French legal text""" | |
try: | |
entities = extract_enhanced_entities(request.text) | |
# Group by type for summary | |
entity_summary = {} | |
for ent in entities: | |
if ent['type'] not in entity_summary: | |
entity_summary[ent['type']] = [] | |
entity_summary[ent['type']].append(ent['text']) | |
return { | |
"entities": entities, | |
"summary": { | |
ent_type: list(set(texts)) # Unique entities per type | |
for ent_type, texts in entity_summary.items() | |
}, | |
"total": len(entities), | |
"text": request.text | |
} | |
except Exception as e: | |
logger.error(f"NER error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def question_answering(request: QARequest): | |
"""Answer questions about French legal documents""" | |
await load_models_on_demand() | |
try: | |
# Generate embeddings for context and question | |
embedding_req = EmbeddingRequest(texts=[request.context, request.question]) | |
embeddings = await generate_embeddings(embedding_req) | |
context_emb = np.array(embeddings['embeddings'][0]) | |
question_emb = np.array(embeddings['embeddings'][1]) | |
# Calculate similarity | |
similarity = np.dot(context_emb, question_emb) / (np.linalg.norm(context_emb) * np.linalg.norm(question_emb)) | |
# Extract relevant part of context based on question keywords | |
question_words = set(request.question.lower().split()) | |
sentences = request.context.split('.') | |
relevant_sentences = [] | |
for sent in sentences: | |
sent_words = set(sent.lower().split()) | |
overlap = len(question_words & sent_words) | |
if overlap > 0: | |
relevant_sentences.append((sent.strip(), overlap)) | |
# Sort by relevance | |
relevant_sentences.sort(key=lambda x: x[1], reverse=True) | |
if relevant_sentences: | |
answer = relevant_sentences[0][0] | |
confidence = min(0.9, similarity + 0.3) | |
else: | |
answer = "Aucune réponse trouvée dans le contexte fourni." | |
confidence = 0.1 | |
return { | |
"question": request.question, | |
"answer": answer, | |
"confidence": float(confidence), | |
"context_relevance": float(similarity), | |
"model": "juribert-base (similarity-based QA)" | |
} | |
except Exception as e: | |
logger.error(f"QA error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def classify_document(request: ClassificationRequest): | |
"""Enhanced document classification""" | |
try: | |
text_lower = request.text.lower() | |
# Enhanced categories with more keywords | |
categories = { | |
"contract": { | |
"keywords": ["contrat", "accord", "convention", "parties", "obligations", "clause", "engagement"], | |
"weight": 1.0 | |
}, | |
"litigation": { | |
"keywords": ["tribunal", "jugement", "litige", "procès", "avocat", "défendeur", "demandeur", "arrêt", "décision"], | |
"weight": 1.2 | |
}, | |
"corporate": { | |
"keywords": ["société", "sarl", "sas", "entreprise", "capital", "associés", "statuts", "assemblée"], | |
"weight": 1.0 | |
}, | |
"employment": { | |
"keywords": ["travail", "salarié", "employeur", "licenciement", "contrat de travail", "cdi", "cdd", "rupture"], | |
"weight": 1.1 | |
}, | |
"real_estate": { | |
"keywords": ["immobilier", "location", "bail", "propriété", "locataire", "propriétaire", "loyer"], | |
"weight": 1.0 | |
}, | |
"intellectual_property": { | |
"keywords": ["brevet", "marque", "propriété intellectuelle", "invention", "droit d'auteur", "œuvre"], | |
"weight": 1.0 | |
} | |
} | |
scores = {} | |
matched_keywords = {} | |
for category, info in categories.items(): | |
score = 0 | |
keywords_found = [] | |
for keyword in info['keywords']: | |
if keyword in text_lower: | |
count = text_lower.count(keyword) | |
score += count * info['weight'] | |
keywords_found.append(keyword) | |
if score > 0: | |
scores[category] = score | |
matched_keywords[category] = keywords_found | |
if not scores: | |
primary_category = "general" | |
confidence = 0.3 | |
else: | |
total_score = sum(scores.values()) | |
primary_category = max(scores, key=scores.get) | |
confidence = min(0.95, scores[primary_category] / total_score + 0.2) | |
return { | |
"primary_category": primary_category, | |
"categories": [ | |
{ | |
"category": cat, | |
"score": score, | |
"keywords_found": matched_keywords.get(cat, []) | |
} | |
for cat, score in sorted(scores.items(), key=lambda x: x[1], reverse=True) | |
], | |
"confidence": float(confidence), | |
"document_type": "legal_document" | |
} | |
except Exception as e: | |
logger.error(f"Classification error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
"""Health check endpoint""" | |
return { | |
"status": "healthy", | |
"timestamp": datetime.utcnow().isoformat(), | |
"version": "2.0.0", | |
"models_loaded": models_loaded, | |
"available_models": list(models.keys()) | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |