Sobro Inc commited on
Commit
967a5fb
·
1 Parent(s): 4786618

Add full version with JuriBERT - mask filling, embeddings, enhanced NER, QA

Browse files
Files changed (2) hide show
  1. Dockerfile +3 -2
  2. main_full.py +460 -0
Dockerfile CHANGED
@@ -26,6 +26,7 @@ RUN mkdir -p /app/.cache && chown -R user:user /app/.cache
26
  COPY --chown=user:user app/ ./app/
27
  COPY --chown=user:user main.py .
28
  COPY --chown=user:user main_simple.py .
 
29
 
30
  # Switch to user
31
  USER user
@@ -38,5 +39,5 @@ ENV PYTHONUNBUFFERED=1
38
  # Expose port
39
  EXPOSE 7860
40
 
41
- # Run the application (using simple version first)
42
- CMD ["uvicorn", "main_simple:app", "--host", "0.0.0.0", "--port", "7860"]
 
26
  COPY --chown=user:user app/ ./app/
27
  COPY --chown=user:user main.py .
28
  COPY --chown=user:user main_simple.py .
29
+ COPY --chown=user:user main_full.py .
30
 
31
  # Switch to user
32
  USER user
 
39
  # Expose port
40
  EXPOSE 7860
41
 
42
+ # Run the application (using full version with on-demand loading)
43
+ CMD ["uvicorn", "main_full:app", "--host", "0.0.0.0", "--port", "7860"]
main_full.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import logging
4
+ from datetime import datetime
5
+ from typing import List, Dict, Any, Optional
6
+ from fastapi import FastAPI, HTTPException, File, UploadFile
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel, Field
9
+ import torch
10
+ from transformers import (
11
+ AutoTokenizer,
12
+ AutoModel,
13
+ AutoModelForMaskedLM,
14
+ pipeline
15
+ )
16
+ import numpy as np
17
+
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Initialize FastAPI app
23
+ app = FastAPI(
24
+ title="SobroJuriBert API - Full Version",
25
+ description="French Legal AI API powered by JuriBERT with complete functionality",
26
+ version="2.0.0"
27
+ )
28
+
29
+ # Add CORS middleware
30
+ app.add_middleware(
31
+ CORSMiddleware,
32
+ allow_origins=["*"],
33
+ allow_credentials=True,
34
+ allow_methods=["*"],
35
+ allow_headers=["*"],
36
+ )
37
+
38
+ # Global model storage
39
+ models = {}
40
+ tokenizers = {}
41
+ models_loaded = False
42
+
43
+ # Pydantic models
44
+ class TextRequest(BaseModel):
45
+ text: str = Field(..., description="Text to analyze")
46
+
47
+ class MaskFillRequest(BaseModel):
48
+ text: str = Field(..., description="Text with [MASK] tokens")
49
+ top_k: int = Field(5, description="Number of predictions to return")
50
+
51
+ class NERRequest(BaseModel):
52
+ text: str = Field(..., description="Legal text for entity extraction")
53
+
54
+ class QARequest(BaseModel):
55
+ context: str = Field(..., description="Legal document context")
56
+ question: str = Field(..., description="Question about the document")
57
+
58
+ class ClassificationRequest(BaseModel):
59
+ text: str = Field(..., description="Legal document to classify")
60
+
61
+ class EmbeddingRequest(BaseModel):
62
+ texts: List[str] = Field(..., description="List of texts to embed")
63
+
64
+ async def load_models_on_demand():
65
+ """Load models on first request"""
66
+ global models_loaded
67
+ if models_loaded:
68
+ return
69
+
70
+ logger.info("Loading JuriBERT models on demand...")
71
+ try:
72
+ # Load JuriBERT for embeddings and mask filling
73
+ models['juribert_base'] = AutoModel.from_pretrained(
74
+ 'dascim/juribert-base',
75
+ cache_dir="/app/.cache/huggingface"
76
+ )
77
+ tokenizers['juribert_base'] = AutoTokenizer.from_pretrained(
78
+ 'dascim/juribert-base',
79
+ cache_dir="/app/.cache/huggingface"
80
+ )
81
+ models['juribert_mlm'] = AutoModelForMaskedLM.from_pretrained(
82
+ 'dascim/juribert-base',
83
+ cache_dir="/app/.cache/huggingface"
84
+ )
85
+ models_loaded = True
86
+ logger.info("JuriBERT models loaded successfully!")
87
+ except Exception as e:
88
+ logger.error(f"Error loading models: {e}")
89
+ raise HTTPException(status_code=503, detail="Models could not be loaded")
90
+
91
+ @app.get("/")
92
+ async def root():
93
+ """Root endpoint with API information"""
94
+ return {
95
+ "name": "SobroJuriBert API - Full Version",
96
+ "version": "2.0.0",
97
+ "description": "Complete French Legal AI API",
98
+ "status": "operational",
99
+ "endpoints": {
100
+ "mask_fill": "/mask-fill - Fill masked tokens in legal text",
101
+ "embeddings": "/embeddings - Generate legal text embeddings",
102
+ "ner": "/ner - Extract legal entities (enhanced)",
103
+ "qa": "/qa - Answer questions about legal documents",
104
+ "classify": "/classify - Classify legal documents",
105
+ "health": "/health - Health check"
106
+ },
107
+ "models": {
108
+ "base": "dascim/juribert-base",
109
+ "status": "loaded" if models_loaded else "on-demand"
110
+ }
111
+ }
112
+
113
+ @app.post("/mask-fill")
114
+ async def mask_fill(request: MaskFillRequest):
115
+ """Fill [MASK] tokens in French legal text using JuriBERT"""
116
+ await load_models_on_demand()
117
+
118
+ try:
119
+ tokenizer = tokenizers['juribert_base']
120
+ model = models['juribert_mlm']
121
+
122
+ # Create pipeline
123
+ fill_mask = pipeline(
124
+ 'fill-mask',
125
+ model=model,
126
+ tokenizer=tokenizer,
127
+ device=-1 # CPU
128
+ )
129
+
130
+ # Get predictions
131
+ predictions = fill_mask(request.text, top_k=request.top_k)
132
+
133
+ return {
134
+ "input": request.text,
135
+ "predictions": [
136
+ {
137
+ "sequence": pred['sequence'],
138
+ "score": float(pred['score']),
139
+ "token": pred['token_str']
140
+ }
141
+ for pred in predictions
142
+ ]
143
+ }
144
+
145
+ except Exception as e:
146
+ logger.error(f"Mask fill error: {e}")
147
+ raise HTTPException(status_code=500, detail=str(e))
148
+
149
+ @app.post("/embeddings")
150
+ async def generate_embeddings(request: EmbeddingRequest):
151
+ """Generate embeddings for French legal texts using JuriBERT"""
152
+ await load_models_on_demand()
153
+
154
+ try:
155
+ tokenizer = tokenizers['juribert_base']
156
+ model = models['juribert_base']
157
+
158
+ embeddings = []
159
+ for text in request.texts:
160
+ # Tokenize
161
+ inputs = tokenizer(
162
+ text,
163
+ return_tensors="pt",
164
+ truncation=True,
165
+ max_length=512,
166
+ padding=True
167
+ )
168
+
169
+ # Generate embeddings
170
+ with torch.no_grad():
171
+ outputs = model(**inputs)
172
+ # Use mean pooling
173
+ attention_mask = inputs['attention_mask']
174
+ token_embeddings = outputs.last_hidden_state
175
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
176
+ embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
177
+ embeddings.append(embedding.squeeze().numpy().tolist())
178
+
179
+ return {
180
+ "embeddings": embeddings,
181
+ "dimension": len(embeddings[0]) if embeddings else 0,
182
+ "model": "juribert-base"
183
+ }
184
+
185
+ except Exception as e:
186
+ logger.error(f"Embedding error: {e}")
187
+ raise HTTPException(status_code=500, detail=str(e))
188
+
189
+ def extract_enhanced_entities(text: str) -> List[Dict[str, Any]]:
190
+ """Enhanced entity extraction for French legal text"""
191
+ entities = []
192
+
193
+ # Extract persons (PER)
194
+ person_patterns = [
195
+ r'\b(?:M\.|Mme|Mlle|Me|Dr|Prof\.?)\s+[A-Z][a-zÀ-ÿ]+(?:\s+[A-Z][a-zÀ-ÿ]+)*',
196
+ r'\b[A-Z][a-zÀ-ÿ]+\s+[A-Z][A-Z]+\b', # Jean DUPONT
197
+ ]
198
+
199
+ for pattern in person_patterns:
200
+ for match in re.finditer(pattern, text):
201
+ entities.append({
202
+ "text": match.group(),
203
+ "type": "PER",
204
+ "start": match.start(),
205
+ "end": match.end()
206
+ })
207
+
208
+ # Extract money amounts (MONEY)
209
+ money_patterns = [
210
+ r'\b\d{1,3}(?:\s?\d{3})*(?:[,\.]\d{2})?\s?(?:€|EUR|euros?)\b',
211
+ r'\b(?:€|EUR)\s?\d{1,3}(?:\s?\d{3})*(?:[,\.]\d{2})?\b',
212
+ ]
213
+
214
+ for pattern in money_patterns:
215
+ for match in re.finditer(pattern, text, re.IGNORECASE):
216
+ entities.append({
217
+ "text": match.group(),
218
+ "type": "MONEY",
219
+ "start": match.start(),
220
+ "end": match.end()
221
+ })
222
+
223
+ # Extract legal references (LEGAL_REF)
224
+ legal_patterns = [
225
+ r'article\s+(?:L\.?)?\d+(?:-\d+)?(?:\s+(?:alinéa|al\.)\s+\d+)?',
226
+ r'articles?\s+\d+\s+(?:à|et)\s+\d+',
227
+ r'(?:loi|décret|ordonnance)\s+n°\s*\d{4}-\d+',
228
+ r'directive\s+\d{4}/\d+/[A-Z]+',
229
+ ]
230
+
231
+ for pattern in legal_patterns:
232
+ for match in re.finditer(pattern, text, re.IGNORECASE):
233
+ entities.append({
234
+ "text": match.group(),
235
+ "type": "LEGAL_REF",
236
+ "start": match.start(),
237
+ "end": match.end()
238
+ })
239
+
240
+ # Extract dates (DATE)
241
+ date_patterns = [
242
+ r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b',
243
+ r'\b\d{1,2}\s+(?:janvier|février|mars|avril|mai|juin|juillet|août|septembre|octobre|novembre|décembre)\s+\d{4}\b',
244
+ ]
245
+
246
+ for pattern in date_patterns:
247
+ for match in re.finditer(pattern, text, re.IGNORECASE):
248
+ entities.append({
249
+ "text": match.group(),
250
+ "type": "DATE",
251
+ "start": match.start(),
252
+ "end": match.end()
253
+ })
254
+
255
+ # Extract organizations (ORG)
256
+ org_patterns = [
257
+ r'\b(?:SARL|SAS|SA|EURL|SCI|SASU|SNC)\s+[A-Z][A-Za-zÀ-ÿ\s&\'-]+',
258
+ r'\b(?:Société|Entreprise|Compagnie|Association)\s+[A-Z][A-Za-zÀ-ÿ\s&\'-]+',
259
+ ]
260
+
261
+ for pattern in org_patterns:
262
+ for match in re.finditer(pattern, text):
263
+ entities.append({
264
+ "text": match.group(),
265
+ "type": "ORG",
266
+ "start": match.start(),
267
+ "end": match.end()
268
+ })
269
+
270
+ # Extract courts (COURT)
271
+ court_patterns = [
272
+ r'(?:Cour|Tribunal|Conseil)\s+(?:de\s+)?[A-Za-zÀ-ÿ\s\'-]+?(?=\s|,|\.)',
273
+ ]
274
+
275
+ for pattern in court_patterns:
276
+ for match in re.finditer(pattern, text, re.IGNORECASE):
277
+ entities.append({
278
+ "text": match.group().strip(),
279
+ "type": "COURT",
280
+ "start": match.start(),
281
+ "end": match.end()
282
+ })
283
+
284
+ # Remove duplicates and sort by position
285
+ seen = set()
286
+ unique_entities = []
287
+ for ent in sorted(entities, key=lambda x: x['start']):
288
+ key = (ent['text'], ent['type'], ent['start'])
289
+ if key not in seen:
290
+ seen.add(key)
291
+ unique_entities.append(ent)
292
+
293
+ return unique_entities
294
+
295
+ @app.post("/ner")
296
+ async def extract_entities(request: NERRequest):
297
+ """Enhanced NER for French legal text"""
298
+ try:
299
+ entities = extract_enhanced_entities(request.text)
300
+
301
+ # Group by type for summary
302
+ entity_summary = {}
303
+ for ent in entities:
304
+ if ent['type'] not in entity_summary:
305
+ entity_summary[ent['type']] = []
306
+ entity_summary[ent['type']].append(ent['text'])
307
+
308
+ return {
309
+ "entities": entities,
310
+ "summary": {
311
+ ent_type: list(set(texts)) # Unique entities per type
312
+ for ent_type, texts in entity_summary.items()
313
+ },
314
+ "total": len(entities),
315
+ "text": request.text
316
+ }
317
+
318
+ except Exception as e:
319
+ logger.error(f"NER error: {e}")
320
+ raise HTTPException(status_code=500, detail=str(e))
321
+
322
+ @app.post("/qa")
323
+ async def question_answering(request: QARequest):
324
+ """Answer questions about French legal documents"""
325
+ await load_models_on_demand()
326
+
327
+ try:
328
+ # Generate embeddings for context and question
329
+ embedding_req = EmbeddingRequest(texts=[request.context, request.question])
330
+ embeddings = await generate_embeddings(embedding_req)
331
+
332
+ context_emb = np.array(embeddings['embeddings'][0])
333
+ question_emb = np.array(embeddings['embeddings'][1])
334
+
335
+ # Calculate similarity
336
+ similarity = np.dot(context_emb, question_emb) / (np.linalg.norm(context_emb) * np.linalg.norm(question_emb))
337
+
338
+ # Extract relevant part of context based on question keywords
339
+ question_words = set(request.question.lower().split())
340
+ sentences = request.context.split('.')
341
+
342
+ relevant_sentences = []
343
+ for sent in sentences:
344
+ sent_words = set(sent.lower().split())
345
+ overlap = len(question_words & sent_words)
346
+ if overlap > 0:
347
+ relevant_sentences.append((sent.strip(), overlap))
348
+
349
+ # Sort by relevance
350
+ relevant_sentences.sort(key=lambda x: x[1], reverse=True)
351
+
352
+ if relevant_sentences:
353
+ answer = relevant_sentences[0][0]
354
+ confidence = min(0.9, similarity + 0.3)
355
+ else:
356
+ answer = "Aucune réponse trouvée dans le contexte fourni."
357
+ confidence = 0.1
358
+
359
+ return {
360
+ "question": request.question,
361
+ "answer": answer,
362
+ "confidence": float(confidence),
363
+ "context_relevance": float(similarity),
364
+ "model": "juribert-base (similarity-based QA)"
365
+ }
366
+
367
+ except Exception as e:
368
+ logger.error(f"QA error: {e}")
369
+ raise HTTPException(status_code=500, detail=str(e))
370
+
371
+ @app.post("/classify")
372
+ async def classify_document(request: ClassificationRequest):
373
+ """Enhanced document classification"""
374
+ try:
375
+ text_lower = request.text.lower()
376
+
377
+ # Enhanced categories with more keywords
378
+ categories = {
379
+ "contract": {
380
+ "keywords": ["contrat", "accord", "convention", "parties", "obligations", "clause", "engagement"],
381
+ "weight": 1.0
382
+ },
383
+ "litigation": {
384
+ "keywords": ["tribunal", "jugement", "litige", "procès", "avocat", "défendeur", "demandeur", "arrêt", "décision"],
385
+ "weight": 1.2
386
+ },
387
+ "corporate": {
388
+ "keywords": ["société", "sarl", "sas", "entreprise", "capital", "associés", "statuts", "assemblée"],
389
+ "weight": 1.0
390
+ },
391
+ "employment": {
392
+ "keywords": ["travail", "salarié", "employeur", "licenciement", "contrat de travail", "cdi", "cdd", "rupture"],
393
+ "weight": 1.1
394
+ },
395
+ "real_estate": {
396
+ "keywords": ["immobilier", "location", "bail", "propriété", "locataire", "propriétaire", "loyer"],
397
+ "weight": 1.0
398
+ },
399
+ "intellectual_property": {
400
+ "keywords": ["brevet", "marque", "propriété intellectuelle", "invention", "droit d'auteur", "œuvre"],
401
+ "weight": 1.0
402
+ }
403
+ }
404
+
405
+ scores = {}
406
+ matched_keywords = {}
407
+
408
+ for category, info in categories.items():
409
+ score = 0
410
+ keywords_found = []
411
+ for keyword in info['keywords']:
412
+ if keyword in text_lower:
413
+ count = text_lower.count(keyword)
414
+ score += count * info['weight']
415
+ keywords_found.append(keyword)
416
+
417
+ if score > 0:
418
+ scores[category] = score
419
+ matched_keywords[category] = keywords_found
420
+
421
+ if not scores:
422
+ primary_category = "general"
423
+ confidence = 0.3
424
+ else:
425
+ total_score = sum(scores.values())
426
+ primary_category = max(scores, key=scores.get)
427
+ confidence = min(0.95, scores[primary_category] / total_score + 0.2)
428
+
429
+ return {
430
+ "primary_category": primary_category,
431
+ "categories": [
432
+ {
433
+ "category": cat,
434
+ "score": score,
435
+ "keywords_found": matched_keywords.get(cat, [])
436
+ }
437
+ for cat, score in sorted(scores.items(), key=lambda x: x[1], reverse=True)
438
+ ],
439
+ "confidence": float(confidence),
440
+ "document_type": "legal_document"
441
+ }
442
+
443
+ except Exception as e:
444
+ logger.error(f"Classification error: {e}")
445
+ raise HTTPException(status_code=500, detail=str(e))
446
+
447
+ @app.get("/health")
448
+ async def health_check():
449
+ """Health check endpoint"""
450
+ return {
451
+ "status": "healthy",
452
+ "timestamp": datetime.utcnow().isoformat(),
453
+ "version": "2.0.0",
454
+ "models_loaded": models_loaded,
455
+ "available_models": list(models.keys())
456
+ }
457
+
458
+ if __name__ == "__main__":
459
+ import uvicorn
460
+ uvicorn.run(app, host="0.0.0.0", port=7860)