Files
nofx/mcp/client.go
T
2025-12-08 01:43:22 +08:00

509 lines
14 KiB
Go

package mcp
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
const (
ProviderCustom = "custom"
MCPClientTemperature = 0.5
)
var (
DefaultTimeout = 120 * time.Second
MaxRetryTimes = 3
retryableErrors = []string{
"EOF",
"timeout",
"connection reset",
"connection refused",
"temporary failure",
"no such host",
"stream error", // HTTP/2 stream error
"INTERNAL_ERROR", // Server internal error
}
)
// Client AI API configuration
type Client struct {
Provider string
APIKey string
BaseURL string
Model string
UseFullURL bool // Whether to use full URL (without appending /chat/completions)
MaxTokens int // Maximum tokens for AI response
httpClient *http.Client
logger Logger // Logger (replaceable)
config *Config // Config object (stores all configurations)
// hooks are used to implement dynamic dispatch (polymorphism)
// When DeepSeekClient embeds Client, hooks point to DeepSeekClient
// This way methods called in call() are automatically dispatched to the overridden version in subclass
hooks clientHooks
}
// New creates default client (backward compatible)
//
// Deprecated: Recommend using NewClient(...opts) for better flexibility
func New() AIClient {
return NewClient()
}
// NewClient creates client (supports options pattern)
//
// Usage examples:
// // Basic usage (backward compatible)
// client := mcp.NewClient()
//
// // Custom logger
// client := mcp.NewClient(mcp.WithLogger(customLogger))
//
// // Custom timeout
// client := mcp.NewClient(mcp.WithTimeout(60*time.Second))
//
// // Combine multiple options
// client := mcp.NewClient(
// mcp.WithDeepSeekConfig("sk-xxx"),
// mcp.WithLogger(customLogger),
// mcp.WithTimeout(60*time.Second),
// )
func NewClient(opts ...ClientOption) AIClient {
// 1. Create default config
cfg := DefaultConfig()
// 2. Apply user options
for _, opt := range opts {
opt(cfg)
}
// 3. Create client instance
client := &Client{
Provider: cfg.Provider,
APIKey: cfg.APIKey,
BaseURL: cfg.BaseURL,
Model: cfg.Model,
MaxTokens: cfg.MaxTokens,
UseFullURL: cfg.UseFullURL,
httpClient: cfg.HTTPClient,
logger: cfg.Logger,
config: cfg,
}
// 4. Set default Provider (if not set)
if client.Provider == "" {
client.Provider = ProviderDeepSeek
client.BaseURL = DefaultDeepSeekBaseURL
client.Model = DefaultDeepSeekModel
}
// 5. Set hooks to point to self
client.hooks = client
return client
}
// SetCustomAPI sets custom OpenAI-compatible API
func (client *Client) SetAPIKey(apiKey, apiURL, customModel string) {
client.Provider = ProviderCustom
client.APIKey = apiKey
// Check if URL ends with #, if so use full URL (without appending /chat/completions)
if strings.HasSuffix(apiURL, "#") {
client.BaseURL = strings.TrimSuffix(apiURL, "#")
client.UseFullURL = true
} else {
client.BaseURL = apiURL
client.UseFullURL = false
}
client.Model = customModel
}
func (client *Client) SetTimeout(timeout time.Duration) {
client.httpClient.Timeout = timeout
}
// CallWithMessages template method - fixed retry flow (cannot be overridden)
func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, error) {
if client.APIKey == "" {
return "", fmt.Errorf("AI API key not set, please call SetAPIKey first")
}
// Fixed retry flow
var lastErr error
maxRetries := client.config.MaxRetries
for attempt := 1; attempt <= maxRetries; attempt++ {
if attempt > 1 {
client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
}
// Call the fixed single-call flow
result, err := client.hooks.call(systemPrompt, userPrompt)
if err == nil {
if attempt > 1 {
client.logger.Infof("✓ AI API retry succeeded")
}
return result, nil
}
lastErr = err
// Check if error is retryable via hooks (supports custom retry strategy in subclass)
if !client.hooks.isRetryableError(err) {
return "", err
}
// Wait before retry
if attempt < maxRetries {
waitTime := client.config.RetryWaitBase * time.Duration(attempt)
client.logger.Infof("⏳ Waiting %v before retry...", waitTime)
time.Sleep(waitTime)
}
}
return "", fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
}
func (client *Client) setAuthHeader(reqHeader http.Header) {
reqHeader.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey))
}
func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
// Build messages array
messages := []map[string]string{}
// If system prompt exists, add system message
if systemPrompt != "" {
messages = append(messages, map[string]string{
"role": "system",
"content": systemPrompt,
})
}
// Add user message
messages = append(messages, map[string]string{
"role": "user",
"content": userPrompt,
})
// Build request body
requestBody := map[string]interface{}{
"model": client.Model,
"messages": messages,
"temperature": client.config.Temperature, // Use configured temperature
"max_tokens": client.MaxTokens,
}
return requestBody
}
// can be used to marshal the request body and can be overridden
func (client *Client) marshalRequestBody(requestBody map[string]any) ([]byte, error) {
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to serialize request: %w", err)
}
return jsonData, nil
}
func (client *Client) parseMCPResponse(body []byte) (string, error) {
var result struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &result); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
if len(result.Choices) == 0 {
return "", fmt.Errorf("API returned empty response")
}
return result.Choices[0].Message.Content, nil
}
func (client *Client) buildUrl() string {
if client.UseFullURL {
return client.BaseURL
}
return fmt.Sprintf("%s/chat/completions", client.BaseURL)
}
func (client *Client) buildRequest(url string, jsonData []byte) (*http.Request, error) {
// Create HTTP request
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("fail to build request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
// Set auth header via hooks (supports overriding in subclass)
client.hooks.setAuthHeader(req.Header)
return req, nil
}
// call single AI API call (fixed flow, cannot be overridden)
func (client *Client) call(systemPrompt, userPrompt string) (string, error) {
// Print current AI configuration
client.logger.Infof("📡 [%s] Request AI Server: BaseURL: %s", client.String(), client.BaseURL)
client.logger.Debugf("[%s] UseFullURL: %v", client.String(), client.UseFullURL)
if len(client.APIKey) > 8 {
client.logger.Debugf("[%s] API Key: %s...%s", client.String(), client.APIKey[:4], client.APIKey[len(client.APIKey)-4:])
}
// Step 1: Build request body (via hooks for dynamic dispatch)
requestBody := client.hooks.buildMCPRequestBody(systemPrompt, userPrompt)
// Step 2: Serialize request body (via hooks for dynamic dispatch)
jsonData, err := client.hooks.marshalRequestBody(requestBody)
if err != nil {
return "", err
}
// Step 3: Build URL (via hooks for dynamic dispatch)
url := client.hooks.buildUrl()
client.logger.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
// Step 4: Create HTTP request (fixed logic)
req, err := client.hooks.buildRequest(url, jsonData)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
// Step 5: Send HTTP request (fixed logic)
resp, err := client.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
// Step 6: Read response body (fixed logic)
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
// Step 7: Check HTTP status code (fixed logic)
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
}
// Step 8: Parse response (via hooks for dynamic dispatch)
result, err := client.hooks.parseMCPResponse(body)
if err != nil {
return "", fmt.Errorf("fail to parse AI server response: %w", err)
}
return result, nil
}
func (client *Client) String() string {
return fmt.Sprintf("[Provider: %s, Model: %s]",
client.Provider, client.Model)
}
// isRetryableError determines if error is retryable (network errors, timeouts, etc.)
func (client *Client) isRetryableError(err error) bool {
errStr := err.Error()
// Network errors, timeouts, EOF, etc. can be retried
for _, retryable := range client.config.RetryableErrors {
if strings.Contains(errStr, retryable) {
return true
}
}
return false
}
// ============================================================
// Builder Pattern API (Advanced Features)
// ============================================================
// CallWithRequest calls AI API using Request object (supports advanced features)
//
// This method supports:
// - Multi-turn conversation history
// - Fine-grained parameter control (temperature, top_p, penalties, etc.)
// - Function Calling / Tools
// - Streaming response (future support)
//
// Usage example:
// request := NewRequestBuilder().
// WithSystemPrompt("You are helpful").
// WithUserPrompt("Hello").
// WithTemperature(0.8).
// Build()
// result, err := client.CallWithRequest(request)
func (client *Client) CallWithRequest(req *Request) (string, error) {
if client.APIKey == "" {
return "", fmt.Errorf("AI API key not set, please call SetAPIKey first")
}
// If Model is not set in Request, use Client's Model
if req.Model == "" {
req.Model = client.Model
}
// Fixed retry flow
var lastErr error
maxRetries := client.config.MaxRetries
for attempt := 1; attempt <= maxRetries; attempt++ {
if attempt > 1 {
client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
}
// Call single request
result, err := client.callWithRequest(req)
if err == nil {
if attempt > 1 {
client.logger.Infof("✓ AI API retry succeeded")
}
return result, nil
}
lastErr = err
// Check if error is retryable
if !client.hooks.isRetryableError(err) {
return "", err
}
// Wait before retry
if attempt < maxRetries {
waitTime := client.config.RetryWaitBase * time.Duration(attempt)
client.logger.Infof("⏳ Waiting %v before retry...", waitTime)
time.Sleep(waitTime)
}
}
return "", fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
}
// callWithRequest single AI API call (using Request object)
func (client *Client) callWithRequest(req *Request) (string, error) {
// Print current AI configuration
client.logger.Infof("📡 [%s] Request AI Server with Builder: BaseURL: %s", client.String(), client.BaseURL)
client.logger.Debugf("[%s] Messages count: %d", client.String(), len(req.Messages))
// Build request body (from Request object)
requestBody := client.buildRequestBodyFromRequest(req)
// Serialize request body
jsonData, err := client.hooks.marshalRequestBody(requestBody)
if err != nil {
return "", err
}
// Build URL
url := client.hooks.buildUrl()
client.logger.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
// Create HTTP request
httpReq, err := client.hooks.buildRequest(url, jsonData)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
// Send HTTP request
resp, err := client.httpClient.Do(httpReq)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
// Check HTTP status code
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
}
// Parse response
result, err := client.hooks.parseMCPResponse(body)
if err != nil {
return "", fmt.Errorf("fail to parse AI server response: %w", err)
}
return result, nil
}
// buildRequestBodyFromRequest builds request body from Request object
func (client *Client) buildRequestBodyFromRequest(req *Request) map[string]any {
// Convert Message to API format
messages := make([]map[string]string, 0, len(req.Messages))
for _, msg := range req.Messages {
messages = append(messages, map[string]string{
"role": msg.Role,
"content": msg.Content,
})
}
// Build basic request body
requestBody := map[string]interface{}{
"model": req.Model,
"messages": messages,
}
// Add optional parameters (only add non-nil parameters)
if req.Temperature != nil {
requestBody["temperature"] = *req.Temperature
} else {
// If not set in Request, use Client's configuration
requestBody["temperature"] = client.config.Temperature
}
if req.MaxTokens != nil {
requestBody["max_tokens"] = *req.MaxTokens
} else {
// If not set in Request, use Client's MaxTokens
requestBody["max_tokens"] = client.MaxTokens
}
if req.TopP != nil {
requestBody["top_p"] = *req.TopP
}
if req.FrequencyPenalty != nil {
requestBody["frequency_penalty"] = *req.FrequencyPenalty
}
if req.PresencePenalty != nil {
requestBody["presence_penalty"] = *req.PresencePenalty
}
if len(req.Stop) > 0 {
requestBody["stop"] = req.Stop
}
if len(req.Tools) > 0 {
requestBody["tools"] = req.Tools
}
if req.ToolChoice != "" {
requestBody["tool_choice"] = req.ToolChoice
}
if req.Stream {
requestBody["stream"] = true
}
return requestBody
}