ThanhDT127 commited on
Commit
eb16c9f
·
1 Parent(s): c7afc09

update model

Browse files
.gitattributes copy ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ ENV HF_HOME=/app/hf_cache
4
+
5
+
6
+ RUN apt-get update && \
7
+ apt-get install -y git curl && \
8
+ rm -rf /var/lib/apt/lists/*
9
+
10
+
11
+ WORKDIR /app
12
+ RUN mkdir -p /app/hf_cache /app/models \
13
+ && chmod -R 777 /app/hf_cache
14
+
15
+ COPY requirements.txt /app/
16
+ RUN pip install --no-cache-dir -r /app/requirements.txt
17
+
18
+
19
+ COPY . /app
20
+
21
+
22
+ CMD ["sh", "-c", "uvicorn main:app --host 0.0.0.0 --port=${PORT:-7860}"]
README copy.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: DATN
3
+ emoji: 👁
4
+ colorFrom: green
5
+ colorTo: yellow
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__pycache__/main.cpython-311.pyc ADDED
Binary file (8.42 kB). View file
 
init.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+
3
+ model = AutoModel.from_pretrained("vinai/phobert-base")
4
+ tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base", use_fast=False)
5
+
6
+ model.save_pretrained("D:/ktlt/python/thuchanh/DATN/DATN/tokenizer/phobert-base-fixed")
7
+ tokenizer.save_pretrained("D:/ktlt/python/thuchanh/DATN/DATN/tokenizer/phobert-base-fixed")
main.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ.setdefault("HF_HOME", "/app/hf_cache")
4
+
5
+ import logging
6
+ from fastapi import FastAPI, Request
7
+ from fastapi.templating import Jinja2Templates
8
+ from pydantic import BaseModel
9
+ import torch
10
+ import torch.nn as nn
11
+ from transformers import AutoModel, AutoTokenizer
12
+ import uvicorn
13
+ from huggingface_hub import hf_hub_download
14
+
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ app = FastAPI()
20
+ templates = Jinja2Templates(directory="templates")
21
+
22
+ HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
23
+ HF_REPO = "ThanhDT127/pho-bert-bilstm"
24
+ HF_FILE = "best_model_1.pth"
25
+
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ MODEL_DIR = "models"
28
+ os.makedirs(MODEL_DIR, exist_ok=True)
29
+ MODEL_PATH = os.path.join(MODEL_DIR, HF_FILE)
30
+
31
+
32
+ try:
33
+ if not os.path.isfile(MODEL_PATH):
34
+ logger.info("Downloading model from Hugging Face Hub")
35
+ MODEL_PATH = hf_hub_download(
36
+ repo_id=HF_REPO,
37
+ filename=HF_FILE,
38
+ cache_dir=os.environ["HF_HOME"],
39
+ force_filename=HF_FILE,
40
+ token=HF_TOKEN
41
+ )
42
+ logger.info("Loading model from %s", MODEL_PATH)
43
+ model_state_dict = torch.load(MODEL_PATH, map_location=device)
44
+ logger.info("Model loaded successfully")
45
+ except Exception as e:
46
+ logger.error("Error loading model: %s", str(e))
47
+ raise
48
+
49
+ class TextInput(BaseModel):
50
+ text: str
51
+
52
+ class BertBiLSTMClassifier(nn.Module):
53
+ def __init__(self, bert_model_name, num_emotion_classes, binary_cols, lstm_hidden_size=256, dropout=0.3):
54
+ super().__init__()
55
+ self.bert = AutoModel.from_pretrained(bert_model_name)
56
+ self.lstm = nn.LSTM(
57
+ input_size=self.bert.config.hidden_size,
58
+ hidden_size=lstm_hidden_size,
59
+ num_layers=1,
60
+ batch_first=True,
61
+ bidirectional=True
62
+ )
63
+ self.dropout = nn.Dropout(dropout)
64
+ self.emotion_fc = nn.Linear(lstm_hidden_size * 2, num_emotion_classes)
65
+ self.binary_fcs = nn.ModuleDict({
66
+ col: nn.Linear(lstm_hidden_size * 2, 1)
67
+ for col in binary_cols
68
+ })
69
+
70
+ def forward(self, input_ids, attention_mask):
71
+ bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
72
+ seq_out = bert_out.last_hidden_state
73
+ lstm_out, _ = self.lstm(seq_out)
74
+ last_hidden = lstm_out[:, -1, :]
75
+ dropped = self.dropout(last_hidden)
76
+ emo_logits = self.emotion_fc(dropped)
77
+ bin_logits = {
78
+ col: self.binary_fcs[col](dropped).squeeze(-1)
79
+ for col in self.binary_fcs
80
+ }
81
+ return emo_logits, bin_logits
82
+
83
+ tokenizer = AutoTokenizer.from_pretrained(
84
+ "vinai/phobert-base",
85
+ use_fast=False,
86
+ cache_dir=os.environ["HF_HOME"]
87
+ )
88
+ binary_cols = [
89
+ 'sản phẩm', 'giá cả', 'vận chuyển',
90
+ 'thái độ và dịch vụ khách hàng', 'khác'
91
+ ]
92
+ model = BertBiLSTMClassifier(
93
+ bert_model_name="vinai/phobert-base",
94
+ num_emotion_classes=3,
95
+ binary_cols=binary_cols,
96
+ lstm_hidden_size=256
97
+ ).to(device)
98
+
99
+ # Load model state dict
100
+ model.load_state_dict(model_state_dict)
101
+ model.eval()
102
+
103
+ threshold_dict = {
104
+ 'sản phẩm': 0.6,
105
+ 'giá cả': 0.4,
106
+ 'vận chuyển': 0.45,
107
+ 'thái độ và dịch vụ khách hàng': 0.35,
108
+ 'khác': 0.4
109
+ }
110
+
111
+ def predict(text: str):
112
+ logger.info("Starting prediction for text: %s", text)
113
+ try:
114
+ enc = tokenizer(
115
+ text, add_special_tokens=True, max_length=128,
116
+ padding='max_length', truncation=True, return_tensors='pt'
117
+ )
118
+ input_ids = enc['input_ids'].to(device)
119
+ attention_mask = enc['attention_mask'].to(device)
120
+ with torch.no_grad():
121
+ emo_logits, bin_logits = model(input_ids, attention_mask)
122
+ emo_pred = torch.argmax(emo_logits, dim=1).item()
123
+ bin_pred = {
124
+ col: (torch.sigmoid(bin_logits[col]) > threshold_dict[col]).float().item()
125
+ for col in binary_cols
126
+ }
127
+ emo_label = ['tiêu cực', 'trung tính', 'tích cực'][emo_pred]
128
+ bin_labels = {col: ('có' if bin_pred[col] == 1 else 'không') for col in binary_cols}
129
+ logger.info("Prediction completed: emotion=%s, binary=%s", emo_label, bin_labels)
130
+ return emo_label, bin_labels
131
+ except Exception as e:
132
+ logger.error("Error during prediction: %s", str(e))
133
+ raise
134
+
135
+ @app.get("/")
136
+ async def read_root(request: Request):
137
+ logger.info("Received GET request for /")
138
+ try:
139
+ response = templates.TemplateResponse("index.html", {"request": request})
140
+ logger.info("Successfully rendered index.html")
141
+ return response
142
+ except Exception as e:
143
+ logger.error("Error rendering index.html: %s", str(e))
144
+ raise
145
+
146
+ @app.post("/predict")
147
+ async def predict_text(input: TextInput):
148
+ logger.info("Received POST request for /predict with input: %s", input.text)
149
+ try:
150
+ emotion, binary = predict(input.text)
151
+ logger.info("Sending prediction response: emotion=%s, binary=%s", emotion, binary)
152
+ return {"emotion": emotion, "binary": binary}
153
+ except Exception as e:
154
+ logger.error("Error in predict_text endpoint: %s", str(e))
155
+ raise
156
+
157
+ if __name__ == "__main__":
158
+ port = int(os.getenv("PORT", 8000))
159
+ logger.info("Starting Uvicorn server on port %d", port)
160
+ uvicorn.run("main:app", host="0.0.0.0", port=port)
models/best_model_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cbff8d51a18ef563326782917ec30e9c143343914a632f7f5db0695bddf2fa8
3
+ size 543768578
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ torch
4
+ transformers
5
+ jinja2
6
+ pydantic
7
+ huggingface_hub
templates/index.html ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="vi">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <title>Phân loại văn bản</title>
6
+ <script src="https://cdn.tailwindcss.com"></script>
7
+ </head>
8
+ <body class="bg-gray-100 min-h-screen flex flex-col items-center justify-center p-6">
9
+ <div class="bg-white shadow-xl rounded-lg p-8 w-full max-w-xl">
10
+ <h1 class="text-2xl font-bold text-center text-blue-600 mb-6">Phân loại văn bản cảm xúc & khía cạnh</h1>
11
+
12
+ <label for="textInput" class="block mb-2 text-sm font-medium text-gray-700">Nhập văn bản:</label>
13
+ <textarea id="textInput" rows="4" placeholder="Nhập văn bản cần phân loại..."
14
+ class="w-full p-3 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-400 resize-none"></textarea>
15
+
16
+ <button onclick="classifyText()"
17
+ class="mt-4 w-full bg-blue-500 hover:bg-blue-600 text-white font-semibold py-2 px-4 rounded-lg transition duration-200">
18
+ Phân loại
19
+ </button>
20
+
21
+ <div id="result" class="mt-6 text-sm text-gray-800"></div>
22
+ </div>
23
+
24
+ <script>
25
+ async function classifyText() {
26
+ const text = document.getElementById("textInput").value;
27
+ if (!text) {
28
+ alert("Vui lòng nhập văn bản!");
29
+ return;
30
+ }
31
+
32
+ try {
33
+ const response = await fetch("/predict", {
34
+ method: "POST",
35
+ headers: { "Content-Type": "application/json" },
36
+ body: JSON.stringify({ text })
37
+ });
38
+
39
+ if (!response.ok) {
40
+ throw new Error(`HTTP error! status: ${response.status}`);
41
+ }
42
+
43
+ const data = await response.json();
44
+ const resultDiv = document.getElementById("result");
45
+ resultDiv.innerHTML = `
46
+ <p><strong>Cảm xúc:</strong> <span class="text-blue-600">${data.emotion}</span></p>
47
+ <p><strong>Sản phẩm:</strong> ${data.binary["sản phẩm"]}</p>
48
+ <p><strong>Giá cả:</strong> ${data.binary["giá cả"]}</p>
49
+ <p><strong>Vận chuyển:</strong> ${data.binary["vận chuyển"]}</p>
50
+ <p><strong>Thái độ và dịch vụ khách hàng:</strong> ${data.binary["thái độ và dịch vụ khách hàng"]}</p>
51
+ <p><strong>Khác:</strong> ${data.binary["khác"]}</p>
52
+ `;
53
+ } catch (error) {
54
+ console.error("Error:", error);
55
+ alert("Có lỗi xảy ra khi phân loại văn bản!");
56
+ }
57
+ }
58
+ </script>
59
+ </body>
60
+ </html>