ping98k commited on
Commit
65b6e6f
Β·
1 Parent(s): 3f8e7a3

Refactor balanced K-Means implementation; update beta parameter default value, handle empty embeddings, and ensure proper centroid updates. Enhance UI to include beta input and adjust event handler to read beta value for clustering.

Browse files
Files changed (3) hide show
  1. clustering.js +15 -6
  2. index.html +7 -5
  3. main.js +3 -2
clustering.js CHANGED
@@ -1,16 +1,20 @@
1
  import { UMAP } from "https://cdn.jsdelivr.net/npm/umap-js@1.4.0/+esm";
2
 
3
- function balancedKMeans(emb, k, beta = 1, maxIter = 100) {
 
 
4
  const n = emb.length, d = emb[0].length;
 
 
5
  let cent = kmeansPlusPlusInit(emb, k);
6
- const lab = new Uint32Array(n);
7
  const cnt = new Uint32Array(k);
8
 
9
  for (let iter = 0; iter < maxIter; ++iter) {
10
  let moved = false;
 
11
 
12
  // ── assignment with size penalty ──
13
- cnt.fill(0);
14
  for (let i = 0; i < n; ++i) {
15
  let best = 0, bestCost = Infinity;
16
  for (let c = 0; c < k; ++c) {
@@ -24,18 +28,23 @@ function balancedKMeans(emb, k, beta = 1, maxIter = 100) {
24
  if (cost < bestCost) { bestCost = cost; best = c; }
25
  }
26
  if (lab[i] !== best) { lab[i] = best; moved = true; }
27
- cnt[lab[i]]++;
28
  }
29
 
30
  // ── update centroids ──
31
  cent = Array.from({ length: k }, () => new Array(d).fill(0));
32
  for (let i = 0; i < n; ++i)
33
  for (let j = 0; j < d; ++j) cent[lab[i]][j] += emb[i][j];
 
34
  for (let c = 0; c < k; ++c)
35
- if (cnt[c]) for (let j = 0; j < d; ++j) cent[c][j] /= cnt[c];
 
 
 
36
 
37
- if (!moved) break;
38
  }
 
39
  return { labels: Array.from(lab), centroids: cent };
40
  }
41
 
 
1
  import { UMAP } from "https://cdn.jsdelivr.net/npm/umap-js@1.4.0/+esm";
2
 
3
+ function balancedKMeans(emb, k, beta = 0.01, maxIter = 100) {
4
+ if (!emb.length) return { labels: [], centroids: [] };
5
+
6
  const n = emb.length, d = emb[0].length;
7
+ k = Math.max(2, Math.min(k, n)); // guard k ≀ n
8
+
9
  let cent = kmeansPlusPlusInit(emb, k);
10
+ const lab = new Uint32Array(n).fill(k); // start β€œunassigned”
11
  const cnt = new Uint32Array(k);
12
 
13
  for (let iter = 0; iter < maxIter; ++iter) {
14
  let moved = false;
15
+ cnt.fill(0);
16
 
17
  // ── assignment with size penalty ──
 
18
  for (let i = 0; i < n; ++i) {
19
  let best = 0, bestCost = Infinity;
20
  for (let c = 0; c < k; ++c) {
 
28
  if (cost < bestCost) { bestCost = cost; best = c; }
29
  }
30
  if (lab[i] !== best) { lab[i] = best; moved = true; }
31
+ cnt[best]++;
32
  }
33
 
34
  // ── update centroids ──
35
  cent = Array.from({ length: k }, () => new Array(d).fill(0));
36
  for (let i = 0; i < n; ++i)
37
  for (let j = 0; j < d; ++j) cent[lab[i]][j] += emb[i][j];
38
+
39
  for (let c = 0; c < k; ++c)
40
+ if (cnt[c]) {
41
+ const inv = 1 / cnt[c];
42
+ for (let j = 0; j < d; ++j) cent[c][j] *= inv;
43
+ }
44
 
45
+ if (!moved) break; // converged
46
  }
47
+
48
  return { labels: Array.from(lab), centroids: cent };
49
  }
50
 
index.html CHANGED
@@ -53,17 +53,19 @@
53
  <h1>Embedding Similarity Heatmap</h1>
54
  <textarea id="input"></textarea>
55
  <script type="module">
56
- import { sentences } from './sentences.js';
57
- document.getElementById("input").value = sentences.join("\n");
58
  </script>
59
  <label for="kmeans-k" style="margin-left:10px;">Clusters:</label>
60
  <input id="kmeans-k" type="number" min="2" max="100" value="7" style="width:60px;">
 
 
61
  <label for="kmeans-type" style="margin-left:10px;">Clustering Type:</label>
62
  <select id="kmeans-type" style="width:180px;">
63
- <option value="kmeans">K-Means (standard)</option>
64
- <option value="balancedKMeans">Balanced K-Means</option>
65
  </select>
66
- <button id="kmeans-btn">K-Means Clustering</button>
67
  <button id="heatmap-btn">Similarity Heatmap</button>
68
  <div id="progress-bar">
69
  <div id="progress-bar-inner"></div>
 
53
  <h1>Embedding Similarity Heatmap</h1>
54
  <textarea id="input"></textarea>
55
  <script type="module">
56
+ import { sentences } from './sentences.js';
57
+ document.getElementById("input").value = sentences.join("\n");
58
  </script>
59
  <label for="kmeans-k" style="margin-left:10px;">Clusters:</label>
60
  <input id="kmeans-k" type="number" min="2" max="100" value="7" style="width:60px;">
61
+ <label for="kmeans-beta" style="margin-left:10px;">Beta (balance):</label>
62
+ <input id="kmeans-beta" type="number" min="0" max="10" step="0.0001" value="0.01" style="width:80px;">
63
  <label for="kmeans-type" style="margin-left:10px;">Clustering Type:</label>
64
  <select id="kmeans-type" style="width:180px;">
65
+ <option value="balancedKMeans">Balanced K-Means</option>
66
+ <option value="kmeans">K-Means (standard)</option>
67
  </select>
68
+ <button id="kmeans-btn">Clustering</button>
69
  <button id="heatmap-btn">Similarity Heatmap</button>
70
  <div id="progress-bar">
71
  <div id="progress-bar-inner"></div>
main.js CHANGED
@@ -57,11 +57,12 @@ document.getElementById("kmeans-btn").onclick = async () => {
57
  if (n < 2) return;
58
  const requestedK = parseInt(document.getElementById("kmeans-k").value) || 3;
59
  const k = Math.max(2, Math.min(requestedK, n));
60
- // Read clustering type
61
  const clusteringType = document.getElementById("kmeans-type").value;
 
62
  let labels;
63
  if (clusteringType === "balancedKMeans") {
64
- labels = balancedKMeans(embeddings, k).labels;
65
  } else {
66
  labels = kmeans(embeddings, k).labels;
67
  }
 
57
  if (n < 2) return;
58
  const requestedK = parseInt(document.getElementById("kmeans-k").value) || 3;
59
  const k = Math.max(2, Math.min(requestedK, n));
60
+ // Read clustering type and beta
61
  const clusteringType = document.getElementById("kmeans-type").value;
62
+ const beta = parseFloat(document.getElementById("kmeans-beta").value) || 0.01;
63
  let labels;
64
  if (clusteringType === "balancedKMeans") {
65
+ labels = balancedKMeans(embeddings, k, beta).labels;
66
  } else {
67
  labels = kmeans(embeddings, k).labels;
68
  }