|
import base64 |
|
import io |
|
from flask import Flask, request, jsonify,Response |
|
from flask_cors import CORS |
|
import tempfile |
|
import time |
|
from flask import Flask, request, jsonify |
|
from transformers import AutoProcessor, AutoModelForVision2Seq , AutoModelForImageTextToText |
|
from PIL import Image |
|
import torch |
|
import tempfile |
|
import whisper |
|
import json |
|
|
|
app = Flask(__name__) |
|
from deep_translator import GoogleTranslator |
|
|
|
CORS(app, resources={r"/*": {"origins": "*"}}) |
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
import os |
|
|
|
|
|
"""hugging_face_token = os.getenv("kk") |
|
if not hugging_face_token: |
|
raise EnvironmentError("HUGGINGFACE_TOKEN environment variable not set.") |
|
|
|
#login(hugging_face_token) |
|
""" |
|
"""model_name = "unsloth/medgemma-4b-it-GGUF" |
|
model_file = "medgemma-4b-it-Q8_0.gguf" # this is the specific model file we'll use in this example. It's a 4-bit quant, but other levels of quantization are available in the model repo if preferred |
|
model_path = hf_hub_download(model_name, filename=model_file) |
|
llm = Llama( |
|
model_path=model_path, # Update this to your local model path |
|
n_ctx=8192, |
|
n_threads=12, |
|
temperature=0.7, |
|
) |
|
""" |
|
model_id = "google/medgemma-4b-pt" |
|
|
|
model_medg = AutoModelForImageTextToText.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="cpu", |
|
) |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
@app.route('/analyze-image', methods=['POST']) |
|
def analyze_image(): |
|
image = None |
|
image_path = None |
|
|
|
|
|
file = request.files.get('file') |
|
if file: |
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp: |
|
file.save(tmp.name) |
|
image_path = tmp.name |
|
try: |
|
image = Image.open(tmp.name).convert("RGB") |
|
except Exception as e: |
|
return jsonify({'error': f'Invalid image: {str(e)}'}), 400 |
|
|
|
|
|
try: |
|
chat_history = request.form.get("chat_history") |
|
if not chat_history: |
|
chat_history = request.json.get("chat_history", "[]") |
|
else: |
|
chat_history = json.loads(chat_history) |
|
except Exception as e: |
|
return jsonify({'error': f'Invalid or missing chat_history: {str(e)}'}), 400 |
|
|
|
|
|
prompt_parts = [] |
|
for msg in chat_history: |
|
role = msg.get("role", "").strip().lower() |
|
content = msg.get("content", "").strip() |
|
if role == "system": |
|
prompt_parts.append(content) |
|
elif role == "user": |
|
prompt_parts.append(f"User: {content}") |
|
elif role == "assistant": |
|
prompt_parts.append(f"Assistant: {content}") |
|
|
|
combined_prompt = "\n".join(prompt_parts).strip() |
|
|
|
if not image and not combined_prompt: |
|
return jsonify({'error': 'You must provide either an image or a prompt.'}), 400 |
|
|
|
|
|
|
|
|
|
if image: |
|
combined_prompt = f"<start_of_image> {combined_prompt or 'Findings:'}" |
|
else: |
|
combined_prompt = combined_prompt or "Findings:" |
|
|
|
model_prompt = f" {combined_prompt or 'Response:'}" |
|
|
|
inputs = processor( |
|
text=model_prompt, |
|
images=image if image else None, |
|
return_tensors="pt" |
|
).to(model_medg.device, dtype=torch.bfloat16) |
|
|
|
input_len = inputs["input_ids"].shape[-1] |
|
print(model_prompt) |
|
|
|
with torch.inference_mode(): |
|
generation = model_medg.generate( |
|
**inputs, |
|
max_new_tokens=1000, |
|
do_sample=False |
|
) |
|
generation = generation[0][input_len:] |
|
|
|
decoded = processor.decode(generation, skip_special_tokens=True) |
|
|
|
if image_path and os.path.exists(image_path): |
|
os.remove(image_path) |
|
|
|
return jsonify({"result": decoded.strip()}) |
|
|
|
@app.route('/med-llm', methods=['POST']) |
|
def med_llm(): |
|
uploaded_file = request.files.get('file') |
|
if not uploaded_file: |
|
return jsonify({'error': 'No file uploaded'}), 400 |
|
|
|
mime_type = uploaded_file.mimetype |
|
print(f"Received file type: {mime_type}") |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.bin') as tmp: |
|
uploaded_file.save(tmp.name) |
|
|
|
if mime_type.startswith('image/'): |
|
mock_response = "📷 MedGemma Image Analysis: Detected mild cardiomegaly." |
|
elif mime_type.startswith('audio/'): |
|
mock_response = "🎧 MedGemma Audio Analysis: Suggests potential wheezing." |
|
else: |
|
mock_response = "Unsupported file type." |
|
|
|
return jsonify({'result': mock_response}) |
|
model = whisper.load_model("base") |
|
@app.route('/transcribe-stream', methods=['POST']) |
|
def transcribe_stream(): |
|
|
|
audio_file = request.files.get('audio') |
|
if not audio_file: |
|
return "Missing audio file", 400 |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
|
audio_path = tmp.name |
|
audio_file.save(audio_path) |
|
|
|
def generate(): |
|
|
|
result = model.transcribe(audio_path) |
|
for word in result['text'].split(): |
|
yield f"data: {word}\n\n" |
|
time.sleep(0.3) |
|
|
|
os.remove(audio_path) |
|
|
|
return Response(generate(), mimetype='text/event-stream') |
|
def translate_to_arabic(text): |
|
try: |
|
translated = GoogleTranslator(source='auto', target='ar').translate(text) |
|
return translated |
|
except Exception as e: |
|
print(f"Translation failed: {e}") |
|
return text |
|
@app.route('/chat-translate', methods=['POST']) |
|
def chat_translate(): |
|
try: |
|
data = request.get_json() |
|
chat_history = data.get('chat_history', []) |
|
translate = request.args.get("translate") == "true" |
|
|
|
|
|
prompt_parts = [] |
|
for msg in chat_history: |
|
role = msg.get("role", "").strip().lower() |
|
content = msg.get("content", "").strip() |
|
if role == "system": |
|
prompt_parts.append(f"System:\n{content}") |
|
elif role == "user": |
|
prompt_parts.append(f"User:\n{content}") |
|
elif role == "model": |
|
prompt_parts.append(f"Assistant:\n{content}") |
|
result = content |
|
|
|
full_prompt = "\n\n".join(prompt_parts) |
|
|
|
|
|
|
|
|
|
if translate: |
|
result = translate_to_arabic(result) |
|
|
|
return jsonify({"result": result}) |
|
|
|
except Exception as e: |
|
return jsonify({"error": str(e)}), 500 |
|
if __name__ == '__main__': |
|
app.run(debug=True, host='0.0.0.0', port=7860) |