File size: 7,205 Bytes
c9c6b30
 
 
 
 
 
 
 
 
 
 
 
 
6f318f4
c9c6b30
 
 
65212b6
c9c6b30
6f318f4
c9c6b30
6f318f4
c9c6b30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f318f4
c9c6b30
3649fd9
c9c6b30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f43f764
c9c6b30
a5cbc59
 
 
 
f43f764
 
c9c6b30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e50d767
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import base64
import io
from flask import Flask, request, jsonify,Response
from flask_cors import CORS
import tempfile
import time
from flask import Flask, request, jsonify
from transformers import AutoProcessor, AutoModelForVision2Seq , AutoModelForImageTextToText
from PIL import Image
import torch
import tempfile
import whisper
import json
#import uvicorn
app = Flask(__name__)
from deep_translator import GoogleTranslator

CORS(app, resources={r"/*": {"origins": "*"}})  # Enable CORS for all routes
# Load MedGemma model (4B) on startup
#from asgiref.wsgi import WsgiToAsgi

#asgi_app = WsgiToAsgi(app)  # for uvicorn compatibility
from transformers import AutoTokenizer, AutoModelForCausalLM
#from huggingface_hub import login
import os
#from llama_cpp import Llama
#from huggingface_hub import hf_hub_download
"""hugging_face_token = os.getenv("kk")
if not hugging_face_token:
    raise EnvironmentError("HUGGINGFACE_TOKEN environment variable not set.")

#login(hugging_face_token)
"""
"""model_name = "unsloth/medgemma-4b-it-GGUF"
model_file = "medgemma-4b-it-Q8_0.gguf" # this is the specific model file we'll use in this example. It's a 4-bit quant, but other levels of quantization are available in the model repo if preferred
model_path = hf_hub_download(model_name, filename=model_file)
llm = Llama(
    model_path=model_path,  # Update this to your local model path
    n_ctx=8192,
    n_threads=12,
    temperature=0.7,
)
"""
model_id = "google/medgemma-4b-pt"

model_medg = AutoModelForImageTextToText.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="cpu",
)
processor = AutoProcessor.from_pretrained(model_id)
@app.route('/analyze-image', methods=['POST'])
def analyze_image():
    image = None
    image_path = None

    # Get optional image
    file = request.files.get('file')
    if file:
        with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp:
            file.save(tmp.name)
            image_path = tmp.name
            try:
                image = Image.open(tmp.name).convert("RGB")
            except Exception as e:
                return jsonify({'error': f'Invalid image: {str(e)}'}), 400

    # Get optional chat history
    try:
        chat_history = request.form.get("chat_history")
        if not chat_history:
            chat_history = request.json.get("chat_history", "[]")
        else:
            chat_history = json.loads(chat_history)
    except Exception as e:
        return jsonify({'error': f'Invalid or missing chat_history: {str(e)}'}), 400

    # Build text prompt (from chat history)
    prompt_parts = []
    for msg in chat_history:
        role = msg.get("role", "").strip().lower()
        content = msg.get("content", "").strip()
        if role == "system":
            prompt_parts.append(content)
        elif role == "user":
            prompt_parts.append(f"User: {content}")
        elif role == "assistant":
            prompt_parts.append(f"Assistant: {content}")

    combined_prompt = "\n".join(prompt_parts).strip()

    if not image and not combined_prompt:
        return jsonify({'error': 'You must provide either an image or a prompt.'}), 400

    # Final model prompt
    

    if image:
        combined_prompt = f"<start_of_image> {combined_prompt or 'Findings:'}"
    else:
        combined_prompt = combined_prompt or "Findings:"
        
    model_prompt = f" {combined_prompt or 'Response:'}"
    # Prepare input to model
    inputs = processor(
        text=model_prompt,
        images=image if image else None,
        return_tensors="pt"
    ).to(model_medg.device, dtype=torch.bfloat16)

    input_len = inputs["input_ids"].shape[-1]
    print(model_prompt)

    with torch.inference_mode():
        generation = model_medg.generate(
            **inputs,
            max_new_tokens=1000,
            do_sample=False
        )
        generation = generation[0][input_len:]

    decoded = processor.decode(generation, skip_special_tokens=True)

    if image_path and os.path.exists(image_path):
        os.remove(image_path)

    return jsonify({"result": decoded.strip()})

@app.route('/med-llm', methods=['POST'])
def med_llm():
    uploaded_file = request.files.get('file')
    if not uploaded_file:
        return jsonify({'error': 'No file uploaded'}), 400

    mime_type = uploaded_file.mimetype
    print(f"Received file type: {mime_type}")

    with tempfile.NamedTemporaryFile(delete=False, suffix='.bin') as tmp:
        uploaded_file.save(tmp.name)

        if mime_type.startswith('image/'):
            mock_response = "📷 MedGemma Image Analysis: Detected mild cardiomegaly."
        elif mime_type.startswith('audio/'):
            mock_response = "🎧 MedGemma Audio Analysis: Suggests potential wheezing."
        else:
            mock_response = "Unsupported file type."

        return jsonify({'result': mock_response})
model = whisper.load_model("base")  # You can change to 'small', 'medium', etc.
@app.route('/transcribe-stream', methods=['POST'])
def transcribe_stream():
    # Save the audio file from request
    audio_file = request.files.get('audio')
    if not audio_file:
        return "Missing audio file", 400

    # Save to temp file
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
        audio_path = tmp.name
        audio_file.save(audio_path)

    def generate():
        # Transcribe using Whisper (non-streaming)
        result = model.transcribe(audio_path)
        for word in result['text'].split():
            yield f"data: {word}\n\n"
            time.sleep(0.3)  # Simulate streaming

        os.remove(audio_path)  # Clean up

    return Response(generate(), mimetype='text/event-stream')
def translate_to_arabic(text):
    try:
        translated = GoogleTranslator(source='auto', target='ar').translate(text)
        return translated
    except Exception as e:
        print(f"Translation failed: {e}")
        return text
@app.route('/chat-translate', methods=['POST'])
def chat_translate():
    try:
        data = request.get_json()
        chat_history = data.get('chat_history', [])
        translate = request.args.get("translate") == "true"

        # Join messages into prompt
        prompt_parts = []
        for msg in chat_history:
            role = msg.get("role", "").strip().lower()
            content = msg.get("content", "").strip()
            if role == "system":
                prompt_parts.append(f"System:\n{content}")
            elif role == "user":
                prompt_parts.append(f"User:\n{content}")
            elif role == "model":
                prompt_parts.append(f"Assistant:\n{content}")
                result = content  # Use the last model response as the result

        full_prompt = "\n\n".join(prompt_parts)

        # Simulated LLM response
      #  result = "Simulated model answer. Possible findings: * Infection * Fluid accumulation. Next steps: * Follow-up test."

        if translate:
            result = translate_to_arabic(result)

        return jsonify({"result": result})

    except Exception as e:
        return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0', port=7860)