File size: 5,314 Bytes
c914f37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# This file contains the endpoint implementations
# In production, merge this with main.py

@app.post("/mask-fill")
async def mask_fill(request: MaskFillRequest):
    """Fill [MASK] tokens in French legal text"""
    try:
        tokenizer = tokenizers['juribert_base']
        model = models['juribert_mlm']
        
        # Create pipeline
        fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)
        
        # Get predictions
        predictions = fill_mask(request.text, top_k=request.top_k)
        
        return {
            "input": request.text,
            "predictions": [
                {
                    "sequence": pred['sequence'],
                    "score": pred['score'],
                    "token": pred['token_str']
                }
                for pred in predictions
            ]
        }
        
    except Exception as e:
        logger.error(f"Mask fill error: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/embeddings")
async def generate_embeddings(request: EmbeddingRequest):
    """Generate embeddings for French legal texts"""
    try:
        tokenizer = tokenizers['juribert_base']
        model = models['juribert_base']
        
        embeddings = []
        for text in request.texts:
            # Tokenize
            inputs = tokenizer(
                text, 
                return_tensors="pt", 
                truncation=True, 
                max_length=512,
                padding=True
            )
            
            # Generate embeddings
            with torch.no_grad():
                outputs = model(**inputs)
                # Use CLS token embedding
                embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
                embeddings.append(embedding.tolist())
        
        return {
            "embeddings": embeddings,
            "dimension": len(embeddings[0]) if embeddings else 0
        }
        
    except Exception as e:
        logger.error(f"Embedding error: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/ner")
async def extract_entities(request: NERRequest):
    """Extract named entities from French legal text"""
    try:
        # Use CamemBERT NER model
        ner_pipeline = models['camembert_ner']
        entities = ner_pipeline(request.text)
        
        # Format results
        formatted_entities = []
        for entity in entities:
            formatted_entities.append({
                "text": entity['word'],
                "type": entity['entity_group'],
                "score": entity['score'],
                "start": entity['start'],
                "end": entity['end']
            })
        
        return {
            "entities": formatted_entities,
            "text": request.text
        }
        
    except Exception as e:
        logger.error(f"NER error: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/qa")
async def question_answering(request: QARequest):
    """Answer questions about French legal documents"""
    try:
        # Simple implementation for now
        # In production, use a fine-tuned QA model
        
        return {
            "question": request.question,
            "answer": "This feature requires a fine-tuned QA model. Please check back later.",
            "confidence": 0.0,
            "relevant_articles": [],
            "explanation": "QA model is being fine-tuned on French legal data"
        }
        
    except Exception as e:
        logger.error(f"QA error: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/classify")
async def classify_document(request: ClassificationRequest):
    """Classify French legal documents"""
    try:
        # Simple keyword-based classification for now
        text_lower = request.text.lower()
        
        categories = {
            "contract": ["contrat", "accord", "convention", "parties"],
            "litigation": ["tribunal", "jugement", "litige", "procès"],
            "corporate": ["société", "sarl", "sas", "entreprise"],
            "employment": ["travail", "salarié", "employeur", "licenciement"]
        }
        
        scores = {}
        for category, keywords in categories.items():
            score = sum(1 for kw in keywords if kw in text_lower)
            if score > 0:
                scores[category] = score
        
        if not scores:
            primary_category = "general"
        else:
            primary_category = max(scores, key=scores.get)
        
        return {
            "primary_category": primary_category,
            "categories": [{"category": cat, "score": score} for cat, score in scores.items()],
            "confidence": 0.8 if scores else 0.5,
            "document_type": "legal_document",
            "legal_domain": primary_category
        }
        
    except Exception as e:
        logger.error(f"Classification error: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {
        "status": "healthy",
        "models_loaded": list(models.keys()),
        "timestamp": datetime.utcnow().isoformat()
    }

# Add this to main.py when deploying