willwade commited on
Commit
36fcf07
·
1 Parent(s): 238c097

change to gemini

Browse files
Files changed (6) hide show
  1. app.py +181 -63
  2. llm_interface.py +345 -0
  3. print_test.py +3 -0
  4. requirements.txt +1 -0
  5. test_app.py +54 -0
  6. utils.py +552 -85
app.py CHANGED
@@ -1,29 +1,48 @@
1
  import gradio as gr
2
  import whisper
3
- import tempfile
4
- import os
5
- from utils import SocialGraphManager, SuggestionGenerator
 
6
 
7
  # Define available models
8
  AVAILABLE_MODELS = {
9
- "google/gemma-3-1b-it": "Gemma 3 1B-IT (Small, instruction-tuned)",
10
- "google/gemma-3-4b-it": "Gemma 3 4B-IT (Default, instruction-tuned)",
11
- "google/gemma-3-12b-it": "Gemma 3 12B-IT (Better quality, instruction-tuned)",
12
- "google/gemma-3-27b-it": "Gemma 3 27B-IT (Best quality, instruction-tuned)",
13
- "Qwen/Qwen1.5-0.5B": "Qwen 1.5 0.5B (Very small, efficient)",
14
- "Qwen/Qwen1.5-1.8B": "Qwen 1.5 1.8B (Small, good quality)",
15
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0": "TinyLlama 1.1B (Small, chat-tuned)",
16
- "microsoft/phi-3-mini-4k-instruct": "Phi-3 Mini (Small, instruction-tuned)",
17
- "microsoft/phi-2": "Phi-2 (Small, high quality for size)",
18
- "distilgpt2": "DistilGPT2 (Fast, smaller model)",
19
- "gpt2": "GPT-2 (Medium size, better quality)",
20
  }
21
 
22
  # Initialize the social graph manager
23
  social_graph = SocialGraphManager("social_graph.json")
24
 
25
- # Initialize the suggestion generator with Gemma 3 1B (default - smaller model to save memory)
26
- suggestion_generator = SuggestionGenerator("google/gemma-3-1b-it")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Test the model to make sure it's working
29
  test_result = suggestion_generator.test_model()
@@ -137,15 +156,28 @@ def change_model(model_name, progress=gr.Progress()):
137
  # Show progress indicator
138
  progress(0, desc=f"Loading model: {model_name}")
139
 
140
- # Try to load the new model
141
- success = suggestion_generator.load_model(model_name)
142
-
143
- if success:
144
- progress(1.0, desc=f"Model loaded: {model_name}")
145
- return f"Successfully switched to model: {model_name}"
146
- else:
147
- progress(1.0, desc="Model loading failed")
148
- return f"Failed to load model: {model_name}. Using fallback responses instead."
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
 
