ping98k commited on
Commit
3f8e7a3
Β·
1 Parent(s): 935873d

Implement balanced K-Means clustering; add new clustering type option in UI and update event handler to support balanced K-Means functionality.

Browse files
Files changed (3) hide show
  1. clustering.js +42 -0
  2. index.html +5 -0
  3. main.js +9 -3
clustering.js CHANGED
@@ -1,5 +1,45 @@
1
  import { UMAP } from "https://cdn.jsdelivr.net/npm/umap-js@1.4.0/+esm";
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  function kmeansPlusPlusInit(embeddings, k) {
4
  const n = embeddings.length;
5
  const dim = embeddings[0].length;
@@ -102,3 +142,5 @@ export function runUMAP(embeddings, nNeighbors = 15) {
102
  });
103
  return umap.fit(embeddings);
104
  }
 
 
 
1
  import { UMAP } from "https://cdn.jsdelivr.net/npm/umap-js@1.4.0/+esm";
2
 
3
+ function balancedKMeans(emb, k, beta = 1, maxIter = 100) {
4
+ const n = emb.length, d = emb[0].length;
5
+ let cent = kmeansPlusPlusInit(emb, k);
6
+ const lab = new Uint32Array(n);
7
+ const cnt = new Uint32Array(k);
8
+
9
+ for (let iter = 0; iter < maxIter; ++iter) {
10
+ let moved = false;
11
+
12
+ // ── assignment with size penalty ──
13
+ cnt.fill(0);
14
+ for (let i = 0; i < n; ++i) {
15
+ let best = 0, bestCost = Infinity;
16
+ for (let c = 0; c < k; ++c) {
17
+ let dist = 0;
18
+ for (let j = 0; j < d; ++j) {
19
+ const diff = emb[i][j] - cent[c][j];
20
+ dist += diff * diff;
21
+ }
22
+ const sizePenalty = beta * (2 * cnt[c] + 1);
23
+ const cost = dist + sizePenalty;
24
+ if (cost < bestCost) { bestCost = cost; best = c; }
25
+ }
26
+ if (lab[i] !== best) { lab[i] = best; moved = true; }
27
+ cnt[lab[i]]++;
28
+ }
29
+
30
+ // ── update centroids ──
31
+ cent = Array.from({ length: k }, () => new Array(d).fill(0));
32
+ for (let i = 0; i < n; ++i)
33
+ for (let j = 0; j < d; ++j) cent[lab[i]][j] += emb[i][j];
34
+ for (let c = 0; c < k; ++c)
35
+ if (cnt[c]) for (let j = 0; j < d; ++j) cent[c][j] /= cnt[c];
36
+
37
+ if (!moved) break;
38
+ }
39
+ return { labels: Array.from(lab), centroids: cent };
40
+ }
41
+
42
+
43
  function kmeansPlusPlusInit(embeddings, k) {
44
  const n = embeddings.length;
45
  const dim = embeddings[0].length;
 
142
  });
143
  return umap.fit(embeddings);
144
  }
145
+
146
+ export { balancedKMeans };
index.html CHANGED
@@ -58,6 +58,11 @@
58
  </script>
59
  <label for="kmeans-k" style="margin-left:10px;">Clusters:</label>
60
  <input id="kmeans-k" type="number" min="2" max="100" value="7" style="width:60px;">
 
 
 
 
 
61
  <button id="kmeans-btn">K-Means Clustering</button>
62
  <button id="heatmap-btn">Similarity Heatmap</button>
63
  <div id="progress-bar">
 
58
  </script>
59
  <label for="kmeans-k" style="margin-left:10px;">Clusters:</label>
60
  <input id="kmeans-k" type="number" min="2" max="100" value="7" style="width:60px;">
61
+ <label for="kmeans-type" style="margin-left:10px;">Clustering Type:</label>
62
+ <select id="kmeans-type" style="width:180px;">
63
+ <option value="kmeans">K-Means (standard)</option>
64
+ <option value="balancedKMeans">Balanced K-Means</option>
65
+ </select>
66
  <button id="kmeans-btn">K-Means Clustering</button>
67
  <button id="heatmap-btn">Similarity Heatmap</button>
68
  <div id="progress-bar">
main.js CHANGED
@@ -1,5 +1,5 @@
1
  import { getGroupEmbeddings, getLineEmbeddings } from './embedding.js';
2
- import { kmeans } from './clustering.js';
3
  import { plotHeatmap, plotScatter, updateScatter } from './plotting.js';
4
  import { nameCluster } from './cluster_naming.js';
5
  import { prompt_cluster } from './prompt_cluster.js';
@@ -57,8 +57,14 @@ document.getElementById("kmeans-btn").onclick = async () => {
57
  if (n < 2) return;
58
  const requestedK = parseInt(document.getElementById("kmeans-k").value) || 3;
59
  const k = Math.max(2, Math.min(requestedK, n));
60
- // K-Means clustering
61
- const { labels } = kmeans(embeddings, k);
 
 
 
 
 
 
62
  // UMAP projection
63
  const { UMAP } = await import('https://cdn.jsdelivr.net/npm/umap-js@1.4.0/+esm');
64
  const nNeighbors = Math.max(1, Math.min(lines.length - 1, 15));
 
1
  import { getGroupEmbeddings, getLineEmbeddings } from './embedding.js';
2
+ import { kmeans, balancedKMeans } from './clustering.js';
3
  import { plotHeatmap, plotScatter, updateScatter } from './plotting.js';
4
  import { nameCluster } from './cluster_naming.js';
5
  import { prompt_cluster } from './prompt_cluster.js';
 
57
  if (n < 2) return;
58
  const requestedK = parseInt(document.getElementById("kmeans-k").value) || 3;
59
  const k = Math.max(2, Math.min(requestedK, n));
60
+ // Read clustering type
61
+ const clusteringType = document.getElementById("kmeans-type").value;
62
+ let labels;
63
+ if (clusteringType === "balancedKMeans") {
64
+ labels = balancedKMeans(embeddings, k).labels;
65
+ } else {
66
+ labels = kmeans(embeddings, k).labels;
67
+ }
68
  // UMAP projection
69
  const { UMAP } = await import('https://cdn.jsdelivr.net/npm/umap-js@1.4.0/+esm');
70
  const nNeighbors = Math.max(1, Math.min(lines.length - 1, 15));