refactor: standardize code comments

This commit is contained in:
tinkle-community
2025-12-08 01:40:48 +08:00
parent 0636ced476
commit a12c0ae8c9
103 changed files with 5466 additions and 5468 deletions
+99 -99
View File
@@ -28,65 +28,65 @@ var (
"connection refused",
"temporary failure",
"no such host",
"stream error", // HTTP/2 stream 错误
"INTERNAL_ERROR", // 服务端内部错误
"stream error", // HTTP/2 stream error
"INTERNAL_ERROR", // Server internal error
}
)
// Client AI API配置
// Client AI API configuration
type Client struct {
Provider string
APIKey string
BaseURL string
Model string
UseFullURL bool // 是否使用完整URL(不添加/chat/completions
MaxTokens int // AI响应的最大token数
UseFullURL bool // Whether to use full URL (without appending /chat/completions)
MaxTokens int // Maximum tokens for AI response
httpClient *http.Client
logger Logger // 日志器(可替换)
config *Config // 配置对象(保存所有配置)
logger Logger // Logger (replaceable)
config *Config // Config object (stores all configurations)
// hooks 用于实现动态分派(多态)
// DeepSeekClient 嵌入 Client 时,hooks 指向 DeepSeekClient
// 这样 call() 中调用的方法会自动分派到子类重写的版本
// hooks are used to implement dynamic dispatch (polymorphism)
// When DeepSeekClient embeds Client, hooks point to DeepSeekClient
// This way methods called in call() are automatically dispatched to the overridden version in subclass
hooks clientHooks
}
// New 创建默认客户端(向前兼容)
// New creates default client (backward compatible)
//
// Deprecated: 推荐使用 NewClient(...opts) 以获得更好的灵活性
// Deprecated: Recommend using NewClient(...opts) for better flexibility
func New() AIClient {
return NewClient()
}
// NewClient 创建客户端(支持选项模式)
// NewClient creates client (supports options pattern)
//
// 使用示例:
// // 基础用法(向前兼容)
// Usage examples:
// // Basic usage (backward compatible)
// client := mcp.NewClient()
//
// // 自定义日志
// // Custom logger
// client := mcp.NewClient(mcp.WithLogger(customLogger))
//
// // 自定义超时
// // Custom timeout
// client := mcp.NewClient(mcp.WithTimeout(60*time.Second))
//
// // 组合多个选项
// // Combine multiple options
// client := mcp.NewClient(
// mcp.WithDeepSeekConfig("sk-xxx"),
// mcp.WithLogger(customLogger),
// mcp.WithTimeout(60*time.Second),
// )
func NewClient(opts ...ClientOption) AIClient {
// 1. 创建默认配置
// 1. Create default config
cfg := DefaultConfig()
// 2. 应用用户选项
// 2. Apply user options
for _, opt := range opts {
opt(cfg)
}
// 3. 创建客户端实例
// 3. Create client instance
client := &Client{
Provider: cfg.Provider,
APIKey: cfg.APIKey,
@@ -99,25 +99,25 @@ func NewClient(opts ...ClientOption) AIClient {
config: cfg,
}
// 4. 设置默认 Provider(如果未设置)
// 4. Set default Provider (if not set)
if client.Provider == "" {
client.Provider = ProviderDeepSeek
client.BaseURL = DefaultDeepSeekBaseURL
client.Model = DefaultDeepSeekModel
}
// 5. 设置 hooks 指向自己
// 5. Set hooks to point to self
client.hooks = client
return client
}
// SetCustomAPI 设置自定义OpenAI兼容API
// SetCustomAPI sets custom OpenAI-compatible API
func (client *Client) SetAPIKey(apiKey, apiURL, customModel string) {
client.Provider = ProviderCustom
client.APIKey = apiKey
// 检查URL是否以#结尾,如果是则使用完整URL(不添加/chat/completions
// Check if URL ends with #, if so use full URL (without appending /chat/completions)
if strings.HasSuffix(apiURL, "#") {
client.BaseURL = strings.TrimSuffix(apiURL, "#")
client.UseFullURL = true
@@ -133,45 +133,45 @@ func (client *Client) SetTimeout(timeout time.Duration) {
client.httpClient.Timeout = timeout
}
// CallWithMessages 模板方法 - 固定的重试流程(不可重写)
// CallWithMessages template method - fixed retry flow (cannot be overridden)
func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, error) {
if client.APIKey == "" {
return "", fmt.Errorf("AI API密钥未设置,请先调用 SetAPIKey")
return "", fmt.Errorf("AI API key not set, please call SetAPIKey first")
}
// 固定的重试流程
// Fixed retry flow
var lastErr error
maxRetries := client.config.MaxRetries
for attempt := 1; attempt <= maxRetries; attempt++ {
if attempt > 1 {
client.logger.Warnf("⚠️ AI API调用失败,正在重试 (%d/%d)...", attempt, maxRetries)
client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
}
// 调用固定的单次调用流程
// Call the fixed single-call flow
result, err := client.hooks.call(systemPrompt, userPrompt)
if err == nil {
if attempt > 1 {
client.logger.Infof("✓ AI API重试成功")
client.logger.Infof("✓ AI API retry succeeded")
}
return result, nil
}
lastErr = err
// 通过 hooks 判断是否可重试(支持子类自定义重试策略)
// Check if error is retryable via hooks (supports custom retry strategy in subclass)
if !client.hooks.isRetryableError(err) {
return "", err
}
// 重试前等待
// Wait before retry
if attempt < maxRetries {
waitTime := client.config.RetryWaitBase * time.Duration(attempt)
client.logger.Infof("⏳ 等待%v后重试...", waitTime)
client.logger.Infof("⏳ Waiting %v before retry...", waitTime)
time.Sleep(waitTime)
}
}
return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
return "", fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
}
func (client *Client) setAuthHeader(reqHeader http.Header) {
@@ -179,27 +179,27 @@ func (client *Client) setAuthHeader(reqHeader http.Header) {
}
func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
// 构建 messages 数组
// Build messages array
messages := []map[string]string{}
// 如果有 system prompt,添加 system message
// If system prompt exists, add system message
if systemPrompt != "" {
messages = append(messages, map[string]string{
"role": "system",
"content": systemPrompt,
})
}
// 添加 user message
// Add user message
messages = append(messages, map[string]string{
"role": "user",
"content": userPrompt,
})
// 构建请求体
// Build request body
requestBody := map[string]interface{}{
"model": client.Model,
"messages": messages,
"temperature": client.config.Temperature, // 使用配置的 temperature
"temperature": client.config.Temperature, // Use configured temperature
"max_tokens": client.MaxTokens,
}
return requestBody
@@ -209,7 +209,7 @@ func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[s
func (client *Client) marshalRequestBody(requestBody map[string]any) ([]byte, error) {
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
return nil, fmt.Errorf("failed to serialize request: %w", err)
}
return jsonData, nil
}
@@ -224,11 +224,11 @@ func (client *Client) parseMCPResponse(body []byte) (string, error) {
}
if err := json.Unmarshal(body, &result); err != nil {
return "", fmt.Errorf("解析响应失败: %w", err)
return "", fmt.Errorf("failed to parse response: %w", err)
}
if len(result.Choices) == 0 {
return "", fmt.Errorf("API返回空响应")
return "", fmt.Errorf("API returned empty response")
}
return result.Choices[0].Message.Content, nil
@@ -250,59 +250,59 @@ func (client *Client) buildRequest(url string, jsonData []byte) (*http.Request,
req.Header.Set("Content-Type", "application/json")
// 通过 hooks 设置认证头(支持子类重写)
// Set auth header via hooks (supports overriding in subclass)
client.hooks.setAuthHeader(req.Header)
return req, nil
}
// call 单次调用AI API(固定流程,不可重写)
// call single AI API call (fixed flow, cannot be overridden)
func (client *Client) call(systemPrompt, userPrompt string) (string, error) {
// 打印当前 AI 配置
// Print current AI configuration
client.logger.Infof("📡 [%s] Request AI Server: BaseURL: %s", client.String(), client.BaseURL)
client.logger.Debugf("[%s] UseFullURL: %v", client.String(), client.UseFullURL)
if len(client.APIKey) > 8 {
client.logger.Debugf("[%s] API Key: %s...%s", client.String(), client.APIKey[:4], client.APIKey[len(client.APIKey)-4:])
}
// Step 1: 构建请求体(通过 hooks 实现动态分派)
// Step 1: Build request body (via hooks for dynamic dispatch)
requestBody := client.hooks.buildMCPRequestBody(systemPrompt, userPrompt)
// Step 2: 序列化请求体(通过 hooks 实现动态分派)
// Step 2: Serialize request body (via hooks for dynamic dispatch)
jsonData, err := client.hooks.marshalRequestBody(requestBody)
if err != nil {
return "", err
}
// Step 3: 构建 URL(通过 hooks 实现动态分派)
// Step 3: Build URL (via hooks for dynamic dispatch)
url := client.hooks.buildUrl()
client.logger.Infof("📡 [MCP %s] 请求 URL: %s", client.String(), url)
client.logger.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
// Step 4: 创建 HTTP 请求(固定逻辑)
// Step 4: Create HTTP request (fixed logic)
req, err := client.hooks.buildRequest(url, jsonData)
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
return "", fmt.Errorf("failed to create request: %w", err)
}
// Step 5: 发送 HTTP 请求(固定逻辑)
// Step 5: Send HTTP request (fixed logic)
resp, err := client.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("发送请求失败: %w", err)
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
// Step 6: 读取响应体(固定逻辑)
// Step 6: Read response body (fixed logic)
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取响应失败: %w", err)
return "", fmt.Errorf("failed to read response: %w", err)
}
// Step 7: 检查 HTTP 状态码(固定逻辑)
// Step 7: Check HTTP status code (fixed logic)
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("API返回错误 (status %d): %s", resp.StatusCode, string(body))
return "", fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
}
// Step 8: 解析响应(通过 hooks 实现动态分派)
// Step 8: Parse response (via hooks for dynamic dispatch)
result, err := client.hooks.parseMCPResponse(body)
if err != nil {
return "", fmt.Errorf("fail to parse AI server response: %w", err)
@@ -316,10 +316,10 @@ func (client *Client) String() string {
client.Provider, client.Model)
}
// isRetryableError 判断错误是否可重试(网络错误、超时等)
// isRetryableError determines if error is retryable (network errors, timeouts, etc.)
func (client *Client) isRetryableError(err error) bool {
errStr := err.Error()
// 网络错误、超时、EOF等可以重试
// Network errors, timeouts, EOF, etc. can be retried
for _, retryable := range client.config.RetryableErrors {
if strings.Contains(errStr, retryable) {
return true
@@ -329,18 +329,18 @@ func (client *Client) isRetryableError(err error) bool {
}
// ============================================================
// 构建器模式 API(高级功能)
// Builder Pattern API (Advanced Features)
// ============================================================
// CallWithRequest 使用 Request 对象调用 AI API(支持高级功能)
// CallWithRequest calls AI API using Request object (supports advanced features)
//
// 此方法支持:
// - 多轮对话历史
// - 精细参数控制(temperaturetop_ppenalties 等)
// This method supports:
// - Multi-turn conversation history
// - Fine-grained parameter control (temperature, top_p, penalties, etc.)
// - Function Calling / Tools
// - 流式响应(未来支持)
// - Streaming response (future support)
//
// 使用示例:
// Usage example:
// request := NewRequestBuilder().
// WithSystemPrompt("You are helpful").
// WithUserPrompt("Hello").
@@ -349,93 +349,93 @@ func (client *Client) isRetryableError(err error) bool {
// result, err := client.CallWithRequest(request)
func (client *Client) CallWithRequest(req *Request) (string, error) {
if client.APIKey == "" {
return "", fmt.Errorf("AI API密钥未设置,请先调用 SetAPIKey")
return "", fmt.Errorf("AI API key not set, please call SetAPIKey first")
}
// 如果 Request 中没有设置 Model,使用 Client Model
// If Model is not set in Request, use Client's Model
if req.Model == "" {
req.Model = client.Model
}
// 固定的重试流程
// Fixed retry flow
var lastErr error
maxRetries := client.config.MaxRetries
for attempt := 1; attempt <= maxRetries; attempt++ {
if attempt > 1 {
client.logger.Warnf("⚠️ AI API调用失败,正在重试 (%d/%d)...", attempt, maxRetries)
client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
}
// 调用单次请求
// Call single request
result, err := client.callWithRequest(req)
if err == nil {
if attempt > 1 {
client.logger.Infof("✓ AI API重试成功")
client.logger.Infof("✓ AI API retry succeeded")
}
return result, nil
}
lastErr = err
// 判断是否可重试
// Check if error is retryable
if !client.hooks.isRetryableError(err) {
return "", err
}
// 重试前等待
// Wait before retry
if attempt < maxRetries {
waitTime := client.config.RetryWaitBase * time.Duration(attempt)
client.logger.Infof("⏳ 等待%v后重试...", waitTime)
client.logger.Infof("⏳ Waiting %v before retry...", waitTime)
time.Sleep(waitTime)
}
}
return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
return "", fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
}
// callWithRequest 单次调用 AI API(使用 Request 对象)
// callWithRequest single AI API call (using Request object)
func (client *Client) callWithRequest(req *Request) (string, error) {
// 打印当前 AI 配置
// Print current AI configuration
client.logger.Infof("📡 [%s] Request AI Server with Builder: BaseURL: %s", client.String(), client.BaseURL)
client.logger.Debugf("[%s] Messages count: %d", client.String(), len(req.Messages))
// 构建请求体(从 Request 对象)
// Build request body (from Request object)
requestBody := client.buildRequestBodyFromRequest(req)
// 序列化请求体
// Serialize request body
jsonData, err := client.hooks.marshalRequestBody(requestBody)
if err != nil {
return "", err
}
// 构建 URL
// Build URL
url := client.hooks.buildUrl()
client.logger.Infof("📡 [MCP %s] 请求 URL: %s", client.String(), url)
client.logger.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
// 创建 HTTP 请求
// Create HTTP request
httpReq, err := client.hooks.buildRequest(url, jsonData)
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
return "", fmt.Errorf("failed to create request: %w", err)
}
// 发送 HTTP 请求
// Send HTTP request
resp, err := client.httpClient.Do(httpReq)
if err != nil {
return "", fmt.Errorf("发送请求失败: %w", err)
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
// 读取响应体
// Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取响应失败: %w", err)
return "", fmt.Errorf("failed to read response: %w", err)
}
// 检查 HTTP 状态码
// Check HTTP status code
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("API返回错误 (status %d): %s", resp.StatusCode, string(body))
return "", fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
}
// 解析响应
// Parse response
result, err := client.hooks.parseMCPResponse(body)
if err != nil {
return "", fmt.Errorf("fail to parse AI server response: %w", err)
@@ -444,9 +444,9 @@ func (client *Client) callWithRequest(req *Request) (string, error) {
return result, nil
}
// buildRequestBodyFromRequest Request 对象构建请求体
// buildRequestBodyFromRequest builds request body from Request object
func (client *Client) buildRequestBodyFromRequest(req *Request) map[string]any {
// 转换 Message API 格式
// Convert Message to API format
messages := make([]map[string]string, 0, len(req.Messages))
for _, msg := range req.Messages {
messages = append(messages, map[string]string{
@@ -455,24 +455,24 @@ func (client *Client) buildRequestBodyFromRequest(req *Request) map[string]any {
})
}
// 构建基础请求体
// Build basic request body
requestBody := map[string]interface{}{
"model": req.Model,
"messages": messages,
}
// 添加可选参数(只添加非 nil 的参数)
// Add optional parameters (only add non-nil parameters)
if req.Temperature != nil {
requestBody["temperature"] = *req.Temperature
} else {
// 如果 Request 中没有设置,使用 Client 的配置
// If not set in Request, use Client's configuration
requestBody["temperature"] = client.config.Temperature
}
if req.MaxTokens != nil {
requestBody["max_tokens"] = *req.MaxTokens
} else {
// 如果 Request 中没有设置,使用 Client MaxTokens
// If not set in Request, use Client's MaxTokens
requestBody["max_tokens"] = client.MaxTokens
}
+21 -21
View File
@@ -8,7 +8,7 @@ import (
)
// ============================================================
// 测试 Client 创建和配置
// Test Client Creation and Configuration
// ============================================================
func TestNewClient_Default(t *testing.T) {
@@ -72,7 +72,7 @@ func TestNewClient_WithOptions(t *testing.T) {
}
// ============================================================
// 测试 CallWithMessages
// Test CallWithMessages
// ============================================================
func TestClient_CallWithMessages_Success(t *testing.T) {
@@ -97,7 +97,7 @@ func TestClient_CallWithMessages_Success(t *testing.T) {
t.Errorf("expected 'AI response content', got '%s'", result)
}
// 验证请求
// Verify request
requests := mockHTTP.GetRequests()
if len(requests) != 1 {
t.Errorf("expected 1 request, got %d", len(requests))
@@ -123,7 +123,7 @@ func TestClient_CallWithMessages_NoAPIKey(t *testing.T) {
t.Error("should error when API key is not set")
}
if err.Error() != "AI API密钥未设置,请先调用 SetAPIKey" {
if err.Error() != "AI API key not set, please call SetAPIKey first" {
t.Errorf("unexpected error message: %v", err)
}
}
@@ -147,14 +147,14 @@ func TestClient_CallWithMessages_HTTPError(t *testing.T) {
}
// ============================================================
// 测试重试逻辑
// Test Retry Logic
// ============================================================
func TestClient_Retry_Success(t *testing.T) {
mockHTTP := NewMockHTTPClient()
mockLogger := NewMockLogger()
// 模拟:第一次失败,第二次成功
// Simulate: first call fails, second call succeeds
callCount := 0
mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) {
callCount++
@@ -174,40 +174,40 @@ func TestClient_Retry_Success(t *testing.T) {
WithMaxRetries(3),
)
// 由于我们的 client 使用 hooks.call,需要特殊处理
// 这里我们测试的是 CallWithMessages 会调用 retry 逻辑
// Since our client uses hooks.call, need special handling
// Here we test that CallWithMessages will invoke retry logic
c := client.(*Client)
// 临时修改重试等待时间为 0 以加速测试
// Temporarily modify retry wait time to 0 to speed up test
oldRetries := MaxRetryTimes
MaxRetryTimes = 3
defer func() { MaxRetryTimes = oldRetries }()
_, err := c.CallWithMessages("system", "user")
// 第一次失败(connection reset),第二次成功,但是响应格式不对,会失败
// 但至少验证了重试逻辑被触发
// First fails (connection reset), second succeeds, but response format is wrong, will fail
// But at least verify retry logic was triggered
if callCount < 2 {
t.Errorf("should retry, got %d calls", callCount)
}
// 检查日志中是否有重试信息
// Check if there's retry information in logs
logs := mockLogger.GetLogsByLevel("WARN")
hasRetryLog := false
for _, log := range logs {
if log.Message == "⚠️ AI API调用失败,正在重试 (2/3)..." {
if log.Message == "⚠️ AI API call failed, retrying (2/3)..." {
hasRetryLog = true
break
}
}
if !hasRetryLog && callCount >= 2 {
// 如果确实重试了,应该有警告日志
// 但由于我们的测试设置,可能不会触发,所以这里只是检查
// If retry was indeed attempted, there should be warning logs
// But due to our test setup, it may not trigger, so just check here
t.Log("Retry was attempted")
}
_ = err // 忽略错误,我们主要测试重试逻辑被触发
_ = err // Ignore error, we mainly test retry logic was triggered
}
func TestClient_Retry_NonRetryableError(t *testing.T) {
@@ -227,7 +227,7 @@ func TestClient_Retry_NonRetryableError(t *testing.T) {
t.Error("should error")
}
// 验证没有重试(因为 400 不是可重试错误)
// Verify no retry (because 400 is not a retryable error)
requests := mockHTTP.GetRequests()
if len(requests) != 1 {
t.Errorf("should not retry for 400 error, got %d requests", len(requests))
@@ -235,7 +235,7 @@ func TestClient_Retry_NonRetryableError(t *testing.T) {
}
// ============================================================
// 测试钩子方法
// Test Hook Methods
// ============================================================
func TestClient_BuildMCPRequestBody(t *testing.T) {
@@ -368,7 +368,7 @@ func TestClient_IsRetryableError(t *testing.T) {
}
// ============================================================
// 测试 SetTimeout
// Test SetTimeout
// ============================================================
func TestClient_SetTimeout(t *testing.T) {
@@ -384,7 +384,7 @@ func TestClient_SetTimeout(t *testing.T) {
}
// ============================================================
// 测试 String 方法
// Test String Method
// ============================================================
func TestClient_String(t *testing.T) {
@@ -404,7 +404,7 @@ func TestClient_String(t *testing.T) {
}
}
// 辅助函数
// Helper function
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr))
}
+11 -11
View File
@@ -9,36 +9,36 @@ import (
"nofx/logger"
)
// Config 客户端配置(集中管理所有配置)
// Config client configuration (centralized management of all configurations)
type Config struct {
// Provider 配置
// Provider configuration
Provider string
APIKey string
BaseURL string
Model string
// 行为配置
// Behavior configuration
MaxTokens int
Temperature float64
UseFullURL bool
// 重试配置
// Retry configuration
MaxRetries int
RetryWaitBase time.Duration
RetryableErrors []string
// 超时配置
// Timeout configuration
Timeout time.Duration
// 依赖注入
// Dependency injection
Logger Logger
HTTPClient *http.Client
}
// DefaultConfig 返回默认配置
// DefaultConfig returns default configuration
func DefaultConfig() *Config {
return &Config{
// 默认值
// Default values
MaxTokens: getEnvInt("AI_MAX_TOKENS", 2000),
Temperature: MCPClientTemperature,
MaxRetries: MaxRetryTimes,
@@ -46,13 +46,13 @@ func DefaultConfig() *Config {
Timeout: DefaultTimeout,
RetryableErrors: retryableErrors,
// 默认依赖(使用全局 logger
// Default dependencies (use global logger)
Logger: logger.NewMCPLogger(),
HTTPClient: &http.Client{Timeout: DefaultTimeout},
}
}
// getEnvInt 从环境变量读取整数,失败则返回默认值
// getEnvInt reads integer from environment variable, returns default value if failed
func getEnvInt(key string, defaultValue int) int {
if val := os.Getenv(key); val != "" {
if parsed, err := strconv.Atoi(val); err == nil && parsed > 0 {
@@ -62,7 +62,7 @@ func getEnvInt(key string, defaultValue int) int {
return defaultValue
}
// getEnvString 从环境变量读取字符串,为空则返回默认值
// getEnvString reads string from environment variable, returns default value if empty
func getEnvString(key string, defaultValue string) string {
if val := os.Getenv(key); val != "" {
return val
+38 -38
View File
@@ -11,49 +11,49 @@ import (
)
// ============================================================
// 测试 Config 字段真正被使用(验证问题2修复)
// Test Config Fields Are Actually Used (Verify Issue 2 Fix)
// ============================================================
func TestConfig_MaxRetries_IsUsed(t *testing.T) {
mockHTTP := NewMockHTTPClient()
mockLogger := NewMockLogger()
// 设置 HTTP 客户端返回错误
// Set HTTP client to return error
callCount := 0
mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) {
callCount++
return nil, errors.New("connection reset")
}
// 创建客户端并设置自定义重试次数为 5
// Create client and set custom retry count to 5
client := NewClient(
WithHTTPClient(mockHTTP.ToHTTPClient()),
WithLogger(mockLogger),
WithAPIKey("sk-test-key"),
WithMaxRetries(5), // ✅ 设置重试5次
WithMaxRetries(5), // Set to retry 5 times
)
// 调用 API(应该失败)
// Call API (should fail)
_, err := client.CallWithMessages("system", "user")
if err == nil {
t.Error("should error")
}
// 验证确实重试了5次(而不是默认的3次)
// Verify indeed retried 5 times (not the default 3 times)
if callCount != 5 {
t.Errorf("expected 5 retry attempts (from WithMaxRetries(5)), got %d", callCount)
}
// 验证日志中显示正确的重试次数
// Verify logs show correct retry count
logs := mockLogger.GetLogsByLevel("WARN")
expectedWarningCount := 4 // 第2、3、4、5次重试时会打印警告
expectedWarningCount := 4 // Warnings will be printed on 2nd, 3rd, 4th, 5th retry
actualWarningCount := 0
for _, log := range logs {
if log.Message == "⚠️ AI API调用失败,正在重试 (2/5)..." ||
log.Message == "⚠️ AI API调用失败,正在重试 (3/5)..." ||
log.Message == "⚠️ AI API调用失败,正在重试 (4/5)..." ||
log.Message == "⚠️ AI API调用失败,正在重试 (5/5)..." {
if log.Message == "⚠️ AI API call failed, retrying (2/5)..." ||
log.Message == "⚠️ AI API call failed, retrying (3/5)..." ||
log.Message == "⚠️ AI API call failed, retrying (4/5)..." ||
log.Message == "⚠️ AI API call failed, retrying (5/5)..." {
actualWarningCount++
}
}
@@ -73,20 +73,20 @@ func TestConfig_Temperature_IsUsed(t *testing.T) {
customTemperature := 0.8
// 创建客户端并设置自定义 temperature
// Create client and set custom temperature
client := NewClient(
WithHTTPClient(mockHTTP.ToHTTPClient()),
WithLogger(mockLogger),
WithAPIKey("sk-test-key"),
WithTemperature(customTemperature), // ✅ 设置自定义 temperature
WithTemperature(customTemperature), // Set custom temperature
)
c := client.(*Client)
// 构建请求体
// Build request body
requestBody := c.buildMCPRequestBody("system", "user")
// 验证 temperature 字段
// Verify temperature field
temp, ok := requestBody["temperature"].(float64)
if !ok {
t.Fatal("temperature should be float64")
@@ -96,26 +96,26 @@ func TestConfig_Temperature_IsUsed(t *testing.T) {
t.Errorf("expected temperature %f (from WithTemperature), got %f", customTemperature, temp)
}
// 也可以通过实际 HTTP 请求验证
// Can also verify through actual HTTP request
_, err := client.CallWithMessages("system", "user")
if err != nil {
t.Fatalf("should not error: %v", err)
}
// 检查发送的请求体
// Check sent request body
requests := mockHTTP.GetRequests()
if len(requests) != 1 {
t.Fatalf("expected 1 request, got %d", len(requests))
}
// 解析请求体
// Parse request body
var body map[string]interface{}
decoder := json.NewDecoder(requests[0].Body)
if err := decoder.Decode(&body); err != nil {
t.Fatalf("failed to decode request body: %v", err)
}
// 验证 temperature
// Verify temperature
if body["temperature"] != customTemperature {
t.Errorf("expected temperature %f in HTTP request, got %v", customTemperature, body["temperature"])
}
@@ -125,18 +125,18 @@ func TestConfig_RetryWaitBase_IsUsed(t *testing.T) {
mockHTTP := NewMockHTTPClient()
mockLogger := NewMockLogger()
// 设置成功响应(在 ResponseFunc 之前)
// Set success response (before ResponseFunc)
mockHTTP.SetSuccessResponse("AI response")
// 设置 HTTP 客户端前2次返回错误,第3次成功
// Set HTTP client to return error first 2 times, success on 3rd time
callCount := 0
successResponse := mockHTTP.Response // 保存成功响应字符串
successResponse := mockHTTP.Response // Save success response string
mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) {
callCount++
if callCount <= 2 {
return nil, errors.New("timeout exceeded")
}
// 第3次返回成功响应
// 3rd time return success response
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewBufferString(successResponse)),
@@ -144,27 +144,27 @@ func TestConfig_RetryWaitBase_IsUsed(t *testing.T) {
}, nil
}
// 设置自定义重试等待基数为 1 秒(而不是默认的 2 秒)
// Set custom retry wait base to 1 second (instead of default 2 seconds)
customWaitBase := 1 * time.Second
client := NewClient(
WithHTTPClient(mockHTTP.ToHTTPClient()),
WithLogger(mockLogger),
WithAPIKey("sk-test-key"),
WithRetryWaitBase(customWaitBase), // ✅ 设置自定义等待时间
WithRetryWaitBase(customWaitBase), // Set custom wait time
WithMaxRetries(3),
)
// 记录开始时间
// Record start time
start := time.Now()
// 调用 API
// Call API
_, err := client.CallWithMessages("system", "user")
// 记录结束时间
// Record end time
elapsed := time.Since(start)
// 第3次成功,但前面失败了2次
// 3rd time succeeds, but failed 2 times before
if err != nil {
t.Fatalf("should succeed on 3rd attempt, got error: %v", err)
}
@@ -173,10 +173,10 @@ func TestConfig_RetryWaitBase_IsUsed(t *testing.T) {
t.Errorf("expected 3 attempts, got %d", callCount)
}
// 验证等待时间
// 第1次失败后等待 1s (customWaitBase * 1)
// 第2次失败后等待 2s (customWaitBase * 2)
// 总等待时间应该约为 3s (允许一些误差)
// Verify wait time
// After 1st failure wait 1s (customWaitBase * 1)
// After 2nd failure wait 2s (customWaitBase * 2)
// Total wait time should be about 3s (allow some error)
expectedWait := 3 * time.Second
tolerance := 200 * time.Millisecond
@@ -189,7 +189,7 @@ func TestConfig_RetryableErrors_IsUsed(t *testing.T) {
mockHTTP := NewMockHTTPClient()
mockLogger := NewMockLogger()
// 自定义可重试错误列表(只包含 "custom error"
// Custom retryable error list (only contains "custom error")
customRetryableErrors := []string{"custom error"}
client := NewClient(
@@ -200,7 +200,7 @@ func TestConfig_RetryableErrors_IsUsed(t *testing.T) {
c := client.(*Client)
// 修改 config RetryableErrors(暂时没有 WithRetryableErrors 选项)
// Modify config's RetryableErrors (no WithRetryableErrors option yet)
c.config.RetryableErrors = customRetryableErrors
tests := []struct {
@@ -236,14 +236,14 @@ func TestConfig_RetryableErrors_IsUsed(t *testing.T) {
}
// ============================================================
// 测试默认值
// Test Default Values
// ============================================================
func TestConfig_DefaultValues(t *testing.T) {
client := NewClient()
c := client.(*Client)
// 验证默认值
// Verify default values
if c.config.MaxRetries != 3 {
t.Errorf("default MaxRetries should be 3, got %d", c.config.MaxRetries)
}
+15 -15
View File
@@ -14,45 +14,45 @@ type DeepSeekClient struct {
*Client
}
// NewDeepSeekClient 创建 DeepSeek 客户端(向前兼容)
// NewDeepSeekClient creates DeepSeek client (backward compatible)
//
// Deprecated: 推荐使用 NewDeepSeekClientWithOptions 以获得更好的灵活性
// Deprecated: Recommend using NewDeepSeekClientWithOptions for better flexibility
func NewDeepSeekClient() AIClient {
return NewDeepSeekClientWithOptions()
}
// NewDeepSeekClientWithOptions 创建 DeepSeek 客户端(支持选项模式)
// NewDeepSeekClientWithOptions creates DeepSeek client (supports options pattern)
//
// 使用示例:
// // 基础用法
// Usage examples:
// // Basic usage
// client := mcp.NewDeepSeekClientWithOptions()
//
// // 自定义配置
// // Custom configuration
// client := mcp.NewDeepSeekClientWithOptions(
// mcp.WithAPIKey("sk-xxx"),
// mcp.WithLogger(customLogger),
// mcp.WithTimeout(60*time.Second),
// )
func NewDeepSeekClientWithOptions(opts ...ClientOption) AIClient {
// 1. 创建 DeepSeek 预设选项
// 1. Create DeepSeek preset options
deepseekOpts := []ClientOption{
WithProvider(ProviderDeepSeek),
WithModel(DefaultDeepSeekModel),
WithBaseURL(DefaultDeepSeekBaseURL),
}
// 2. 合并用户选项(用户选项优先级更高)
// 2. Merge user options (user options have higher priority)
allOpts := append(deepseekOpts, opts...)
// 3. 创建基础客户端
// 3. Create base client
baseClient := NewClient(allOpts...).(*Client)
// 4. 创建 DeepSeek 客户端
// 4. Create DeepSeek client
dsClient := &DeepSeekClient{
Client: baseClient,
}
// 5. 设置 hooks 指向 DeepSeekClient(实现动态分派)
// 5. Set hooks to point to DeepSeekClient (implement dynamic dispatch)
baseClient.hooks = dsClient
return dsClient
@@ -66,15 +66,15 @@ func (dsClient *DeepSeekClient) SetAPIKey(apiKey string, customURL string, custo
}
if customURL != "" {
dsClient.BaseURL = customURL
dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s", customURL)
dsClient.logger.Infof("🔧 [MCP] DeepSeek using custom BaseURL: %s", customURL)
} else {
dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用默认 BaseURL: %s", dsClient.BaseURL)
dsClient.logger.Infof("🔧 [MCP] DeepSeek using default BaseURL: %s", dsClient.BaseURL)
}
if customModel != "" {
dsClient.Model = customModel
dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用自定义 Model: %s", customModel)
dsClient.logger.Infof("🔧 [MCP] DeepSeek using custom Model: %s", customModel)
} else {
dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用默认 Model: %s", dsClient.Model)
dsClient.logger.Infof("🔧 [MCP] DeepSeek using default Model: %s", dsClient.Model)
}
}
+22 -22
View File
@@ -6,7 +6,7 @@ import (
)
// ============================================================
// 测试 DeepSeekClient 创建和配置
// Test DeepSeekClient Creation and Configuration
// ============================================================
func TestNewDeepSeekClient_Default(t *testing.T) {
@@ -16,13 +16,13 @@ func TestNewDeepSeekClient_Default(t *testing.T) {
t.Fatal("client should not be nil")
}
// 类型断言检查
// Type assertion check
dsClient, ok := client.(*DeepSeekClient)
if !ok {
t.Fatal("client should be *DeepSeekClient")
}
// 验证默认值
// Verify default values
if dsClient.Provider != ProviderDeepSeek {
t.Errorf("Provider should be '%s', got '%s'", ProviderDeepSeek, dsClient.Provider)
}
@@ -58,7 +58,7 @@ func TestNewDeepSeekClientWithOptions(t *testing.T) {
dsClient := client.(*DeepSeekClient)
// 验证自定义选项被应用
// Verify custom options are applied
if dsClient.logger != mockLogger {
t.Error("logger should be set from option")
}
@@ -75,7 +75,7 @@ func TestNewDeepSeekClientWithOptions(t *testing.T) {
t.Error("MaxTokens should be 4000")
}
// 验证 DeepSeek 默认值仍然保留
// Verify DeepSeek default values are retained
if dsClient.Provider != ProviderDeepSeek {
t.Errorf("Provider should still be '%s'", ProviderDeepSeek)
}
@@ -86,7 +86,7 @@ func TestNewDeepSeekClientWithOptions(t *testing.T) {
}
// ============================================================
// 测试 SetAPIKey
// Test SetAPIKey
// ============================================================
func TestDeepSeekClient_SetAPIKey(t *testing.T) {
@@ -97,20 +97,20 @@ func TestDeepSeekClient_SetAPIKey(t *testing.T) {
dsClient := client.(*DeepSeekClient)
// 测试设置 API Key(默认 URL Model
// Test setting API Key (default URL and Model)
dsClient.SetAPIKey("sk-test-key-12345678", "", "")
if dsClient.APIKey != "sk-test-key-12345678" {
t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", dsClient.APIKey)
}
// 验证日志记录
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
if len(logs) == 0 {
t.Error("should have logged API key setting")
}
// 验证 BaseURL Model 保持默认
// Verify BaseURL and Model remain default
if dsClient.BaseURL != DefaultDeepSeekBaseURL {
t.Error("BaseURL should remain default")
}
@@ -135,11 +135,11 @@ func TestDeepSeekClient_SetAPIKey_WithCustomURL(t *testing.T) {
t.Errorf("BaseURL should be '%s', got '%s'", customURL, dsClient.BaseURL)
}
// 验证日志记录
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
hasCustomURLLog := false
for _, log := range logs {
if log.Format == "🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s" {
if log.Format == "🔧 [MCP] DeepSeek using custom BaseURL: %s" {
hasCustomURLLog = true
break
}
@@ -165,11 +165,11 @@ func TestDeepSeekClient_SetAPIKey_WithCustomModel(t *testing.T) {
t.Errorf("Model should be '%s', got '%s'", customModel, dsClient.Model)
}
// 验证日志记录
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
hasCustomModelLog := false
for _, log := range logs {
if log.Format == "🔧 [MCP] DeepSeek 使用自定义 Model: %s" {
if log.Format == "🔧 [MCP] DeepSeek using custom Model: %s" {
hasCustomModelLog = true
break
}
@@ -181,7 +181,7 @@ func TestDeepSeekClient_SetAPIKey_WithCustomModel(t *testing.T) {
}
// ============================================================
// 测试集成功能
// Test Integration Features
// ============================================================
func TestDeepSeekClient_CallWithMessages_Success(t *testing.T) {
@@ -205,7 +205,7 @@ func TestDeepSeekClient_CallWithMessages_Success(t *testing.T) {
t.Errorf("expected 'DeepSeek AI response', got '%s'", result)
}
// 验证请求
// Verify request
requests := mockHTTP.GetRequests()
if len(requests) != 1 {
t.Fatalf("expected 1 request, got %d", len(requests))
@@ -213,19 +213,19 @@ func TestDeepSeekClient_CallWithMessages_Success(t *testing.T) {
req := requests[0]
// 验证 URL
// Verify URL
expectedURL := DefaultDeepSeekBaseURL + "/chat/completions"
if req.URL.String() != expectedURL {
t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String())
}
// 验证 Authorization header
// Verify Authorization header
authHeader := req.Header.Get("Authorization")
if authHeader != "Bearer sk-test-key" {
t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader)
}
// 验证 Content-Type
// Verify Content-Type
if req.Header.Get("Content-Type") != "application/json" {
t.Error("Content-Type should be application/json")
}
@@ -242,7 +242,7 @@ func TestDeepSeekClient_Timeout(t *testing.T) {
t.Errorf("expected timeout 30s, got %v", dsClient.httpClient.Timeout)
}
// 测试 SetTimeout
// Test SetTimeout
client.SetTimeout(60 * time.Second)
if dsClient.httpClient.Timeout != 60*time.Second {
@@ -251,19 +251,19 @@ func TestDeepSeekClient_Timeout(t *testing.T) {
}
// ============================================================
// 测试 hooks 机制
// Test hooks Mechanism
// ============================================================
func TestDeepSeekClient_HooksIntegration(t *testing.T) {
client := NewDeepSeekClientWithOptions()
dsClient := client.(*DeepSeekClient)
// 验证 hooks 指向 dsClient 自己(实现多态)
// Verify hooks point to dsClient itself (implements polymorphism)
if dsClient.hooks != dsClient {
t.Error("hooks should point to dsClient for polymorphism")
}
// 验证 buildUrl 使用 DeepSeek 配置
// Verify buildUrl uses DeepSeek configuration
url := dsClient.buildUrl()
expectedURL := DefaultDeepSeekBaseURL + "/chat/completions"
if url != expectedURL {
+44 -44
View File
@@ -9,21 +9,21 @@ import (
)
// ============================================================
// 示例 1: 基础用法(向前兼容)
// Example 1: Basic Usage (Backward Compatible)
// ============================================================
func Example_backward_compatible() {
// ✅ 旧代码继续工作,无需修改
// Old code continues to work without modification
client := mcp.New()
client.SetAPIKey("sk-xxx", "https://api.custom.com", "gpt-4")
// 使用
// Usage
result, _ := client.CallWithMessages("system prompt", "user prompt")
fmt.Println(result)
}
func Example_deepseek_backward_compatible() {
// DeepSeek 旧代码继续工作
// DeepSeek old code continues to work
client := mcp.NewDeepSeekClient()
client.SetAPIKey("sk-xxx", "", "")
@@ -32,19 +32,19 @@ func Example_deepseek_backward_compatible() {
}
// ============================================================
// 示例 2: 新的推荐用法(选项模式)
// Example 2: New Recommended Usage (Options Pattern)
// ============================================================
func Example_new_client_basic() {
// 使用默认配置
// Use default configuration
client := mcp.NewClient()
// 使用 DeepSeek
// Use DeepSeek
client = mcp.NewClient(
mcp.WithDeepSeekConfig("sk-xxx"),
)
// 使用 Qwen
// Use Qwen
client = mcp.NewClient(
mcp.WithQwenConfig("sk-xxx"),
)
@@ -53,7 +53,7 @@ func Example_new_client_basic() {
}
func Example_new_client_with_options() {
// 组合多个选项
// Combine multiple options
client := mcp.NewClient(
mcp.WithDeepSeekConfig("sk-xxx"),
mcp.WithTimeout(60*time.Second),
@@ -67,10 +67,10 @@ func Example_new_client_with_options() {
}
// ============================================================
// 示例 3: 自定义日志器
// Example 3: Custom Logger
// ============================================================
// CustomLogger 自定义日志器示例
// CustomLogger custom logger example
type CustomLogger struct{}
func (l *CustomLogger) Debugf(format string, args ...any) {
@@ -90,7 +90,7 @@ func (l *CustomLogger) Errorf(format string, args ...any) {
}
func Example_custom_logger() {
// 使用自定义日志器
// Use custom logger
customLogger := &CustomLogger{}
client := mcp.NewClient(
@@ -103,7 +103,7 @@ func Example_custom_logger() {
}
func Example_no_logger_for_testing() {
// 测试时禁用日志
// Disable logging during testing
client := mcp.NewClient(
mcp.WithLogger(mcp.NewNoopLogger()),
)
@@ -113,16 +113,16 @@ func Example_no_logger_for_testing() {
}
// ============================================================
// 示例 4: 自定义 HTTP 客户端
// Example 4: Custom HTTP Client
// ============================================================
func Example_custom_http_client() {
// 自定义 HTTP 客户端(添加代理、TLS等)
// Custom HTTP client (add proxy, TLS, etc.)
customHTTP := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
// 自定义 TLS、连接池等
// Custom TLS, connection pool, etc.
},
}
@@ -136,16 +136,16 @@ func Example_custom_http_client() {
}
// ============================================================
// 示例 5: DeepSeek 客户端(新 API
// Example 5: DeepSeek Client (New API)
// ============================================================
func Example_deepseek_new_api() {
// 基础用法
// Basic usage
client := mcp.NewDeepSeekClientWithOptions(
mcp.WithAPIKey("sk-xxx"),
)
// 高级用法
// Advanced usage
client = mcp.NewDeepSeekClientWithOptions(
mcp.WithAPIKey("sk-xxx"),
mcp.WithLogger(&CustomLogger{}),
@@ -158,16 +158,16 @@ func Example_deepseek_new_api() {
}
// ============================================================
// 示例 6: Qwen 客户端(新 API
// Example 6: Qwen Client (New API)
// ============================================================
func Example_qwen_new_api() {
// 基础用法
// Basic usage
client := mcp.NewQwenClientWithOptions(
mcp.WithAPIKey("sk-xxx"),
)
// 高级用法
// Advanced usage
client = mcp.NewQwenClientWithOptions(
mcp.WithAPIKey("sk-xxx"),
mcp.WithLogger(&CustomLogger{}),
@@ -179,18 +179,18 @@ func Example_qwen_new_api() {
}
// ============================================================
// 示例 7: 在 trader/auto_trader.go 中的迁移示例
// Example 7: Migration Example in trader/auto_trader.go
// ============================================================
func Example_trader_migration() {
// === 旧代码(继续工作)===
// Old code (continues to work)
oldStyleClient := func(apiKey, customURL, customModel string) mcp.AIClient {
client := mcp.NewDeepSeekClient()
client.SetAPIKey(apiKey, customURL, customModel)
return client
}
// === 新代码(推荐)===
// New code (recommended)
newStyleClient := func(apiKey, customURL, customModel string) mcp.AIClient {
opts := []mcp.ClientOption{
mcp.WithAPIKey(apiKey),
@@ -207,37 +207,37 @@ func Example_trader_migration() {
return mcp.NewDeepSeekClientWithOptions(opts...)
}
// 两种方式都能工作
// Both approaches work
_ = oldStyleClient("sk-xxx", "", "")
_ = newStyleClient("sk-xxx", "", "")
}
// ============================================================
// 示例 8: 测试场景
// Example 8: Testing Scenarios
// ============================================================
// MockHTTPClient Mock HTTP 客户端
// MockHTTPClient Mock HTTP client
type MockHTTPClient struct {
Response string
}
func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
// 返回预设的响应
// Return preset response
return &http.Response{
StatusCode: 200,
Body: nil, // 实际测试中需要实现
Body: nil, // Need to implement in actual tests
}, nil
}
func Example_testing_with_mock() {
// 测试时使用 Mock
// Use Mock during testing
// mockHTTP := &MockHTTPClient{
// Response: `{"choices":[{"message":{"content":"test response"}}]}`,
// }
client := mcp.NewClient(
// mcp.WithHTTPClient(mockHTTP), // 实际测试中使用 mockHTTP
mcp.WithLogger(mcp.NewNoopLogger()), // 禁用日志
// mcp.WithHTTPClient(mockHTTP), // Use mockHTTP in actual tests
mcp.WithLogger(mcp.NewNoopLogger()), // Disable logging
)
result, _ := client.CallWithMessages("system", "user")
@@ -245,20 +245,20 @@ func Example_testing_with_mock() {
}
// ============================================================
// 示例 9: 环境特定配置
// Example 9: Environment-Specific Configuration
// ============================================================
func Example_environment_specific() {
// 开发环境:详细日志
// Development environment: detailed logging
devClient := mcp.NewClient(
mcp.WithDeepSeekConfig("sk-xxx"),
mcp.WithLogger(&CustomLogger{}), // 详细日志
mcp.WithLogger(&CustomLogger{}), // Detailed logging
)
// 生产环境:结构化日志 + 超时保护
// Production environment: structured logging + timeout protection
prodClient := mcp.NewClient(
mcp.WithDeepSeekConfig("sk-xxx"),
// mcp.WithLogger(&ZapLogger{}), // 生产级日志
// mcp.WithLogger(&ZapLogger{}), // Production-grade logging
mcp.WithTimeout(30*time.Second),
mcp.WithMaxRetries(3),
)
@@ -268,11 +268,11 @@ func Example_environment_specific() {
}
// ============================================================
// 示例 10: 完整实战示例
// Example 10: Complete Real-World Example
// ============================================================
func Example_real_world_usage() {
// 创建带有完整配置的客户端
// Create client with complete configuration
client := mcp.NewDeepSeekClientWithOptions(
mcp.WithAPIKey("sk-xxxxxxxxxx"),
mcp.WithTimeout(60*time.Second),
@@ -282,9 +282,9 @@ func Example_real_world_usage() {
mcp.WithLogger(&CustomLogger{}),
)
// 使用客户端
systemPrompt := "你是一个专业的量化交易顾问"
userPrompt := "分析 BTC 当前走势"
// Use client
systemPrompt := "You are a professional quantitative trading advisor"
userPrompt := "Analyze current BTC trend"
result, err := client.CallWithMessages(systemPrompt, userPrompt)
if err != nil {
@@ -292,5 +292,5 @@ func Example_real_world_usage() {
return
}
fmt.Printf("AI 响应: %s\n", result)
fmt.Printf("AI response: %s\n", result)
}
+5 -5
View File
@@ -5,18 +5,18 @@ import (
"time"
)
// AIClient AI客户端公开接口(给外部使用)
// AIClient public AI client interface (for external use)
type AIClient interface {
SetAPIKey(apiKey string, customURL string, customModel string)
SetTimeout(timeout time.Duration)
CallWithMessages(systemPrompt, userPrompt string) (string, error)
CallWithRequest(req *Request) (string, error) // 构建器模式 API(支持高级功能)
CallWithRequest(req *Request) (string, error) // Builder pattern API (supports advanced features)
}
// clientHooks 内部钩子接口(用于子类重写特定步骤)
// 这些方法只在包内部使用,实现动态分派
// clientHooks internal hook interface (for subclass to override specific steps)
// These methods are only used inside the package to implement dynamic dispatch
type clientHooks interface {
// 可被子类重写的钩子方法
// Hook methods that can be overridden by subclass
call(systemPrompt, userPrompt string) (string, error)
+5 -5
View File
@@ -1,8 +1,8 @@
package mcp
// Logger 日志接口(抽象依赖)
// 使用 Printf 风格的方法名,方便集成 logruszap 等主流日志库
// 默认使用全局 logger 包(见 mcp/config.go
// Logger interface (abstract dependency)
// Uses Printf-style method names for easy integration with mainstream logging libraries like logrus, zap, etc.
// Default uses global logger package (see mcp/config.go)
type Logger interface {
Debugf(format string, args ...any)
Infof(format string, args ...any)
@@ -10,7 +10,7 @@ type Logger interface {
Errorf(format string, args ...any)
}
// noopLogger 空日志实现(测试时使用)
// noopLogger no-op logger implementation (used in tests)
type noopLogger struct{}
func (l *noopLogger) Debugf(format string, args ...any) {}
@@ -18,7 +18,7 @@ func (l *noopLogger) Infof(format string, args ...any) {}
func (l *noopLogger) Warnf(format string, args ...any) {}
func (l *noopLogger) Errorf(format string, args ...any) {}
// NewNoopLogger 创建空日志器(测试使用)
// NewNoopLogger creates no-op logger (for testing)
func NewNoopLogger() Logger {
return &noopLogger{}
}
+28 -28
View File
@@ -13,19 +13,19 @@ import (
// Mock Logger
// ============================================================
// MockLogger Mock 日志器(用于测试)
// MockLogger Mock logger (for testing)
type MockLogger struct {
mu sync.Mutex
Logs []LogEntry
Enabled bool // 是否启用日志记录
Enabled bool // Whether logging is enabled
}
// LogEntry 日志条目
// LogEntry log entry
type LogEntry struct {
Level string
Format string
Args []any
Message string // 格式化后的消息
Message string // Formatted message
}
func NewMockLogger() *MockLogger {
@@ -68,14 +68,14 @@ func (m *MockLogger) log(level, format string, args ...any) {
})
}
// GetLogs 获取所有日志
// GetLogs gets all logs
func (m *MockLogger) GetLogs() []LogEntry {
m.mu.Lock()
defer m.mu.Unlock()
return append([]LogEntry{}, m.Logs...)
}
// GetLogsByLevel 获取指定级别的日志
// GetLogsByLevel gets logs by specified level
func (m *MockLogger) GetLogsByLevel(level string) []LogEntry {
m.mu.Lock()
defer m.mu.Unlock()
@@ -89,14 +89,14 @@ func (m *MockLogger) GetLogsByLevel(level string) []LogEntry {
return result
}
// Clear 清空日志
// Clear clears all logs
func (m *MockLogger) Clear() {
m.mu.Lock()
defer m.mu.Unlock()
m.Logs = make([]LogEntry, 0)
}
// HasLog 检查是否包含指定消息
// HasLog checks if contains specified message
func (m *MockLogger) HasLog(level, message string) bool {
m.mu.Lock()
defer m.mu.Unlock()
@@ -110,20 +110,20 @@ func (m *MockLogger) HasLog(level, message string) bool {
}
// ============================================================
// Mock HTTP Client (实现 http.RoundTripper)
// Mock HTTP Client (implements http.RoundTripper)
// ============================================================
// MockHTTPClient Mock HTTP 客户端(实现 http.RoundTripper
// MockHTTPClient Mock HTTP client (implements http.RoundTripper)
type MockHTTPClient struct {
mu sync.Mutex
// 配置
// Configuration
Response string
StatusCode int
Error error
ResponseFunc func(req *http.Request) (*http.Response, error) // 自定义响应函数
ResponseFunc func(req *http.Request) (*http.Response, error) // Custom response function
// 记录
// Recording
Requests []*http.Request
}
@@ -134,32 +134,32 @@ func NewMockHTTPClient() *MockHTTPClient {
}
}
// ToHTTPClient 转换为 http.Client
// ToHTTPClient converts to http.Client
func (m *MockHTTPClient) ToHTTPClient() *http.Client {
return &http.Client{
Transport: m,
}
}
// RoundTrip 实现 http.RoundTripper 接口
// RoundTrip implements http.RoundTripper interface
func (m *MockHTTPClient) RoundTrip(req *http.Request) (*http.Response, error) {
m.mu.Lock()
defer m.mu.Unlock()
// 记录请求
// Record request
m.Requests = append(m.Requests, req)
// 如果有自定义响应函数,使用它
// If custom response function exists, use it
if m.ResponseFunc != nil {
return m.ResponseFunc(req)
}
// 如果设置了错误,返回错误
// If error is set, return error
if m.Error != nil {
return nil, m.Error
}
// 返回模拟响应
// Return mock response
resp := &http.Response{
StatusCode: m.StatusCode,
Body: io.NopCloser(bytes.NewBufferString(m.Response)),
@@ -169,14 +169,14 @@ func (m *MockHTTPClient) RoundTrip(req *http.Request) (*http.Response, error) {
return resp, nil
}
// GetRequests 获取所有请求
// GetRequests gets all requests
func (m *MockHTTPClient) GetRequests() []*http.Request {
m.mu.Lock()
defer m.mu.Unlock()
return append([]*http.Request{}, m.Requests...)
}
// GetLastRequest 获取最后一次请求
// GetLastRequest gets last request
func (m *MockHTTPClient) GetLastRequest() *http.Request {
m.mu.Lock()
defer m.mu.Unlock()
@@ -187,14 +187,14 @@ func (m *MockHTTPClient) GetLastRequest() *http.Request {
return m.Requests[len(m.Requests)-1]
}
// Reset 重置状态
// Reset resets state
func (m *MockHTTPClient) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
m.Requests = make([]*http.Request, 0)
}
// SetSuccessResponse 设置成功响应
// SetSuccessResponse sets success response
func (m *MockHTTPClient) SetSuccessResponse(content string) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -204,7 +204,7 @@ func (m *MockHTTPClient) SetSuccessResponse(content string) {
m.Error = nil
}
// SetErrorResponse 设置错误响应
// SetErrorResponse sets error response
func (m *MockHTTPClient) SetErrorResponse(statusCode int, message string) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -214,7 +214,7 @@ func (m *MockHTTPClient) SetErrorResponse(statusCode int, message string) {
m.Error = nil
}
// SetNetworkError 设置网络错误
// SetNetworkError sets network error
func (m *MockHTTPClient) SetNetworkError(err error) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -223,10 +223,10 @@ func (m *MockHTTPClient) SetNetworkError(err error) {
}
// ============================================================
// Mock Client Hooks (用于测试钩子机制)
// Mock Client Hooks (for testing hook mechanism)
// ============================================================
// MockClientHooks Mock 客户端钩子
// MockClientHooks Mock client hooks
type MockClientHooks struct {
BuildRequestBodyCalled int
BuildUrlCalled int
@@ -235,7 +235,7 @@ type MockClientHooks struct {
ParseResponseCalled int
IsRetryableErrorCalled int
// 自定义返回值
// Custom return values
BuildUrlFunc func() string
ParseResponseFunc func([]byte) (string, error)
IsRetryableErrorFunc func(error) bool
+29 -29
View File
@@ -5,16 +5,16 @@ import (
"time"
)
// ClientOption 客户端选项函数(Functional Options 模式)
// ClientOption client option function (Functional Options pattern)
type ClientOption func(*Config)
// ============================================================
// 依赖注入选项
// Dependency Injection Options
// ============================================================
// WithLogger 设置自定义日志器
// WithLogger sets custom logger
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithLogger(customLogger))
func WithLogger(logger Logger) ClientOption {
return func(c *Config) {
@@ -22,9 +22,9 @@ func WithLogger(logger Logger) ClientOption {
}
}
// WithHTTPClient 设置自定义 HTTP 客户端
// WithHTTPClient sets custom HTTP client
//
// 使用示例:
// Usage example:
// httpClient := &http.Client{Timeout: 60 * time.Second}
// client := mcp.NewClient(mcp.WithHTTPClient(httpClient))
func WithHTTPClient(client *http.Client) ClientOption {
@@ -34,12 +34,12 @@ func WithHTTPClient(client *http.Client) ClientOption {
}
// ============================================================
// 超时和重试选项
// Timeout and Retry Options
// ============================================================
// WithTimeout 设置请求超时时间
// WithTimeout sets request timeout duration
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithTimeout(60 * time.Second))
func WithTimeout(timeout time.Duration) ClientOption {
return func(c *Config) {
@@ -48,9 +48,9 @@ func WithTimeout(timeout time.Duration) ClientOption {
}
}
// WithMaxRetries 设置最大重试次数
// WithMaxRetries sets maximum retry count
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithMaxRetries(5))
func WithMaxRetries(maxRetries int) ClientOption {
return func(c *Config) {
@@ -58,9 +58,9 @@ func WithMaxRetries(maxRetries int) ClientOption {
}
}
// WithRetryWaitBase 设置重试等待基础时长
// WithRetryWaitBase sets base retry wait duration
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithRetryWaitBase(3 * time.Second))
func WithRetryWaitBase(waitTime time.Duration) ClientOption {
return func(c *Config) {
@@ -69,12 +69,12 @@ func WithRetryWaitBase(waitTime time.Duration) ClientOption {
}
// ============================================================
// AI 参数选项
// AI Parameter Options
// ============================================================
// WithMaxTokens 设置最大 token
// WithMaxTokens sets maximum token count
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithMaxTokens(4000))
func WithMaxTokens(maxTokens int) ClientOption {
return func(c *Config) {
@@ -82,9 +82,9 @@ func WithMaxTokens(maxTokens int) ClientOption {
}
}
// WithTemperature 设置温度参数
// WithTemperature sets temperature parameter
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithTemperature(0.7))
func WithTemperature(temperature float64) ClientOption {
return func(c *Config) {
@@ -93,38 +93,38 @@ func WithTemperature(temperature float64) ClientOption {
}
// ============================================================
// Provider 配置选项
// Provider Configuration Options
// ============================================================
// WithAPIKey 设置 API Key
// WithAPIKey sets API Key
func WithAPIKey(apiKey string) ClientOption {
return func(c *Config) {
c.APIKey = apiKey
}
}
// WithBaseURL 设置基础 URL
// WithBaseURL sets base URL
func WithBaseURL(baseURL string) ClientOption {
return func(c *Config) {
c.BaseURL = baseURL
}
}
// WithModel 设置模型名称
// WithModel sets model name
func WithModel(model string) ClientOption {
return func(c *Config) {
c.Model = model
}
}
// WithProvider 设置提供商
// WithProvider sets provider
func WithProvider(provider string) ClientOption {
return func(c *Config) {
c.Provider = provider
}
}
// WithUseFullURL 设置是否使用完整 URL
// WithUseFullURL sets whether to use full URL
func WithUseFullURL(useFullURL bool) ClientOption {
return func(c *Config) {
c.UseFullURL = useFullURL
@@ -132,12 +132,12 @@ func WithUseFullURL(useFullURL bool) ClientOption {
}
// ============================================================
// 组合选项(便捷方法)
// Combined Options (Convenience Methods)
// ============================================================
// WithDeepSeekConfig 设置 DeepSeek 配置
// WithDeepSeekConfig sets DeepSeek configuration
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithDeepSeekConfig("sk-xxx"))
func WithDeepSeekConfig(apiKey string) ClientOption {
return func(c *Config) {
@@ -148,9 +148,9 @@ func WithDeepSeekConfig(apiKey string) ClientOption {
}
}
// WithQwenConfig 设置 Qwen 配置
// WithQwenConfig sets Qwen configuration
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithQwenConfig("sk-xxx"))
func WithQwenConfig(apiKey string) ClientOption {
return func(c *Config) {
+15 -15
View File
@@ -7,7 +7,7 @@ import (
)
// ============================================================
// 测试基础选项
// Test Basic Options
// ============================================================
func TestWithProvider(t *testing.T) {
@@ -116,7 +116,7 @@ func TestWithHTTPClient(t *testing.T) {
}
// ============================================================
// 测试预设配置选项
// Test Preset Configuration Options
// ============================================================
func TestWithDeepSeekConfig(t *testing.T) {
@@ -162,7 +162,7 @@ func TestWithQwenConfig(t *testing.T) {
}
// ============================================================
// 测试选项组合
// Test Options Combination
// ============================================================
func TestMultipleOptions(t *testing.T) {
@@ -170,7 +170,7 @@ func TestMultipleOptions(t *testing.T) {
cfg := DefaultConfig()
// 应用多个选项
// Apply multiple options
options := []ClientOption{
WithProvider("test-provider"),
WithAPIKey("sk-test-key"),
@@ -186,7 +186,7 @@ func TestMultipleOptions(t *testing.T) {
opt(cfg)
}
// 验证所有选项都被应用
// Verify all options are applied
if cfg.Provider != "test-provider" {
t.Error("Provider should be set")
}
@@ -223,14 +223,14 @@ func TestMultipleOptions(t *testing.T) {
func TestOptionsOverride(t *testing.T) {
cfg := DefaultConfig()
// 先应用 DeepSeek 配置
// First apply DeepSeek configuration
WithDeepSeekConfig("sk-deepseek-key")(cfg)
// 然后覆盖某些选项
// Then override some options
WithModel("custom-model")(cfg)
WithMaxTokens(5000)(cfg)
// 验证覆盖成功
// Verify override succeeded
if cfg.Model != "custom-model" {
t.Errorf("Model should be overridden to 'custom-model', got '%s'", cfg.Model)
}
@@ -239,7 +239,7 @@ func TestOptionsOverride(t *testing.T) {
t.Errorf("MaxTokens should be overridden to 5000, got %d", cfg.MaxTokens)
}
// 验证其他 DeepSeek 配置保持不变
// Verify other DeepSeek configurations remain unchanged
if cfg.Provider != ProviderDeepSeek {
t.Error("Provider should still be DeepSeek")
}
@@ -250,7 +250,7 @@ func TestOptionsOverride(t *testing.T) {
}
// ============================================================
// 测试与客户端集成
// Test Integration with Client
// ============================================================
func TestOptionsWithNewClient(t *testing.T) {
@@ -266,7 +266,7 @@ func TestOptionsWithNewClient(t *testing.T) {
c := client.(*Client)
// 验证选项被正确应用到客户端
// Verify options are correctly applied to client
if c.Provider != "test-provider" {
t.Error("Provider should be set from options")
}
@@ -299,7 +299,7 @@ func TestOptionsWithDeepSeekClient(t *testing.T) {
dsClient := client.(*DeepSeekClient)
// 验证 DeepSeek 默认值
// Verify DeepSeek default values
if dsClient.Provider != ProviderDeepSeek {
t.Error("Provider should be DeepSeek")
}
@@ -312,7 +312,7 @@ func TestOptionsWithDeepSeekClient(t *testing.T) {
t.Error("Model should be DeepSeek default")
}
// 验证自定义选项
// Verify custom options
if dsClient.APIKey != "sk-deepseek-key" {
t.Error("APIKey should be set from options")
}
@@ -337,7 +337,7 @@ func TestOptionsWithQwenClient(t *testing.T) {
qwenClient := client.(*QwenClient)
// 验证 Qwen 默认值
// Verify Qwen default values
if qwenClient.Provider != ProviderQwen {
t.Error("Provider should be Qwen")
}
@@ -350,7 +350,7 @@ func TestOptionsWithQwenClient(t *testing.T) {
t.Error("Model should be Qwen default")
}
// 验证自定义选项
// Verify custom options
if qwenClient.APIKey != "sk-qwen-key" {
t.Error("APIKey should be set from options")
}
+15 -15
View File
@@ -14,45 +14,45 @@ type QwenClient struct {
*Client
}
// NewQwenClient 创建 Qwen 客户端(向前兼容)
// NewQwenClient creates Qwen client (backward compatible)
//
// Deprecated: 推荐使用 NewQwenClientWithOptions 以获得更好的灵活性
// Deprecated: Recommend using NewQwenClientWithOptions for better flexibility
func NewQwenClient() AIClient {
return NewQwenClientWithOptions()
}
// NewQwenClientWithOptions 创建 Qwen 客户端(支持选项模式)
// NewQwenClientWithOptions creates Qwen client (supports options pattern)
//
// 使用示例:
// // 基础用法
// Usage examples:
// // Basic usage
// client := mcp.NewQwenClientWithOptions()
//
// // 自定义配置
// // Custom configuration
// client := mcp.NewQwenClientWithOptions(
// mcp.WithAPIKey("sk-xxx"),
// mcp.WithLogger(customLogger),
// mcp.WithTimeout(60*time.Second),
// )
func NewQwenClientWithOptions(opts ...ClientOption) AIClient {
// 1. 创建 Qwen 预设选项
// 1. Create Qwen preset options
qwenOpts := []ClientOption{
WithProvider(ProviderQwen),
WithModel(DefaultQwenModel),
WithBaseURL(DefaultQwenBaseURL),
}
// 2. 合并用户选项(用户选项优先级更高)
// 2. Merge user options (user options have higher priority)
allOpts := append(qwenOpts, opts...)
// 3. 创建基础客户端
// 3. Create base client
baseClient := NewClient(allOpts...).(*Client)
// 4. 创建 Qwen 客户端
// 4. Create Qwen client
qwenClient := &QwenClient{
Client: baseClient,
}
// 5. 设置 hooks 指向 QwenClient(实现动态分派)
// 5. Set hooks to point to QwenClient (implement dynamic dispatch)
baseClient.hooks = qwenClient
return qwenClient
@@ -66,15 +66,15 @@ func (qwenClient *QwenClient) SetAPIKey(apiKey string, customURL string, customM
}
if customURL != "" {
qwenClient.BaseURL = customURL
qwenClient.logger.Infof("🔧 [MCP] Qwen 使用自定义 BaseURL: %s", customURL)
qwenClient.logger.Infof("🔧 [MCP] Qwen using custom BaseURL: %s", customURL)
} else {
qwenClient.logger.Infof("🔧 [MCP] Qwen 使用默认 BaseURL: %s", qwenClient.BaseURL)
qwenClient.logger.Infof("🔧 [MCP] Qwen using default BaseURL: %s", qwenClient.BaseURL)
}
if customModel != "" {
qwenClient.Model = customModel
qwenClient.logger.Infof("🔧 [MCP] Qwen 使用自定义 Model: %s", customModel)
qwenClient.logger.Infof("🔧 [MCP] Qwen using custom Model: %s", customModel)
} else {
qwenClient.logger.Infof("🔧 [MCP] Qwen 使用默认 Model: %s", qwenClient.Model)
qwenClient.logger.Infof("🔧 [MCP] Qwen using default Model: %s", qwenClient.Model)
}
}
+22 -22
View File
@@ -6,7 +6,7 @@ import (
)
// ============================================================
// 测试 QwenClient 创建和配置
// Test QwenClient Creation and Configuration
// ============================================================
func TestNewQwenClient_Default(t *testing.T) {
@@ -16,13 +16,13 @@ func TestNewQwenClient_Default(t *testing.T) {
t.Fatal("client should not be nil")
}
// 类型断言检查
// Type assertion check
qwenClient, ok := client.(*QwenClient)
if !ok {
t.Fatal("client should be *QwenClient")
}
// 验证默认值
// Verify default values
if qwenClient.Provider != ProviderQwen {
t.Errorf("Provider should be '%s', got '%s'", ProviderQwen, qwenClient.Provider)
}
@@ -58,7 +58,7 @@ func TestNewQwenClientWithOptions(t *testing.T) {
qwenClient := client.(*QwenClient)
// 验证自定义选项被应用
// Verify custom options are applied
if qwenClient.logger != mockLogger {
t.Error("logger should be set from option")
}
@@ -75,7 +75,7 @@ func TestNewQwenClientWithOptions(t *testing.T) {
t.Error("MaxTokens should be 4000")
}
// 验证 Qwen 默认值仍然保留
// Verify Qwen default values are retained
if qwenClient.Provider != ProviderQwen {
t.Errorf("Provider should still be '%s'", ProviderQwen)
}
@@ -86,7 +86,7 @@ func TestNewQwenClientWithOptions(t *testing.T) {
}
// ============================================================
// 测试 SetAPIKey
// Test SetAPIKey
// ============================================================
func TestQwenClient_SetAPIKey(t *testing.T) {
@@ -97,20 +97,20 @@ func TestQwenClient_SetAPIKey(t *testing.T) {
qwenClient := client.(*QwenClient)
// 测试设置 API Key(默认 URL Model
// Test setting API Key (default URL and Model)
qwenClient.SetAPIKey("sk-test-key-12345678", "", "")
if qwenClient.APIKey != "sk-test-key-12345678" {
t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", qwenClient.APIKey)
}
// 验证日志记录
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
if len(logs) == 0 {
t.Error("should have logged API key setting")
}
// 验证 BaseURL Model 保持默认
// Verify BaseURL and Model remain default
if qwenClient.BaseURL != DefaultQwenBaseURL {
t.Error("BaseURL should remain default")
}
@@ -135,11 +135,11 @@ func TestQwenClient_SetAPIKey_WithCustomURL(t *testing.T) {
t.Errorf("BaseURL should be '%s', got '%s'", customURL, qwenClient.BaseURL)
}
// 验证日志记录
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
hasCustomURLLog := false
for _, log := range logs {
if log.Format == "🔧 [MCP] Qwen 使用自定义 BaseURL: %s" {
if log.Format == "🔧 [MCP] Qwen using custom BaseURL: %s" {
hasCustomURLLog = true
break
}
@@ -165,11 +165,11 @@ func TestQwenClient_SetAPIKey_WithCustomModel(t *testing.T) {
t.Errorf("Model should be '%s', got '%s'", customModel, qwenClient.Model)
}
// 验证日志记录
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
hasCustomModelLog := false
for _, log := range logs {
if log.Format == "🔧 [MCP] Qwen 使用自定义 Model: %s" {
if log.Format == "🔧 [MCP] Qwen using custom Model: %s" {
hasCustomModelLog = true
break
}
@@ -181,7 +181,7 @@ func TestQwenClient_SetAPIKey_WithCustomModel(t *testing.T) {
}
// ============================================================
// 测试集成功能
// Test Integration Features
// ============================================================
func TestQwenClient_CallWithMessages_Success(t *testing.T) {
@@ -205,7 +205,7 @@ func TestQwenClient_CallWithMessages_Success(t *testing.T) {
t.Errorf("expected 'Qwen AI response', got '%s'", result)
}
// 验证请求
// Verify request
requests := mockHTTP.GetRequests()
if len(requests) != 1 {
t.Fatalf("expected 1 request, got %d", len(requests))
@@ -213,19 +213,19 @@ func TestQwenClient_CallWithMessages_Success(t *testing.T) {
req := requests[0]
// 验证 URL
// Verify URL
expectedURL := DefaultQwenBaseURL + "/chat/completions"
if req.URL.String() != expectedURL {
t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String())
}
// 验证 Authorization header
// Verify Authorization header
authHeader := req.Header.Get("Authorization")
if authHeader != "Bearer sk-test-key" {
t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader)
}
// 验证 Content-Type
// Verify Content-Type
if req.Header.Get("Content-Type") != "application/json" {
t.Error("Content-Type should be application/json")
}
@@ -242,7 +242,7 @@ func TestQwenClient_Timeout(t *testing.T) {
t.Errorf("expected timeout 30s, got %v", qwenClient.httpClient.Timeout)
}
// 测试 SetTimeout
// Test SetTimeout
client.SetTimeout(60 * time.Second)
if qwenClient.httpClient.Timeout != 60*time.Second {
@@ -251,19 +251,19 @@ func TestQwenClient_Timeout(t *testing.T) {
}
// ============================================================
// 测试 hooks 机制
// Test hooks Mechanism
// ============================================================
func TestQwenClient_HooksIntegration(t *testing.T) {
client := NewQwenClientWithOptions()
qwenClient := client.(*QwenClient)
// 验证 hooks 指向 qwenClient 自己(实现多态)
// Verify hooks point to qwenClient itself (implements polymorphism)
if qwenClient.hooks != qwenClient {
t.Error("hooks should point to qwenClient for polymorphism")
}
// 验证 buildUrl 使用 Qwen 配置
// Verify buildUrl uses Qwen configuration
url := qwenClient.buildUrl()
expectedURL := DefaultQwenBaseURL + "/chat/completions"
if url != expectedURL {
+28 -28
View File
@@ -1,45 +1,45 @@
package mcp
// Message 表示一条对话消息
// Message represents a conversation message
type Message struct {
Role string `json:"role"` // "system", "user", "assistant"
Content string `json:"content"` // 消息内容
Content string `json:"content"` // Message content
}
// Tool 表示 AI 可以调用的工具/函数
// Tool represents a tool/function that AI can call
type Tool struct {
Type string `json:"type"` // 通常为 "function"
Function FunctionDef `json:"function"` // 函数定义
Type string `json:"type"` // Usually "function"
Function FunctionDef `json:"function"` // Function definition
}
// FunctionDef 函数定义
// FunctionDef function definition
type FunctionDef struct {
Name string `json:"name"` // 函数名
Description string `json:"description,omitempty"` // 函数描述
Parameters map[string]any `json:"parameters,omitempty"` // 参数 schema (JSON Schema)
Name string `json:"name"` // Function name
Description string `json:"description,omitempty"` // Function description
Parameters map[string]any `json:"parameters,omitempty"` // Parameter schema (JSON Schema)
}
// Request AI API 请求(支持高级功能)
// Request AI API request (supports advanced features)
type Request struct {
// 基础字段
Model string `json:"model"` // 模型名称
Messages []Message `json:"messages"` // 对话消息列表
Stream bool `json:"stream,omitempty"` // 是否流式响应
// Basic fields
Model string `json:"model"` // Model name
Messages []Message `json:"messages"` // Conversation message list
Stream bool `json:"stream,omitempty"` // Whether to stream response
// 可选参数(用于精细控制)
Temperature *float64 `json:"temperature,omitempty"` // 温度 (0-2),控制随机性
MaxTokens *int `json:"max_tokens,omitempty"` // 最大 token
TopP *float64 `json:"top_p,omitempty"` // 核采样参数 (0-1)
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // 频率惩罚 (-2 to 2)
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // 存在惩罚 (-2 to 2)
Stop []string `json:"stop,omitempty"` // 停止序列
// Optional parameters (for fine-grained control)
Temperature *float64 `json:"temperature,omitempty"` // Temperature (0-2), controls randomness
MaxTokens *int `json:"max_tokens,omitempty"` // Maximum token count
TopP *float64 `json:"top_p,omitempty"` // Nucleus sampling parameter (0-1)
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Frequency penalty (-2 to 2)
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Presence penalty (-2 to 2)
Stop []string `json:"stop,omitempty"` // Stop sequences
// 高级功能
Tools []Tool `json:"tools,omitempty"` // 可用工具列表
ToolChoice string `json:"tool_choice,omitempty"` // 工具选择策略 ("auto", "none", {"type": "function", "function": {"name": "xxx"}})
// Advanced features
Tools []Tool `json:"tools,omitempty"` // Available tools list
ToolChoice string `json:"tool_choice,omitempty"` // Tool choice strategy ("auto", "none", {"type": "function", "function": {"name": "xxx"}})
}
// NewMessage 创建一条消息
// NewMessage creates a message
func NewMessage(role, content string) Message {
return Message{
Role: role,
@@ -47,7 +47,7 @@ func NewMessage(role, content string) Message {
}
}
// NewSystemMessage 创建系统消息
// NewSystemMessage creates a system message
func NewSystemMessage(content string) Message {
return Message{
Role: "system",
@@ -55,7 +55,7 @@ func NewSystemMessage(content string) Message {
}
}
// NewUserMessage 创建用户消息
// NewUserMessage creates a user message
func NewUserMessage(content string) Message {
return Message{
Role: "user",
@@ -63,7 +63,7 @@ func NewUserMessage(content string) Message {
}
}
// NewAssistantMessage 创建助手消息
// NewAssistantMessage creates an assistant message
func NewAssistantMessage(content string) Message {
return Message{
Role: "assistant",
+49 -49
View File
@@ -4,7 +4,7 @@ import (
"errors"
)
// RequestBuilder 请求构建器
// RequestBuilder request builder
type RequestBuilder struct {
model string
messages []Message
@@ -19,9 +19,9 @@ type RequestBuilder struct {
toolChoice string
}
// NewRequestBuilder 创建请求构建器
// NewRequestBuilder creates request builder
//
// 使用示例:
// Usage example:
// request := NewRequestBuilder().
// WithSystemPrompt("You are helpful").
// WithUserPrompt("Hello").
@@ -35,26 +35,26 @@ func NewRequestBuilder() *RequestBuilder {
}
// ============================================================
// 模型和流式配置
// Model and Stream Configuration
// ============================================================
// WithModel 设置模型名称
// WithModel sets model name
func (b *RequestBuilder) WithModel(model string) *RequestBuilder {
b.model = model
return b
}
// WithStream 设置是否使用流式响应
// WithStream sets whether to use streaming response
func (b *RequestBuilder) WithStream(stream bool) *RequestBuilder {
b.stream = stream
return b
}
// ============================================================
// 消息构建方法
// Message Building Methods
// ============================================================
// WithSystemPrompt 添加系统提示词(便捷方法)
// WithSystemPrompt adds system prompt (convenience method)
func (b *RequestBuilder) WithSystemPrompt(prompt string) *RequestBuilder {
if prompt != "" {
b.messages = append(b.messages, NewSystemMessage(prompt))
@@ -62,7 +62,7 @@ func (b *RequestBuilder) WithSystemPrompt(prompt string) *RequestBuilder {
return b
}
// WithUserPrompt 添加用户提示词(便捷方法)
// WithUserPrompt adds user prompt (convenience method)
func (b *RequestBuilder) WithUserPrompt(prompt string) *RequestBuilder {
if prompt != "" {
b.messages = append(b.messages, NewUserMessage(prompt))
@@ -70,17 +70,17 @@ func (b *RequestBuilder) WithUserPrompt(prompt string) *RequestBuilder {
return b
}
// AddSystemMessage 添加系统消息
// AddSystemMessage adds system message
func (b *RequestBuilder) AddSystemMessage(content string) *RequestBuilder {
return b.WithSystemPrompt(content)
}
// AddUserMessage 添加用户消息
// AddUserMessage adds user message
func (b *RequestBuilder) AddUserMessage(content string) *RequestBuilder {
return b.WithUserPrompt(content)
}
// AddAssistantMessage 添加助手消息(用于多轮对话上下文)
// AddAssistantMessage adds assistant message (for multi-turn conversation context)
func (b *RequestBuilder) AddAssistantMessage(content string) *RequestBuilder {
if content != "" {
b.messages = append(b.messages, NewAssistantMessage(content))
@@ -88,7 +88,7 @@ func (b *RequestBuilder) AddAssistantMessage(content string) *RequestBuilder {
return b
}
// AddMessage 添加自定义角色的消息
// AddMessage adds message with custom role
func (b *RequestBuilder) AddMessage(role, content string) *RequestBuilder {
if content != "" {
b.messages = append(b.messages, NewMessage(role, content))
@@ -96,33 +96,33 @@ func (b *RequestBuilder) AddMessage(role, content string) *RequestBuilder {
return b
}
// AddMessages 批量添加消息
// AddMessages adds messages in batch
func (b *RequestBuilder) AddMessages(messages ...Message) *RequestBuilder {
b.messages = append(b.messages, messages...)
return b
}
// AddConversationHistory 添加对话历史
// AddConversationHistory adds conversation history
func (b *RequestBuilder) AddConversationHistory(history []Message) *RequestBuilder {
b.messages = append(b.messages, history...)
return b
}
// ClearMessages 清空所有消息
// ClearMessages clears all messages
func (b *RequestBuilder) ClearMessages() *RequestBuilder {
b.messages = make([]Message, 0)
return b
}
// ============================================================
// 参数控制方法
// Parameter Control Methods
// ============================================================
// WithTemperature 设置温度参数 (0-2)
// 较高的温度(如 1.2)会使输出更随机,较低的温度(如 0.2)会使输出更确定
// WithTemperature sets temperature parameter (0-2)
// Higher temperature (e.g. 1.2) makes output more random, lower temperature (e.g. 0.2) makes output more deterministic
func (b *RequestBuilder) WithTemperature(t float64) *RequestBuilder {
if t < 0 || t > 2 {
// 可以选择 panic 或者静默忽略,这里选择限制范围
// Can choose to panic or silently ignore, here we choose to limit the range
if t < 0 {
t = 0
}
@@ -134,7 +134,7 @@ func (b *RequestBuilder) WithTemperature(t float64) *RequestBuilder {
return b
}
// WithMaxTokens 设置最大 token
// WithMaxTokens sets maximum token count
func (b *RequestBuilder) WithMaxTokens(tokens int) *RequestBuilder {
if tokens > 0 {
b.maxTokens = &tokens
@@ -142,8 +142,8 @@ func (b *RequestBuilder) WithMaxTokens(tokens int) *RequestBuilder {
return b
}
// WithTopP 设置 top-p 核采样参数 (0-1)
// 控制考虑的 token 范围,较小的值(如 0.1)使输出更聚焦
// WithTopP sets top-p nucleus sampling parameter (0-1)
// Controls the range of tokens considered, smaller values (e.g. 0.1) make output more focused
func (b *RequestBuilder) WithTopP(p float64) *RequestBuilder {
if p >= 0 && p <= 1 {
b.topP = &p
@@ -151,8 +151,8 @@ func (b *RequestBuilder) WithTopP(p float64) *RequestBuilder {
return b
}
// WithFrequencyPenalty 设置频率惩罚 (-2 to 2)
// 正值会根据 token 在文本中出现的频率惩罚它们,减少重复
// WithFrequencyPenalty sets frequency penalty (-2 to 2)
// Positive values penalize tokens based on their frequency in the text, reducing repetition
func (b *RequestBuilder) WithFrequencyPenalty(penalty float64) *RequestBuilder {
if penalty >= -2 && penalty <= 2 {
b.frequencyPenalty = &penalty
@@ -160,8 +160,8 @@ func (b *RequestBuilder) WithFrequencyPenalty(penalty float64) *RequestBuilder {
return b
}
// WithPresencePenalty 设置存在惩罚 (-2 to 2)
// 正值会根据 token 是否出现在文本中惩罚它们,增加话题多样性
// WithPresencePenalty sets presence penalty (-2 to 2)
// Positive values penalize tokens based on whether they appear in the text, increasing topic diversity
func (b *RequestBuilder) WithPresencePenalty(penalty float64) *RequestBuilder {
if penalty >= -2 && penalty <= 2 {
b.presencePenalty = &penalty
@@ -169,14 +169,14 @@ func (b *RequestBuilder) WithPresencePenalty(penalty float64) *RequestBuilder {
return b
}
// WithStopSequences 设置停止序列
// 当模型生成这些序列之一时,将停止生成
// WithStopSequences sets stop sequences
// Model will stop generating when it generates one of these sequences
func (b *RequestBuilder) WithStopSequences(sequences []string) *RequestBuilder {
b.stop = sequences
return b
}
// AddStopSequence 添加单个停止序列
// AddStopSequence adds a single stop sequence
func (b *RequestBuilder) AddStopSequence(sequence string) *RequestBuilder {
if sequence != "" {
b.stop = append(b.stop, sequence)
@@ -185,16 +185,16 @@ func (b *RequestBuilder) AddStopSequence(sequence string) *RequestBuilder {
}
// ============================================================
// 工具/函数调用相关
// Tool/Function Calling Related
// ============================================================
// AddTool 添加工具
// AddTool adds a tool
func (b *RequestBuilder) AddTool(tool Tool) *RequestBuilder {
b.tools = append(b.tools, tool)
return b
}
// AddFunction 添加函数(便捷方法)
// AddFunction adds a function (convenience method)
func (b *RequestBuilder) AddFunction(name, description string, parameters map[string]any) *RequestBuilder {
tool := Tool{
Type: "function",
@@ -208,27 +208,27 @@ func (b *RequestBuilder) AddFunction(name, description string, parameters map[st
return b
}
// WithToolChoice 设置工具选择策略
// - "auto": 自动选择是否调用工具
// - "none": 不调用工具
// - 也可以指定特定工具: `{"type": "function", "function": {"name": "my_function"}}`
// WithToolChoice sets tool choice strategy
// - "auto": automatically choose whether to call tools
// - "none": don't call tools
// - Can also specify a specific tool: `{"type": "function", "function": {"name": "my_function"}}`
func (b *RequestBuilder) WithToolChoice(choice string) *RequestBuilder {
b.toolChoice = choice
return b
}
// ============================================================
// 构建方法
// Build Methods
// ============================================================
// Build 构建请求对象
// Build builds request object
func (b *RequestBuilder) Build() (*Request, error) {
// 验证:至少需要一条消息
// Validation: at least one message is required
if len(b.messages) == 0 {
return nil, errors.New("至少需要一条消息")
return nil, errors.New("at least one message is required")
}
// 创建请求
// Create request
req := &Request{
Model: b.model,
Messages: b.messages,
@@ -238,7 +238,7 @@ func (b *RequestBuilder) Build() (*Request, error) {
ToolChoice: b.toolChoice,
}
// 只设置非 nil 的可选参数(避免发送 0 值覆盖服务端默认值)
// Only set non-nil optional parameters (avoid sending 0 values that override server defaults)
if b.temperature != nil {
req.Temperature = b.temperature
}
@@ -258,8 +258,8 @@ func (b *RequestBuilder) Build() (*Request, error) {
return req, nil
}
// MustBuild 构建请求对象,如果失败则 panic
// 适用于构建过程中确定不会出错的场景
// MustBuild builds request object, panics if failed
// Suitable for scenarios where build is guaranteed not to fail
func (b *RequestBuilder) MustBuild() *Request {
req, err := b.Build()
if err != nil {
@@ -269,10 +269,10 @@ func (b *RequestBuilder) MustBuild() *Request {
}
// ============================================================
// 便捷方法:预设场景
// Convenience Methods: Preset Scenarios
// ============================================================
// ForChat 创建用于聊天的构建器(预设合理的参数)
// ForChat creates builder for chat (preset with reasonable parameters)
func ForChat() *RequestBuilder {
temp := 0.7
tokens := 2000
@@ -284,7 +284,7 @@ func ForChat() *RequestBuilder {
}
}
// ForCodeGeneration 创建用于代码生成的构建器(低温度,更确定)
// ForCodeGeneration creates builder for code generation (low temperature, more deterministic)
func ForCodeGeneration() *RequestBuilder {
temp := 0.2
tokens := 2000
@@ -298,7 +298,7 @@ func ForCodeGeneration() *RequestBuilder {
}
}
// ForCreativeWriting 创建用于创意写作的构建器(高温度,更随机)
// ForCreativeWriting creates builder for creative writing (high temperature, more random)
func ForCreativeWriting() *RequestBuilder {
temp := 1.2
tokens := 4000
+17 -17
View File
@@ -6,7 +6,7 @@ import (
)
// ============================================================
// 测试 RequestBuilder 基本功能
// Test RequestBuilder Basic Features
// ============================================================
func TestRequestBuilder_BasicUsage(t *testing.T) {
@@ -39,13 +39,13 @@ func TestRequestBuilder_EmptyMessages(t *testing.T) {
t.Error("Build should error when no messages")
}
if err.Error() != "至少需要一条消息" {
if err.Error() != "at least one message is required" {
t.Errorf("unexpected error: %v", err)
}
}
// ============================================================
// 测试消息构建方法
// Test Message Building Methods
// ============================================================
func TestRequestBuilder_MultipleMessages(t *testing.T) {
@@ -85,7 +85,7 @@ func TestRequestBuilder_AddConversationHistory(t *testing.T) {
}
// ============================================================
// 测试参数控制方法
// Test Parameter Control Methods
// ============================================================
func TestRequestBuilder_WithTemperature(t *testing.T) {
@@ -165,7 +165,7 @@ func TestRequestBuilder_WithStopSequences(t *testing.T) {
}
// ============================================================
// 测试工具/函数调用
// Test Tool/Function Calling
// ============================================================
func TestRequestBuilder_AddTool(t *testing.T) {
@@ -229,7 +229,7 @@ func TestRequestBuilder_AddFunction(t *testing.T) {
}
// ============================================================
// 测试便捷方法
// Test Convenience Methods
// ============================================================
func TestRequestBuilder_ForChat(t *testing.T) {
@@ -287,7 +287,7 @@ func TestRequestBuilder_ForCreativeWriting(t *testing.T) {
}
// ============================================================
// 测试 CallWithRequest 集成
// Test CallWithRequest Integration
// ============================================================
func TestClient_CallWithRequest_Success(t *testing.T) {
@@ -317,25 +317,25 @@ func TestClient_CallWithRequest_Success(t *testing.T) {
t.Errorf("expected 'Builder response', got '%s'", result)
}
// 验证请求体
// Verify request body
requests := mockHTTP.GetRequests()
if len(requests) != 1 {
t.Fatalf("expected 1 request, got %d", len(requests))
}
// 解析请求体验证参数
// Parse request body to verify parameters
var body map[string]interface{}
decoder := json.NewDecoder(requests[0].Body)
if err := decoder.Decode(&body); err != nil {
t.Fatalf("failed to decode request body: %v", err)
}
// 验证 temperature
// Verify temperature
if body["temperature"] != 0.8 {
t.Errorf("expected temperature 0.8, got %v", body["temperature"])
}
// 验证 messages
// Verify messages
messages, ok := body["messages"].([]interface{})
if !ok || len(messages) != 2 {
t.Error("messages not correctly formatted")
@@ -353,7 +353,7 @@ func TestClient_CallWithRequest_MultiRound(t *testing.T) {
WithAPIKey("sk-test-key"),
)
// 构建多轮对话
// Build multi-round conversation
request := NewRequestBuilder().
AddSystemMessage("You are a trading advisor").
AddUserMessage("Analyze BTC").
@@ -372,7 +372,7 @@ func TestClient_CallWithRequest_MultiRound(t *testing.T) {
t.Errorf("expected 'Multi-round response', got '%s'", result)
}
// 验证请求体包含所有消息
// Verify request body contains all messages
requests := mockHTTP.GetRequests()
var body map[string]interface{}
json.NewDecoder(requests[0].Body).Decode(&body)
@@ -411,7 +411,7 @@ func TestClient_CallWithRequest_WithTools(t *testing.T) {
t.Fatalf("should not error: %v", err)
}
// 验证请求体包含 tools
// Verify request body contains tools
requests := mockHTTP.GetRequests()
var body map[string]interface{}
json.NewDecoder(requests[0].Body).Decode(&body)
@@ -440,7 +440,7 @@ func TestClient_CallWithRequest_NoAPIKey(t *testing.T) {
t.Error("should error when API key not set")
}
if err.Error() != "AI API密钥未设置,请先调用 SetAPIKey" {
if err.Error() != "AI API key not set, please call SetAPIKey first" {
t.Errorf("unexpected error: %v", err)
}
}
@@ -456,7 +456,7 @@ func TestClient_CallWithRequest_UsesClientModel(t *testing.T) {
WithAPIKey("sk-test-key"),
)
// Request 不设置 model,应该使用 Client model
// Request does not set model, should use Client's model
request := NewRequestBuilder().
WithUserPrompt("Hello").
MustBuild()
@@ -467,7 +467,7 @@ func TestClient_CallWithRequest_UsesClientModel(t *testing.T) {
client.CallWithRequest(request)
// 验证使用了 DeepSeek model
// Verify DeepSeek's model is used
requests := mockHTTP.GetRequests()
var body map[string]interface{}
json.NewDecoder(requests[0].Body).Decode(&body)