besucoder's picture
multi.py
3fed2fb verified
raw
history blame
5.36 kB
# Imports
import gradio as gr
import wikipedia
import numpy as np
import faiss
from langdetect import detect
from gtts import gTTS
from transformers import pipeline
from sentence_transformers import SentenceTransformer
import tempfile, os
import torch
import speech_recognition as sr
from functools import lru_cache
from pydub import AudioSegment
# ===== Model Setup =====
models = {}
def load_models():
models['encoder'] = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
models['to_en'] = pipeline('translation', model='Helsinki-NLP/opus-mt-mul-en')
for lang in ['fr', 'ar', 'zh', 'es']:
models[f'en_to_{lang}'] = pipeline('translation_en_to_' + lang, model=f'Helsinki-NLP/opus-mt-en-{lang}')
models['answer_gen'] = pipeline('text2text-generation', model='google/flan-t5-base', max_length=1024) # increased length
load_models()
# ===== Utility Functions =====
def detect_language(text):
try:
return detect(text)
except:
return 'en'
def translate(text, src, tgt):
if src == tgt:
return text
if src != 'en':
text = models['to_en'](text)[0]['translation_text']
if f'en_to_{tgt}' in models:
return models[f'en_to_{tgt}'](text)[0]['translation_text']
return text
def tts_play(text, lang):
tts = gTTS(text=text, lang=lang)
path = tempfile.mktemp(suffix=".mp3")
tts.save(path)
return path
def chunk_text(text, max_words=100): # increased chunk size
sentences = text.split('. ')
chunks, current_chunk, current_len = [], [], 0
for sent in sentences:
words = sent.split()
if current_len + len(words) > max_words:
chunks.append('. '.join(current_chunk))
current_chunk = [sent]
current_len = len(words)
else:
current_chunk.append(sent)
current_len += len(words)
if current_chunk:
chunks.append('. '.join(current_chunk))
return chunks
def build_faiss_index(chunks, model):
embeddings = model.encode(chunks, convert_to_numpy=True)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
return index
@lru_cache(maxsize=20)
def prepare_faiss_for_topic(topic):
wikipedia.set_lang('en')
page = wikipedia.page(topic)
content = page.content[:5000] # increase content for better answers
chunks = chunk_text(content)
index = build_faiss_index(chunks, models['encoder'])
return chunks, index
def retrieve_context(question, index, chunks, model, top_k=5): # increased top_k
q_emb = model.encode([question], convert_to_numpy=True)
_, indices = index.search(q_emb, top_k)
return ' '.join([chunks[i] for i in indices[0]])
# ===== Main Inference Function =====
def qa_system(audio, text_question, topic, output_lang):
question = ""
if audio is not None:
try:
r = sr.Recognizer()
audio_wav_path = tempfile.mktemp(suffix=".wav")
sound = AudioSegment.from_file(audio)
sound.export(audio_wav_path, format="wav")
with sr.AudioFile(audio_wav_path) as source:
audio_data = r.record(source)
question = r.recognize_google(audio_data)
except Exception as e:
return f"❌ Could not understand the audio: {e}", None, None
elif text_question:
question = text_question.strip()
else:
return "❌ Please provide a voice or text question.", None, None
input_lang = detect_language(question)
try:
chunks, faiss_index = prepare_faiss_for_topic(topic)
except:
return "Error loading topic from Wikipedia", None, None
context = retrieve_context(question, faiss_index, chunks, models['encoder'], top_k=5)
question_en = translate(question, input_lang, 'en')
prompt = f"Answer based on the context:\nContext: {context}\nQuestion: {question_en}"
answer_en = models['answer_gen'](prompt)[0]['generated_text']
if output_lang == 'en':
answer = answer_en
elif output_lang == 'am':
answer = "Amharic translation not supported."
else:
answer = translate(answer_en, 'en', output_lang)
audio_path = tts_play(answer, output_lang)
return f"You asked: {question}\n\nAnswer: {answer}", audio_path, answer
# ===== Gradio UI =====
lang_options = ['en', 'am', 'fr', 'ar', 'es', 'zh']
demo = gr.Interface(
fn=qa_system,
inputs=[
gr.Audio(type="filepath", label="🎀 Ask your Question by Voice (optional)"),
gr.Textbox(label="✍️ Or type your Question here (optional)"),
gr.Textbox(value="Artificial intelligence", label="πŸ“š Wikipedia Topic"),
gr.Dropdown(choices=lang_options, value='en', label="🌍 Output Language")
],
outputs=[
gr.Textbox(label="πŸ€– Answer Output"),
gr.Audio(label="πŸ”Š Answer Playback"),
gr.Textbox(label="πŸ“ Translated Answer Text")
],
title="🌍 Multilingual Voice/Text Q&A Assistant",
description="""
<h3 style='text-align: center; font-weight: bold; font-style: italic;'>πŸ‘‹ Welcome to the Multilingual Wikipedia Q&A Assistant</h3>
<p style='text-align: center;'>You can ask questions using voice or text in different languages, and get spoken and translated answers using AI + Wikipedia. 🌐</p>
"""
)
# Launch the app
demo.launch()