# This file contains the endpoint implementations # In production, merge this with main.py @app.post("/mask-fill") async def mask_fill(request: MaskFillRequest): """Fill [MASK] tokens in French legal text""" try: tokenizer = tokenizers['juribert_base'] model = models['juribert_mlm'] # Create pipeline fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer) # Get predictions predictions = fill_mask(request.text, top_k=request.top_k) return { "input": request.text, "predictions": [ { "sequence": pred['sequence'], "score": pred['score'], "token": pred['token_str'] } for pred in predictions ] } except Exception as e: logger.error(f"Mask fill error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/embeddings") async def generate_embeddings(request: EmbeddingRequest): """Generate embeddings for French legal texts""" try: tokenizer = tokenizers['juribert_base'] model = models['juribert_base'] embeddings = [] for text in request.texts: # Tokenize inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=512, padding=True ) # Generate embeddings with torch.no_grad(): outputs = model(**inputs) # Use CLS token embedding embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy() embeddings.append(embedding.tolist()) return { "embeddings": embeddings, "dimension": len(embeddings[0]) if embeddings else 0 } except Exception as e: logger.error(f"Embedding error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/ner") async def extract_entities(request: NERRequest): """Extract named entities from French legal text""" try: # Use CamemBERT NER model ner_pipeline = models['camembert_ner'] entities = ner_pipeline(request.text) # Format results formatted_entities = [] for entity in entities: formatted_entities.append({ "text": entity['word'], "type": entity['entity_group'], "score": entity['score'], "start": entity['start'], "end": entity['end'] }) return { "entities": formatted_entities, "text": request.text } except Exception as e: logger.error(f"NER error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/qa") async def question_answering(request: QARequest): """Answer questions about French legal documents""" try: # Simple implementation for now # In production, use a fine-tuned QA model return { "question": request.question, "answer": "This feature requires a fine-tuned QA model. Please check back later.", "confidence": 0.0, "relevant_articles": [], "explanation": "QA model is being fine-tuned on French legal data" } except Exception as e: logger.error(f"QA error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/classify") async def classify_document(request: ClassificationRequest): """Classify French legal documents""" try: # Simple keyword-based classification for now text_lower = request.text.lower() categories = { "contract": ["contrat", "accord", "convention", "parties"], "litigation": ["tribunal", "jugement", "litige", "procès"], "corporate": ["société", "sarl", "sas", "entreprise"], "employment": ["travail", "salarié", "employeur", "licenciement"] } scores = {} for category, keywords in categories.items(): score = sum(1 for kw in keywords if kw in text_lower) if score > 0: scores[category] = score if not scores: primary_category = "general" else: primary_category = max(scores, key=scores.get) return { "primary_category": primary_category, "categories": [{"category": cat, "score": score} for cat, score in scores.items()], "confidence": 0.8 if scores else 0.5, "document_type": "legal_document", "legal_domain": primary_category } except Exception as e: logger.error(f"Classification error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "models_loaded": list(models.keys()), "timestamp": datetime.utcnow().isoformat() } # Add this to main.py when deploying