ysharma HF Staff commited on
Commit
5738139
Β·
verified Β·
1 Parent(s): 7bab86d

Create mcp_client.py

Browse files
Files changed (1) hide show
  1. mcp_client.py +759 -0
mcp_client.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MCP Client implementation for Universal MCP Client - Fixed Version
3
+ """
4
+ import asyncio
5
+ import json
6
+ import re
7
+ import logging
8
+ import traceback
9
+ from typing import Dict, Optional, Tuple, List, Any
10
+ from openai import OpenAI
11
+
12
+ # Import the proper MCP client components
13
+ from mcp import ClientSession
14
+ from mcp.client.sse import sse_client
15
+
16
+ from config import MCPServerConfig, AppConfig, HTTPX_AVAILABLE
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class UniversalMCPClient:
21
+ """Universal MCP Client using HuggingFace Inference Providers instead of Anthropic"""
22
+
23
+ def __init__(self):
24
+ self.servers: Dict[str, MCPServerConfig] = {}
25
+ self.enabled_servers: Dict[str, bool] = {} # Track enabled/disabled servers
26
+ self.hf_client = None
27
+ self.current_provider = None
28
+ self.current_model = None
29
+ self.server_tools = {} # Cache for server tools
30
+
31
+ # Initialize HF Inference Client if token is available
32
+ if AppConfig.HF_TOKEN:
33
+ self.hf_client = OpenAI(
34
+ base_url="https://router.huggingface.co/v1",
35
+ api_key=AppConfig.HF_TOKEN
36
+ )
37
+ logger.info("βœ… HuggingFace Inference client initialized")
38
+ else:
39
+ logger.warning("⚠️ HF_TOKEN not found")
40
+
41
+ def enable_server(self, server_name: str, enabled: bool = True):
42
+ """Enable or disable a server"""
43
+ if server_name in self.servers:
44
+ self.enabled_servers[server_name] = enabled
45
+ logger.info(f"πŸ”§ Server {server_name} {'enabled' if enabled else 'disabled'}")
46
+
47
+ def get_enabled_servers(self) -> Dict[str, MCPServerConfig]:
48
+ """Get only enabled servers"""
49
+ return {name: config for name, config in self.servers.items()
50
+ if self.enabled_servers.get(name, True)}
51
+
52
+ def remove_all_servers(self):
53
+ """Remove all servers"""
54
+ count = len(self.servers)
55
+ self.servers.clear()
56
+ self.enabled_servers.clear()
57
+ self.server_tools.clear()
58
+ logger.info(f"πŸ—‘οΈ Removed all {count} servers")
59
+ return count
60
+
61
+ def set_model_and_provider(self, provider_id: str, model_id: str):
62
+ """Set the current provider and model"""
63
+ self.current_provider = provider_id
64
+ self.current_model = model_id
65
+ logger.info(f"πŸ”§ Set provider: {provider_id}, model: {model_id}")
66
+
67
+ def get_model_endpoint(self) -> str:
68
+ """Get the current model endpoint for API calls"""
69
+ if not self.current_provider or not self.current_model:
70
+ raise ValueError("Provider and model must be set before making API calls")
71
+
72
+ return AppConfig.get_model_endpoint(self.current_model, self.current_provider)
73
+
74
+ async def add_server_async(self, config: MCPServerConfig) -> Tuple[bool, str]:
75
+ """Add an MCP server using pure MCP protocol"""
76
+ try:
77
+ logger.info(f"πŸ”§ Adding MCP server: {config.name} at {config.url}")
78
+
79
+ # Clean and validate URL - handle various input formats
80
+ original_url = config.url.strip()
81
+
82
+ # Remove common MCP endpoint variations
83
+ base_url = original_url
84
+ for endpoint in ["/gradio_api/mcp/sse", "/gradio_api/mcp/", "/gradio_api/mcp"]:
85
+ if base_url.endswith(endpoint):
86
+ base_url = base_url[:-len(endpoint)]
87
+ break
88
+
89
+ # Remove trailing slashes
90
+ base_url = base_url.rstrip("/")
91
+
92
+ # Construct proper MCP URL
93
+ mcp_url = f"{base_url}/gradio_api/mcp/sse"
94
+
95
+ logger.info(f"πŸ”§ Original URL: {original_url}")
96
+ logger.info(f"πŸ”§ Base URL: {base_url}")
97
+ logger.info(f"πŸ”§ MCP URL: {mcp_url}")
98
+
99
+ # Extract space ID if it's a HuggingFace space
100
+ if "hf.space" in base_url:
101
+ space_parts = base_url.split("/")
102
+ if len(space_parts) >= 1:
103
+ space_id = space_parts[-1].replace('.hf.space', '').replace('https://', '').replace('http://', '')
104
+ if '-' in space_id:
105
+ # Format: username-spacename.hf.space
106
+ config.space_id = space_id.replace('-', '/', 1)
107
+ else:
108
+ config.space_id = space_id
109
+ logger.info(f"πŸ“ Detected HF Space ID: {config.space_id}")
110
+
111
+ # Update config with proper MCP URL
112
+ config.url = mcp_url
113
+
114
+ # Test MCP connection and cache tools
115
+ success, message = await self._test_mcp_connection(config)
116
+
117
+ if success:
118
+ self.servers[config.name] = config
119
+ self.enabled_servers[config.name] = True # Enable by default
120
+ logger.info(f"βœ… MCP Server {config.name} added successfully")
121
+ return True, f"βœ… Successfully added MCP server: {config.name}\n{message}"
122
+ else:
123
+ logger.error(f"❌ Failed to connect to MCP server {config.name}: {message}")
124
+ return False, f"❌ Failed to add server: {config.name}\n{message}"
125
+
126
+ except Exception as e:
127
+ error_msg = f"Failed to add server {config.name}: {str(e)}"
128
+ logger.error(error_msg)
129
+ logger.error(traceback.format_exc())
130
+ return False, f"❌ {error_msg}"
131
+
132
+ async def _test_mcp_connection(self, config: MCPServerConfig) -> Tuple[bool, str]:
133
+ """Test MCP server connection with detailed debugging and tool caching"""
134
+ try:
135
+ logger.info(f"πŸ” Testing MCP connection to {config.url}")
136
+
137
+ async with sse_client(config.url, timeout=20.0) as (read_stream, write_stream):
138
+ async with ClientSession(read_stream, write_stream) as session:
139
+ # Initialize MCP session
140
+ logger.info("πŸ”§ Initializing MCP session...")
141
+ await session.initialize()
142
+
143
+ # List available tools
144
+ logger.info("πŸ“‹ Listing available tools...")
145
+ tools = await session.list_tools()
146
+
147
+ # Cache tools for this server
148
+ server_tools = {}
149
+ tool_info = []
150
+
151
+ for tool in tools.tools:
152
+ server_tools[tool.name] = {
153
+ 'description': tool.description,
154
+ 'schema': tool.inputSchema if hasattr(tool, 'inputSchema') else None
155
+ }
156
+ tool_info.append(f" - {tool.name}: {tool.description}")
157
+ logger.info(f" πŸ“ Tool: {tool.name}")
158
+ logger.info(f" Description: {tool.description}")
159
+ if hasattr(tool, 'inputSchema') and tool.inputSchema:
160
+ logger.info(f" Input Schema: {tool.inputSchema}")
161
+
162
+ # Cache tools for this server
163
+ self.server_tools[config.name] = server_tools
164
+
165
+ if len(tools.tools) == 0:
166
+ return False, "No tools found on MCP server"
167
+
168
+ message = f"Connected successfully!\nFound {len(tools.tools)} tools:\n" + "\n".join(tool_info)
169
+ return True, message
170
+
171
+ except asyncio.TimeoutError:
172
+ return False, "Connection timeout - server may be sleeping or unreachable"
173
+ except Exception as e:
174
+ logger.error(f"MCP connection failed: {e}")
175
+ logger.error(traceback.format_exc())
176
+ return False, f"Connection failed: {str(e)}"
177
+
178
+ async def call_mcp_tool_async(self, server_name: str, tool_name: str, arguments: dict) -> Tuple[bool, str]:
179
+ """Call a tool on a specific MCP server"""
180
+ logger.info(f"πŸ”§ MCP Tool Call - Server: {server_name}, Tool: {tool_name}")
181
+ logger.info(f"πŸ”§ Arguments: {arguments}")
182
+
183
+ if server_name not in self.servers:
184
+ error_msg = f"Server {server_name} not found. Available servers: {list(self.servers.keys())}"
185
+ logger.error(f"❌ {error_msg}")
186
+ return False, error_msg
187
+
188
+ config = self.servers[server_name]
189
+ logger.info(f"πŸ”§ Using server config: {config.url}")
190
+
191
+ try:
192
+ logger.info(f"πŸ”— Connecting to MCP server at {config.url}")
193
+ async with sse_client(config.url, timeout=30.0) as (read_stream, write_stream):
194
+ async with ClientSession(read_stream, write_stream) as session:
195
+ # Initialize MCP session
196
+ logger.info("πŸ”§ Initializing MCP session...")
197
+ await session.initialize()
198
+
199
+ # Call the tool
200
+ logger.info(f"πŸ”§ Calling tool {tool_name} with arguments: {arguments}")
201
+ result = await session.call_tool(tool_name, arguments)
202
+
203
+ # Extract result content
204
+ if result.content:
205
+ result_text = result.content[0].text if hasattr(result.content[0], 'text') else str(result.content[0])
206
+ logger.info(f"βœ… Tool call successful, result length: {len(result_text)}")
207
+ logger.info(f"πŸ“‹ Result preview: {result_text[:200]}...")
208
+ return True, result_text
209
+ else:
210
+ error_msg = "No content returned from tool"
211
+ logger.error(f"❌ {error_msg}")
212
+ return False, error_msg
213
+
214
+ except asyncio.TimeoutError:
215
+ error_msg = f"Tool call timeout for {tool_name} on {server_name}"
216
+ logger.error(f"❌ {error_msg}")
217
+ return False, error_msg
218
+ except Exception as e:
219
+ error_msg = f"Tool call failed: {str(e)}"
220
+ logger.error(f"❌ MCP tool call failed: {e}")
221
+ logger.error(traceback.format_exc())
222
+ return False, error_msg
223
+
224
+ def generate_chat_completion(self, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
225
+ """Generate chat completion using HuggingFace Inference Providers"""
226
+ if not self.hf_client:
227
+ raise ValueError("HuggingFace client not initialized. Please set HF_TOKEN.")
228
+
229
+ if not self.current_provider or not self.current_model:
230
+ raise ValueError("Provider and model must be set before making API calls")
231
+
232
+ # Get the model endpoint
233
+ model_endpoint = self.get_model_endpoint()
234
+
235
+ # Set up default parameters for GPT OSS models with higher limits
236
+ params = {
237
+ "model": model_endpoint,
238
+ "messages": messages,
239
+ "max_tokens": kwargs.pop("max_tokens", 8192), # Use pop to avoid conflicts
240
+ "temperature": kwargs.get("temperature", 0.3),
241
+ "stream": kwargs.get("stream", False)
242
+ }
243
+
244
+ # Add any remaining kwargs
245
+ params.update(kwargs)
246
+
247
+ # Add reasoning effort if specified (GPT OSS feature)
248
+ reasoning_effort = kwargs.pop("reasoning_effort", AppConfig.DEFAULT_REASONING_EFFORT)
249
+ if reasoning_effort:
250
+ # For GPT OSS models, we can set reasoning in system prompt
251
+ system_message = None
252
+ for msg in messages:
253
+ if msg.get("role") == "system":
254
+ system_message = msg
255
+ break
256
+
257
+ if system_message:
258
+ system_message["content"] += f"\n\nReasoning: {reasoning_effort}"
259
+ else:
260
+ messages.insert(0, {
261
+ "role": "system",
262
+ "content": f"You are a helpful AI assistant. Reasoning: {reasoning_effort}"
263
+ })
264
+
265
+ try:
266
+ logger.info(f"πŸ€– Calling {model_endpoint} via {self.current_provider}")
267
+ response = self.hf_client.chat.completions.create(**params)
268
+ return response
269
+ except Exception as e:
270
+ logger.error(f"HF Inference API call failed: {e}")
271
+ raise
272
+
273
+ def generate_chat_completion_with_mcp_tools(self, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
274
+ """Generate chat completion with MCP tool support"""
275
+ enabled_servers = self.get_enabled_servers()
276
+ if not enabled_servers:
277
+ # No enabled MCP servers available, use regular completion
278
+ logger.info("πŸ€– No enabled MCP servers available, using regular chat completion")
279
+ return self.generate_chat_completion(messages, **kwargs)
280
+
281
+ logger.info(f"πŸ”§ Processing chat with {len(enabled_servers)} enabled MCP servers available")
282
+
283
+ # Add system message about available tools with exact tool names
284
+ tool_descriptions = []
285
+ server_names = []
286
+ exact_tool_mappings = []
287
+
288
+ for server_name, config in enabled_servers.items():
289
+ tool_descriptions.append(f"- **{server_name}**: {config.description}")
290
+ server_names.append(server_name)
291
+
292
+ # Add exact tool names if we have them cached
293
+ if server_name in self.server_tools:
294
+ for tool_name, tool_info in self.server_tools[server_name].items():
295
+ exact_tool_mappings.append(f" * Server '{server_name}' has tool '{tool_name}': {tool_info['description']}")
296
+
297
+ # Get the actual server name (not the space ID)
298
+ server_list = ", ".join([f'"{name}"' for name in server_names])
299
+
300
+ tools_system_msg = f"""
301
+ You have access to the following MCP tools:
302
+ {chr(10).join(tool_descriptions)}
303
+ EXACT TOOL MAPPINGS:
304
+ {chr(10).join(exact_tool_mappings) if exact_tool_mappings else "Loading tool mappings..."}
305
+ IMPORTANT SERVER NAMES: {server_list}
306
+ When you need to use a tool, respond with ONLY a JSON object in this EXACT format:
307
+ {{"use_tool": true, "server": "exact_server_name", "tool": "exact_tool_name", "arguments": {{"param": "value"}}}}
308
+ CRITICAL INSTRUCTIONS:
309
+ - Use ONLY the exact server names from this list: {server_list}
310
+ - Use the exact tool names as shown in the mappings above
311
+ - Always include all required parameters in the arguments
312
+ - Do not include any other text before or after the JSON
313
+ - Make sure the JSON is complete and properly formatted
314
+ If you don't need to use a tool, respond normally without any JSON.
315
+ """
316
+
317
+ # Add tools system message with increased context
318
+ enhanced_messages = messages.copy()
319
+ if enhanced_messages and enhanced_messages[0].get("role") == "system":
320
+ enhanced_messages[0]["content"] += "\n\n" + tools_system_msg
321
+ else:
322
+ enhanced_messages.insert(0, {"role": "system", "content": tools_system_msg})
323
+
324
+ # Get initial response with higher token limit
325
+ logger.info("πŸ€– Getting initial response from LLM...")
326
+ response = self.generate_chat_completion(enhanced_messages, **{"max_tokens": 8192})
327
+ response_text = response.choices[0].message.content
328
+
329
+ logger.info(f"πŸ€– LLM Response (length: {len(response_text)}): {response_text}")
330
+
331
+ # Check if the response indicates tool usage
332
+ if '"use_tool": true' in response_text:
333
+ logger.info("πŸ”§ Tool usage detected, parsing JSON...")
334
+
335
+ # Extract and parse JSON more robustly
336
+ tool_request = self._extract_tool_json(response_text)
337
+
338
+ if not tool_request:
339
+ # Fallback: try to extract tool info manually
340
+ logger.info("πŸ”§ JSON parsing failed, trying manual extraction...")
341
+ tool_request = self._manual_tool_extraction(response_text)
342
+
343
+ if tool_request:
344
+ server_name = tool_request.get("server")
345
+ tool_name = tool_request.get("tool")
346
+ arguments = tool_request.get("arguments", {})
347
+
348
+ # Replace any local file paths in arguments with uploaded URLs
349
+ if hasattr(self, 'chat_handler_file_mapping'):
350
+ for arg_key, arg_value in arguments.items():
351
+ if isinstance(arg_value, str) and arg_value.startswith('/tmp/gradio/'):
352
+ # Check if we have an uploaded URL for this local path
353
+ for local_path, uploaded_url in self.chat_handler_file_mapping.items():
354
+ if local_path in arg_value or arg_value in local_path:
355
+ logger.info(f"πŸ”„ Replacing local path {arg_value} with uploaded URL {uploaded_url}")
356
+ arguments[arg_key] = uploaded_url
357
+ break
358
+
359
+ logger.info(f"πŸ”§ Tool request - Server: {server_name}, Tool: {tool_name}, Args: {arguments}")
360
+
361
+ if server_name not in self.servers:
362
+ available_servers = list(self.servers.keys())
363
+ logger.error(f"❌ Server '{server_name}' not found. Available servers: {available_servers}")
364
+ # Try to find a matching server by space_id or similar name
365
+ matching_server = None
366
+ for srv_name, srv_config in self.servers.items():
367
+ if (srv_config.space_id and server_name in srv_config.space_id) or server_name in srv_name:
368
+ matching_server = srv_name
369
+ logger.info(f"πŸ”§ Found matching server: {matching_server}")
370
+ break
371
+
372
+ if matching_server and self.enabled_servers.get(matching_server, True):
373
+ server_name = matching_server
374
+ logger.info(f"πŸ”§ Using corrected server name: {server_name}")
375
+ else:
376
+ # Return error response with server name correction
377
+ error_msg = f"Server '{server_name}' not found or disabled. Available enabled servers: {[name for name, enabled in self.enabled_servers.items() if enabled]}"
378
+ response._tool_execution = {
379
+ "server": server_name,
380
+ "tool": tool_name,
381
+ "result": error_msg,
382
+ "success": False
383
+ }
384
+ return response
385
+ elif not self.enabled_servers.get(server_name, True):
386
+ logger.error(f"❌ Server '{server_name}' is disabled")
387
+ response._tool_execution = {
388
+ "server": server_name,
389
+ "tool": tool_name,
390
+ "result": f"Server '{server_name}' is currently disabled",
391
+ "success": False
392
+ }
393
+ return response
394
+ # Validate tool name exists for this server
395
+ if server_name in self.server_tools and tool_name not in self.server_tools[server_name]:
396
+ available_tools = list(self.server_tools[server_name].keys())
397
+ logger.warning(f"⚠️ Tool '{tool_name}' not found for server '{server_name}'. Available tools: {available_tools}")
398
+
399
+ # Try to find the correct tool name
400
+ if available_tools:
401
+ # Use the first available tool if there's only one
402
+ if len(available_tools) == 1:
403
+ tool_name = available_tools[0]
404
+ logger.info(f"πŸ”§ Using only available tool: {tool_name}")
405
+ # Or try to find a similar tool name
406
+ else:
407
+ for available_tool in available_tools:
408
+ if tool_name.lower() in available_tool.lower() or available_tool.lower() in tool_name.lower():
409
+ tool_name = available_tool
410
+ logger.info(f"πŸ”§ Found similar tool name: {tool_name}")
411
+ break
412
+
413
+ # Call the MCP tool
414
+ def run_mcp_tool():
415
+ loop = asyncio.new_event_loop()
416
+ asyncio.set_event_loop(loop)
417
+ try:
418
+ return loop.run_until_complete(
419
+ self.call_mcp_tool_async(server_name, tool_name, arguments)
420
+ )
421
+ finally:
422
+ loop.close()
423
+
424
+ success, result = run_mcp_tool()
425
+
426
+ if success:
427
+ logger.info(f"βœ… Tool call successful, result length: {len(str(result))}")
428
+
429
+ # Add tool result to conversation and get final response with better prompting
430
+ enhanced_messages.append({"role": "assistant", "content": response_text})
431
+ enhanced_messages.append({"role": "user", "content": f"Tool '{tool_name}' from server '{server_name}' completed successfully. Result: {result}\n\nPlease provide a helpful response based on this tool result. If the result contains media URLs, present them appropriately."})
432
+
433
+ # Remove the tool instruction from the system message for the final response
434
+ final_messages = enhanced_messages.copy()
435
+ if final_messages[0].get("role") == "system":
436
+ final_messages[0]["content"] = final_messages[0]["content"].split("You have access to the following MCP tools:")[0].strip()
437
+
438
+ logger.info("πŸ€– Getting final response with tool result...")
439
+ final_response = self.generate_chat_completion(final_messages, **{"max_tokens": 4096})
440
+
441
+ # Store tool execution info for the chat handler
442
+ final_response._tool_execution = {
443
+ "server": server_name,
444
+ "tool": tool_name,
445
+ "result": result,
446
+ "success": True
447
+ }
448
+
449
+ return final_response
450
+ else:
451
+ logger.error(f"❌ Tool call failed: {result}")
452
+ # Return original response with error info
453
+ response._tool_execution = {
454
+ "server": server_name,
455
+ "tool": tool_name,
456
+ "result": result,
457
+ "success": False
458
+ }
459
+ return response
460
+ else:
461
+ logger.warning("⚠️ Failed to parse tool request JSON")
462
+ else:
463
+ logger.info("πŸ’¬ No tool usage detected, returning normal response")
464
+
465
+ # Return original response if no tool usage or tool call failed
466
+ return response
467
+
468
+ def _extract_tool_json(self, text: str) -> Optional[Dict[str, Any]]:
469
+ """Extract JSON from LLM response more robustly"""
470
+ import json
471
+ import re
472
+
473
+ logger.info(f"πŸ” Full LLM response text: {text}")
474
+
475
+ # Try multiple strategies to extract JSON
476
+ strategies = [
477
+ # Strategy 1: Find complete JSON between outer braces
478
+ lambda t: re.search(r'\{[^{}]*"use_tool"[^{}]*"arguments"[^{}]*\{[^{}]*\}[^{}]*\}', t),
479
+ # Strategy 2: Find JSON that starts with {"use_tool" and reconstruct if needed
480
+ lambda t: self._reconstruct_json_from_start(t),
481
+ # Strategy 3: Find any complete JSON object
482
+ lambda t: re.search(r'\{(?:[^{}]|\{[^{}]*\})*\}', t),
483
+ ]
484
+
485
+ for i, strategy in enumerate(strategies, 1):
486
+ try:
487
+ if i == 2:
488
+ # Strategy 2 returns a string directly
489
+ json_str = strategy(text)
490
+ if not json_str:
491
+ continue
492
+ else:
493
+ match = strategy(text)
494
+ if not match:
495
+ continue
496
+ json_str = match.group(0)
497
+
498
+ logger.info(f"πŸ” JSON extraction strategy {i} found: {json_str}")
499
+
500
+ # Clean up the JSON string
501
+ json_str = json_str.strip()
502
+
503
+ # Try to parse
504
+ parsed = json.loads(json_str)
505
+
506
+ # Validate it's a tool request
507
+ if parsed.get("use_tool") is True:
508
+ logger.info(f"βœ… Valid tool request parsed: {parsed}")
509
+ return parsed
510
+
511
+ except json.JSONDecodeError as e:
512
+ logger.warning(f"⚠️ JSON parse error with strategy {i}: {e}")
513
+ logger.warning(f"⚠️ Problematic JSON: {json_str if 'json_str' in locals() else 'N/A'}")
514
+ continue
515
+ except Exception as e:
516
+ logger.warning(f"⚠️ Strategy {i} failed: {e}")
517
+ continue
518
+
519
+ logger.error("❌ Failed to extract valid JSON from response")
520
+ return None
521
+
522
+ def _manual_tool_extraction(self, text: str) -> Optional[Dict[str, Any]]:
523
+ """Manually extract tool information as fallback"""
524
+ import re
525
+
526
+ logger.info("πŸ”§ Attempting manual tool extraction...")
527
+
528
+ try:
529
+ # Extract server name
530
+ server_match = re.search(r'"server":\s*"([^"]+)"', text)
531
+ tool_match = re.search(r'"tool":\s*"([^"]+)"', text)
532
+
533
+ if not server_match or not tool_match:
534
+ logger.warning("⚠️ Could not find server or tool in manual extraction")
535
+ return None
536
+
537
+ server_name = server_match.group(1)
538
+ tool_name = tool_match.group(1)
539
+
540
+ # Try to extract arguments
541
+ args_match = re.search(r'"arguments":\s*\{([^}]+)\}', text)
542
+ arguments = {}
543
+
544
+ if args_match:
545
+ args_content = args_match.group(1)
546
+ # Simple extraction of key-value pairs
547
+ pairs = re.findall(r'"([^"]+)":\s*"([^"]+)"', args_content)
548
+ arguments = dict(pairs)
549
+
550
+ manual_request = {
551
+ "use_tool": True,
552
+ "server": server_name,
553
+ "tool": tool_name,
554
+ "arguments": arguments
555
+ }
556
+
557
+ logger.info(f"πŸ”§ Manual extraction successful: {manual_request}")
558
+ return manual_request
559
+
560
+ except Exception as e:
561
+ logger.error(f"❌ Manual extraction failed: {e}")
562
+ return None
563
+
564
+ def _reconstruct_json_from_start(self, text: str) -> Optional[str]:
565
+ """Try to reconstruct JSON if it's truncated"""
566
+ import re
567
+
568
+ # Find start of JSON
569
+ match = re.search(r'\{"use_tool":\s*true[^}]*', text)
570
+ if not match:
571
+ return None
572
+
573
+ json_start = match.start()
574
+ json_part = text[json_start:]
575
+
576
+ logger.info(f"πŸ”§ Reconstructing JSON from: {json_part[:200]}...")
577
+
578
+ # Try to find the end or reconstruct
579
+ brace_count = 0
580
+ end_pos = 0
581
+ in_string = False
582
+ escape_next = False
583
+
584
+ for i, char in enumerate(json_part):
585
+ if escape_next:
586
+ escape_next = False
587
+ continue
588
+
589
+ if char == '\\':
590
+ escape_next = True
591
+ continue
592
+
593
+ if char == '"' and not escape_next:
594
+ in_string = not in_string
595
+ continue
596
+
597
+ if not in_string:
598
+ if char == '{':
599
+ brace_count += 1
600
+ elif char == '}':
601
+ brace_count -= 1
602
+ if brace_count == 0:
603
+ end_pos = i + 1
604
+ break
605
+
606
+ if end_pos > 0:
607
+ reconstructed = json_part[:end_pos]
608
+ logger.info(f"πŸ”§ Reconstructed JSON: {reconstructed}")
609
+ return reconstructed
610
+ else:
611
+ # Try to add missing closing braces
612
+ missing_braces = json_part.count('{') - json_part.count('}')
613
+ if missing_braces > 0:
614
+ reconstructed = json_part + '}' * missing_braces
615
+ logger.info(f"πŸ”§ Added {missing_braces} closing braces: {reconstructed}")
616
+ return reconstructed
617
+
618
+ return None
619
+
620
+ def _extract_media_from_mcp_response(self, result_text: str, config: MCPServerConfig) -> Optional[str]:
621
+ """Enhanced media extraction from MCP responses with better URL resolution"""
622
+ if not isinstance(result_text, str):
623
+ logger.info(f"πŸ” Non-string result: {type(result_text)}")
624
+ return None
625
+
626
+ base_url = config.url.replace("/gradio_api/mcp/sse", "")
627
+ logger.info(f"πŸ” Processing MCP result for media: {result_text[:300]}...")
628
+ logger.info(f"πŸ” Base URL: {base_url}")
629
+
630
+ # 1. Try to parse as JSON (most Gradio MCP servers return structured data)
631
+ try:
632
+ if result_text.strip().startswith('[') or result_text.strip().startswith('{'):
633
+ logger.info("πŸ” Attempting JSON parse...")
634
+ data = json.loads(result_text.strip())
635
+ logger.info(f"πŸ” Parsed JSON structure: {data}")
636
+
637
+ # Handle array format: [{'image': {'url': '...'}}] or [{'url': '...'}]
638
+ if isinstance(data, list) and len(data) > 0:
639
+ item = data[0]
640
+ logger.info(f"πŸ” First array item: {item}")
641
+
642
+ if isinstance(item, dict):
643
+ # Check for nested media structure
644
+ for media_type in ['image', 'audio', 'video']:
645
+ if media_type in item and isinstance(item[media_type], dict):
646
+ media_data = item[media_type]
647
+ if 'url' in media_data:
648
+ url = media_data['url'].strip('\'"') # Clean quotes
649
+ logger.info(f"🎯 Found {media_type} URL: {url}")
650
+ return self._resolve_media_url(url, base_url)
651
+
652
+ # Check for direct URL
653
+ if 'url' in item:
654
+ url = item['url'].strip('\'"') # Clean quotes
655
+ logger.info(f"🎯 Found direct URL: {url}")
656
+ return self._resolve_media_url(url, base_url)
657
+
658
+ # Handle object format: {'image': {'url': '...'}} or {'url': '...'}
659
+ elif isinstance(data, dict):
660
+ logger.info(f"πŸ” Processing dict: {data}")
661
+
662
+ # Check for nested media structure
663
+ for media_type in ['image', 'audio', 'video']:
664
+ if media_type in data and isinstance(data[media_type], dict):
665
+ media_data = data[media_type]
666
+ if 'url' in media_data:
667
+ url = media_data['url'].strip('\'"') # Clean quotes
668
+ logger.info(f"🎯 Found {media_type} URL: {url}")
669
+ return self._resolve_media_url(url, base_url)
670
+
671
+ # Check for direct URL
672
+ if 'url' in data:
673
+ url = data['url'].strip('\'"') # Clean quotes
674
+ logger.info(f"🎯 Found direct URL: {url}")
675
+ return self._resolve_media_url(url, base_url)
676
+
677
+ except json.JSONDecodeError:
678
+ logger.info("πŸ” Not valid JSON, trying other formats...")
679
+ except Exception as e:
680
+ logger.warning(f"πŸ” JSON parsing error: {e}")
681
+
682
+ # 2. Check for Gradio file URLs (common pattern) with better cleaning
683
+ gradio_file_patterns = [
684
+ r'https://[^/]+\.hf\.space/gradio_api/file=/[^/]+/[^/]+/[^"\s\',]+',
685
+ r'https://[^/]+\.hf\.space/file=[^"\s\',]+',
686
+ r'/gradio_api/file=/[^"\s\',]+'
687
+ ]
688
+
689
+ for pattern in gradio_file_patterns:
690
+ match = re.search(pattern, result_text)
691
+ if match:
692
+ url = match.group(0).rstrip('\'",:;') # Remove trailing punctuation
693
+ logger.info(f"🎯 Found Gradio file URL: {url}")
694
+ if url.startswith('/'):
695
+ url = f"{base_url}{url}"
696
+ return url
697
+
698
+ # 3. Check for simple HTTP URLs in the text
699
+ http_url_pattern = r'https?://[^\s"<>]+'
700
+ matches = re.findall(http_url_pattern, result_text)
701
+ for url in matches:
702
+ if AppConfig.is_media_file(url):
703
+ logger.info(f"🎯 Found HTTP media URL: {url}")
704
+ return url
705
+
706
+ # 4. Check for data URLs (base64 encoded media)
707
+ if result_text.startswith('data:'):
708
+ logger.info("🎯 Found data URL")
709
+ return result_text
710
+
711
+ # 5. For simple file paths, create proper Gradio URLs
712
+ if AppConfig.is_media_file(result_text):
713
+ # Extract just the filename if it's a path
714
+ if '/' in result_text:
715
+ filename = result_text.split('/')[-1]
716
+ else:
717
+ filename = result_text.strip()
718
+
719
+ # Create proper Gradio file URL
720
+ media_url = f"{base_url}/file={filename}"
721
+ logger.info(f"🎯 Created media URL from filename: {media_url}")
722
+ return media_url
723
+
724
+ logger.info("❌ No media detected in result")
725
+ return None
726
+
727
+ def _resolve_media_url(self, url: str, base_url: str) -> str:
728
+ """Resolve relative URLs to absolute URLs with better handling"""
729
+ if url.startswith('http') or url.startswith('data:'):
730
+ return url
731
+ elif url.startswith('/gradio_api/file='):
732
+ return f"{base_url}{url}"
733
+ elif url.startswith('/file='):
734
+ return f"{base_url}/gradio_api{url}"
735
+ elif url.startswith('file='):
736
+ return f"{base_url}/gradio_api/{url}"
737
+ elif url.startswith('/'):
738
+ return f"{base_url}/file={url}"
739
+ else:
740
+ return f"{base_url}/file={url}"
741
+
742
+ def get_server_status(self) -> Dict[str, str]:
743
+ """Get status of all configured servers"""
744
+ status = {}
745
+ for name in self.servers:
746
+ compatibility = self._check_file_upload_compatibility(self.servers[name])
747
+ status[name] = f"βœ… Connected (MCP Protocol) - {compatibility}"
748
+ return status
749
+
750
+ def _check_file_upload_compatibility(self, config: MCPServerConfig) -> str:
751
+ """Check if a server likely supports file uploads"""
752
+ if "hf.space" in config.url:
753
+ return "🟑 Hugging Face Space (usually compatible)"
754
+ elif "gradio" in config.url.lower():
755
+ return "🟒 Gradio server (likely compatible)"
756
+ elif "localhost" in config.url or "127.0.0.1" in config.url:
757
+ return "🟒 Local server (file access available)"
758
+ else:
759
+ return "πŸ”΄ Remote server (may need public URLs)"