Spaces:
Sleeping
Sleeping
Commit
·
eb16c9f
1
Parent(s):
c7afc09
update model
Browse files- .gitattributes copy +35 -0
- Dockerfile +22 -0
- README copy.md +10 -0
- __pycache__/main.cpython-311.pyc +0 -0
- init.py +7 -0
- main.py +160 -0
- models/best_model_1.pth +3 -0
- requirements.txt +7 -0
- templates/index.html +60 -0
.gitattributes copy
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim
|
2 |
+
|
3 |
+
ENV HF_HOME=/app/hf_cache
|
4 |
+
|
5 |
+
|
6 |
+
RUN apt-get update && \
|
7 |
+
apt-get install -y git curl && \
|
8 |
+
rm -rf /var/lib/apt/lists/*
|
9 |
+
|
10 |
+
|
11 |
+
WORKDIR /app
|
12 |
+
RUN mkdir -p /app/hf_cache /app/models \
|
13 |
+
&& chmod -R 777 /app/hf_cache
|
14 |
+
|
15 |
+
COPY requirements.txt /app/
|
16 |
+
RUN pip install --no-cache-dir -r /app/requirements.txt
|
17 |
+
|
18 |
+
|
19 |
+
COPY . /app
|
20 |
+
|
21 |
+
|
22 |
+
CMD ["sh", "-c", "uvicorn main:app --host 0.0.0.0 --port=${PORT:-7860}"]
|
README copy.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: DATN
|
3 |
+
emoji: 👁
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
---
|
9 |
+
|
10 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
__pycache__/main.cpython-311.pyc
ADDED
Binary file (8.42 kB). View file
|
|
init.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModel
|
2 |
+
|
3 |
+
model = AutoModel.from_pretrained("vinai/phobert-base")
|
4 |
+
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base", use_fast=False)
|
5 |
+
|
6 |
+
model.save_pretrained("D:/ktlt/python/thuchanh/DATN/DATN/tokenizer/phobert-base-fixed")
|
7 |
+
tokenizer.save_pretrained("D:/ktlt/python/thuchanh/DATN/DATN/tokenizer/phobert-base-fixed")
|
main.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ.setdefault("HF_HOME", "/app/hf_cache")
|
4 |
+
|
5 |
+
import logging
|
6 |
+
from fastapi import FastAPI, Request
|
7 |
+
from fastapi.templating import Jinja2Templates
|
8 |
+
from pydantic import BaseModel
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from transformers import AutoModel, AutoTokenizer
|
12 |
+
import uvicorn
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
+
|
15 |
+
|
16 |
+
logging.basicConfig(level=logging.INFO)
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
app = FastAPI()
|
20 |
+
templates = Jinja2Templates(directory="templates")
|
21 |
+
|
22 |
+
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
23 |
+
HF_REPO = "ThanhDT127/pho-bert-bilstm"
|
24 |
+
HF_FILE = "best_model_1.pth"
|
25 |
+
|
26 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
27 |
+
MODEL_DIR = "models"
|
28 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
29 |
+
MODEL_PATH = os.path.join(MODEL_DIR, HF_FILE)
|
30 |
+
|
31 |
+
|
32 |
+
try:
|
33 |
+
if not os.path.isfile(MODEL_PATH):
|
34 |
+
logger.info("Downloading model from Hugging Face Hub")
|
35 |
+
MODEL_PATH = hf_hub_download(
|
36 |
+
repo_id=HF_REPO,
|
37 |
+
filename=HF_FILE,
|
38 |
+
cache_dir=os.environ["HF_HOME"],
|
39 |
+
force_filename=HF_FILE,
|
40 |
+
token=HF_TOKEN
|
41 |
+
)
|
42 |
+
logger.info("Loading model from %s", MODEL_PATH)
|
43 |
+
model_state_dict = torch.load(MODEL_PATH, map_location=device)
|
44 |
+
logger.info("Model loaded successfully")
|
45 |
+
except Exception as e:
|
46 |
+
logger.error("Error loading model: %s", str(e))
|
47 |
+
raise
|
48 |
+
|
49 |
+
class TextInput(BaseModel):
|
50 |
+
text: str
|
51 |
+
|
52 |
+
class BertBiLSTMClassifier(nn.Module):
|
53 |
+
def __init__(self, bert_model_name, num_emotion_classes, binary_cols, lstm_hidden_size=256, dropout=0.3):
|
54 |
+
super().__init__()
|
55 |
+
self.bert = AutoModel.from_pretrained(bert_model_name)
|
56 |
+
self.lstm = nn.LSTM(
|
57 |
+
input_size=self.bert.config.hidden_size,
|
58 |
+
hidden_size=lstm_hidden_size,
|
59 |
+
num_layers=1,
|
60 |
+
batch_first=True,
|
61 |
+
bidirectional=True
|
62 |
+
)
|
63 |
+
self.dropout = nn.Dropout(dropout)
|
64 |
+
self.emotion_fc = nn.Linear(lstm_hidden_size * 2, num_emotion_classes)
|
65 |
+
self.binary_fcs = nn.ModuleDict({
|
66 |
+
col: nn.Linear(lstm_hidden_size * 2, 1)
|
67 |
+
for col in binary_cols
|
68 |
+
})
|
69 |
+
|
70 |
+
def forward(self, input_ids, attention_mask):
|
71 |
+
bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
72 |
+
seq_out = bert_out.last_hidden_state
|
73 |
+
lstm_out, _ = self.lstm(seq_out)
|
74 |
+
last_hidden = lstm_out[:, -1, :]
|
75 |
+
dropped = self.dropout(last_hidden)
|
76 |
+
emo_logits = self.emotion_fc(dropped)
|
77 |
+
bin_logits = {
|
78 |
+
col: self.binary_fcs[col](dropped).squeeze(-1)
|
79 |
+
for col in self.binary_fcs
|
80 |
+
}
|
81 |
+
return emo_logits, bin_logits
|
82 |
+
|
83 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
84 |
+
"vinai/phobert-base",
|
85 |
+
use_fast=False,
|
86 |
+
cache_dir=os.environ["HF_HOME"]
|
87 |
+
)
|
88 |
+
binary_cols = [
|
89 |
+
'sản phẩm', 'giá cả', 'vận chuyển',
|
90 |
+
'thái độ và dịch vụ khách hàng', 'khác'
|
91 |
+
]
|
92 |
+
model = BertBiLSTMClassifier(
|
93 |
+
bert_model_name="vinai/phobert-base",
|
94 |
+
num_emotion_classes=3,
|
95 |
+
binary_cols=binary_cols,
|
96 |
+
lstm_hidden_size=256
|
97 |
+
).to(device)
|
98 |
+
|
99 |
+
# Load model state dict
|
100 |
+
model.load_state_dict(model_state_dict)
|
101 |
+
model.eval()
|
102 |
+
|
103 |
+
threshold_dict = {
|
104 |
+
'sản phẩm': 0.6,
|
105 |
+
'giá cả': 0.4,
|
106 |
+
'vận chuyển': 0.45,
|
107 |
+
'thái độ và dịch vụ khách hàng': 0.35,
|
108 |
+
'khác': 0.4
|
109 |
+
}
|
110 |
+
|
111 |
+
def predict(text: str):
|
112 |
+
logger.info("Starting prediction for text: %s", text)
|
113 |
+
try:
|
114 |
+
enc = tokenizer(
|
115 |
+
text, add_special_tokens=True, max_length=128,
|
116 |
+
padding='max_length', truncation=True, return_tensors='pt'
|
117 |
+
)
|
118 |
+
input_ids = enc['input_ids'].to(device)
|
119 |
+
attention_mask = enc['attention_mask'].to(device)
|
120 |
+
with torch.no_grad():
|
121 |
+
emo_logits, bin_logits = model(input_ids, attention_mask)
|
122 |
+
emo_pred = torch.argmax(emo_logits, dim=1).item()
|
123 |
+
bin_pred = {
|
124 |
+
col: (torch.sigmoid(bin_logits[col]) > threshold_dict[col]).float().item()
|
125 |
+
for col in binary_cols
|
126 |
+
}
|
127 |
+
emo_label = ['tiêu cực', 'trung tính', 'tích cực'][emo_pred]
|
128 |
+
bin_labels = {col: ('có' if bin_pred[col] == 1 else 'không') for col in binary_cols}
|
129 |
+
logger.info("Prediction completed: emotion=%s, binary=%s", emo_label, bin_labels)
|
130 |
+
return emo_label, bin_labels
|
131 |
+
except Exception as e:
|
132 |
+
logger.error("Error during prediction: %s", str(e))
|
133 |
+
raise
|
134 |
+
|
135 |
+
@app.get("/")
|
136 |
+
async def read_root(request: Request):
|
137 |
+
logger.info("Received GET request for /")
|
138 |
+
try:
|
139 |
+
response = templates.TemplateResponse("index.html", {"request": request})
|
140 |
+
logger.info("Successfully rendered index.html")
|
141 |
+
return response
|
142 |
+
except Exception as e:
|
143 |
+
logger.error("Error rendering index.html: %s", str(e))
|
144 |
+
raise
|
145 |
+
|
146 |
+
@app.post("/predict")
|
147 |
+
async def predict_text(input: TextInput):
|
148 |
+
logger.info("Received POST request for /predict with input: %s", input.text)
|
149 |
+
try:
|
150 |
+
emotion, binary = predict(input.text)
|
151 |
+
logger.info("Sending prediction response: emotion=%s, binary=%s", emotion, binary)
|
152 |
+
return {"emotion": emotion, "binary": binary}
|
153 |
+
except Exception as e:
|
154 |
+
logger.error("Error in predict_text endpoint: %s", str(e))
|
155 |
+
raise
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
port = int(os.getenv("PORT", 8000))
|
159 |
+
logger.info("Starting Uvicorn server on port %d", port)
|
160 |
+
uvicorn.run("main:app", host="0.0.0.0", port=port)
|
models/best_model_1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0cbff8d51a18ef563326782917ec30e9c143343914a632f7f5db0695bddf2fa8
|
3 |
+
size 543768578
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
uvicorn[standard]
|
3 |
+
torch
|
4 |
+
transformers
|
5 |
+
jinja2
|
6 |
+
pydantic
|
7 |
+
huggingface_hub
|
templates/index.html
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="vi">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<title>Phân loại văn bản</title>
|
6 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
7 |
+
</head>
|
8 |
+
<body class="bg-gray-100 min-h-screen flex flex-col items-center justify-center p-6">
|
9 |
+
<div class="bg-white shadow-xl rounded-lg p-8 w-full max-w-xl">
|
10 |
+
<h1 class="text-2xl font-bold text-center text-blue-600 mb-6">Phân loại văn bản cảm xúc & khía cạnh</h1>
|
11 |
+
|
12 |
+
<label for="textInput" class="block mb-2 text-sm font-medium text-gray-700">Nhập văn bản:</label>
|
13 |
+
<textarea id="textInput" rows="4" placeholder="Nhập văn bản cần phân loại..."
|
14 |
+
class="w-full p-3 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-400 resize-none"></textarea>
|
15 |
+
|
16 |
+
<button onclick="classifyText()"
|
17 |
+
class="mt-4 w-full bg-blue-500 hover:bg-blue-600 text-white font-semibold py-2 px-4 rounded-lg transition duration-200">
|
18 |
+
Phân loại
|
19 |
+
</button>
|
20 |
+
|
21 |
+
<div id="result" class="mt-6 text-sm text-gray-800"></div>
|
22 |
+
</div>
|
23 |
+
|
24 |
+
<script>
|
25 |
+
async function classifyText() {
|
26 |
+
const text = document.getElementById("textInput").value;
|
27 |
+
if (!text) {
|
28 |
+
alert("Vui lòng nhập văn bản!");
|
29 |
+
return;
|
30 |
+
}
|
31 |
+
|
32 |
+
try {
|
33 |
+
const response = await fetch("/predict", {
|
34 |
+
method: "POST",
|
35 |
+
headers: { "Content-Type": "application/json" },
|
36 |
+
body: JSON.stringify({ text })
|
37 |
+
});
|
38 |
+
|
39 |
+
if (!response.ok) {
|
40 |
+
throw new Error(`HTTP error! status: ${response.status}`);
|
41 |
+
}
|
42 |
+
|
43 |
+
const data = await response.json();
|
44 |
+
const resultDiv = document.getElementById("result");
|
45 |
+
resultDiv.innerHTML = `
|
46 |
+
<p><strong>Cảm xúc:</strong> <span class="text-blue-600">${data.emotion}</span></p>
|
47 |
+
<p><strong>Sản phẩm:</strong> ${data.binary["sản phẩm"]}</p>
|
48 |
+
<p><strong>Giá cả:</strong> ${data.binary["giá cả"]}</p>
|
49 |
+
<p><strong>Vận chuyển:</strong> ${data.binary["vận chuyển"]}</p>
|
50 |
+
<p><strong>Thái độ và dịch vụ khách hàng:</strong> ${data.binary["thái độ và dịch vụ khách hàng"]}</p>
|
51 |
+
<p><strong>Khác:</strong> ${data.binary["khác"]}</p>
|
52 |
+
`;
|
53 |
+
} catch (error) {
|
54 |
+
console.error("Error:", error);
|
55 |
+
alert("Có lỗi xảy ra khi phân loại văn bản!");
|
56 |
+
}
|
57 |
+
}
|
58 |
+
</script>
|
59 |
+
</body>
|
60 |
+
</html>
|