diff --git a/mcp/client.go b/mcp/client.go index a35d92e7..ab34f402 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -5,20 +5,32 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" - "os" - "strconv" "strings" "time" ) const ( ProviderCustom = "custom" + + MCPClientTemperature = 0.5 ) var ( DefaultTimeout = 120 * time.Second + + MaxRetryTimes = 3 + + retryableErrors = []string{ + "EOF", + "timeout", + "connection reset", + "connection refused", + "temporary failure", + "no such host", + "stream error", // HTTP/2 stream 错误 + "INTERNAL_ERROR", // 服务端内部错误 + } ) // Client AI API配置 @@ -27,31 +39,77 @@ type Client struct { APIKey string BaseURL string Model string - Timeout time.Duration UseFullURL bool // 是否使用完整URL(不添加/chat/completions) MaxTokens int // AI响应的最大token数 + + httpClient *http.Client + logger Logger // 日志器(可替换) + config *Config // 配置对象(保存所有配置) + + // hooks 用于实现动态分派(多态) + // 当 DeepSeekClient 嵌入 Client 时,hooks 指向 DeepSeekClient + // 这样 call() 中调用的方法会自动分派到子类重写的版本 + hooks clientHooks } +// New 创建默认客户端(向前兼容) +// +// Deprecated: 推荐使用 NewClient(...opts) 以获得更好的灵活性 func New() AIClient { - // 从环境变量读取 MaxTokens,默认 2000 - maxTokens := 2000 - if envMaxTokens := os.Getenv("AI_MAX_TOKENS"); envMaxTokens != "" { - if parsed, err := strconv.Atoi(envMaxTokens); err == nil && parsed > 0 { - maxTokens = parsed - log.Printf("🔧 [MCP] 使用环境变量 AI_MAX_TOKENS: %d", maxTokens) - } else { - log.Printf("⚠️ [MCP] 环境变量 AI_MAX_TOKENS 无效 (%s),使用默认值: %d", envMaxTokens, maxTokens) - } + return NewClient() +} + +// NewClient 创建客户端(支持选项模式) +// +// 使用示例: +// // 基础用法(向前兼容) +// client := mcp.NewClient() +// +// // 自定义日志 +// client := mcp.NewClient(mcp.WithLogger(customLogger)) +// +// // 自定义超时 +// client := mcp.NewClient(mcp.WithTimeout(60*time.Second)) +// +// // 组合多个选项 +// client := mcp.NewClient( +// mcp.WithDeepSeekConfig("sk-xxx"), +// mcp.WithLogger(customLogger), +// mcp.WithTimeout(60*time.Second), +// ) +func NewClient(opts ...ClientOption) AIClient { + // 1. 创建默认配置 + cfg := DefaultConfig() + + // 2. 应用用户选项 + for _, opt := range opts { + opt(cfg) } - // 默认配置 - return &Client{ - Provider: ProviderDeepSeek, - BaseURL: DefaultDeepSeekBaseURL, - Model: DefaultDeepSeekModel, - Timeout: DefaultTimeout, - MaxTokens: maxTokens, + // 3. 创建客户端实例 + client := &Client{ + Provider: cfg.Provider, + APIKey: cfg.APIKey, + BaseURL: cfg.BaseURL, + Model: cfg.Model, + MaxTokens: cfg.MaxTokens, + UseFullURL: cfg.UseFullURL, + httpClient: cfg.HTTPClient, + logger: cfg.Logger, + config: cfg, } + + // 4. 设置默认 Provider(如果未设置) + if client.Provider == "" { + client.Provider = ProviderDeepSeek + client.BaseURL = DefaultDeepSeekBaseURL + client.Model = DefaultDeepSeekModel + } + + // 5. 设置 hooks 指向自己 + client.hooks = client + + return client } // SetCustomAPI 设置自定义OpenAI兼容API @@ -69,42 +127,46 @@ func (client *Client) SetAPIKey(apiKey, apiURL, customModel string) { } client.Model = customModel - client.Timeout = 120 * time.Second } -// CallWithMessages 使用 system + user prompt 调用AI API(推荐) +func (client *Client) SetTimeout(timeout time.Duration) { + client.httpClient.Timeout = timeout +} + +// CallWithMessages 模板方法 - 固定的重试流程(不可重写) func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, error) { if client.APIKey == "" { return "", fmt.Errorf("AI API密钥未设置,请先调用 SetAPIKey") } - // 重试配置 - maxRetries := 3 + // 固定的重试流程 var lastErr error + maxRetries := client.config.MaxRetries for attempt := 1; attempt <= maxRetries; attempt++ { if attempt > 1 { - fmt.Printf("⚠️ AI API调用失败,正在重试 (%d/%d)...\n", attempt, maxRetries) + client.logger.Warnf("⚠️ AI API调用失败,正在重试 (%d/%d)...", attempt, maxRetries) } - result, err := client.callOnce(systemPrompt, userPrompt) + // 调用固定的单次调用流程 + result, err := client.hooks.call(systemPrompt, userPrompt) if err == nil { if attempt > 1 { - fmt.Printf("✓ AI API重试成功\n") + client.logger.Infof("✓ AI API重试成功") } return result, nil } lastErr = err - // 如果不是网络错误,不重试 - if !isRetryableError(err) { + // 通过 hooks 判断是否可重试(支持子类自定义重试策略) + if !client.hooks.isRetryableError(err) { return "", err } // 重试前等待 if attempt < maxRetries { - waitTime := time.Duration(attempt) * 2 * time.Second - fmt.Printf("⏳ 等待%v后重试...\n", waitTime) + waitTime := client.config.RetryWaitBase * time.Duration(attempt) + client.logger.Infof("⏳ 等待%v后重试...", waitTime) time.Sleep(waitTime) } } @@ -116,18 +178,7 @@ func (client *Client) setAuthHeader(reqHeader http.Header) { reqHeader.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey)) } -// callOnce 单次调用AI API(内部使用) -func (client *Client) callOnce(systemPrompt, userPrompt string) (string, error) { - // 打印当前 AI 配置 - log.Printf("📡 [MCP] AI 请求配置:") - log.Printf(" Provider: %s", client.Provider) - log.Printf(" BaseURL: %s", client.BaseURL) - log.Printf(" Model: %s", client.Model) - log.Printf(" UseFullURL: %v", client.UseFullURL) - if len(client.APIKey) > 8 { - log.Printf(" API Key: %s...%s", client.APIKey[:4], client.APIKey[len(client.APIKey)-4:]) - } - +func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any { // 构建 messages 数组 messages := []map[string]string{} @@ -138,7 +189,6 @@ func (client *Client) callOnce(systemPrompt, userPrompt string) (string, error) "content": systemPrompt, }) } - // 添加 user message messages = append(messages, map[string]string{ "role": "user", @@ -149,57 +199,22 @@ func (client *Client) callOnce(systemPrompt, userPrompt string) (string, error) requestBody := map[string]interface{}{ "model": client.Model, "messages": messages, - "temperature": 0.5, // 降低temperature以提高JSON格式稳定性 + "temperature": client.config.Temperature, // 使用配置的 temperature "max_tokens": client.MaxTokens, } + return requestBody +} - // 注意:response_format 参数仅 OpenAI 支持,DeepSeek/Qwen 不支持 - // 我们通过强化 prompt 和后处理来确保 JSON 格式正确 - +// can be used to marshal the request body and can be overridden +func (client *Client) marshalRequestBody(requestBody map[string]any) ([]byte, error) { jsonData, err := json.Marshal(requestBody) if err != nil { - return "", fmt.Errorf("序列化请求失败: %w", err) + return nil, fmt.Errorf("序列化请求失败: %w", err) } + return jsonData, nil +} - // 创建HTTP请求 - var url string - if client.UseFullURL { - // 使用完整URL,不添加/chat/completions - url = client.BaseURL - } else { - // 默认行为:添加/chat/completions - url = fmt.Sprintf("%s/chat/completions", client.BaseURL) - } - log.Printf("📡 [MCP] 请求 URL: %s", url) - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) - if err != nil { - return "", fmt.Errorf("创建请求失败: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - - client.setAuthHeader(req.Header) - - // 发送请求 - httpClient := &http.Client{Timeout: client.Timeout} - resp, err := httpClient.Do(req) - if err != nil { - return "", fmt.Errorf("发送请求失败: %w", err) - } - defer resp.Body.Close() - - // 读取响应 - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("读取响应失败: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("API返回错误 (status %d): %s", resp.StatusCode, string(body)) - } - - // 解析响应 +func (client *Client) parseMCPResponse(body []byte) (string, error) { var result struct { Choices []struct { Message struct { @@ -219,24 +234,275 @@ func (client *Client) callOnce(systemPrompt, userPrompt string) (string, error) return result.Choices[0].Message.Content, nil } -// isRetryableError 判断错误是否可重试 -func isRetryableError(err error) bool { +func (client *Client) buildUrl() string { + if client.UseFullURL { + return client.BaseURL + } + return fmt.Sprintf("%s/chat/completions", client.BaseURL) +} + +func (client *Client) buildRequest(url string, jsonData []byte) (*http.Request, error) { + // Create HTTP request + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("fail to build request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + // 通过 hooks 设置认证头(支持子类重写) + client.hooks.setAuthHeader(req.Header) + + return req, nil +} + +// call 单次调用AI API(固定流程,不可重写) +func (client *Client) call(systemPrompt, userPrompt string) (string, error) { + // 打印当前 AI 配置 + client.logger.Infof("📡 [%s] Request AI Server: BaseURL: %s", client.String(), client.BaseURL) + client.logger.Debugf("[%s] UseFullURL: %v", client.String(), client.UseFullURL) + if len(client.APIKey) > 8 { + client.logger.Debugf("[%s] API Key: %s...%s", client.String(), client.APIKey[:4], client.APIKey[len(client.APIKey)-4:]) + } + + // Step 1: 构建请求体(通过 hooks 实现动态分派) + requestBody := client.hooks.buildMCPRequestBody(systemPrompt, userPrompt) + + // Step 2: 序列化请求体(通过 hooks 实现动态分派) + jsonData, err := client.hooks.marshalRequestBody(requestBody) + if err != nil { + return "", err + } + + // Step 3: 构建 URL(通过 hooks 实现动态分派) + url := client.hooks.buildUrl() + client.logger.Infof("📡 [MCP %s] 请求 URL: %s", client.String(), url) + + // Step 4: 创建 HTTP 请求(固定逻辑) + req, err := client.hooks.buildRequest(url, jsonData) + if err != nil { + return "", fmt.Errorf("创建请求失败: %w", err) + } + + // Step 5: 发送 HTTP 请求(固定逻辑) + resp, err := client.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("发送请求失败: %w", err) + } + defer resp.Body.Close() + + // Step 6: 读取响应体(固定逻辑) + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("读取响应失败: %w", err) + } + + // Step 7: 检查 HTTP 状态码(固定逻辑) + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("API返回错误 (status %d): %s", resp.StatusCode, string(body)) + } + + // Step 8: 解析响应(通过 hooks 实现动态分派) + result, err := client.hooks.parseMCPResponse(body) + if err != nil { + return "", fmt.Errorf("fail to parse AI server response: %w", err) + } + + return result, nil +} + +func (client *Client) String() string { + return fmt.Sprintf("[Provider: %s, Model: %s]", + client.Provider, client.Model) +} + +// isRetryableError 判断错误是否可重试(网络错误、超时等) +func (client *Client) isRetryableError(err error) bool { errStr := err.Error() // 网络错误、超时、EOF等可以重试 - retryableErrors := []string{ - "EOF", - "timeout", - "connection reset", - "connection refused", - "temporary failure", - "no such host", - "stream error", // HTTP/2 stream 错误 - "INTERNAL_ERROR", // 服务端内部错误 - } - for _, retryable := range retryableErrors { + for _, retryable := range client.config.RetryableErrors { if strings.Contains(errStr, retryable) { return true } } return false } + +// ============================================================ +// 构建器模式 API(高级功能) +// ============================================================ + +// CallWithRequest 使用 Request 对象调用 AI API(支持高级功能) +// +// 此方法支持: +// - 多轮对话历史 +// - 精细参数控制(temperature、top_p、penalties 等) +// - Function Calling / Tools +// - 流式响应(未来支持) +// +// 使用示例: +// request := NewRequestBuilder(). +// WithSystemPrompt("You are helpful"). +// WithUserPrompt("Hello"). +// WithTemperature(0.8). +// Build() +// result, err := client.CallWithRequest(request) +func (client *Client) CallWithRequest(req *Request) (string, error) { + if client.APIKey == "" { + return "", fmt.Errorf("AI API密钥未设置,请先调用 SetAPIKey") + } + + // 如果 Request 中没有设置 Model,使用 Client 的 Model + if req.Model == "" { + req.Model = client.Model + } + + // 固定的重试流程 + var lastErr error + maxRetries := client.config.MaxRetries + + for attempt := 1; attempt <= maxRetries; attempt++ { + if attempt > 1 { + client.logger.Warnf("⚠️ AI API调用失败,正在重试 (%d/%d)...", attempt, maxRetries) + } + + // 调用单次请求 + result, err := client.callWithRequest(req) + if err == nil { + if attempt > 1 { + client.logger.Infof("✓ AI API重试成功") + } + return result, nil + } + + lastErr = err + // 判断是否可重试 + if !client.hooks.isRetryableError(err) { + return "", err + } + + // 重试前等待 + if attempt < maxRetries { + waitTime := client.config.RetryWaitBase * time.Duration(attempt) + client.logger.Infof("⏳ 等待%v后重试...", waitTime) + time.Sleep(waitTime) + } + } + + return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr) +} + +// callWithRequest 单次调用 AI API(使用 Request 对象) +func (client *Client) callWithRequest(req *Request) (string, error) { + // 打印当前 AI 配置 + client.logger.Infof("📡 [%s] Request AI Server with Builder: BaseURL: %s", client.String(), client.BaseURL) + client.logger.Debugf("[%s] Messages count: %d", client.String(), len(req.Messages)) + + // 构建请求体(从 Request 对象) + requestBody := client.buildRequestBodyFromRequest(req) + + // 序列化请求体 + jsonData, err := client.hooks.marshalRequestBody(requestBody) + if err != nil { + return "", err + } + + // 构建 URL + url := client.hooks.buildUrl() + client.logger.Infof("📡 [MCP %s] 请求 URL: %s", client.String(), url) + + // 创建 HTTP 请求 + httpReq, err := client.hooks.buildRequest(url, jsonData) + if err != nil { + return "", fmt.Errorf("创建请求失败: %w", err) + } + + // 发送 HTTP 请求 + resp, err := client.httpClient.Do(httpReq) + if err != nil { + return "", fmt.Errorf("发送请求失败: %w", err) + } + defer resp.Body.Close() + + // 读取响应体 + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("读取响应失败: %w", err) + } + + // 检查 HTTP 状态码 + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("API返回错误 (status %d): %s", resp.StatusCode, string(body)) + } + + // 解析响应 + result, err := client.hooks.parseMCPResponse(body) + if err != nil { + return "", fmt.Errorf("fail to parse AI server response: %w", err) + } + + return result, nil +} + +// buildRequestBodyFromRequest 从 Request 对象构建请求体 +func (client *Client) buildRequestBodyFromRequest(req *Request) map[string]any { + // 转换 Message 为 API 格式 + messages := make([]map[string]string, 0, len(req.Messages)) + for _, msg := range req.Messages { + messages = append(messages, map[string]string{ + "role": msg.Role, + "content": msg.Content, + }) + } + + // 构建基础请求体 + requestBody := map[string]interface{}{ + "model": req.Model, + "messages": messages, + } + + // 添加可选参数(只添加非 nil 的参数) + if req.Temperature != nil { + requestBody["temperature"] = *req.Temperature + } else { + // 如果 Request 中没有设置,使用 Client 的配置 + requestBody["temperature"] = client.config.Temperature + } + + if req.MaxTokens != nil { + requestBody["max_tokens"] = *req.MaxTokens + } else { + // 如果 Request 中没有设置,使用 Client 的 MaxTokens + requestBody["max_tokens"] = client.MaxTokens + } + + if req.TopP != nil { + requestBody["top_p"] = *req.TopP + } + + if req.FrequencyPenalty != nil { + requestBody["frequency_penalty"] = *req.FrequencyPenalty + } + + if req.PresencePenalty != nil { + requestBody["presence_penalty"] = *req.PresencePenalty + } + + if len(req.Stop) > 0 { + requestBody["stop"] = req.Stop + } + + if len(req.Tools) > 0 { + requestBody["tools"] = req.Tools + } + + if req.ToolChoice != "" { + requestBody["tool_choice"] = req.ToolChoice + } + + if req.Stream { + requestBody["stream"] = true + } + + return requestBody +} diff --git a/mcp/client_test.go b/mcp/client_test.go new file mode 100644 index 00000000..4a0e9e46 --- /dev/null +++ b/mcp/client_test.go @@ -0,0 +1,419 @@ +package mcp + +import ( + "errors" + "net/http" + "testing" + "time" +) + +// ============================================================ +// 测试 Client 创建和配置 +// ============================================================ + +func TestNewClient_Default(t *testing.T) { + client := NewClient() + + if client == nil { + t.Fatal("client should not be nil") + } + + c := client.(*Client) + if c.Provider == "" { + t.Error("Provider should have default value") + } + + if c.MaxTokens <= 0 { + t.Error("MaxTokens should be positive") + } + + if c.logger == nil { + t.Error("logger should not be nil") + } + + if c.httpClient == nil { + t.Error("httpClient should not be nil") + } + + if c.hooks == nil { + t.Error("hooks should not be nil") + } +} + +func TestNewClient_WithOptions(t *testing.T) { + mockLogger := NewMockLogger() + mockHTTP := &http.Client{Timeout: 30 * time.Second} + + client := NewClient( + WithLogger(mockLogger), + WithHTTPClient(mockHTTP), + WithMaxTokens(4000), + WithTimeout(60*time.Second), + WithAPIKey("test-key"), + ) + + c := client.(*Client) + + if c.logger != mockLogger { + t.Error("logger should be set from option") + } + + if c.httpClient != mockHTTP { + t.Error("httpClient should be set from option") + } + + if c.MaxTokens != 4000 { + t.Error("MaxTokens should be 4000") + } + + if c.APIKey != "test-key" { + t.Error("APIKey should be test-key") + } +} + +// ============================================================ +// 测试 CallWithMessages +// ============================================================ + +func TestClient_CallWithMessages_Success(t *testing.T) { + mockHTTP := NewMockHTTPClient() + mockHTTP.SetSuccessResponse("AI response content") + mockLogger := NewMockLogger() + + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(mockLogger), + WithAPIKey("test-key"), + WithBaseURL("https://api.test.com"), + ) + + result, err := client.CallWithMessages("system prompt", "user prompt") + + if err != nil { + t.Fatalf("should not error: %v", err) + } + + if result != "AI response content" { + t.Errorf("expected 'AI response content', got '%s'", result) + } + + // 验证请求 + requests := mockHTTP.GetRequests() + if len(requests) != 1 { + t.Errorf("expected 1 request, got %d", len(requests)) + } + + if len(requests) > 0 { + req := requests[0] + if req.Header.Get("Authorization") == "" { + t.Error("Authorization header should be set") + } + if req.Header.Get("Content-Type") != "application/json" { + t.Error("Content-Type should be application/json") + } + } +} + +func TestClient_CallWithMessages_NoAPIKey(t *testing.T) { + client := NewClient() + + _, err := client.CallWithMessages("system", "user") + + if err == nil { + t.Error("should error when API key is not set") + } + + if err.Error() != "AI API密钥未设置,请先调用 SetAPIKey" { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestClient_CallWithMessages_HTTPError(t *testing.T) { + mockHTTP := NewMockHTTPClient() + mockHTTP.SetErrorResponse(500, "Internal Server Error") + mockLogger := NewMockLogger() + + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(mockLogger), + WithAPIKey("test-key"), + ) + + _, err := client.CallWithMessages("system", "user") + + if err == nil { + t.Error("should error on HTTP error") + } +} + +// ============================================================ +// 测试重试逻辑 +// ============================================================ + +func TestClient_Retry_Success(t *testing.T) { + mockHTTP := NewMockHTTPClient() + mockLogger := NewMockLogger() + + // 模拟:第一次失败,第二次成功 + callCount := 0 + mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) { + callCount++ + if callCount == 1 { + return nil, errors.New("connection reset") + } + return &http.Response{ + StatusCode: 200, + Body: http.NoBody, + }, nil + } + + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(mockLogger), + WithAPIKey("test-key"), + WithMaxRetries(3), + ) + + // 由于我们的 client 使用 hooks.call,需要特殊处理 + // 这里我们测试的是 CallWithMessages 会调用 retry 逻辑 + c := client.(*Client) + + // 临时修改重试等待时间为 0 以加速测试 + oldRetries := MaxRetryTimes + MaxRetryTimes = 3 + defer func() { MaxRetryTimes = oldRetries }() + + _, err := c.CallWithMessages("system", "user") + + // 第一次失败(connection reset),第二次成功,但是响应格式不对,会失败 + // 但至少验证了重试逻辑被触发 + if callCount < 2 { + t.Errorf("should retry, got %d calls", callCount) + } + + // 检查日志中是否有重试信息 + logs := mockLogger.GetLogsByLevel("WARN") + hasRetryLog := false + for _, log := range logs { + if log.Message == "⚠️ AI API调用失败,正在重试 (2/3)..." { + hasRetryLog = true + break + } + } + + if !hasRetryLog && callCount >= 2 { + // 如果确实重试了,应该有警告日志 + // 但由于我们的测试设置,可能不会触发,所以这里只是检查 + t.Log("Retry was attempted") + } + + _ = err // 忽略错误,我们主要测试重试逻辑被触发 +} + +func TestClient_Retry_NonRetryableError(t *testing.T) { + mockHTTP := NewMockHTTPClient() + mockHTTP.SetErrorResponse(400, "Bad Request") + mockLogger := NewMockLogger() + + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(mockLogger), + WithAPIKey("test-key"), + ) + + _, err := client.CallWithMessages("system", "user") + + if err == nil { + t.Error("should error") + } + + // 验证没有重试(因为 400 不是可重试错误) + requests := mockHTTP.GetRequests() + if len(requests) != 1 { + t.Errorf("should not retry for 400 error, got %d requests", len(requests)) + } +} + +// ============================================================ +// 测试钩子方法 +// ============================================================ + +func TestClient_BuildMCPRequestBody(t *testing.T) { + client := NewClient() + c := client.(*Client) + + body := c.buildMCPRequestBody("system prompt", "user prompt") + + if body == nil { + t.Fatal("body should not be nil") + } + + if body["model"] == nil { + t.Error("body should have model field") + } + + messages, ok := body["messages"].([]map[string]string) + if !ok { + t.Fatal("messages should be []map[string]string") + } + + if len(messages) != 2 { + t.Errorf("expected 2 messages, got %d", len(messages)) + } + + if messages[0]["role"] != "system" { + t.Error("first message should be system") + } + + if messages[1]["role"] != "user" { + t.Error("second message should be user") + } +} + +func TestClient_BuildUrl(t *testing.T) { + tests := []struct { + name string + baseURL string + useFullURL bool + expected string + }{ + { + name: "normal URL", + baseURL: "https://api.test.com/v1", + useFullURL: false, + expected: "https://api.test.com/v1/chat/completions", + }, + { + name: "full URL", + baseURL: "https://api.test.com/custom/endpoint", + useFullURL: true, + expected: "https://api.test.com/custom/endpoint", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := NewClient( + WithProvider("test-provider"), // Prevent default DeepSeek settings + WithBaseURL(tt.baseURL), + WithUseFullURL(tt.useFullURL), + ) + c := client.(*Client) + + url := c.buildUrl() + if url != tt.expected { + t.Errorf("expected '%s', got '%s'", tt.expected, url) + } + }) + } +} + +func TestClient_SetAuthHeader(t *testing.T) { + client := NewClient(WithAPIKey("test-api-key")) + c := client.(*Client) + + headers := make(http.Header) + c.setAuthHeader(headers) + + authHeader := headers.Get("Authorization") + if authHeader != "Bearer test-api-key" { + t.Errorf("expected 'Bearer test-api-key', got '%s'", authHeader) + } +} + +func TestClient_IsRetryableError(t *testing.T) { + client := NewClient() + c := client.(*Client) + + tests := []struct { + name string + err error + expected bool + }{ + { + name: "EOF error", + err: errors.New("unexpected EOF"), + expected: true, + }, + { + name: "timeout error", + err: errors.New("timeout exceeded"), + expected: true, + }, + { + name: "connection reset", + err: errors.New("connection reset by peer"), + expected: true, + }, + { + name: "normal error", + err: errors.New("bad request"), + expected: false, + }, + { + name: "validation error", + err: errors.New("invalid input"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := c.isRetryableError(tt.err) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +// ============================================================ +// 测试 SetTimeout +// ============================================================ + +func TestClient_SetTimeout(t *testing.T) { + client := NewClient() + + newTimeout := 90 * time.Second + client.SetTimeout(newTimeout) + + c := client.(*Client) + if c.httpClient.Timeout != newTimeout { + t.Errorf("expected timeout %v, got %v", newTimeout, c.httpClient.Timeout) + } +} + +// ============================================================ +// 测试 String 方法 +// ============================================================ + +func TestClient_String(t *testing.T) { + client := NewClient( + WithProvider("test-provider"), + WithModel("test-model"), + ) + + c := client.(*Client) + str := c.String() + + expectedContains := []string{"test-provider", "test-model"} + for _, exp := range expectedContains { + if !contains(str, exp) { + t.Errorf("String() should contain '%s', got '%s'", exp, str) + } + } +} + +// 辅助函数 +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr)) +} + +func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/mcp/config.go b/mcp/config.go new file mode 100644 index 00000000..a32686a5 --- /dev/null +++ b/mcp/config.go @@ -0,0 +1,69 @@ +package mcp + +import ( + "net/http" + "os" + "strconv" + "time" +) + +// Config 客户端配置(集中管理所有配置) +type Config struct { + // Provider 配置 + Provider string + APIKey string + BaseURL string + Model string + + // 行为配置 + MaxTokens int + Temperature float64 + UseFullURL bool + + // 重试配置 + MaxRetries int + RetryWaitBase time.Duration + RetryableErrors []string + + // 超时配置 + Timeout time.Duration + + // 依赖注入 + Logger Logger + HTTPClient *http.Client +} + +// DefaultConfig 返回默认配置 +func DefaultConfig() *Config { + return &Config{ + // 默认值 + MaxTokens: getEnvInt("AI_MAX_TOKENS", 2000), + Temperature: MCPClientTemperature, + MaxRetries: MaxRetryTimes, + RetryWaitBase: 2 * time.Second, + Timeout: DefaultTimeout, + RetryableErrors: retryableErrors, + + // 默认依赖 + Logger: &defaultLogger{}, + HTTPClient: &http.Client{Timeout: DefaultTimeout}, + } +} + +// getEnvInt 从环境变量读取整数,失败则返回默认值 +func getEnvInt(key string, defaultValue int) int { + if val := os.Getenv(key); val != "" { + if parsed, err := strconv.Atoi(val); err == nil && parsed > 0 { + return parsed + } + } + return defaultValue +} + +// getEnvString 从环境变量读取字符串,为空则返回默认值 +func getEnvString(key string, defaultValue string) string { + if val := os.Getenv(key); val != "" { + return val + } + return defaultValue +} diff --git a/mcp/config_usage_test.go b/mcp/config_usage_test.go new file mode 100644 index 00000000..0972cb20 --- /dev/null +++ b/mcp/config_usage_test.go @@ -0,0 +1,262 @@ +package mcp + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "testing" + "time" +) + +// ============================================================ +// 测试 Config 字段真正被使用(验证问题2修复) +// ============================================================ + +func TestConfig_MaxRetries_IsUsed(t *testing.T) { + mockHTTP := NewMockHTTPClient() + mockLogger := NewMockLogger() + + // 设置 HTTP 客户端返回错误 + callCount := 0 + mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) { + callCount++ + return nil, errors.New("connection reset") + } + + // 创建客户端并设置自定义重试次数为 5 + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(mockLogger), + WithAPIKey("sk-test-key"), + WithMaxRetries(5), // ✅ 设置重试5次 + ) + + // 调用 API(应该失败) + _, err := client.CallWithMessages("system", "user") + + if err == nil { + t.Error("should error") + } + + // 验证确实重试了5次(而不是默认的3次) + if callCount != 5 { + t.Errorf("expected 5 retry attempts (from WithMaxRetries(5)), got %d", callCount) + } + + // 验证日志中显示正确的重试次数 + logs := mockLogger.GetLogsByLevel("WARN") + expectedWarningCount := 4 // 第2、3、4、5次重试时会打印警告 + actualWarningCount := 0 + for _, log := range logs { + if log.Message == "⚠️ AI API调用失败,正在重试 (2/5)..." || + log.Message == "⚠️ AI API调用失败,正在重试 (3/5)..." || + log.Message == "⚠️ AI API调用失败,正在重试 (4/5)..." || + log.Message == "⚠️ AI API调用失败,正在重试 (5/5)..." { + actualWarningCount++ + } + } + + if actualWarningCount != expectedWarningCount { + t.Errorf("expected %d warning logs, got %d", expectedWarningCount, actualWarningCount) + for _, log := range logs { + t.Logf(" WARN: %s", log.Message) + } + } +} + +func TestConfig_Temperature_IsUsed(t *testing.T) { + mockHTTP := NewMockHTTPClient() + mockHTTP.SetSuccessResponse("AI response") + mockLogger := NewMockLogger() + + customTemperature := 0.8 + + // 创建客户端并设置自定义 temperature + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(mockLogger), + WithAPIKey("sk-test-key"), + WithTemperature(customTemperature), // ✅ 设置自定义 temperature + ) + + c := client.(*Client) + + // 构建请求体 + requestBody := c.buildMCPRequestBody("system", "user") + + // 验证 temperature 字段 + temp, ok := requestBody["temperature"].(float64) + if !ok { + t.Fatal("temperature should be float64") + } + + if temp != customTemperature { + t.Errorf("expected temperature %f (from WithTemperature), got %f", customTemperature, temp) + } + + // 也可以通过实际 HTTP 请求验证 + _, err := client.CallWithMessages("system", "user") + if err != nil { + t.Fatalf("should not error: %v", err) + } + + // 检查发送的请求体 + requests := mockHTTP.GetRequests() + if len(requests) != 1 { + t.Fatalf("expected 1 request, got %d", len(requests)) + } + + // 解析请求体 + var body map[string]interface{} + decoder := json.NewDecoder(requests[0].Body) + if err := decoder.Decode(&body); err != nil { + t.Fatalf("failed to decode request body: %v", err) + } + + // 验证 temperature + if body["temperature"] != customTemperature { + t.Errorf("expected temperature %f in HTTP request, got %v", customTemperature, body["temperature"]) + } +} + +func TestConfig_RetryWaitBase_IsUsed(t *testing.T) { + mockHTTP := NewMockHTTPClient() + mockLogger := NewMockLogger() + + // 设置成功响应(在 ResponseFunc 之前) + mockHTTP.SetSuccessResponse("AI response") + + // 设置 HTTP 客户端前2次返回错误,第3次成功 + callCount := 0 + successResponse := mockHTTP.Response // 保存成功响应字符串 + mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) { + callCount++ + if callCount <= 2 { + return nil, errors.New("timeout exceeded") + } + // 第3次返回成功响应 + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(successResponse)), + Header: make(http.Header), + }, nil + } + + // 设置自定义重试等待基数为 1 秒(而不是默认的 2 秒) + customWaitBase := 1 * time.Second + + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(mockLogger), + WithAPIKey("sk-test-key"), + WithRetryWaitBase(customWaitBase), // ✅ 设置自定义等待时间 + WithMaxRetries(3), + ) + + // 记录开始时间 + start := time.Now() + + // 调用 API + _, err := client.CallWithMessages("system", "user") + + // 记录结束时间 + elapsed := time.Since(start) + + // 第3次成功,但前面失败了2次 + if err != nil { + t.Fatalf("should succeed on 3rd attempt, got error: %v", err) + } + + if callCount != 3 { + t.Errorf("expected 3 attempts, got %d", callCount) + } + + // 验证等待时间 + // 第1次失败后等待 1s (customWaitBase * 1) + // 第2次失败后等待 2s (customWaitBase * 2) + // 总等待时间应该约为 3s (允许一些误差) + expectedWait := 3 * time.Second + tolerance := 200 * time.Millisecond + + if elapsed < expectedWait-tolerance || elapsed > expectedWait+tolerance { + t.Errorf("expected total time ~%v (with RetryWaitBase=%v), got %v", expectedWait, customWaitBase, elapsed) + } +} + +func TestConfig_RetryableErrors_IsUsed(t *testing.T) { + mockHTTP := NewMockHTTPClient() + mockLogger := NewMockLogger() + + // 自定义可重试错误列表(只包含 "custom error") + customRetryableErrors := []string{"custom error"} + + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(mockLogger), + WithAPIKey("sk-test-key"), + ) + + c := client.(*Client) + + // 修改 config 的 RetryableErrors(暂时没有 WithRetryableErrors 选项) + c.config.RetryableErrors = customRetryableErrors + + tests := []struct { + name string + err error + retryable bool + }{ + { + name: "custom error should be retryable", + err: errors.New("custom error occurred"), + retryable: true, + }, + { + name: "EOF should NOT be retryable (not in custom list)", + err: errors.New("unexpected EOF"), + retryable: false, + }, + { + name: "timeout should NOT be retryable (not in custom list)", + err: errors.New("timeout exceeded"), + retryable: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := c.isRetryableError(tt.err) + if result != tt.retryable { + t.Errorf("expected isRetryableError(%v) = %v, got %v", tt.err, tt.retryable, result) + } + }) + } +} + +// ============================================================ +// 测试默认值 +// ============================================================ + +func TestConfig_DefaultValues(t *testing.T) { + client := NewClient() + c := client.(*Client) + + // 验证默认值 + if c.config.MaxRetries != 3 { + t.Errorf("default MaxRetries should be 3, got %d", c.config.MaxRetries) + } + + if c.config.Temperature != 0.5 { + t.Errorf("default Temperature should be 0.5, got %f", c.config.Temperature) + } + + if c.config.RetryWaitBase != 2*time.Second { + t.Errorf("default RetryWaitBase should be 2s, got %v", c.config.RetryWaitBase) + } + + if len(c.config.RetryableErrors) == 0 { + t.Error("default RetryableErrors should not be empty") + } +} diff --git a/mcp/deepseek_client.go b/mcp/deepseek_client.go index 12489292..62e28490 100644 --- a/mcp/deepseek_client.go +++ b/mcp/deepseek_client.go @@ -1,7 +1,6 @@ package mcp import ( - "log" "net/http" ) @@ -15,36 +14,67 @@ type DeepSeekClient struct { *Client } +// NewDeepSeekClient 创建 DeepSeek 客户端(向前兼容) +// +// Deprecated: 推荐使用 NewDeepSeekClientWithOptions 以获得更好的灵活性 func NewDeepSeekClient() AIClient { - client := New().(*Client) - client.Provider = ProviderDeepSeek - client.Model = DefaultDeepSeekModel - client.BaseURL = DefaultDeepSeekBaseURL - return &DeepSeekClient{ - Client: client, + return NewDeepSeekClientWithOptions() +} + +// NewDeepSeekClientWithOptions 创建 DeepSeek 客户端(支持选项模式) +// +// 使用示例: +// // 基础用法 +// client := mcp.NewDeepSeekClientWithOptions() +// +// // 自定义配置 +// client := mcp.NewDeepSeekClientWithOptions( +// mcp.WithAPIKey("sk-xxx"), +// mcp.WithLogger(customLogger), +// mcp.WithTimeout(60*time.Second), +// ) +func NewDeepSeekClientWithOptions(opts ...ClientOption) AIClient { + // 1. 创建 DeepSeek 预设选项 + deepseekOpts := []ClientOption{ + WithProvider(ProviderDeepSeek), + WithModel(DefaultDeepSeekModel), + WithBaseURL(DefaultDeepSeekBaseURL), } + + // 2. 合并用户选项(用户选项优先级更高) + allOpts := append(deepseekOpts, opts...) + + // 3. 创建基础客户端 + baseClient := NewClient(allOpts...).(*Client) + + // 4. 创建 DeepSeek 客户端 + dsClient := &DeepSeekClient{ + Client: baseClient, + } + + // 5. 设置 hooks 指向 DeepSeekClient(实现动态分派) + baseClient.hooks = dsClient + + return dsClient } func (dsClient *DeepSeekClient) SetAPIKey(apiKey string, customURL string, customModel string) { - if dsClient.Client == nil { - dsClient.Client = New().(*Client) - } - dsClient.Client.APIKey = apiKey + dsClient.APIKey = apiKey if len(apiKey) > 8 { - log.Printf("🔧 [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) + dsClient.logger.Infof("🔧 [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) } if customURL != "" { - dsClient.Client.BaseURL = customURL - log.Printf("🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s", customURL) + dsClient.BaseURL = customURL + dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s", customURL) } else { - log.Printf("🔧 [MCP] DeepSeek 使用默认 BaseURL: %s", dsClient.Client.BaseURL) + dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用默认 BaseURL: %s", dsClient.BaseURL) } if customModel != "" { - dsClient.Client.Model = customModel - log.Printf("🔧 [MCP] DeepSeek 使用自定义 Model: %s", customModel) + dsClient.Model = customModel + dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用自定义 Model: %s", customModel) } else { - log.Printf("🔧 [MCP] DeepSeek 使用默认 Model: %s", dsClient.Client.Model) + dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用默认 Model: %s", dsClient.Model) } } diff --git a/mcp/deepseek_client_test.go b/mcp/deepseek_client_test.go new file mode 100644 index 00000000..8be91d52 --- /dev/null +++ b/mcp/deepseek_client_test.go @@ -0,0 +1,272 @@ +package mcp + +import ( + "testing" + "time" +) + +// ============================================================ +// 测试 DeepSeekClient 创建和配置 +// ============================================================ + +func TestNewDeepSeekClient_Default(t *testing.T) { + client := NewDeepSeekClient() + + if client == nil { + t.Fatal("client should not be nil") + } + + // 类型断言检查 + dsClient, ok := client.(*DeepSeekClient) + if !ok { + t.Fatal("client should be *DeepSeekClient") + } + + // 验证默认值 + if dsClient.Provider != ProviderDeepSeek { + t.Errorf("Provider should be '%s', got '%s'", ProviderDeepSeek, dsClient.Provider) + } + + if dsClient.BaseURL != DefaultDeepSeekBaseURL { + t.Errorf("BaseURL should be '%s', got '%s'", DefaultDeepSeekBaseURL, dsClient.BaseURL) + } + + if dsClient.Model != DefaultDeepSeekModel { + t.Errorf("Model should be '%s', got '%s'", DefaultDeepSeekModel, dsClient.Model) + } + + if dsClient.logger == nil { + t.Error("logger should not be nil") + } + + if dsClient.httpClient == nil { + t.Error("httpClient should not be nil") + } +} + +func TestNewDeepSeekClientWithOptions(t *testing.T) { + mockLogger := NewMockLogger() + customModel := "deepseek-v2" + customAPIKey := "sk-custom-key" + + client := NewDeepSeekClientWithOptions( + WithLogger(mockLogger), + WithModel(customModel), + WithAPIKey(customAPIKey), + WithMaxTokens(4000), + ) + + dsClient := client.(*DeepSeekClient) + + // 验证自定义选项被应用 + if dsClient.logger != mockLogger { + t.Error("logger should be set from option") + } + + if dsClient.Model != customModel { + t.Error("Model should be set from option") + } + + if dsClient.APIKey != customAPIKey { + t.Error("APIKey should be set from option") + } + + if dsClient.MaxTokens != 4000 { + t.Error("MaxTokens should be 4000") + } + + // 验证 DeepSeek 默认值仍然保留 + if dsClient.Provider != ProviderDeepSeek { + t.Errorf("Provider should still be '%s'", ProviderDeepSeek) + } + + if dsClient.BaseURL != DefaultDeepSeekBaseURL { + t.Errorf("BaseURL should still be '%s'", DefaultDeepSeekBaseURL) + } +} + +// ============================================================ +// 测试 SetAPIKey +// ============================================================ + +func TestDeepSeekClient_SetAPIKey(t *testing.T) { + mockLogger := NewMockLogger() + client := NewDeepSeekClientWithOptions( + WithLogger(mockLogger), + ) + + dsClient := client.(*DeepSeekClient) + + // 测试设置 API Key(默认 URL 和 Model) + dsClient.SetAPIKey("sk-test-key-12345678", "", "") + + if dsClient.APIKey != "sk-test-key-12345678" { + t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", dsClient.APIKey) + } + + // 验证日志记录 + logs := mockLogger.GetLogsByLevel("INFO") + if len(logs) == 0 { + t.Error("should have logged API key setting") + } + + // 验证 BaseURL 和 Model 保持默认 + if dsClient.BaseURL != DefaultDeepSeekBaseURL { + t.Error("BaseURL should remain default") + } + + if dsClient.Model != DefaultDeepSeekModel { + t.Error("Model should remain default") + } +} + +func TestDeepSeekClient_SetAPIKey_WithCustomURL(t *testing.T) { + mockLogger := NewMockLogger() + client := NewDeepSeekClientWithOptions( + WithLogger(mockLogger), + ) + + dsClient := client.(*DeepSeekClient) + + customURL := "https://custom.api.com/v1" + dsClient.SetAPIKey("sk-test-key-12345678", customURL, "") + + if dsClient.BaseURL != customURL { + t.Errorf("BaseURL should be '%s', got '%s'", customURL, dsClient.BaseURL) + } + + // 验证日志记录 + logs := mockLogger.GetLogsByLevel("INFO") + hasCustomURLLog := false + for _, log := range logs { + if log.Format == "🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s" { + hasCustomURLLog = true + break + } + } + + if !hasCustomURLLog { + t.Error("should have logged custom BaseURL") + } +} + +func TestDeepSeekClient_SetAPIKey_WithCustomModel(t *testing.T) { + mockLogger := NewMockLogger() + client := NewDeepSeekClientWithOptions( + WithLogger(mockLogger), + ) + + dsClient := client.(*DeepSeekClient) + + customModel := "deepseek-v3" + dsClient.SetAPIKey("sk-test-key-12345678", "", customModel) + + if dsClient.Model != customModel { + t.Errorf("Model should be '%s', got '%s'", customModel, dsClient.Model) + } + + // 验证日志记录 + logs := mockLogger.GetLogsByLevel("INFO") + hasCustomModelLog := false + for _, log := range logs { + if log.Format == "🔧 [MCP] DeepSeek 使用自定义 Model: %s" { + hasCustomModelLog = true + break + } + } + + if !hasCustomModelLog { + t.Error("should have logged custom Model") + } +} + +// ============================================================ +// 测试集成功能 +// ============================================================ + +func TestDeepSeekClient_CallWithMessages_Success(t *testing.T) { + mockHTTP := NewMockHTTPClient() + mockHTTP.SetSuccessResponse("DeepSeek AI response") + mockLogger := NewMockLogger() + + client := NewDeepSeekClientWithOptions( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(mockLogger), + WithAPIKey("sk-test-key"), + ) + + result, err := client.CallWithMessages("system prompt", "user prompt") + + if err != nil { + t.Fatalf("should not error: %v", err) + } + + if result != "DeepSeek AI response" { + t.Errorf("expected 'DeepSeek AI response', got '%s'", result) + } + + // 验证请求 + requests := mockHTTP.GetRequests() + if len(requests) != 1 { + t.Fatalf("expected 1 request, got %d", len(requests)) + } + + req := requests[0] + + // 验证 URL + expectedURL := DefaultDeepSeekBaseURL + "/chat/completions" + if req.URL.String() != expectedURL { + t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String()) + } + + // 验证 Authorization header + authHeader := req.Header.Get("Authorization") + if authHeader != "Bearer sk-test-key" { + t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader) + } + + // 验证 Content-Type + if req.Header.Get("Content-Type") != "application/json" { + t.Error("Content-Type should be application/json") + } +} + +func TestDeepSeekClient_Timeout(t *testing.T) { + client := NewDeepSeekClientWithOptions( + WithTimeout(30 * time.Second), + ) + + dsClient := client.(*DeepSeekClient) + + if dsClient.httpClient.Timeout != 30*time.Second { + t.Errorf("expected timeout 30s, got %v", dsClient.httpClient.Timeout) + } + + // 测试 SetTimeout + client.SetTimeout(60 * time.Second) + + if dsClient.httpClient.Timeout != 60*time.Second { + t.Errorf("expected timeout 60s after SetTimeout, got %v", dsClient.httpClient.Timeout) + } +} + +// ============================================================ +// 测试 hooks 机制 +// ============================================================ + +func TestDeepSeekClient_HooksIntegration(t *testing.T) { + client := NewDeepSeekClientWithOptions() + dsClient := client.(*DeepSeekClient) + + // 验证 hooks 指向 dsClient 自己(实现多态) + if dsClient.hooks != dsClient { + t.Error("hooks should point to dsClient for polymorphism") + } + + // 验证 buildUrl 使用 DeepSeek 配置 + url := dsClient.buildUrl() + expectedURL := DefaultDeepSeekBaseURL + "/chat/completions" + if url != expectedURL { + t.Errorf("expected URL '%s', got '%s'", expectedURL, url) + } +} diff --git a/mcp/examples_test.go b/mcp/examples_test.go new file mode 100644 index 00000000..2aa20829 --- /dev/null +++ b/mcp/examples_test.go @@ -0,0 +1,296 @@ +package mcp_test + +import ( + "fmt" + "net/http" + "time" + + "nofx/mcp" +) + +// ============================================================ +// 示例 1: 基础用法(向前兼容) +// ============================================================ + +func Example_backward_compatible() { + // ✅ 旧代码继续工作,无需修改 + client := mcp.New() + client.SetAPIKey("sk-xxx", "https://api.custom.com", "gpt-4") + + // 使用 + result, _ := client.CallWithMessages("system prompt", "user prompt") + fmt.Println(result) +} + +func Example_deepseek_backward_compatible() { + // ✅ DeepSeek 旧代码继续工作 + client := mcp.NewDeepSeekClient() + client.SetAPIKey("sk-xxx", "", "") + + result, _ := client.CallWithMessages("system", "user") + fmt.Println(result) +} + +// ============================================================ +// 示例 2: 新的推荐用法(选项模式) +// ============================================================ + +func Example_new_client_basic() { + // 使用默认配置 + client := mcp.NewClient() + + // 使用 DeepSeek + client = mcp.NewClient( + mcp.WithDeepSeekConfig("sk-xxx"), + ) + + // 使用 Qwen + client = mcp.NewClient( + mcp.WithQwenConfig("sk-xxx"), + ) + + _ = client +} + +func Example_new_client_with_options() { + // 组合多个选项 + client := mcp.NewClient( + mcp.WithDeepSeekConfig("sk-xxx"), + mcp.WithTimeout(60*time.Second), + mcp.WithMaxRetries(5), + mcp.WithMaxTokens(4000), + mcp.WithTemperature(0.7), + ) + + result, _ := client.CallWithMessages("system", "user") + fmt.Println(result) +} + +// ============================================================ +// 示例 3: 自定义日志器 +// ============================================================ + +// CustomLogger 自定义日志器示例 +type CustomLogger struct{} + +func (l *CustomLogger) Debugf(format string, args ...any) { + fmt.Printf("[DEBUG] "+format+"\n", args...) +} + +func (l *CustomLogger) Infof(format string, args ...any) { + fmt.Printf("[INFO] "+format+"\n", args...) +} + +func (l *CustomLogger) Warnf(format string, args ...any) { + fmt.Printf("[WARN] "+format+"\n", args...) +} + +func (l *CustomLogger) Errorf(format string, args ...any) { + fmt.Printf("[ERROR] "+format+"\n", args...) +} + +func Example_custom_logger() { + // 使用自定义日志器 + customLogger := &CustomLogger{} + + client := mcp.NewClient( + mcp.WithDeepSeekConfig("sk-xxx"), + mcp.WithLogger(customLogger), + ) + + result, _ := client.CallWithMessages("system", "user") + fmt.Println(result) +} + +func Example_no_logger_for_testing() { + // 测试时禁用日志 + client := mcp.NewClient( + mcp.WithLogger(mcp.NewNoopLogger()), + ) + + result, _ := client.CallWithMessages("system", "user") + fmt.Println(result) +} + +// ============================================================ +// 示例 4: 自定义 HTTP 客户端 +// ============================================================ + +func Example_custom_http_client() { + // 自定义 HTTP 客户端(添加代理、TLS等) + customHTTP := &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + // 自定义 TLS、连接池等 + }, + } + + client := mcp.NewClient( + mcp.WithDeepSeekConfig("sk-xxx"), + mcp.WithHTTPClient(customHTTP), + ) + + result, _ := client.CallWithMessages("system", "user") + fmt.Println(result) +} + +// ============================================================ +// 示例 5: DeepSeek 客户端(新 API) +// ============================================================ + +func Example_deepseek_new_api() { + // 基础用法 + client := mcp.NewDeepSeekClientWithOptions( + mcp.WithAPIKey("sk-xxx"), + ) + + // 高级用法 + client = mcp.NewDeepSeekClientWithOptions( + mcp.WithAPIKey("sk-xxx"), + mcp.WithLogger(&CustomLogger{}), + mcp.WithTimeout(90*time.Second), + mcp.WithMaxTokens(8000), + ) + + result, _ := client.CallWithMessages("system", "user") + fmt.Println(result) +} + +// ============================================================ +// 示例 6: Qwen 客户端(新 API) +// ============================================================ + +func Example_qwen_new_api() { + // 基础用法 + client := mcp.NewQwenClientWithOptions( + mcp.WithAPIKey("sk-xxx"), + ) + + // 高级用法 + client = mcp.NewQwenClientWithOptions( + mcp.WithAPIKey("sk-xxx"), + mcp.WithLogger(&CustomLogger{}), + mcp.WithTimeout(90*time.Second), + ) + + result, _ := client.CallWithMessages("system", "user") + fmt.Println(result) +} + +// ============================================================ +// 示例 7: 在 trader/auto_trader.go 中的迁移示例 +// ============================================================ + +func Example_trader_migration() { + // === 旧代码(继续工作)=== + oldStyleClient := func(apiKey, customURL, customModel string) mcp.AIClient { + client := mcp.NewDeepSeekClient() + client.SetAPIKey(apiKey, customURL, customModel) + return client + } + + // === 新代码(推荐)=== + newStyleClient := func(apiKey, customURL, customModel string) mcp.AIClient { + opts := []mcp.ClientOption{ + mcp.WithAPIKey(apiKey), + } + + if customURL != "" { + opts = append(opts, mcp.WithBaseURL(customURL)) + } + + if customModel != "" { + opts = append(opts, mcp.WithModel(customModel)) + } + + return mcp.NewDeepSeekClientWithOptions(opts...) + } + + // 两种方式都能工作 + _ = oldStyleClient("sk-xxx", "", "") + _ = newStyleClient("sk-xxx", "", "") +} + +// ============================================================ +// 示例 8: 测试场景 +// ============================================================ + +// MockHTTPClient Mock HTTP 客户端 +type MockHTTPClient struct { + Response string +} + +func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) { + // 返回预设的响应 + return &http.Response{ + StatusCode: 200, + Body: nil, // 实际测试中需要实现 + }, nil +} + +func Example_testing_with_mock() { + // 测试时使用 Mock + // mockHTTP := &MockHTTPClient{ + // Response: `{"choices":[{"message":{"content":"test response"}}]}`, + // } + + client := mcp.NewClient( + // mcp.WithHTTPClient(mockHTTP), // 实际测试中使用 mockHTTP + mcp.WithLogger(mcp.NewNoopLogger()), // 禁用日志 + ) + + result, _ := client.CallWithMessages("system", "user") + fmt.Println(result) +} + +// ============================================================ +// 示例 9: 环境特定配置 +// ============================================================ + +func Example_environment_specific() { + // 开发环境:详细日志 + devClient := mcp.NewClient( + mcp.WithDeepSeekConfig("sk-xxx"), + mcp.WithLogger(&CustomLogger{}), // 详细日志 + ) + + // 生产环境:结构化日志 + 超时保护 + prodClient := mcp.NewClient( + mcp.WithDeepSeekConfig("sk-xxx"), + // mcp.WithLogger(&ZapLogger{}), // 生产级日志 + mcp.WithTimeout(30*time.Second), + mcp.WithMaxRetries(3), + ) + + _, _ = devClient.CallWithMessages("system", "user") + _, _ = prodClient.CallWithMessages("system", "user") +} + +// ============================================================ +// 示例 10: 完整实战示例 +// ============================================================ + +func Example_real_world_usage() { + // 创建带有完整配置的客户端 + client := mcp.NewDeepSeekClientWithOptions( + mcp.WithAPIKey("sk-xxxxxxxxxx"), + mcp.WithTimeout(60*time.Second), + mcp.WithMaxRetries(5), + mcp.WithMaxTokens(4000), + mcp.WithTemperature(0.5), + mcp.WithLogger(&CustomLogger{}), + ) + + // 使用客户端 + systemPrompt := "你是一个专业的量化交易顾问" + userPrompt := "分析 BTC 当前走势" + + result, err := client.CallWithMessages(systemPrompt, userPrompt) + if err != nil { + fmt.Printf("Error: %v\n", err) + return + } + + fmt.Printf("AI 响应: %s\n", result) +} diff --git a/mcp/interface.go b/mcp/interface.go index 8c9b9574..e155ac01 100644 --- a/mcp/interface.go +++ b/mcp/interface.go @@ -1,12 +1,30 @@ package mcp -import "net/http" +import ( + "net/http" + "time" +) -// AIClient AI客户端接口 +// AIClient AI客户端公开接口(给外部使用) type AIClient interface { SetAPIKey(apiKey string, customURL string, customModel string) - // CallWithMessages 使用 system + user prompt 调用AI API + SetTimeout(timeout time.Duration) CallWithMessages(systemPrompt, userPrompt string) (string, error) - - setAuthHeader(reqHeaders http.Header) + CallWithRequest(req *Request) (string, error) // 构建器模式 API(支持高级功能) +} + +// clientHooks 内部钩子接口(用于子类重写特定步骤) +// 这些方法只在包内部使用,实现动态分派 +type clientHooks interface { + // 可被子类重写的钩子方法 + + call(systemPrompt, userPrompt string) (string, error) + + buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any + buildUrl() string + buildRequest(url string, jsonData []byte) (*http.Request, error) + setAuthHeader(reqHeaders http.Header) + marshalRequestBody(requestBody map[string]any) ([]byte, error) + parseMCPResponse(body []byte) (string, error) + isRetryableError(err error) bool } diff --git a/mcp/intro/BUILDER_EXAMPLES.md b/mcp/intro/BUILDER_EXAMPLES.md new file mode 100644 index 00000000..8ec8af9e --- /dev/null +++ b/mcp/intro/BUILDER_EXAMPLES.md @@ -0,0 +1,572 @@ +# RequestBuilder 使用示例 + +## 📋 目录 +1. [基础用法](#基础用法) +2. [多轮对话](#多轮对话) +3. [参数精细控制](#参数精细控制) +4. [Function Calling](#function-calling) +5. [预设场景](#预设场景) +6. [完整示例](#完整示例) + +--- + +## 基础用法 + +### 简单对话 + +```go +package main + +import ( + "fmt" + "nofx/mcp" +) + +func main() { + // 创建客户端 + client := mcp.NewDeepSeekClientWithOptions( + mcp.WithAPIKey("sk-xxx"), + ) + + // 使用构建器创建请求 + request := mcp.NewRequestBuilder(). + WithSystemPrompt("You are a helpful assistant"). + WithUserPrompt("What is Go programming language?"). + Build() + + // 调用 API + result, err := client.CallWithRequest(request) + if err != nil { + panic(err) + } + + fmt.Println(result) +} +``` + +### 与传统方式对比 + +```go +// 传统方式(仍然可用) +result, err := client.CallWithMessages( + "You are a helpful assistant", + "What is Go?", +) + +// 构建器方式(新API,功能更强大) +request := mcp.NewRequestBuilder(). + WithSystemPrompt("You are a helpful assistant"). + WithUserPrompt("What is Go?"). + Build() +result, err := client.CallWithRequest(request) +``` + +--- + +## 多轮对话 + +### 带上下文的对话 + +```go +// 构建包含历史的多轮对话 +request := mcp.NewRequestBuilder(). + AddSystemMessage("You are a trading advisor"). + AddUserMessage("Analyze BTC price"). + AddAssistantMessage("BTC is currently in an upward trend..."). + AddUserMessage("What's the best entry point?"). // 继续对话 + WithTemperature(0.3). // 低温度,更精确 + Build() + +result, err := client.CallWithRequest(request) +``` + +### 从历史记录构建 + +```go +// 假设你有保存的对话历史 +history := []mcp.Message{ + mcp.NewUserMessage("Hello"), + mcp.NewAssistantMessage("Hi! How can I help?"), + mcp.NewUserMessage("What's the weather?"), + mcp.NewAssistantMessage("It's sunny today"), +} + +// 继续对话 +request := mcp.NewRequestBuilder(). + AddSystemMessage("You are helpful"). + AddConversationHistory(history). // 添加历史 + AddUserMessage("What about tomorrow?"). // 新问题 + Build() + +result, err := client.CallWithRequest(request) +``` + +--- + +## 参数精细控制 + +### 代码生成(低温度、精确) + +```go +request := mcp.NewRequestBuilder(). + WithSystemPrompt("You are a Go expert"). + WithUserPrompt("Generate a HTTP server"). + WithTemperature(0.2). // 低温度 = 更确定 + WithTopP(0.1). // 低 top_p = 更聚焦 + WithMaxTokens(2000). + AddStopSequence("```"). // 遇到代码块结束符停止 + Build() + +code, err := client.CallWithRequest(request) +``` + +### 创意写作(高温度、随机) + +```go +request := mcp.NewRequestBuilder(). + WithSystemPrompt("You are a creative writer"). + WithUserPrompt("Write a sci-fi story about AI"). + WithTemperature(1.2). // 高温度 = 更创意 + WithTopP(0.95). // 高 top_p = 更多样 + WithPresencePenalty(0.6). // 避免重复主题 + WithFrequencyPenalty(0.5). // 避免重复词汇 + WithMaxTokens(4000). + Build() + +story, err := client.CallWithRequest(request) +``` + +### 精确分析(平衡参数) + +```go +request := mcp.NewRequestBuilder(). + WithSystemPrompt("You are a quantitative analyst"). + WithUserPrompt("Analyze BTC/USDT chart pattern"). + WithTemperature(0.5). // 中等温度 + WithMaxTokens(1500). + WithStopSequences([]string{"---", "END"}). // 多个停止序列 + Build() + +analysis, err := client.CallWithRequest(request) +``` + +--- + +## Function Calling + +### 天气查询工具 + +```go +// 定义工具参数 schema(JSON Schema 格式) +weatherParams := map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{ + "type": "string", + "description": "City name, e.g., Beijing, Shanghai", + }, + "unit": map[string]any{ + "type": "string", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + "required": []string{"location"}, +} + +// 构建请求 +request := mcp.NewRequestBuilder(). + WithUserPrompt("北京今天天气怎么样?"). + AddFunction( + "get_weather", // 函数名 + "Get current weather", // 函数描述 + weatherParams, // 参数定义 + ). + WithToolChoice("auto"). // 让 AI 自动决定是否调用 + Build() + +response, err := client.CallWithRequest(request) + +// AI 可能返回 tool_calls,你需要执行函数并返回结果 +// (具体实现取决于 AI provider 的响应格式) +``` + +### 多个工具 + +```go +// 定义多个工具 +request := mcp.NewRequestBuilder(). + WithUserPrompt("帮我查询北京天气,并计算100的平方根"). + AddFunction("get_weather", "Get weather", weatherParams). + AddFunction("calculate", "Calculate math", calcParams). + AddFunction("search_web", "Search web", searchParams). + WithToolChoice("auto"). + Build() + +response, err := client.CallWithRequest(request) +// AI 会选择调用相应的工具 +``` + +### 强制使用特定工具 + +```go +request := mcp.NewRequestBuilder(). + WithUserPrompt("北京"). + AddFunction("get_weather", "Get weather", weatherParams). + WithToolChoice(`{"type": "function", "function": {"name": "get_weather"}}`). + Build() + +// AI 必须调用 get_weather 函数 +``` + +--- + +## 预设场景 + +### ForChat - 聊天场景 + +```go +// 预设参数:temperature=0.7, maxTokens=2000 +request := mcp.ForChat(). + WithSystemPrompt("You are a friendly chatbot"). + WithUserPrompt("Hello!"). + Build() + +// 等价于 +request := mcp.NewRequestBuilder(). + WithSystemPrompt("You are a friendly chatbot"). + WithUserPrompt("Hello!"). + WithTemperature(0.7). + WithMaxTokens(2000). + Build() +``` + +### ForCodeGeneration - 代码生成场景 + +```go +// 预设参数:temperature=0.2, topP=0.1, maxTokens=2000 +request := mcp.ForCodeGeneration(). + WithUserPrompt("Generate a REST API in Go"). + Build() + +// 自动使用低温度和低 top_p,确保代码准确性 +``` + +### ForCreativeWriting - 创意写作场景 + +```go +// 预设参数: +// temperature=1.2, topP=0.95, maxTokens=4000 +// presencePenalty=0.6, frequencyPenalty=0.5 +request := mcp.ForCreativeWriting(). + WithSystemPrompt("You are a novelist"). + WithUserPrompt("Write a fantasy story"). + Build() + +// 自动使用高温度和惩罚参数,增加创意和多样性 +``` + +--- + +## 完整示例 + +### 量化交易 AI 顾问 + +```go +package main + +import ( + "fmt" + "log" + "nofx/mcp" + "os" +) + +func main() { + // 创建客户端 + client := mcp.NewDeepSeekClientWithOptions( + mcp.WithAPIKey(os.Getenv("DEEPSEEK_API_KEY")), + mcp.WithMaxRetries(5), + mcp.WithTimeout(60 * time.Second), + ) + + // 场景1: 市场分析(需要精确) + analysisRequest := mcp.NewRequestBuilder(). + WithSystemPrompt("You are a professional quantitative trader"). + WithUserPrompt("Analyze BTC/USDT 1H chart, current price $45,000"). + WithTemperature(0.3). // 低温度,更精确 + WithMaxTokens(1500). + Build() + + analysis, err := client.CallWithRequest(analysisRequest) + if err != nil { + log.Fatal(err) + } + fmt.Println("=== Market Analysis ===") + fmt.Println(analysis) + + // 场景2: 继续对话,询问入场点 + followUpRequest := mcp.NewRequestBuilder(). + AddSystemMessage("You are a professional quantitative trader"). + AddUserMessage("Analyze BTC/USDT 1H chart, current price $45,000"). + AddAssistantMessage(analysis). // 添加之前的回复 + AddUserMessage("Based on your analysis, what's the best entry point?"). + WithTemperature(0.3). + Build() + + entryPoint, err := client.CallWithRequest(followUpRequest) + if err != nil { + log.Fatal(err) + } + fmt.Println("\n=== Entry Point Suggestion ===") + fmt.Println(entryPoint) +} +``` + +### 代码评审助手 + +```go +func reviewCode(client mcp.AIClient, code string) (string, error) { + request := mcp.ForCodeGeneration(). // 使用代码场景预设 + WithSystemPrompt("You are a senior Go developer reviewing code"). + WithUserPrompt(fmt.Sprintf("Review this code:\n\n```go\n%s\n```", code)). + WithMaxTokens(2000). + AddStopSequence("---END---"). + Build() + + return client.CallWithRequest(request) +} + +func main() { + client := mcp.NewDeepSeekClientWithOptions( + mcp.WithAPIKey(os.Getenv("DEEPSEEK_API_KEY")), + ) + + code := ` +func Add(a, b int) int { + return a + b +} +` + + review, err := reviewCode(client, code) + if err != nil { + log.Fatal(err) + } + fmt.Println(review) +} +``` + +### AI 聊天机器人(带历史记录) + +```go +type ChatBot struct { + client mcp.AIClient + history []mcp.Message +} + +func NewChatBot(client mcp.AIClient, systemPrompt string) *ChatBot { + return &ChatBot{ + client: client, + history: []mcp.Message{ + mcp.NewSystemMessage(systemPrompt), + }, + } +} + +func (bot *ChatBot) Chat(userMessage string) (string, error) { + // 添加用户消息到历史 + bot.history = append(bot.history, mcp.NewUserMessage(userMessage)) + + // 构建请求(包含完整历史) + request := mcp.ForChat(). + AddMessages(bot.history...). + Build() + + // 调用 API + response, err := bot.client.CallWithRequest(request) + if err != nil { + return "", err + } + + // 添加 AI 回复到历史 + bot.history = append(bot.history, mcp.NewAssistantMessage(response)) + + return response, nil +} + +func main() { + client := mcp.NewDeepSeekClientWithOptions( + mcp.WithAPIKey(os.Getenv("DEEPSEEK_API_KEY")), + ) + + bot := NewChatBot(client, "You are a friendly and helpful assistant") + + // 对话1 + resp1, _ := bot.Chat("What is Go?") + fmt.Println("User: What is Go?") + fmt.Println("Bot:", resp1) + + // 对话2(带上下文) + resp2, _ := bot.Chat("What are its main features?") + fmt.Println("\nUser: What are its main features?") + fmt.Println("Bot:", resp2) + + // 对话3(继续上下文) + resp3, _ := bot.Chat("Show me an example") + fmt.Println("\nUser: Show me an example") + fmt.Println("Bot:", resp3) +} +``` + +### Function Calling 完整示例 + +```go +package main + +import ( + "encoding/json" + "fmt" + "nofx/mcp" + "os" +) + +// 天气查询函数(模拟) +func getWeather(location string) string { + return fmt.Sprintf("Weather in %s: Sunny, 25°C", location) +} + +func main() { + client := mcp.NewDeepSeekClientWithOptions( + mcp.WithAPIKey(os.Getenv("DEEPSEEK_API_KEY")), + ) + + // 定义工具 + weatherParams := map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{ + "type": "string", + "description": "City name", + }, + }, + "required": []string{"location"}, + } + + // 第一步:发送带工具的请求 + request := mcp.NewRequestBuilder(). + WithUserPrompt("北京天气怎么样?"). + AddFunction("get_weather", "Get current weather", weatherParams). + WithToolChoice("auto"). + Build() + + response, err := client.CallWithRequest(request) + if err != nil { + panic(err) + } + + fmt.Println("AI Response:", response) + + // 第二步:如果 AI 返回了 tool_call(实际需要解析 JSON 响应) + // 这里是示例,实际需要根据 provider 的响应格式解析 + // toolCall := parseToolCall(response) + // weatherResult := getWeather(toolCall.Arguments.Location) + + // 第三步:将工具结果返回给 AI + // followUp := mcp.NewRequestBuilder(). + // AddConversationHistory(previousMessages). + // AddToolResult(toolCall.ID, weatherResult). + // Build() + // + // finalResponse, _ := client.CallWithRequest(followUp) +} +``` + +--- + +## 最佳实践 + +### 1. 使用 MustBuild() vs Build() + +```go +// Build() - 返回 error,需要处理 +request, err := NewRequestBuilder(). + WithUserPrompt("Hello"). + Build() +if err != nil { + log.Fatal(err) +} + +// MustBuild() - 如果失败会 panic,适用于确定不会错的场景 +request := NewRequestBuilder(). + WithSystemPrompt("You are helpful"). + WithUserPrompt("Hello"). + MustBuild() // 构建失败会 panic +``` + +### 2. 重用构建器 + +```go +// 创建基础构建器 +baseBuilder := mcp.NewRequestBuilder(). + WithSystemPrompt("You are a trading advisor"). + WithTemperature(0.3) + +// 为不同问题添加用户消息 +question1 := baseBuilder. + AddUserMessage("Analyze BTC"). + Build() + +question2 := baseBuilder. + ClearMessages(). // 清空之前的消息 + AddSystemMessage("You are a trading advisor"). + AddUserMessage("Analyze ETH"). + Build() +``` + +### 3. 选择合适的预设 + +```go +// ✅ 代码生成 - 使用 ForCodeGeneration +ForCodeGeneration().WithUserPrompt("Generate code") + +// ✅ 聊天 - 使用 ForChat +ForChat().WithUserPrompt("Hello") + +// ✅ 创意写作 - 使用 ForCreativeWriting +ForCreativeWriting().WithUserPrompt("Write a story") + +// ✅ 自定义 - 使用 NewRequestBuilder +NewRequestBuilder().WithTemperature(0.6).WithUserPrompt("...") +``` + +--- + +## 迁移指南 + +### 从旧 API 迁移 + +```go +// 旧 API(仍然可用) +result, err := client.CallWithMessages("system", "user") + +// 迁移到新 API +request := mcp.NewRequestBuilder(). + WithSystemPrompt("system"). + WithUserPrompt("user"). + Build() +result, err := client.CallWithRequest(request) + +// 如果需要更多控制 +request := mcp.NewRequestBuilder(). + WithSystemPrompt("system"). + WithUserPrompt("user"). + WithTemperature(0.8). // 新功能 + WithMaxTokens(2000). // 新功能 + Build() +result, err := client.CallWithRequest(request) +``` + +--- + +更多信息请参考: +- [构建器模式价值分析](./BUILDER_PATTERN_BENEFITS.md) +- [MCP 使用指南](./README.md) diff --git a/mcp/intro/BUILDER_PATTERN_BENEFITS.md b/mcp/intro/BUILDER_PATTERN_BENEFITS.md new file mode 100644 index 00000000..04ff2587 --- /dev/null +++ b/mcp/intro/BUILDER_PATTERN_BENEFITS.md @@ -0,0 +1,716 @@ +# 构建器模式在 MCP 模块中的应用价值 + +## 📋 目录 +1. [当前实现的局限性](#当前实现的局限性) +2. [构建器模式的好处](#构建器模式的好处) +3. [实际应用场景](#实际应用场景) +4. [对比示例](#对比示例) +5. [是否需要引入](#是否需要引入) + +--- + +## 当前实现的局限性 + +### 现状分析 + +**当前 buildMCPRequestBody 实现**: +```go +func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any { + messages := []map[string]string{} + + if systemPrompt != "" { + messages = append(messages, map[string]string{ + "role": "system", + "content": systemPrompt, + }) + } + messages = append(messages, map[string]string{ + "role": "user", + "content": userPrompt, + }) + + return map[string]interface{}{ + "model": client.Model, + "messages": messages, + "temperature": client.config.Temperature, + "max_tokens": client.MaxTokens, + } +} +``` + +### 存在的限制 + +1. **只支持简单对话** + - ❌ 无法添加多轮对话历史 + - ❌ 无法添加 assistant 回复 + - ❌ 无法构建复杂的对话上下文 + +2. **参数固定** + - ❌ 无法动态添加可选参数(如 top_p、frequency_penalty) + - ❌ 无法为单次请求自定义 temperature(会影响全局配置) + - ❌ 无法添加 function calling、tools 等高级功能 + +3. **扩展性差** + - ❌ 每次添加新参数都需要修改方法签名 + - ❌ 参数列表会越来越长 + - ❌ 子类重写时需要处理所有参数 + +--- + +## 构建器模式的好处 + +### 1. 🎯 **灵活性和可读性** + +#### 当前方式(参数传递) +```go +// 问题:参数多了会很混乱 +client.CallWithCustomParams( + "system prompt", + "user prompt", + 0.8, // temperature - 这是什么? + 2000, // max_tokens - 这是什么? + 0.9, // top_p - 这是什么? + 0.5, // frequency_penalty + nil, // stop sequences + false, // stream +) +``` + +#### 构建器方式 +```go +// 清晰、自解释 +request := NewRequestBuilder(). + WithSystemPrompt("You are a helpful assistant"). + WithUserPrompt("Tell me about Go"). + WithTemperature(0.8). + WithMaxTokens(2000). + WithTopP(0.9). + Build() + +result, err := client.CallWithRequest(request) +``` + +--- + +### 2. 📚 **支持复杂场景** + +#### 场景1: 多轮对话 + +**当前方式**: 😢 不支持 +```go +// ❌ 无法实现 +client.CallWithMessages("system", "user prompt") +``` + +**构建器方式**: ✅ 支持 +```go +request := NewRequestBuilder(). + AddSystemMessage("You are a helpful assistant"). + AddUserMessage("What is the weather?"). + AddAssistantMessage("It's sunny today"). + AddUserMessage("What about tomorrow?"). // 继续对话 + WithTemperature(0.7). + Build() +``` + +#### 场景2: 函数调用(Function Calling) + +**当前方式**: 😢 不支持 +```go +// ❌ 无法添加 tools/functions +``` + +**构建器方式**: ✅ 支持 +```go +request := NewRequestBuilder(). + WithUserPrompt("What's the weather in Beijing?"). + AddTool(Tool{ + Type: "function", + Function: FunctionDef{ + Name: "get_weather", + Description: "Get current weather", + Parameters: weatherParamsSchema, + }, + }). + WithToolChoice("auto"). + Build() +``` + +#### 场景3: 流式响应 + +**当前方式**: 😢 需要修改整个架构 +```go +// ❌ CallWithMessages 不支持流式 +``` + +**构建器方式**: ✅ 易于扩展 +```go +request := NewRequestBuilder(). + WithUserPrompt("Write a long story"). + WithStream(true). + Build() + +stream, err := client.CallStream(request) +for chunk := range stream { + fmt.Print(chunk) +} +``` + +--- + +### 3. 🔧 **易于扩展和维护** + +#### 添加新参数 + +**当前方式**: 😢 破坏性修改 +```go +// 需要修改方法签名(破坏现有代码) +func (client *Client) buildMCPRequestBody( + systemPrompt, userPrompt string, + // 新增参数会导致所有调用处都要修改 + topP float64, + presencePenalty float64, +) map[string]any +``` + +**构建器方式**: ✅ 向后兼容 +```go +// 只需添加新方法,不影响现有代码 +func (b *RequestBuilder) WithPresencePenalty(p float64) *RequestBuilder { + b.presencePenalty = p + return b +} + +// 旧代码不受影响 +request := builder.WithUserPrompt("Hello").Build() + +// 新代码可以使用新功能 +request := builder. + WithUserPrompt("Hello"). + WithPresencePenalty(0.6). // 新参数 + Build() +``` + +--- + +### 4. 🎨 **可选参数处理** + +**当前方式**: 😢 难以处理可选参数 +```go +// 方案1: 传 nil/0 值(不优雅) +client.CallWithParams(system, user, 0, 0, nil, nil) + +// 方案2: 使用选项模式(但每次调用都要传) +client.CallWithParams(system, user, WithTopP(0.9), WithPenalty(0.5)) + +// 方案3: 配置对象(需要创建临时对象) +config := &RequestConfig{ + SystemPrompt: system, + UserPrompt: user, + TopP: 0.9, +} +``` + +**构建器方式**: ✅ 优雅处理 +```go +// 只设置需要的参数,其他使用默认值 +request := NewRequestBuilder(). + WithUserPrompt("Hello"). + // 不设置 temperature,使用默认值 + // 不设置 topP,使用默认值 + Build() + +// 也可以全部自定义 +request := NewRequestBuilder(). + WithUserPrompt("Hello"). + WithTemperature(0.8). + WithTopP(0.9). + WithMaxTokens(2000). + Build() +``` + +--- + +### 5. ✅ **类型安全和验证** + +**当前方式**: 😢 运行时才发现错误 +```go +// ❌ 编译时无法发现问题 +client.CallWithMessages("", "") // 空 prompt +client.CallWithMessages("system", "user") // temperature 可能不合法 +``` + +**构建器方式**: ✅ 提前验证 +```go +type RequestBuilder struct { + messages []Message + temperature float64 + maxTokens int +} + +func (b *RequestBuilder) WithTemperature(t float64) *RequestBuilder { + if t < 0 || t > 2 { + panic("temperature must be between 0 and 2") // 或返回 error + } + b.temperature = t + return b +} + +func (b *RequestBuilder) Build() (*Request, error) { + if len(b.messages) == 0 { + return nil, errors.New("at least one message is required") + } + if b.maxTokens <= 0 { + return nil, errors.New("maxTokens must be positive") + } + return &Request{...}, nil +} +``` + +--- + +## 实际应用场景 + +### 场景1: 量化交易 AI 顾问(多轮对话) + +```go +// 构建包含市场数据的上下文对话 +request := NewRequestBuilder(). + AddSystemMessage("You are a quantitative trading advisor"). + AddUserMessage("Analyze BTC trend"). + AddAssistantMessage("BTC is in an upward trend based on..."). + AddUserMessage("What about entry points?"). // 继续对话 + WithTemperature(0.3). // 低温度,更精确 + WithMaxTokens(1000). + Build() + +analysis, err := client.CallWithRequest(request) +``` + +### 场景2: 代码生成(需要精确控制) + +```go +request := NewRequestBuilder(). + WithSystemPrompt("You are a Go expert"). + WithUserPrompt("Generate a HTTP server"). + WithTemperature(0.2). // 低温度,更确定性 + WithTopP(0.1). // 低 top_p,更聚焦 + WithMaxTokens(2000). + WithStopSequences([]string{"```"}). // 遇到代码块结束符停止 + Build() +``` + +### 场景3: 创意写作(需要随机性) + +```go +request := NewRequestBuilder(). + WithSystemPrompt("You are a creative writer"). + WithUserPrompt("Write a sci-fi story"). + WithTemperature(1.2). // 高温度,更创意 + WithTopP(0.95). // 高 top_p,更多样性 + WithPresencePenalty(0.6). // 避免重复 + WithFrequencyPenalty(0.5). + WithMaxTokens(4000). + Build() +``` + +### 场景4: 函数调用(工具使用) + +```go +// 定义工具 +weatherTool := Tool{ + Type: "function", + Function: FunctionDef{ + Name: "get_weather", + Description: "Get current weather for a location", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{ + "type": "string", + "description": "City name", + }, + }, + "required": []string{"location"}, + }, + }, +} + +request := NewRequestBuilder(). + WithUserPrompt("What's the weather in Beijing?"). + AddTool(weatherTool). + WithToolChoice("auto"). + Build() + +response, err := client.CallWithRequest(request) +// 解析 response.ToolCalls 并执行实际的天气查询 +``` + +--- + +## 对比示例 + +### 示例1: 基础用法 + +#### 当前实现 +```go +result, err := client.CallWithMessages( + "You are a helpful assistant", + "What is Go?", +) +``` + +#### 构建器模式 +```go +request := NewRequestBuilder(). + WithSystemPrompt("You are a helpful assistant"). + WithUserPrompt("What is Go?"). + Build() + +result, err := client.CallWithRequest(request) +``` + +**分析**: 基础用法下,构建器稍显冗长,但更清晰。 + +--- + +### 示例2: 复杂用法 + +#### 当前实现(假设扩展后) +```go +// 😢 参数太多,难以理解 +result, err := client.CallWithMessagesAdvanced( + "system prompt", + "user prompt", + nil, // messages history? + 0.8, // temperature + 2000, // max_tokens + 0.9, // top_p + 0.5, // frequency_penalty + 0.6, // presence_penalty + nil, // stop sequences + false, // stream + nil, // tools + "", // tool_choice +) +``` + +#### 构建器模式 +```go +// ✅ 清晰、自解释 +request := NewRequestBuilder(). + WithSystemPrompt("system prompt"). + WithUserPrompt("user prompt"). + WithTemperature(0.8). + WithMaxTokens(2000). + WithTopP(0.9). + WithFrequencyPenalty(0.5). + WithPresencePenalty(0.6). + Build() + +result, err := client.CallWithRequest(request) +``` + +**分析**: 复杂场景下,构建器模式优势明显。 + +--- + +## 是否需要引入? + +### ✅ 建议引入的情况 + +1. **需要支持多轮对话** + - 聊天机器人 + - 上下文相关的 AI 助手 + +2. **需要精细控制 AI 参数** + - 不同任务需要不同 temperature + - 需要使用 top_p、penalty 等高级参数 + +3. **需要使用 AI 高级功能** + - Function Calling / Tools + - 流式响应 + - Vision API(图片输入) + +4. **API 接口可能频繁变化** + - AI 提供商经常添加新参数 + - 需要向后兼容 + +### ⚠️ 可以暂缓的情况 + +1. **只有简单的单轮对话** + - 当前 `CallWithMessages` 已足够 + +2. **参数固定不变** + - 所有请求使用相同配置 + +3. **团队规模小,代码量少** + - 引入新模式的学习成本 > 收益 + +--- + +## 推荐方案 + +### 方案1: 渐进式引入(推荐) + +**第一阶段**: 保留现有 API,新增构建器 +```go +// 旧 API 继续工作(向后兼容) +result, err := client.CallWithMessages("system", "user") + +// 新 API 提供高级功能 +request := NewRequestBuilder(). + WithUserPrompt("user"). + WithTemperature(0.8). + Build() +result, err := client.CallWithRequest(request) +``` + +**第二阶段**: 逐步迁移 +```go +// 在文档中推荐使用构建器 +// 旧 API 标记为 Deprecated(但不删除) +``` + +### 方案2: 仅用于高级场景 + +只在需要复杂功能时使用构建器: +```go +// 简单场景:使用现有 API +client.CallWithMessages("system", "user") + +// 复杂场景:使用构建器 +client.CallWithRequest( + NewRequestBuilder(). + AddConversationHistory(history). + AddUserMessage("new question"). + WithTools(tools). + Build(), +) +``` + +--- + +## 实现示例 + +### 完整的构建器实现 + +```go +package mcp + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type Tool struct { + Type string `json:"type"` + Function FunctionDef `json:"function"` +} + +type Request struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + TopP float64 `json:"top_p,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + Stop []string `json:"stop,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice string `json:"tool_choice,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type RequestBuilder struct { + model string + messages []Message + temperature *float64 + maxTokens *int + topP *float64 + frequencyPenalty *float64 + presencePenalty *float64 + stop []string + tools []Tool + toolChoice string + stream bool +} + +func NewRequestBuilder() *RequestBuilder { + return &RequestBuilder{ + messages: make([]Message, 0), + } +} + +func (b *RequestBuilder) WithModel(model string) *RequestBuilder { + b.model = model + return b +} + +func (b *RequestBuilder) WithSystemPrompt(prompt string) *RequestBuilder { + if prompt != "" { + b.messages = append(b.messages, Message{ + Role: "system", + Content: prompt, + }) + } + return b +} + +func (b *RequestBuilder) WithUserPrompt(prompt string) *RequestBuilder { + b.messages = append(b.messages, Message{ + Role: "user", + Content: prompt, + }) + return b +} + +func (b *RequestBuilder) AddUserMessage(content string) *RequestBuilder { + return b.WithUserPrompt(content) +} + +func (b *RequestBuilder) AddSystemMessage(content string) *RequestBuilder { + return b.WithSystemPrompt(content) +} + +func (b *RequestBuilder) AddAssistantMessage(content string) *RequestBuilder { + b.messages = append(b.messages, Message{ + Role: "assistant", + Content: content, + }) + return b +} + +func (b *RequestBuilder) AddMessage(role, content string) *RequestBuilder { + b.messages = append(b.messages, Message{ + Role: role, + Content: content, + }) + return b +} + +func (b *RequestBuilder) AddConversationHistory(history []Message) *RequestBuilder { + b.messages = append(b.messages, history...) + return b +} + +func (b *RequestBuilder) WithTemperature(t float64) *RequestBuilder { + if t < 0 || t > 2 { + panic("temperature must be between 0 and 2") + } + b.temperature = &t + return b +} + +func (b *RequestBuilder) WithMaxTokens(tokens int) *RequestBuilder { + b.maxTokens = &tokens + return b +} + +func (b *RequestBuilder) WithTopP(p float64) *RequestBuilder { + b.topP = &p + return b +} + +func (b *RequestBuilder) WithFrequencyPenalty(p float64) *RequestBuilder { + b.frequencyPenalty = &p + return b +} + +func (b *RequestBuilder) WithPresencePenalty(p float64) *RequestBuilder { + b.presencePenalty = &p + return b +} + +func (b *RequestBuilder) WithStopSequences(sequences []string) *RequestBuilder { + b.stop = sequences + return b +} + +func (b *RequestBuilder) AddTool(tool Tool) *RequestBuilder { + b.tools = append(b.tools, tool) + return b +} + +func (b *RequestBuilder) WithToolChoice(choice string) *RequestBuilder { + b.toolChoice = choice + return b +} + +func (b *RequestBuilder) WithStream(stream bool) *RequestBuilder { + b.stream = stream + return b +} + +func (b *RequestBuilder) Build() (*Request, error) { + if len(b.messages) == 0 { + return nil, errors.New("at least one message is required") + } + + req := &Request{ + Model: b.model, + Messages: b.messages, + Stop: b.stop, + Tools: b.tools, + ToolChoice: b.toolChoice, + Stream: b.stream, + } + + // 只设置非 nil 的可选参数 + if b.temperature != nil { + req.Temperature = *b.temperature + } + if b.maxTokens != nil { + req.MaxTokens = *b.maxTokens + } + if b.topP != nil { + req.TopP = *b.topP + } + if b.frequencyPenalty != nil { + req.FrequencyPenalty = *b.frequencyPenalty + } + if b.presencePenalty != nil { + req.PresencePenalty = *b.presencePenalty + } + + return req, nil +} +``` + +### Client 集成 + +```go +// 新增方法(不影响现有代码) +func (client *Client) CallWithRequest(req *Request) (string, error) { + // 使用 req 中的参数发送请求 + // ... +} +``` + +--- + +## 总结 + +### 核心优势 +1. ✅ **灵活性** - 轻松支持复杂场景 +2. ✅ **可读性** - 代码自解释,易于理解 +3. ✅ **可扩展性** - 添加新功能不破坏现有代码 +4. ✅ **类型安全** - 编译时检查,提前发现错误 +5. ✅ **向后兼容** - 可以与现有 API 共存 + +### 建议 +- **当前阶段**: 如果只需要简单对话,现有实现已足够 +- **未来扩展**: 当需要以下功能时再引入 + - 多轮对话 + - Function Calling + - 流式响应 + - 精细参数控制 + +### 最佳实践 +采用**渐进式引入**策略: +1. 保留现有 `CallWithMessages` API +2. 新增 `CallWithRequest` + 构建器 +3. 在文档中推荐新 API,但不强制迁移 +4. 根据实际需求逐步完善构建器功能 + +这样既能保持向后兼容,又能为未来的功能扩展做好准备。 diff --git a/mcp/intro/LOGRUS_INTEGRATION.md b/mcp/intro/LOGRUS_INTEGRATION.md new file mode 100644 index 00000000..13630e02 --- /dev/null +++ b/mcp/intro/LOGRUS_INTEGRATION.md @@ -0,0 +1,268 @@ +# Logrus 集成指南 + +本文档展示如何将 MCP 模块与 Logrus 日志库集成。 + +## 📦 安装 Logrus + +```bash +go get github.com/sirupsen/logrus +``` + +## 🔧 集成步骤 + +### 1. 创建 Logrus 适配器 + +创建一个实现 `mcp.Logger` 接口的适配器: + +```go +package main + +import ( + "github.com/sirupsen/logrus" + "nofx/mcp" +) + +// LogrusLogger Logrus 日志适配器 +type LogrusLogger struct { + logger *logrus.Logger +} + +// NewLogrusLogger 创建 Logrus 日志适配器 +func NewLogrusLogger(logger *logrus.Logger) *LogrusLogger { + return &LogrusLogger{logger: logger} +} + +// Debugf 实现 Debug 日志 +func (l *LogrusLogger) Debugf(format string, args ...any) { + l.logger.Debugf(format, args...) +} + +// Infof 实现 Info 日志 +func (l *LogrusLogger) Infof(format string, args ...any) { + l.logger.Infof(format, args...) +} + +// Warnf 实现 Warn 日志 +func (l *LogrusLogger) Warnf(format string, args ...any) { + l.logger.Warnf(format, args...) +} + +// Errorf 实现 Error 日志 +func (l *LogrusLogger) Errorf(format string, args ...any) { + l.logger.Errorf(format, args...) +} +``` + +### 2. 使用 Logrus Logger + +```go +package main + +import ( + "github.com/sirupsen/logrus" + "nofx/mcp" +) + +func main() { + // 1. 创建 Logrus logger + logger := logrus.New() + + // 2. 配置 Logrus + logger.SetLevel(logrus.DebugLevel) + logger.SetFormatter(&logrus.JSONFormatter{}) + + // 3. 创建适配器 + logrusAdapter := NewLogrusLogger(logger) + + // 4. 使用 MCP 客户端 + client := mcp.NewClient( + mcp.WithDeepSeekConfig("sk-xxx"), + mcp.WithLogger(logrusAdapter), // 注入 Logrus 日志器 + ) + + // 5. 调用 AI + result, err := client.CallWithMessages("system", "user") + if err != nil { + logger.Errorf("AI 调用失败: %v", err) + return + } + + logger.Infof("AI 响应: %s", result) +} +``` + +## 🎨 高级配置 + +### JSON 格式输出 + +```go +logger := logrus.New() +logger.SetFormatter(&logrus.JSONFormatter{ + TimestampFormat: "2006-01-02 15:04:05", + PrettyPrint: true, +}) +``` + +输出示例: +```json +{ + "level": "info", + "msg": "📡 [Provider: deepseek, Model: deepseek-chat] Request AI Server: BaseURL: https://api.deepseek.com/v1", + "time": "2024-01-15 10:30:45" +} +``` + +### 添加固定字段 + +```go +logger := logrus.New() +logger.WithFields(logrus.Fields{ + "service": "trading-bot", + "version": "1.0.0", +}) +``` + +### 不同环境配置 + +```go +func createLogger(env string) *logrus.Logger { + logger := logrus.New() + + switch env { + case "production": + // 生产环境:JSON 格式,只记录 Info 以上 + logger.SetLevel(logrus.InfoLevel) + logger.SetFormatter(&logrus.JSONFormatter{}) + + case "development": + // 开发环境:文本格式,记录所有级别 + logger.SetLevel(logrus.DebugLevel) + logger.SetFormatter(&logrus.TextFormatter{ + FullTimestamp: true, + }) + + case "test": + // 测试环境:静默模式 + logger.SetLevel(logrus.FatalLevel) + } + + return logger +} + +// 使用 +logger := createLogger("production") +mcpClient := mcp.NewClient( + mcp.WithLogger(NewLogrusLogger(logger)), +) +``` + +## 📝 完整示例 + +```go +package main + +import ( + "os" + + "github.com/sirupsen/logrus" + "nofx/mcp" +) + +// LogrusLogger Logrus 适配器 +type LogrusLogger struct { + logger *logrus.Logger +} + +func NewLogrusLogger(logger *logrus.Logger) *LogrusLogger { + return &LogrusLogger{logger: logger} +} + +func (l *LogrusLogger) Debugf(format string, args ...any) { + l.logger.Debugf(format, args...) +} + +func (l *LogrusLogger) Infof(format string, args ...any) { + l.logger.Infof(format, args...) +} + +func (l *LogrusLogger) Warnf(format string, args ...any) { + l.logger.Warnf(format, args...) +} + +func (l *LogrusLogger) Errorf(format string, args ...any) { + l.logger.Errorf(format, args...) +} + +func main() { + // 创建 Logrus logger + logger := logrus.New() + logger.SetLevel(logrus.DebugLevel) + logger.SetFormatter(&logrus.TextFormatter{ + FullTimestamp: true, + ForceColors: true, + }) + logger.SetOutput(os.Stdout) + + // 创建 MCP 客户端 + client := mcp.NewDeepSeekClientWithOptions( + mcp.WithAPIKey(os.Getenv("DEEPSEEK_API_KEY")), + mcp.WithLogger(NewLogrusLogger(logger)), + mcp.WithMaxRetries(5), + ) + + // 调用 AI + logger.Info("开始调用 AI...") + result, err := client.CallWithMessages( + "你是一个专业的量化交易顾问", + "分析 BTC 当前走势", + ) + + if err != nil { + logger.WithError(err).Error("AI 调用失败") + return + } + + logger.WithField("result", result).Info("AI 调用成功") +} +``` + +## 🔍 输出示例 + +### 开发环境(Text 格式) + +``` +INFO[2024-01-15 10:30:45] 开始调用 AI... +INFO[2024-01-15 10:30:45] 📡 [Provider: deepseek, Model: deepseek-chat] Request AI Server: BaseURL: https://api.deepseek.com/v1 +DEBUG[2024-01-15 10:30:45] [Provider: deepseek, Model: deepseek-chat] UseFullURL: false +DEBUG[2024-01-15 10:30:45] [Provider: deepseek, Model: deepseek-chat] API Key: sk-x...xxx +INFO[2024-01-15 10:30:45] 📡 [MCP Provider: deepseek, Model: deepseek-chat] 请求 URL: https://api.deepseek.com/v1/chat/completions +INFO[2024-01-15 10:30:46] AI 调用成功 result="[AI 响应内容]" +``` + +### 生产环境(JSON 格式) + +```json +{"level":"info","msg":"开始调用 AI...","time":"2024-01-15T10:30:45+08:00"} +{"level":"info","msg":"📡 [Provider: deepseek, Model: deepseek-chat] Request AI Server: BaseURL: https://api.deepseek.com/v1","time":"2024-01-15T10:30:45+08:00"} +{"level":"info","msg":"AI 调用成功","result":"[AI 响应内容]","time":"2024-01-15T10:30:46+08:00"} +``` + +## 🎯 最佳实践 + +1. **生产环境使用 JSON 格式**,便于日志收集和分析 +2. **开发环境使用 Text 格式**,便于阅读 +3. **测试环境关闭日志**,提高测试速度 +4. **添加请求 ID**,方便追踪请求链路 +5. **记录错误堆栈**,便于问题排查 + +## 📊 性能优化 + +Logrus 在高并发场景下可能有性能瓶颈,推荐使用 [Zap](https://github.com/uber-go/zap) 获得更好的性能。 + +MCP 模块也支持 Zap,集成方式类似。 + +## 🔗 相关资源 + +- [Logrus 官方文档](https://github.com/sirupsen/logrus) +- [Zap 集成示例](./ZAP_INTEGRATION.md) +- [MCP README](./README.md) diff --git a/mcp/intro/MIGRATION_GUIDE.md b/mcp/intro/MIGRATION_GUIDE.md new file mode 100644 index 00000000..fa6655ec --- /dev/null +++ b/mcp/intro/MIGRATION_GUIDE.md @@ -0,0 +1,361 @@ +# MCP 模块重构迁移指南 + +## 📋 重构概览 + +本次重构采用**渐进式、向前兼容**的设计,现有代码**无需修改**即可继续使用,同时提供了更强大的新 API。 + +### 重构目标 + +- ✅ **100% 向前兼容** - 所有现有 API 继续工作 +- ✅ **模块独立** - 可作为独立 Go module 发布 +- ✅ **依赖可替换** - 日志、HTTP 客户端都可自定义 +- ✅ **易于测试** - 支持依赖注入和 mock +- ✅ **配置灵活** - 支持选项模式 (Functional Options) + +--- + +## 🔄 向前兼容保证 + +### ✅ 所有现有代码继续工作 + +```go +// ✅ 这些代码无需修改,继续正常工作 +mcpClient := mcp.New() +mcpClient.SetAPIKey(apiKey, url, model) + +// ✅ 这些也继续工作 +dsClient := mcp.NewDeepSeekClient() +qwenClient := mcp.NewQwenClient() +``` + +**重要**:虽然标记为 `Deprecated`,但这些函数会一直保留,不会被删除。 + +--- + +## 🆕 新特性使用指南 + +### 1. 基础用法(推荐) + +```go +// 新的推荐用法 +client := mcp.NewClient( + mcp.WithDeepSeekConfig("sk-xxx"), + mcp.WithTimeout(60 * time.Second), +) +``` + +### 2. 自定义日志 + +```go +// 使用自定义日志器(如 zap, logrus) +type MyLogger struct { + zapLogger *zap.Logger +} + +func (l *MyLogger) Info(msg string, args ...any) { + l.zapLogger.Sugar().Infof(msg, args...) +} + +// 注入自定义日志器 +client := mcp.NewClient( + mcp.WithLogger(&MyLogger{zapLogger}), +) +``` + +### 3. 自定义 HTTP 客户端 + +```go +// 添加代理、追踪、自定义 TLS 等 +customHTTP := &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tls.Config{/* ... */}, + }, +} + +client := mcp.NewClient( + mcp.WithHTTPClient(customHTTP), +) +``` + +### 4. 测试场景 + +```go +func TestMyCode(t *testing.T) { + // Mock HTTP 客户端 + mockHTTP := &MockHTTPClient{ + // 返回预设的响应 + } + + // 禁用日志 + client := mcp.NewClient( + mcp.WithHTTPClient(mockHTTP), + mcp.WithLogger(mcp.NewNoopLogger()), + ) + + // 测试... +} +``` + +### 5. 组合多个选项 + +```go +client := mcp.NewDeepSeekClientWithOptions( + mcp.WithAPIKey("sk-xxx"), + mcp.WithLogger(customLogger), + mcp.WithTimeout(60 * time.Second), + mcp.WithMaxRetries(5), + mcp.WithMaxTokens(4000), +) +``` + +--- + +## 📊 API 对比表 + +### 构造函数对比 + +| 旧 API (仍可用) | 新 API (推荐) | 说明 | +|----------------|--------------|------| +| `mcp.New()` | `mcp.NewClient(opts...)` | 支持选项模式 | +| `mcp.NewDeepSeekClient()` | `mcp.NewDeepSeekClientWithOptions(opts...)` | 支持自定义配置 | +| `mcp.NewQwenClient()` | `mcp.NewQwenClientWithOptions(opts...)` | 支持自定义配置 | + +### 配置选项 + +| 选项函数 | 说明 | 使用示例 | +|---------|------|---------| +| `WithLogger(logger)` | 自定义日志器 | `WithLogger(zapLogger)` | +| `WithHTTPClient(client)` | 自定义 HTTP 客户端 | `WithHTTPClient(customHTTP)` | +| `WithTimeout(duration)` | 设置超时 | `WithTimeout(60*time.Second)` | +| `WithMaxRetries(n)` | 设置重试次数 | `WithMaxRetries(5)` | +| `WithMaxTokens(n)` | 设置最大 token | `WithMaxTokens(4000)` | +| `WithTemperature(t)` | 设置温度参数 | `WithTemperature(0.7)` | +| `WithAPIKey(key)` | 设置 API Key | `WithAPIKey("sk-xxx")` | +| `WithDeepSeekConfig(key)` | 快速配置 DeepSeek | `WithDeepSeekConfig("sk-xxx")` | +| `WithQwenConfig(key)` | 快速配置 Qwen | `WithQwenConfig("sk-xxx")` | + +--- + +## 🔧 迁移步骤 + +### Phase 1: 继续使用现有代码(无需改动) + +```go +// trader/auto_trader.go 中的现有代码 +mcpClient := mcp.New() + +if config.AIModel == "qwen" { + mcpClient = mcp.NewQwenClient() + mcpClient.SetAPIKey(config.QwenKey, config.CustomAPIURL, config.CustomModelName) +} else { + mcpClient = mcp.NewDeepSeekClient() + mcpClient.SetAPIKey(config.DeepSeekKey, config.CustomAPIURL, config.CustomModelName) +} + +// ✅ 继续工作,无需修改 +``` + +### Phase 2: 可选升级到新 API(推荐) + +```go +// 升级后的代码(可选) +var mcpClient mcp.AIClient + +if config.AIModel == "qwen" { + mcpClient = mcp.NewQwenClientWithOptions( + mcp.WithAPIKey(config.QwenKey), + mcp.WithBaseURL(config.CustomAPIURL), + mcp.WithModel(config.CustomModelName), + ) +} else { + mcpClient = mcp.NewDeepSeekClientWithOptions( + mcp.WithAPIKey(config.DeepSeekKey), + mcp.WithBaseURL(config.CustomAPIURL), + mcp.WithModel(config.CustomModelName), + ) +} +``` + +### Phase 3: 添加自定义配置(高级) + +```go +// 添加自定义日志 +customLogger := &MyZapLogger{zap.NewProduction()} + +mcpClient := mcp.NewDeepSeekClientWithOptions( + mcp.WithAPIKey(config.DeepSeekKey), + mcp.WithLogger(customLogger), // 自定义日志 + mcp.WithTimeout(90 * time.Second), // 自定义超时 + mcp.WithMaxRetries(5), // 自定义重试次数 +) +``` + +--- + +## 🎯 实际使用场景 + +### 场景 1: 开发环境详细日志 + +```go +// 开发环境:使用详细日志 +devClient := mcp.NewClient( + mcp.WithDeepSeekConfig(apiKey), + mcp.WithLogger(&defaultLogger{}), // 详细日志 +) +``` + +### 场景 2: 生产环境结构化日志 + +```go +// 生产环境:使用 zap 结构化日志 +zapLogger, _ := zap.NewProduction() +prodClient := mcp.NewClient( + mcp.WithDeepSeekConfig(apiKey), + mcp.WithLogger(&ZapLogger{zapLogger}), +) +``` + +### 场景 3: 测试环境 Mock + +```go +// 测试环境:Mock HTTP 响应 +mockHTTP := &MockHTTPClient{ + Response: `{"choices":[{"message":{"content":"test"}}]}`, +} + +testClient := mcp.NewClient( + mcp.WithHTTPClient(mockHTTP), + mcp.WithLogger(mcp.NewNoopLogger()), // 禁用日志 +) +``` + +### 场景 4: 需要代理的网络环境 + +```go +// 使用代理 +proxyURL, _ := url.Parse("http://proxy.company.com:8080") +proxyClient := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, +} + +client := mcp.NewClient( + mcp.WithDeepSeekConfig(apiKey), + mcp.WithHTTPClient(proxyClient), +) +``` + +--- + +## 📦 作为独立模块发布 + +重构后,mcp 模块可以独立发布: + +### go.mod +```go +module github.com/yourorg/mcp + +go 1.21 + +// 无外部依赖! +``` + +### 使用方 +```go +import "github.com/yourorg/mcp" + +client := mcp.NewClient( + mcp.WithDeepSeekConfig("sk-xxx"), +) +``` + +--- + +## 🧪 测试支持 + +### Mock 示例 + +```go +package mypackage_test + +import ( + "testing" + "github.com/stretchr/testify/assert" + "nofx/mcp" +) + +type MockHTTPClient struct { + Response string + Error error +} + +func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) { + if m.Error != nil { + return nil, m.Error + } + + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(m.Response)), + }, nil +} + +func TestAIIntegration(t *testing.T) { + // Arrange + mockHTTP := &MockHTTPClient{ + Response: `{"choices":[{"message":{"content":"success"}}]}`, + } + + client := mcp.NewClient( + mcp.WithHTTPClient(mockHTTP), + mcp.WithLogger(mcp.NewNoopLogger()), + ) + + // Act + result, err := client.CallWithMessages("system", "user") + + // Assert + assert.NoError(t, err) + assert.Equal(t, "success", result) +} +``` + +--- + +## ⚠️ 注意事项 + +1. **向前兼容性** + - 所有 `Deprecated` 的 API 会永久保留 + - 现有代码可以继续使用,不会被破坏 + +2. **渐进式迁移** + - 不需要一次性迁移所有代码 + - 可以逐步采用新 API + +3. **配置优先级** + - 用户传入的选项优先级最高 + - 环境变量次之 + - 默认配置最低 + +4. **日志器接口** + - 可以适配任何日志库(zap, logrus, etc.) + - 测试时可以使用 `NewNoopLogger()` 禁用日志 + +--- + +## 📚 进一步阅读 + +- [选项模式详解](https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis) +- [依赖注入最佳实践](https://go.dev/blog/wire) +- [Go 接口设计原则](https://go.dev/blog/laws-of-reflection) + +--- + +## 🤝 贡献 + +欢迎提交 issue 和 PR! + +如有问题,请联系:[your-email@example.com] diff --git a/mcp/intro/README.md b/mcp/intro/README.md new file mode 100644 index 00000000..509b4c9c --- /dev/null +++ b/mcp/intro/README.md @@ -0,0 +1,379 @@ +# MCP - Model Context Protocol Client + +一个灵活、可扩展的 AI 模型客户端库,支持 DeepSeek、Qwen 等多种 AI 提供商。 + +## ✨ 特性 + +- 🔌 **多 Provider 支持** - DeepSeek、Qwen、OpenAI 兼容 API +- 🎯 **模板方法模式** - 固定流程,可扩展步骤 +- 🏗️ **构建器模式** - 支持多轮对话、Function Calling、精细参数控制 +- 📦 **零外部依赖** - 仅使用 Go 标准库 +- 🔧 **高度可配置** - 支持 Functional Options 模式 +- 🧪 **易于测试** - 支持依赖注入和 Mock +- ⚡ **向前兼容** - 现有代码无需修改 +- 📝 **丰富的日志** - 可替换的日志接口 + +## 🚀 快速开始 + +### 基础用法 + +```go +import "nofx/mcp" + +// 创建客户端 +client := mcp.NewClient( + mcp.WithDeepSeekConfig("sk-xxx"), +) + +// 调用 AI +result, err := client.CallWithMessages("system prompt", "user prompt") +if err != nil { + log.Fatal(err) +} + +fmt.Println(result) +``` + +### DeepSeek 客户端 + +```go +client := mcp.NewDeepSeekClientWithOptions( + mcp.WithAPIKey("sk-xxx"), + mcp.WithTimeout(60 * time.Second), +) +``` + +### Qwen 客户端 + +```go +client := mcp.NewQwenClientWithOptions( + mcp.WithAPIKey("sk-xxx"), + mcp.WithMaxTokens(4000), +) +``` + +### 🏗️ 构建器模式(高级功能) + +构建器模式支持多轮对话、精细参数控制、Function Calling 等高级功能。 + +#### 简单用法 + +```go +// 使用构建器创建请求 +request := mcp.NewRequestBuilder(). + WithSystemPrompt("You are helpful"). + WithUserPrompt("What is Go?"). + WithTemperature(0.8). + Build() + +result, err := client.CallWithRequest(request) +``` + +#### 多轮对话 + +```go +// 构建包含历史的多轮对话 +request := mcp.NewRequestBuilder(). + AddSystemMessage("You are a trading advisor"). + AddUserMessage("Analyze BTC"). + AddAssistantMessage("BTC is bullish..."). + AddUserMessage("What about entry point?"). // 继续对话 + WithTemperature(0.3). + Build() + +result, err := client.CallWithRequest(request) +``` + +#### 预设场景 + +```go +// 代码生成(低温度、精确) +request := mcp.ForCodeGeneration(). + WithUserPrompt("Generate a HTTP server"). + Build() + +// 创意写作(高温度、随机) +request := mcp.ForCreativeWriting(). + WithUserPrompt("Write a story"). + Build() + +// 聊天(平衡参数) +request := mcp.ForChat(). + WithUserPrompt("Hello"). + Build() +``` + +#### Function Calling + +```go +// 定义工具 +weatherParams := map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + }, +} + +request := mcp.NewRequestBuilder(). + WithUserPrompt("北京天气怎么样?"). + AddFunction("get_weather", "Get weather", weatherParams). + WithToolChoice("auto"). + Build() + +result, err := client.CallWithRequest(request) +``` + +## 📖 详细文档 + +- [构建器模式完整示例](./BUILDER_EXAMPLES.md) - 多轮对话、Function Calling、参数控制 +- [构建器模式价值分析](./BUILDER_PATTERN_BENEFITS.md) - 为什么引入构建器模式 +- [迁移指南](./MIGRATION_GUIDE.md) - 从旧 API 迁移到新 API +- [Logrus 集成](./LOGRUS_INTEGRATION.md) - 日志框架集成示例 +- [代码审查报告](./CODE_REVIEW.md) - 问题分析和修复记录 + +## 🎛️ 配置选项 + +### 依赖注入 + +```go +// 自定义日志器 +mcp.WithLogger(customLogger) + +// 自定义 HTTP 客户端 +mcp.WithHTTPClient(customHTTP) +``` + +### 超时和重试 + +```go +mcp.WithTimeout(60 * time.Second) +mcp.WithMaxRetries(5) +mcp.WithRetryWaitBase(3 * time.Second) +``` + +### AI 参数 + +```go +mcp.WithMaxTokens(4000) +mcp.WithTemperature(0.7) +``` + +### Provider 配置 + +```go +// 快速配置 DeepSeek +mcp.WithDeepSeekConfig("sk-xxx") + +// 快速配置 Qwen +mcp.WithQwenConfig("sk-xxx") + +// 自定义配置 +mcp.WithAPIKey("sk-xxx") +mcp.WithBaseURL("https://api.custom.com") +mcp.WithModel("gpt-4") +``` + +## 🧪 测试 + +```go +// 使用 Mock HTTP 客户端 +mockHTTP := &MockHTTPClient{ + Response: `{"choices":[{"message":{"content":"test"}}]}`, +} + +client := mcp.NewClient( + mcp.WithHTTPClient(mockHTTP), + mcp.WithLogger(mcp.NewNoopLogger()), // 禁用日志 +) +``` + +## 🏗️ 架构设计 + +### 模板方法模式 + +``` +CallWithMessages (固定重试流程) + ↓ +call (固定调用流程) + ↓ +hooks (可重写的步骤) + ├─ buildMCPRequestBody + ├─ marshalRequestBody + ├─ buildUrl + ├─ setAuthHeader + ├─ parseMCPResponse + └─ isRetryableError +``` + +### 接口分离 + +```go +// 公开接口(给外部使用) +type AIClient interface { + SetAPIKey(...) + SetTimeout(...) + CallWithMessages(...) (string, error) +} + +// 内部钩子接口(供子类重写) +type clientHooks interface { + buildMCPRequestBody(...) map[string]any + buildUrl() string + setAuthHeader(...) + marshalRequestBody(...) ([]byte, error) + parseMCPResponse(...) (string, error) + isRetryableError(...) bool +} +``` + +## 🔄 向前兼容 + +所有旧 API 继续工作: + +```go +// ✅ 旧代码无需修改 +client := mcp.New() +client.SetAPIKey("sk-xxx", "https://api.custom.com", "gpt-4") + +dsClient := mcp.NewDeepSeekClient() +dsClient.SetAPIKey("sk-xxx", "", "") +``` + +## 📦 作为独立模块使用 + +```go +// go.mod +module github.com/yourorg/yourproject + +require github.com/yourorg/mcp v1.0.0 +``` + +```go +// main.go +import "github.com/yourorg/mcp" + +client := mcp.NewClient( + mcp.WithDeepSeekConfig("sk-xxx"), +) +``` + +## 🤝 扩展自定义 Provider + +```go +type CustomProvider struct { + *mcp.Client +} + +// 重写特定钩子 +func (c *CustomProvider) buildUrl() string { + return c.BaseURL + "/custom/endpoint" +} + +func (c *CustomProvider) setAuthHeader(headers http.Header) { + headers.Set("X-Custom-Auth", c.APIKey) +} +``` + +## 📝 日志器适配示例 + +### Zap 日志器 + +```go +type ZapLogger struct { + logger *zap.Logger +} + +func (l *ZapLogger) Infof(format string, args ...any) { + l.logger.Sugar().Infof(format, args...) +} + +func (l *ZapLogger) Debugf(format string, args ...any) { + l.logger.Sugar().Debugf(format, args...) +} + +// 使用 +client := mcp.NewClient( + mcp.WithLogger(&ZapLogger{zapLogger}), +) +``` + +### Logrus 日志器 + +```go +type LogrusLogger struct { + logger *logrus.Logger +} + +func (l *LogrusLogger) Infof(format string, args ...any) { + l.logger.Infof(format, args...) +} + +func (l *LogrusLogger) Debugf(format string, args ...any) { + l.logger.Debugf(format, args...) +} +``` + +## 🎯 使用场景 + +### 开发环境 + +```go +devClient := mcp.NewClient( + mcp.WithDeepSeekConfig("sk-xxx"), + mcp.WithLogger(&customLogger{}), // 详细日志 +) +``` + +### 生产环境 + +```go +prodClient := mcp.NewClient( + mcp.WithDeepSeekConfig("sk-xxx"), + mcp.WithLogger(&zapLogger{}), // 结构化日志 + mcp.WithTimeout(30*time.Second), // 超时保护 + mcp.WithMaxRetries(3), // 重试保护 +) +``` + +### 测试环境 + +```go +testClient := mcp.NewClient( + mcp.WithHTTPClient(mockHTTP), + mcp.WithLogger(mcp.NewNoopLogger()), +) +``` + +## 📊 性能特性 + +- ✅ HTTP 连接复用 +- ✅ 智能重试机制 +- ✅ 可配置超时 +- ✅ 零分配日志(使用 NoopLogger) + +## 🛡️ 安全性 + +- ✅ API Key 部分脱敏日志 +- ✅ HTTPS 默认启用 +- ✅ 支持自定义 TLS 配置 +- ✅ 请求超时保护 + +## 📈 版本兼容性 + +- Go 1.18+ +- 向前兼容保证 +- 语义化版本管理 + +## 🤝 贡献 + +欢迎提交 Issue 和 Pull Request! + +## 📄 许可证 + +MIT License + +## 🔗 相关链接 + +- [DeepSeek API 文档](https://platform.deepseek.com/docs) +- [Qwen API 文档](https://help.aliyun.com/zh/dashscope/) +- [OpenAI API 文档](https://platform.openai.com/docs) diff --git a/mcp/logger.go b/mcp/logger.go new file mode 100644 index 00000000..863310db --- /dev/null +++ b/mcp/logger.go @@ -0,0 +1,68 @@ +package mcp + +import "log" + +// Logger 日志接口(抽象依赖) +// 使用 Printf 风格的方法名,方便集成 logrus、zap 等主流日志库 +type Logger interface { + Debugf(format string, args ...any) + Infof(format string, args ...any) + Warnf(format string, args ...any) + Errorf(format string, args ...any) +} + +// defaultLogger 默认日志实现(包装标准库 log) +type defaultLogger struct{} + +func (l *defaultLogger) Debugf(format string, args ...any) { + log.Printf("[DEBUG] "+format, args...) +} + +func (l *defaultLogger) Infof(format string, args ...any) { + log.Printf("[INFO] "+format, args...) +} + +func (l *defaultLogger) Warnf(format string, args ...any) { + log.Printf("[WARN] "+format, args...) +} + +func (l *defaultLogger) Errorf(format string, args ...any) { + log.Printf("[ERROR] "+format, args...) +} + +// noopLogger 空日志实现(测试时使用) +type noopLogger struct{} + +func (l *noopLogger) Debugf(format string, args ...any) {} +func (l *noopLogger) Infof(format string, args ...any) {} +func (l *noopLogger) Warnf(format string, args ...any) {} +func (l *noopLogger) Errorf(format string, args ...any) {} + +// NewNoopLogger 创建空日志器(测试使用) +func NewNoopLogger() Logger { + return &noopLogger{} +} + +// ============================================================ +// 适配第三方日志库示例 +// ============================================================ + +// Logrus 适配示例: +// type LogrusLogger struct { +// logger *logrus.Logger +// } +// +// func (l *LogrusLogger) Infof(format string, args ...any) { +// l.logger.Infof(format, args...) +// } +// +// Zap 适配示例: +// type ZapLogger struct { +// logger *zap.Logger +// } +// +// func (l *ZapLogger) Infof(format string, args ...any) { +// l.logger.Sugar().Infof(format, args...) +// } +// +// 然后通过 WithLogger(logger) 注入 diff --git a/mcp/mock_test.go b/mcp/mock_test.go new file mode 100644 index 00000000..6c39b099 --- /dev/null +++ b/mcp/mock_test.go @@ -0,0 +1,310 @@ +package mcp + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" +) + +// ============================================================ +// Mock Logger +// ============================================================ + +// MockLogger Mock 日志器(用于测试) +type MockLogger struct { + mu sync.Mutex + Logs []LogEntry + Enabled bool // 是否启用日志记录 +} + +// LogEntry 日志条目 +type LogEntry struct { + Level string + Format string + Args []any + Message string // 格式化后的消息 +} + +func NewMockLogger() *MockLogger { + return &MockLogger{ + Logs: make([]LogEntry, 0), + Enabled: true, + } +} + +func (m *MockLogger) Debugf(format string, args ...any) { + m.log("DEBUG", format, args...) +} + +func (m *MockLogger) Infof(format string, args ...any) { + m.log("INFO", format, args...) +} + +func (m *MockLogger) Warnf(format string, args ...any) { + m.log("WARN", format, args...) +} + +func (m *MockLogger) Errorf(format string, args ...any) { + m.log("ERROR", format, args...) +} + +func (m *MockLogger) log(level, format string, args ...any) { + if !m.Enabled { + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + message := fmt.Sprintf(format, args...) + m.Logs = append(m.Logs, LogEntry{ + Level: level, + Format: format, + Args: args, + Message: message, + }) +} + +// GetLogs 获取所有日志 +func (m *MockLogger) GetLogs() []LogEntry { + m.mu.Lock() + defer m.mu.Unlock() + return append([]LogEntry{}, m.Logs...) +} + +// GetLogsByLevel 获取指定级别的日志 +func (m *MockLogger) GetLogsByLevel(level string) []LogEntry { + m.mu.Lock() + defer m.mu.Unlock() + + var result []LogEntry + for _, log := range m.Logs { + if log.Level == level { + result = append(result, log) + } + } + return result +} + +// Clear 清空日志 +func (m *MockLogger) Clear() { + m.mu.Lock() + defer m.mu.Unlock() + m.Logs = make([]LogEntry, 0) +} + +// HasLog 检查是否包含指定消息 +func (m *MockLogger) HasLog(level, message string) bool { + m.mu.Lock() + defer m.mu.Unlock() + + for _, log := range m.Logs { + if log.Level == level && log.Message == message { + return true + } + } + return false +} + +// ============================================================ +// Mock HTTP Client (实现 http.RoundTripper) +// ============================================================ + +// MockHTTPClient Mock HTTP 客户端(实现 http.RoundTripper) +type MockHTTPClient struct { + mu sync.Mutex + + // 配置 + Response string + StatusCode int + Error error + ResponseFunc func(req *http.Request) (*http.Response, error) // 自定义响应函数 + + // 记录 + Requests []*http.Request +} + +func NewMockHTTPClient() *MockHTTPClient { + return &MockHTTPClient{ + StatusCode: http.StatusOK, + Requests: make([]*http.Request, 0), + } +} + +// ToHTTPClient 转换为 http.Client +func (m *MockHTTPClient) ToHTTPClient() *http.Client { + return &http.Client{ + Transport: m, + } +} + +// RoundTrip 实现 http.RoundTripper 接口 +func (m *MockHTTPClient) RoundTrip(req *http.Request) (*http.Response, error) { + m.mu.Lock() + defer m.mu.Unlock() + + // 记录请求 + m.Requests = append(m.Requests, req) + + // 如果有自定义响应函数,使用它 + if m.ResponseFunc != nil { + return m.ResponseFunc(req) + } + + // 如果设置了错误,返回错误 + if m.Error != nil { + return nil, m.Error + } + + // 返回模拟响应 + resp := &http.Response{ + StatusCode: m.StatusCode, + Body: io.NopCloser(bytes.NewBufferString(m.Response)), + Header: make(http.Header), + } + + return resp, nil +} + +// GetRequests 获取所有请求 +func (m *MockHTTPClient) GetRequests() []*http.Request { + m.mu.Lock() + defer m.mu.Unlock() + return append([]*http.Request{}, m.Requests...) +} + +// GetLastRequest 获取最后一次请求 +func (m *MockHTTPClient) GetLastRequest() *http.Request { + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.Requests) == 0 { + return nil + } + return m.Requests[len(m.Requests)-1] +} + +// Reset 重置状态 +func (m *MockHTTPClient) Reset() { + m.mu.Lock() + defer m.mu.Unlock() + m.Requests = make([]*http.Request, 0) +} + +// SetSuccessResponse 设置成功响应 +func (m *MockHTTPClient) SetSuccessResponse(content string) { + m.mu.Lock() + defer m.mu.Unlock() + + m.StatusCode = http.StatusOK + m.Response = `{"choices":[{"message":{"content":"` + content + `"}}]}` + m.Error = nil +} + +// SetErrorResponse 设置错误响应 +func (m *MockHTTPClient) SetErrorResponse(statusCode int, message string) { + m.mu.Lock() + defer m.mu.Unlock() + + m.StatusCode = statusCode + m.Response = message + m.Error = nil +} + +// SetNetworkError 设置网络错误 +func (m *MockHTTPClient) SetNetworkError(err error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.Error = err +} + +// ============================================================ +// Mock Client Hooks (用于测试钩子机制) +// ============================================================ + +// MockClientHooks Mock 客户端钩子 +type MockClientHooks struct { + BuildRequestBodyCalled int + BuildUrlCalled int + SetAuthHeaderCalled int + MarshalRequestCalled int + ParseResponseCalled int + IsRetryableErrorCalled int + + // 自定义返回值 + BuildUrlFunc func() string + ParseResponseFunc func([]byte) (string, error) + IsRetryableErrorFunc func(error) bool + BuildRequestBodyFunc func(string, string) map[string]any + MarshalRequestBodyFunc func(map[string]any) ([]byte, error) +} + +func NewMockClientHooks() *MockClientHooks { + return &MockClientHooks{} +} + +func (m *MockClientHooks) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any { + m.BuildRequestBodyCalled++ + if m.BuildRequestBodyFunc != nil { + return m.BuildRequestBodyFunc(systemPrompt, userPrompt) + } + return map[string]any{ + "model": "test-model", + "messages": []map[string]string{ + {"role": "system", "content": systemPrompt}, + {"role": "user", "content": userPrompt}, + }, + } +} + +func (m *MockClientHooks) buildUrl() string { + m.BuildUrlCalled++ + if m.BuildUrlFunc != nil { + return m.BuildUrlFunc() + } + return "https://api.test.com/chat/completions" +} + +func (m *MockClientHooks) setAuthHeader(headers http.Header) { + m.SetAuthHeaderCalled++ + headers.Set("Authorization", "Bearer test-key") +} + +func (m *MockClientHooks) marshalRequestBody(body map[string]any) ([]byte, error) { + m.MarshalRequestCalled++ + if m.MarshalRequestBodyFunc != nil { + return m.MarshalRequestBodyFunc(body) + } + return json.Marshal(body) +} + +func (m *MockClientHooks) parseMCPResponse(body []byte) (string, error) { + m.ParseResponseCalled++ + if m.ParseResponseFunc != nil { + return m.ParseResponseFunc(body) + } + return "mocked response", nil +} + +func (m *MockClientHooks) isRetryableError(err error) bool { + m.IsRetryableErrorCalled++ + if m.IsRetryableErrorFunc != nil { + return m.IsRetryableErrorFunc(err) + } + return false +} + +func (m *MockClientHooks) buildRequest(url string, jsonData []byte) (*http.Request, error) { + req, _ := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + m.setAuthHeader(req.Header) + return req, nil +} + +func (m *MockClientHooks) call(systemPrompt, userPrompt string) (string, error) { + return "mocked call result", nil +} diff --git a/mcp/options.go b/mcp/options.go new file mode 100644 index 00000000..c460224f --- /dev/null +++ b/mcp/options.go @@ -0,0 +1,162 @@ +package mcp + +import ( + "net/http" + "time" +) + +// ClientOption 客户端选项函数(Functional Options 模式) +type ClientOption func(*Config) + +// ============================================================ +// 依赖注入选项 +// ============================================================ + +// WithLogger 设置自定义日志器 +// +// 使用示例: +// client := mcp.NewClient(mcp.WithLogger(customLogger)) +func WithLogger(logger Logger) ClientOption { + return func(c *Config) { + c.Logger = logger + } +} + +// WithHTTPClient 设置自定义 HTTP 客户端 +// +// 使用示例: +// httpClient := &http.Client{Timeout: 60 * time.Second} +// client := mcp.NewClient(mcp.WithHTTPClient(httpClient)) +func WithHTTPClient(client *http.Client) ClientOption { + return func(c *Config) { + c.HTTPClient = client + } +} + +// ============================================================ +// 超时和重试选项 +// ============================================================ + +// WithTimeout 设置请求超时时间 +// +// 使用示例: +// client := mcp.NewClient(mcp.WithTimeout(60 * time.Second)) +func WithTimeout(timeout time.Duration) ClientOption { + return func(c *Config) { + c.Timeout = timeout + c.HTTPClient.Timeout = timeout + } +} + +// WithMaxRetries 设置最大重试次数 +// +// 使用示例: +// client := mcp.NewClient(mcp.WithMaxRetries(5)) +func WithMaxRetries(maxRetries int) ClientOption { + return func(c *Config) { + c.MaxRetries = maxRetries + } +} + +// WithRetryWaitBase 设置重试等待基础时长 +// +// 使用示例: +// client := mcp.NewClient(mcp.WithRetryWaitBase(3 * time.Second)) +func WithRetryWaitBase(waitTime time.Duration) ClientOption { + return func(c *Config) { + c.RetryWaitBase = waitTime + } +} + +// ============================================================ +// AI 参数选项 +// ============================================================ + +// WithMaxTokens 设置最大 token 数 +// +// 使用示例: +// client := mcp.NewClient(mcp.WithMaxTokens(4000)) +func WithMaxTokens(maxTokens int) ClientOption { + return func(c *Config) { + c.MaxTokens = maxTokens + } +} + +// WithTemperature 设置温度参数 +// +// 使用示例: +// client := mcp.NewClient(mcp.WithTemperature(0.7)) +func WithTemperature(temperature float64) ClientOption { + return func(c *Config) { + c.Temperature = temperature + } +} + +// ============================================================ +// Provider 配置选项 +// ============================================================ + +// WithAPIKey 设置 API Key +func WithAPIKey(apiKey string) ClientOption { + return func(c *Config) { + c.APIKey = apiKey + } +} + +// WithBaseURL 设置基础 URL +func WithBaseURL(baseURL string) ClientOption { + return func(c *Config) { + c.BaseURL = baseURL + } +} + +// WithModel 设置模型名称 +func WithModel(model string) ClientOption { + return func(c *Config) { + c.Model = model + } +} + +// WithProvider 设置提供商 +func WithProvider(provider string) ClientOption { + return func(c *Config) { + c.Provider = provider + } +} + +// WithUseFullURL 设置是否使用完整 URL +func WithUseFullURL(useFullURL bool) ClientOption { + return func(c *Config) { + c.UseFullURL = useFullURL + } +} + +// ============================================================ +// 组合选项(便捷方法) +// ============================================================ + +// WithDeepSeekConfig 设置 DeepSeek 配置 +// +// 使用示例: +// client := mcp.NewClient(mcp.WithDeepSeekConfig("sk-xxx")) +func WithDeepSeekConfig(apiKey string) ClientOption { + return func(c *Config) { + c.Provider = ProviderDeepSeek + c.APIKey = apiKey + c.BaseURL = DefaultDeepSeekBaseURL + c.Model = DefaultDeepSeekModel + } +} + +// WithQwenConfig 设置 Qwen 配置 +// +// 使用示例: +// client := mcp.NewClient(mcp.WithQwenConfig("sk-xxx")) +func WithQwenConfig(apiKey string) ClientOption { + return func(c *Config) { + c.Provider = ProviderQwen + c.APIKey = apiKey + c.BaseURL = DefaultQwenBaseURL + c.Model = DefaultQwenModel + } +} diff --git a/mcp/options_test.go b/mcp/options_test.go new file mode 100644 index 00000000..67ad5b9b --- /dev/null +++ b/mcp/options_test.go @@ -0,0 +1,365 @@ +package mcp + +import ( + "net/http" + "testing" + "time" +) + +// ============================================================ +// 测试基础选项 +// ============================================================ + +func TestWithProvider(t *testing.T) { + cfg := DefaultConfig() + WithProvider("custom-provider")(cfg) + + if cfg.Provider != "custom-provider" { + t.Errorf("expected 'custom-provider', got '%s'", cfg.Provider) + } +} + +func TestWithAPIKey(t *testing.T) { + cfg := DefaultConfig() + WithAPIKey("sk-test-key")(cfg) + + if cfg.APIKey != "sk-test-key" { + t.Errorf("expected 'sk-test-key', got '%s'", cfg.APIKey) + } +} + +func TestWithBaseURL(t *testing.T) { + cfg := DefaultConfig() + WithBaseURL("https://api.test.com")(cfg) + + if cfg.BaseURL != "https://api.test.com" { + t.Errorf("expected 'https://api.test.com', got '%s'", cfg.BaseURL) + } +} + +func TestWithModel(t *testing.T) { + cfg := DefaultConfig() + WithModel("test-model")(cfg) + + if cfg.Model != "test-model" { + t.Errorf("expected 'test-model', got '%s'", cfg.Model) + } +} + +func TestWithMaxTokens(t *testing.T) { + cfg := DefaultConfig() + WithMaxTokens(4000)(cfg) + + if cfg.MaxTokens != 4000 { + t.Errorf("expected 4000, got %d", cfg.MaxTokens) + } +} + +func TestWithTemperature(t *testing.T) { + cfg := DefaultConfig() + WithTemperature(0.8)(cfg) + + if cfg.Temperature != 0.8 { + t.Errorf("expected 0.8, got %f", cfg.Temperature) + } +} + +func TestWithUseFullURL(t *testing.T) { + cfg := DefaultConfig() + WithUseFullURL(true)(cfg) + + if !cfg.UseFullURL { + t.Error("UseFullURL should be true") + } +} + +func TestWithMaxRetries(t *testing.T) { + cfg := DefaultConfig() + WithMaxRetries(5)(cfg) + + if cfg.MaxRetries != 5 { + t.Errorf("expected 5, got %d", cfg.MaxRetries) + } +} + +func TestWithTimeout(t *testing.T) { + cfg := DefaultConfig() + WithTimeout(60 * time.Second)(cfg) + + if cfg.Timeout != 60*time.Second { + t.Errorf("expected 60s, got %v", cfg.Timeout) + } +} + +func TestWithLogger(t *testing.T) { + cfg := DefaultConfig() + mockLogger := NewMockLogger() + WithLogger(mockLogger)(cfg) + + if cfg.Logger != mockLogger { + t.Error("Logger should be set to mockLogger") + } +} + +func TestWithHTTPClient(t *testing.T) { + cfg := DefaultConfig() + customClient := &http.Client{Timeout: 30 * time.Second} + WithHTTPClient(customClient)(cfg) + + if cfg.HTTPClient != customClient { + t.Error("HTTPClient should be set to customClient") + } + + if cfg.HTTPClient.Timeout != 30*time.Second { + t.Errorf("expected 30s, got %v", cfg.HTTPClient.Timeout) + } +} + +// ============================================================ +// 测试预设配置选项 +// ============================================================ + +func TestWithDeepSeekConfig(t *testing.T) { + cfg := DefaultConfig() + WithDeepSeekConfig("sk-deepseek-key")(cfg) + + if cfg.Provider != ProviderDeepSeek { + t.Errorf("Provider should be '%s', got '%s'", ProviderDeepSeek, cfg.Provider) + } + + if cfg.APIKey != "sk-deepseek-key" { + t.Errorf("APIKey should be 'sk-deepseek-key', got '%s'", cfg.APIKey) + } + + if cfg.BaseURL != DefaultDeepSeekBaseURL { + t.Errorf("BaseURL should be '%s', got '%s'", DefaultDeepSeekBaseURL, cfg.BaseURL) + } + + if cfg.Model != DefaultDeepSeekModel { + t.Errorf("Model should be '%s', got '%s'", DefaultDeepSeekModel, cfg.Model) + } +} + +func TestWithQwenConfig(t *testing.T) { + cfg := DefaultConfig() + WithQwenConfig("sk-qwen-key")(cfg) + + if cfg.Provider != ProviderQwen { + t.Errorf("Provider should be '%s', got '%s'", ProviderQwen, cfg.Provider) + } + + if cfg.APIKey != "sk-qwen-key" { + t.Errorf("APIKey should be 'sk-qwen-key', got '%s'", cfg.APIKey) + } + + if cfg.BaseURL != DefaultQwenBaseURL { + t.Errorf("BaseURL should be '%s', got '%s'", DefaultQwenBaseURL, cfg.BaseURL) + } + + if cfg.Model != DefaultQwenModel { + t.Errorf("Model should be '%s', got '%s'", DefaultQwenModel, cfg.Model) + } +} + +// ============================================================ +// 测试选项组合 +// ============================================================ + +func TestMultipleOptions(t *testing.T) { + mockLogger := NewMockLogger() + + cfg := DefaultConfig() + + // 应用多个选项 + options := []ClientOption{ + WithProvider("test-provider"), + WithAPIKey("sk-test-key"), + WithBaseURL("https://api.test.com"), + WithModel("test-model"), + WithMaxTokens(4000), + WithTemperature(0.8), + WithLogger(mockLogger), + WithTimeout(60 * time.Second), + } + + for _, opt := range options { + opt(cfg) + } + + // 验证所有选项都被应用 + if cfg.Provider != "test-provider" { + t.Error("Provider should be set") + } + + if cfg.APIKey != "sk-test-key" { + t.Error("APIKey should be set") + } + + if cfg.BaseURL != "https://api.test.com" { + t.Error("BaseURL should be set") + } + + if cfg.Model != "test-model" { + t.Error("Model should be set") + } + + if cfg.MaxTokens != 4000 { + t.Error("MaxTokens should be 4000") + } + + if cfg.Temperature != 0.8 { + t.Error("Temperature should be 0.8") + } + + if cfg.Logger != mockLogger { + t.Error("Logger should be mockLogger") + } + + if cfg.Timeout != 60*time.Second { + t.Error("Timeout should be 60s") + } +} + +func TestOptionsOverride(t *testing.T) { + cfg := DefaultConfig() + + // 先应用 DeepSeek 配置 + WithDeepSeekConfig("sk-deepseek-key")(cfg) + + // 然后覆盖某些选项 + WithModel("custom-model")(cfg) + WithMaxTokens(5000)(cfg) + + // 验证覆盖成功 + if cfg.Model != "custom-model" { + t.Errorf("Model should be overridden to 'custom-model', got '%s'", cfg.Model) + } + + if cfg.MaxTokens != 5000 { + t.Errorf("MaxTokens should be overridden to 5000, got %d", cfg.MaxTokens) + } + + // 验证其他 DeepSeek 配置保持不变 + if cfg.Provider != ProviderDeepSeek { + t.Error("Provider should still be DeepSeek") + } + + if cfg.BaseURL != DefaultDeepSeekBaseURL { + t.Error("BaseURL should still be DeepSeek default") + } +} + +// ============================================================ +// 测试与客户端集成 +// ============================================================ + +func TestOptionsWithNewClient(t *testing.T) { + mockLogger := NewMockLogger() + + client := NewClient( + WithProvider("test-provider"), + WithAPIKey("sk-test-key"), + WithModel("test-model"), + WithLogger(mockLogger), + WithMaxTokens(4000), + ) + + c := client.(*Client) + + // 验证选项被正确应用到客户端 + if c.Provider != "test-provider" { + t.Error("Provider should be set from options") + } + + if c.APIKey != "sk-test-key" { + t.Error("APIKey should be set from options") + } + + if c.Model != "test-model" { + t.Error("Model should be set from options") + } + + if c.logger != mockLogger { + t.Error("logger should be set from options") + } + + if c.MaxTokens != 4000 { + t.Error("MaxTokens should be 4000") + } +} + +func TestOptionsWithDeepSeekClient(t *testing.T) { + mockLogger := NewMockLogger() + + client := NewDeepSeekClientWithOptions( + WithAPIKey("sk-deepseek-key"), + WithLogger(mockLogger), + WithMaxTokens(5000), + ) + + dsClient := client.(*DeepSeekClient) + + // 验证 DeepSeek 默认值 + if dsClient.Provider != ProviderDeepSeek { + t.Error("Provider should be DeepSeek") + } + + if dsClient.BaseURL != DefaultDeepSeekBaseURL { + t.Error("BaseURL should be DeepSeek default") + } + + if dsClient.Model != DefaultDeepSeekModel { + t.Error("Model should be DeepSeek default") + } + + // 验证自定义选项 + if dsClient.APIKey != "sk-deepseek-key" { + t.Error("APIKey should be set from options") + } + + if dsClient.logger != mockLogger { + t.Error("logger should be set from options") + } + + if dsClient.MaxTokens != 5000 { + t.Error("MaxTokens should be 5000") + } +} + +func TestOptionsWithQwenClient(t *testing.T) { + mockLogger := NewMockLogger() + + client := NewQwenClientWithOptions( + WithAPIKey("sk-qwen-key"), + WithLogger(mockLogger), + WithMaxTokens(6000), + ) + + qwenClient := client.(*QwenClient) + + // 验证 Qwen 默认值 + if qwenClient.Provider != ProviderQwen { + t.Error("Provider should be Qwen") + } + + if qwenClient.BaseURL != DefaultQwenBaseURL { + t.Error("BaseURL should be Qwen default") + } + + if qwenClient.Model != DefaultQwenModel { + t.Error("Model should be Qwen default") + } + + // 验证自定义选项 + if qwenClient.APIKey != "sk-qwen-key" { + t.Error("APIKey should be set from options") + } + + if qwenClient.logger != mockLogger { + t.Error("logger should be set from options") + } + + if qwenClient.MaxTokens != 6000 { + t.Error("MaxTokens should be 6000") + } +} diff --git a/mcp/qwen_client.go b/mcp/qwen_client.go index e56ed42d..f790d08d 100644 --- a/mcp/qwen_client.go +++ b/mcp/qwen_client.go @@ -1,7 +1,6 @@ package mcp import ( - "log" "net/http" ) @@ -15,36 +14,67 @@ type QwenClient struct { *Client } +// NewQwenClient 创建 Qwen 客户端(向前兼容) +// +// Deprecated: 推荐使用 NewQwenClientWithOptions 以获得更好的灵活性 func NewQwenClient() AIClient { - client := New().(*Client) - client.Provider = ProviderQwen - client.Model = DefaultQwenModel - client.BaseURL = DefaultQwenBaseURL - return &QwenClient{ - Client: client, + return NewQwenClientWithOptions() +} + +// NewQwenClientWithOptions 创建 Qwen 客户端(支持选项模式) +// +// 使用示例: +// // 基础用法 +// client := mcp.NewQwenClientWithOptions() +// +// // 自定义配置 +// client := mcp.NewQwenClientWithOptions( +// mcp.WithAPIKey("sk-xxx"), +// mcp.WithLogger(customLogger), +// mcp.WithTimeout(60*time.Second), +// ) +func NewQwenClientWithOptions(opts ...ClientOption) AIClient { + // 1. 创建 Qwen 预设选项 + qwenOpts := []ClientOption{ + WithProvider(ProviderQwen), + WithModel(DefaultQwenModel), + WithBaseURL(DefaultQwenBaseURL), } + + // 2. 合并用户选项(用户选项优先级更高) + allOpts := append(qwenOpts, opts...) + + // 3. 创建基础客户端 + baseClient := NewClient(allOpts...).(*Client) + + // 4. 创建 Qwen 客户端 + qwenClient := &QwenClient{ + Client: baseClient, + } + + // 5. 设置 hooks 指向 QwenClient(实现动态分派) + baseClient.hooks = qwenClient + + return qwenClient } func (qwenClient *QwenClient) SetAPIKey(apiKey string, customURL string, customModel string) { - if qwenClient.Client == nil { - qwenClient.Client = New().(*Client) - } - qwenClient.Client.APIKey = apiKey + qwenClient.APIKey = apiKey if len(apiKey) > 8 { - log.Printf("🔧 [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) + qwenClient.logger.Infof("🔧 [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) } if customURL != "" { - qwenClient.Client.BaseURL = customURL - log.Printf("🔧 [MCP] Qwen 使用自定义 BaseURL: %s", customURL) + qwenClient.BaseURL = customURL + qwenClient.logger.Infof("🔧 [MCP] Qwen 使用自定义 BaseURL: %s", customURL) } else { - log.Printf("🔧 [MCP] Qwen 使用默认 BaseURL: %s", qwenClient.Client.BaseURL) + qwenClient.logger.Infof("🔧 [MCP] Qwen 使用默认 BaseURL: %s", qwenClient.BaseURL) } if customModel != "" { - qwenClient.Client.Model = customModel - log.Printf("🔧 [MCP] Qwen 使用自定义 Model: %s", customModel) + qwenClient.Model = customModel + qwenClient.logger.Infof("🔧 [MCP] Qwen 使用自定义 Model: %s", customModel) } else { - log.Printf("🔧 [MCP] Qwen 使用默认 Model: %s", qwenClient.Client.Model) + qwenClient.logger.Infof("🔧 [MCP] Qwen 使用默认 Model: %s", qwenClient.Model) } } diff --git a/mcp/qwen_client_test.go b/mcp/qwen_client_test.go new file mode 100644 index 00000000..d8f0c44c --- /dev/null +++ b/mcp/qwen_client_test.go @@ -0,0 +1,272 @@ +package mcp + +import ( + "testing" + "time" +) + +// ============================================================ +// 测试 QwenClient 创建和配置 +// ============================================================ + +func TestNewQwenClient_Default(t *testing.T) { + client := NewQwenClient() + + if client == nil { + t.Fatal("client should not be nil") + } + + // 类型断言检查 + qwenClient, ok := client.(*QwenClient) + if !ok { + t.Fatal("client should be *QwenClient") + } + + // 验证默认值 + if qwenClient.Provider != ProviderQwen { + t.Errorf("Provider should be '%s', got '%s'", ProviderQwen, qwenClient.Provider) + } + + if qwenClient.BaseURL != DefaultQwenBaseURL { + t.Errorf("BaseURL should be '%s', got '%s'", DefaultQwenBaseURL, qwenClient.BaseURL) + } + + if qwenClient.Model != DefaultQwenModel { + t.Errorf("Model should be '%s', got '%s'", DefaultQwenModel, qwenClient.Model) + } + + if qwenClient.logger == nil { + t.Error("logger should not be nil") + } + + if qwenClient.httpClient == nil { + t.Error("httpClient should not be nil") + } +} + +func TestNewQwenClientWithOptions(t *testing.T) { + mockLogger := NewMockLogger() + customModel := "qwen-plus" + customAPIKey := "sk-custom-qwen-key" + + client := NewQwenClientWithOptions( + WithLogger(mockLogger), + WithModel(customModel), + WithAPIKey(customAPIKey), + WithMaxTokens(4000), + ) + + qwenClient := client.(*QwenClient) + + // 验证自定义选项被应用 + if qwenClient.logger != mockLogger { + t.Error("logger should be set from option") + } + + if qwenClient.Model != customModel { + t.Error("Model should be set from option") + } + + if qwenClient.APIKey != customAPIKey { + t.Error("APIKey should be set from option") + } + + if qwenClient.MaxTokens != 4000 { + t.Error("MaxTokens should be 4000") + } + + // 验证 Qwen 默认值仍然保留 + if qwenClient.Provider != ProviderQwen { + t.Errorf("Provider should still be '%s'", ProviderQwen) + } + + if qwenClient.BaseURL != DefaultQwenBaseURL { + t.Errorf("BaseURL should still be '%s'", DefaultQwenBaseURL) + } +} + +// ============================================================ +// 测试 SetAPIKey +// ============================================================ + +func TestQwenClient_SetAPIKey(t *testing.T) { + mockLogger := NewMockLogger() + client := NewQwenClientWithOptions( + WithLogger(mockLogger), + ) + + qwenClient := client.(*QwenClient) + + // 测试设置 API Key(默认 URL 和 Model) + qwenClient.SetAPIKey("sk-test-key-12345678", "", "") + + if qwenClient.APIKey != "sk-test-key-12345678" { + t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", qwenClient.APIKey) + } + + // 验证日志记录 + logs := mockLogger.GetLogsByLevel("INFO") + if len(logs) == 0 { + t.Error("should have logged API key setting") + } + + // 验证 BaseURL 和 Model 保持默认 + if qwenClient.BaseURL != DefaultQwenBaseURL { + t.Error("BaseURL should remain default") + } + + if qwenClient.Model != DefaultQwenModel { + t.Error("Model should remain default") + } +} + +func TestQwenClient_SetAPIKey_WithCustomURL(t *testing.T) { + mockLogger := NewMockLogger() + client := NewQwenClientWithOptions( + WithLogger(mockLogger), + ) + + qwenClient := client.(*QwenClient) + + customURL := "https://custom.qwen.api.com/v1" + qwenClient.SetAPIKey("sk-test-key-12345678", customURL, "") + + if qwenClient.BaseURL != customURL { + t.Errorf("BaseURL should be '%s', got '%s'", customURL, qwenClient.BaseURL) + } + + // 验证日志记录 + logs := mockLogger.GetLogsByLevel("INFO") + hasCustomURLLog := false + for _, log := range logs { + if log.Format == "🔧 [MCP] Qwen 使用自定义 BaseURL: %s" { + hasCustomURLLog = true + break + } + } + + if !hasCustomURLLog { + t.Error("should have logged custom BaseURL") + } +} + +func TestQwenClient_SetAPIKey_WithCustomModel(t *testing.T) { + mockLogger := NewMockLogger() + client := NewQwenClientWithOptions( + WithLogger(mockLogger), + ) + + qwenClient := client.(*QwenClient) + + customModel := "qwen-turbo" + qwenClient.SetAPIKey("sk-test-key-12345678", "", customModel) + + if qwenClient.Model != customModel { + t.Errorf("Model should be '%s', got '%s'", customModel, qwenClient.Model) + } + + // 验证日志记录 + logs := mockLogger.GetLogsByLevel("INFO") + hasCustomModelLog := false + for _, log := range logs { + if log.Format == "🔧 [MCP] Qwen 使用自定义 Model: %s" { + hasCustomModelLog = true + break + } + } + + if !hasCustomModelLog { + t.Error("should have logged custom Model") + } +} + +// ============================================================ +// 测试集成功能 +// ============================================================ + +func TestQwenClient_CallWithMessages_Success(t *testing.T) { + mockHTTP := NewMockHTTPClient() + mockHTTP.SetSuccessResponse("Qwen AI response") + mockLogger := NewMockLogger() + + client := NewQwenClientWithOptions( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(mockLogger), + WithAPIKey("sk-test-key"), + ) + + result, err := client.CallWithMessages("system prompt", "user prompt") + + if err != nil { + t.Fatalf("should not error: %v", err) + } + + if result != "Qwen AI response" { + t.Errorf("expected 'Qwen AI response', got '%s'", result) + } + + // 验证请求 + requests := mockHTTP.GetRequests() + if len(requests) != 1 { + t.Fatalf("expected 1 request, got %d", len(requests)) + } + + req := requests[0] + + // 验证 URL + expectedURL := DefaultQwenBaseURL + "/chat/completions" + if req.URL.String() != expectedURL { + t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String()) + } + + // 验证 Authorization header + authHeader := req.Header.Get("Authorization") + if authHeader != "Bearer sk-test-key" { + t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader) + } + + // 验证 Content-Type + if req.Header.Get("Content-Type") != "application/json" { + t.Error("Content-Type should be application/json") + } +} + +func TestQwenClient_Timeout(t *testing.T) { + client := NewQwenClientWithOptions( + WithTimeout(30 * time.Second), + ) + + qwenClient := client.(*QwenClient) + + if qwenClient.httpClient.Timeout != 30*time.Second { + t.Errorf("expected timeout 30s, got %v", qwenClient.httpClient.Timeout) + } + + // 测试 SetTimeout + client.SetTimeout(60 * time.Second) + + if qwenClient.httpClient.Timeout != 60*time.Second { + t.Errorf("expected timeout 60s after SetTimeout, got %v", qwenClient.httpClient.Timeout) + } +} + +// ============================================================ +// 测试 hooks 机制 +// ============================================================ + +func TestQwenClient_HooksIntegration(t *testing.T) { + client := NewQwenClientWithOptions() + qwenClient := client.(*QwenClient) + + // 验证 hooks 指向 qwenClient 自己(实现多态) + if qwenClient.hooks != qwenClient { + t.Error("hooks should point to qwenClient for polymorphism") + } + + // 验证 buildUrl 使用 Qwen 配置 + url := qwenClient.buildUrl() + expectedURL := DefaultQwenBaseURL + "/chat/completions" + if url != expectedURL { + t.Errorf("expected URL '%s', got '%s'", expectedURL, url) + } +} diff --git a/mcp/request.go b/mcp/request.go new file mode 100644 index 00000000..d706dd63 --- /dev/null +++ b/mcp/request.go @@ -0,0 +1,72 @@ +package mcp + +// Message 表示一条对话消息 +type Message struct { + Role string `json:"role"` // "system", "user", "assistant" + Content string `json:"content"` // 消息内容 +} + +// Tool 表示 AI 可以调用的工具/函数 +type Tool struct { + Type string `json:"type"` // 通常为 "function" + Function FunctionDef `json:"function"` // 函数定义 +} + +// FunctionDef 函数定义 +type FunctionDef struct { + Name string `json:"name"` // 函数名 + Description string `json:"description,omitempty"` // 函数描述 + Parameters map[string]any `json:"parameters,omitempty"` // 参数 schema (JSON Schema) +} + +// Request AI API 请求(支持高级功能) +type Request struct { + // 基础字段 + Model string `json:"model"` // 模型名称 + Messages []Message `json:"messages"` // 对话消息列表 + Stream bool `json:"stream,omitempty"` // 是否流式响应 + + // 可选参数(用于精细控制) + Temperature *float64 `json:"temperature,omitempty"` // 温度 (0-2),控制随机性 + MaxTokens *int `json:"max_tokens,omitempty"` // 最大 token 数 + TopP *float64 `json:"top_p,omitempty"` // 核采样参数 (0-1) + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // 频率惩罚 (-2 to 2) + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // 存在惩罚 (-2 to 2) + Stop []string `json:"stop,omitempty"` // 停止序列 + + // 高级功能 + Tools []Tool `json:"tools,omitempty"` // 可用工具列表 + ToolChoice string `json:"tool_choice,omitempty"` // 工具选择策略 ("auto", "none", {"type": "function", "function": {"name": "xxx"}}) +} + +// NewMessage 创建一条消息 +func NewMessage(role, content string) Message { + return Message{ + Role: role, + Content: content, + } +} + +// NewSystemMessage 创建系统消息 +func NewSystemMessage(content string) Message { + return Message{ + Role: "system", + Content: content, + } +} + +// NewUserMessage 创建用户消息 +func NewUserMessage(content string) Message { + return Message{ + Role: "user", + Content: content, + } +} + +// NewAssistantMessage 创建助手消息 +func NewAssistantMessage(content string) Message { + return Message{ + Role: "assistant", + Content: content, + } +} diff --git a/mcp/request_builder.go b/mcp/request_builder.go new file mode 100644 index 00000000..4c61d6a4 --- /dev/null +++ b/mcp/request_builder.go @@ -0,0 +1,317 @@ +package mcp + +import ( + "errors" +) + +// RequestBuilder 请求构建器 +type RequestBuilder struct { + model string + messages []Message + stream bool + temperature *float64 + maxTokens *int + topP *float64 + frequencyPenalty *float64 + presencePenalty *float64 + stop []string + tools []Tool + toolChoice string +} + +// NewRequestBuilder 创建请求构建器 +// +// 使用示例: +// request := NewRequestBuilder(). +// WithSystemPrompt("You are helpful"). +// WithUserPrompt("Hello"). +// WithTemperature(0.8). +// Build() +func NewRequestBuilder() *RequestBuilder { + return &RequestBuilder{ + messages: make([]Message, 0), + tools: make([]Tool, 0), + } +} + +// ============================================================ +// 模型和流式配置 +// ============================================================ + +// WithModel 设置模型名称 +func (b *RequestBuilder) WithModel(model string) *RequestBuilder { + b.model = model + return b +} + +// WithStream 设置是否使用流式响应 +func (b *RequestBuilder) WithStream(stream bool) *RequestBuilder { + b.stream = stream + return b +} + +// ============================================================ +// 消息构建方法 +// ============================================================ + +// WithSystemPrompt 添加系统提示词(便捷方法) +func (b *RequestBuilder) WithSystemPrompt(prompt string) *RequestBuilder { + if prompt != "" { + b.messages = append(b.messages, NewSystemMessage(prompt)) + } + return b +} + +// WithUserPrompt 添加用户提示词(便捷方法) +func (b *RequestBuilder) WithUserPrompt(prompt string) *RequestBuilder { + if prompt != "" { + b.messages = append(b.messages, NewUserMessage(prompt)) + } + return b +} + +// AddSystemMessage 添加系统消息 +func (b *RequestBuilder) AddSystemMessage(content string) *RequestBuilder { + return b.WithSystemPrompt(content) +} + +// AddUserMessage 添加用户消息 +func (b *RequestBuilder) AddUserMessage(content string) *RequestBuilder { + return b.WithUserPrompt(content) +} + +// AddAssistantMessage 添加助手消息(用于多轮对话上下文) +func (b *RequestBuilder) AddAssistantMessage(content string) *RequestBuilder { + if content != "" { + b.messages = append(b.messages, NewAssistantMessage(content)) + } + return b +} + +// AddMessage 添加自定义角色的消息 +func (b *RequestBuilder) AddMessage(role, content string) *RequestBuilder { + if content != "" { + b.messages = append(b.messages, NewMessage(role, content)) + } + return b +} + +// AddMessages 批量添加消息 +func (b *RequestBuilder) AddMessages(messages ...Message) *RequestBuilder { + b.messages = append(b.messages, messages...) + return b +} + +// AddConversationHistory 添加对话历史 +func (b *RequestBuilder) AddConversationHistory(history []Message) *RequestBuilder { + b.messages = append(b.messages, history...) + return b +} + +// ClearMessages 清空所有消息 +func (b *RequestBuilder) ClearMessages() *RequestBuilder { + b.messages = make([]Message, 0) + return b +} + +// ============================================================ +// 参数控制方法 +// ============================================================ + +// WithTemperature 设置温度参数 (0-2) +// 较高的温度(如 1.2)会使输出更随机,较低的温度(如 0.2)会使输出更确定 +func (b *RequestBuilder) WithTemperature(t float64) *RequestBuilder { + if t < 0 || t > 2 { + // 可以选择 panic 或者静默忽略,这里选择限制范围 + if t < 0 { + t = 0 + } + if t > 2 { + t = 2 + } + } + b.temperature = &t + return b +} + +// WithMaxTokens 设置最大 token 数 +func (b *RequestBuilder) WithMaxTokens(tokens int) *RequestBuilder { + if tokens > 0 { + b.maxTokens = &tokens + } + return b +} + +// WithTopP 设置 top-p 核采样参数 (0-1) +// 控制考虑的 token 范围,较小的值(如 0.1)使输出更聚焦 +func (b *RequestBuilder) WithTopP(p float64) *RequestBuilder { + if p >= 0 && p <= 1 { + b.topP = &p + } + return b +} + +// WithFrequencyPenalty 设置频率惩罚 (-2 to 2) +// 正值会根据 token 在文本中出现的频率惩罚它们,减少重复 +func (b *RequestBuilder) WithFrequencyPenalty(penalty float64) *RequestBuilder { + if penalty >= -2 && penalty <= 2 { + b.frequencyPenalty = &penalty + } + return b +} + +// WithPresencePenalty 设置存在惩罚 (-2 to 2) +// 正值会根据 token 是否出现在文本中惩罚它们,增加话题多样性 +func (b *RequestBuilder) WithPresencePenalty(penalty float64) *RequestBuilder { + if penalty >= -2 && penalty <= 2 { + b.presencePenalty = &penalty + } + return b +} + +// WithStopSequences 设置停止序列 +// 当模型生成这些序列之一时,将停止生成 +func (b *RequestBuilder) WithStopSequences(sequences []string) *RequestBuilder { + b.stop = sequences + return b +} + +// AddStopSequence 添加单个停止序列 +func (b *RequestBuilder) AddStopSequence(sequence string) *RequestBuilder { + if sequence != "" { + b.stop = append(b.stop, sequence) + } + return b +} + +// ============================================================ +// 工具/函数调用相关 +// ============================================================ + +// AddTool 添加工具 +func (b *RequestBuilder) AddTool(tool Tool) *RequestBuilder { + b.tools = append(b.tools, tool) + return b +} + +// AddFunction 添加函数(便捷方法) +func (b *RequestBuilder) AddFunction(name, description string, parameters map[string]any) *RequestBuilder { + tool := Tool{ + Type: "function", + Function: FunctionDef{ + Name: name, + Description: description, + Parameters: parameters, + }, + } + b.tools = append(b.tools, tool) + return b +} + +// WithToolChoice 设置工具选择策略 +// - "auto": 自动选择是否调用工具 +// - "none": 不调用工具 +// - 也可以指定特定工具: `{"type": "function", "function": {"name": "my_function"}}` +func (b *RequestBuilder) WithToolChoice(choice string) *RequestBuilder { + b.toolChoice = choice + return b +} + +// ============================================================ +// 构建方法 +// ============================================================ + +// Build 构建请求对象 +func (b *RequestBuilder) Build() (*Request, error) { + // 验证:至少需要一条消息 + if len(b.messages) == 0 { + return nil, errors.New("至少需要一条消息") + } + + // 创建请求 + req := &Request{ + Model: b.model, + Messages: b.messages, + Stream: b.stream, + Stop: b.stop, + Tools: b.tools, + ToolChoice: b.toolChoice, + } + + // 只设置非 nil 的可选参数(避免发送 0 值覆盖服务端默认值) + if b.temperature != nil { + req.Temperature = b.temperature + } + if b.maxTokens != nil { + req.MaxTokens = b.maxTokens + } + if b.topP != nil { + req.TopP = b.topP + } + if b.frequencyPenalty != nil { + req.FrequencyPenalty = b.frequencyPenalty + } + if b.presencePenalty != nil { + req.PresencePenalty = b.presencePenalty + } + + return req, nil +} + +// MustBuild 构建请求对象,如果失败则 panic +// 适用于构建过程中确定不会出错的场景 +func (b *RequestBuilder) MustBuild() *Request { + req, err := b.Build() + if err != nil { + panic(err) + } + return req +} + +// ============================================================ +// 便捷方法:预设场景 +// ============================================================ + +// ForChat 创建用于聊天的构建器(预设合理的参数) +func ForChat() *RequestBuilder { + temp := 0.7 + tokens := 2000 + return &RequestBuilder{ + messages: make([]Message, 0), + tools: make([]Tool, 0), + temperature: &temp, + maxTokens: &tokens, + } +} + +// ForCodeGeneration 创建用于代码生成的构建器(低温度,更确定) +func ForCodeGeneration() *RequestBuilder { + temp := 0.2 + tokens := 2000 + topP := 0.1 + return &RequestBuilder{ + messages: make([]Message, 0), + tools: make([]Tool, 0), + temperature: &temp, + maxTokens: &tokens, + topP: &topP, + } +} + +// ForCreativeWriting 创建用于创意写作的构建器(高温度,更随机) +func ForCreativeWriting() *RequestBuilder { + temp := 1.2 + tokens := 4000 + topP := 0.95 + presencePenalty := 0.6 + frequencyPenalty := 0.5 + return &RequestBuilder{ + messages: make([]Message, 0), + tools: make([]Tool, 0), + temperature: &temp, + maxTokens: &tokens, + topP: &topP, + presencePenalty: &presencePenalty, + frequencyPenalty: &frequencyPenalty, + } +} diff --git a/mcp/request_builder_test.go b/mcp/request_builder_test.go new file mode 100644 index 00000000..be78601f --- /dev/null +++ b/mcp/request_builder_test.go @@ -0,0 +1,478 @@ +package mcp + +import ( + "encoding/json" + "testing" +) + +// ============================================================ +// 测试 RequestBuilder 基本功能 +// ============================================================ + +func TestRequestBuilder_BasicUsage(t *testing.T) { + request, err := NewRequestBuilder(). + WithSystemPrompt("You are helpful"). + WithUserPrompt("Hello"). + Build() + + if err != nil { + t.Fatalf("Build should not error: %v", err) + } + + if len(request.Messages) != 2 { + t.Errorf("expected 2 messages, got %d", len(request.Messages)) + } + + if request.Messages[0].Role != "system" { + t.Errorf("first message should be system, got %s", request.Messages[0].Role) + } + + if request.Messages[1].Role != "user" { + t.Errorf("second message should be user, got %s", request.Messages[1].Role) + } +} + +func TestRequestBuilder_EmptyMessages(t *testing.T) { + _, err := NewRequestBuilder().Build() + + if err == nil { + t.Error("Build should error when no messages") + } + + if err.Error() != "至少需要一条消息" { + t.Errorf("unexpected error: %v", err) + } +} + +// ============================================================ +// 测试消息构建方法 +// ============================================================ + +func TestRequestBuilder_MultipleMessages(t *testing.T) { + request := NewRequestBuilder(). + AddSystemMessage("You are helpful"). + AddUserMessage("What is Go?"). + AddAssistantMessage("Go is a programming language"). + AddUserMessage("Tell me more"). + MustBuild() + + if len(request.Messages) != 4 { + t.Fatalf("expected 4 messages, got %d", len(request.Messages)) + } + + expectedRoles := []string{"system", "user", "assistant", "user"} + for i, expected := range expectedRoles { + if request.Messages[i].Role != expected { + t.Errorf("message %d: expected role %s, got %s", i, expected, request.Messages[i].Role) + } + } +} + +func TestRequestBuilder_AddConversationHistory(t *testing.T) { + history := []Message{ + NewUserMessage("Previous question"), + NewAssistantMessage("Previous answer"), + } + + request := NewRequestBuilder(). + AddConversationHistory(history). + AddUserMessage("New question"). + MustBuild() + + if len(request.Messages) != 3 { + t.Fatalf("expected 3 messages, got %d", len(request.Messages)) + } +} + +// ============================================================ +// 测试参数控制方法 +// ============================================================ + +func TestRequestBuilder_WithTemperature(t *testing.T) { + request := NewRequestBuilder(). + WithUserPrompt("Hello"). + WithTemperature(0.8). + MustBuild() + + if request.Temperature == nil { + t.Fatal("Temperature should be set") + } + + if *request.Temperature != 0.8 { + t.Errorf("expected temperature 0.8, got %f", *request.Temperature) + } +} + +func TestRequestBuilder_WithMaxTokens(t *testing.T) { + request := NewRequestBuilder(). + WithUserPrompt("Hello"). + WithMaxTokens(2000). + MustBuild() + + if request.MaxTokens == nil { + t.Fatal("MaxTokens should be set") + } + + if *request.MaxTokens != 2000 { + t.Errorf("expected maxTokens 2000, got %d", *request.MaxTokens) + } +} + +func TestRequestBuilder_WithTopP(t *testing.T) { + request := NewRequestBuilder(). + WithUserPrompt("Hello"). + WithTopP(0.9). + MustBuild() + + if request.TopP == nil { + t.Fatal("TopP should be set") + } + + if *request.TopP != 0.9 { + t.Errorf("expected topP 0.9, got %f", *request.TopP) + } +} + +func TestRequestBuilder_WithPenalties(t *testing.T) { + request := NewRequestBuilder(). + WithUserPrompt("Hello"). + WithFrequencyPenalty(0.5). + WithPresencePenalty(0.6). + MustBuild() + + if request.FrequencyPenalty == nil || *request.FrequencyPenalty != 0.5 { + t.Error("FrequencyPenalty should be 0.5") + } + + if request.PresencePenalty == nil || *request.PresencePenalty != 0.6 { + t.Error("PresencePenalty should be 0.6") + } +} + +func TestRequestBuilder_WithStopSequences(t *testing.T) { + request := NewRequestBuilder(). + WithUserPrompt("Hello"). + WithStopSequences([]string{"STOP", "END"}). + MustBuild() + + if len(request.Stop) != 2 { + t.Fatalf("expected 2 stop sequences, got %d", len(request.Stop)) + } + + if request.Stop[0] != "STOP" || request.Stop[1] != "END" { + t.Error("stop sequences not set correctly") + } +} + +// ============================================================ +// 测试工具/函数调用 +// ============================================================ + +func TestRequestBuilder_AddTool(t *testing.T) { + tool := Tool{ + Type: "function", + Function: FunctionDef{ + Name: "get_weather", + Description: "Get weather", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + }, + }, + }, + } + + request := NewRequestBuilder(). + WithUserPrompt("What's the weather?"). + AddTool(tool). + WithToolChoice("auto"). + MustBuild() + + if len(request.Tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(request.Tools)) + } + + if request.Tools[0].Function.Name != "get_weather" { + t.Error("tool not added correctly") + } + + if request.ToolChoice != "auto" { + t.Error("tool choice not set correctly") + } +} + +func TestRequestBuilder_AddFunction(t *testing.T) { + params := map[string]any{ + "type": "object", + "properties": map[string]any{ + "city": map[string]any{"type": "string"}, + }, + } + + request := NewRequestBuilder(). + WithUserPrompt("Hello"). + AddFunction("get_weather", "Get current weather", params). + MustBuild() + + if len(request.Tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(request.Tools)) + } + + if request.Tools[0].Type != "function" { + t.Error("tool type should be function") + } + + if request.Tools[0].Function.Name != "get_weather" { + t.Error("function name not set correctly") + } +} + +// ============================================================ +// 测试便捷方法 +// ============================================================ + +func TestRequestBuilder_ForChat(t *testing.T) { + request := ForChat(). + WithUserPrompt("Hello"). + MustBuild() + + if request.Temperature == nil { + t.Fatal("ForChat should set temperature") + } + + if *request.Temperature != 0.7 { + t.Errorf("ForChat should set temperature to 0.7, got %f", *request.Temperature) + } + + if request.MaxTokens == nil { + t.Fatal("ForChat should set maxTokens") + } + + if *request.MaxTokens != 2000 { + t.Errorf("ForChat should set maxTokens to 2000, got %d", *request.MaxTokens) + } +} + +func TestRequestBuilder_ForCodeGeneration(t *testing.T) { + request := ForCodeGeneration(). + WithUserPrompt("Generate code"). + MustBuild() + + if request.Temperature == nil || *request.Temperature != 0.2 { + t.Error("ForCodeGeneration should set low temperature") + } + + if request.TopP == nil || *request.TopP != 0.1 { + t.Error("ForCodeGeneration should set low topP") + } +} + +func TestRequestBuilder_ForCreativeWriting(t *testing.T) { + request := ForCreativeWriting(). + WithUserPrompt("Write a story"). + MustBuild() + + if request.Temperature == nil || *request.Temperature != 1.2 { + t.Error("ForCreativeWriting should set high temperature") + } + + if request.PresencePenalty == nil || *request.PresencePenalty != 0.6 { + t.Error("ForCreativeWriting should set presence penalty") + } + + if request.FrequencyPenalty == nil || *request.FrequencyPenalty != 0.5 { + t.Error("ForCreativeWriting should set frequency penalty") + } +} + +// ============================================================ +// 测试 CallWithRequest 集成 +// ============================================================ + +func TestClient_CallWithRequest_Success(t *testing.T) { + mockHTTP := NewMockHTTPClient() + mockHTTP.SetSuccessResponse("Builder response") + mockLogger := NewMockLogger() + + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(mockLogger), + WithAPIKey("sk-test-key"), + ) + + request := NewRequestBuilder(). + WithSystemPrompt("You are helpful"). + WithUserPrompt("Hello"). + WithTemperature(0.8). + MustBuild() + + result, err := client.CallWithRequest(request) + + if err != nil { + t.Fatalf("should not error: %v", err) + } + + if result != "Builder response" { + t.Errorf("expected 'Builder response', got '%s'", result) + } + + // 验证请求体 + requests := mockHTTP.GetRequests() + if len(requests) != 1 { + t.Fatalf("expected 1 request, got %d", len(requests)) + } + + // 解析请求体验证参数 + var body map[string]interface{} + decoder := json.NewDecoder(requests[0].Body) + if err := decoder.Decode(&body); err != nil { + t.Fatalf("failed to decode request body: %v", err) + } + + // 验证 temperature + if body["temperature"] != 0.8 { + t.Errorf("expected temperature 0.8, got %v", body["temperature"]) + } + + // 验证 messages + messages, ok := body["messages"].([]interface{}) + if !ok || len(messages) != 2 { + t.Error("messages not correctly formatted") + } +} + +func TestClient_CallWithRequest_MultiRound(t *testing.T) { + mockHTTP := NewMockHTTPClient() + mockHTTP.SetSuccessResponse("Multi-round response") + mockLogger := NewMockLogger() + + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(mockLogger), + WithAPIKey("sk-test-key"), + ) + + // 构建多轮对话 + request := NewRequestBuilder(). + AddSystemMessage("You are a trading advisor"). + AddUserMessage("Analyze BTC"). + AddAssistantMessage("BTC is bullish"). + AddUserMessage("What about entry point?"). + WithTemperature(0.3). + MustBuild() + + result, err := client.CallWithRequest(request) + + if err != nil { + t.Fatalf("should not error: %v", err) + } + + if result != "Multi-round response" { + t.Errorf("expected 'Multi-round response', got '%s'", result) + } + + // 验证请求体包含所有消息 + requests := mockHTTP.GetRequests() + var body map[string]interface{} + json.NewDecoder(requests[0].Body).Decode(&body) + + messages := body["messages"].([]interface{}) + if len(messages) != 4 { + t.Errorf("expected 4 messages in request, got %d", len(messages)) + } +} + +func TestClient_CallWithRequest_WithTools(t *testing.T) { + mockHTTP := NewMockHTTPClient() + mockHTTP.SetSuccessResponse("Tool response") + mockLogger := NewMockLogger() + + client := NewClient( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(mockLogger), + WithAPIKey("sk-test-key"), + ) + + request := NewRequestBuilder(). + WithUserPrompt("What's the weather in Beijing?"). + AddFunction("get_weather", "Get weather", map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + }, + }). + WithToolChoice("auto"). + MustBuild() + + _, err := client.CallWithRequest(request) + + if err != nil { + t.Fatalf("should not error: %v", err) + } + + // 验证请求体包含 tools + requests := mockHTTP.GetRequests() + var body map[string]interface{} + json.NewDecoder(requests[0].Body).Decode(&body) + + tools, ok := body["tools"].([]interface{}) + if !ok || len(tools) == 0 { + t.Error("tools should be present in request") + } + + toolChoice, ok := body["tool_choice"].(string) + if !ok || toolChoice != "auto" { + t.Error("tool_choice should be 'auto'") + } +} + +func TestClient_CallWithRequest_NoAPIKey(t *testing.T) { + client := NewClient() + + request := NewRequestBuilder(). + WithUserPrompt("Hello"). + MustBuild() + + _, err := client.CallWithRequest(request) + + if err == nil { + t.Error("should error when API key not set") + } + + if err.Error() != "AI API密钥未设置,请先调用 SetAPIKey" { + t.Errorf("unexpected error: %v", err) + } +} + +func TestClient_CallWithRequest_UsesClientModel(t *testing.T) { + mockHTTP := NewMockHTTPClient() + mockHTTP.SetSuccessResponse("Response") + mockLogger := NewMockLogger() + + client := NewDeepSeekClientWithOptions( + WithHTTPClient(mockHTTP.ToHTTPClient()), + WithLogger(mockLogger), + WithAPIKey("sk-test-key"), + ) + + // Request 不设置 model,应该使用 Client 的 model + request := NewRequestBuilder(). + WithUserPrompt("Hello"). + MustBuild() + + if request.Model != "" { + t.Error("request.Model should be empty initially") + } + + client.CallWithRequest(request) + + // 验证使用了 DeepSeek 的 model + requests := mockHTTP.GetRequests() + var body map[string]interface{} + json.NewDecoder(requests[0].Body).Decode(&body) + + if body["model"] != DefaultDeepSeekModel { + t.Errorf("expected model %s, got %v", DefaultDeepSeekModel, body["model"]) + } +}