Spaces:
Running
Running
ping98k
commited on
Commit
Β·
3f8e7a3
1
Parent(s):
935873d
Implement balanced K-Means clustering; add new clustering type option in UI and update event handler to support balanced K-Means functionality.
Browse files- clustering.js +42 -0
- index.html +5 -0
- main.js +9 -3
clustering.js
CHANGED
@@ -1,5 +1,45 @@
|
|
1 |
import { UMAP } from "https://cdn.jsdelivr.net/npm/umap-js@1.4.0/+esm";
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
function kmeansPlusPlusInit(embeddings, k) {
|
4 |
const n = embeddings.length;
|
5 |
const dim = embeddings[0].length;
|
@@ -102,3 +142,5 @@ export function runUMAP(embeddings, nNeighbors = 15) {
|
|
102 |
});
|
103 |
return umap.fit(embeddings);
|
104 |
}
|
|
|
|
|
|
1 |
import { UMAP } from "https://cdn.jsdelivr.net/npm/umap-js@1.4.0/+esm";
|
2 |
|
3 |
+
function balancedKMeans(emb, k, beta = 1, maxIter = 100) {
|
4 |
+
const n = emb.length, d = emb[0].length;
|
5 |
+
let cent = kmeansPlusPlusInit(emb, k);
|
6 |
+
const lab = new Uint32Array(n);
|
7 |
+
const cnt = new Uint32Array(k);
|
8 |
+
|
9 |
+
for (let iter = 0; iter < maxIter; ++iter) {
|
10 |
+
let moved = false;
|
11 |
+
|
12 |
+
// ββ assignment with size penalty ββ
|
13 |
+
cnt.fill(0);
|
14 |
+
for (let i = 0; i < n; ++i) {
|
15 |
+
let best = 0, bestCost = Infinity;
|
16 |
+
for (let c = 0; c < k; ++c) {
|
17 |
+
let dist = 0;
|
18 |
+
for (let j = 0; j < d; ++j) {
|
19 |
+
const diff = emb[i][j] - cent[c][j];
|
20 |
+
dist += diff * diff;
|
21 |
+
}
|
22 |
+
const sizePenalty = beta * (2 * cnt[c] + 1);
|
23 |
+
const cost = dist + sizePenalty;
|
24 |
+
if (cost < bestCost) { bestCost = cost; best = c; }
|
25 |
+
}
|
26 |
+
if (lab[i] !== best) { lab[i] = best; moved = true; }
|
27 |
+
cnt[lab[i]]++;
|
28 |
+
}
|
29 |
+
|
30 |
+
// ββ update centroids ββ
|
31 |
+
cent = Array.from({ length: k }, () => new Array(d).fill(0));
|
32 |
+
for (let i = 0; i < n; ++i)
|
33 |
+
for (let j = 0; j < d; ++j) cent[lab[i]][j] += emb[i][j];
|
34 |
+
for (let c = 0; c < k; ++c)
|
35 |
+
if (cnt[c]) for (let j = 0; j < d; ++j) cent[c][j] /= cnt[c];
|
36 |
+
|
37 |
+
if (!moved) break;
|
38 |
+
}
|
39 |
+
return { labels: Array.from(lab), centroids: cent };
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
function kmeansPlusPlusInit(embeddings, k) {
|
44 |
const n = embeddings.length;
|
45 |
const dim = embeddings[0].length;
|
|
|
142 |
});
|
143 |
return umap.fit(embeddings);
|
144 |
}
|
145 |
+
|
146 |
+
export { balancedKMeans };
|
index.html
CHANGED
@@ -58,6 +58,11 @@
|
|
58 |
</script>
|
59 |
<label for="kmeans-k" style="margin-left:10px;">Clusters:</label>
|
60 |
<input id="kmeans-k" type="number" min="2" max="100" value="7" style="width:60px;">
|
|
|
|
|
|
|
|
|
|
|
61 |
<button id="kmeans-btn">K-Means Clustering</button>
|
62 |
<button id="heatmap-btn">Similarity Heatmap</button>
|
63 |
<div id="progress-bar">
|
|
|
58 |
</script>
|
59 |
<label for="kmeans-k" style="margin-left:10px;">Clusters:</label>
|
60 |
<input id="kmeans-k" type="number" min="2" max="100" value="7" style="width:60px;">
|
61 |
+
<label for="kmeans-type" style="margin-left:10px;">Clustering Type:</label>
|
62 |
+
<select id="kmeans-type" style="width:180px;">
|
63 |
+
<option value="kmeans">K-Means (standard)</option>
|
64 |
+
<option value="balancedKMeans">Balanced K-Means</option>
|
65 |
+
</select>
|
66 |
<button id="kmeans-btn">K-Means Clustering</button>
|
67 |
<button id="heatmap-btn">Similarity Heatmap</button>
|
68 |
<div id="progress-bar">
|
main.js
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import { getGroupEmbeddings, getLineEmbeddings } from './embedding.js';
|
2 |
-
import { kmeans } from './clustering.js';
|
3 |
import { plotHeatmap, plotScatter, updateScatter } from './plotting.js';
|
4 |
import { nameCluster } from './cluster_naming.js';
|
5 |
import { prompt_cluster } from './prompt_cluster.js';
|
@@ -57,8 +57,14 @@ document.getElementById("kmeans-btn").onclick = async () => {
|
|
57 |
if (n < 2) return;
|
58 |
const requestedK = parseInt(document.getElementById("kmeans-k").value) || 3;
|
59 |
const k = Math.max(2, Math.min(requestedK, n));
|
60 |
-
//
|
61 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
// UMAP projection
|
63 |
const { UMAP } = await import('https://cdn.jsdelivr.net/npm/umap-js@1.4.0/+esm');
|
64 |
const nNeighbors = Math.max(1, Math.min(lines.length - 1, 15));
|
|
|
1 |
import { getGroupEmbeddings, getLineEmbeddings } from './embedding.js';
|
2 |
+
import { kmeans, balancedKMeans } from './clustering.js';
|
3 |
import { plotHeatmap, plotScatter, updateScatter } from './plotting.js';
|
4 |
import { nameCluster } from './cluster_naming.js';
|
5 |
import { prompt_cluster } from './prompt_cluster.js';
|
|
|
57 |
if (n < 2) return;
|
58 |
const requestedK = parseInt(document.getElementById("kmeans-k").value) || 3;
|
59 |
const k = Math.max(2, Math.min(requestedK, n));
|
60 |
+
// Read clustering type
|
61 |
+
const clusteringType = document.getElementById("kmeans-type").value;
|
62 |
+
let labels;
|
63 |
+
if (clusteringType === "balancedKMeans") {
|
64 |
+
labels = balancedKMeans(embeddings, k).labels;
|
65 |
+
} else {
|
66 |
+
labels = kmeans(embeddings, k).labels;
|
67 |
+
}
|
68 |
// UMAP projection
|
69 |
const { UMAP } = await import('https://cdn.jsdelivr.net/npm/umap-js@1.4.0/+esm');
|
70 |
const nNeighbors = Math.max(1, Math.min(lines.length - 1, 15));
|