besucoder commited on
Commit
c4d47f7
Β·
verified Β·
1 Parent(s): 51c7e1a
Files changed (1) hide show
  1. app.py +100 -88
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import wikipedia
3
  import numpy as np
@@ -6,13 +7,12 @@ from langdetect import detect
6
  from gtts import gTTS
7
  from transformers import pipeline
8
  from sentence_transformers import SentenceTransformer
9
- import tempfile, os
10
- import torch
11
  import speech_recognition as sr
12
- from functools import lru_cache
13
  from pydub import AudioSegment
 
14
 
15
- # ===== Model Setup =====
16
  models = {}
17
  def load_models():
18
  models['encoder'] = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
@@ -20,10 +20,9 @@ def load_models():
20
  for lang in ['fr', 'ar', 'zh', 'es']:
21
  models[f'en_to_{lang}'] = pipeline('translation_en_to_' + lang, model=f'Helsinki-NLP/opus-mt-en-{lang}')
22
  models['answer_gen'] = pipeline('text2text-generation', model='google/flan-t5-base', max_length=1024)
23
-
24
  load_models()
25
 
26
- # ===== Utility Functions =====
27
  def detect_language(text):
28
  try:
29
  return detect(text)
@@ -41,108 +40,121 @@ def translate(text, src, tgt):
41
 
42
  def tts_play(text, lang):
43
  tts = gTTS(text=text, lang=lang)
44
- path = tempfile.mktemp(suffix=".mp3")
45
  tts.save(path)
46
  return path
47
 
48
  def chunk_text(text, max_words=100):
49
  sentences = text.split('. ')
50
- chunks, current_chunk, current_len = [], [], 0
51
  for sent in sentences:
52
  words = sent.split()
53
- if current_len + len(words) > max_words:
54
- chunks.append('. '.join(current_chunk))
55
- current_chunk = [sent]
56
- current_len = len(words)
57
  else:
58
- current_chunk.append(sent)
59
- current_len += len(words)
60
- if current_chunk:
61
- chunks.append('. '.join(current_chunk))
62
  return chunks
63
 
64
  def build_faiss_index(chunks, model):
65
- embeddings = model.encode(chunks, convert_to_numpy=True)
66
- index = faiss.IndexFlatL2(embeddings.shape[1])
67
- index.add(embeddings)
68
  return index
69
 
70
  @lru_cache(maxsize=20)
71
  def prepare_faiss_for_topic(topic):
72
  wikipedia.set_lang('en')
73
  page = wikipedia.page(topic)
74
- content = page.content[:5000]
75
- chunks = chunk_text(content)
76
- index = build_faiss_index(chunks, models['encoder'])
77
- return chunks, index
78
-
79
- def retrieve_context(question, index, chunks, model, top_k=5):
80
- q_emb = model.encode([question], convert_to_numpy=True)
81
- _, indices = index.search(q_emb, top_k)
82
- return ' '.join([chunks[i] for i in indices[0]])
83
-
84
- # ===== Main Inference Function =====
85
- def qa_system(audio, text_question, topic, output_lang):
86
- question = ""
87
- if audio is not None:
88
  try:
89
  r = sr.Recognizer()
90
- audio_wav_path = tempfile.mktemp(suffix=".wav")
91
- sound = AudioSegment.from_file(audio)
92
- sound.export(audio_wav_path, format="wav")
93
- with sr.AudioFile(audio_wav_path) as source:
94
- audio_data = r.record(source)
95
- question = r.recognize_google(audio_data)
96
  except Exception as e:
97
- return f"❌ Could not understand the audio: {e}", None, None
98
- elif text_question:
99
- question = text_question.strip()
100
  else:
101
- return "❌ Please provide a voice or text question.", None, None
102
-
103
- input_lang = detect_language(question)
104
 
105
  try:
