Medgmma / app.py
Ashraf's picture
Update app.py
6f318f4 verified
import base64
import io
from flask import Flask, request, jsonify,Response
from flask_cors import CORS
import tempfile
import time
from flask import Flask, request, jsonify
from transformers import AutoProcessor, AutoModelForVision2Seq , AutoModelForImageTextToText
from PIL import Image
import torch
import tempfile
import whisper
import json
#import uvicorn
app = Flask(__name__)
from deep_translator import GoogleTranslator
CORS(app, resources={r"/*": {"origins": "*"}}) # Enable CORS for all routes
# Load MedGemma model (4B) on startup
#from asgiref.wsgi import WsgiToAsgi
#asgi_app = WsgiToAsgi(app) # for uvicorn compatibility
from transformers import AutoTokenizer, AutoModelForCausalLM
#from huggingface_hub import login
import os
#from llama_cpp import Llama
#from huggingface_hub import hf_hub_download
"""hugging_face_token = os.getenv("kk")
if not hugging_face_token:
raise EnvironmentError("HUGGINGFACE_TOKEN environment variable not set.")
#login(hugging_face_token)
"""
"""model_name = "unsloth/medgemma-4b-it-GGUF"
model_file = "medgemma-4b-it-Q8_0.gguf" # this is the specific model file we'll use in this example. It's a 4-bit quant, but other levels of quantization are available in the model repo if preferred
model_path = hf_hub_download(model_name, filename=model_file)
llm = Llama(
model_path=model_path, # Update this to your local model path
n_ctx=8192,
n_threads=12,
temperature=0.7,
)
"""
model_id = "google/medgemma-4b-pt"
model_medg = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="cpu",
)
processor = AutoProcessor.from_pretrained(model_id)
@app.route('/analyze-image', methods=['POST'])
def analyze_image():
image = None
image_path = None
# Get optional image
file = request.files.get('file')
if file:
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp:
file.save(tmp.name)
image_path = tmp.name
try:
image = Image.open(tmp.name).convert("RGB")
except Exception as e:
return jsonify({'error': f'Invalid image: {str(e)}'}), 400
# Get optional chat history
try:
chat_history = request.form.get("chat_history")
if not chat_history:
chat_history = request.json.get("chat_history", "[]")
else:
chat_history = json.loads(chat_history)
except Exception as e:
return jsonify({'error': f'Invalid or missing chat_history: {str(e)}'}), 400
# Build text prompt (from chat history)
prompt_parts = []
for msg in chat_history:
role = msg.get("role", "").strip().lower()
content = msg.get("content", "").strip()
if role == "system":
prompt_parts.append(content)
elif role == "user":
prompt_parts.append(f"User: {content}")
elif role == "assistant":
prompt_parts.append(f"Assistant: {content}")
combined_prompt = "\n".join(prompt_parts).strip()
if not image and not combined_prompt:
return jsonify({'error': 'You must provide either an image or a prompt.'}), 400
# Final model prompt
if image:
combined_prompt = f"<start_of_image> {combined_prompt or 'Findings:'}"
else:
combined_prompt = combined_prompt or "Findings:"
model_prompt = f" {combined_prompt or 'Response:'}"
# Prepare input to model
inputs = processor(
text=model_prompt,
images=image if image else None,
return_tensors="pt"
).to(model_medg.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
print(model_prompt)
with torch.inference_mode():
generation = model_medg.generate(
**inputs,
max_new_tokens=1000,
do_sample=False
)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
if image_path and os.path.exists(image_path):
os.remove(image_path)
return jsonify({"result": decoded.strip()})
@app.route('/med-llm', methods=['POST'])
def med_llm():
uploaded_file = request.files.get('file')
if not uploaded_file:
return jsonify({'error': 'No file uploaded'}), 400
mime_type = uploaded_file.mimetype
print(f"Received file type: {mime_type}")
with tempfile.NamedTemporaryFile(delete=False, suffix='.bin') as tmp:
uploaded_file.save(tmp.name)
if mime_type.startswith('image/'):
mock_response = "📷 MedGemma Image Analysis: Detected mild cardiomegaly."
elif mime_type.startswith('audio/'):
mock_response = "🎧 MedGemma Audio Analysis: Suggests potential wheezing."
else:
mock_response = "Unsupported file type."
return jsonify({'result': mock_response})
model = whisper.load_model("base") # You can change to 'small', 'medium', etc.
@app.route('/transcribe-stream', methods=['POST'])
def transcribe_stream():
# Save the audio file from request
audio_file = request.files.get('audio')
if not audio_file:
return "Missing audio file", 400
# Save to temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
audio_path = tmp.name
audio_file.save(audio_path)
def generate():
# Transcribe using Whisper (non-streaming)
result = model.transcribe(audio_path)
for word in result['text'].split():
yield f"data: {word}\n\n"
time.sleep(0.3) # Simulate streaming
os.remove(audio_path) # Clean up
return Response(generate(), mimetype='text/event-stream')
def translate_to_arabic(text):
try:
translated = GoogleTranslator(source='auto', target='ar').translate(text)
return translated
except Exception as e:
print(f"Translation failed: {e}")
return text
@app.route('/chat-translate', methods=['POST'])
def chat_translate():
try:
data = request.get_json()
chat_history = data.get('chat_history', [])
translate = request.args.get("translate") == "true"
# Join messages into prompt
prompt_parts = []
for msg in chat_history:
role = msg.get("role", "").strip().lower()
content = msg.get("content", "").strip()
if role == "system":
prompt_parts.append(f"System:\n{content}")
elif role == "user":
prompt_parts.append(f"User:\n{content}")
elif role == "model":
prompt_parts.append(f"Assistant:\n{content}")
result = content # Use the last model response as the result
full_prompt = "\n\n".join(prompt_parts)
# Simulated LLM response
# result = "Simulated model answer. Possible findings: * Infection * Fluid accumulation. Next steps: * Follow-up test."
if translate:
result = translate_to_arabic(result)
return jsonify({"result": result})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=7860)