"""
Pydantic models for the MCP tool calling protocol.
Defines the data structures for MCP server configuration, tool definitions,
call tracking, ranking, and orchestration state.
"""
from __future__ import annotations
import hashlib
import json
import time
from enum import Enum
from pathlib import Path
from typing import Any, Optional
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# MCP Server Configuration
# ---------------------------------------------------------------------------
[docs]
class MCPTransport(str, Enum):
"""Supported MCP transport types.
Per MCP spec 2025-06-18: SSE is deprecated, replaced by Streamable HTTP.
STREAMABLE_HTTP is the recommended transport for remote servers.
STDIO remains for local per-user integrations.
"""
STDIO = "stdio"
STREAMABLE_HTTP = "streamable-http"
HTTP = "http"
[docs]
class MCPServerConfig(BaseModel):
"""Configuration for a single MCP server."""
name: str
description: str = ""
transport: MCPTransport = MCPTransport.HTTP
url: Optional[str] = None
command: Optional[str] = None
args: list[str] = Field(default_factory=list)
env: dict[str, str] = Field(default_factory=dict)
headers: dict[str, str] = Field(default_factory=dict)
enabled: bool = True
tags: list[str] = Field(default_factory=list)
# Rate limiting
max_concurrent: int = 10
timeout_seconds: int = 30
[docs]
class MCPServersFile(BaseModel):
"""Root schema for mcp_servers.json config file."""
version: str = "1.0"
servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
# ---------------------------------------------------------------------------
# Tool Definitions (reduced from LiteLLM spec)
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Tool Call Tracking
# ---------------------------------------------------------------------------
[docs]
class ErrorCategory(str, Enum):
"""Classification of errors for retry strategy selection."""
TRANSIENT = "transient" # 429, 502, 503, timeout — backoff + retry
SERVER = "server" # 500, unknown server error — circuit breaker
CLIENT = "client" # 400, 404, validation — no retry, fix args
UNKNOWN = "unknown" # Unclassified
# ---------------------------------------------------------------------------
# Circuit Breaker State
# ---------------------------------------------------------------------------
[docs]
class CircuitState(str, Enum):
"""Circuit breaker states per distributed systems pattern."""
CLOSED = "closed" # Normal operation, counting failures
OPEN = "open" # Blocking requests, cooldown active
HALF_OPEN = "half_open" # Probing with single request
# ---------------------------------------------------------------------------
# Ranking Index
# ---------------------------------------------------------------------------
[docs]
class ToolRankEntry(BaseModel):
"""Entry in the tool ranking index."""
tool_name: str
server_name: str
score: float = 0.0
relevance_score: float = 0.0
reliability_score: float = 0.0
latency_score: float = 0.0
inertia_score: float = 0.0
tags: list[str] = Field(default_factory=list)
[docs]
class RankingIndex(BaseModel):
"""The full ranking index for tool selection."""
entries: list[ToolRankEntry] = Field(default_factory=list)
last_updated: float = Field(default_factory=time.time)
[docs]
def top_k(self, k: int = 10) -> list[ToolRankEntry]:
"""Get top-k tools by score."""
return sorted(self.entries, key=lambda e: e.score, reverse=True)[:k]
[docs]
def for_query(self, query: str, k: int = 10) -> list[ToolRankEntry]:
"""Get top-k tools relevant to a query (simple keyword match)."""
query_lower = query.lower()
query_words = set(query_lower.split())
scored = []
for entry in self.entries:
text = f"{entry.tool_name} {entry.server_name} {' '.join(entry.tags)}".lower()
text_words = set(text.split())
overlap = len(query_words & text_words)
if overlap > 0 or any(w in text for w in query_words):
match_score = overlap / max(len(query_words), 1)
combined = entry.score * 0.5 + match_score * 0.5
scored.append((combined, entry))
scored.sort(key=lambda x: x[0], reverse=True)
return [e for _, e in scored[:k]]
# ---------------------------------------------------------------------------
# Orchestrator State
# ---------------------------------------------------------------------------
[docs]
class OrchestrationSession(BaseModel):
"""State for a hub-and-spoke orchestration session."""
session_id: str
call_records: list[ToolCallRecord] = Field(default_factory=list)
total_calls: int = 0
max_depth: int = 0
started_at: float = Field(default_factory=time.time)
ranking_index: RankingIndex = Field(default_factory=RankingIndex)
# Wall-clock timeout for watchdog
timeout_seconds: float = 1800.0 # 30 minutes default
[docs]
def add_record(self, record: ToolCallRecord) -> None:
"""Append a call record and update aggregate stats (total_calls, max_depth)."""
self.call_records.append(record)
self.total_calls += 1
if record.depth > self.max_depth:
self.max_depth = record.depth
@property
def is_timed_out(self) -> bool:
"""Check if session has exceeded wall-clock timeout."""
return (time.time() - self.started_at) > self.timeout_seconds
# ---------------------------------------------------------------------------
# LiteLLM API Spec Filter
# ---------------------------------------------------------------------------
[docs]
class LiteLLMEndpoint(BaseModel):
"""Reduced representation of a LiteLLM API endpoint."""
path: str
method: str
summary: str = ""
description: str = ""
parameters: dict[str, Any] = Field(default_factory=dict)
request_body: dict[str, Any] = Field(default_factory=dict)
response_schema: dict[str, Any] = Field(default_factory=dict)
[docs]
class LiteLLMFilteredSpec(BaseModel):
"""Filtered/reduced LiteLLM API spec - only what we need for tool calling."""
version: str = ""
base_url: str = ""
endpoints: list[LiteLLMEndpoint] = Field(default_factory=list)
schemas: dict[str, Any] = Field(default_factory=dict)
total_original_size_bytes: int = 0
filtered_size_bytes: int = 0
@property
def reduction_ratio(self) -> float:
"""Fraction of the original spec size that was removed by filtering."""
if self.total_original_size_bytes == 0:
return 0.0
return 1.0 - (self.filtered_size_bytes / self.total_original_size_bytes)