106
- chunks, faiss_index = prepare_faiss_for_topic(topic)
107
- except:
108
- return "❌ Error loading topic from Wikipedia.", None, None
109
-
110
- context = retrieve_context(question, faiss_index, chunks, models['encoder'], top_k=5)
111
- question_en = translate(question, input_lang, 'en')
112
- prompt = f"Answer based on the context:\nContext: {context}\nQuestion: {question_en}"
113
- answer_en = models['answer_gen'](prompt)[0]['generated_text']
114
-
115
- if output_lang == 'en':
116
- answer = answer_en
117
- elif output_lang == 'am':
118
- answer = "Amharic translation not supported."
119
- else:
120
- answer = translate(answer_en, 'en', output_lang)
121
-
122
- audio_path = tts_play(answer, output_lang)
123
- return f"You asked: {question}\n\nAnswer: {answer}", audio_path, answer
124
-
125
- # ===== Gradio UI =====
126
- lang_options = ['en', 'am', 'fr', 'ar', 'es', 'zh']
127
-
128
- demo = gr.Interface(
129
- fn=qa_system,
130
- inputs=[
131
- gr.Audio(type="filepath", label="🎀 Ask your Question by Voice (optional)"),
132
- gr.Textbox(label="✍️ Or type your Question here (optional)"),
133
- gr.Textbox(value="Artificial intelligence", label="πŸ“š Wikipedia Topic"),
134
- gr.Dropdown(choices=lang_options, value='en', label="🌍 Output Language")
135
- ],
136
- outputs=[
137
- gr.Textbox(label="πŸ€– Answer Output"),
138
- gr.Audio(label="πŸ”Š Answer Playback"),
139
- gr.Textbox(label="πŸ“ Translated Answer Text")
140
- ],
141
- title="🌍 Multilingual Voice/Text Q&A Assistant",
142
- description="""
143
- <h3 style='text-align: center; font-weight: bold; font-style: italic;'>πŸ‘‹ Welcome to the Multilingual Wikipedia Q&A Assistant</h3>
144
- <p style='text-align: center;'>Ask questions by voice or text in different languages, and get spoken and translated answers using AI + Wikipedia.</p>
145
- """
146
- )
147
-
148
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
  import gradio as gr
3
  import wikipedia
4
  import numpy as np
 
7
  from gtts import gTTS
8
  from transformers import pipeline
9
  from sentence_transformers import SentenceTransformer
10
+ import tempfile
 
11
  import speech_recognition as sr
 
12
  from pydub import AudioSegment
13
+ from functools import lru_cache
14
 
15
+ # --- Load models ---
16
  models = {}
17
  def load_models():
18
  models['encoder'] = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
 
20
  for lang in ['fr', 'ar', 'zh', 'es']:
21
  models[f'en_to_{lang}'] = pipeline('translation_en_to_' + lang, model=f'Helsinki-NLP/opus-mt-en-{lang}')
22
  models['answer_gen'] = pipeline('text2text-generation', model='google/flan-t5-base', max_length=1024)
 
23
  load_models()
24
 
25
+ # --- Utility functions ---
26
  def detect_language(text):
27
  try:
28
  return detect(text)
 
40
 
41
  def tts_play(text, lang):
42
  tts = gTTS(text=text, lang=lang)
43
+ path = tempfile.mktemp(suffix='.mp3')
44
  tts.save(path)
45
  return path
46
 
47
  def chunk_text(text, max_words=100):
48
  sentences = text.split('. ')
49
+ chunks, current, length = [], [], 0
50
  for sent in sentences:
51
  words = sent.split()
52
+ if length + len(words) > max_words:
53
+ chunks.append(' '.join(current))
54
+ current, length = [sent], len(words)
 
55
  else:
56
+ current.append(sent)
57
+ length += len(words)
58
+ if current: chunks.append(' '.join(current))
 
59
  return chunks
60
 
61
  def build_faiss_index(chunks, model):
62
+ emb = model.encode(chunks, convert_to_numpy=True)
63
+ index = faiss.IndexFlatL2(emb.shape[1])
64
+ index.add(emb)
65
  return index
66
 
67
  @lru_cache(maxsize=20)
68
  def prepare_faiss_for_topic(topic):
69
  wikipedia.set_lang('en')
70
  page = wikipedia.page(topic)
