sohei1l commited on
Commit
f01c9d3
·
1 Parent(s): c836ca5

Implement adaptive local classifier that learns from user feedback

Browse files
Files changed (3) hide show
  1. src/App.css +13 -0
  2. src/App.jsx +78 -5
  3. src/localClassifier.js +205 -0
src/App.css CHANGED
@@ -147,6 +147,19 @@ header p {
147
  opacity: 0.6;
148
  }
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  .tag-controls {
151
  display: flex;
152
  gap: 0.25rem;
 
147
  opacity: 0.6;
148
  }
149
 
150
+ .tag.local {
151
+ background: linear-gradient(45deg, #9b59b6, #8e44ad);
152
+ }
153
+
154
+ .tag.blended {
155
+ background: linear-gradient(45deg, #f39c12, #e67e22);
156
+ }
157
+
158
+ .source-indicator {
159
+ margin-left: 0.5rem;
160
+ font-size: 0.8em;
161
+ }
162
+
163
  .tag-controls {
164
  display: flex;
165
  gap: 0.25rem;
src/App.jsx CHANGED
@@ -1,6 +1,7 @@
1
  import { useState, useRef, useEffect } from 'react'
2
  import CLAPProcessor from './clapProcessor'
3
  import UserFeedbackStore from './userFeedbackStore'
 
4
  import './App.css'
5
 
6
  function App() {
@@ -18,11 +19,16 @@ function App() {
18
  const chunksRef = useRef([])
19
  const clapProcessorRef = useRef(null)
20
  const feedbackStoreRef = useRef(null)
 
21
 
22
  useEffect(() => {
23
  const initializeStore = async () => {
24
  feedbackStoreRef.current = new UserFeedbackStore()
25
  await feedbackStoreRef.current.initialize()
 
 
 
 
26
  loadCustomTags()
27
  }
28
  initializeStore()
@@ -107,13 +113,53 @@ function App() {
107
  const generatedTags = await clapProcessorRef.current.processAudio(audioBuffer)
108
 
109
  // Store basic audio info for later use
110
- setAudioFeatures({
111
  sampleRate: audioBuffer.sampleRate,
112
  duration: audioBuffer.duration,
113
  numberOfChannels: audioBuffer.numberOfChannels
114
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- setTags(generatedTags.map(tag => ({ ...tag, userFeedback: null })))
117
  } catch (err) {
118
  console.error('Error processing audio:', err)
119
  setError('Failed to process audio. Using fallback tags.')
@@ -139,6 +185,17 @@ function App() {
139
  feedback,
140
  audioHash
141
  )
 
 
 
 
 
 
 
 
 
 
 
142
  } catch (error) {
143
  console.error('Error saving tag feedback:', error)
144
  }
@@ -151,7 +208,8 @@ function App() {
151
  label: newTag.trim(),
152
  confidence: 1.0,
153
  userFeedback: 'custom',
154
- isCustom: true
 
155
  }
156
 
157
  setTags(prev => [...prev, customTag])
@@ -159,6 +217,18 @@ function App() {
159
  try {
160
  await feedbackStoreRef.current.saveCustomTag(newTag.trim())
161
  await feedbackStoreRef.current.saveTagFeedback(newTag.trim(), 'custom', audioHash)
 
 
 
 
 
 
 
 
 
 
 
 
162
  loadCustomTags()
163
  } catch (error) {
164
  console.error('Error saving custom tag:', error)
@@ -236,8 +306,11 @@ function App() {
236
  <div className="tags">
237
  {tags.map((tag, index) => (
238
  <div key={index} className={`tag-item ${tag.userFeedback ? 'has-feedback' : ''}`}>
239
- <span className={`tag ${tag.isCustom ? 'custom' : ''} ${tag.userFeedback === 'negative' ? 'negative' : ''}`}>
240
  {tag.label} ({Math.round(tag.confidence * 100)}%)
 
 
 
241
  </span>
242
  {!tag.isCustom && (
243
  <div className="tag-controls">
 
1
  import { useState, useRef, useEffect } from 'react'
2
  import CLAPProcessor from './clapProcessor'
3
  import UserFeedbackStore from './userFeedbackStore'
4
+ import LocalClassifier from './localClassifier'
5
  import './App.css'
6
 
7
  function App() {
 
19
  const chunksRef = useRef([])
20
  const clapProcessorRef = useRef(null)
21
  const feedbackStoreRef = useRef(null)
22
+ const localClassifierRef = useRef(null)
23
 
24
  useEffect(() => {
25
  const initializeStore = async () => {
26
  feedbackStoreRef.current = new UserFeedbackStore()
27
  await feedbackStoreRef.current.initialize()
28
+
29
+ localClassifierRef.current = new LocalClassifier()
30
+ localClassifierRef.current.loadModel()
31
+
32
  loadCustomTags()
33
  }
34
  initializeStore()
 
113
  const generatedTags = await clapProcessorRef.current.processAudio(audioBuffer)
114
 
115
  // Store basic audio info for later use
116
+ const features = {
117
  sampleRate: audioBuffer.sampleRate,
118
  duration: audioBuffer.duration,
119
  numberOfChannels: audioBuffer.numberOfChannels
120
+ }
121
+ setAudioFeatures(features)
122
+
123
+ // Apply local classifier adjustments
124
+ let finalTags = generatedTags.map(tag => ({ ...tag, userFeedback: null }))
125
+
126
+ if (localClassifierRef.current) {
127
+ const simpleFeatures = localClassifierRef.current.extractSimpleFeatures(features)
128
+ const allPossibleTags = [...generatedTags.map(t => t.label), ...customTags]
129
+ const localPredictions = localClassifierRef.current.predictAll(simpleFeatures, allPossibleTags)
130
+
131
+ // Merge CLAP predictions with local classifier predictions
132
+ const mergedTags = new Map()
133
+
134
+ // Add CLAP tags
135
+ for (const tag of generatedTags) {
136
+ mergedTags.set(tag.label, { ...tag, source: 'clap' })
137
+ }
138
+
139
+ // Add or adjust with local predictions
140
+ for (const pred of localPredictions) {
141
+ if (mergedTags.has(pred.tag)) {
142
+ // Blend CLAP and local predictions
143
+ const existing = mergedTags.get(pred.tag)
144
+ existing.confidence = (existing.confidence + pred.confidence) / 2
145
+ existing.source = 'blended'
146
+ } else if (pred.confidence > 0.6) {
147
+ // Add high-confidence local predictions
148
+ mergedTags.set(pred.tag, {
149
+ label: pred.tag,
150
+ confidence: pred.confidence,
151
+ source: 'local',
152
+ userFeedback: null
153
+ })
154
+ }
155
+ }
156
+
157
+ finalTags = Array.from(mergedTags.values())
158
+ .sort((a, b) => b.confidence - a.confidence)
159
+ .slice(0, 8) // Keep top 8 tags
160
+ }
161
 
162
+ setTags(finalTags)
163
  } catch (err) {
164
  console.error('Error processing audio:', err)
165
  setError('Failed to process audio. Using fallback tags.')
 
185
  feedback,
186
  audioHash
187
  )
188
+
189
+ // Train local classifier on this feedback
190
+ if (localClassifierRef.current && audioFeatures) {
191
+ const simpleFeatures = localClassifierRef.current.extractSimpleFeatures(audioFeatures)
192
+ localClassifierRef.current.trainOnFeedback(
193
+ simpleFeatures,
194
+ updatedTags[tagIndex].label,
195
+ feedback
196
+ )
197
+ localClassifierRef.current.saveModel()
198
+ }
199
  } catch (error) {
200
  console.error('Error saving tag feedback:', error)
201
  }
 
208
  label: newTag.trim(),
209
  confidence: 1.0,
210
  userFeedback: 'custom',
211
+ isCustom: true,
212
+ source: 'custom'
213
  }
214
 
215
  setTags(prev => [...prev, customTag])
 
217
  try {
218
  await feedbackStoreRef.current.saveCustomTag(newTag.trim())
219
  await feedbackStoreRef.current.saveTagFeedback(newTag.trim(), 'custom', audioHash)
220
+
221
+ // Train local classifier on custom tag
222
+ if (localClassifierRef.current && audioFeatures) {
223
+ const simpleFeatures = localClassifierRef.current.extractSimpleFeatures(audioFeatures)
224
+ localClassifierRef.current.trainOnFeedback(
225
+ simpleFeatures,
226
+ newTag.trim(),
227
+ 'custom'
228
+ )
229
+ localClassifierRef.current.saveModel()
230
+ }
231
+
232
  loadCustomTags()
233
  } catch (error) {
234
  console.error('Error saving custom tag:', error)
 
306
  <div className="tags">
307
  {tags.map((tag, index) => (
308
  <div key={index} className={`tag-item ${tag.userFeedback ? 'has-feedback' : ''}`}>
309
+ <span className={`tag ${tag.isCustom ? 'custom' : ''} ${tag.userFeedback === 'negative' ? 'negative' : ''} ${tag.source || 'clap'}`}>
310
  {tag.label} ({Math.round(tag.confidence * 100)}%)
311
+ {tag.source === 'local' && <span className="source-indicator">🧠</span>}
312
+ {tag.source === 'blended' && <span className="source-indicator">⚡</span>}
313
+ {tag.source === 'custom' && <span className="source-indicator">✨</span>}
314
  </span>
315
  {!tag.isCustom && (
316
  <div className="tag-controls">
src/localClassifier.js ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class LocalClassifier {
2
+ constructor() {
3
+ this.weights = new Map(); // tag -> weight vector
4
+ this.biases = new Map(); // tag -> bias
5
+ this.learningRate = 0.01;
6
+ this.featureDim = 512; // CLAP embedding dimension
7
+ this.isInitialized = false;
8
+ }
9
+
10
+ initialize(featureDim = 512) {
11
+ this.featureDim = featureDim;
12
+ this.isInitialized = true;
13
+ }
14
+
15
+ // Simple logistic regression training
16
+ trainOnFeedback(features, tag, feedback) {
17
+ if (!this.isInitialized) {
18
+ this.initialize();
19
+ }
20
+
21
+ // Convert feedback to target value
22
+ let target;
23
+ switch (feedback) {
24
+ case 'positive':
25
+ target = 1.0;
26
+ break;
27
+ case 'negative':
28
+ target = 0.0;
29
+ break;
30
+ case 'custom':
31
+ target = 1.0;
32
+ break;
33
+ default:
34
+ return; // Skip unknown feedback
35
+ }
36
+
37
+ // Initialize weights for new tag
38
+ if (!this.weights.has(tag)) {
39
+ this.weights.set(tag, new Array(this.featureDim).fill(0).map(() =>
40
+ (Math.random() - 0.5) * 0.01
41
+ ));
42
+ this.biases.set(tag, 0);
43
+ }
44
+
45
+ const weights = this.weights.get(tag);
46
+ const bias = this.biases.get(tag);
47
+
48
+ // Forward pass
49
+ let logit = bias;
50
+ for (let i = 0; i < features.length; i++) {
51
+ logit += weights[i] * features[i];
52
+ }
53
+
54
+ // Sigmoid activation
55
+ const prediction = 1 / (1 + Math.exp(-logit));
56
+
57
+ // Compute gradient
58
+ const error = prediction - target;
59
+
60
+ // Update weights and bias
61
+ for (let i = 0; i < features.length; i++) {
62
+ weights[i] -= this.learningRate * error * features[i];
63
+ }
64
+ this.biases.set(tag, bias - this.learningRate * error);
65
+
66
+ // Store updated weights
67
+ this.weights.set(tag, weights);
68
+ }
69
+
70
+ // Predict confidence for a tag given features
71
+ predict(features, tag) {
72
+ if (!this.weights.has(tag)) {
73
+ return null; // No training data for this tag
74
+ }
75
+
76
+ const weights = this.weights.get(tag);
77
+ const bias = this.biases.get(tag);
78
+
79
+ let logit = bias;
80
+ for (let i = 0; i < Math.min(features.length, weights.length); i++) {
81
+ logit += weights[i] * features[i];
82
+ }
83
+
84
+ // Sigmoid activation
85
+ return 1 / (1 + Math.exp(-logit));
86
+ }
87
+
88
+ // Get all predictions for given features
89
+ predictAll(features, candidateTags) {
90
+ const predictions = [];
91
+
92
+ for (const tag of candidateTags) {
93
+ const confidence = this.predict(features, tag);
94
+ if (confidence !== null) {
95
+ predictions.push({ tag, confidence });
96
+ }
97
+ }
98
+
99
+ return predictions.sort((a, b) => b.confidence - a.confidence);
100
+ }
101
+
102
+ // Retrain on batch of feedback data
103
+ retrainOnBatch(feedbackData) {
104
+ for (const item of feedbackData) {
105
+ if (item.audioFeatures && item.correctedTags) {
106
+ // Create simple features from audio metadata
107
+ const features = this.extractSimpleFeatures(item.audioFeatures);
108
+
109
+ // Train on corrected tags
110
+ for (const tagData of item.correctedTags) {
111
+ this.trainOnFeedback(features, tagData.tag, tagData.feedback);
112
+ }
113
+ }
114
+ }
115
+ }
116
+
117
+ // Extract simple features from audio metadata
118
+ extractSimpleFeatures(audioFeatures) {
119
+ // Create a simple feature vector from audio metadata
120
+ // In a real implementation, this would use actual CLAP embeddings
121
+ const features = new Array(this.featureDim).fill(0);
122
+
123
+ if (audioFeatures) {
124
+ // Use basic audio properties to create pseudo-features
125
+ features[0] = audioFeatures.duration / 60; // Duration in minutes
126
+ features[1] = audioFeatures.sampleRate / 48000; // Normalized sample rate
127
+ features[2] = audioFeatures.numberOfChannels; // Number of channels
128
+
129
+ // Fill remaining with small random values based on hash of properties
130
+ const seed = this.simpleHash(JSON.stringify(audioFeatures));
131
+ for (let i = 3; i < this.featureDim; i++) {
132
+ features[i] = this.seededRandom(seed + i) * 0.1;
133
+ }
134
+ }
135
+
136
+ return features;
137
+ }
138
+
139
+ // Simple hash function for seeded random
140
+ simpleHash(str) {
141
+ let hash = 0;
142
+ for (let i = 0; i < str.length; i++) {
143
+ const char = str.charCodeAt(i);
144
+ hash = ((hash << 5) - hash) + char;
145
+ hash = hash & hash; // Convert to 32-bit integer
146
+ }
147
+ return Math.abs(hash);
148
+ }
149
+
150
+ // Seeded random number generator
151
+ seededRandom(seed) {
152
+ const x = Math.sin(seed) * 10000;
153
+ return x - Math.floor(x);
154
+ }
155
+
156
+ // Save model to localStorage
157
+ saveModel() {
158
+ const modelData = {
159
+ weights: Object.fromEntries(this.weights),
160
+ biases: Object.fromEntries(this.biases),
161
+ featureDim: this.featureDim,
162
+ learningRate: this.learningRate
163
+ };
164
+
165
+ localStorage.setItem('clipTaggerModel', JSON.stringify(modelData));
166
+ }
167
+
168
+ // Load model from localStorage
169
+ loadModel() {
170
+ const saved = localStorage.getItem('clipTaggerModel');
171
+ if (saved) {
172
+ try {
173
+ const modelData = JSON.parse(saved);
174
+ this.weights = new Map(Object.entries(modelData.weights));
175
+ this.biases = new Map(Object.entries(modelData.biases));
176
+ this.featureDim = modelData.featureDim || 512;
177
+ this.learningRate = modelData.learningRate || 0.01;
178
+ this.isInitialized = true;
179
+ return true;
180
+ } catch (error) {
181
+ console.error('Error loading model:', error);
182
+ }
183
+ }
184
+ return false;
185
+ }
186
+
187
+ // Get model statistics
188
+ getModelStats() {
189
+ return {
190
+ trainedTags: this.weights.size,
191
+ featureDim: this.featureDim,
192
+ learningRate: this.learningRate,
193
+ tags: Array.from(this.weights.keys())
194
+ };
195
+ }
196
+
197
+ // Clear the model
198
+ clearModel() {
199
+ this.weights.clear();
200
+ this.biases.clear();
201
+ localStorage.removeItem('clipTaggerModel');
202
+ }
203
+ }
204
+
205
+ export default LocalClassifier;