File size: 5,343 Bytes
e041ea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
from transformers import ViTImageProcessor, ViTForImageClassification, pipeline
from fastai.learner import load_learner
from fastai.vision.core import PILImage
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import io
import base64

# 🔹 Modelo ViT desde Hugging Face (HAM10000)
MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
model_vit.eval()

# 🔹 Modelos Fast.ai desde archivo local
model_malignancy = load_learner("ada_learn_malben.pkl")
model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")

# 🔹 Modelo binario ISIC preentrenado (alta fiabilidad)
classifier_isic = pipeline("image-classification", model="VRJBro/skin-cancer-detection")

# 🔹 Clases y niveles de riesgo
CLASSES = [
    "Queratosis actínica / Bowen", "Carcinoma células basales",
    "Lesión queratósica benigna", "Dermatofibroma",
    "Melanoma maligno", "Nevus melanocítico", "Lesión vascular"
]
RISK_LEVELS = {
    0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
    1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8},
    2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
    3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
    4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0},
    5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
    6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
}

def analizar_lesion_combined(img):
    img_fastai = PILImage.create(img)

    # 🔹 ViT prediction
    inputs = feature_extractor(img, return_tensors="pt")
    with torch.no_grad():
        outputs = model_vit(**inputs)
        probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
    pred_idx_vit = int(np.argmax(probs_vit))
    pred_class_vit = CLASSES[pred_idx_vit]
    confidence_vit = probs_vit[pred_idx_vit]

    # 🔹 Fast.ai predictions
    pred_fast_malignant, _, probs_fast_mal = model_malignancy.predict(img_fastai)
    prob_malignant = float(probs_fast_mal[1])
    pred_fast_type, _, probs_fast_type = model_norm2000.predict(img_fastai)

    # 🔹 ISIC binary classification (modelo 4)
    result_isic = classifier_isic(img)
    pred_isic = result_isic[0]['label']
    confidence_isic = result_isic[0]['score']

    # 🔹 Gráfico ViT
    colors_bars = [RISK_LEVELS[i]['color'] for i in range(7)]
    fig, ax = plt.subplots(figsize=(8, 3))
    ax.bar(CLASSES, probs_vit*100, color=colors_bars)
    ax.set_title("Probabilidad ViT por tipo de lesión")
    ax.set_ylabel("Probabilidad (%)")
    ax.set_xticks(np.arange(len(CLASSES)))
    ax.set_xticklabels(CLASSES, rotation=45, ha='right')
    ax.grid(axis='y', alpha=0.2)
    plt.tight_layout()
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    plt.close(fig)
    img_bytes = buf.getvalue()
    img_b64 = base64.b64encode(img_bytes).decode("utf-8")
    html_chart = f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'

    # 🔹 Informe HTML
    informe = f"""

    <div style="font-family:sans-serif; max-width:800px; margin:auto">

    <h2>🧪 Diagnóstico por 4 modelos de IA</h2>

    <table style="border-collapse: collapse; width:100%; font-size:16px">

        <tr><th style="text-align:left">🔍 Modelo</th><th>Resultado</th><th>Confianza</th></tr>

        <tr><td>🧠 ViT (transformer)</td><td><b>{pred_class_vit}</b></td><td>{confidence_vit:.1%}</td></tr>

        <tr><td>🧬 Fast.ai (clasificación)</td><td><b>{pred_fast_type}</b></td><td>N/A</td></tr>

        <tr><td>⚠️ Fast.ai (malignidad)</td><td><b>{"Maligno" if prob_malignant > 0.5 else "Benigno"}</b></td><td>{prob_malignant:.1%}</td></tr>

        <tr><td>🔬 ISIC binario</td><td><b>{pred_isic.capitalize()}</b></td><td>{confidence_isic:.1%}</td></tr>

    </table>

    <br>

    <b>🩺 Recomendación automática:</b><br>

    """

    cancer_risk_score = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
    if prob_malignant > 0.7 or cancer_risk_score > 0.6 or (pred_isic == "cancerous" and confidence_isic > 0.9):
        informe += "🚨 <b>CRÍTICO</b> – Derivación urgente a oncología dermatológica"
    elif prob_malignant > 0.4 or cancer_risk_score > 0.4 or (pred_isic == "cancerous"):
        informe += "⚠️ <b>ALTO RIESGO</b> – Consulta con dermatólogo en 7 días"
    elif cancer_risk_score > 0.2:
        informe += "📋 <b>RIESGO MODERADO</b> – Evaluación programada (2-4 semanas)"
    else:
        informe += "✅ <b>BAJO RIESGO</b> – Seguimiento de rutina (3-6 meses)"

    informe += "</div>"

    return informe, html_chart

# 🔹 Interfaz Gradio actualizada
demo = gr.Interface(
    fn=analizar_lesion_combined,
    inputs=gr.Image(type="pil", label="Sube una imagen de la lesión"),
    outputs=[gr.HTML(label="Informe combinado"), gr.HTML(label="Gráfico ViT")],
    title="Detector de Lesiones Cutáneas (ViT + Fast.ai + ISIC)",
    description="Comparación entre ViT transformer (HAM10000), dos modelos Fast.ai y un clasificador binario ISIC con alta precisión.",
    flagging_mode="never"
)

if __name__ == "__main__":
    demo.launch()