File size: 7,205 Bytes
c9c6b30 6f318f4 c9c6b30 65212b6 c9c6b30 6f318f4 c9c6b30 6f318f4 c9c6b30 6f318f4 c9c6b30 3649fd9 c9c6b30 f43f764 c9c6b30 a5cbc59 f43f764 c9c6b30 e50d767 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
import base64
import io
from flask import Flask, request, jsonify,Response
from flask_cors import CORS
import tempfile
import time
from flask import Flask, request, jsonify
from transformers import AutoProcessor, AutoModelForVision2Seq , AutoModelForImageTextToText
from PIL import Image
import torch
import tempfile
import whisper
import json
#import uvicorn
app = Flask(__name__)
from deep_translator import GoogleTranslator
CORS(app, resources={r"/*": {"origins": "*"}}) # Enable CORS for all routes
# Load MedGemma model (4B) on startup
#from asgiref.wsgi import WsgiToAsgi
#asgi_app = WsgiToAsgi(app) # for uvicorn compatibility
from transformers import AutoTokenizer, AutoModelForCausalLM
#from huggingface_hub import login
import os
#from llama_cpp import Llama
#from huggingface_hub import hf_hub_download
"""hugging_face_token = os.getenv("kk")
if not hugging_face_token:
raise EnvironmentError("HUGGINGFACE_TOKEN environment variable not set.")
#login(hugging_face_token)
"""
"""model_name = "unsloth/medgemma-4b-it-GGUF"
model_file = "medgemma-4b-it-Q8_0.gguf" # this is the specific model file we'll use in this example. It's a 4-bit quant, but other levels of quantization are available in the model repo if preferred
model_path = hf_hub_download(model_name, filename=model_file)
llm = Llama(
model_path=model_path, # Update this to your local model path
n_ctx=8192,
n_threads=12,
temperature=0.7,
)
"""
model_id = "google/medgemma-4b-pt"
model_medg = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="cpu",
)
processor = AutoProcessor.from_pretrained(model_id)
@app.route('/analyze-image', methods=['POST'])
def analyze_image():
image = None
image_path = None
# Get optional image
file = request.files.get('file')
if file:
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp:
file.save(tmp.name)
image_path = tmp.name
try:
image = Image.open(tmp.name).convert("RGB")
except Exception as e:
return jsonify({'error': f'Invalid image: {str(e)}'}), 400
# Get optional chat history
try:
chat_history = request.form.get("chat_history")
if not chat_history:
chat_history = request.json.get("chat_history", "[]")
else:
chat_history = json.loads(chat_history)
except Exception as e:
return jsonify({'error': f'Invalid or missing chat_history: {str(e)}'}), 400
# Build text prompt (from chat history)
prompt_parts = []
for msg in chat_history:
role = msg.get("role", "").strip().lower()
content = msg.get("content", "").strip()
if role == "system":
prompt_parts.append(content)
elif role == "user":
prompt_parts.append(f"User: {content}")
elif role == "assistant":
prompt_parts.append(f"Assistant: {content}")
combined_prompt = "\n".join(prompt_parts).strip()
if not image and not combined_prompt:
return jsonify({'error': 'You must provide either an image or a prompt.'}), 400
# Final model prompt
if image:
combined_prompt = f"<start_of_image> {combined_prompt or 'Findings:'}"
else:
combined_prompt = combined_prompt or "Findings:"
model_prompt = f" {combined_prompt or 'Response:'}"
# Prepare input to model
inputs = processor(
text=model_prompt,
images=image if image else None,
return_tensors="pt"
).to(model_medg.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
print(model_prompt)
with torch.inference_mode():
generation = model_medg.generate(
**inputs,
max_new_tokens=1000,
do_sample=False
)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
if image_path and os.path.exists(image_path):
os.remove(image_path)
return jsonify({"result": decoded.strip()})
@app.route('/med-llm', methods=['POST'])
def med_llm():
uploaded_file = request.files.get('file')
if not uploaded_file:
return jsonify({'error': 'No file uploaded'}), 400
mime_type = uploaded_file.mimetype
print(f"Received file type: {mime_type}")
with tempfile.NamedTemporaryFile(delete=False, suffix='.bin') as tmp:
uploaded_file.save(tmp.name)
if mime_type.startswith('image/'):
mock_response = "📷 MedGemma Image Analysis: Detected mild cardiomegaly."
elif mime_type.startswith('audio/'):
mock_response = "🎧 MedGemma Audio Analysis: Suggests potential wheezing."
else:
mock_response = "Unsupported file type."
return jsonify({'result': mock_response})
model = whisper.load_model("base") # You can change to 'small', 'medium', etc.
@app.route('/transcribe-stream', methods=['POST'])
def transcribe_stream():
# Save the audio file from request
audio_file = request.files.get('audio')
if not audio_file:
return "Missing audio file", 400
# Save to temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
audio_path = tmp.name
audio_file.save(audio_path)
def generate():
# Transcribe using Whisper (non-streaming)
result = model.transcribe(audio_path)
for word in result['text'].split():
yield f"data: {word}\n\n"
time.sleep(0.3) # Simulate streaming
os.remove(audio_path) # Clean up
return Response(generate(), mimetype='text/event-stream')
def translate_to_arabic(text):
try:
translated = GoogleTranslator(source='auto', target='ar').translate(text)
return translated
except Exception as e:
print(f"Translation failed: {e}")
return text
@app.route('/chat-translate', methods=['POST'])
def chat_translate():
try:
data = request.get_json()
chat_history = data.get('chat_history', [])
translate = request.args.get("translate") == "true"
# Join messages into prompt
prompt_parts = []
for msg in chat_history:
role = msg.get("role", "").strip().lower()
content = msg.get("content", "").strip()
if role == "system":
prompt_parts.append(f"System:\n{content}")
elif role == "user":
prompt_parts.append(f"User:\n{content}")
elif role == "model":
prompt_parts.append(f"Assistant:\n{content}")
result = content # Use the last model response as the result
full_prompt = "\n\n".join(prompt_parts)
# Simulated LLM response
# result = "Simulated model answer. Possible findings: * Infection * Fluid accumulation. Next steps: * Follow-up test."
if translate:
result = translate_to_arabic(result)
return jsonify({"result": result})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=7860) |