Spaces:
Running
Running
ping98k
commited on
Commit
·
46bbd3d
1
Parent(s):
aa85324
Refactor K-Means clustering code for improved readability and performance; streamline cluster name generation and enhance progress bar updates
Browse files
main.js
CHANGED
@@ -18,9 +18,9 @@ document.getElementById("run").onclick = async () => {
|
|
18 |
|
19 |
// Extract cluster names from lines starting with ##
|
20 |
const clusterNames = text.split(/\n/)
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
|
25 |
|
26 |
const groupEmbeddings = [];
|
@@ -54,7 +54,7 @@ document.getElementById("run").onclick = async () => {
|
|
54 |
sim.push(row);
|
55 |
}
|
56 |
// If clusterNames exist and match group count, use as axis labels
|
57 |
-
let xLabels = clusterNames && clusterNames.length === n ? clusterNames : Array.from({length: n}, (_, i) => `Group ${i+1}`);
|
58 |
const data = [{ z: sim, type: "heatmap", colorscale: "Viridis", zmin: 0.7, zmax: 1, x: xLabels, y: xLabels }];
|
59 |
Plotly.newPlot("plot-heatmap", data, {
|
60 |
xaxis: { title: "Group", scaleanchor: "y", scaleratio: 1 },
|
@@ -66,58 +66,98 @@ document.getElementById("run").onclick = async () => {
|
|
66 |
});
|
67 |
};
|
68 |
|
69 |
-
// --- K-Means Clustering ---
|
70 |
document.getElementById("kmeans-btn").onclick = async () => {
|
71 |
const progressBar = document.getElementById("progress-bar");
|
72 |
const progressBarInner = document.getElementById("progress-bar-inner");
|
73 |
progressBar.style.display = "block";
|
74 |
-
progressBarInner.style.width = "0%";
|
75 |
|
76 |
const text = document.getElementById("input").value;
|
77 |
-
|
78 |
-
const lines = text.split(/\n/)
|
79 |
-
.map(x => x.trim())
|
80 |
-
.filter(x => x && !x.startsWith('##'));
|
81 |
const prompts = lines.map(s => `Instruct: ${task}\nQuery:${s}`);
|
82 |
const out = await embed(prompts, { pooling: "mean", normalize: true });
|
83 |
-
const embeddings = typeof out.tolist ===
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
// K-Means implementation
|
86 |
-
const k = Math.max(2, Math.min(20, parseInt(document.getElementById("kmeans-k").value) || 3));
|
87 |
-
const n = embeddings.length, dim = embeddings[0].length;
|
88 |
let centroids = Array.from({ length: k }, () => embeddings[Math.floor(Math.random() * n)].slice());
|
89 |
let labels = new Array(n).fill(0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
for (let iter = 0; iter < 100; ++iter) {
|
|
|
91 |
for (let i = 0; i < n; ++i) {
|
92 |
let best = 0, bestDist = Infinity;
|
93 |
for (let c = 0; c < k; ++c) {
|
94 |
let dist = 0;
|
95 |
-
for (let d = 0; d < dim; ++d)
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
}
|
98 |
-
labels[i] = best;
|
99 |
}
|
|
|
100 |
centroids = Array.from({ length: k }, () => new Array(dim).fill(0));
|
101 |
const counts = new Array(k).fill(0);
|
102 |
for (let i = 0; i < n; ++i) {
|
103 |
counts[labels[i]]++;
|
104 |
-
for (let d = 0; d < dim; ++d)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
}
|
106 |
-
|
107 |
}
|
108 |
-
|
109 |
const nNeighbors = Math.max(1, Math.min(lines.length - 1, 15));
|
110 |
const umap = new UMAP({ nComponents: 2, nNeighbors, minDist: 0.1 });
|
111 |
const proj = umap.fit(embeddings);
|
112 |
-
|
113 |
-
const clustered = Array.from({ length: k }, (
|
114 |
-
for (let i = 0; i < n; ++i)
|
115 |
-
|
|
|
116 |
const colors = ["red", "blue", "green", "orange", "purple", "cyan", "magenta", "yellow", "brown", "black", "lime", "navy", "teal", "olive", "maroon", "pink", "gray", "gold", "aqua", "indigo"];
|
117 |
const placeholderNames = Array.from({ length: k }, (_, c) => `Cluster ${c + 1}`);
|
118 |
-
|
119 |
-
x: [], y: [], text: [],
|
120 |
-
|
|
|
|
|
121 |
}));
|
122 |
for (let i = 0; i < n; ++i) {
|
123 |
traces[labels[i]].x.push(proj[i][0]);
|
@@ -133,15 +173,14 @@ document.getElementById("kmeans-btn").onclick = async () => {
|
|
133 |
title: `K-Means Clustering (k=${k})`,
|
134 |
legend: { x: 1.05, y: 0.5, orientation: "v", xanchor: "left", yanchor: "middle" }
|
135 |
});
|
136 |
-
|
137 |
const clusterNames = [];
|
138 |
for (let c = 0; c < k; ++c) {
|
139 |
-
progressBarInner.style.width = `${Math.round(((c) / k) * 100)}%`;
|
|
|
140 |
const joined = clustered[c].join("\n");
|
141 |
const messages = [
|
142 |
-
{
|
143 |
-
role: "system", content: prompt_cluster
|
144 |
-
},
|
145 |
{ role: "user", content: `Input:\n${joined}\nOutput:` }
|
146 |
];
|
147 |
|
@@ -150,64 +189,29 @@ document.getElementById("kmeans-btn").onclick = async () => {
|
|
150 |
return_dict: true,
|
151 |
enable_thinking: false,
|
152 |
});
|
153 |
-
|
154 |
-
let state = "answering";
|
155 |
-
let startTime;
|
156 |
-
let numTokens = 0;
|
157 |
-
let tps;
|
158 |
-
const token_callback_function = (tokens) => {
|
159 |
-
startTime ??= performance.now();
|
160 |
-
if (numTokens++ > 0) {
|
161 |
-
tps = (numTokens / (performance.now() - startTime)) * 1000;
|
162 |
-
}
|
163 |
-
switch (Number(tokens[0])) {
|
164 |
-
case START_THINKING_TOKEN_ID:
|
165 |
-
state = "thinking";
|
166 |
-
break;
|
167 |
-
case END_THINKING_TOKEN_ID:
|
168 |
-
state = "answering";
|
169 |
-
break;
|
170 |
-
}
|
171 |
-
console.log(state, tokens, tokenizer.decode(tokens));
|
172 |
-
};
|
173 |
-
const callback_function = (output) => {
|
174 |
-
// You can update UI here if desired
|
175 |
-
console.log({ output, tps, numTokens, state });
|
176 |
-
};
|
177 |
-
const streamer = new TextStreamer(tokenizer, {
|
178 |
-
skip_prompt: true,
|
179 |
-
skip_special_tokens: true,
|
180 |
-
callback_function,
|
181 |
-
token_callback_function,
|
182 |
-
});
|
183 |
const outputTokens = await model.generate({
|
184 |
...inputs,
|
185 |
max_new_tokens: 1024,
|
186 |
do_sample: true,
|
187 |
-
temperature: 0.6
|
188 |
-
// streamer,
|
189 |
});
|
190 |
-
let rawName = tokenizer
|
191 |
-
.decode(outputTokens[0], { skip_special_tokens: false })
|
192 |
-
.trim();
|
193 |
|
194 |
-
|
195 |
-
|
|
|
|
|
196 |
|
197 |
if (rawName.includes(THINK_TAG)) {
|
198 |
-
// take everything after the last </think>
|
199 |
rawName = rawName.substring(rawName.lastIndexOf(THINK_TAG) + THINK_TAG.length).trim();
|
200 |
}
|
201 |
if (rawName.includes(END_TAG)) {
|
202 |
-
// take everything before the first <|im_end|>
|
203 |
rawName = rawName.substring(0, rawName.indexOf(END_TAG)).trim();
|
204 |
}
|
205 |
-
|
206 |
clusterNames.push(rawName || `Cluster ${c + 1}`);
|
207 |
-
|
208 |
-
|
209 |
-
traces[c].name = clusterNames[c];
|
210 |
-
}
|
211 |
Plotly.react("plot-scatter", traces, {
|
212 |
xaxis: { title: "UMAP-1", scaleanchor: "y", scaleratio: 1 },
|
213 |
yaxis: { title: "UMAP-2", scaleanchor: "x", scaleratio: 1 },
|
@@ -217,9 +221,13 @@ document.getElementById("kmeans-btn").onclick = async () => {
|
|
217 |
title: `K-Means Clustering (k=${k})`,
|
218 |
legend: { x: 1.05, y: 0.5, orientation: "v", xanchor: "left", yanchor: "middle" }
|
219 |
});
|
220 |
-
|
221 |
-
document.getElementById("input").value = clustered.map((g, i) =>
|
222 |
-
|
|
|
|
|
223 |
document.getElementById("run").onclick();
|
224 |
}
|
|
|
|
|
225 |
};
|
|
|
18 |
|
19 |
// Extract cluster names from lines starting with ##
|
20 |
const clusterNames = text.split(/\n/)
|
21 |
+
.map(x => x.trim())
|
22 |
+
.filter(x => x && x.startsWith('##'))
|
23 |
+
.map(x => x.replace(/^##\s*/, ''));
|
24 |
|
25 |
|
26 |
const groupEmbeddings = [];
|
|
|
54 |
sim.push(row);
|
55 |
}
|
56 |
// If clusterNames exist and match group count, use as axis labels
|
57 |
+
let xLabels = clusterNames && clusterNames.length === n ? clusterNames : Array.from({ length: n }, (_, i) => `Group ${i + 1}`);
|
58 |
const data = [{ z: sim, type: "heatmap", colorscale: "Viridis", zmin: 0.7, zmax: 1, x: xLabels, y: xLabels }];
|
59 |
Plotly.newPlot("plot-heatmap", data, {
|
60 |
xaxis: { title: "Group", scaleanchor: "y", scaleratio: 1 },
|
|
|
66 |
});
|
67 |
};
|
68 |
|
|
|
69 |
document.getElementById("kmeans-btn").onclick = async () => {
|
70 |
const progressBar = document.getElementById("progress-bar");
|
71 |
const progressBarInner = document.getElementById("progress-bar-inner");
|
72 |
progressBar.style.display = "block";
|
73 |
+
progressBarInner.style.width = "0%"; // Set to 0% at the start
|
74 |
|
75 |
const text = document.getElementById("input").value;
|
76 |
+
const lines = text.split(/\n/).map(x => x.trim()).filter(x => x && !x.startsWith("##"));
|
|
|
|
|
|
|
77 |
const prompts = lines.map(s => `Instruct: ${task}\nQuery:${s}`);
|
78 |
const out = await embed(prompts, { pooling: "mean", normalize: true });
|
79 |
+
const embeddings = typeof out.tolist === "function" ? out.tolist() : out.data;
|
80 |
+
|
81 |
+
const n = embeddings.length;
|
82 |
+
if (n < 2) return;
|
83 |
+
|
84 |
+
const requestedK = parseInt(document.getElementById("kmeans-k").value) || 3;
|
85 |
+
const k = Math.max(2, Math.min(requestedK, n));
|
86 |
+
const dim = embeddings[0].length;
|
87 |
|
|
|
|
|
|
|
88 |
let centroids = Array.from({ length: k }, () => embeddings[Math.floor(Math.random() * n)].slice());
|
89 |
let labels = new Array(n).fill(0);
|
90 |
+
|
91 |
+
const reseed = () => {
|
92 |
+
let bestIdx = 0, bestDist = -1;
|
93 |
+
for (let i = 0; i < n; ++i) {
|
94 |
+
let minDist = Infinity;
|
95 |
+
for (let c = 0; c < k; ++c) {
|
96 |
+
let dist = 0;
|
97 |
+
for (let d = 0; d < dim; ++d)
|
98 |
+
dist += (embeddings[i][d] - centroids[c][d]) ** 2;
|
99 |
+
if (dist < minDist) minDist = dist;
|
100 |
+
}
|
101 |
+
if (minDist > bestDist) {
|
102 |
+
bestDist = minDist;
|
103 |
+
bestIdx = i;
|
104 |
+
}
|
105 |
+
}
|
106 |
+
return embeddings[bestIdx].slice();
|
107 |
+
};
|
108 |
+
|
109 |
for (let iter = 0; iter < 100; ++iter) {
|
110 |
+
let changed = false;
|
111 |
for (let i = 0; i < n; ++i) {
|
112 |
let best = 0, bestDist = Infinity;
|
113 |
for (let c = 0; c < k; ++c) {
|
114 |
let dist = 0;
|
115 |
+
for (let d = 0; d < dim; ++d)
|
116 |
+
dist += (embeddings[i][d] - centroids[c][d]) ** 2;
|
117 |
+
if (dist < bestDist) {
|
118 |
+
bestDist = dist;
|
119 |
+
best = c;
|
120 |
+
}
|
121 |
+
}
|
122 |
+
if (labels[i] !== best) {
|
123 |
+
labels[i] = best;
|
124 |
+
changed = true;
|
125 |
}
|
|
|
126 |
}
|
127 |
+
|
128 |
centroids = Array.from({ length: k }, () => new Array(dim).fill(0));
|
129 |
const counts = new Array(k).fill(0);
|
130 |
for (let i = 0; i < n; ++i) {
|
131 |
counts[labels[i]]++;
|
132 |
+
for (let d = 0; d < dim; ++d)
|
133 |
+
centroids[labels[i]][d] += embeddings[i][d];
|
134 |
+
}
|
135 |
+
for (let c = 0; c < k; ++c) {
|
136 |
+
if (counts[c] === 0) {
|
137 |
+
centroids[c] = reseed();
|
138 |
+
} else {
|
139 |
+
for (let d = 0; d < dim; ++d)
|
140 |
+
centroids[c][d] /= counts[c];
|
141 |
+
}
|
142 |
}
|
143 |
+
if (!changed) break;
|
144 |
}
|
145 |
+
|
146 |
const nNeighbors = Math.max(1, Math.min(lines.length - 1, 15));
|
147 |
const umap = new UMAP({ nComponents: 2, nNeighbors, minDist: 0.1 });
|
148 |
const proj = umap.fit(embeddings);
|
149 |
+
|
150 |
+
const clustered = Array.from({ length: k }, () => []);
|
151 |
+
for (let i = 0; i < n; ++i)
|
152 |
+
clustered[labels[i]].push(lines[i]);
|
153 |
+
|
154 |
const colors = ["red", "blue", "green", "orange", "purple", "cyan", "magenta", "yellow", "brown", "black", "lime", "navy", "teal", "olive", "maroon", "pink", "gray", "gold", "aqua", "indigo"];
|
155 |
const placeholderNames = Array.from({ length: k }, (_, c) => `Cluster ${c + 1}`);
|
156 |
+
const traces = Array.from({ length: k }, (_, c) => ({
|
157 |
+
x: [], y: [], text: [],
|
158 |
+
mode: "markers", type: "scatter",
|
159 |
+
name: placeholderNames[c],
|
160 |
+
marker: { color: colors[c % colors.length], size: 12, line: { width: 1, color: "#333" } }
|
161 |
}));
|
162 |
for (let i = 0; i < n; ++i) {
|
163 |
traces[labels[i]].x.push(proj[i][0]);
|
|
|
173 |
title: `K-Means Clustering (k=${k})`,
|
174 |
legend: { x: 1.05, y: 0.5, orientation: "v", xanchor: "left", yanchor: "middle" }
|
175 |
});
|
176 |
+
|
177 |
const clusterNames = [];
|
178 |
for (let c = 0; c < k; ++c) {
|
179 |
+
progressBarInner.style.width = `${Math.round(((c + 1) / k) * 100)}%`;
|
180 |
+
|
181 |
const joined = clustered[c].join("\n");
|
182 |
const messages = [
|
183 |
+
{ role: "system", content: prompt_cluster },
|
|
|
|
|
184 |
{ role: "user", content: `Input:\n${joined}\nOutput:` }
|
185 |
];
|
186 |
|
|
|
189 |
return_dict: true,
|
190 |
enable_thinking: false,
|
191 |
});
|
192 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
const outputTokens = await model.generate({
|
194 |
...inputs,
|
195 |
max_new_tokens: 1024,
|
196 |
do_sample: true,
|
197 |
+
temperature: 0.6
|
|
|
198 |
});
|
|
|
|
|
|
|
199 |
|
200 |
+
let rawName = tokenizer.decode(outputTokens[0], { skip_special_tokens: false }).trim();
|
201 |
+
|
202 |
+
const THINK_TAG = "</think>";
|
203 |
+
const END_TAG = "<|im_end|>";
|
204 |
|
205 |
if (rawName.includes(THINK_TAG)) {
|
|
|
206 |
rawName = rawName.substring(rawName.lastIndexOf(THINK_TAG) + THINK_TAG.length).trim();
|
207 |
}
|
208 |
if (rawName.includes(END_TAG)) {
|
|
|
209 |
rawName = rawName.substring(0, rawName.indexOf(END_TAG)).trim();
|
210 |
}
|
211 |
+
|
212 |
clusterNames.push(rawName || `Cluster ${c + 1}`);
|
213 |
+
traces[c].name = clusterNames[c];
|
214 |
+
|
|
|
|
|
215 |
Plotly.react("plot-scatter", traces, {
|
216 |
xaxis: { title: "UMAP-1", scaleanchor: "y", scaleratio: 1 },
|
217 |
yaxis: { title: "UMAP-2", scaleanchor: "x", scaleratio: 1 },
|
|
|
221 |
title: `K-Means Clustering (k=${k})`,
|
222 |
legend: { x: 1.05, y: 0.5, orientation: "v", xanchor: "left", yanchor: "middle" }
|
223 |
});
|
224 |
+
|
225 |
+
document.getElementById("input").value = clustered.map((g, i) =>
|
226 |
+
`## ${clusterNames[i]}\n${g.join("\n")}`
|
227 |
+
).join("\n\n\n");
|
228 |
+
|
229 |
document.getElementById("run").onclick();
|
230 |
}
|
231 |
+
|
232 |
+
progressBarInner.style.width = "100%"; // Set to 100% after all clusters are named
|
233 |
};
|