151
  def generate_suggestions(
@@ -153,7 +185,7 @@ def generate_suggestions(
153
  user_input,
154
  suggestion_type,
155
  selected_topic=None,
156
- model_name="google/gemma-3-1b-it",
157
  temperature=0.7,
158
  mood=3,
159
  progress=gr.Progress(),
@@ -232,6 +264,9 @@ def generate_suggestions(
232
  if selected_topic:
233
  person_context["selected_topic"] = selected_topic
234
 
 
 
 
235
  # Format the output with multiple suggestions
236
  result = ""
237
 
@@ -240,31 +275,40 @@ def generate_suggestions(
240
  print("Using model for suggestions")
241
  progress(0.2, desc="Preparing to generate suggestions...")
242
 
243
- # Generate 3 different suggestions
244
- suggestions = []
245
- for i in range(3):
246
- progress_value = 0.3 + (i * 0.2) # Progress from 30% to 70%
247
- progress(progress_value, desc=f"Generating suggestion {i+1}/3")
248
- print(f"Generating suggestion {i+1}/3")
249
- try:
250
- # Add mood to person context
251
- person_context["mood"] = mood
252
- suggestion = suggestion_generator.generate_suggestion(
253
- person_context, user_input, temperature=temperature
254
- )
255
- print(f"Generated suggestion: {suggestion}")
256
- suggestions.append(suggestion)
257
- except Exception as e:
258
- print(f"Error generating suggestion: {e}")
259
- suggestions.append("Error generating suggestion")
260
 
261
- result = (
262
- f"### AI-Generated Responses (using {suggestion_generator.model_name}):\n\n"
263
- )
264
- for i, suggestion in enumerate(suggestions, 1):
265
- result += f"{i}. {suggestion}\n\n"
 
 
 
 
 
 
 
 
266
 
267
- print(f"Final result: {result[:100]}...")
 
 
 
 
 
 
 
 
268
 
269
  # If suggestion type is "common_phrases", use the person's common phrases
270
  elif clean_suggestion_type == "common_phrases":
@@ -288,23 +332,87 @@ def generate_suggestions(
288
  progress(0.3, desc="No category detected, using model instead...")
289
  try:
290
  suggestions = []
 
 
 
291
  for i in range(3):
292
  progress_value = 0.4 + (i * 0.15) # Progress from 40% to 70%
293
  progress(
294
  progress_value, desc=f"Generating fallback suggestion {i+1}/3"
295
  )
296
- # Add mood to person context
297
- person_context["mood"] = mood
298
- suggestion = suggestion_generator.generate_suggestion(
299
- person_context, user_input, temperature=temperature
300
- )
301
- suggestions.append(suggestion)
302
-
303
- result = f"### AI-Generated Responses (no category detected, using {suggestion_generator.model_name}):\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  for i, suggestion in enumerate(suggestions, 1):
305
  result += f"{i}. {suggestion}\n\n"
306
  except Exception as e:
307
  print(f"Error generating fallback suggestion: {e}")
 
308
  result = "### Could not generate a response:\n\n"
309
  result += "1. Sorry, I couldn't generate a suggestion at this time.\n\n"
310
 
@@ -334,13 +442,19 @@ def generate_suggestions(
334
  print(f"Result type: {type(result)}")
335
  print(f"Result length: {len(result)}")
336
 
337
- # Complete the progress
338
- progress(1.0, desc="Completed!")
339
-
340
  # Make sure we're returning a non-empty string
341
  if not result or len(result.strip()) == 0:
342
  result = "No response was generated. Please try again with different settings."
343
 
 
 
 
 
 
 
 
 
 
344
  return result
345
 
346
 
@@ -462,9 +576,9 @@ with gr.Blocks(title="Will's AAC Communication Aid", css="custom.css") as demo:
462
  with gr.Row():
463
  model_dropdown = gr.Dropdown(
464
  choices=list(AVAILABLE_MODELS.keys()),
465
- value="google/gemma-3-1b-it",
466
  label="Language Model",
467
- info="Select which AI model to use for generating responses",
468
  )
469
 
470
  temperature_slider = gr.Slider(
@@ -556,4 +670,8 @@ with gr.Blocks(title="Will's AAC Communication Aid", css="custom.css") as demo:
556
 
557
  # Launch the app
558
  if __name__ == "__main__":
559
- demo.launch()
 
 
 
 
 
1
  import gradio as gr
2
  import whisper
3
+ import random
4
+ import time
5
+ from utils import SocialGraphManager
6
+ from llm_interface import LLMInterface
7
 
8
  # Define available models
9
  AVAILABLE_MODELS = {
10
+ # Gemini models (online API)
11
+ "gemini-1.5-flash-latest": "🌐 Gemini 1.5 Flash (Online API - Fast, Recommended)",
12
+ "gemini-1.5-pro-latest": "🌐 Gemini 1.5 Pro (Online API - High quality)",
13
+ # OpenAI models (if API key is set)
14
+ "gpt-3.5-turbo": "🌐 ChatGPT 3.5 (Online API)",
15
+ "gpt-4o-mini": "🌐 GPT-4o Mini (Online API - Fast)",
16
+ # Ollama models (if installed locally)
17
+ "ollama/gemma:7b": "💻 Gemma 7B (Offline - requires Ollama)",
18
+ "ollama/llama3:8b": "💻 Llama 3 8B (Offline - requires Ollama)",
 
 
19
  }
20
 
21
  # Initialize the social graph manager
22
  social_graph = SocialGraphManager("social_graph.json")
23
 
24
+ # Initialize the suggestion generator with a fast online model by default
25
+ print("Initializing with Gemini 1.5 Flash (online model)")
26
+ suggestion_generator = LLMInterface("gemini-1.5-flash-latest")
27
+
28
+ # Test the model to make sure it's working
29
+ print("Testing model connection...")
30
+ test_result = suggestion_generator.test_model()
31
+ print(f"Model test result: {test_result}")
32
+
33
+ # If the model didn't load, try Ollama as fallback
34
+ if not suggestion_generator.model_loaded:
35
+ print("Online model not available, trying Ollama model...")
36
+ suggestion_generator = LLMInterface("ollama/gemma:7b")
37
+ test_result = suggestion_generator.test_model()
38
+ print(f"Ollama model test result: {test_result}")
39
+
40
+ # If Ollama also fails, try OpenAI as fallback
41
+ if not suggestion_generator.model_loaded:
42
+ print("Ollama not available, trying OpenAI model...")
43
+ suggestion_generator = LLMInterface("gpt-3.5-turbo")
44
+ test_result = suggestion_generator.test_model()
45
+ print(f"OpenAI model test result: {test_result}")
46
 
47
  # Test the model to make sure it's working
48
  test_result = suggestion_generator.test_model()
 
156
  # Show progress indicator
157
  progress(0, desc=f"Loading model: {model_name}")
158
 
159
+ # Create a new LLMInterface with the selected model
160
+ try:
161
+ progress(0.3, desc=f"Initializing {model_name}...")
162
+ new_generator = LLMInterface(model_name)
163
+
164
+ # Test if the model works
165
+ progress(0.6, desc="Testing model connection...")
166
+ test_result = new_generator.test_model()
167
+ print(f"Model test result: {test_result}")
168
+
169
+ if new_generator.model_loaded:
170
+ # Replace the current generator with the new one
171
+ suggestion_generator = new_generator
172
+ progress(1.0, desc=f"Model loaded: {model_name}")
173
+ return f"Successfully switched to model: {model_name}"
174
+ else:
175
+ progress(1.0, desc="Model loading failed")
176
+ return f"Failed to load model: {model_name}. Using previous model instead."
177
+ except Exception as e:
178
+ print(f"Error changing model: {e}")
179
+ progress(1.0, desc="Error loading model")
180
+ return f"Error loading model: {model_name}. Using previous model instead."
181
 
182
 
183
  def generate_suggestions(
 
185
  user_input,
186
  suggestion_type,
187
  selected_topic=None,
188
+ model_name="gemini-1.5-flash",
189
  temperature=0.7,
190
  mood=3,
191
  progress=gr.Progress(),
 
264
  if selected_topic:
265
  person_context["selected_topic"] = selected_topic
266
 
267
+ # Add mood to person context
268
+ person_context["mood"] = mood
269
+
270
  # Format the output with multiple suggestions
271
  result = ""
272
 
 
275
  print("Using model for suggestions")
276
  progress(0.2, desc="Preparing to generate suggestions...")
277
 
278
+ # Generate suggestions using the LLM interface
279
+ try:
280
+ # Use the LLM interface to generate multiple suggestions
281
+ suggestions = suggestion_generator.generate_multiple_suggestions(
282
+ person_context=person_context,
283
+ user_input=user_input,
284
+ num_suggestions=3,
285
+ temperature=temperature,
286
+ progress_callback=lambda p, desc: progress(0.2 + (p * 0.7), desc=desc),
287
+ )
 
 
 
 
 
 
 
288
 
289
+ # Make sure we have at least one suggestion
290
+ if not suggestions:
291
+ suggestions = ["I'm not sure what to say about that."]
292
+
293
+ # Make sure we have exactly 3 suggestions (pad with fallbacks if needed)
294
+ while len(suggestions) < 3:
295
+ suggestions.append("I'm not sure what else to say about that.")
296
+
297
+ result = f"### AI-Generated Responses (using {suggestion_generator.model_name}):\n\n"
298
+ for i, suggestion in enumerate(suggestions, 1):
299
+ result += f"{i}. {suggestion}\n\n"
300
+
301
+ print(f"Final result: {result[:100]}...")
302
 
303
+ except Exception as e:
304
+ print(f"Error generating suggestions: {e}")
305
+ result = "### Error generating suggestions:\n\n"
306
+ result += "1. I'm having trouble generating responses right now.\n\n"
307
+ result += "2. Please try again or select a different model.\n\n"
308
+ result += "3. You might want to check your internet connection if using an online model.\n\n"
309
+
310
+ # Force a complete progress update before returning
311
+ progress(0.9, desc="Finalizing suggestions...")
312
 
313
  # If suggestion type is "common_phrases", use the person's common phrases
314
  elif clean_suggestion_type == "common_phrases":
 
332
  progress(0.3, desc="No category detected, using model instead...")
333
  try:
334
  suggestions = []
335
+ # Set a timeout for each suggestion generation (10 seconds)
336
+ timeout_per_suggestion = 10
337
+
338
  for i in range(3):
339
  progress_value = 0.4 + (i * 0.15) # Progress from 40% to 70%
340
  progress(
341
  progress_value, desc=f"Generating fallback suggestion {i+1}/3"
342
  )
343
+ try:
344
+ # Add mood to person context
345
+ person_context["mood"] = mood
346
+
347
+ # Set a start time for timeout tracking
348
+ start_time = time.time()
349
+
350
+ # Try to generate a suggestion with timeout
351
+ suggestion = None
352
+
353
+ # If model isn't loaded, use fallback immediately
354
+ if not suggestion_generator.model_loaded:
355
+ print("Model not loaded, using fallback response")
356
+ suggestion = random.choice(
357
+ suggestion_generator.fallback_responses
358
+ )
359
+ else:
360
+ # Try to generate with the model
361
+ suggestion = suggestion_generator.generate_suggestion(
362
+ person_context, user_input, temperature=temperature
363
+ )
364
+
365
+ # Check if generation took too long
366
+ if time.time() - start_time > timeout_per_suggestion:
367
+ print(
368
+ f"Fallback suggestion {i+1} generation timed out, using fallback"
369
+ )
370
+ suggestion = (
371
+ "I'm not sure what to say about that right now."
372
+ )
373
+
374
+ # Only add non-empty suggestions
375
+ if suggestion and suggestion.strip():
376
+ suggestions.append(suggestion.strip())
377
+ else:
378
+ print("Empty fallback suggestion received, using default")
379
+ suggestions.append("I'm not sure what to say about that.")
380
+
381
+ # Force a progress update after each suggestion
382
+ progress(
383
+ 0.4 + (i * 0.15) + 0.05,
384
+ desc=f"Completed fallback suggestion {i+1}/3",
385
+ )
386
+
387
+ except Exception as e:
388
+ print(f"Error generating fallback suggestion {i+1}: {e}")
389
+ suggestions.append("I'm having trouble responding to that.")
390
+ # Force a progress update even after error
391
+ progress(
392
+ 0.4 + (i * 0.15) + 0.05,
393
+ desc=f"Error in fallback suggestion {i+1}/3",
394
+ )
395
+
396
+ # Small delay to ensure UI updates
397
+ time.sleep(0.2)
398
+
399
+ # Make sure we have at least one suggestion
400
+ if not suggestions:
401
+ suggestions = ["I'm not sure what to say about that."]
402
+
403
+ # Make sure we have exactly 3 suggestions (pad with fallbacks if needed)
404
+ while len(suggestions) < 3:
405
+ suggestions.append("I'm not sure what else to say about that.")
406
+
407
+ # Force a progress update
408
+ progress(0.85, desc="Finalizing fallback suggestions...")
409
+
410
+ result = "### AI-Generated Responses (no category detected):\n\n"
411
  for i, suggestion in enumerate(suggestions, 1):
412
  result += f"{i}. {suggestion}\n\n"
413
  except Exception as e:
414
  print(f"Error generating fallback suggestion: {e}")
415
+ progress(0.9, desc="Error handling...")
416
  result = "### Could not generate a response:\n\n"
417
  result += "1. Sorry, I couldn't generate a suggestion at this time.\n\n"
418
 
 
442
  print(f"Result type: {type(result)}")
443
  print(f"Result length: {len(result)}")
444
 
 
 
 
445
  # Make sure we're returning a non-empty string
446
  if not result or len(result.strip()) == 0:
447
  result = "No response was generated. Please try again with different settings."
448
 
449
+ # Always complete the progress to 100% before returning
450
+ progress(1.0, desc="Completed!")
451
+
452
+ # Add a small delay to ensure UI updates properly
453
+ time.sleep(0.5)
454
+
455
+ # Print final status
456
+ print("Generation completed successfully, returning result")
457
+
458
  return result
459
 
460
 
 
576
  with gr.Row():
577
  model_dropdown = gr.Dropdown(
578
  choices=list(AVAILABLE_MODELS.keys()),
579
+ value="gemini-1.5-flash-latest",
580
  label="Language Model",
581
+ info="Select which AI model to use (🌐 = online API, 💻 = offline model)",
582
  )
583
 
584
  temperature_slider = gr.Slider(
 
670
 
671
  # Launch the app
672
  if __name__ == "__main__":
673
+ print("Starting application...")
674
+ try:
675
+ demo.launch()
676
+ except Exception as e:
677
+ print(f"Error launching application: {e}")
llm_interface.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Interface for the AAC app using Simon Willison's LLM library.
3
+ """
4
+
5
+ import subprocess
6
+ import time
7
+ from typing import List, Optional, Dict, Any
8
+
9
+
10
+ class LLMInterface:
11
+ """Interface for Simon Willison's LLM tool."""
12
+
13
+ def __init__(
14
+ self,
15
+ model_name: str = "gemini-1.5-flash",
16
+ max_length: int = 150,
17
+ temperature: float = 0.7,
18
+ ):
19
+ """Initialize the LLM interface.
20
+
21
+ Args:
22
+ model_name: Name of the model to use
23
+ max_length: Maximum length of generated text
24
+ temperature: Controls randomness (higher = more random)
25
+ """
26
+ self.model_name = model_name
27
+ self.max_length = max_length
28
+ self.temperature = temperature
29
+ self.model_loaded = self._check_llm_installed()
30
+ self.fallback_responses = [
31
+ "I'm not sure how to respond to that.",
32
+ "That's interesting. Tell me more.",
33
+ "I'd like to talk about that further.",
34
+ "I appreciate you sharing that with me.",
35
+ "Could we talk about something else?",
36
+ "I need some time to think about that.",
37
+ ]
38
+
39
+ def _check_llm_installed(self) -> bool:
40
+ """Check if the LLM tool is installed and working."""
41
+ try:
42
+ result = subprocess.run(
43
+ ["llm", "--version"],
44
+ capture_output=True,
45
+ text=True,
46
+ timeout=5, # Add a timeout to prevent hanging
47
+ )
48
+ if result.returncode == 0:
49
+ print(f"LLM tool is installed: {result.stdout.strip()}")
50
+
51
+ # Also check if the model exists
52
+ try:
53
+ # Just check if the model is in the list of available models
54
+ model_check = subprocess.run(
55
+ ["llm", "models"],
56
+ capture_output=True,
57
+ text=True,
58
+ timeout=5,
59
+ )
60
+
61
+ if model_check.returncode == 0:
62
+ if self.model_name in model_check.stdout:
63
+ print(f"Model {self.model_name} is available")
64
+ return True
65
+ else:
66
+ print(
67
+ f"Model {self.model_name} not found in available models"
68
+ )
69
+ # Try to find similar models
70
+ if "gemini" in self.model_name.lower():
71
+ print("Available Gemini models:")
72
+ for line in model_check.stdout.splitlines():
73
+ if "gemini" in line.lower():
74
+ print(f" {line}")
75
+ return False
76
+ else:
77
+ print("Error checking available models")
78
+ return False
79
+
80
+ except Exception as model_error:
81
+ print(f"Error checking model availability: {model_error}")
82
+ return False
83
+ else:
84
+ print("LLM tool returned an error.")
85
+ return False
86
+ except subprocess.TimeoutExpired:
87
+ print("Timeout checking LLM tool installation")
88
+ return False
89
+ except Exception as e:
90
+ print(f"Error checking LLM tool: {e}")
91
+ return False
92
+
93
+ def _get_max_tokens_param(self) -> str:
94
+ """Get the appropriate max tokens parameter name for the model."""
95
+ if "gemini" in self.model_name.lower():
96
+ return "max_output_tokens"
97
+ else:
98
+ return "max_tokens"
99
+
100
+ def generate_suggestion(
101
+ self,
102
+ person_context: Dict[str, Any],
103
+ user_input: Optional[str] = None,
104
+ temperature: Optional[float] = None,
105
+ progress_callback=None,
106
+ ) -> str:
107
+ """Generate a suggestion based on the person context and user input.
108
+
109
+ Args:
110
+ person_context: Context information about the person
111
+ user_input: Optional user input to consider
112
+ temperature: Controls randomness in generation (higher = more random)
113
+ progress_callback: Optional callback function to report progress
114
+
115
+ Returns:
116
+ A generated suggestion string
117
+ """
118
+ if not self.model_loaded:
119
+ import random
120
+
121
+ return random.choice(self.fallback_responses)
122
+
123
+ # Extract context information
124
+ name = person_context.get("name", "")
125
+ role = person_context.get("role", "")
126
+ topics = person_context.get("topics", [])
127
+ context = person_context.get("context", "")
128
+ selected_topic = person_context.get("selected_topic", "")
129
+ common_phrases = person_context.get("common_phrases", [])
130
+ frequency = person_context.get("frequency", "")
131
+ mood = person_context.get("mood", 3) # Default to neutral mood (3)
132
+
133
+ # Get mood description
134
+ mood_descriptions = {
135
+ 1: "I'm feeling quite down and sad today. My responses might be more subdued.",
136
+ 2: "I'm feeling a bit low today. I might be less enthusiastic than usual.",
137
+ 3: "I'm feeling okay today - neither particularly happy nor sad.",
138
+ 4: "I'm feeling pretty good today. I'm in a positive mood.",
139
+ 5: "I'm feeling really happy and upbeat today! I'm in a great mood.",
140
+ }
141
+ mood_description = mood_descriptions.get(mood, mood_descriptions[3])
142
+
143
+ # Build enhanced prompt
144
+ prompt = f"""I am Will, a 38-year-old with MND (Motor Neuron Disease) from Manchester.
145
+ I am talking to {name}, who is my {role}.
146
+ About {name}: {context}
147
+ We typically talk about: {', '.join(topics)}
148
+ We communicate {frequency}.
149
+
150
+ My current mood: {mood_description}
151
+ """
152
+
153
+ # Add communication style based on relationship
154
+ if role in ["wife", "son", "daughter", "mother", "father"]:
155
+ prompt += "I communicate with my family in a warm, loving way, sometimes using inside jokes.\n"
156
+ elif role in ["doctor", "therapist", "nurse"]:
157
+ prompt += "I communicate with healthcare providers in a direct, informative way.\n"
158
+ elif role in ["best mate", "friend"]:
159
+ prompt += "I communicate with friends casually, often with humor and sometimes swearing.\n"
160
+ elif role in ["work colleague", "boss"]:
161
+ prompt += (
162
+ "I communicate with colleagues professionally but still friendly.\n"
163
+ )
164
+
165
+ # Add topic information if provided
166
+ if selected_topic:
167
+ prompt += f"\nWe are currently discussing {selected_topic}.\n"
168
+
169
+ # Add the user's message if provided, or set up for conversation initiation
170
+ if user_input:
171
+ # If user input is provided, we're responding to something
172
+ prompt += f'\n{name} just said to me: "{user_input}"\n'
173
+ prompt += f"I want to respond directly to what {name} just said.\n"
174
+ else:
175
+ # No user input means we're initiating a conversation
176
+ if selected_topic:
177
+ # If a topic is selected, initiate conversation about that topic
178
+ prompt += f"\nI'm about to start a conversation with {name} about {selected_topic}.\n"
179
+ prompt += f"I want to initiate a conversation about {selected_topic} in a natural way.\n"
180
+ else:
181
+ # Generic conversation starter
182
+ prompt += f"\nI'm about to start a conversation with {name}.\n"
183
+ prompt += "I want to initiate a conversation in a natural way based on our relationship.\n"
184
+
185
+ # Add the response prompt with specific guidance
186
+ if user_input:
187
+ # Responding to something
188
+ prompt += f"""
189
+ I am Will, the person with MND. I want to respond to {name}'s message: "{user_input}"
190
+ My response should be natural, brief (1-2 sentences), and directly relevant to what {name} just said.
191
+ I'll use language appropriate for our relationship and speak as myself (Will).
192
+
193
+ My response to {name}:"""
194
+ else:
195
+ # Initiating a conversation
196
+ prompt += f"""
197
+ I am Will, the person with MND. I want to start a conversation with {name}.
198
+ My conversation starter should be natural, brief (1-2 sentences), and appropriate for our relationship.
199
+ I'll speak in first person as myself (Will).
200
+
201
+ My conversation starter to {name}:"""
202
+
203
+ # Use the provided temperature or default
204
+ temp = temperature if temperature is not None else self.temperature
205
+
206
+ # Update progress if callback provided
207
+ if progress_callback:
208
+ progress_callback(0.3, desc="Sending prompt to LLM...")
209
+
210
+ try:
211
+ # Get the appropriate max tokens parameter
212
+ max_tokens_param = self._get_max_tokens_param()
213
+
214
+ # Call the LLM tool
215
+ result = subprocess.run(
216
+ [
217
+ "llm",
218
+ "-m",
219
+ self.model_name,
220
+ "-s",
221
+ f"temperature={temp}",
222
+ "-s",
223
+ f"{max_tokens_param}={self.max_length}",
224
+ prompt,
225
+ ],
226
+ capture_output=True,
227
+ text=True,
228
+ timeout=15, # Add timeout to prevent hanging
229
+ )
230
+
231
+ if progress_callback:
232
+ progress_callback(0.7, desc="Processing response...")
233
+
234
+ if result.returncode == 0:
235
+ # Get the generated text
236
+ generated = result.stdout.strip()
237
+
238
+ # Clean up the response if needed
239
+ if not generated:
240
+ generated = "I'm not sure what to say about that."
241
+
242
+ if progress_callback:
243
+ progress_callback(0.9, desc="Response generated successfully")
244
+
245
+ return generated
246
+ else:
247
+ print(f"Error from LLM tool: {result.stderr}")
248
+ if progress_callback:
249
+ progress_callback(0.9, desc="Error generating response")
250
+ return "I'm having trouble responding to that right now."
251
+ except subprocess.TimeoutExpired:
252
+ print("LLM generation timed out")
253
+ if progress_callback:
254
+ progress_callback(0.9, desc="Generation timed out")
255
+ return "I need more time to think about that."
256
+ except Exception as e:
257
+ print(f"Error generating with LLM tool: {e}")
258
+ if progress_callback:
259
+ progress_callback(0.9, desc="Error generating response")
260
+ return "I'm having trouble responding to that."
261
+
262
+ def generate_multiple_suggestions(
263
+ self,
264
+ person_context: Dict[str, Any],
265
+ user_input: Optional[str] = None,
266
+ num_suggestions: int = 3,
267
+ temperature: Optional[float] = None,
268
+ progress_callback=None,
269
+ ) -> List[str]:
270
+ """Generate multiple suggestions.
271
+
272
+ Args:
273
+ person_context: Context information about the person
274
+ user_input: Optional user input to consider
275
+ num_suggestions: Number of suggestions to generate
276
+ temperature: Controls randomness in generation
277
+ progress_callback: Optional callback function to report progress
278
+
279
+ Returns:
280
+ A list of generated suggestions
281
+ """
282
+ suggestions = []
283
+
284
+ for i in range(num_suggestions):
285
+ if progress_callback:
286
+ progress_callback(
287
+ 0.1 + (i * 0.3),
288
+ desc=f"Generating suggestion {i+1}/{num_suggestions}",
289
+ )
290
+
291
+ # Vary temperature slightly for each suggestion to increase diversity
292
+ temp_variation = 0.05 * (i - 1) # -0.05, 0, 0.05
293
+ temp = (
294
+ temperature if temperature is not None else self.temperature
295
+ ) + temp_variation
296
+
297
+ suggestion = self.generate_suggestion(
298
+ person_context,
299
+ user_input,
300
+ temperature=temp,
301
+ progress_callback=lambda p, desc: (
302
+ progress_callback(0.1 + (i * 0.3) + (p * 0.3), desc=desc)
303
+ if progress_callback
304
+ else None
305
+ ),
306
+ )
307
+
308
+ suggestions.append(suggestion)
309
+
310
+ # Small delay to ensure UI updates
311
+ time.sleep(0.2)
312
+
313
+ return suggestions
314
+
315
+ def test_model(self) -> str:
316
+ """Test if the model is working correctly."""
317
+ if not self.model_loaded:
318
+ return "LLM tool not available"
319
+
320
+ try:
321
+ # Create a simple test prompt
322
+ test_prompt = "Say hello in one word."
323
+
324
+ # Call the LLM tool
325
+ result = subprocess.run(
326
+ [
327
+ "llm",
328
+ "-m",
329
+ self.model_name,
330
+ "-s",
331
+ "temperature=0.7",
332
+ test_prompt,
333
+ ],
334
+ capture_output=True,
335
+ text=True,
336
+ timeout=10,
337
+ )
338
+
339
+ if result.returncode == 0:
340
+ response = result.stdout.strip()
341
+ return f"LLM test successful: {response}"
342
+ else:
343
+ return f"LLM test failed: {result.stderr}"
344
+ except Exception as e:
345
+ return f"LLM test error: {str(e)}"
print_test.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ print("Hello, world!")
2
+ print("This is a test script.")
3
+ print("If you can see this, the terminal output is working.")
requirements.txt CHANGED
@@ -6,3 +6,4 @@ numpy>=1.24.0
6
  openai-whisper>=20231117
7
  bitsandbytes>=0.41.0
8
  accelerate>=0.21.0
 
 
6
  openai-whisper>=20231117
7
  bitsandbytes>=0.41.0
8
  accelerate>=0.21.0
9
+ google-generativeai>=0.3.0
test_app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ print("Starting test...")
5
+
6
+ # Test importing the modules
7
+ try:
8
+ import gradio as gr
9
+ import whisper
10
+ import random
11
+ import time
12
+ from utils import SocialGraphManager, SuggestionGenerator
13
+ print("All modules imported successfully")
14
+ except Exception as e:
15
+ print(f"Error importing modules: {e}")
16
+ sys.exit(1)
17
+
18
+ # Test loading the social graph
19
+ try:
20
+ social_graph = SocialGraphManager("social_graph.json")
21
+ print("Social graph loaded successfully")
22
+ except Exception as e:
23
+ print(f"Error loading social graph: {e}")
24
+ sys.exit(1)
25
+
26
+ # Test initializing the suggestion generator
27
+ try:
28
+ suggestion_generator = SuggestionGenerator("distilgpt2") # Use a simpler model for testing
29
+ print("Suggestion generator initialized successfully")
30
+ except Exception as e:
31
+ print(f"Error initializing suggestion generator: {e}")
32
+ sys.exit(1)
33
+
34
+ # Test getting people from the social graph
35
+ try:
36
+ people = social_graph.get_people_list()
37
+ print(f"Found {len(people)} people in the social graph")
38
+ if people:
39
+ print(f"First person: {people[0]['name']} ({people[0]['role']})")
40
+ except Exception as e:
41
+ print(f"Error getting people from social graph: {e}")
42
+ sys.exit(1)
43
+
44
+ # Test getting person context
45
+ try:
46
+ if people:
47
+ person_id = people[0]['id']
48
+ person_context = social_graph.get_person_context(person_id)
49
+ print(f"Got context for {person_context.get('name', 'unknown')}")
50
+ except Exception as e:
51
+ print(f"Error getting person context: {e}")
52
+ sys.exit(1)
53
+
54
+ print("All tests passed successfully!")
utils.py CHANGED
@@ -1,9 +1,10 @@
1
  import json
2
  import random
3
- from typing import Dict, List, Any, Optional, Tuple
 
 
4
  from sentence_transformers import SentenceTransformer
5
  import numpy as np
6
-
7
  from transformers import pipeline
8
 
9
 
@@ -186,10 +187,10 @@ class SuggestionGenerator:
186
  ]
187
 
188
  def load_model(self, model_name: str) -> bool:
189
- """Load a Hugging Face model.
190
 
191
  Args:
192
- model_name: Name of the HuggingFace model to use
193
 
194
  Returns:
195
  bool: True if model loaded successfully, False otherwise
@@ -204,8 +205,48 @@ class SuggestionGenerator:
204
  self.model_loaded = True
205
  return True
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  try:
208
- print(f"Loading model: {model_name}")
209
 
210
  # Check if this is a gated model that requires authentication
211
  is_gated_model = any(
@@ -217,7 +258,9 @@ class SuggestionGenerator:
217
  # Try to get token from environment
218
  import os
219
  import torch
 
220
  from transformers import BitsAndBytesConfig
 
221
 
222
  token = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get(
223
  "HF_TOKEN"
@@ -232,56 +275,138 @@ class SuggestionGenerator:
232
  # Explicitly pass token to pipeline
233
  from transformers import AutoTokenizer, AutoModelForCausalLM
234
 
235
- try:
236
- # Configure 4-bit quantization to save memory
237
- quantization_config = BitsAndBytesConfig(
238
- load_in_4bit=True,
239
- bnb_4bit_compute_dtype=torch.float16,
240
- bnb_4bit_quant_type="nf4",
241
- bnb_4bit_use_double_quant=True,
242
- )
243
 
244
- tokenizer = AutoTokenizer.from_pretrained(
245
- model_name, token=token
246
- )
 
 
247
 
248
- # Load model with quantization
249
- model = AutoModelForCausalLM.from_pretrained(
250
- model_name,
251
- token=token,
252
- quantization_config=quantization_config,
253
- device_map="auto",
254
- )
 
 
 
 
 
 
 
 
 
 
 
255
 
256
- self.generator = pipeline(
257
- "text-generation",
258
- model=model,
259
- tokenizer=tokenizer,
260
- torch_dtype=torch.float16,
261
- )
262
- except Exception as e:
263
- print(f"Error loading gated model with token: {e}")
264
- print(
265
- "This may be due to not having accepted the model license or insufficient permissions."
266
- )
267
- print(
268
- "Please visit the model page on Hugging Face Hub and accept the license."
269
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  # Try loading without quantization as fallback
271
  try:
272
- print("Trying to load model without quantization...")
 
 
273
  tokenizer = AutoTokenizer.from_pretrained(
274
- model_name, token=token
275
  )
276
  model = AutoModelForCausalLM.from_pretrained(
277
- model_name, token=token
 
 
 
278
  )
279
- self.generator = pipeline(
280
- "text-generation", model=model, tokenizer=tokenizer
 
 
 
 
 
 
281
  )
282
  except Exception as e2:
283
  print(f"Fallback loading also failed: {e2}")
284
- raise e
 
 
 
 
 
285
  else:
286
  print("No Hugging Face token found in environment variables.")
287
  print(
@@ -297,7 +422,12 @@ class SuggestionGenerator:
297
  raise ValueError("Authentication token required for gated model")
298
  else:
299
  # For non-gated models, use the standard pipeline
300
- self.generator = pipeline("text-generation", model=model_name)
 
 
 
 
 
301
 
302
  # Cache the loaded model
303
  self.loaded_models[model_name] = self.generator
@@ -310,6 +440,71 @@ class SuggestionGenerator:
310
  self.model_loaded = False
311
  return False
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  def _get_mood_description(self, mood_value: int) -> str:
314
  """Convert mood value (1-5) to a descriptive string.
315
 
@@ -336,16 +531,132 @@ class SuggestionGenerator:
336
  return "Model not loaded"
337
 
338
  try:
339
- test_prompt = "I am Will. My son Billy asked about football. I respond:"
 
 
 
 
 
 
340
  print(f"Testing model with prompt: {test_prompt}")
341
- response = self.generator(test_prompt, max_new_tokens=30, do_sample=True)
342
- full_text = response[0]["generated_text"]
343
- if len(test_prompt) < len(full_text):
344
- result = full_text[len(test_prompt) :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  else:
346
- result = "No additional text generated"
347
- print(f"Test response: {result}")
348
- return f"Model test successful: {result}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  except Exception as e:
350
  print(f"Error testing model: {e}")
351
  return f"Model test failed: {str(e)}"
@@ -486,14 +797,42 @@ My current mood: {self._get_mood_description(mood)}
486
  for marker in ["-it", "instruct", "chat", "phi-3", "phi-2"]
487
  )
488
 
489
- if is_instruction_model:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  # Use instruction format for instruction-tuned models
491
  if user_input:
492
  # Responding to something
493
  prompt += f"""
494
  <instruction>
495
- Respond to {name} in a way that is natural, brief (1-2 sentences), and directly relevant to what they just said.
496
- Use language appropriate for our relationship.
 
 
497
  </instruction>
498
 
499
  My response to {name}:"""
@@ -501,55 +840,183 @@ My response to {name}:"""
501
  # Initiating a conversation
502
  prompt += f"""
503
  <instruction>
504
- Start a conversation with {name} in a natural, brief (1-2 sentences) way.
505
- Use language appropriate for our relationship.
506
- If a topic was selected, focus on that topic.
 
507
  </instruction>
508
 
509
  My conversation starter to {name}:"""
510
  else:
511
- # Use standard format for non-instruction models
512
  if user_input:
513
  # Responding to something
514
  prompt += f"""
515
- I want to respond to {name} in a way that is natural, brief (1-2 sentences), and directly relevant to what they just said. I'll use language appropriate for our relationship.
 
 
516
 
517
  My response to {name}:"""
518
  else:
519
  # Initiating a conversation
520
  prompt += f"""
521
- I want to start a conversation with {name} in a natural, brief (1-2 sentences) way. I'll use language appropriate for our relationship.
 
 
522
 
523
  My conversation starter to {name}:"""
524
 
525
  # Generate suggestion
526
  try:
527
  print(f"Generating suggestion with prompt: {prompt}")
528
- # Use max_new_tokens instead of max_length to avoid the error
529
- response = self.generator(
530
- prompt,
531
- max_new_tokens=100, # Generate more tokens to ensure we get a response
532
- temperature=temperature,
533
- do_sample=True,
534
- top_p=0.92,
535
- top_k=50,
536
- # Only use truncation if we're providing a max_length
537
- truncation=False,
538
- )
539
- # Extract only the generated part, not the prompt
540
- full_text = response[0]["generated_text"]
541
- print(f"Full generated text length: {len(full_text)}")
542
- print(f"Prompt length: {len(prompt)}")
543
-
544
- # Make sure we're not trying to slice beyond the text length
545
- if len(prompt) < len(full_text):
546
- result = full_text[len(prompt) :]
547
- print(f"Generated response: {result}")
548
- return result.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  else:
550
- # If the model didn't generate anything beyond the prompt
551
- print("Model didn't generate text beyond prompt")
552
- return "I'm thinking about what to say..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  except Exception as e:
554
  print(f"Error generating suggestion: {e}")
555
  return "Could not generate a suggestion. Please try again."
 
1
  import json
2
  import random
3
+ import threading
4
+ import time
5
+ from typing import Dict, List, Any, Optional
6
  from sentence_transformers import SentenceTransformer
7
  import numpy as np
 
8
  from transformers import pipeline
9
 
10
 
 
187
  ]
188
 
189
  def load_model(self, model_name: str) -> bool:
190
+ """Load a model (either Hugging Face model or API-based model).
191
 
192
  Args:
193
+ model_name: Name of the model to use (HuggingFace model name or API identifier)
194
 
195
  Returns:
196
  bool: True if model loaded successfully, False otherwise
 
205
  self.model_loaded = True
206
  return True
207
 
208
+ # Check if this is a Gemini API model
209
+ if model_name.startswith("gemini-api:"):
210
+ try:
211
+ import os
212
+ import google.generativeai as genai
213
+
214
+ # Get API key from environment
215
+ api_key = os.environ.get("GEMINI_API_KEY")
216
+ if not api_key:
217
+ print("No GEMINI_API_KEY found in environment variables.")
218
+ print("Please set the GEMINI_API_KEY environment variable.")
219
+ return False
220
+
221
+ # Configure the Gemini API
222
+ genai.configure(api_key=api_key)
223
+
224
+ # Extract the specific model name after the prefix
225
+ gemini_model = model_name.split(":", 1)[1]
226
+ print(f"Using Gemini API with model: {gemini_model}")
227
+
228
+ # Store the model name and API client in the generator
229
+ self.generator = {
230
+ "type": "gemini-api",
231
+ "model": gemini_model,
232
+ "client": genai,
233
+ }
234
+
235
+ # Cache the API client
236
+ self.loaded_models[model_name] = self.generator
237
+
238
+ self.model_loaded = True
239
+ print(f"Gemini API configured successfully for model: {gemini_model}")
240
+ return True
241
+
242
+ except Exception as e:
243
+ print(f"Error configuring Gemini API: {e}")
244
+ self.model_loaded = False
245
+ return False
246
+
247
+ # Otherwise, try to load a Hugging Face model
248
  try:
249
+ print(f"Loading Hugging Face model: {model_name}")
250
 
251
  # Check if this is a gated model that requires authentication
252
  is_gated_model = any(
 
258
  # Try to get token from environment
259
  import os
260
  import torch
261
+ import time
262
  from transformers import BitsAndBytesConfig
263
+ from requests.exceptions import ConnectionError, Timeout, HTTPError
264
 
265
  token = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get(
266
  "HF_TOKEN"
 
275
  # Explicitly pass token to pipeline
276
  from transformers import AutoTokenizer, AutoModelForCausalLM
277
 
278
+ # Implement retry mechanism for network issues
279
+ max_retries = 3
280
+ retry_delay = 2 # seconds
 
 
 
 
 
281
 
282
+ for attempt in range(max_retries):
283
+ try:
284
+ print(
285
+ f"Attempt {attempt+1}/{max_retries} to load model: {model_name}"
286
+ )
287
 
288
+ # First try to load just the tokenizer to check connectivity
289
+ print(f"Loading tokenizer for {model_name}...")
290
+ tokenizer = AutoTokenizer.from_pretrained(
291
+ model_name,
292
+ token=token,
293
+ use_fast=True,
294
+ local_files_only=False,
295
+ )
296
+ print(f"Tokenizer loaded successfully for {model_name}")
297
+
298
+ # Configure 4-bit quantization to save memory
299
+ print("Configuring quantization settings...")
300
+ quantization_config = BitsAndBytesConfig(
301
+ load_in_4bit=True,
302
+ bnb_4bit_compute_dtype=torch.float16,
303
+ bnb_4bit_quant_type="nf4",
304
+ bnb_4bit_use_double_quant=True,
305
+ )
306
 
307
+ # Load model with quantization
308
+ print(f"Loading model {model_name} with quantization...")
309
+ model = AutoModelForCausalLM.from_pretrained(
310
+ model_name,
311
+ token=token,
312
+ quantization_config=quantization_config,
313
+ device_map="auto",
314
+ low_cpu_mem_usage=True,
315
+ )
316
+ print(
317
+ f"Model {model_name} loaded successfully with quantization"
318
+ )
319
+
320
+ # Create pipeline
321
+ print("Creating text generation pipeline...")
322
+ self.generator = {
323
+ "type": "huggingface",
324
+ "pipeline": pipeline(
325
+ "text-generation",
326
+ model=model,
327
+ tokenizer=tokenizer,
328
+ torch_dtype=torch.float16,
329
+ ),
330
+ }
331
+ print("Pipeline created successfully")
332
+
333
+ # If we got here, loading succeeded
334
+ break
335
+
336
+ except (ConnectionError, Timeout, HTTPError) as network_error:
337
+ # Handle network-related errors with retries
338
+ print(
339
+ f"Network error loading model (attempt {attempt+1}/{max_retries}): {network_error}"
340
+ )
341
+ if attempt < max_retries - 1:
342
+ print(f"Retrying in {retry_delay} seconds...")
343
+ time.sleep(retry_delay)
344
+ retry_delay *= 2 # Exponential backoff
345
+ else:
346
+ print(
347
+ "Maximum retries reached, falling back to alternative loading method"
348
+ )
349
+ raise network_error
350
+
351
+ except (RuntimeError, ValueError, OSError) as e:
352
+ # Handle memory errors or other issues
353
+ print(
354
+ f"Error loading gated model with token (attempt {attempt+1}/{max_retries}): {e}"
355
+ )
356
+ print(
357
+ "This may be due to memory limitations, network issues, or insufficient permissions."
358
+ )
359
+
360
+ if "CUDA out of memory" in str(
361
+ e
362
+ ) or "DefaultCPUAllocator" in str(e):
363
+ print(
364
+ "Memory error detected. Trying with more aggressive memory optimization..."
365
+ )
366
+ break # Skip to non-quantized version with CPU offloading
367
+
368
+ if attempt < max_retries - 1:
369
+ print(f"Retrying in {retry_delay} seconds...")
370
+ time.sleep(retry_delay)
371
+ retry_delay *= 2 # Exponential backoff
372
+ else:
373
+ print(
374
+ "Maximum retries reached, falling back to alternative loading method"
375
+ )
376
+
377
+ # If the loop completed without success, try alternative loading methods
378
+ if not hasattr(self, "generator") or self.generator is None:
379
  # Try loading without quantization as fallback
380
  try:
381
+ print(
382
+ "Trying to load model without quantization (CPU only)..."
383
+ )
384
  tokenizer = AutoTokenizer.from_pretrained(
385
+ model_name, token=token, use_fast=True
386
  )
387
  model = AutoModelForCausalLM.from_pretrained(
388
+ model_name,
389
+ token=token,
390
+ device_map="cpu",
391
+ low_cpu_mem_usage=True,
392
  )
393
+ self.generator = {
394
+ "type": "huggingface",
395
+ "pipeline": pipeline(
396
+ "text-generation", model=model, tokenizer=tokenizer
397
+ ),
398
+ }
399
+ print(
400
+ "Successfully loaded model on CPU without quantization"
401
  )
402
  except Exception as e2:
403
  print(f"Fallback loading also failed: {e2}")
404
+ print(
405
+ "All loading attempts failed. Please try a different model or check your connection."
406
+ )
407
+ raise RuntimeError(
408
+ f"Failed to load model after multiple attempts: {str(e2)}"
409
+ )
410
  else:
411
  print("No Hugging Face token found in environment variables.")
412
  print(
 
422
  raise ValueError("Authentication token required for gated model")
423
  else:
424
  # For non-gated models, use the standard pipeline
425
+ from transformers import pipeline
426
+
427
+ self.generator = {
428
+ "type": "huggingface",
429
+ "pipeline": pipeline("text-generation", model=model_name),
430
+ }
431
 
432
  # Cache the loaded model
433
  self.loaded_models[model_name] = self.generator
 
440
  self.model_loaded = False
441
  return False
442
 
443
+ def _clean_small_model_response(self, response: str) -> str:
444
+ """Clean up responses from small models that often repeat instructions or generate nonsense.
445
+
446
+ Args:
447
+ response: The raw response from the model
448
+
449
+ Returns:
450
+ A cleaned response
451
+ """
452
+ # If response is too short, return as is
453
+ if len(response) < 5:
454
+ return response
455
+
456
+ # Remove common instruction repetitions
457
+ patterns_to_remove = [
458
+ "I want to respond to what",
459
+ "I'll use language appropriate for our relationship",
460
+ "I should speak in first person",
461
+ "I should use language appropriate",
462
+ "I want to respond directly",
463
+ "I'll speak as myself",
464
+ "I want to initiate a conversation",
465
+ "My response should be natural",
466
+ "My response to",
467
+ "Will's response to",
468
+ "Will says to",
469
+ ]
470
+
471
+ # Check for and remove these patterns
472
+ cleaned_response = response
473
+ for pattern in patterns_to_remove:
474
+ if pattern in cleaned_response:
475
+ # Find the first occurrence and remove everything from there
476
+ index = cleaned_response.find(pattern)
477
+ if index > 10: # Keep some beginning text if available
478
+ cleaned_response = cleaned_response[:index].strip()
479
+ else:
480
+ # If pattern is at the beginning, remove just that pattern
481
+ parts = cleaned_response.split(pattern, 1)
482
+ if len(parts) > 1:
483
+ cleaned_response = parts[1].strip()
484
+
485
+ # Remove any lines that are just the name repeated
486
+ lines = cleaned_response.split("\n")
487
+ cleaned_lines = []
488
+ for line in lines:
489
+ # Skip lines that are just a name repeated
490
+ if line.strip() and not all(
491
+ word == line.split()[0] for word in line.split()
492
+ ):
493
+ cleaned_lines.append(line)
494
+
495
+ cleaned_response = "\n".join(cleaned_lines).strip()
496
+
497
+ # If we've removed too much, use a fallback
498
+ if len(cleaned_response) < 5:
499
+ return "I'm not sure what to say about that."
500
+
501
+ # Limit to first 2 sentences to avoid rambling
502
+ sentences = cleaned_response.split(".")
503
+ if len(sentences) > 2:
504
+ cleaned_response = ".".join(sentences[:2]) + "."
505
+
506
+ return cleaned_response
507
+
508
  def _get_mood_description(self, mood_value: int) -> str:
509
  """Convert mood value (1-5) to a descriptive string.
510
 
 
531
  return "Model not loaded"
532
 
533
  try:
534
+ # Create a more explicit test prompt that clearly establishes Will's identity and role
535
+ test_prompt = """I am Will, a 38-year-old with MND (Motor Neuron Disease).
536
+ I am talking to my 7-year-old son Billy.
537
+ Billy just asked me about football.
538
+ I want to respond to Billy in a natural, brief way.
539
+
540
+ My response to Billy:"""
541
  print(f"Testing model with prompt: {test_prompt}")
542
+
543
+ # Check if we're using the Gemini API or a Hugging Face model
544
+ if (
545
+ isinstance(self.generator, dict)
546
+ and self.generator.get("type") == "gemini-api"
547
+ ):
548
+ try:
549
+ # Use Gemini API
550
+ genai = self.generator["client"]
551
+ model_name = self.generator["model"]
552
+
553
+ # Create a generative model
554
+ model = genai.GenerativeModel(model_name)
555
+
556
+ # Generate content with timeout
557
+ print("Sending test request to Gemini API...")
558
+
559
+ # Set a timeout for the test
560
+ import threading
561
+ import time
562
+
563
+ result = ["No response received yet"]
564
+ generation_complete = [False]
565
+
566
+ def generate_with_timeout():
567
+ try:
568
+ print("Starting Gemini API test request...")
569
+ response = model.generate_content(test_prompt)
570
+ print(f"Received response from Gemini API: {response}")
571
+
572
+ if response and hasattr(response, "text"):
573
+ result[0] = response.text
574
+ print(f"Extracted text from response: {result[0]}")
575
+ else:
576
+ result[0] = "No text in Gemini API response"
577
+ print("Response object has no text attribute")
578
+
579
+ generation_complete[0] = True
580
+ except Exception as e:
581
+ print(f"Error in Gemini test generation: {e}")
582
+ result[0] = f"Error: {str(e)}"
583
+ generation_complete[0] = True
584
+
585
+ # Start generation in a separate thread
586
+ generation_thread = threading.Thread(target=generate_with_timeout)
587
+ generation_thread.daemon = True
588
+ generation_thread.start()
589
+
590
+ # Wait for up to 10 seconds
591
+ timeout = 10
592
+ start_time = time.time()
593
+ while (
594
+ not generation_complete[0]
595
+ and time.time() - start_time < timeout
596
+ ):
597
+ print(
598
+ f"Waiting for Gemini API response... ({int(time.time() - start_time)}s)"
599
+ )
600
+ time.sleep(1)
601
+
602
+ if not generation_complete[0]:
603
+ print("Gemini API test request timed out")
604
+ return "Gemini API test timed out after 10 seconds"
605
+
606
+ print(f"Test response from Gemini API: {result[0]}")
607
+ return f"Gemini API test successful: {result[0]}"
608
+ except Exception as e:
609
+ print(f"Error testing Gemini API: {e}")
610
+ return f"Gemini API test failed: {str(e)}"
611
+
612
+ elif (
613
+ isinstance(self.generator, dict)
614
+ and self.generator.get("type") == "huggingface"
615
+ ):
616
+ # Use Hugging Face pipeline
617
+ pipeline = self.generator["pipeline"]
618
+ response = pipeline(test_prompt, max_new_tokens=30, do_sample=True)
619
+ full_text = response[0]["generated_text"]
620
+
621
+ if len(test_prompt) < len(full_text):
622
+ result = full_text[len(test_prompt) :].strip()
623
+
624
+ # Check if this is a small model that needs cleaning
625
+ is_small_model = any(
626
+ name in self.model_name.lower()
627
+ for name in ["distilgpt2", "gpt2-small", "tiny"]
628
+ )
629
+ if is_small_model:
630
+ result = self._clean_small_model_response(result)
631
+ else:
632
+ result = "No additional text generated"
633
+
634
+ print(f"Test response from Hugging Face: {result}")
635
+ return f"Hugging Face model test successful: {result}"
636
+
637
  else:
638
+ # Legacy format (for backward compatibility)
639
+ response = self.generator(
640
+ test_prompt, max_new_tokens=30, do_sample=True
641
+ )
642
+ full_text = response[0]["generated_text"]
643
+
644
+ if len(test_prompt) < len(full_text):
645
+ result = full_text[len(test_prompt) :].strip()
646
+
647
+ # Check if this is a small model that needs cleaning
648
+ is_small_model = any(
649
+ name in self.model_name.lower()
650
+ for name in ["distilgpt2", "gpt2-small", "tiny"]
651
+ )
652
+ if is_small_model:
653
+ result = self._clean_small_model_response(result)
654
+ else:
655
+ result = "No additional text generated"
656
+
657
+ print(f"Test response: {result}")
658
+ return f"Model test successful: {result}"
659
+
660
  except Exception as e:
661
  print(f"Error testing model: {e}")
662
  return f"Model test failed: {str(e)}"
 
797
  for marker in ["-it", "instruct", "chat", "phi-3", "phi-2"]
798
  )
799
 
800
+ # Check if this is a very small model that needs simpler prompts
801
+ is_small_model = any(
802
+ name in self.model_name.lower()
803
+ for name in ["distilgpt2", "gpt2-small", "tiny"]
804
+ )
805
+
806
+ if is_small_model:
807
+ # Use a much simpler format for very small models
808
+ if user_input:
809
+ # Responding to something
810
+ prompt += f"""
811
+ {name} said: "{user_input}"
812
+
813
+ Will's response:"""
814
+ else:
815
+ # Initiating a conversation
816
+ if selected_topic:
817
+ prompt += f"""
818
+ Will starts a conversation with {name} about {selected_topic}.
819
+
820
+ Will says:"""
821
+ else:
822
+ prompt += f"""
823
+ Will starts a conversation with {name}.
824
+
825
+ Will says:"""
826
+ elif is_instruction_model:
827
  # Use instruction format for instruction-tuned models
828
  if user_input:
829
  # Responding to something
830
  prompt += f"""
831
  <instruction>
832
+ I am Will, the person with MND. I need to respond to {name}'s message: "{user_input}"
833
+ My response should be natural, brief (1-2 sentences), and directly relevant to what {name} just said.
834
+ I should use language appropriate for our relationship.
835
+ I should speak in first person as myself (Will).
836
  </instruction>
837
 
838
  My response to {name}:"""
 
840
  # Initiating a conversation
841
  prompt += f"""
842
  <instruction>
843
+ I am Will, the person with MND. I need to start a conversation with {name}.
844
+ My conversation starter should be natural, brief (1-2 sentences), and appropriate for our relationship.
845
+ If a topic was selected, I should focus on that topic.
846
+ I should speak in first person as myself (Will).
847
  </instruction>
848
 
849
  My conversation starter to {name}:"""
850
  else:
851
+ # Use standard format for other models
852
  if user_input:
853
  # Responding to something
854
  prompt += f"""
855
+ I am Will, the person with MND. I want to respond to {name}'s message: "{user_input}"
856
+ My response should be natural, brief (1-2 sentences), and directly relevant to what {name} just said.
857
+ I'll use language appropriate for our relationship and speak as myself (Will).
858
 
859
  My response to {name}:"""
860
  else:
861
  # Initiating a conversation
862
  prompt += f"""
863
+ I am Will, the person with MND. I want to start a conversation with {name}.
864
+ My conversation starter should be natural, brief (1-2 sentences), and appropriate for our relationship.
865
+ I'll speak in first person as myself (Will).
866
 
867
  My conversation starter to {name}:"""
868
 
869
  # Generate suggestion
870
  try:
871
  print(f"Generating suggestion with prompt: {prompt}")
872
+
873
+ # Check if we're using the Gemini API or a Hugging Face model
874
+ if (
875
+ isinstance(self.generator, dict)
876
+ and self.generator.get("type") == "gemini-api"
877
+ ):
878
+ try:
879
+ # Use Gemini API
880
+ try:
881
+ genai = self.generator["client"]
882
+ model_name = self.generator["model"]
883
+
884
+ # Create a generative model
885
+ model = genai.GenerativeModel(model_name)
886
+
887
+ # Set generation config
888
+ generation_config = {
889
+ "temperature": temperature,
890
+ "top_p": 0.92,
891
+ "top_k": 50,
892
+ "max_output_tokens": 100,
893
+ }
894
+
895
+ # Generate content with timeout
896
+
897
+ result = [
898
+ "I'm thinking about what to say..."
899
+ ] # Default response
900
+ generation_complete = [False]
901
+
902
+ def generate_with_gemini():
903
+ try:
904
+ response = model.generate_content(
905
+ prompt, generation_config=generation_config
906
+ )
907
+
908
+ if response and hasattr(response, "text"):
909
+ result[0] = response.text.strip()
910
+ print(f"Gemini API response: {result[0]}")
911
+ else:
912
+ print("No response from Gemini API")
913
+
914
+ generation_complete[0] = True
915
+ except Exception as e:
916
+ print(f"Error in Gemini generation thread: {e}")
917
+ generation_complete[0] = True
918
+
919
+ # Start generation in a separate thread
920
+ generation_thread = threading.Thread(
921
+ target=generate_with_gemini
922
+ )
923
+ generation_thread.daemon = True
924
+ generation_thread.start()
925
+
926
+ # Wait for up to 10 seconds
927
+ timeout = 10
928
+ start_time = time.time()
929
+ while (
930
+ not generation_complete[0]
931
+ and time.time() - start_time < timeout
932
+ ):
933
+ time.sleep(0.1)
934
+
935
+ if not generation_complete[0]:
936
+ print("Gemini API request timed out")
937
+ return "I'm thinking about what to say... (API timeout)"
938
+
939
+ return result[0]
940
+ except Exception as e:
941
+ print(f"Error setting up Gemini API: {e}")
942
+ return (
943
+ "I'm having trouble connecting to the Gemini API right now."
944
+ )
945
+
946
+ except Exception as e:
947
+ print(f"Error generating with Gemini API: {e}")
948
+ return "Could not generate a suggestion with Gemini API. Please try again."
949
+
950
+ elif (
951
+ isinstance(self.generator, dict)
952
+ and self.generator.get("type") == "huggingface"
953
+ ):
954
+ # Use Hugging Face pipeline
955
+ pipeline = self.generator["pipeline"]
956
+
957
+ # Generate with Hugging Face
958
+ response = pipeline(
959
+ prompt,
960
+ max_new_tokens=100, # Generate more tokens to ensure we get a response
961
+ temperature=temperature,
962
+ do_sample=True,
963
+ top_p=0.92,
964
+ top_k=50,
965
+ truncation=False,
966
+ )
967
+
968
+ # Extract only the generated part, not the prompt
969
+ full_text = response[0]["generated_text"]
970
+ print(f"Full generated text length: {len(full_text)}")
971
+ print(f"Prompt length: {len(prompt)}")
972
+
973
+ # Make sure we're not trying to slice beyond the text length
974
+ if len(prompt) < len(full_text):
975
+ result = full_text[len(prompt) :].strip()
976
+
977
+ # Post-process the result for small models
978
+ if is_small_model:
979
+ result = self._clean_small_model_response(result)
980
+
981
+ print(f"Generated response: {result}")
982
+ return result
983
+ else:
984
+ # If the model didn't generate anything beyond the prompt
985
+ print("Model didn't generate text beyond prompt")
986
+ return "I'm thinking about what to say..."
987
+
988
  else:
989
+ # Legacy format (for backward compatibility)
990
+ response = self.generator(
991
+ prompt,
992
+ max_new_tokens=100,
993
+ temperature=temperature,
994
+ do_sample=True,
995
+ top_p=0.92,
996
+ top_k=50,
997
+ truncation=False,
998
+ )
999
+
1000
+ # Extract only the generated part, not the prompt
1001
+ full_text = response[0]["generated_text"]
1002
+ print(f"Full generated text length: {len(full_text)}")
1003
+ print(f"Prompt length: {len(prompt)}")
1004
+
1005
+ # Make sure we're not trying to slice beyond the text length
1006
+ if len(prompt) < len(full_text):
1007
+ result = full_text[len(prompt) :].strip()
1008
+
1009
+ # Post-process the result for small models
1010
+ if is_small_model:
1011
+ result = self._clean_small_model_response(result)
1012
+
1013
+ print(f"Generated response: {result}")
1014
+ return result
1015
+ else:
1016
+ # If the model didn't generate anything beyond the prompt
1017
+ print("Model didn't generate text beyond prompt")
1018
+ return "I'm thinking about what to say..."
1019
+
1020
  except Exception as e:
1021
  print(f"Error generating suggestion: {e}")
1022
  return "Could not generate a suggestion. Please try again."