Spaces:
Running
Running
Implement adaptive local classifier that learns from user feedback
Browse files- src/App.css +13 -0
- src/App.jsx +78 -5
- src/localClassifier.js +205 -0
src/App.css
CHANGED
@@ -147,6 +147,19 @@ header p {
|
|
147 |
opacity: 0.6;
|
148 |
}
|
149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
.tag-controls {
|
151 |
display: flex;
|
152 |
gap: 0.25rem;
|
|
|
147 |
opacity: 0.6;
|
148 |
}
|
149 |
|
150 |
+
.tag.local {
|
151 |
+
background: linear-gradient(45deg, #9b59b6, #8e44ad);
|
152 |
+
}
|
153 |
+
|
154 |
+
.tag.blended {
|
155 |
+
background: linear-gradient(45deg, #f39c12, #e67e22);
|
156 |
+
}
|
157 |
+
|
158 |
+
.source-indicator {
|
159 |
+
margin-left: 0.5rem;
|
160 |
+
font-size: 0.8em;
|
161 |
+
}
|
162 |
+
|
163 |
.tag-controls {
|
164 |
display: flex;
|
165 |
gap: 0.25rem;
|
src/App.jsx
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import { useState, useRef, useEffect } from 'react'
|
2 |
import CLAPProcessor from './clapProcessor'
|
3 |
import UserFeedbackStore from './userFeedbackStore'
|
|
|
4 |
import './App.css'
|
5 |
|
6 |
function App() {
|
@@ -18,11 +19,16 @@ function App() {
|
|
18 |
const chunksRef = useRef([])
|
19 |
const clapProcessorRef = useRef(null)
|
20 |
const feedbackStoreRef = useRef(null)
|
|
|
21 |
|
22 |
useEffect(() => {
|
23 |
const initializeStore = async () => {
|
24 |
feedbackStoreRef.current = new UserFeedbackStore()
|
25 |
await feedbackStoreRef.current.initialize()
|
|
|
|
|
|
|
|
|
26 |
loadCustomTags()
|
27 |
}
|
28 |
initializeStore()
|
@@ -107,13 +113,53 @@ function App() {
|
|
107 |
const generatedTags = await clapProcessorRef.current.processAudio(audioBuffer)
|
108 |
|
109 |
// Store basic audio info for later use
|
110 |
-
|
111 |
sampleRate: audioBuffer.sampleRate,
|
112 |
duration: audioBuffer.duration,
|
113 |
numberOfChannels: audioBuffer.numberOfChannels
|
114 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
-
setTags(
|
117 |
} catch (err) {
|
118 |
console.error('Error processing audio:', err)
|
119 |
setError('Failed to process audio. Using fallback tags.')
|
@@ -139,6 +185,17 @@ function App() {
|
|
139 |
feedback,
|
140 |
audioHash
|
141 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
} catch (error) {
|
143 |
console.error('Error saving tag feedback:', error)
|
144 |
}
|
@@ -151,7 +208,8 @@ function App() {
|
|
151 |
label: newTag.trim(),
|
152 |
confidence: 1.0,
|
153 |
userFeedback: 'custom',
|
154 |
-
isCustom: true
|
|
|
155 |
}
|
156 |
|
157 |
setTags(prev => [...prev, customTag])
|
@@ -159,6 +217,18 @@ function App() {
|
|
159 |
try {
|
160 |
await feedbackStoreRef.current.saveCustomTag(newTag.trim())
|
161 |
await feedbackStoreRef.current.saveTagFeedback(newTag.trim(), 'custom', audioHash)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
loadCustomTags()
|
163 |
} catch (error) {
|
164 |
console.error('Error saving custom tag:', error)
|
@@ -236,8 +306,11 @@ function App() {
|
|
236 |
<div className="tags">
|
237 |
{tags.map((tag, index) => (
|
238 |
<div key={index} className={`tag-item ${tag.userFeedback ? 'has-feedback' : ''}`}>
|
239 |
-
<span className={`tag ${tag.isCustom ? 'custom' : ''} ${tag.userFeedback === 'negative' ? 'negative' : ''}`}>
|
240 |
{tag.label} ({Math.round(tag.confidence * 100)}%)
|
|
|
|
|
|
|
241 |
</span>
|
242 |
{!tag.isCustom && (
|
243 |
<div className="tag-controls">
|
|
|
1 |
import { useState, useRef, useEffect } from 'react'
|
2 |
import CLAPProcessor from './clapProcessor'
|
3 |
import UserFeedbackStore from './userFeedbackStore'
|
4 |
+
import LocalClassifier from './localClassifier'
|
5 |
import './App.css'
|
6 |
|
7 |
function App() {
|
|
|
19 |
const chunksRef = useRef([])
|
20 |
const clapProcessorRef = useRef(null)
|
21 |
const feedbackStoreRef = useRef(null)
|
22 |
+
const localClassifierRef = useRef(null)
|
23 |
|
24 |
useEffect(() => {
|
25 |
const initializeStore = async () => {
|
26 |
feedbackStoreRef.current = new UserFeedbackStore()
|
27 |
await feedbackStoreRef.current.initialize()
|
28 |
+
|
29 |
+
localClassifierRef.current = new LocalClassifier()
|
30 |
+
localClassifierRef.current.loadModel()
|
31 |
+
|
32 |
loadCustomTags()
|
33 |
}
|
34 |
initializeStore()
|
|
|
113 |
const generatedTags = await clapProcessorRef.current.processAudio(audioBuffer)
|
114 |
|
115 |
// Store basic audio info for later use
|
116 |
+
const features = {
|
117 |
sampleRate: audioBuffer.sampleRate,
|
118 |
duration: audioBuffer.duration,
|
119 |
numberOfChannels: audioBuffer.numberOfChannels
|
120 |
+
}
|
121 |
+
setAudioFeatures(features)
|
122 |
+
|
123 |
+
// Apply local classifier adjustments
|
124 |
+
let finalTags = generatedTags.map(tag => ({ ...tag, userFeedback: null }))
|
125 |
+
|
126 |
+
if (localClassifierRef.current) {
|
127 |
+
const simpleFeatures = localClassifierRef.current.extractSimpleFeatures(features)
|
128 |
+
const allPossibleTags = [...generatedTags.map(t => t.label), ...customTags]
|
129 |
+
const localPredictions = localClassifierRef.current.predictAll(simpleFeatures, allPossibleTags)
|
130 |
+
|
131 |
+
// Merge CLAP predictions with local classifier predictions
|
132 |
+
const mergedTags = new Map()
|
133 |
+
|
134 |
+
// Add CLAP tags
|
135 |
+
for (const tag of generatedTags) {
|
136 |
+
mergedTags.set(tag.label, { ...tag, source: 'clap' })
|
137 |
+
}
|
138 |
+
|
139 |
+
// Add or adjust with local predictions
|
140 |
+
for (const pred of localPredictions) {
|
141 |
+
if (mergedTags.has(pred.tag)) {
|
142 |
+
// Blend CLAP and local predictions
|
143 |
+
const existing = mergedTags.get(pred.tag)
|
144 |
+
existing.confidence = (existing.confidence + pred.confidence) / 2
|
145 |
+
existing.source = 'blended'
|
146 |
+
} else if (pred.confidence > 0.6) {
|
147 |
+
// Add high-confidence local predictions
|
148 |
+
mergedTags.set(pred.tag, {
|
149 |
+
label: pred.tag,
|
150 |
+
confidence: pred.confidence,
|
151 |
+
source: 'local',
|
152 |
+
userFeedback: null
|
153 |
+
})
|
154 |
+
}
|
155 |
+
}
|
156 |
+
|
157 |
+
finalTags = Array.from(mergedTags.values())
|
158 |
+
.sort((a, b) => b.confidence - a.confidence)
|
159 |
+
.slice(0, 8) // Keep top 8 tags
|
160 |
+
}
|
161 |
|
162 |
+
setTags(finalTags)
|
163 |
} catch (err) {
|
164 |
console.error('Error processing audio:', err)
|
165 |
setError('Failed to process audio. Using fallback tags.')
|
|
|
185 |
feedback,
|
186 |
audioHash
|
187 |
)
|
188 |
+
|
189 |
+
// Train local classifier on this feedback
|
190 |
+
if (localClassifierRef.current && audioFeatures) {
|
191 |
+
const simpleFeatures = localClassifierRef.current.extractSimpleFeatures(audioFeatures)
|
192 |
+
localClassifierRef.current.trainOnFeedback(
|
193 |
+
simpleFeatures,
|
194 |
+
updatedTags[tagIndex].label,
|
195 |
+
feedback
|
196 |
+
)
|
197 |
+
localClassifierRef.current.saveModel()
|
198 |
+
}
|
199 |
} catch (error) {
|
200 |
console.error('Error saving tag feedback:', error)
|
201 |
}
|
|
|
208 |
label: newTag.trim(),
|
209 |
confidence: 1.0,
|
210 |
userFeedback: 'custom',
|
211 |
+
isCustom: true,
|
212 |
+
source: 'custom'
|
213 |
}
|
214 |
|
215 |
setTags(prev => [...prev, customTag])
|
|
|
217 |
try {
|
218 |
await feedbackStoreRef.current.saveCustomTag(newTag.trim())
|
219 |
await feedbackStoreRef.current.saveTagFeedback(newTag.trim(), 'custom', audioHash)
|
220 |
+
|
221 |
+
// Train local classifier on custom tag
|
222 |
+
if (localClassifierRef.current && audioFeatures) {
|
223 |
+
const simpleFeatures = localClassifierRef.current.extractSimpleFeatures(audioFeatures)
|
224 |
+
localClassifierRef.current.trainOnFeedback(
|
225 |
+
simpleFeatures,
|
226 |
+
newTag.trim(),
|
227 |
+
'custom'
|
228 |
+
)
|
229 |
+
localClassifierRef.current.saveModel()
|
230 |
+
}
|
231 |
+
|
232 |
loadCustomTags()
|
233 |
} catch (error) {
|
234 |
console.error('Error saving custom tag:', error)
|
|
|
306 |
<div className="tags">
|
307 |
{tags.map((tag, index) => (
|
308 |
<div key={index} className={`tag-item ${tag.userFeedback ? 'has-feedback' : ''}`}>
|
309 |
+
<span className={`tag ${tag.isCustom ? 'custom' : ''} ${tag.userFeedback === 'negative' ? 'negative' : ''} ${tag.source || 'clap'}`}>
|
310 |
{tag.label} ({Math.round(tag.confidence * 100)}%)
|
311 |
+
{tag.source === 'local' && <span className="source-indicator">🧠</span>}
|
312 |
+
{tag.source === 'blended' && <span className="source-indicator">⚡</span>}
|
313 |
+
{tag.source === 'custom' && <span className="source-indicator">✨</span>}
|
314 |
</span>
|
315 |
{!tag.isCustom && (
|
316 |
<div className="tag-controls">
|
src/localClassifier.js
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class LocalClassifier {
|
2 |
+
constructor() {
|
3 |
+
this.weights = new Map(); // tag -> weight vector
|
4 |
+
this.biases = new Map(); // tag -> bias
|
5 |
+
this.learningRate = 0.01;
|
6 |
+
this.featureDim = 512; // CLAP embedding dimension
|
7 |
+
this.isInitialized = false;
|
8 |
+
}
|
9 |
+
|
10 |
+
initialize(featureDim = 512) {
|
11 |
+
this.featureDim = featureDim;
|
12 |
+
this.isInitialized = true;
|
13 |
+
}
|
14 |
+
|
15 |
+
// Simple logistic regression training
|
16 |
+
trainOnFeedback(features, tag, feedback) {
|
17 |
+
if (!this.isInitialized) {
|
18 |
+
this.initialize();
|
19 |
+
}
|
20 |
+
|
21 |
+
// Convert feedback to target value
|
22 |
+
let target;
|
23 |
+
switch (feedback) {
|
24 |
+
case 'positive':
|
25 |
+
target = 1.0;
|
26 |
+
break;
|
27 |
+
case 'negative':
|
28 |
+
target = 0.0;
|
29 |
+
break;
|
30 |
+
case 'custom':
|
31 |
+
target = 1.0;
|
32 |
+
break;
|
33 |
+
default:
|
34 |
+
return; // Skip unknown feedback
|
35 |
+
}
|
36 |
+
|
37 |
+
// Initialize weights for new tag
|
38 |
+
if (!this.weights.has(tag)) {
|
39 |
+
this.weights.set(tag, new Array(this.featureDim).fill(0).map(() =>
|
40 |
+
(Math.random() - 0.5) * 0.01
|
41 |
+
));
|
42 |
+
this.biases.set(tag, 0);
|
43 |
+
}
|
44 |
+
|
45 |
+
const weights = this.weights.get(tag);
|
46 |
+
const bias = this.biases.get(tag);
|
47 |
+
|
48 |
+
// Forward pass
|
49 |
+
let logit = bias;
|
50 |
+
for (let i = 0; i < features.length; i++) {
|
51 |
+
logit += weights[i] * features[i];
|
52 |
+
}
|
53 |
+
|
54 |
+
// Sigmoid activation
|
55 |
+
const prediction = 1 / (1 + Math.exp(-logit));
|
56 |
+
|
57 |
+
// Compute gradient
|
58 |
+
const error = prediction - target;
|
59 |
+
|
60 |
+
// Update weights and bias
|
61 |
+
for (let i = 0; i < features.length; i++) {
|
62 |
+
weights[i] -= this.learningRate * error * features[i];
|
63 |
+
}
|
64 |
+
this.biases.set(tag, bias - this.learningRate * error);
|
65 |
+
|
66 |
+
// Store updated weights
|
67 |
+
this.weights.set(tag, weights);
|
68 |
+
}
|
69 |
+
|
70 |
+
// Predict confidence for a tag given features
|
71 |
+
predict(features, tag) {
|
72 |
+
if (!this.weights.has(tag)) {
|
73 |
+
return null; // No training data for this tag
|
74 |
+
}
|
75 |
+
|
76 |
+
const weights = this.weights.get(tag);
|
77 |
+
const bias = this.biases.get(tag);
|
78 |
+
|
79 |
+
let logit = bias;
|
80 |
+
for (let i = 0; i < Math.min(features.length, weights.length); i++) {
|
81 |
+
logit += weights[i] * features[i];
|
82 |
+
}
|
83 |
+
|
84 |
+
// Sigmoid activation
|
85 |
+
return 1 / (1 + Math.exp(-logit));
|
86 |
+
}
|
87 |
+
|
88 |
+
// Get all predictions for given features
|
89 |
+
predictAll(features, candidateTags) {
|
90 |
+
const predictions = [];
|
91 |
+
|
92 |
+
for (const tag of candidateTags) {
|
93 |
+
const confidence = this.predict(features, tag);
|
94 |
+
if (confidence !== null) {
|
95 |
+
predictions.push({ tag, confidence });
|
96 |
+
}
|
97 |
+
}
|
98 |
+
|
99 |
+
return predictions.sort((a, b) => b.confidence - a.confidence);
|
100 |
+
}
|
101 |
+
|
102 |
+
// Retrain on batch of feedback data
|
103 |
+
retrainOnBatch(feedbackData) {
|
104 |
+
for (const item of feedbackData) {
|
105 |
+
if (item.audioFeatures && item.correctedTags) {
|
106 |
+
// Create simple features from audio metadata
|
107 |
+
const features = this.extractSimpleFeatures(item.audioFeatures);
|
108 |
+
|
109 |
+
// Train on corrected tags
|
110 |
+
for (const tagData of item.correctedTags) {
|
111 |
+
this.trainOnFeedback(features, tagData.tag, tagData.feedback);
|
112 |
+
}
|
113 |
+
}
|
114 |
+
}
|
115 |
+
}
|
116 |
+
|
117 |
+
// Extract simple features from audio metadata
|
118 |
+
extractSimpleFeatures(audioFeatures) {
|
119 |
+
// Create a simple feature vector from audio metadata
|
120 |
+
// In a real implementation, this would use actual CLAP embeddings
|
121 |
+
const features = new Array(this.featureDim).fill(0);
|
122 |
+
|
123 |
+
if (audioFeatures) {
|
124 |
+
// Use basic audio properties to create pseudo-features
|
125 |
+
features[0] = audioFeatures.duration / 60; // Duration in minutes
|
126 |
+
features[1] = audioFeatures.sampleRate / 48000; // Normalized sample rate
|
127 |
+
features[2] = audioFeatures.numberOfChannels; // Number of channels
|
128 |
+
|
129 |
+
// Fill remaining with small random values based on hash of properties
|
130 |
+
const seed = this.simpleHash(JSON.stringify(audioFeatures));
|
131 |
+
for (let i = 3; i < this.featureDim; i++) {
|
132 |
+
features[i] = this.seededRandom(seed + i) * 0.1;
|
133 |
+
}
|
134 |
+
}
|
135 |
+
|
136 |
+
return features;
|
137 |
+
}
|
138 |
+
|
139 |
+
// Simple hash function for seeded random
|
140 |
+
simpleHash(str) {
|
141 |
+
let hash = 0;
|
142 |
+
for (let i = 0; i < str.length; i++) {
|
143 |
+
const char = str.charCodeAt(i);
|
144 |
+
hash = ((hash << 5) - hash) + char;
|
145 |
+
hash = hash & hash; // Convert to 32-bit integer
|
146 |
+
}
|
147 |
+
return Math.abs(hash);
|
148 |
+
}
|
149 |
+
|
150 |
+
// Seeded random number generator
|
151 |
+
seededRandom(seed) {
|
152 |
+
const x = Math.sin(seed) * 10000;
|
153 |
+
return x - Math.floor(x);
|
154 |
+
}
|
155 |
+
|
156 |
+
// Save model to localStorage
|
157 |
+
saveModel() {
|
158 |
+
const modelData = {
|
159 |
+
weights: Object.fromEntries(this.weights),
|
160 |
+
biases: Object.fromEntries(this.biases),
|
161 |
+
featureDim: this.featureDim,
|
162 |
+
learningRate: this.learningRate
|
163 |
+
};
|
164 |
+
|
165 |
+
localStorage.setItem('clipTaggerModel', JSON.stringify(modelData));
|
166 |
+
}
|
167 |
+
|
168 |
+
// Load model from localStorage
|
169 |
+
loadModel() {
|
170 |
+
const saved = localStorage.getItem('clipTaggerModel');
|
171 |
+
if (saved) {
|
172 |
+
try {
|
173 |
+
const modelData = JSON.parse(saved);
|
174 |
+
this.weights = new Map(Object.entries(modelData.weights));
|
175 |
+
this.biases = new Map(Object.entries(modelData.biases));
|
176 |
+
this.featureDim = modelData.featureDim || 512;
|
177 |
+
this.learningRate = modelData.learningRate || 0.01;
|
178 |
+
this.isInitialized = true;
|
179 |
+
return true;
|
180 |
+
} catch (error) {
|
181 |
+
console.error('Error loading model:', error);
|
182 |
+
}
|
183 |
+
}
|
184 |
+
return false;
|
185 |
+
}
|
186 |
+
|
187 |
+
// Get model statistics
|
188 |
+
getModelStats() {
|
189 |
+
return {
|
190 |
+
trainedTags: this.weights.size,
|
191 |
+
featureDim: this.featureDim,
|
192 |
+
learningRate: this.learningRate,
|
193 |
+
tags: Array.from(this.weights.keys())
|
194 |
+
};
|
195 |
+
}
|
196 |
+
|
197 |
+
// Clear the model
|
198 |
+
clearModel() {
|
199 |
+
this.weights.clear();
|
200 |
+
this.biases.clear();
|
201 |
+
localStorage.removeItem('clipTaggerModel');
|
202 |
+
}
|
203 |
+
}
|
204 |
+
|
205 |
+
export default LocalClassifier;
|