first attempt
Browse files- README-HF.md +15 -11
- README.md +20 -5
- app.py +238 -46
- demo.py +25 -12
- social_graph.json +7 -0
- utils.py +144 -67
README-HF.md
CHANGED
@@ -1,16 +1,19 @@
|
|
1 |
-
# AAC
|
2 |
|
3 |
-
An Augmentative and Alternative Communication (AAC) system that uses a social graph to provide contextually relevant suggestions for
|
4 |
|
5 |
## About
|
6 |
|
7 |
-
This demo
|
8 |
|
9 |
## Features
|
10 |
|
11 |
-
- **Person-Specific Suggestions**: Select who you
|
|
|
12 |
- **Context-Aware**: Uses a social graph to understand relationships and common topics
|
13 |
-
- **
|
|
|
|
|
14 |
- **British Context**: Designed with British English and NHS healthcare context in mind
|
15 |
- **MND-Specific**: Tailored for the needs of someone with Motor Neurone Disease
|
16 |
- **Expandable**: Easily improve the system by enhancing the social graph JSON file
|
@@ -30,12 +33,13 @@ The current social graph represents a British person with MND who:
|
|
30 |
|
31 |
## How to Use
|
32 |
|
33 |
-
1. Select
|
34 |
-
2.
|
35 |
-
3.
|
36 |
-
4.
|
37 |
-
5.
|
38 |
-
6.
|
|
|
39 |
|
40 |
## Customizing the Social Graph
|
41 |
|
|
|
1 |
+
# Will's AAC Communication Aid
|
2 |
|
3 |
+
An Augmentative and Alternative Communication (AAC) system that uses a social graph to provide contextually relevant suggestions for Will, a user with Motor Neurone Disease (MND).
|
4 |
|
5 |
## About
|
6 |
|
7 |
+
This demo simulates an AAC system from Will's perspective (a 38-year-old with MND). The system allows Will to select who he's talking to, optionally choose a conversation topic, and get appropriate responses based on what the other person has said. All suggestions are tailored to the relationship and conversation context, using British English and NHS healthcare terminology where appropriate.
|
8 |
|
9 |
## Features
|
10 |
|
11 |
+
- **Person-Specific Suggestions**: Select who you (Will) are talking to and get tailored responses
|
12 |
+
- **Topic Selection**: Choose conversation topics relevant to your relationship
|
13 |
- **Context-Aware**: Uses a social graph to understand relationships and common topics
|
14 |
+
- **Speech Recognition**: Record what others have said to you and have it transcribed
|
15 |
+
- **Auto-Detection**: Automatically detect conversation type from what others say
|
16 |
+
- **Multiple Response Types**: Get AI-generated responses, common phrases, or category-specific utterances
|
17 |
- **British Context**: Designed with British English and NHS healthcare context in mind
|
18 |
- **MND-Specific**: Tailored for the needs of someone with Motor Neurone Disease
|
19 |
- **Expandable**: Easily improve the system by enhancing the social graph JSON file
|
|
|
33 |
|
34 |
## How to Use
|
35 |
|
36 |
+
1. Select who you (Will) are talking to from the dropdown menu
|
37 |
+
2. Optionally select a conversation topic
|
38 |
+
3. View the relationship context information
|
39 |
+
4. Enter what the other person said to you, or record audio
|
40 |
+
5. If you record audio, click "Transcribe" to convert it to text
|
41 |
+
6. Choose how you want to respond (auto-detect, AI-generated, common phrases, etc.)
|
42 |
+
7. Click "Generate My Responses" to get contextually relevant suggestions
|
43 |
|
44 |
## Customizing the Social Graph
|
45 |
|
README.md
CHANGED
@@ -37,16 +37,31 @@ python app.py
|
|
37 |
|
38 |
## How It Works
|
39 |
|
40 |
-
1. **Social Graph**: The system uses a JSON-based social graph (`social_graph.json`) that contains information about
|
|
|
|
|
|
|
41 |
|
42 |
-
2. **Context Retrieval**: When you select
|
43 |
|
44 |
-
3. **
|
|
|
|
|
45 |
- A language model (Flan-T5)
|
46 |
-
- Common phrases
|
47 |
- General utterance categories (greetings, needs, emotions, questions)
|
48 |
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
## Customizing the Social Graph
|
52 |
|
|
|
37 |
|
38 |
## How It Works
|
39 |
|
40 |
+
1. **Social Graph**: The system uses a JSON-based social graph (`social_graph.json`) that contains information about:
|
41 |
+
- Will (the AAC user) - a 38-year-old with MND
|
42 |
+
- People in Will's life (family, healthcare providers, friends, colleagues)
|
43 |
+
- Relationships, common topics, and phrases
|
44 |
|
45 |
+
2. **Context Retrieval**: When you select who you are (as someone talking to Will), the system retrieves relevant context information from the social graph.
|
46 |
|
47 |
+
3. **Conversation Input**: You enter or record what you've said to Will in the conversation.
|
48 |
+
|
49 |
+
4. **Suggestion Generation**: Based on who you are and what you've said, the system generates appropriate responses for Will using:
|
50 |
- A language model (Flan-T5)
|
51 |
+
- Common phrases Will might say to you
|
52 |
- General utterance categories (greetings, needs, emotions, questions)
|
53 |
|
54 |
+
5. **User Interface**: The Gradio interface provides an intuitive way to simulate conversations with Will and see what an AAC system might suggest for him to say.
|
55 |
+
|
56 |
+
## How to Use
|
57 |
+
|
58 |
+
1. Select who you (Will) are talking to from the dropdown menu
|
59 |
+
2. Optionally select a conversation topic
|
60 |
+
3. View the relationship context information
|
61 |
+
4. Enter what the other person said to you, or record audio
|
62 |
+
5. If you record audio, click "Transcribe" to convert it to text
|
63 |
+
6. Choose how you want to respond (auto-detect, AI-generated, common phrases, etc.)
|
64 |
+
7. Click "Generate My Responses" to get contextually relevant suggestions
|
65 |
|
66 |
## Customizing the Social Graph
|
67 |
|
app.py
CHANGED
@@ -6,14 +6,24 @@ from utils import SocialGraphManager, SuggestionGenerator
|
|
6 |
|
7 |
# Initialize the social graph manager and suggestion generator
|
8 |
social_graph = SocialGraphManager("social_graph.json")
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# Initialize Whisper model (using the smallest model for speed)
|
12 |
try:
|
13 |
whisper_model = whisper.load_model("tiny")
|
14 |
whisper_loaded = True
|
15 |
-
except Exception
|
16 |
-
print(f"Warning: Could not load Whisper model: {e}")
|
17 |
whisper_loaded = False
|
18 |
|
19 |
|
@@ -25,7 +35,22 @@ def format_person_display(person):
|
|
25 |
def get_people_choices():
|
26 |
"""Get formatted choices for the people dropdown."""
|
27 |
people = social_graph.get_people_list()
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
def get_suggestion_categories():
|
@@ -38,49 +63,164 @@ def get_suggestion_categories():
|
|
38 |
def on_person_change(person_id):
|
39 |
"""Handle person selection change."""
|
40 |
if not person_id:
|
41 |
-
return "", ""
|
42 |
|
43 |
person_context = social_graph.get_person_context(person_id)
|
44 |
-
|
45 |
-
|
46 |
-
)
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
# Get common phrases for this person
|
52 |
phrases = person_context.get("common_phrases", [])
|
53 |
phrases_text = "\n\n".join(phrases)
|
54 |
|
55 |
-
|
|
|
56 |
|
|
|
57 |
|
58 |
-
|
|
|
59 |
"""Generate suggestions based on the selected person and user input."""
|
|
|
|
|
|
|
|
|
60 |
if not person_id:
|
61 |
-
|
|
|
62 |
|
63 |
person_context = social_graph.get_person_context(person_id)
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
if suggestion_type == "model":
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
# If suggestion type is "common_phrases", use the person's common phrases
|
73 |
elif suggestion_type == "common_phrases":
|
74 |
phrases = social_graph.get_relevant_phrases(person_id, user_input)
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
# If suggestion type is a category from common_utterances
|
78 |
elif suggestion_type in get_suggestion_categories():
|
|
|
79 |
utterances = social_graph.get_common_utterances(suggestion_type)
|
80 |
-
|
|
|
|
|
|
|
81 |
|
82 |
# Default fallback
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
|
86 |
def transcribe_audio(audio_path):
|
@@ -92,40 +232,75 @@ def transcribe_audio(audio_path):
|
|
92 |
# Transcribe the audio
|
93 |
result = whisper_model.transcribe(audio_path)
|
94 |
return result["text"]
|
95 |
-
except Exception
|
96 |
-
print(f"Error transcribing audio: {e}")
|
97 |
return "Could not transcribe audio. Please try again."
|
98 |
|
99 |
|
100 |
# Create the Gradio interface
|
101 |
-
with gr.Blocks(title="AAC
|
102 |
-
gr.Markdown("# AAC
|
103 |
gr.Markdown(
|
104 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
)
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
with gr.Row():
|
108 |
with gr.Column(scale=1):
|
109 |
# Person selection
|
110 |
person_dropdown = gr.Dropdown(
|
111 |
-
choices=get_people_choices(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
)
|
113 |
|
114 |
# Context display
|
115 |
-
context_display = gr.Markdown(label="Context
|
116 |
|
117 |
# User input section
|
118 |
with gr.Row():
|
119 |
user_input = gr.Textbox(
|
120 |
-
label="
|
121 |
-
placeholder="
|
122 |
lines=3,
|
123 |
)
|
124 |
|
125 |
# Audio input
|
126 |
with gr.Row():
|
127 |
audio_input = gr.Audio(
|
128 |
-
label="Or record
|
129 |
type="filepath",
|
130 |
sources=["microphone"],
|
131 |
)
|
@@ -133,39 +308,56 @@ with gr.Blocks(title="AAC Social Graph Assistant") as demo:
|
|
133 |
|
134 |
# Suggestion type selection
|
135 |
suggestion_type = gr.Radio(
|
136 |
-
choices=[
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
)
|
140 |
|
141 |
# Generate button
|
142 |
-
generate_btn = gr.Button("Generate
|
143 |
|
144 |
with gr.Column(scale=1):
|
145 |
# Common phrases
|
146 |
common_phrases = gr.Textbox(
|
147 |
-
label="Common Phrases",
|
148 |
-
placeholder="Common phrases will appear here...",
|
149 |
lines=5,
|
150 |
)
|
151 |
|
152 |
# Suggestions output
|
153 |
-
suggestions_output = gr.
|
154 |
-
label="Suggested
|
155 |
-
|
156 |
-
lines=8,
|
157 |
)
|
158 |
|
159 |
# Set up event handlers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
person_dropdown.change(
|
161 |
-
|
162 |
inputs=[person_dropdown],
|
163 |
-
outputs=[context_display, common_phrases],
|
164 |
)
|
165 |
|
|
|
166 |
generate_btn.click(
|
167 |
generate_suggestions,
|
168 |
-
inputs=[person_dropdown, user_input, suggestion_type],
|
169 |
outputs=[suggestions_output],
|
170 |
)
|
171 |
|
|
|
6 |
|
7 |
# Initialize the social graph manager and suggestion generator
|
8 |
social_graph = SocialGraphManager("social_graph.json")
|
9 |
+
|
10 |
+
# Initialize the suggestion generator with distilgpt2
|
11 |
+
suggestion_generator = SuggestionGenerator("distilgpt2")
|
12 |
+
|
13 |
+
# Test the model to make sure it's working
|
14 |
+
test_result = suggestion_generator.test_model()
|
15 |
+
print(f"Model test result: {test_result}")
|
16 |
+
|
17 |
+
# If the model didn't load, use the fallback responses
|
18 |
+
if not suggestion_generator.model_loaded:
|
19 |
+
print("Model failed to load, using fallback responses...")
|
20 |
+
# The SuggestionGenerator class has built-in fallback responses
|
21 |
|
22 |
# Initialize Whisper model (using the smallest model for speed)
|
23 |
try:
|
24 |
whisper_model = whisper.load_model("tiny")
|
25 |
whisper_loaded = True
|
26 |
+
except Exception:
|
|
|
27 |
whisper_loaded = False
|
28 |
|
29 |
|
|
|
35 |
def get_people_choices():
|
36 |
"""Get formatted choices for the people dropdown."""
|
37 |
people = social_graph.get_people_list()
|
38 |
+
choices = {}
|
39 |
+
for person in people:
|
40 |
+
display_name = format_person_display(person)
|
41 |
+
person_id = person["id"]
|
42 |
+
choices[display_name] = person_id
|
43 |
+
return choices
|
44 |
+
|
45 |
+
|
46 |
+
def get_topics_for_person(person_id):
|
47 |
+
"""Get topics for a specific person."""
|
48 |
+
if not person_id:
|
49 |
+
return []
|
50 |
+
|
51 |
+
person_context = social_graph.get_person_context(person_id)
|
52 |
+
topics = person_context.get("topics", [])
|
53 |
+
return topics
|
54 |
|
55 |
|
56 |
def get_suggestion_categories():
|
|
|
63 |
def on_person_change(person_id):
|
64 |
"""Handle person selection change."""
|
65 |
if not person_id:
|
66 |
+
return "", "", []
|
67 |
|
68 |
person_context = social_graph.get_person_context(person_id)
|
69 |
+
|
70 |
+
# Create a more user-friendly context display
|
71 |
+
name = person_context.get("name", "")
|
72 |
+
role = person_context.get("role", "")
|
73 |
+
frequency = person_context.get("frequency", "")
|
74 |
+
context_text = person_context.get("context", "")
|
75 |
+
|
76 |
+
context_info = f"""### I'm talking to: {name}
|
77 |
+
**Relationship:** {role}
|
78 |
+
**How often we talk:** {frequency}
|
79 |
+
|
80 |
+
**Our relationship:** {context_text}
|
81 |
+
"""
|
82 |
|
83 |
# Get common phrases for this person
|
84 |
phrases = person_context.get("common_phrases", [])
|
85 |
phrases_text = "\n\n".join(phrases)
|
86 |
|
87 |
+
# Get topics for this person
|
88 |
+
topics = person_context.get("topics", [])
|
89 |
|
90 |
+
return context_info, phrases_text, topics
|
91 |
|
92 |
+
|
93 |
+
def generate_suggestions(person_id, user_input, suggestion_type, selected_topic=None):
|
94 |
"""Generate suggestions based on the selected person and user input."""
|
95 |
+
print(
|
96 |
+
f"Generating suggestions with: person_id={person_id}, user_input={user_input}, suggestion_type={suggestion_type}, selected_topic={selected_topic}"
|
97 |
+
)
|
98 |
+
|
99 |
if not person_id:
|
100 |
+
print("No person_id provided")
|
101 |
+
return "Please select who you're talking to first."
|
102 |
|
103 |
person_context = social_graph.get_person_context(person_id)
|
104 |
+
print(f"Person context: {person_context}")
|
105 |
+
|
106 |
+
# Try to infer conversation type if user input is provided
|
107 |
+
inferred_category = None
|
108 |
+
if user_input and suggestion_type == "auto_detect":
|
109 |
+
# Simple keyword matching for now - could be enhanced with ML
|
110 |
+
user_input_lower = user_input.lower()
|
111 |
+
if any(
|
112 |
+
word in user_input_lower
|
113 |
+
for word in ["hi", "hello", "morning", "afternoon", "evening"]
|
114 |
+
):
|
115 |
+
inferred_category = "greetings"
|
116 |
+
elif any(
|
117 |
+
word in user_input_lower
|
118 |
+
for word in ["feel", "tired", "happy", "sad", "frustrated"]
|
119 |
+
):
|
120 |
+
inferred_category = "emotions"
|
121 |
+
elif any(
|
122 |
+
word in user_input_lower
|
123 |
+
for word in ["need", "want", "help", "water", "toilet", "loo"]
|
124 |
+
):
|
125 |
+
inferred_category = "needs"
|
126 |
+
elif any(
|
127 |
+
word in user_input_lower
|
128 |
+
for word in ["what", "how", "when", "where", "why", "did"]
|
129 |
+
):
|
130 |
+
inferred_category = "questions"
|
131 |
+
elif any(
|
132 |
+
word in user_input_lower
|
133 |
+
for word in ["remember", "used to", "back then", "when we"]
|
134 |
+
):
|
135 |
+
inferred_category = "reminiscing"
|
136 |
+
elif any(
|
137 |
+
word in user_input_lower
|
138 |
+
for word in ["code", "program", "software", "app", "tech"]
|
139 |
+
):
|
140 |
+
inferred_category = "tech_talk"
|
141 |
+
elif any(
|
142 |
+
word in user_input_lower
|
143 |
+
for word in ["plan", "schedule", "appointment", "tomorrow", "later"]
|
144 |
+
):
|
145 |
+
inferred_category = "organization"
|
146 |
+
|
147 |
+
# Add topic to context if selected
|
148 |
+
if selected_topic:
|
149 |
+
person_context["selected_topic"] = selected_topic
|
150 |
+
|
151 |
+
# Format the output with multiple suggestions
|
152 |
+
result = ""
|
153 |
+
|
154 |
+
# If suggestion type is "model", use the language model for multiple suggestions
|
155 |
if suggestion_type == "model":
|
156 |
+
print("Using model for suggestions")
|
157 |
+
# Generate 3 different suggestions
|
158 |
+
suggestions = []
|
159 |
+
for i in range(3):
|
160 |
+
print(f"Generating suggestion {i+1}/3")
|
161 |
+
try:
|
162 |
+
suggestion = suggestion_generator.generate_suggestion(
|
163 |
+
person_context, user_input, temperature=0.7
|
164 |
+
)
|
165 |
+
print(f"Generated suggestion: {suggestion}")
|
166 |
+
suggestions.append(suggestion)
|
167 |
+
except Exception as e:
|
168 |
+
print(f"Error generating suggestion: {e}")
|
169 |
+
suggestions.append("Error generating suggestion")
|
170 |
+
|
171 |
+
result = "### AI-Generated Responses:\n\n"
|
172 |
+
for i, suggestion in enumerate(suggestions, 1):
|
173 |
+
result += f"{i}. {suggestion}\n\n"
|
174 |
+
|
175 |
+
print(f"Final result: {result}")
|
176 |
|
177 |
# If suggestion type is "common_phrases", use the person's common phrases
|
178 |
elif suggestion_type == "common_phrases":
|
179 |
phrases = social_graph.get_relevant_phrases(person_id, user_input)
|
180 |
+
result = "### My Common Phrases with this Person:\n\n"
|
181 |
+
for i, phrase in enumerate(phrases, 1):
|
182 |
+
result += f"{i}. {phrase}\n\n"
|
183 |
+
|
184 |
+
# If suggestion type is "auto_detect", use the inferred category or default to model
|
185 |
+
elif suggestion_type == "auto_detect":
|
186 |
+
print(f"Auto-detect mode, inferred category: {inferred_category}")
|
187 |
+
if inferred_category:
|
188 |
+
utterances = social_graph.get_common_utterances(inferred_category)
|
189 |
+
print(f"Got utterances for category {inferred_category}: {utterances}")
|
190 |
+
result = f"### Auto-detected category: {inferred_category.replace('_', ' ').title()}\n\n"
|
191 |
+
for i, utterance in enumerate(utterances, 1):
|
192 |
+
result += f"{i}. {utterance}\n\n"
|
193 |
+
else:
|
194 |
+
print("No category inferred, falling back to model")
|
195 |
+
# Fall back to model if we couldn't infer a category
|
196 |
+
try:
|
197 |
+
suggestion = suggestion_generator.generate_suggestion(
|
198 |
+
person_context, user_input
|
199 |
+
)
|
200 |
+
print(f"Generated fallback suggestion: {suggestion}")
|
201 |
+
result = "### AI-Generated Response (no category detected):\n\n"
|
202 |
+
result += f"1. {suggestion}\n\n"
|
203 |
+
except Exception as e:
|
204 |
+
print(f"Error generating fallback suggestion: {e}")
|
205 |
+
result = "### Could not generate a response:\n\n"
|
206 |
+
result += "1. Sorry, I couldn't generate a suggestion at this time.\n\n"
|
207 |
|
208 |
# If suggestion type is a category from common_utterances
|
209 |
elif suggestion_type in get_suggestion_categories():
|
210 |
+
print(f"Using category: {suggestion_type}")
|
211 |
utterances = social_graph.get_common_utterances(suggestion_type)
|
212 |
+
print(f"Got utterances: {utterances}")
|
213 |
+
result = f"### {suggestion_type.replace('_', ' ').title()} Phrases:\n\n"
|
214 |
+
for i, utterance in enumerate(utterances, 1):
|
215 |
+
result += f"{i}. {utterance}\n\n"
|
216 |
|
217 |
# Default fallback
|
218 |
+
else:
|
219 |
+
print(f"No handler for suggestion type: {suggestion_type}")
|
220 |
+
result = "No suggestions available. Please try a different option."
|
221 |
+
|
222 |
+
print(f"Returning result: {result[:100]}...")
|
223 |
+
return result
|
224 |
|
225 |
|
226 |
def transcribe_audio(audio_path):
|
|
|
232 |
# Transcribe the audio
|
233 |
result = whisper_model.transcribe(audio_path)
|
234 |
return result["text"]
|
235 |
+
except Exception:
|
|
|
236 |
return "Could not transcribe audio. Please try again."
|
237 |
|
238 |
|
239 |
# Create the Gradio interface
|
240 |
+
with gr.Blocks(title="Will's AAC Communication Aid") as demo:
|
241 |
+
gr.Markdown("# Will's AAC Communication Aid")
|
242 |
gr.Markdown(
|
243 |
+
"""
|
244 |
+
This demo simulates an AAC system from Will's perspective (a 38-year-old with MND).
|
245 |
+
|
246 |
+
**How to use this demo:**
|
247 |
+
1. Select who you (Will) are talking to from the dropdown
|
248 |
+
2. Optionally select a conversation topic
|
249 |
+
3. Enter or record what the other person said to you
|
250 |
+
4. Get suggested responses based on your relationship with that person
|
251 |
+
"""
|
252 |
)
|
253 |
|
254 |
+
# Display information about Will
|
255 |
+
with gr.Accordion("About Me (Will)", open=False):
|
256 |
+
gr.Markdown(
|
257 |
+
"""
|
258 |
+
I'm Will, a 38-year-old computer programmer from Manchester with MND (diagnosed 5 months ago).
|
259 |
+
I live with my wife Emma and two children (Mabel, 4 and Billy, 7).
|
260 |
+
Originally from South East London, I enjoy technology, Manchester United, and have fond memories of cycling and hiking.
|
261 |
+
I'm increasingly using this AAC system as my speech becomes more difficult.
|
262 |
+
"""
|
263 |
+
)
|
264 |
+
|
265 |
with gr.Row():
|
266 |
with gr.Column(scale=1):
|
267 |
# Person selection
|
268 |
person_dropdown = gr.Dropdown(
|
269 |
+
choices=get_people_choices(),
|
270 |
+
label="I'm talking to:",
|
271 |
+
info="Select who you (Will) are talking to",
|
272 |
+
)
|
273 |
+
|
274 |
+
# Get topics for the selected person
|
275 |
+
def get_filtered_topics(person_id):
|
276 |
+
if not person_id:
|
277 |
+
return []
|
278 |
+
person_context = social_graph.get_person_context(person_id)
|
279 |
+
return person_context.get("topics", [])
|
280 |
+
|
281 |
+
# Topic selection dropdown
|
282 |
+
topic_dropdown = gr.Dropdown(
|
283 |
+
choices=[], # Will be populated when a person is selected
|
284 |
+
label="Topic (optional):",
|
285 |
+
info="Select a topic relevant to this person",
|
286 |
+
allow_custom_value=True,
|
287 |
)
|
288 |
|
289 |
# Context display
|
290 |
+
context_display = gr.Markdown(label="Relationship Context")
|
291 |
|
292 |
# User input section
|
293 |
with gr.Row():
|
294 |
user_input = gr.Textbox(
|
295 |
+
label="What they said to me:",
|
296 |
+
placeholder='Examples:\n"How was your physio session today?"\n"The kids are asking if you want to watch a movie tonight"\n"I\'ve been looking at that new AAC software you mentioned"',
|
297 |
lines=3,
|
298 |
)
|
299 |
|
300 |
# Audio input
|
301 |
with gr.Row():
|
302 |
audio_input = gr.Audio(
|
303 |
+
label="Or record what they said:",
|
304 |
type="filepath",
|
305 |
sources=["microphone"],
|
306 |
)
|
|
|
308 |
|
309 |
# Suggestion type selection
|
310 |
suggestion_type = gr.Radio(
|
311 |
+
choices=[
|
312 |
+
"auto_detect",
|
313 |
+
"model",
|
314 |
+
"common_phrases",
|
315 |
+
]
|
316 |
+
+ get_suggestion_categories(),
|
317 |
+
value="model", # Default to model for better results
|
318 |
+
label="How should I respond?",
|
319 |
+
info="Choose what kind of responses you want (model = AI-generated)",
|
320 |
)
|
321 |
|
322 |
# Generate button
|
323 |
+
generate_btn = gr.Button("Generate My Responses", variant="primary")
|
324 |
|
325 |
with gr.Column(scale=1):
|
326 |
# Common phrases
|
327 |
common_phrases = gr.Textbox(
|
328 |
+
label="My Common Phrases",
|
329 |
+
placeholder="Common phrases I often use with this person will appear here...",
|
330 |
lines=5,
|
331 |
)
|
332 |
|
333 |
# Suggestions output
|
334 |
+
suggestions_output = gr.Markdown(
|
335 |
+
label="My Suggested Responses",
|
336 |
+
value="Suggested responses will appear here...",
|
|
|
337 |
)
|
338 |
|
339 |
# Set up event handlers
|
340 |
+
def handle_person_change(person_id):
|
341 |
+
"""Handle person selection change and update UI elements."""
|
342 |
+
context_info, phrases_text, _ = on_person_change(person_id)
|
343 |
+
|
344 |
+
# Get topics for this person
|
345 |
+
topics = get_filtered_topics(person_id)
|
346 |
+
|
347 |
+
# Update the context, phrases, and topic dropdown
|
348 |
+
return context_info, phrases_text, gr.update(choices=topics)
|
349 |
+
|
350 |
+
# Set up the person change event
|
351 |
person_dropdown.change(
|
352 |
+
handle_person_change,
|
353 |
inputs=[person_dropdown],
|
354 |
+
outputs=[context_display, common_phrases, topic_dropdown],
|
355 |
)
|
356 |
|
357 |
+
# Set up the generate button click event
|
358 |
generate_btn.click(
|
359 |
generate_suggestions,
|
360 |
+
inputs=[person_dropdown, user_input, suggestion_type, topic_dropdown],
|
361 |
outputs=[suggestions_output],
|
362 |
)
|
363 |
|
demo.py
CHANGED
@@ -1,27 +1,40 @@
|
|
1 |
-
from transformers import pipeline
|
2 |
import json
|
3 |
|
4 |
# Load model
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
rag_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
|
9 |
|
10 |
# Load KG
|
11 |
with open("social_graph.json", "r") as f:
|
12 |
kg = json.load(f)
|
13 |
|
14 |
# Build context
|
15 |
-
person = kg["people"]["
|
16 |
-
context =
|
17 |
|
18 |
# User input
|
19 |
-
query = "What should I say to
|
20 |
|
21 |
# RAG-style prompt
|
22 |
-
prompt =
|
23 |
-
|
|
|
|
|
|
|
24 |
|
25 |
# Generate
|
26 |
-
response = rag_pipeline(
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import pipeline
|
2 |
import json
|
3 |
|
4 |
# Load model
|
5 |
+
|
6 |
+
# Use a simpler approach with a pre-built pipeline
|
7 |
+
rag_pipeline = pipeline("text-generation", model="distilgpt2")
|
|
|
8 |
|
9 |
# Load KG
|
10 |
with open("social_graph.json", "r") as f:
|
11 |
kg = json.load(f)
|
12 |
|
13 |
# Build context
|
14 |
+
person = kg["people"]["billy"] # Using Billy instead of Bob
|
15 |
+
context = person["context"]
|
16 |
|
17 |
# User input
|
18 |
+
query = "What should I say to Billy?"
|
19 |
|
20 |
# RAG-style prompt
|
21 |
+
prompt = """I am Will, a 38-year-old father with MND (Motor Neuron Disease). I have a 7-year-old son named Billy who loves Manchester United football.
|
22 |
+
|
23 |
+
Billy just asked me: "Dad, did you see the United match last night?"
|
24 |
+
|
25 |
+
My response to Billy:"""
|
26 |
|
27 |
# Generate
|
28 |
+
response = rag_pipeline(
|
29 |
+
prompt,
|
30 |
+
max_length=100, # Longer output
|
31 |
+
temperature=0.9, # More creative
|
32 |
+
do_sample=True,
|
33 |
+
num_return_sequences=1,
|
34 |
+
top_p=0.92, # More focused sampling
|
35 |
+
top_k=50, # Limit vocabulary
|
36 |
+
)
|
37 |
+
print("Generated response:")
|
38 |
+
# For text-generation models, we need to extract just the generated part (not the prompt)
|
39 |
+
generated_text = response[0]["generated_text"][len(prompt) :]
|
40 |
+
print(generated_text)
|
social_graph.json
CHANGED
@@ -1,4 +1,11 @@
|
|
1 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
"people": {
|
3 |
"emma": {
|
4 |
"name": "Emma",
|
|
|
1 |
{
|
2 |
+
"aac_user": {
|
3 |
+
"name": "Will",
|
4 |
+
"age": 38,
|
5 |
+
"location": "Manchester",
|
6 |
+
"background": "Originally from South East London, now living in Manchester with wife Emma and two children (Mabel, 4 and Billy, 7). Diagnosed with MND 5 months ago. Works as a computer programmer with accommodations for condition. Enjoys technology, Manchester United, and has fond memories of cycling and hiking in the Peak District and Lake District. Was active in Scouts growing up.",
|
7 |
+
"communication_needs": "Increasingly using AAC as speech becomes more difficult. Voice banking in progress. Prefers British English expressions and technical terminology when appropriate. When talking to familar people has a tendency to swear but in a friendly way."
|
8 |
+
},
|
9 |
"people": {
|
10 |
"emma": {
|
11 |
"name": "Emma",
|
utils.py
CHANGED
@@ -3,186 +3,263 @@ import random
|
|
3 |
from typing import Dict, List, Any, Optional, Tuple
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
import numpy as np
|
6 |
-
|
|
|
|
|
7 |
|
8 |
class SocialGraphManager:
|
9 |
"""Manages the social graph and provides context for the AAC system."""
|
10 |
-
|
11 |
def __init__(self, graph_path: str = "social_graph.json"):
|
12 |
"""Initialize the social graph manager.
|
13 |
-
|
14 |
Args:
|
15 |
graph_path: Path to the social graph JSON file
|
16 |
"""
|
17 |
self.graph_path = graph_path
|
18 |
self.graph = self._load_graph()
|
19 |
-
|
20 |
# Initialize sentence transformer for semantic matching
|
21 |
try:
|
22 |
-
self.sentence_model = SentenceTransformer(
|
|
|
|
|
23 |
self.embeddings_cache = {}
|
24 |
self._initialize_embeddings()
|
25 |
except Exception as e:
|
26 |
-
print(f"Warning: Could not load sentence transformer model: {e}")
|
27 |
self.sentence_model = None
|
28 |
-
|
29 |
def _load_graph(self) -> Dict[str, Any]:
|
30 |
"""Load the social graph from the JSON file."""
|
31 |
try:
|
32 |
with open(self.graph_path, "r") as f:
|
33 |
return json.load(f)
|
34 |
-
except Exception
|
35 |
-
print(f"Error loading social graph: {e}")
|
36 |
return {"people": {}, "places": [], "topics": []}
|
37 |
-
|
38 |
def _initialize_embeddings(self):
|
39 |
"""Initialize embeddings for topics and phrases in the social graph."""
|
40 |
if not self.sentence_model:
|
41 |
return
|
42 |
-
|
43 |
# Create embeddings for topics
|
44 |
topics = self.graph.get("topics", [])
|
45 |
for topic in topics:
|
46 |
if topic not in self.embeddings_cache:
|
47 |
self.embeddings_cache[topic] = self.sentence_model.encode(topic)
|
48 |
-
|
49 |
# Create embeddings for common phrases
|
50 |
for person_id, person_data in self.graph.get("people", {}).items():
|
51 |
for phrase in person_data.get("common_phrases", []):
|
52 |
if phrase not in self.embeddings_cache:
|
53 |
self.embeddings_cache[phrase] = self.sentence_model.encode(phrase)
|
54 |
-
|
55 |
# Create embeddings for common utterances
|
56 |
for category, utterances in self.graph.get("common_utterances", {}).items():
|
57 |
for utterance in utterances:
|
58 |
if utterance not in self.embeddings_cache:
|
59 |
-
self.embeddings_cache[utterance] = self.sentence_model.encode(
|
60 |
-
|
|
|
|
|
61 |
def get_people_list(self) -> List[Dict[str, str]]:
|
62 |
"""Get a list of people from the social graph with their names and roles."""
|
63 |
people = []
|
64 |
for person_id, person_data in self.graph.get("people", {}).items():
|
65 |
-
people.append(
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
70 |
return people
|
71 |
-
|
72 |
def get_person_context(self, person_id: str) -> Dict[str, Any]:
|
73 |
"""Get context information for a specific person."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
if person_id not in self.graph.get("people", {}):
|
75 |
return {}
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
80 |
"""Get relevant phrases for a specific person based on user input."""
|
81 |
if person_id not in self.graph.get("people", {}):
|
82 |
return []
|
83 |
-
|
84 |
person_data = self.graph["people"][person_id]
|
85 |
phrases = person_data.get("common_phrases", [])
|
86 |
-
|
87 |
# If no user input, return random phrases
|
88 |
if not user_input or not self.sentence_model:
|
89 |
return random.sample(phrases, min(3, len(phrases)))
|
90 |
-
|
91 |
# Use semantic search to find relevant phrases
|
92 |
user_embedding = self.sentence_model.encode(user_input)
|
93 |
phrase_scores = []
|
94 |
-
|
95 |
for phrase in phrases:
|
96 |
if phrase in self.embeddings_cache:
|
97 |
phrase_embedding = self.embeddings_cache[phrase]
|
98 |
else:
|
99 |
phrase_embedding = self.sentence_model.encode(phrase)
|
100 |
self.embeddings_cache[phrase] = phrase_embedding
|
101 |
-
|
102 |
similarity = np.dot(user_embedding, phrase_embedding) / (
|
103 |
np.linalg.norm(user_embedding) * np.linalg.norm(phrase_embedding)
|
104 |
)
|
105 |
phrase_scores.append((phrase, similarity))
|
106 |
-
|
107 |
# Sort by similarity score and return top phrases
|
108 |
phrase_scores.sort(key=lambda x: x[1], reverse=True)
|
109 |
return [phrase for phrase, _ in phrase_scores[:3]]
|
110 |
-
|
111 |
def get_common_utterances(self, category: Optional[str] = None) -> List[str]:
|
112 |
"""Get common utterances from the social graph, optionally filtered by category."""
|
113 |
utterances = []
|
114 |
-
|
115 |
if "common_utterances" not in self.graph:
|
116 |
return utterances
|
117 |
-
|
118 |
if category and category in self.graph["common_utterances"]:
|
119 |
return self.graph["common_utterances"][category]
|
120 |
-
|
121 |
# If no category specified, return a sample from each category
|
122 |
for category_utterances in self.graph["common_utterances"].values():
|
123 |
-
utterances.extend(
|
124 |
-
|
125 |
-
|
|
|
126 |
return utterances
|
127 |
|
|
|
128 |
class SuggestionGenerator:
|
129 |
"""Generates contextual suggestions for the AAC system."""
|
130 |
-
|
131 |
-
def __init__(self, model_name: str = "
|
132 |
"""Initialize the suggestion generator.
|
133 |
-
|
134 |
Args:
|
135 |
model_name: Name of the HuggingFace model to use
|
136 |
"""
|
|
|
|
|
|
|
137 |
try:
|
138 |
-
|
139 |
-
|
140 |
-
self.generator = pipeline("
|
141 |
-
model=self.model,
|
142 |
-
tokenizer=self.tokenizer)
|
143 |
self.model_loaded = True
|
|
|
144 |
except Exception as e:
|
145 |
-
print(f"
|
146 |
self.model_loaded = False
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
"""Generate a contextually appropriate suggestion.
|
153 |
-
|
154 |
Args:
|
155 |
person_context: Context information about the person
|
156 |
user_input: Optional user input to consider
|
157 |
max_length: Maximum length of the generated suggestion
|
158 |
-
|
|
|
159 |
Returns:
|
160 |
A generated suggestion string
|
161 |
"""
|
162 |
if not self.model_loaded:
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
|
|
165 |
# Extract context information
|
166 |
name = person_context.get("name", "")
|
167 |
role = person_context.get("role", "")
|
168 |
topics = ", ".join(person_context.get("topics", []))
|
169 |
context = person_context.get("context", "")
|
170 |
-
|
|
|
171 |
# Build prompt
|
172 |
-
prompt = f"""
|
173 |
-
|
174 |
-
Topics of interest: {topics}
|
175 |
"""
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
if user_input:
|
178 |
-
prompt += f
|
179 |
-
|
180 |
-
prompt += "
|
181 |
-
|
182 |
# Generate suggestion
|
183 |
try:
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
except Exception as e:
|
187 |
print(f"Error generating suggestion: {e}")
|
188 |
return "Could not generate a suggestion. Please try again."
|
|
|
3 |
from typing import Dict, List, Any, Optional, Tuple
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
import numpy as np
|
6 |
+
|
7 |
+
from transformers import pipeline
|
8 |
+
|
9 |
|
10 |
class SocialGraphManager:
|
11 |
"""Manages the social graph and provides context for the AAC system."""
|
12 |
+
|
13 |
def __init__(self, graph_path: str = "social_graph.json"):
|
14 |
"""Initialize the social graph manager.
|
15 |
+
|
16 |
Args:
|
17 |
graph_path: Path to the social graph JSON file
|
18 |
"""
|
19 |
self.graph_path = graph_path
|
20 |
self.graph = self._load_graph()
|
21 |
+
|
22 |
# Initialize sentence transformer for semantic matching
|
23 |
try:
|
24 |
+
self.sentence_model = SentenceTransformer(
|
25 |
+
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
26 |
+
)
|
27 |
self.embeddings_cache = {}
|
28 |
self._initialize_embeddings()
|
29 |
except Exception as e:
|
|
|
30 |
self.sentence_model = None
|
31 |
+
|
32 |
def _load_graph(self) -> Dict[str, Any]:
|
33 |
"""Load the social graph from the JSON file."""
|
34 |
try:
|
35 |
with open(self.graph_path, "r") as f:
|
36 |
return json.load(f)
|
37 |
+
except Exception:
|
|
|
38 |
return {"people": {}, "places": [], "topics": []}
|
39 |
+
|
40 |
def _initialize_embeddings(self):
|
41 |
"""Initialize embeddings for topics and phrases in the social graph."""
|
42 |
if not self.sentence_model:
|
43 |
return
|
44 |
+
|
45 |
# Create embeddings for topics
|
46 |
topics = self.graph.get("topics", [])
|
47 |
for topic in topics:
|
48 |
if topic not in self.embeddings_cache:
|
49 |
self.embeddings_cache[topic] = self.sentence_model.encode(topic)
|
50 |
+
|
51 |
# Create embeddings for common phrases
|
52 |
for person_id, person_data in self.graph.get("people", {}).items():
|
53 |
for phrase in person_data.get("common_phrases", []):
|
54 |
if phrase not in self.embeddings_cache:
|
55 |
self.embeddings_cache[phrase] = self.sentence_model.encode(phrase)
|
56 |
+
|
57 |
# Create embeddings for common utterances
|
58 |
for category, utterances in self.graph.get("common_utterances", {}).items():
|
59 |
for utterance in utterances:
|
60 |
if utterance not in self.embeddings_cache:
|
61 |
+
self.embeddings_cache[utterance] = self.sentence_model.encode(
|
62 |
+
utterance
|
63 |
+
)
|
64 |
+
|
65 |
def get_people_list(self) -> List[Dict[str, str]]:
|
66 |
"""Get a list of people from the social graph with their names and roles."""
|
67 |
people = []
|
68 |
for person_id, person_data in self.graph.get("people", {}).items():
|
69 |
+
people.append(
|
70 |
+
{
|
71 |
+
"id": person_id,
|
72 |
+
"name": person_data.get("name", person_id),
|
73 |
+
"role": person_data.get("role", ""),
|
74 |
+
}
|
75 |
+
)
|
76 |
return people
|
77 |
+
|
78 |
def get_person_context(self, person_id: str) -> Dict[str, Any]:
|
79 |
"""Get context information for a specific person."""
|
80 |
+
# Check if the person_id contains a display name (e.g., "Emma (wife)")
|
81 |
+
# and try to extract the actual ID
|
82 |
+
if person_id not in self.graph.get("people", {}):
|
83 |
+
# Try to find the person by name
|
84 |
+
for pid, pdata in self.graph.get("people", {}).items():
|
85 |
+
name = pdata.get("name", "")
|
86 |
+
role = pdata.get("role", "")
|
87 |
+
if f"{name} ({role})" == person_id:
|
88 |
+
person_id = pid
|
89 |
+
break
|
90 |
+
|
91 |
+
# If still not found, return empty dict
|
92 |
if person_id not in self.graph.get("people", {}):
|
93 |
return {}
|
94 |
+
|
95 |
+
person_data = self.graph["people"][person_id]
|
96 |
+
return person_data
|
97 |
+
|
98 |
+
def get_relevant_phrases(
|
99 |
+
self, person_id: str, user_input: Optional[str] = None
|
100 |
+
) -> List[str]:
|
101 |
"""Get relevant phrases for a specific person based on user input."""
|
102 |
if person_id not in self.graph.get("people", {}):
|
103 |
return []
|
104 |
+
|
105 |
person_data = self.graph["people"][person_id]
|
106 |
phrases = person_data.get("common_phrases", [])
|
107 |
+
|
108 |
# If no user input, return random phrases
|
109 |
if not user_input or not self.sentence_model:
|
110 |
return random.sample(phrases, min(3, len(phrases)))
|
111 |
+
|
112 |
# Use semantic search to find relevant phrases
|
113 |
user_embedding = self.sentence_model.encode(user_input)
|
114 |
phrase_scores = []
|
115 |
+
|
116 |
for phrase in phrases:
|
117 |
if phrase in self.embeddings_cache:
|
118 |
phrase_embedding = self.embeddings_cache[phrase]
|
119 |
else:
|
120 |
phrase_embedding = self.sentence_model.encode(phrase)
|
121 |
self.embeddings_cache[phrase] = phrase_embedding
|
122 |
+
|
123 |
similarity = np.dot(user_embedding, phrase_embedding) / (
|
124 |
np.linalg.norm(user_embedding) * np.linalg.norm(phrase_embedding)
|
125 |
)
|
126 |
phrase_scores.append((phrase, similarity))
|
127 |
+
|
128 |
# Sort by similarity score and return top phrases
|
129 |
phrase_scores.sort(key=lambda x: x[1], reverse=True)
|
130 |
return [phrase for phrase, _ in phrase_scores[:3]]
|
131 |
+
|
132 |
def get_common_utterances(self, category: Optional[str] = None) -> List[str]:
|
133 |
"""Get common utterances from the social graph, optionally filtered by category."""
|
134 |
utterances = []
|
135 |
+
|
136 |
if "common_utterances" not in self.graph:
|
137 |
return utterances
|
138 |
+
|
139 |
if category and category in self.graph["common_utterances"]:
|
140 |
return self.graph["common_utterances"][category]
|
141 |
+
|
142 |
# If no category specified, return a sample from each category
|
143 |
for category_utterances in self.graph["common_utterances"].values():
|
144 |
+
utterances.extend(
|
145 |
+
random.sample(category_utterances, min(2, len(category_utterances)))
|
146 |
+
)
|
147 |
+
|
148 |
return utterances
|
149 |
|
150 |
+
|
151 |
class SuggestionGenerator:
|
152 |
"""Generates contextual suggestions for the AAC system."""
|
153 |
+
|
154 |
+
def __init__(self, model_name: str = "distilgpt2"):
|
155 |
"""Initialize the suggestion generator.
|
156 |
+
|
157 |
Args:
|
158 |
model_name: Name of the HuggingFace model to use
|
159 |
"""
|
160 |
+
self.model_name = model_name
|
161 |
+
self.model_loaded = False
|
162 |
+
|
163 |
try:
|
164 |
+
print(f"Loading model: {model_name}")
|
165 |
+
# Use a simpler approach with a pre-built pipeline
|
166 |
+
self.generator = pipeline("text-generation", model=model_name)
|
|
|
|
|
167 |
self.model_loaded = True
|
168 |
+
print(f"Model loaded successfully: {model_name}")
|
169 |
except Exception as e:
|
170 |
+
print(f"Error loading model: {e}")
|
171 |
self.model_loaded = False
|
172 |
+
|
173 |
+
# Fallback responses if model fails to load or generate
|
174 |
+
self.fallback_responses = [
|
175 |
+
"I'm not sure how to respond to that.",
|
176 |
+
"That's interesting. Tell me more.",
|
177 |
+
"I'd like to talk about that further.",
|
178 |
+
"I appreciate you sharing that with me.",
|
179 |
+
]
|
180 |
+
|
181 |
+
def test_model(self) -> str:
|
182 |
+
"""Test if the model is working correctly."""
|
183 |
+
if not self.model_loaded:
|
184 |
+
return "Model not loaded"
|
185 |
+
|
186 |
+
try:
|
187 |
+
test_prompt = "I am Will. My son Billy asked about football. I respond:"
|
188 |
+
print(f"Testing model with prompt: {test_prompt}")
|
189 |
+
response = self.generator(test_prompt, max_length=30, do_sample=True)
|
190 |
+
result = response[0]["generated_text"][len(test_prompt) :]
|
191 |
+
print(f"Test response: {result}")
|
192 |
+
return f"Model test successful: {result}"
|
193 |
+
except Exception as e:
|
194 |
+
print(f"Error testing model: {e}")
|
195 |
+
return f"Model test failed: {str(e)}"
|
196 |
+
|
197 |
+
def generate_suggestion(
|
198 |
+
self,
|
199 |
+
person_context: Dict[str, Any],
|
200 |
+
user_input: Optional[str] = None,
|
201 |
+
max_length: int = 50,
|
202 |
+
temperature: float = 0.7,
|
203 |
+
) -> str:
|
204 |
"""Generate a contextually appropriate suggestion.
|
205 |
+
|
206 |
Args:
|
207 |
person_context: Context information about the person
|
208 |
user_input: Optional user input to consider
|
209 |
max_length: Maximum length of the generated suggestion
|
210 |
+
temperature: Controls randomness in generation (higher = more random)
|
211 |
+
|
212 |
Returns:
|
213 |
A generated suggestion string
|
214 |
"""
|
215 |
if not self.model_loaded:
|
216 |
+
# Use fallback responses if model isn't loaded
|
217 |
+
import random
|
218 |
+
|
219 |
+
print("Model not loaded, using fallback responses")
|
220 |
+
return random.choice(self.fallback_responses)
|
221 |
+
|
222 |
# Extract context information
|
223 |
name = person_context.get("name", "")
|
224 |
role = person_context.get("role", "")
|
225 |
topics = ", ".join(person_context.get("topics", []))
|
226 |
context = person_context.get("context", "")
|
227 |
+
selected_topic = person_context.get("selected_topic", "")
|
228 |
+
|
229 |
# Build prompt
|
230 |
+
prompt = f"""I am Will, a person with MND (Motor Neuron Disease).
|
231 |
+
I'm talking to {name}, who is my {role}.
|
|
|
232 |
"""
|
233 |
+
|
234 |
+
if context:
|
235 |
+
prompt += f"Context: {context}\n"
|
236 |
+
|
237 |
+
if topics:
|
238 |
+
prompt += f"Topics of interest: {topics}\n"
|
239 |
+
|
240 |
+
if selected_topic:
|
241 |
+
prompt += f"We're currently talking about: {selected_topic}\n"
|
242 |
+
|
243 |
if user_input:
|
244 |
+
prompt += f'\n{name} just said to me: "{user_input}"\n'
|
245 |
+
|
246 |
+
prompt += "\nMy response:"
|
247 |
+
|
248 |
# Generate suggestion
|
249 |
try:
|
250 |
+
print(f"Generating suggestion with prompt: {prompt}")
|
251 |
+
response = self.generator(
|
252 |
+
prompt,
|
253 |
+
max_length=len(prompt.split()) + max_length,
|
254 |
+
temperature=temperature,
|
255 |
+
do_sample=True,
|
256 |
+
top_p=0.92,
|
257 |
+
top_k=50,
|
258 |
+
)
|
259 |
+
# Extract only the generated part, not the prompt
|
260 |
+
result = response[0]["generated_text"][len(prompt) :]
|
261 |
+
print(f"Generated response: {result}")
|
262 |
+
return result.strip()
|
263 |
except Exception as e:
|
264 |
print(f"Error generating suggestion: {e}")
|
265 |
return "Could not generate a suggestion. Please try again."
|