willwade commited on
Commit
deb6f27
·
1 Parent(s): 7b20b20

first attempt

Browse files
Files changed (6) hide show
  1. README-HF.md +15 -11
  2. README.md +20 -5
  3. app.py +238 -46
  4. demo.py +25 -12
  5. social_graph.json +7 -0
  6. utils.py +144 -67
README-HF.md CHANGED
@@ -1,16 +1,19 @@
1
- # AAC Social Graph Assistant for MND
2
 
3
- An Augmentative and Alternative Communication (AAC) system that uses a social graph to provide contextually relevant suggestions for users with Motor Neurone Disease (MND).
4
 
5
  ## About
6
 
7
- This demo showcases an AAC system that uses a social graph to provide contextually relevant suggestions for users with MND. The system allows users to select who they are talking to and provides suggestions based on the relationship and common topics of conversation, tailored to the British context with NHS healthcare terminology.
8
 
9
  ## Features
10
 
11
- - **Person-Specific Suggestions**: Select who you're talking to and get suggestions tailored to that relationship
 
12
  - **Context-Aware**: Uses a social graph to understand relationships and common topics
13
- - **Multiple Suggestion Types**: Get suggestions from a language model, common phrases, or predefined utterance categories
 
 
14
  - **British Context**: Designed with British English and NHS healthcare context in mind
15
  - **MND-Specific**: Tailored for the needs of someone with Motor Neurone Disease
16
  - **Expandable**: Easily improve the system by enhancing the social graph JSON file
@@ -30,12 +33,13 @@ The current social graph represents a British person with MND who:
30
 
31
  ## How to Use
32
 
33
- 1. Select a person from the dropdown menu
34
- 2. View their context information
35
- 3. Optionally enter current conversation context or record audio
36
- 4. If you record audio, click "Transcribe" to convert it to text
37
- 5. Choose a suggestion type
38
- 6. Click "Generate Suggestions" to get contextually relevant phrases
 
39
 
40
  ## Customizing the Social Graph
41
 
 
1
+ # Will's AAC Communication Aid
2
 
3
+ An Augmentative and Alternative Communication (AAC) system that uses a social graph to provide contextually relevant suggestions for Will, a user with Motor Neurone Disease (MND).
4
 
5
  ## About
6
 
7
+ This demo simulates an AAC system from Will's perspective (a 38-year-old with MND). The system allows Will to select who he's talking to, optionally choose a conversation topic, and get appropriate responses based on what the other person has said. All suggestions are tailored to the relationship and conversation context, using British English and NHS healthcare terminology where appropriate.
8
 
9
  ## Features
10
 
11
+ - **Person-Specific Suggestions**: Select who you (Will) are talking to and get tailored responses
12
+ - **Topic Selection**: Choose conversation topics relevant to your relationship
13
  - **Context-Aware**: Uses a social graph to understand relationships and common topics
14
+ - **Speech Recognition**: Record what others have said to you and have it transcribed
15
+ - **Auto-Detection**: Automatically detect conversation type from what others say
16
+ - **Multiple Response Types**: Get AI-generated responses, common phrases, or category-specific utterances
17
  - **British Context**: Designed with British English and NHS healthcare context in mind
18
  - **MND-Specific**: Tailored for the needs of someone with Motor Neurone Disease
19
  - **Expandable**: Easily improve the system by enhancing the social graph JSON file
 
33
 
34
  ## How to Use
35
 
36
+ 1. Select who you (Will) are talking to from the dropdown menu
37
+ 2. Optionally select a conversation topic
38
+ 3. View the relationship context information
39
+ 4. Enter what the other person said to you, or record audio
40
+ 5. If you record audio, click "Transcribe" to convert it to text
41
+ 6. Choose how you want to respond (auto-detect, AI-generated, common phrases, etc.)
42
+ 7. Click "Generate My Responses" to get contextually relevant suggestions
43
 
44
  ## Customizing the Social Graph
45
 
README.md CHANGED
@@ -37,16 +37,31 @@ python app.py
37
 
38
  ## How It Works
39
 
40
- 1. **Social Graph**: The system uses a JSON-based social graph (`social_graph.json`) that contains information about people, their relationships, common topics, and phrases.
 
 
 
41
 
42
- 2. **Context Retrieval**: When you select a person, the system retrieves relevant context information from the social graph.
43
 
44
- 3. **Suggestion Generation**: Based on the selected person and optional conversation context, the system generates suggestions using:
 
 
45
  - A language model (Flan-T5)
46
- - Common phrases associated with the person
47
  - General utterance categories (greetings, needs, emotions, questions)
48
 
49
- 4. **User Interface**: The Gradio interface provides an intuitive way to interact with the system, select people, and get suggestions.
 
 
 
 
 
 
 
 
 
 
50
 
51
  ## Customizing the Social Graph
52
 
 
37
 
38
  ## How It Works
39
 
40
+ 1. **Social Graph**: The system uses a JSON-based social graph (`social_graph.json`) that contains information about:
41
+ - Will (the AAC user) - a 38-year-old with MND
42
+ - People in Will's life (family, healthcare providers, friends, colleagues)
43
+ - Relationships, common topics, and phrases
44
 
45
+ 2. **Context Retrieval**: When you select who you are (as someone talking to Will), the system retrieves relevant context information from the social graph.
46
 
47
+ 3. **Conversation Input**: You enter or record what you've said to Will in the conversation.
48
+
49
+ 4. **Suggestion Generation**: Based on who you are and what you've said, the system generates appropriate responses for Will using:
50
  - A language model (Flan-T5)
51
+ - Common phrases Will might say to you
52
  - General utterance categories (greetings, needs, emotions, questions)
53
 
