Spaces:
Sleeping
Sleeping
File size: 5,314 Bytes
c914f37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
# This file contains the endpoint implementations
# In production, merge this with main.py
@app.post("/mask-fill")
async def mask_fill(request: MaskFillRequest):
"""Fill [MASK] tokens in French legal text"""
try:
tokenizer = tokenizers['juribert_base']
model = models['juribert_mlm']
# Create pipeline
fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)
# Get predictions
predictions = fill_mask(request.text, top_k=request.top_k)
return {
"input": request.text,
"predictions": [
{
"sequence": pred['sequence'],
"score": pred['score'],
"token": pred['token_str']
}
for pred in predictions
]
}
except Exception as e:
logger.error(f"Mask fill error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/embeddings")
async def generate_embeddings(request: EmbeddingRequest):
"""Generate embeddings for French legal texts"""
try:
tokenizer = tokenizers['juribert_base']
model = models['juribert_base']
embeddings = []
for text in request.texts:
# Tokenize
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# Generate embeddings
with torch.no_grad():
outputs = model(**inputs)
# Use CLS token embedding
embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
embeddings.append(embedding.tolist())
return {
"embeddings": embeddings,
"dimension": len(embeddings[0]) if embeddings else 0
}
except Exception as e:
logger.error(f"Embedding error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/ner")
async def extract_entities(request: NERRequest):
"""Extract named entities from French legal text"""
try:
# Use CamemBERT NER model
ner_pipeline = models['camembert_ner']
entities = ner_pipeline(request.text)
# Format results
formatted_entities = []
for entity in entities:
formatted_entities.append({
"text": entity['word'],
"type": entity['entity_group'],
"score": entity['score'],
"start": entity['start'],
"end": entity['end']
})
return {
"entities": formatted_entities,
"text": request.text
}
except Exception as e:
logger.error(f"NER error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/qa")
async def question_answering(request: QARequest):
"""Answer questions about French legal documents"""
try:
# Simple implementation for now
# In production, use a fine-tuned QA model
return {
"question": request.question,
"answer": "This feature requires a fine-tuned QA model. Please check back later.",
"confidence": 0.0,
"relevant_articles": [],
"explanation": "QA model is being fine-tuned on French legal data"
}
except Exception as e:
logger.error(f"QA error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/classify")
async def classify_document(request: ClassificationRequest):
"""Classify French legal documents"""
try:
# Simple keyword-based classification for now
text_lower = request.text.lower()
categories = {
"contract": ["contrat", "accord", "convention", "parties"],
"litigation": ["tribunal", "jugement", "litige", "procès"],
"corporate": ["société", "sarl", "sas", "entreprise"],
"employment": ["travail", "salarié", "employeur", "licenciement"]
}
scores = {}
for category, keywords in categories.items():
score = sum(1 for kw in keywords if kw in text_lower)
if score > 0:
scores[category] = score
if not scores:
primary_category = "general"
else:
primary_category = max(scores, key=scores.get)
return {
"primary_category": primary_category,
"categories": [{"category": cat, "score": score} for cat, score in scores.items()],
"confidence": 0.8 if scores else 0.5,
"document_type": "legal_document",
"legal_domain": primary_category
}
except Exception as e:
logger.error(f"Classification error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"models_loaded": list(models.keys()),
"timestamp": datetime.utcnow().isoformat()
}
# Add this to main.py when deploying |