Spaces:
Running
Running
ping98k
commited on
Commit
Β·
65b6e6f
1
Parent(s):
3f8e7a3
Refactor balanced K-Means implementation; update beta parameter default value, handle empty embeddings, and ensure proper centroid updates. Enhance UI to include beta input and adjust event handler to read beta value for clustering.
Browse files- clustering.js +15 -6
- index.html +7 -5
- main.js +3 -2
clustering.js
CHANGED
@@ -1,16 +1,20 @@
|
|
1 |
import { UMAP } from "https://cdn.jsdelivr.net/npm/umap-js@1.4.0/+esm";
|
2 |
|
3 |
-
function balancedKMeans(emb, k, beta =
|
|
|
|
|
4 |
const n = emb.length, d = emb[0].length;
|
|
|
|
|
5 |
let cent = kmeansPlusPlusInit(emb, k);
|
6 |
-
const lab = new Uint32Array(n);
|
7 |
const cnt = new Uint32Array(k);
|
8 |
|
9 |
for (let iter = 0; iter < maxIter; ++iter) {
|
10 |
let moved = false;
|
|
|
11 |
|
12 |
// ββ assignment with size penalty ββ
|
13 |
-
cnt.fill(0);
|
14 |
for (let i = 0; i < n; ++i) {
|
15 |
let best = 0, bestCost = Infinity;
|
16 |
for (let c = 0; c < k; ++c) {
|
@@ -24,18 +28,23 @@ function balancedKMeans(emb, k, beta = 1, maxIter = 100) {
|
|
24 |
if (cost < bestCost) { bestCost = cost; best = c; }
|
25 |
}
|
26 |
if (lab[i] !== best) { lab[i] = best; moved = true; }
|
27 |
-
cnt[
|
28 |
}
|
29 |
|
30 |
// ββ update centroids ββ
|
31 |
cent = Array.from({ length: k }, () => new Array(d).fill(0));
|
32 |
for (let i = 0; i < n; ++i)
|
33 |
for (let j = 0; j < d; ++j) cent[lab[i]][j] += emb[i][j];
|
|
|
34 |
for (let c = 0; c < k; ++c)
|
35 |
-
if (cnt[c])
|
|
|
|
|
|
|
36 |
|
37 |
-
if (!moved) break;
|
38 |
}
|
|
|
39 |
return { labels: Array.from(lab), centroids: cent };
|
40 |
}
|
41 |
|
|
|
1 |
import { UMAP } from "https://cdn.jsdelivr.net/npm/umap-js@1.4.0/+esm";
|
2 |
|
3 |
+
function balancedKMeans(emb, k, beta = 0.01, maxIter = 100) {
|
4 |
+
if (!emb.length) return { labels: [], centroids: [] };
|
5 |
+
|
6 |
const n = emb.length, d = emb[0].length;
|
7 |
+
k = Math.max(2, Math.min(k, n)); // guard k β€ n
|
8 |
+
|
9 |
let cent = kmeansPlusPlusInit(emb, k);
|
10 |
+
const lab = new Uint32Array(n).fill(k); // start βunassignedβ
|
11 |
const cnt = new Uint32Array(k);
|
12 |
|
13 |
for (let iter = 0; iter < maxIter; ++iter) {
|
14 |
let moved = false;
|
15 |
+
cnt.fill(0);
|
16 |
|
17 |
// ββ assignment with size penalty ββ
|
|
|
18 |
for (let i = 0; i < n; ++i) {
|
19 |
let best = 0, bestCost = Infinity;
|
20 |
for (let c = 0; c < k; ++c) {
|
|
|
28 |
if (cost < bestCost) { bestCost = cost; best = c; }
|
29 |
}
|
30 |
if (lab[i] !== best) { lab[i] = best; moved = true; }
|
31 |
+
cnt[best]++;
|
32 |
}
|
33 |
|
34 |
// ββ update centroids ββ
|
35 |
cent = Array.from({ length: k }, () => new Array(d).fill(0));
|
36 |
for (let i = 0; i < n; ++i)
|
37 |
for (let j = 0; j < d; ++j) cent[lab[i]][j] += emb[i][j];
|
38 |
+
|
39 |
for (let c = 0; c < k; ++c)
|
40 |
+
if (cnt[c]) {
|
41 |
+
const inv = 1 / cnt[c];
|
42 |
+
for (let j = 0; j < d; ++j) cent[c][j] *= inv;
|
43 |
+
}
|
44 |
|
45 |
+
if (!moved) break; // converged
|
46 |
}
|
47 |
+
|
48 |
return { labels: Array.from(lab), centroids: cent };
|
49 |
}
|
50 |
|
index.html
CHANGED
@@ -53,17 +53,19 @@
|
|
53 |
<h1>Embedding Similarity Heatmap</h1>
|
54 |
<textarea id="input"></textarea>
|
55 |
<script type="module">
|
56 |
-
|
57 |
-
|
58 |
</script>
|
59 |
<label for="kmeans-k" style="margin-left:10px;">Clusters:</label>
|
60 |
<input id="kmeans-k" type="number" min="2" max="100" value="7" style="width:60px;">
|
|
|
|
|
61 |
<label for="kmeans-type" style="margin-left:10px;">Clustering Type:</label>
|
62 |
<select id="kmeans-type" style="width:180px;">
|
63 |
-
|
64 |
-
|
65 |
</select>
|
66 |
-
<button id="kmeans-btn">
|
67 |
<button id="heatmap-btn">Similarity Heatmap</button>
|
68 |
<div id="progress-bar">
|
69 |
<div id="progress-bar-inner"></div>
|
|
|
53 |
<h1>Embedding Similarity Heatmap</h1>
|
54 |
<textarea id="input"></textarea>
|
55 |
<script type="module">
|
56 |
+
import { sentences } from './sentences.js';
|
57 |
+
document.getElementById("input").value = sentences.join("\n");
|
58 |
</script>
|
59 |
<label for="kmeans-k" style="margin-left:10px;">Clusters:</label>
|
60 |
<input id="kmeans-k" type="number" min="2" max="100" value="7" style="width:60px;">
|
61 |
+
<label for="kmeans-beta" style="margin-left:10px;">Beta (balance):</label>
|
62 |
+
<input id="kmeans-beta" type="number" min="0" max="10" step="0.0001" value="0.01" style="width:80px;">
|
63 |
<label for="kmeans-type" style="margin-left:10px;">Clustering Type:</label>
|
64 |
<select id="kmeans-type" style="width:180px;">
|
65 |
+
<option value="balancedKMeans">Balanced K-Means</option>
|
66 |
+
<option value="kmeans">K-Means (standard)</option>
|
67 |
</select>
|
68 |
+
<button id="kmeans-btn">Clustering</button>
|
69 |
<button id="heatmap-btn">Similarity Heatmap</button>
|
70 |
<div id="progress-bar">
|
71 |
<div id="progress-bar-inner"></div>
|
main.js
CHANGED
@@ -57,11 +57,12 @@ document.getElementById("kmeans-btn").onclick = async () => {
|
|
57 |
if (n < 2) return;
|
58 |
const requestedK = parseInt(document.getElementById("kmeans-k").value) || 3;
|
59 |
const k = Math.max(2, Math.min(requestedK, n));
|
60 |
-
// Read clustering type
|
61 |
const clusteringType = document.getElementById("kmeans-type").value;
|
|
|
62 |
let labels;
|
63 |
if (clusteringType === "balancedKMeans") {
|
64 |
-
labels = balancedKMeans(embeddings, k).labels;
|
65 |
} else {
|
66 |
labels = kmeans(embeddings, k).labels;
|
67 |
}
|
|
|
57 |
if (n < 2) return;
|
58 |
const requestedK = parseInt(document.getElementById("kmeans-k").value) || 3;
|
59 |
const k = Math.max(2, Math.min(requestedK, n));
|
60 |
+
// Read clustering type and beta
|
61 |
const clusteringType = document.getElementById("kmeans-type").value;
|
62 |
+
const beta = parseFloat(document.getElementById("kmeans-beta").value) || 0.01;
|
63 |
let labels;
|
64 |
if (clusteringType === "balancedKMeans") {
|
65 |
+
labels = balancedKMeans(embeddings, k, beta).labels;
|
66 |
} else {
|
67 |
labels = kmeans(embeddings, k).labels;
|
68 |
}
|