54
+ 5. **User Interface**: The Gradio interface provides an intuitive way to simulate conversations with Will and see what an AAC system might suggest for him to say.
55
+
56
+ ## How to Use
57
+
58
+ 1. Select who you (Will) are talking to from the dropdown menu
59
+ 2. Optionally select a conversation topic
60
+ 3. View the relationship context information
61
+ 4. Enter what the other person said to you, or record audio
62
+ 5. If you record audio, click "Transcribe" to convert it to text
63
+ 6. Choose how you want to respond (auto-detect, AI-generated, common phrases, etc.)
64
+ 7. Click "Generate My Responses" to get contextually relevant suggestions
65
 
66
  ## Customizing the Social Graph
67
 
app.py CHANGED
@@ -6,14 +6,24 @@ from utils import SocialGraphManager, SuggestionGenerator
6
 
7
  # Initialize the social graph manager and suggestion generator
8
  social_graph = SocialGraphManager("social_graph.json")
9
- suggestion_generator = SuggestionGenerator()
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Initialize Whisper model (using the smallest model for speed)
12
  try:
13
  whisper_model = whisper.load_model("tiny")
14
  whisper_loaded = True
15
- except Exception as e:
16
- print(f"Warning: Could not load Whisper model: {e}")
17
  whisper_loaded = False
18
 
19
 
@@ -25,7 +35,22 @@ def format_person_display(person):
25
  def get_people_choices():
26
  """Get formatted choices for the people dropdown."""
27
  people = social_graph.get_people_list()
28
- return {format_person_display(person): person["id"] for person in people}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  def get_suggestion_categories():
@@ -38,49 +63,164 @@ def get_suggestion_categories():
38
  def on_person_change(person_id):
39
  """Handle person selection change."""
40
  if not person_id:
41
- return "", ""
42
 
43
  person_context = social_graph.get_person_context(person_id)
44
- context_info = (
45
- f"**{person_context.get('name', '')}** - {person_context.get('role', '')}\n\n"
46
- )
47
- context_info += f"**Topics:** {', '.join(person_context.get('topics', []))}\n\n"
48
- context_info += f"**Frequency:** {person_context.get('frequency', '')}\n\n"
49
- context_info += f"**Context:** {person_context.get('context', '')}"
 
 
 
 
 
 
 
50
 
51
  # Get common phrases for this person
52
  phrases = person_context.get("common_phrases", [])
53
  phrases_text = "\n\n".join(phrases)
54
 
55
- return context_info, phrases_text
 
56
 
 
57
 
58
- def generate_suggestions(person_id, user_input, suggestion_type):
 
59
  """Generate suggestions based on the selected person and user input."""
 
 
 
 
60
  if not person_id:
61
- return "Please select a person first."
 
62
 
63
  person_context = social_graph.get_person_context(person_id)
64
-
65
- # If suggestion type is "model", use the language model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  if suggestion_type == "model":
67
- suggestion = suggestion_generator.generate_suggestion(
68
- person_context, user_input
69
- )
70
- return suggestion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # If suggestion type is "common_phrases", use the person's common phrases
73
  elif suggestion_type == "common_phrases":
74
  phrases = social_graph.get_relevant_phrases(person_id, user_input)
75
- return "\n\n".join(phrases)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # If suggestion type is a category from common_utterances
78
  elif suggestion_type in get_suggestion_categories():
 
79
  utterances = social_graph.get_common_utterances(suggestion_type)
80
- return "\n\n".join(utterances)
 
 
 
81
 
82
  # Default fallback
83
- return "No suggestions available."
 
 
 
 
 
84
 
85
 
86
  def transcribe_audio(audio_path):
@@ -92,40 +232,75 @@ def transcribe_audio(audio_path):
92
  # Transcribe the audio
93
  result = whisper_model.transcribe(audio_path)
94
  return result["text"]
95
- except Exception as e:
96
- print(f"Error transcribing audio: {e}")
97
  return "Could not transcribe audio. Please try again."
98
 
99
 
100
  # Create the Gradio interface
101
- with gr.Blocks(title="AAC Social Graph Assistant") as demo:
102
- gr.Markdown("# AAC Social Graph Assistant")
103
  gr.Markdown(
104
- "Select who you're talking to, and get contextually relevant suggestions."
 
 
 
 
 
 
 
 
105
  )
106
 
 
 
 
 
 
 
 
 
 
 
 
107
  with gr.Row():
108
  with gr.Column(scale=1):
109
  # Person selection