71
+ chunks = chunk_text(page.content) # Use full content, no slicing limit
72
+ return chunks, build_faiss_index(chunks, models['encoder'])
73
+
74
+ def retrieve_context(q, idx, chunks, model, top_k=5):
75
+ emb = model.encode([q], convert_to_numpy=True)
76
+ _, inds = idx.search(emb, top_k)
77
+ return ' '.join(chunks[i] for i in inds[0])
78
+
79
+ # --- Main Q&A function ---
80
+ def qa_system(audio, text_q, topic, lang, history):
81
+ q = ''
82
+ if audio:
 
 
83
  try:
84
  r = sr.Recognizer()
85
+ wav = tempfile.mktemp('.wav')
86
+ AudioSegment.from_file(audio).export(wav, format='wav')
87
+ with sr.AudioFile(wav) as src:
88
+ q = r.recognize_google(r.record(src))
 
 
89
  except Exception as e:
90
+ return f"❌ Could not transcribe audio: {e}", None, history, ''
91
+ elif text_q:
92
+ q = text_q.strip()
93
  else:
94
+ return '❌ Please speak or type your question.', None, history, ''
 
 
95
 
96
  try:
97
+ chunks, idx = prepare_faiss_for_topic(topic)
98
+ except Exception as e:
99
+ return f'Error loading content: {e}', None, history, ''
100
+
101
+ ctx = retrieve_context(q, idx, chunks, models['encoder'])
102
+ q_en = translate(q, detect_language(q), 'en')
103
+
104
+ # Debug prints β€” remove in production
105
+ print("Question (original):", q)
106
+ print("Question (English):", q_en)
107
+ print("Retrieved context snippet:", ctx[:500], "...\n")
108
+
109
+ prompt = f"Context:\n{ctx}\n\nQuestion: {q_en}\nAnswer:"
110
+ ans_en = models['answer_gen'](prompt)[0]['generated_text']
111
+
112
+ print("Generated answer (English):", ans_en)
113
+
114
+ ans = ans_en if lang == 'en' else translate(ans_en, 'en', lang)
115
+ audio_path = tts_play(ans, lang)
116
+ history.append((q, ans))
117
+ chat = '\n\n'.join(f"Q{i+1}: {x}\nA{i+1}: {y}" for i,(x,y) in enumerate(history))
118
+ return f'You asked: {q}\n\nAnswer: {ans}', audio_path, history, chat
119
+
120
+ def clear_all():
121
+ return None, '', None, [], ''
122
+
123
+ # --- Gradio UI with styling ---
124
+ css_style = """
125
+ .gradio-container {
126
+ background-color: #cce7ff !important; /* Light blue */
127
+ border: 3px solid #000022 !important; /* Balanced blue-black border */
128
+ border-radius: 12px;
129
+ padding: 20px;
130
+ }
131
+ """
132
+
133
+ with gr.Blocks(css=css_style) as demo:
134
+ gr.Markdown("""
135
+ <h1 style='color:#003366; text-align:center; margin-bottom: 0;'>🌐 Multilingual Wikipedia Q&A Assistant</h1>
136
+ <p style='text-align:center; font-size:16px; margin-top: 0;'>Ask your questions by typing or speaking, and get answers in your language!</p>
137
+ """)
138
+
139
+ state = gr.State([])
140
+
141
+ with gr.Row():
142
+ ai = gr.Audio(type='filepath', label='🎀 Speak your question')
143
+ ti = gr.Textbox(lines=3, placeholder='Or type your question here')
144
+
145
+ with gr.Row():
146
+ tp = gr.Textbox(value='Artificial intelligence', label='Wikipedia Topic')
147
+ lg = gr.Dropdown(['en','am','fr','ar','es','zh'], value='en', label='Output Language')
148
+
149
+ with gr.Row():
150
+ sb = gr.Button('πŸ” Get Answer')
151
+ cb = gr.Button('πŸ—‘οΈ Clear All')
152
+
153
+ ao = gr.Textbox(label='πŸ€– Answer')
154
+ av = gr.Audio(label='πŸ”Š Listen Answer')
155
+ cd = gr.Markdown(label='πŸ—‚οΈ Chat History')
156
+
157
+ sb.click(qa_system, inputs=[ai, ti, tp, lg, state], outputs=[ao, av, state, cd])
158
+ cb.click(clear_all, outputs=[ai, ti, tp, state, cd])
159
+
160
+ demo.launch(share=True)