110
  person_dropdown = gr.Dropdown(
111
- choices=get_people_choices(), label="Who are you talking to?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  )
113
 
114
  # Context display
115
- context_display = gr.Markdown(label="Context Information")
116
 
117
  # User input section
118
  with gr.Row():
119
  user_input = gr.Textbox(
120
- label="Your current conversation (optional)",
121
- placeholder="Type or paste current conversation context here...",
122
  lines=3,
123
  )
124
 
125
  # Audio input
126
  with gr.Row():
127
  audio_input = gr.Audio(
128
- label="Or record your conversation",
129
  type="filepath",
130
  sources=["microphone"],
131
  )
@@ -133,39 +308,56 @@ with gr.Blocks(title="AAC Social Graph Assistant") as demo:
133
 
134
  # Suggestion type selection
135
  suggestion_type = gr.Radio(
136
- choices=["model", "common_phrases"] + get_suggestion_categories(),
137
- value="model",
138
- label="Suggestion Type",
 
 
 
 
 
 
139
  )
140
 
141
  # Generate button
142
- generate_btn = gr.Button("Generate Suggestions", variant="primary")
143
 
144
  with gr.Column(scale=1):
145
  # Common phrases
146
  common_phrases = gr.Textbox(
147
- label="Common Phrases",
148
- placeholder="Common phrases will appear here...",
149
  lines=5,
150
  )
151
 
152
  # Suggestions output
153
- suggestions_output = gr.Textbox(
154
- label="Suggested Phrases",
155
- placeholder="Suggestions will appear here...",
156
- lines=8,
157
  )
158
 
159
  # Set up event handlers
 
 
 
 
 
 
 
 
 
 
 
160
  person_dropdown.change(
161
- on_person_change,
162
  inputs=[person_dropdown],
163
- outputs=[context_display, common_phrases],
164
  )
165
 
 
166
  generate_btn.click(
167
  generate_suggestions,
168
- inputs=[person_dropdown, user_input, suggestion_type],
169
  outputs=[suggestions_output],
170
  )
171
 
 
6
 
7
  # Initialize the social graph manager and suggestion generator
8
  social_graph = SocialGraphManager("social_graph.json")
9
+
10
+ # Initialize the suggestion generator with distilgpt2
11
+ suggestion_generator = SuggestionGenerator("distilgpt2")
12
+
13
+ # Test the model to make sure it's working
14
+ test_result = suggestion_generator.test_model()
15
+ print(f"Model test result: {test_result}")
16
+
17
+ # If the model didn't load, use the fallback responses
18
+ if not suggestion_generator.model_loaded:
19
+ print("Model failed to load, using fallback responses...")
20
+ # The SuggestionGenerator class has built-in fallback responses
21
 
22
  # Initialize Whisper model (using the smallest model for speed)
23
  try:
24
  whisper_model = whisper.load_model("tiny")
25
  whisper_loaded = True
26
+ except Exception:
 
27
  whisper_loaded = False
28
 
29
 
 
35
  def get_people_choices():
36
  """Get formatted choices for the people dropdown."""
37
  people = social_graph.get_people_list()
38
+ choices = {}
39
+ for person in people:
40
+ display_name = format_person_display(person)
41
+ person_id = person["id"]
42
+ choices[display_name] = person_id
43
+ return choices
44
+
45
+
46
+ def get_topics_for_person(person_id):
47
+ """Get topics for a specific person."""
48
+ if not person_id:
49
+ return []
50
+
51
+ person_context = social_graph.get_person_context(person_id)
52
+ topics = person_context.get("topics", [])
53
+ return topics
54
 
55
 
56
  def get_suggestion_categories():
 
63
  def on_person_change(person_id):
64
  """Handle person selection change."""
65
  if not person_id:
66
+ return "", "", []
67
 
68
  person_context = social_graph.get_person_context(person_id)
69
+
70
+ # Create a more user-friendly context display
71
+ name = person_context.get("name", "")
72
+ role = person_context.get("role", "")
73
+ frequency = person_context.get("frequency", "")
74
+ context_text = person_context.get("context", "")
75
+
76
+ context_info = f"""### I'm talking to: {name}
77
+ **Relationship:** {role}
78
+ **How often we talk:** {frequency}
79
+
80
+ **Our relationship:** {context_text}
81
+ """
82
 
83
  # Get common phrases for this person
84
  phrases = person_context.get("common_phrases", [])
85
  phrases_text = "\n\n".join(phrases)
86
 
87
+ # Get topics for this person
88
+ topics = person_context.get("topics", [])
89
 
90
+ return context_info, phrases_text, topics
91
 
92
+
93
+ def generate_suggestions(person_id, user_input, suggestion_type, selected_topic=None):
94
  """Generate suggestions based on the selected person and user input."""
95
+ print(
96
+ f"Generating suggestions with: person_id={person_id}, user_input={user_input}, suggestion_type={suggestion_type}, selected_topic={selected_topic}"
97
+ )
98
+
99
  if not person_id:
100
+ print("No person_id provided")
101
+ return "Please select who you're talking to first."
102
 
103
  person_context = social_graph.get_person_context(person_id)
104
+ print(f"Person context: {person_context}")
105
+
106
+ # Try to infer conversation type if user input is provided
107
+ inferred_category = None
108
+ if user_input and suggestion_type == "auto_detect":
109
+ # Simple keyword matching for now - could be enhanced with ML
110
+ user_input_lower = user_input.lower()
111
+ if any(
112
+ word in user_input_lower
113
+ for word in ["hi", "hello", "morning", "afternoon", "evening"]
114
+ ):
115
+ inferred_category = "greetings"
116
+ elif any(
117
+ word in user_input_lower
118
+ for word in ["feel", "tired", "happy", "sad", "frustrated"]
119
+ ):
120
+ inferred_category = "emotions"
121
+ elif any(
122
+ word in user_input_lower
123
+ for word in ["need", "want", "help", "water", "toilet", "loo"]
124
+ ):
125
+ inferred_category = "needs"
126
+ elif any(
127
+ word in user_input_lower
128
+ for word in ["what", "how", "when", "where", "why", "did"]
129
+ ):
130
+ inferred_category = "questions"
131
+ elif any(
132
+ word in user_input_lower
133
+ for word in ["remember", "used to", "back then", "when we"]
134
+ ):
135
+ inferred_category = "reminiscing"
136
+ elif any(
137
+ word in user_input_lower
138
+ for word in ["code", "program", "software", "app", "tech"]
139
+ ):
140
+ inferred_category = "tech_talk"
141
+ elif any(
142
+ word in user_input_lower
143
+ for word in ["plan", "schedule", "appointment", "tomorrow", "later"]
144
+ ):
145
+ inferred_category = "organization"
146
+
147
+ # Add topic to context if selected
148
+ if selected_topic:
149
+ person_context["selected_topic"] = selected_topic
150
+
151
+ # Format the output with multiple suggestions
152
+ result = ""
153
+
154
+ # If suggestion type is "model", use the language model for multiple suggestions
155
  if suggestion_type == "model":
156
+ print("Using model for suggestions")
157
+ # Generate 3 different suggestions
158
+ suggestions = []
159
+ for i in range(3):
160
+ print(f"Generating suggestion {i+1}/3")
161
+ try:
162
+ suggestion = suggestion_generator.generate_suggestion(
163
+ person_context, user_input, temperature=0.7
164
+ )
165
+ print(f"Generated suggestion: {suggestion}")
166
+ suggestions.append(suggestion)
167
+ except Exception as e:
168
+ print(f"Error generating suggestion: {e}")
169
+ suggestions.append("Error generating suggestion")
170
+
171
+ result = "### AI-Generated Responses:\n\n"
172
+ for i, suggestion in enumerate(suggestions, 1):
173
+ result += f"{i}. {suggestion}\n\n"
174
+
175
+ print(f"Final result: {result}")
176
 
177
  # If suggestion type is "common_phrases", use the person's common phrases
178
  elif suggestion_type == "common_phrases":
179
  phrases = social_graph.get_relevant_phrases(person_id, user_input)
180
+ result = "### My Common Phrases with this Person:\n\n"
181
+ for i, phrase in enumerate(phrases, 1):
182
+ result += f"{i}. {phrase}\n\n"
183
+
184
+ # If suggestion type is "auto_detect", use the inferred category or default to model
185
+ elif suggestion_type == "auto_detect":
186
+ print(f"Auto-detect mode, inferred category: {inferred_category}")
187
+ if inferred_category:
188
+ utterances = social_graph.get_common_utterances(inferred_category)
189
+ print(f"Got utterances for category {inferred_category}: {utterances}")
190
+ result = f"### Auto-detected category: {inferred_category.replace('_', ' ').title()}\n\n"
191
+ for i, utterance in enumerate(utterances, 1):
192
+ result += f"{i}. {utterance}\n\n"
193
+ else:
194
+ print("No category inferred, falling back to model")
195
+ # Fall back to model if we couldn't infer a category
196
+ try:
197
+ suggestion = suggestion_generator.generate_suggestion(
198
+ person_context, user_input
199
+ )
200
+ print(f"Generated fallback suggestion: {suggestion}")
201
+ result = "### AI-Generated Response (no category detected):\n\n"
202
+ result += f"1. {suggestion}\n\n"
203
+ except Exception as e:
204
+ print(f"Error generating fallback suggestion: {e}")
205
+ result = "### Could not generate a response:\n\n"
206
+ result += "1. Sorry, I couldn't generate a suggestion at this time.\n\n"
207
 
208
  # If suggestion type is a category from common_utterances
209
  elif suggestion_type in get_suggestion_categories():
210
+ print(f"Using category: {suggestion_type}")
211
  utterances = social_graph.get_common_utterances(suggestion_type)
212
+ print(f"Got utterances: {utterances}")
213
+ result = f"### {suggestion_type.replace('_', ' ').title()} Phrases:\n\n"
214
+ for i, utterance in enumerate(utterances, 1):
215
+ result += f"{i}. {utterance}\n\n"
216
 
217
  # Default fallback
218
+ else:
219
+ print(f"No handler for suggestion type: {suggestion_type}")
220
+ result = "No suggestions available. Please try a different option."
221
+
222
+ print(f"Returning result: {result[:100]}...")
223
+ return result
224
 
225
 
226
  def transcribe_audio(audio_path):
 
232
  # Transcribe the audio
233
  result = whisper_model.transcribe(audio_path)
234
  return result["text"]
235
+ except Exception:
 
236
  return "Could not transcribe audio. Please try again."
237
 
238
 
239
  # Create the Gradio interface
240
+ with gr.Blocks(title="Will's AAC Communication Aid") as demo:
241
+ gr.Markdown("# Will's AAC Communication Aid")
242
  gr.Markdown(
243
+ """
244
+ This demo simulates an AAC system from Will's perspective (a 38-year-old with MND).
245
+
246
+ **How to use this demo:**
247
+ 1. Select who you (Will) are talking to from the dropdown
248
+ 2. Optionally select a conversation topic
249
+ 3. Enter or record what the other person said to you
250
+ 4. Get suggested responses based on your relationship with that person
251
+ """
252
  )
253
 
254
+ # Display information about Will
255
+ with gr.Accordion("About Me (Will)", open=False):
256
+ gr.Markdown(
257
+ """
258
+ I'm Will, a 38-year-old computer programmer from Manchester with MND (diagnosed 5 months ago).
259
+ I live with my wife Emma and two children (Mabel, 4 and Billy, 7).
260
+ Originally from South East London, I enjoy technology, Manchester United, and have fond memories of cycling and hiking.
261
+ I'm increasingly using this AAC system as my speech becomes more difficult.
262
+ """
263
+ )
264
+
265
  with gr.Row():
266
  with gr.Column(scale=1):
267
  # Person selection
268
  person_dropdown = gr.Dropdown(
269
+ choices=get_people_choices(),
270
+ label="I'm talking to:",
271
+ info="Select who you (Will) are talking to",
272
+ )
273
+
274
+ # Get topics for the selected person
275
+ def get_filtered_topics(person_id):
276
+ if not person_id:
277
+ return []
278
+ person_context = social_graph.get_person_context(person_id)
279
+ return person_context.get("topics", [])
280
+
281
+ # Topic selection dropdown
282
+ topic_dropdown = gr.Dropdown(
283
+ choices=[], # Will be populated when a person is selected
284
+ label="Topic (optional):",
285
+ info="Select a topic relevant to this person",
286
+ allow_custom_value=True,
287
  )
288
 
289
  # Context display
290
+ context_display = gr.Markdown(label="Relationship Context")
291
 
292
  # User input section
293
  with gr.Row():
294
  user_input = gr.Textbox(
295
+ label="What they said to me:",
296
+ placeholder='Examples:\n"How was your physio session today?"\n"The kids are asking if you want to watch a movie tonight"\n"I\'ve been looking at that new AAC software you mentioned"',
297
  lines=3,
298
  )
299
 
300
  # Audio input
301
  with gr.Row():
302
  audio_input = gr.Audio(
303
+ label="Or record what they said:",
304
  type="filepath",
305
  sources=["microphone"],
306
  )
 
308
 
309
  # Suggestion type selection
310
  suggestion_type = gr.Radio(
311
+ choices=[
312
+ "auto_detect",
313
+ "model",
314
+ "common_phrases",
315
+ ]
316
+ + get_suggestion_categories(),
317
+ value="model", # Default to model for better results
318
+ label="How should I respond?",
319
+ info="Choose what kind of responses you want (model = AI-generated)",
320
  )
321
 
322
  # Generate button
323
+ generate_btn = gr.Button("Generate My Responses", variant="primary")
324
 
325
  with gr.Column(scale=1):
326
  # Common phrases
327
  common_phrases = gr.Textbox(
328
+ label="My Common Phrases",
329
+ placeholder="Common phrases I often use with this person will appear here...",
330
  lines=5,
331
  )
332
 
333
  # Suggestions output
334
+ suggestions_output = gr.Markdown(
335
+ label="My Suggested Responses",
336
+ value="Suggested responses will appear here...",
 
337
  )
338
 
339
  # Set up event handlers
340
+ def handle_person_change(person_id):
341
+ """Handle person selection change and update UI elements."""
342
+ context_info, phrases_text, _ = on_person_change(person_id)
343
+
344
+ # Get topics for this person
345
+ topics = get_filtered_topics(person_id)
346
+
347
+ # Update the context, phrases, and topic dropdown
348
+ return context_info, phrases_text, gr.update(choices=topics)
349
+
350
+ # Set up the person change event
351
  person_dropdown.change(
352
+ handle_person_change,
353
  inputs=[person_dropdown],
354
+ outputs=[context_display, common_phrases, topic_dropdown],
355
  )
356
 
357
+ # Set up the generate button click event
358
  generate_btn.click(
359
  generate_suggestions,
360
+ inputs=[person_dropdown, user_input, suggestion_type, topic_dropdown],
361
  outputs=[suggestions_output],
362
  )
363
 
demo.py CHANGED
@@ -1,27 +1,40 @@
1
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
2
  import json
3
 
4
  # Load model
5
- model_name = "google/flan-t5-base"
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
- rag_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
9
 
10
  # Load KG
11
  with open("social_graph.json", "r") as f:
12
  kg = json.load(f)
13
 
14
  # Build context
15
- person = kg["people"]["bob"]
16
- context = f"Bob is the user's son. They talk about football weekly. Last conversation was about coaching changes."
17
 
18
  # User input
19
- query = "What should I say to Bob?"
20
 
21
  # RAG-style prompt
22
- prompt = f"""Context: {context}
23
- User wants to say something appropriate to Bob. Suggest a phrase:"""
 
 
 
24
 
25
  # Generate
26
- response = rag_pipeline(prompt, max_length=50)
27
- print(response[0]["generated_text"])
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
  import json
3
 
4
  # Load model
5
+
6
+ # Use a simpler approach with a pre-built pipeline
7
+ rag_pipeline = pipeline("text-generation", model="distilgpt2")
 
8
 
9
  # Load KG
10
  with open("social_graph.json", "r") as f:
11
  kg = json.load(f)
12
 
13
  # Build context
14
+ person = kg["people"]["billy"] # Using Billy instead of Bob
15
+ context = person["context"]
16
 
17
  # User input
18
+ query = "What should I say to Billy?"
19
 
20
  # RAG-style prompt
21
+ prompt = """I am Will, a 38-year-old father with MND (Motor Neuron Disease). I have a 7-year-old son named Billy who loves Manchester United football.
22
+
23
+ Billy just asked me: "Dad, did you see the United match last night?"
24
+
25
+ My response to Billy:"""
26
 
27
  # Generate
28
+ response = rag_pipeline(
29
+ prompt,
30
+ max_length=100, # Longer output
31
+ temperature=0.9, # More creative
32
+ do_sample=True,
33
+ num_return_sequences=1,
34
+ top_p=0.92, # More focused sampling
35
+ top_k=50, # Limit vocabulary
36
+ )
37
+ print("Generated response:")
38
+ # For text-generation models, we need to extract just the generated part (not the prompt)
39
+ generated_text = response[0]["generated_text"][len(prompt) :]
40
+ print(generated_text)
social_graph.json CHANGED
@@ -1,4 +1,11 @@
1
  {
 
 
 
 
 
 
 
2
  "people": {
3
  "emma": {
4
  "name": "Emma",
 
1
  {
2
+ "aac_user": {
3
+ "name": "Will",
4
+ "age": 38,
5
+ "location": "Manchester",
6
+ "background": "Originally from South East London, now living in Manchester with wife Emma and two children (Mabel, 4 and Billy, 7). Diagnosed with MND 5 months ago. Works as a computer programmer with accommodations for condition. Enjoys technology, Manchester United, and has fond memories of cycling and hiking in the Peak District and Lake District. Was active in Scouts growing up.",
7
+ "communication_needs": "Increasingly using AAC as speech becomes more difficult. Voice banking in progress. Prefers British English expressions and technical terminology when appropriate. When talking to familar people has a tendency to swear but in a friendly way."
8
+ },
9
  "people": {
10
  "emma": {
11
  "name": "Emma",
utils.py CHANGED
@@ -3,186 +3,263 @@ import random
3
  from typing import Dict, List, Any, Optional, Tuple
4
  from sentence_transformers import SentenceTransformer
5
  import numpy as np
6
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
 
 
7
 
8
  class SocialGraphManager:
9
  """Manages the social graph and provides context for the AAC system."""
10
-
11
  def __init__(self, graph_path: str = "social_graph.json"):
12
  """Initialize the social graph manager.
13
-
14
  Args:
15
  graph_path: Path to the social graph JSON file
16
  """
17
  self.graph_path = graph_path
18
  self.graph = self._load_graph()
19
-
20
  # Initialize sentence transformer for semantic matching
21
  try:
22
- self.sentence_model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
 
 
23
  self.embeddings_cache = {}
24
  self._initialize_embeddings()
25
  except Exception as e:
26
- print(f"Warning: Could not load sentence transformer model: {e}")
27
  self.sentence_model = None
28
-
29
  def _load_graph(self) -> Dict[str, Any]:
30
  """Load the social graph from the JSON file."""
31
  try:
32
  with open(self.graph_path, "r") as f:
33
  return json.load(f)
34
- except Exception as e:
35
- print(f"Error loading social graph: {e}")
36
  return {"people": {}, "places": [], "topics": []}
37
-
38
  def _initialize_embeddings(self):
39
  """Initialize embeddings for topics and phrases in the social graph."""
40
  if not self.sentence_model:
41
  return
42
-
43
  # Create embeddings for topics
44
  topics = self.graph.get("topics", [])
45
  for topic in topics:
46
  if topic not in self.embeddings_cache:
47
  self.embeddings_cache[topic] = self.sentence_model.encode(topic)
48
-
49
  # Create embeddings for common phrases
50
  for person_id, person_data in self.graph.get("people", {}).items():
51
  for phrase in person_data.get("common_phrases", []):
52
  if phrase not in self.embeddings_cache:
53
  self.embeddings_cache[phrase] = self.sentence_model.encode(phrase)
54
-
55
  # Create embeddings for common utterances
56
  for category, utterances in self.graph.get("common_utterances", {}).items():
57
  for utterance in utterances:
58
  if utterance not in self.embeddings_cache:
59
- self.embeddings_cache[utterance] = self.sentence_model.encode(utterance)
60
-
 
 
61
  def get_people_list(self) -> List[Dict[str, str]]:
62
  """Get a list of people from the social graph with their names and roles."""
63
  people = []
64
  for person_id, person_data in self.graph.get("people", {}).items():
65
- people.append({
66
- "id": person_id,
67
- "name": person_data.get("name", person_id),
68
- "role": person_data.get("role", "")
69
- })
 
 
70
  return people
71
-
72
  def get_person_context(self, person_id: str) -> Dict[str, Any]:
73
  """Get context information for a specific person."""
 
 
 
 
 
 
 
 
 
 
 
 
74
  if person_id not in self.graph.get("people", {}):
75
  return {}
76
-
77
- return self.graph["people"][person_id]
78
-
79
- def get_relevant_phrases(self, person_id: str, user_input: Optional[str] = None) -> List[str]:
 
 
 
80
  """Get relevant phrases for a specific person based on user input."""
81
  if person_id not in self.graph.get("people", {}):
82
  return []
83
-
84
  person_data = self.graph["people"][person_id]
85
  phrases = person_data.get("common_phrases", [])
86
-
87
  # If no user input, return random phrases
88
  if not user_input or not self.sentence_model:
89
  return random.sample(phrases, min(3, len(phrases)))
90
-
91
  # Use semantic search to find relevant phrases
92
  user_embedding = self.sentence_model.encode(user_input)
93
  phrase_scores = []
94
-
95
  for phrase in phrases:
96
  if phrase in self.embeddings_cache:
97
  phrase_embedding = self.embeddings_cache[phrase]
98
  else:
99
  phrase_embedding = self.sentence_model.encode(phrase)
100
  self.embeddings_cache[phrase] = phrase_embedding
101
-
102
  similarity = np.dot(user_embedding, phrase_embedding) / (
103
  np.linalg.norm(user_embedding) * np.linalg.norm(phrase_embedding)
104
  )
105
  phrase_scores.append((phrase, similarity))
106
-
107
  # Sort by similarity score and return top phrases
108
  phrase_scores.sort(key=lambda x: x[1], reverse=True)
109
  return [phrase for phrase, _ in phrase_scores[:3]]
110
-
111
  def get_common_utterances(self, category: Optional[str] = None) -> List[str]:
112
  """Get common utterances from the social graph, optionally filtered by category."""
113
  utterances = []
114
-
115
  if "common_utterances" not in self.graph:
116
  return utterances
117
-
118
  if category and category in self.graph["common_utterances"]:
119
  return self.graph["common_utterances"][category]
120
-
121
  # If no category specified, return a sample from each category
122
  for category_utterances in self.graph["common_utterances"].values():
123
- utterances.extend(random.sample(category_utterances,
124
- min(2, len(category_utterances))))
125
-
 
126
  return utterances
127
 
 
128
  class SuggestionGenerator:
129
  """Generates contextual suggestions for the AAC system."""
130
-
131
- def __init__(self, model_name: str = "google/flan-t5-base"):
132
  """Initialize the suggestion generator.
133
-
134
  Args:
135
  model_name: Name of the HuggingFace model to use
136
  """
 
 
 
137
  try:
138
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
139
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
140
- self.generator = pipeline("text2text-generation",
141
- model=self.model,
142
- tokenizer=self.tokenizer)
143
  self.model_loaded = True
 
144
  except Exception as e:
145
- print(f"Warning: Could not load model {model_name}: {e}")
146
  self.model_loaded = False
147
-
148
- def generate_suggestion(self,
149
- person_context: Dict[str, Any],
150
- user_input: Optional[str] = None,
151
- max_length: int = 50) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  """Generate a contextually appropriate suggestion.
153
-
154
  Args:
155
  person_context: Context information about the person
156
  user_input: Optional user input to consider
157
  max_length: Maximum length of the generated suggestion
158
-
 
159
  Returns:
160
  A generated suggestion string
161
  """
162
  if not self.model_loaded:
163
- return "Model not loaded. Please check your installation."
164
-
 
 
 
 
165
  # Extract context information
166
  name = person_context.get("name", "")
167
  role = person_context.get("role", "")
168
  topics = ", ".join(person_context.get("topics", []))
169
  context = person_context.get("context", "")
170
-
 
171
  # Build prompt
172
- prompt = f"""Context: {context}
173
- Person: {name} ({role})
174
- Topics of interest: {topics}
175
  """
176
-
 
 
 
 
 
 
 
 
 
177
  if user_input:
178
- prompt += f"Current conversation: {user_input}\n"
179
-
180
- prompt += "Generate an appropriate phrase to say to this person:"
181
-
182
  # Generate suggestion
183
  try:
184
- response = self.generator(prompt, max_length=max_length)
185
- return response[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
186
  except Exception as e:
187
  print(f"Error generating suggestion: {e}")
188
  return "Could not generate a suggestion. Please try again."
 
3
  from typing import Dict, List, Any, Optional, Tuple
4
  from sentence_transformers import SentenceTransformer
5
  import numpy as np
6
+
7
+ from transformers import pipeline
8
+
9
 
10
  class SocialGraphManager:
11
  """Manages the social graph and provides context for the AAC system."""
12
+
13
  def __init__(self, graph_path: str = "social_graph.json"):
14
  """Initialize the social graph manager.
15
+
16
  Args:
17
  graph_path: Path to the social graph JSON file
18
  """
19
  self.graph_path = graph_path
20
  self.graph = self._load_graph()
21
+
22
  # Initialize sentence transformer for semantic matching
23
  try:
24
+ self.sentence_model = SentenceTransformer(
25
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
26
+ )
27
  self.embeddings_cache = {}
28
  self._initialize_embeddings()
29
  except Exception as e:
 
30
  self.sentence_model = None
31
+
32
  def _load_graph(self) -> Dict[str, Any]:
33
  """Load the social graph from the JSON file."""
34
  try:
35
  with open(self.graph_path, "r") as f:
36
  return json.load(f)
37
+ except Exception:
 
38
  return {"people": {}, "places": [], "topics": []}
39
+
40
  def _initialize_embeddings(self):
41
  """Initialize embeddings for topics and phrases in the social graph."""
42
  if not self.sentence_model:
43
  return
44
+
45
  # Create embeddings for topics
46
  topics = self.graph.get("topics", [])
47
  for topic in topics:
48
  if topic not in self.embeddings_cache:
49
  self.embeddings_cache[topic] = self.sentence_model.encode(topic)
50
+
51
  # Create embeddings for common phrases
52
  for person_id, person_data in self.graph.get("people", {}).items():
53
  for phrase in person_data.get("common_phrases", []):
54
  if phrase not in self.embeddings_cache:
55
  self.embeddings_cache[phrase] = self.sentence_model.encode(phrase)
56
+
57
  # Create embeddings for common utterances
58
  for category, utterances in self.graph.get("common_utterances", {}).items():
59
  for utterance in utterances:
60
  if utterance not in self.embeddings_cache:
61
+ self.embeddings_cache[utterance] = self.sentence_model.encode(
62
+ utterance
63
+ )
64
+
65
  def get_people_list(self) -> List[Dict[str, str]]:
66
  """Get a list of people from the social graph with their names and roles."""
67
  people = []
68
  for person_id, person_data in self.graph.get("people", {}).items():
69
+ people.append(
70
+ {
71
+ "id": person_id,
72
+ "name": person_data.get("name", person_id),
73
+ "role": person_data.get("role", ""),
74
+ }
75
+ )
76
  return people
77
+
78
  def get_person_context(self, person_id: str) -> Dict[str, Any]:
79
  """Get context information for a specific person."""
80
+ # Check if the person_id contains a display name (e.g., "Emma (wife)")
81
+ # and try to extract the actual ID
82
+ if person_id not in self.graph.get("people", {}):
83
+ # Try to find the person by name
84
+ for pid, pdata in self.graph.get("people", {}).items():
85
+ name = pdata.get("name", "")
86
+ role = pdata.get("role", "")
87
+ if f"{name} ({role})" == person_id:
88
+ person_id = pid
89
+ break
90
+
91
+ # If still not found, return empty dict
92
  if person_id not in self.graph.get("people", {}):
93
  return {}
94
+
95
+ person_data = self.graph["people"][person_id]
96
+ return person_data
97
+
98
+ def get_relevant_phrases(
99
+ self, person_id: str, user_input: Optional[str] = None
100
+ ) -> List[str]:
101
  """Get relevant phrases for a specific person based on user input."""
102
  if person_id not in self.graph.get("people", {}):
103
  return []
104
+
105
  person_data = self.graph["people"][person_id]
106
  phrases = person_data.get("common_phrases", [])
107
+
108
  # If no user input, return random phrases
109
  if not user_input or not self.sentence_model:
110
  return random.sample(phrases, min(3, len(phrases)))
111
+
112
  # Use semantic search to find relevant phrases
113
  user_embedding = self.sentence_model.encode(user_input)
114
  phrase_scores = []
115
+
116
  for phrase in phrases:
117
  if phrase in self.embeddings_cache:
118
  phrase_embedding = self.embeddings_cache[phrase]
119
  else:
120
  phrase_embedding = self.sentence_model.encode(phrase)
121
  self.embeddings_cache[phrase] = phrase_embedding
122
+
123
  similarity = np.dot(user_embedding, phrase_embedding) / (
124
  np.linalg.norm(user_embedding) * np.linalg.norm(phrase_embedding)
125
  )
126
  phrase_scores.append((phrase, similarity))
127
+
128
  # Sort by similarity score and return top phrases
129
  phrase_scores.sort(key=lambda x: x[1], reverse=True)
130
  return [phrase for phrase, _ in phrase_scores[:3]]
131
+
132
  def get_common_utterances(self, category: Optional[str] = None) -> List[str]:
133
  """Get common utterances from the social graph, optionally filtered by category."""
134
  utterances = []
135
+
136
  if "common_utterances" not in self.graph:
137
  return utterances
138
+
139
  if category and category in self.graph["common_utterances"]:
140
  return self.graph["common_utterances"][category]
141
+
142
  # If no category specified, return a sample from each category
143
  for category_utterances in self.graph["common_utterances"].values():
144
+ utterances.extend(
145
+ random.sample(category_utterances, min(2, len(category_utterances)))
146
+ )
147
+
148
  return utterances
149
 
150
+
151
  class SuggestionGenerator:
152
  """Generates contextual suggestions for the AAC system."""
153
+
154
+ def __init__(self, model_name: str = "distilgpt2"):
155
  """Initialize the suggestion generator.
156
+
157
  Args:
158
  model_name: Name of the HuggingFace model to use
159
  """
160
+ self.model_name = model_name
161
+ self.model_loaded = False
162
+
163
  try:
164
+ print(f"Loading model: {model_name}")
165
+ # Use a simpler approach with a pre-built pipeline
166
+ self.generator = pipeline("text-generation", model=model_name)
 
 
167
  self.model_loaded = True
168
+ print(f"Model loaded successfully: {model_name}")
169
  except Exception as e:
170
+ print(f"Error loading model: {e}")
171
  self.model_loaded = False
172
+
173
+ # Fallback responses if model fails to load or generate
174
+ self.fallback_responses = [
175
+ "I'm not sure how to respond to that.",
176
+ "That's interesting. Tell me more.",
177
+ "I'd like to talk about that further.",
178
+ "I appreciate you sharing that with me.",
179
+ ]
180
+
181
+ def test_model(self) -> str:
182
+ """Test if the model is working correctly."""
183
+ if not self.model_loaded:
184
+ return "Model not loaded"
185
+
186
+ try:
187
+ test_prompt = "I am Will. My son Billy asked about football. I respond:"
188
+ print(f"Testing model with prompt: {test_prompt}")
189
+ response = self.generator(test_prompt, max_length=30, do_sample=True)
190
+ result = response[0]["generated_text"][len(test_prompt) :]
191
+ print(f"Test response: {result}")
192
+ return f"Model test successful: {result}"
193
+ except Exception as e:
194
+ print(f"Error testing model: {e}")
195
+ return f"Model test failed: {str(e)}"
196
+
197
+ def generate_suggestion(
198
+ self,
199
+ person_context: Dict[str, Any],
200
+ user_input: Optional[str] = None,
201
+ max_length: int = 50,
202
+ temperature: float = 0.7,
203
+ ) -> str:
204
  """Generate a contextually appropriate suggestion.
205
+
206
  Args:
207
  person_context: Context information about the person
208
  user_input: Optional user input to consider
209
  max_length: Maximum length of the generated suggestion
210
+ temperature: Controls randomness in generation (higher = more random)
211
+
212
  Returns:
213
  A generated suggestion string
214
  """
215
  if not self.model_loaded:
216
+ # Use fallback responses if model isn't loaded
217
+ import random
218
+
219
+ print("Model not loaded, using fallback responses")
220
+ return random.choice(self.fallback_responses)
221
+
222
  # Extract context information
223
  name = person_context.get("name", "")
224
  role = person_context.get("role", "")
225
  topics = ", ".join(person_context.get("topics", []))
226
  context = person_context.get("context", "")
227
+ selected_topic = person_context.get("selected_topic", "")
228
+
229
  # Build prompt
230
+ prompt = f"""I am Will, a person with MND (Motor Neuron Disease).
231
+ I'm talking to {name}, who is my {role}.
 
232
  """
233
+
234
+ if context:
235
+ prompt += f"Context: {context}\n"
236
+
237
+ if topics:
238
+ prompt += f"Topics of interest: {topics}\n"
239
+
240
+ if selected_topic:
241
+ prompt += f"We're currently talking about: {selected_topic}\n"
242
+
243
  if user_input:
244
+ prompt += f'\n{name} just said to me: "{user_input}"\n'
245
+
246
+ prompt += "\nMy response:"
247
+
248
  # Generate suggestion
249
  try:
250
+ print(f"Generating suggestion with prompt: {prompt}")
251
+ response = self.generator(
252
+ prompt,
253
+ max_length=len(prompt.split()) + max_length,
254
+ temperature=temperature,
255
+ do_sample=True,
256
+ top_p=0.92,
257
+ top_k=50,
258
+ )
259
+ # Extract only the generated part, not the prompt
260
+ result = response[0]["generated_text"][len(prompt) :]
261
+ print(f"Generated response: {result}")
262
+ return result.strip()
263
  except Exception as e:
264
  print(f"Error generating suggestion: {e}")
265
  return "Could not generate a suggestion. Please try again."