mirror of
https://github.com/laoxong/nofx.git
synced 2026-06-04 09:58:22 +08:00
refactor: standardize code comments
This commit is contained in:
+99
-99
@@ -28,65 +28,65 @@ var (
|
||||
"connection refused",
|
||||
"temporary failure",
|
||||
"no such host",
|
||||
"stream error", // HTTP/2 stream 错误
|
||||
"INTERNAL_ERROR", // 服务端内部错误
|
||||
"stream error", // HTTP/2 stream error
|
||||
"INTERNAL_ERROR", // Server internal error
|
||||
}
|
||||
)
|
||||
|
||||
// Client AI API配置
|
||||
// Client AI API configuration
|
||||
type Client struct {
|
||||
Provider string
|
||||
APIKey string
|
||||
BaseURL string
|
||||
Model string
|
||||
UseFullURL bool // 是否使用完整URL(不添加/chat/completions)
|
||||
MaxTokens int // AI响应的最大token数
|
||||
UseFullURL bool // Whether to use full URL (without appending /chat/completions)
|
||||
MaxTokens int // Maximum tokens for AI response
|
||||
|
||||
httpClient *http.Client
|
||||
logger Logger // 日志器(可替换)
|
||||
config *Config // 配置对象(保存所有配置)
|
||||
logger Logger // Logger (replaceable)
|
||||
config *Config // Config object (stores all configurations)
|
||||
|
||||
// hooks 用于实现动态分派(多态)
|
||||
// 当 DeepSeekClient 嵌入 Client 时,hooks 指向 DeepSeekClient
|
||||
// 这样 call() 中调用的方法会自动分派到子类重写的版本
|
||||
// hooks are used to implement dynamic dispatch (polymorphism)
|
||||
// When DeepSeekClient embeds Client, hooks point to DeepSeekClient
|
||||
// This way methods called in call() are automatically dispatched to the overridden version in subclass
|
||||
hooks clientHooks
|
||||
}
|
||||
|
||||
// New 创建默认客户端(向前兼容)
|
||||
// New creates default client (backward compatible)
|
||||
//
|
||||
// Deprecated: 推荐使用 NewClient(...opts) 以获得更好的灵活性
|
||||
// Deprecated: Recommend using NewClient(...opts) for better flexibility
|
||||
func New() AIClient {
|
||||
return NewClient()
|
||||
}
|
||||
|
||||
// NewClient 创建客户端(支持选项模式)
|
||||
// NewClient creates client (supports options pattern)
|
||||
//
|
||||
// 使用示例:
|
||||
// // 基础用法(向前兼容)
|
||||
// Usage examples:
|
||||
// // Basic usage (backward compatible)
|
||||
// client := mcp.NewClient()
|
||||
//
|
||||
// // 自定义日志
|
||||
// // Custom logger
|
||||
// client := mcp.NewClient(mcp.WithLogger(customLogger))
|
||||
//
|
||||
// // 自定义超时
|
||||
// // Custom timeout
|
||||
// client := mcp.NewClient(mcp.WithTimeout(60*time.Second))
|
||||
//
|
||||
// // 组合多个选项
|
||||
// // Combine multiple options
|
||||
// client := mcp.NewClient(
|
||||
// mcp.WithDeepSeekConfig("sk-xxx"),
|
||||
// mcp.WithLogger(customLogger),
|
||||
// mcp.WithTimeout(60*time.Second),
|
||||
// )
|
||||
func NewClient(opts ...ClientOption) AIClient {
|
||||
// 1. 创建默认配置
|
||||
// 1. Create default config
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// 2. 应用用户选项
|
||||
// 2. Apply user options
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
|
||||
// 3. 创建客户端实例
|
||||
// 3. Create client instance
|
||||
client := &Client{
|
||||
Provider: cfg.Provider,
|
||||
APIKey: cfg.APIKey,
|
||||
@@ -99,25 +99,25 @@ func NewClient(opts ...ClientOption) AIClient {
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
// 4. 设置默认 Provider(如果未设置)
|
||||
// 4. Set default Provider (if not set)
|
||||
if client.Provider == "" {
|
||||
client.Provider = ProviderDeepSeek
|
||||
client.BaseURL = DefaultDeepSeekBaseURL
|
||||
client.Model = DefaultDeepSeekModel
|
||||
}
|
||||
|
||||
// 5. 设置 hooks 指向自己
|
||||
// 5. Set hooks to point to self
|
||||
client.hooks = client
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// SetCustomAPI 设置自定义OpenAI兼容API
|
||||
// SetCustomAPI sets custom OpenAI-compatible API
|
||||
func (client *Client) SetAPIKey(apiKey, apiURL, customModel string) {
|
||||
client.Provider = ProviderCustom
|
||||
client.APIKey = apiKey
|
||||
|
||||
// 检查URL是否以#结尾,如果是则使用完整URL(不添加/chat/completions)
|
||||
// Check if URL ends with #, if so use full URL (without appending /chat/completions)
|
||||
if strings.HasSuffix(apiURL, "#") {
|
||||
client.BaseURL = strings.TrimSuffix(apiURL, "#")
|
||||
client.UseFullURL = true
|
||||
@@ -133,45 +133,45 @@ func (client *Client) SetTimeout(timeout time.Duration) {
|
||||
client.httpClient.Timeout = timeout
|
||||
}
|
||||
|
||||
// CallWithMessages 模板方法 - 固定的重试流程(不可重写)
|
||||
// CallWithMessages template method - fixed retry flow (cannot be overridden)
|
||||
func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, error) {
|
||||
if client.APIKey == "" {
|
||||
return "", fmt.Errorf("AI API密钥未设置,请先调用 SetAPIKey")
|
||||
return "", fmt.Errorf("AI API key not set, please call SetAPIKey first")
|
||||
}
|
||||
|
||||
// 固定的重试流程
|
||||
// Fixed retry flow
|
||||
var lastErr error
|
||||
maxRetries := client.config.MaxRetries
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
if attempt > 1 {
|
||||
client.logger.Warnf("⚠️ AI API调用失败,正在重试 (%d/%d)...", attempt, maxRetries)
|
||||
client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
|
||||
}
|
||||
|
||||
// 调用固定的单次调用流程
|
||||
// Call the fixed single-call flow
|
||||
result, err := client.hooks.call(systemPrompt, userPrompt)
|
||||
if err == nil {
|
||||
if attempt > 1 {
|
||||
client.logger.Infof("✓ AI API重试成功")
|
||||
client.logger.Infof("✓ AI API retry succeeded")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
// 通过 hooks 判断是否可重试(支持子类自定义重试策略)
|
||||
// Check if error is retryable via hooks (supports custom retry strategy in subclass)
|
||||
if !client.hooks.isRetryableError(err) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 重试前等待
|
||||
// Wait before retry
|
||||
if attempt < maxRetries {
|
||||
waitTime := client.config.RetryWaitBase * time.Duration(attempt)
|
||||
client.logger.Infof("⏳ 等待%v后重试...", waitTime)
|
||||
client.logger.Infof("⏳ Waiting %v before retry...", waitTime)
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
|
||||
return "", fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func (client *Client) setAuthHeader(reqHeader http.Header) {
|
||||
@@ -179,27 +179,27 @@ func (client *Client) setAuthHeader(reqHeader http.Header) {
|
||||
}
|
||||
|
||||
func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
|
||||
// 构建 messages 数组
|
||||
// Build messages array
|
||||
messages := []map[string]string{}
|
||||
|
||||
// 如果有 system prompt,添加 system message
|
||||
// If system prompt exists, add system message
|
||||
if systemPrompt != "" {
|
||||
messages = append(messages, map[string]string{
|
||||
"role": "system",
|
||||
"content": systemPrompt,
|
||||
})
|
||||
}
|
||||
// 添加 user message
|
||||
// Add user message
|
||||
messages = append(messages, map[string]string{
|
||||
"role": "user",
|
||||
"content": userPrompt,
|
||||
})
|
||||
|
||||
// 构建请求体
|
||||
// Build request body
|
||||
requestBody := map[string]interface{}{
|
||||
"model": client.Model,
|
||||
"messages": messages,
|
||||
"temperature": client.config.Temperature, // 使用配置的 temperature
|
||||
"temperature": client.config.Temperature, // Use configured temperature
|
||||
"max_tokens": client.MaxTokens,
|
||||
}
|
||||
return requestBody
|
||||
@@ -209,7 +209,7 @@ func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[s
|
||||
func (client *Client) marshalRequestBody(requestBody map[string]any) ([]byte, error) {
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
return nil, fmt.Errorf("failed to serialize request: %w", err)
|
||||
}
|
||||
return jsonData, nil
|
||||
}
|
||||
@@ -224,11 +224,11 @@ func (client *Client) parseMCPResponse(body []byte) (string, error) {
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Choices) == 0 {
|
||||
return "", fmt.Errorf("API返回空响应")
|
||||
return "", fmt.Errorf("API returned empty response")
|
||||
}
|
||||
|
||||
return result.Choices[0].Message.Content, nil
|
||||
@@ -250,59 +250,59 @@ func (client *Client) buildRequest(url string, jsonData []byte) (*http.Request,
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 通过 hooks 设置认证头(支持子类重写)
|
||||
// Set auth header via hooks (supports overriding in subclass)
|
||||
client.hooks.setAuthHeader(req.Header)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// call 单次调用AI API(固定流程,不可重写)
|
||||
// call single AI API call (fixed flow, cannot be overridden)
|
||||
func (client *Client) call(systemPrompt, userPrompt string) (string, error) {
|
||||
// 打印当前 AI 配置
|
||||
// Print current AI configuration
|
||||
client.logger.Infof("📡 [%s] Request AI Server: BaseURL: %s", client.String(), client.BaseURL)
|
||||
client.logger.Debugf("[%s] UseFullURL: %v", client.String(), client.UseFullURL)
|
||||
if len(client.APIKey) > 8 {
|
||||
client.logger.Debugf("[%s] API Key: %s...%s", client.String(), client.APIKey[:4], client.APIKey[len(client.APIKey)-4:])
|
||||
}
|
||||
|
||||
// Step 1: 构建请求体(通过 hooks 实现动态分派)
|
||||
// Step 1: Build request body (via hooks for dynamic dispatch)
|
||||
requestBody := client.hooks.buildMCPRequestBody(systemPrompt, userPrompt)
|
||||
|
||||
// Step 2: 序列化请求体(通过 hooks 实现动态分派)
|
||||
// Step 2: Serialize request body (via hooks for dynamic dispatch)
|
||||
jsonData, err := client.hooks.marshalRequestBody(requestBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Step 3: 构建 URL(通过 hooks 实现动态分派)
|
||||
// Step 3: Build URL (via hooks for dynamic dispatch)
|
||||
url := client.hooks.buildUrl()
|
||||
client.logger.Infof("📡 [MCP %s] 请求 URL: %s", client.String(), url)
|
||||
client.logger.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
|
||||
|
||||
// Step 4: 创建 HTTP 请求(固定逻辑)
|
||||
// Step 4: Create HTTP request (fixed logic)
|
||||
req, err := client.hooks.buildRequest(url, jsonData)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Step 5: 发送 HTTP 请求(固定逻辑)
|
||||
// Step 5: Send HTTP request (fixed logic)
|
||||
resp, err := client.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("发送请求失败: %w", err)
|
||||
return "", fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Step 6: 读取响应体(固定逻辑)
|
||||
// Step 6: Read response body (fixed logic)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取响应失败: %w", err)
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// Step 7: 检查 HTTP 状态码(固定逻辑)
|
||||
// Step 7: Check HTTP status code (fixed logic)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("API返回错误 (status %d): %s", resp.StatusCode, string(body))
|
||||
return "", fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Step 8: 解析响应(通过 hooks 实现动态分派)
|
||||
// Step 8: Parse response (via hooks for dynamic dispatch)
|
||||
result, err := client.hooks.parseMCPResponse(body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fail to parse AI server response: %w", err)
|
||||
@@ -316,10 +316,10 @@ func (client *Client) String() string {
|
||||
client.Provider, client.Model)
|
||||
}
|
||||
|
||||
// isRetryableError 判断错误是否可重试(网络错误、超时等)
|
||||
// isRetryableError determines if error is retryable (network errors, timeouts, etc.)
|
||||
func (client *Client) isRetryableError(err error) bool {
|
||||
errStr := err.Error()
|
||||
// 网络错误、超时、EOF等可以重试
|
||||
// Network errors, timeouts, EOF, etc. can be retried
|
||||
for _, retryable := range client.config.RetryableErrors {
|
||||
if strings.Contains(errStr, retryable) {
|
||||
return true
|
||||
@@ -329,18 +329,18 @@ func (client *Client) isRetryableError(err error) bool {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 构建器模式 API(高级功能)
|
||||
// Builder Pattern API (Advanced Features)
|
||||
// ============================================================
|
||||
|
||||
// CallWithRequest 使用 Request 对象调用 AI API(支持高级功能)
|
||||
// CallWithRequest calls AI API using Request object (supports advanced features)
|
||||
//
|
||||
// 此方法支持:
|
||||
// - 多轮对话历史
|
||||
// - 精细参数控制(temperature、top_p、penalties 等)
|
||||
// This method supports:
|
||||
// - Multi-turn conversation history
|
||||
// - Fine-grained parameter control (temperature, top_p, penalties, etc.)
|
||||
// - Function Calling / Tools
|
||||
// - 流式响应(未来支持)
|
||||
// - Streaming response (future support)
|
||||
//
|
||||
// 使用示例:
|
||||
// Usage example:
|
||||
// request := NewRequestBuilder().
|
||||
// WithSystemPrompt("You are helpful").
|
||||
// WithUserPrompt("Hello").
|
||||
@@ -349,93 +349,93 @@ func (client *Client) isRetryableError(err error) bool {
|
||||
// result, err := client.CallWithRequest(request)
|
||||
func (client *Client) CallWithRequest(req *Request) (string, error) {
|
||||
if client.APIKey == "" {
|
||||
return "", fmt.Errorf("AI API密钥未设置,请先调用 SetAPIKey")
|
||||
return "", fmt.Errorf("AI API key not set, please call SetAPIKey first")
|
||||
}
|
||||
|
||||
// 如果 Request 中没有设置 Model,使用 Client 的 Model
|
||||
// If Model is not set in Request, use Client's Model
|
||||
if req.Model == "" {
|
||||
req.Model = client.Model
|
||||
}
|
||||
|
||||
// 固定的重试流程
|
||||
// Fixed retry flow
|
||||
var lastErr error
|
||||
maxRetries := client.config.MaxRetries
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
if attempt > 1 {
|
||||
client.logger.Warnf("⚠️ AI API调用失败,正在重试 (%d/%d)...", attempt, maxRetries)
|
||||
client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
|
||||
}
|
||||
|
||||
// 调用单次请求
|
||||
// Call single request
|
||||
result, err := client.callWithRequest(req)
|
||||
if err == nil {
|
||||
if attempt > 1 {
|
||||
client.logger.Infof("✓ AI API重试成功")
|
||||
client.logger.Infof("✓ AI API retry succeeded")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
// 判断是否可重试
|
||||
// Check if error is retryable
|
||||
if !client.hooks.isRetryableError(err) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 重试前等待
|
||||
// Wait before retry
|
||||
if attempt < maxRetries {
|
||||
waitTime := client.config.RetryWaitBase * time.Duration(attempt)
|
||||
client.logger.Infof("⏳ 等待%v后重试...", waitTime)
|
||||
client.logger.Infof("⏳ Waiting %v before retry...", waitTime)
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
|
||||
return "", fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// callWithRequest 单次调用 AI API(使用 Request 对象)
|
||||
// callWithRequest single AI API call (using Request object)
|
||||
func (client *Client) callWithRequest(req *Request) (string, error) {
|
||||
// 打印当前 AI 配置
|
||||
// Print current AI configuration
|
||||
client.logger.Infof("📡 [%s] Request AI Server with Builder: BaseURL: %s", client.String(), client.BaseURL)
|
||||
client.logger.Debugf("[%s] Messages count: %d", client.String(), len(req.Messages))
|
||||
|
||||
// 构建请求体(从 Request 对象)
|
||||
// Build request body (from Request object)
|
||||
requestBody := client.buildRequestBodyFromRequest(req)
|
||||
|
||||
// 序列化请求体
|
||||
// Serialize request body
|
||||
jsonData, err := client.hooks.marshalRequestBody(requestBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 构建 URL
|
||||
// Build URL
|
||||
url := client.hooks.buildUrl()
|
||||
client.logger.Infof("📡 [MCP %s] 请求 URL: %s", client.String(), url)
|
||||
client.logger.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
|
||||
|
||||
// 创建 HTTP 请求
|
||||
// Create HTTP request
|
||||
httpReq, err := client.hooks.buildRequest(url, jsonData)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// 发送 HTTP 请求
|
||||
// Send HTTP request
|
||||
resp, err := client.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("发送请求失败: %w", err)
|
||||
return "", fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应体
|
||||
// Read response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取响应失败: %w", err)
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// 检查 HTTP 状态码
|
||||
// Check HTTP status code
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("API返回错误 (status %d): %s", resp.StatusCode, string(body))
|
||||
return "", fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
// Parse response
|
||||
result, err := client.hooks.parseMCPResponse(body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fail to parse AI server response: %w", err)
|
||||
@@ -444,9 +444,9 @@ func (client *Client) callWithRequest(req *Request) (string, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// buildRequestBodyFromRequest 从 Request 对象构建请求体
|
||||
// buildRequestBodyFromRequest builds request body from Request object
|
||||
func (client *Client) buildRequestBodyFromRequest(req *Request) map[string]any {
|
||||
// 转换 Message 为 API 格式
|
||||
// Convert Message to API format
|
||||
messages := make([]map[string]string, 0, len(req.Messages))
|
||||
for _, msg := range req.Messages {
|
||||
messages = append(messages, map[string]string{
|
||||
@@ -455,24 +455,24 @@ func (client *Client) buildRequestBodyFromRequest(req *Request) map[string]any {
|
||||
})
|
||||
}
|
||||
|
||||
// 构建基础请求体
|
||||
// Build basic request body
|
||||
requestBody := map[string]interface{}{
|
||||
"model": req.Model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
// 添加可选参数(只添加非 nil 的参数)
|
||||
// Add optional parameters (only add non-nil parameters)
|
||||
if req.Temperature != nil {
|
||||
requestBody["temperature"] = *req.Temperature
|
||||
} else {
|
||||
// 如果 Request 中没有设置,使用 Client 的配置
|
||||
// If not set in Request, use Client's configuration
|
||||
requestBody["temperature"] = client.config.Temperature
|
||||
}
|
||||
|
||||
if req.MaxTokens != nil {
|
||||
requestBody["max_tokens"] = *req.MaxTokens
|
||||
} else {
|
||||
// 如果 Request 中没有设置,使用 Client 的 MaxTokens
|
||||
// If not set in Request, use Client's MaxTokens
|
||||
requestBody["max_tokens"] = client.MaxTokens
|
||||
}
|
||||
|
||||
|
||||
+21
-21
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// 测试 Client 创建和配置
|
||||
// Test Client Creation and Configuration
|
||||
// ============================================================
|
||||
|
||||
func TestNewClient_Default(t *testing.T) {
|
||||
@@ -72,7 +72,7 @@ func TestNewClient_WithOptions(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试 CallWithMessages
|
||||
// Test CallWithMessages
|
||||
// ============================================================
|
||||
|
||||
func TestClient_CallWithMessages_Success(t *testing.T) {
|
||||
@@ -97,7 +97,7 @@ func TestClient_CallWithMessages_Success(t *testing.T) {
|
||||
t.Errorf("expected 'AI response content', got '%s'", result)
|
||||
}
|
||||
|
||||
// 验证请求
|
||||
// Verify request
|
||||
requests := mockHTTP.GetRequests()
|
||||
if len(requests) != 1 {
|
||||
t.Errorf("expected 1 request, got %d", len(requests))
|
||||
@@ -123,7 +123,7 @@ func TestClient_CallWithMessages_NoAPIKey(t *testing.T) {
|
||||
t.Error("should error when API key is not set")
|
||||
}
|
||||
|
||||
if err.Error() != "AI API密钥未设置,请先调用 SetAPIKey" {
|
||||
if err.Error() != "AI API key not set, please call SetAPIKey first" {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -147,14 +147,14 @@ func TestClient_CallWithMessages_HTTPError(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试重试逻辑
|
||||
// Test Retry Logic
|
||||
// ============================================================
|
||||
|
||||
func TestClient_Retry_Success(t *testing.T) {
|
||||
mockHTTP := NewMockHTTPClient()
|
||||
mockLogger := NewMockLogger()
|
||||
|
||||
// 模拟:第一次失败,第二次成功
|
||||
// Simulate: first call fails, second call succeeds
|
||||
callCount := 0
|
||||
mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) {
|
||||
callCount++
|
||||
@@ -174,40 +174,40 @@ func TestClient_Retry_Success(t *testing.T) {
|
||||
WithMaxRetries(3),
|
||||
)
|
||||
|
||||
// 由于我们的 client 使用 hooks.call,需要特殊处理
|
||||
// 这里我们测试的是 CallWithMessages 会调用 retry 逻辑
|
||||
// Since our client uses hooks.call, need special handling
|
||||
// Here we test that CallWithMessages will invoke retry logic
|
||||
c := client.(*Client)
|
||||
|
||||
// 临时修改重试等待时间为 0 以加速测试
|
||||
// Temporarily modify retry wait time to 0 to speed up test
|
||||
oldRetries := MaxRetryTimes
|
||||
MaxRetryTimes = 3
|
||||
defer func() { MaxRetryTimes = oldRetries }()
|
||||
|
||||
_, err := c.CallWithMessages("system", "user")
|
||||
|
||||
// 第一次失败(connection reset),第二次成功,但是响应格式不对,会失败
|
||||
// 但至少验证了重试逻辑被触发
|
||||
// First fails (connection reset), second succeeds, but response format is wrong, will fail
|
||||
// But at least verify retry logic was triggered
|
||||
if callCount < 2 {
|
||||
t.Errorf("should retry, got %d calls", callCount)
|
||||
}
|
||||
|
||||
// 检查日志中是否有重试信息
|
||||
// Check if there's retry information in logs
|
||||
logs := mockLogger.GetLogsByLevel("WARN")
|
||||
hasRetryLog := false
|
||||
for _, log := range logs {
|
||||
if log.Message == "⚠️ AI API调用失败,正在重试 (2/3)..." {
|
||||
if log.Message == "⚠️ AI API call failed, retrying (2/3)..." {
|
||||
hasRetryLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasRetryLog && callCount >= 2 {
|
||||
// 如果确实重试了,应该有警告日志
|
||||
// 但由于我们的测试设置,可能不会触发,所以这里只是检查
|
||||
// If retry was indeed attempted, there should be warning logs
|
||||
// But due to our test setup, it may not trigger, so just check here
|
||||
t.Log("Retry was attempted")
|
||||
}
|
||||
|
||||
_ = err // 忽略错误,我们主要测试重试逻辑被触发
|
||||
_ = err // Ignore error, we mainly test retry logic was triggered
|
||||
}
|
||||
|
||||
func TestClient_Retry_NonRetryableError(t *testing.T) {
|
||||
@@ -227,7 +227,7 @@ func TestClient_Retry_NonRetryableError(t *testing.T) {
|
||||
t.Error("should error")
|
||||
}
|
||||
|
||||
// 验证没有重试(因为 400 不是可重试错误)
|
||||
// Verify no retry (because 400 is not a retryable error)
|
||||
requests := mockHTTP.GetRequests()
|
||||
if len(requests) != 1 {
|
||||
t.Errorf("should not retry for 400 error, got %d requests", len(requests))
|
||||
@@ -235,7 +235,7 @@ func TestClient_Retry_NonRetryableError(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试钩子方法
|
||||
// Test Hook Methods
|
||||
// ============================================================
|
||||
|
||||
func TestClient_BuildMCPRequestBody(t *testing.T) {
|
||||
@@ -368,7 +368,7 @@ func TestClient_IsRetryableError(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试 SetTimeout
|
||||
// Test SetTimeout
|
||||
// ============================================================
|
||||
|
||||
func TestClient_SetTimeout(t *testing.T) {
|
||||
@@ -384,7 +384,7 @@ func TestClient_SetTimeout(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试 String 方法
|
||||
// Test String Method
|
||||
// ============================================================
|
||||
|
||||
func TestClient_String(t *testing.T) {
|
||||
@@ -404,7 +404,7 @@ func TestClient_String(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
// Helper function
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr))
|
||||
}
|
||||
|
||||
+11
-11
@@ -9,36 +9,36 @@ import (
|
||||
"nofx/logger"
|
||||
)
|
||||
|
||||
// Config 客户端配置(集中管理所有配置)
|
||||
// Config client configuration (centralized management of all configurations)
|
||||
type Config struct {
|
||||
// Provider 配置
|
||||
// Provider configuration
|
||||
Provider string
|
||||
APIKey string
|
||||
BaseURL string
|
||||
Model string
|
||||
|
||||
// 行为配置
|
||||
// Behavior configuration
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
UseFullURL bool
|
||||
|
||||
// 重试配置
|
||||
// Retry configuration
|
||||
MaxRetries int
|
||||
RetryWaitBase time.Duration
|
||||
RetryableErrors []string
|
||||
|
||||
// 超时配置
|
||||
// Timeout configuration
|
||||
Timeout time.Duration
|
||||
|
||||
// 依赖注入
|
||||
// Dependency injection
|
||||
Logger Logger
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
// DefaultConfig 返回默认配置
|
||||
// DefaultConfig returns default configuration
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
// 默认值
|
||||
// Default values
|
||||
MaxTokens: getEnvInt("AI_MAX_TOKENS", 2000),
|
||||
Temperature: MCPClientTemperature,
|
||||
MaxRetries: MaxRetryTimes,
|
||||
@@ -46,13 +46,13 @@ func DefaultConfig() *Config {
|
||||
Timeout: DefaultTimeout,
|
||||
RetryableErrors: retryableErrors,
|
||||
|
||||
// 默认依赖(使用全局 logger)
|
||||
// Default dependencies (use global logger)
|
||||
Logger: logger.NewMCPLogger(),
|
||||
HTTPClient: &http.Client{Timeout: DefaultTimeout},
|
||||
}
|
||||
}
|
||||
|
||||
// getEnvInt 从环境变量读取整数,失败则返回默认值
|
||||
// getEnvInt reads integer from environment variable, returns default value if failed
|
||||
func getEnvInt(key string, defaultValue int) int {
|
||||
if val := os.Getenv(key); val != "" {
|
||||
if parsed, err := strconv.Atoi(val); err == nil && parsed > 0 {
|
||||
@@ -62,7 +62,7 @@ func getEnvInt(key string, defaultValue int) int {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// getEnvString 从环境变量读取字符串,为空则返回默认值
|
||||
// getEnvString reads string from environment variable, returns default value if empty
|
||||
func getEnvString(key string, defaultValue string) string {
|
||||
if val := os.Getenv(key); val != "" {
|
||||
return val
|
||||
|
||||
+38
-38
@@ -11,49 +11,49 @@ import (
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// 测试 Config 字段真正被使用(验证问题2修复)
|
||||
// Test Config Fields Are Actually Used (Verify Issue 2 Fix)
|
||||
// ============================================================
|
||||
|
||||
func TestConfig_MaxRetries_IsUsed(t *testing.T) {
|
||||
mockHTTP := NewMockHTTPClient()
|
||||
mockLogger := NewMockLogger()
|
||||
|
||||
// 设置 HTTP 客户端返回错误
|
||||
// Set HTTP client to return error
|
||||
callCount := 0
|
||||
mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) {
|
||||
callCount++
|
||||
return nil, errors.New("connection reset")
|
||||
}
|
||||
|
||||
// 创建客户端并设置自定义重试次数为 5
|
||||
// Create client and set custom retry count to 5
|
||||
client := NewClient(
|
||||
WithHTTPClient(mockHTTP.ToHTTPClient()),
|
||||
WithLogger(mockLogger),
|
||||
WithAPIKey("sk-test-key"),
|
||||
WithMaxRetries(5), // ✅ 设置重试5次
|
||||
WithMaxRetries(5), // Set to retry 5 times
|
||||
)
|
||||
|
||||
// 调用 API(应该失败)
|
||||
// Call API (should fail)
|
||||
_, err := client.CallWithMessages("system", "user")
|
||||
|
||||
if err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
|
||||
// 验证确实重试了5次(而不是默认的3次)
|
||||
// Verify indeed retried 5 times (not the default 3 times)
|
||||
if callCount != 5 {
|
||||
t.Errorf("expected 5 retry attempts (from WithMaxRetries(5)), got %d", callCount)
|
||||
}
|
||||
|
||||
// 验证日志中显示正确的重试次数
|
||||
// Verify logs show correct retry count
|
||||
logs := mockLogger.GetLogsByLevel("WARN")
|
||||
expectedWarningCount := 4 // 第2、3、4、5次重试时会打印警告
|
||||
expectedWarningCount := 4 // Warnings will be printed on 2nd, 3rd, 4th, 5th retry
|
||||
actualWarningCount := 0
|
||||
for _, log := range logs {
|
||||
if log.Message == "⚠️ AI API调用失败,正在重试 (2/5)..." ||
|
||||
log.Message == "⚠️ AI API调用失败,正在重试 (3/5)..." ||
|
||||
log.Message == "⚠️ AI API调用失败,正在重试 (4/5)..." ||
|
||||
log.Message == "⚠️ AI API调用失败,正在重试 (5/5)..." {
|
||||
if log.Message == "⚠️ AI API call failed, retrying (2/5)..." ||
|
||||
log.Message == "⚠️ AI API call failed, retrying (3/5)..." ||
|
||||
log.Message == "⚠️ AI API call failed, retrying (4/5)..." ||
|
||||
log.Message == "⚠️ AI API call failed, retrying (5/5)..." {
|
||||
actualWarningCount++
|
||||
}
|
||||
}
|
||||
@@ -73,20 +73,20 @@ func TestConfig_Temperature_IsUsed(t *testing.T) {
|
||||
|
||||
customTemperature := 0.8
|
||||
|
||||
// 创建客户端并设置自定义 temperature
|
||||
// Create client and set custom temperature
|
||||
client := NewClient(
|
||||
WithHTTPClient(mockHTTP.ToHTTPClient()),
|
||||
WithLogger(mockLogger),
|
||||
WithAPIKey("sk-test-key"),
|
||||
WithTemperature(customTemperature), // ✅ 设置自定义 temperature
|
||||
WithTemperature(customTemperature), // Set custom temperature
|
||||
)
|
||||
|
||||
c := client.(*Client)
|
||||
|
||||
// 构建请求体
|
||||
// Build request body
|
||||
requestBody := c.buildMCPRequestBody("system", "user")
|
||||
|
||||
// 验证 temperature 字段
|
||||
// Verify temperature field
|
||||
temp, ok := requestBody["temperature"].(float64)
|
||||
if !ok {
|
||||
t.Fatal("temperature should be float64")
|
||||
@@ -96,26 +96,26 @@ func TestConfig_Temperature_IsUsed(t *testing.T) {
|
||||
t.Errorf("expected temperature %f (from WithTemperature), got %f", customTemperature, temp)
|
||||
}
|
||||
|
||||
// 也可以通过实际 HTTP 请求验证
|
||||
// Can also verify through actual HTTP request
|
||||
_, err := client.CallWithMessages("system", "user")
|
||||
if err != nil {
|
||||
t.Fatalf("should not error: %v", err)
|
||||
}
|
||||
|
||||
// 检查发送的请求体
|
||||
// Check sent request body
|
||||
requests := mockHTTP.GetRequests()
|
||||
if len(requests) != 1 {
|
||||
t.Fatalf("expected 1 request, got %d", len(requests))
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
// Parse request body
|
||||
var body map[string]interface{}
|
||||
decoder := json.NewDecoder(requests[0].Body)
|
||||
if err := decoder.Decode(&body); err != nil {
|
||||
t.Fatalf("failed to decode request body: %v", err)
|
||||
}
|
||||
|
||||
// 验证 temperature
|
||||
// Verify temperature
|
||||
if body["temperature"] != customTemperature {
|
||||
t.Errorf("expected temperature %f in HTTP request, got %v", customTemperature, body["temperature"])
|
||||
}
|
||||
@@ -125,18 +125,18 @@ func TestConfig_RetryWaitBase_IsUsed(t *testing.T) {
|
||||
mockHTTP := NewMockHTTPClient()
|
||||
mockLogger := NewMockLogger()
|
||||
|
||||
// 设置成功响应(在 ResponseFunc 之前)
|
||||
// Set success response (before ResponseFunc)
|
||||
mockHTTP.SetSuccessResponse("AI response")
|
||||
|
||||
// 设置 HTTP 客户端前2次返回错误,第3次成功
|
||||
// Set HTTP client to return error first 2 times, success on 3rd time
|
||||
callCount := 0
|
||||
successResponse := mockHTTP.Response // 保存成功响应字符串
|
||||
successResponse := mockHTTP.Response // Save success response string
|
||||
mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) {
|
||||
callCount++
|
||||
if callCount <= 2 {
|
||||
return nil, errors.New("timeout exceeded")
|
||||
}
|
||||
// 第3次返回成功响应
|
||||
// 3rd time return success response
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(bytes.NewBufferString(successResponse)),
|
||||
@@ -144,27 +144,27 @@ func TestConfig_RetryWaitBase_IsUsed(t *testing.T) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 设置自定义重试等待基数为 1 秒(而不是默认的 2 秒)
|
||||
// Set custom retry wait base to 1 second (instead of default 2 seconds)
|
||||
customWaitBase := 1 * time.Second
|
||||
|
||||
client := NewClient(
|
||||
WithHTTPClient(mockHTTP.ToHTTPClient()),
|
||||
WithLogger(mockLogger),
|
||||
WithAPIKey("sk-test-key"),
|
||||
WithRetryWaitBase(customWaitBase), // ✅ 设置自定义等待时间
|
||||
WithRetryWaitBase(customWaitBase), // Set custom wait time
|
||||
WithMaxRetries(3),
|
||||
)
|
||||
|
||||
// 记录开始时间
|
||||
// Record start time
|
||||
start := time.Now()
|
||||
|
||||
// 调用 API
|
||||
// Call API
|
||||
_, err := client.CallWithMessages("system", "user")
|
||||
|
||||
// 记录结束时间
|
||||
// Record end time
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// 第3次成功,但前面失败了2次
|
||||
// 3rd time succeeds, but failed 2 times before
|
||||
if err != nil {
|
||||
t.Fatalf("should succeed on 3rd attempt, got error: %v", err)
|
||||
}
|
||||
@@ -173,10 +173,10 @@ func TestConfig_RetryWaitBase_IsUsed(t *testing.T) {
|
||||
t.Errorf("expected 3 attempts, got %d", callCount)
|
||||
}
|
||||
|
||||
// 验证等待时间
|
||||
// 第1次失败后等待 1s (customWaitBase * 1)
|
||||
// 第2次失败后等待 2s (customWaitBase * 2)
|
||||
// 总等待时间应该约为 3s (允许一些误差)
|
||||
// Verify wait time
|
||||
// After 1st failure wait 1s (customWaitBase * 1)
|
||||
// After 2nd failure wait 2s (customWaitBase * 2)
|
||||
// Total wait time should be about 3s (allow some error)
|
||||
expectedWait := 3 * time.Second
|
||||
tolerance := 200 * time.Millisecond
|
||||
|
||||
@@ -189,7 +189,7 @@ func TestConfig_RetryableErrors_IsUsed(t *testing.T) {
|
||||
mockHTTP := NewMockHTTPClient()
|
||||
mockLogger := NewMockLogger()
|
||||
|
||||
// 自定义可重试错误列表(只包含 "custom error")
|
||||
// Custom retryable error list (only contains "custom error")
|
||||
customRetryableErrors := []string{"custom error"}
|
||||
|
||||
client := NewClient(
|
||||
@@ -200,7 +200,7 @@ func TestConfig_RetryableErrors_IsUsed(t *testing.T) {
|
||||
|
||||
c := client.(*Client)
|
||||
|
||||
// 修改 config 的 RetryableErrors(暂时没有 WithRetryableErrors 选项)
|
||||
// Modify config's RetryableErrors (no WithRetryableErrors option yet)
|
||||
c.config.RetryableErrors = customRetryableErrors
|
||||
|
||||
tests := []struct {
|
||||
@@ -236,14 +236,14 @@ func TestConfig_RetryableErrors_IsUsed(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试默认值
|
||||
// Test Default Values
|
||||
// ============================================================
|
||||
|
||||
func TestConfig_DefaultValues(t *testing.T) {
|
||||
client := NewClient()
|
||||
c := client.(*Client)
|
||||
|
||||
// 验证默认值
|
||||
// Verify default values
|
||||
if c.config.MaxRetries != 3 {
|
||||
t.Errorf("default MaxRetries should be 3, got %d", c.config.MaxRetries)
|
||||
}
|
||||
|
||||
+15
-15
@@ -14,45 +14,45 @@ type DeepSeekClient struct {
|
||||
*Client
|
||||
}
|
||||
|
||||
// NewDeepSeekClient 创建 DeepSeek 客户端(向前兼容)
|
||||
// NewDeepSeekClient creates DeepSeek client (backward compatible)
|
||||
//
|
||||
// Deprecated: 推荐使用 NewDeepSeekClientWithOptions 以获得更好的灵活性
|
||||
// Deprecated: Recommend using NewDeepSeekClientWithOptions for better flexibility
|
||||
func NewDeepSeekClient() AIClient {
|
||||
return NewDeepSeekClientWithOptions()
|
||||
}
|
||||
|
||||
// NewDeepSeekClientWithOptions 创建 DeepSeek 客户端(支持选项模式)
|
||||
// NewDeepSeekClientWithOptions creates DeepSeek client (supports options pattern)
|
||||
//
|
||||
// 使用示例:
|
||||
// // 基础用法
|
||||
// Usage examples:
|
||||
// // Basic usage
|
||||
// client := mcp.NewDeepSeekClientWithOptions()
|
||||
//
|
||||
// // 自定义配置
|
||||
// // Custom configuration
|
||||
// client := mcp.NewDeepSeekClientWithOptions(
|
||||
// mcp.WithAPIKey("sk-xxx"),
|
||||
// mcp.WithLogger(customLogger),
|
||||
// mcp.WithTimeout(60*time.Second),
|
||||
// )
|
||||
func NewDeepSeekClientWithOptions(opts ...ClientOption) AIClient {
|
||||
// 1. 创建 DeepSeek 预设选项
|
||||
// 1. Create DeepSeek preset options
|
||||
deepseekOpts := []ClientOption{
|
||||
WithProvider(ProviderDeepSeek),
|
||||
WithModel(DefaultDeepSeekModel),
|
||||
WithBaseURL(DefaultDeepSeekBaseURL),
|
||||
}
|
||||
|
||||
// 2. 合并用户选项(用户选项优先级更高)
|
||||
// 2. Merge user options (user options have higher priority)
|
||||
allOpts := append(deepseekOpts, opts...)
|
||||
|
||||
// 3. 创建基础客户端
|
||||
// 3. Create base client
|
||||
baseClient := NewClient(allOpts...).(*Client)
|
||||
|
||||
// 4. 创建 DeepSeek 客户端
|
||||
// 4. Create DeepSeek client
|
||||
dsClient := &DeepSeekClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
// 5. 设置 hooks 指向 DeepSeekClient(实现动态分派)
|
||||
// 5. Set hooks to point to DeepSeekClient (implement dynamic dispatch)
|
||||
baseClient.hooks = dsClient
|
||||
|
||||
return dsClient
|
||||
@@ -66,15 +66,15 @@ func (dsClient *DeepSeekClient) SetAPIKey(apiKey string, customURL string, custo
|
||||
}
|
||||
if customURL != "" {
|
||||
dsClient.BaseURL = customURL
|
||||
dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s", customURL)
|
||||
dsClient.logger.Infof("🔧 [MCP] DeepSeek using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用默认 BaseURL: %s", dsClient.BaseURL)
|
||||
dsClient.logger.Infof("🔧 [MCP] DeepSeek using default BaseURL: %s", dsClient.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
dsClient.Model = customModel
|
||||
dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用自定义 Model: %s", customModel)
|
||||
dsClient.logger.Infof("🔧 [MCP] DeepSeek using custom Model: %s", customModel)
|
||||
} else {
|
||||
dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用默认 Model: %s", dsClient.Model)
|
||||
dsClient.logger.Infof("🔧 [MCP] DeepSeek using default Model: %s", dsClient.Model)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+22
-22
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// 测试 DeepSeekClient 创建和配置
|
||||
// Test DeepSeekClient Creation and Configuration
|
||||
// ============================================================
|
||||
|
||||
func TestNewDeepSeekClient_Default(t *testing.T) {
|
||||
@@ -16,13 +16,13 @@ func TestNewDeepSeekClient_Default(t *testing.T) {
|
||||
t.Fatal("client should not be nil")
|
||||
}
|
||||
|
||||
// 类型断言检查
|
||||
// Type assertion check
|
||||
dsClient, ok := client.(*DeepSeekClient)
|
||||
if !ok {
|
||||
t.Fatal("client should be *DeepSeekClient")
|
||||
}
|
||||
|
||||
// 验证默认值
|
||||
// Verify default values
|
||||
if dsClient.Provider != ProviderDeepSeek {
|
||||
t.Errorf("Provider should be '%s', got '%s'", ProviderDeepSeek, dsClient.Provider)
|
||||
}
|
||||
@@ -58,7 +58,7 @@ func TestNewDeepSeekClientWithOptions(t *testing.T) {
|
||||
|
||||
dsClient := client.(*DeepSeekClient)
|
||||
|
||||
// 验证自定义选项被应用
|
||||
// Verify custom options are applied
|
||||
if dsClient.logger != mockLogger {
|
||||
t.Error("logger should be set from option")
|
||||
}
|
||||
@@ -75,7 +75,7 @@ func TestNewDeepSeekClientWithOptions(t *testing.T) {
|
||||
t.Error("MaxTokens should be 4000")
|
||||
}
|
||||
|
||||
// 验证 DeepSeek 默认值仍然保留
|
||||
// Verify DeepSeek default values are retained
|
||||
if dsClient.Provider != ProviderDeepSeek {
|
||||
t.Errorf("Provider should still be '%s'", ProviderDeepSeek)
|
||||
}
|
||||
@@ -86,7 +86,7 @@ func TestNewDeepSeekClientWithOptions(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试 SetAPIKey
|
||||
// Test SetAPIKey
|
||||
// ============================================================
|
||||
|
||||
func TestDeepSeekClient_SetAPIKey(t *testing.T) {
|
||||
@@ -97,20 +97,20 @@ func TestDeepSeekClient_SetAPIKey(t *testing.T) {
|
||||
|
||||
dsClient := client.(*DeepSeekClient)
|
||||
|
||||
// 测试设置 API Key(默认 URL 和 Model)
|
||||
// Test setting API Key (default URL and Model)
|
||||
dsClient.SetAPIKey("sk-test-key-12345678", "", "")
|
||||
|
||||
if dsClient.APIKey != "sk-test-key-12345678" {
|
||||
t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", dsClient.APIKey)
|
||||
}
|
||||
|
||||
// 验证日志记录
|
||||
// Verify logging
|
||||
logs := mockLogger.GetLogsByLevel("INFO")
|
||||
if len(logs) == 0 {
|
||||
t.Error("should have logged API key setting")
|
||||
}
|
||||
|
||||
// 验证 BaseURL 和 Model 保持默认
|
||||
// Verify BaseURL and Model remain default
|
||||
if dsClient.BaseURL != DefaultDeepSeekBaseURL {
|
||||
t.Error("BaseURL should remain default")
|
||||
}
|
||||
@@ -135,11 +135,11 @@ func TestDeepSeekClient_SetAPIKey_WithCustomURL(t *testing.T) {
|
||||
t.Errorf("BaseURL should be '%s', got '%s'", customURL, dsClient.BaseURL)
|
||||
}
|
||||
|
||||
// 验证日志记录
|
||||
// Verify logging
|
||||
logs := mockLogger.GetLogsByLevel("INFO")
|
||||
hasCustomURLLog := false
|
||||
for _, log := range logs {
|
||||
if log.Format == "🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s" {
|
||||
if log.Format == "🔧 [MCP] DeepSeek using custom BaseURL: %s" {
|
||||
hasCustomURLLog = true
|
||||
break
|
||||
}
|
||||
@@ -165,11 +165,11 @@ func TestDeepSeekClient_SetAPIKey_WithCustomModel(t *testing.T) {
|
||||
t.Errorf("Model should be '%s', got '%s'", customModel, dsClient.Model)
|
||||
}
|
||||
|
||||
// 验证日志记录
|
||||
// Verify logging
|
||||
logs := mockLogger.GetLogsByLevel("INFO")
|
||||
hasCustomModelLog := false
|
||||
for _, log := range logs {
|
||||
if log.Format == "🔧 [MCP] DeepSeek 使用自定义 Model: %s" {
|
||||
if log.Format == "🔧 [MCP] DeepSeek using custom Model: %s" {
|
||||
hasCustomModelLog = true
|
||||
break
|
||||
}
|
||||
@@ -181,7 +181,7 @@ func TestDeepSeekClient_SetAPIKey_WithCustomModel(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试集成功能
|
||||
// Test Integration Features
|
||||
// ============================================================
|
||||
|
||||
func TestDeepSeekClient_CallWithMessages_Success(t *testing.T) {
|
||||
@@ -205,7 +205,7 @@ func TestDeepSeekClient_CallWithMessages_Success(t *testing.T) {
|
||||
t.Errorf("expected 'DeepSeek AI response', got '%s'", result)
|
||||
}
|
||||
|
||||
// 验证请求
|
||||
// Verify request
|
||||
requests := mockHTTP.GetRequests()
|
||||
if len(requests) != 1 {
|
||||
t.Fatalf("expected 1 request, got %d", len(requests))
|
||||
@@ -213,19 +213,19 @@ func TestDeepSeekClient_CallWithMessages_Success(t *testing.T) {
|
||||
|
||||
req := requests[0]
|
||||
|
||||
// 验证 URL
|
||||
// Verify URL
|
||||
expectedURL := DefaultDeepSeekBaseURL + "/chat/completions"
|
||||
if req.URL.String() != expectedURL {
|
||||
t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String())
|
||||
}
|
||||
|
||||
// 验证 Authorization header
|
||||
// Verify Authorization header
|
||||
authHeader := req.Header.Get("Authorization")
|
||||
if authHeader != "Bearer sk-test-key" {
|
||||
t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader)
|
||||
}
|
||||
|
||||
// 验证 Content-Type
|
||||
// Verify Content-Type
|
||||
if req.Header.Get("Content-Type") != "application/json" {
|
||||
t.Error("Content-Type should be application/json")
|
||||
}
|
||||
@@ -242,7 +242,7 @@ func TestDeepSeekClient_Timeout(t *testing.T) {
|
||||
t.Errorf("expected timeout 30s, got %v", dsClient.httpClient.Timeout)
|
||||
}
|
||||
|
||||
// 测试 SetTimeout
|
||||
// Test SetTimeout
|
||||
client.SetTimeout(60 * time.Second)
|
||||
|
||||
if dsClient.httpClient.Timeout != 60*time.Second {
|
||||
@@ -251,19 +251,19 @@ func TestDeepSeekClient_Timeout(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试 hooks 机制
|
||||
// Test hooks Mechanism
|
||||
// ============================================================
|
||||
|
||||
func TestDeepSeekClient_HooksIntegration(t *testing.T) {
|
||||
client := NewDeepSeekClientWithOptions()
|
||||
dsClient := client.(*DeepSeekClient)
|
||||
|
||||
// 验证 hooks 指向 dsClient 自己(实现多态)
|
||||
// Verify hooks point to dsClient itself (implements polymorphism)
|
||||
if dsClient.hooks != dsClient {
|
||||
t.Error("hooks should point to dsClient for polymorphism")
|
||||
}
|
||||
|
||||
// 验证 buildUrl 使用 DeepSeek 配置
|
||||
// Verify buildUrl uses DeepSeek configuration
|
||||
url := dsClient.buildUrl()
|
||||
expectedURL := DefaultDeepSeekBaseURL + "/chat/completions"
|
||||
if url != expectedURL {
|
||||
|
||||
+44
-44
@@ -9,21 +9,21 @@ import (
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// 示例 1: 基础用法(向前兼容)
|
||||
// Example 1: Basic Usage (Backward Compatible)
|
||||
// ============================================================
|
||||
|
||||
func Example_backward_compatible() {
|
||||
// ✅ 旧代码继续工作,无需修改
|
||||
// Old code continues to work without modification
|
||||
client := mcp.New()
|
||||
client.SetAPIKey("sk-xxx", "https://api.custom.com", "gpt-4")
|
||||
|
||||
// 使用
|
||||
// Usage
|
||||
result, _ := client.CallWithMessages("system prompt", "user prompt")
|
||||
fmt.Println(result)
|
||||
}
|
||||
|
||||
func Example_deepseek_backward_compatible() {
|
||||
// ✅ DeepSeek 旧代码继续工作
|
||||
// DeepSeek old code continues to work
|
||||
client := mcp.NewDeepSeekClient()
|
||||
client.SetAPIKey("sk-xxx", "", "")
|
||||
|
||||
@@ -32,19 +32,19 @@ func Example_deepseek_backward_compatible() {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 示例 2: 新的推荐用法(选项模式)
|
||||
// Example 2: New Recommended Usage (Options Pattern)
|
||||
// ============================================================
|
||||
|
||||
func Example_new_client_basic() {
|
||||
// 使用默认配置
|
||||
// Use default configuration
|
||||
client := mcp.NewClient()
|
||||
|
||||
// 使用 DeepSeek
|
||||
// Use DeepSeek
|
||||
client = mcp.NewClient(
|
||||
mcp.WithDeepSeekConfig("sk-xxx"),
|
||||
)
|
||||
|
||||
// 使用 Qwen
|
||||
// Use Qwen
|
||||
client = mcp.NewClient(
|
||||
mcp.WithQwenConfig("sk-xxx"),
|
||||
)
|
||||
@@ -53,7 +53,7 @@ func Example_new_client_basic() {
|
||||
}
|
||||
|
||||
func Example_new_client_with_options() {
|
||||
// 组合多个选项
|
||||
// Combine multiple options
|
||||
client := mcp.NewClient(
|
||||
mcp.WithDeepSeekConfig("sk-xxx"),
|
||||
mcp.WithTimeout(60*time.Second),
|
||||
@@ -67,10 +67,10 @@ func Example_new_client_with_options() {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 示例 3: 自定义日志器
|
||||
// Example 3: Custom Logger
|
||||
// ============================================================
|
||||
|
||||
// CustomLogger 自定义日志器示例
|
||||
// CustomLogger custom logger example
|
||||
type CustomLogger struct{}
|
||||
|
||||
func (l *CustomLogger) Debugf(format string, args ...any) {
|
||||
@@ -90,7 +90,7 @@ func (l *CustomLogger) Errorf(format string, args ...any) {
|
||||
}
|
||||
|
||||
func Example_custom_logger() {
|
||||
// 使用自定义日志器
|
||||
// Use custom logger
|
||||
customLogger := &CustomLogger{}
|
||||
|
||||
client := mcp.NewClient(
|
||||
@@ -103,7 +103,7 @@ func Example_custom_logger() {
|
||||
}
|
||||
|
||||
func Example_no_logger_for_testing() {
|
||||
// 测试时禁用日志
|
||||
// Disable logging during testing
|
||||
client := mcp.NewClient(
|
||||
mcp.WithLogger(mcp.NewNoopLogger()),
|
||||
)
|
||||
@@ -113,16 +113,16 @@ func Example_no_logger_for_testing() {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 示例 4: 自定义 HTTP 客户端
|
||||
// Example 4: Custom HTTP Client
|
||||
// ============================================================
|
||||
|
||||
func Example_custom_http_client() {
|
||||
// 自定义 HTTP 客户端(添加代理、TLS等)
|
||||
// Custom HTTP client (add proxy, TLS, etc.)
|
||||
customHTTP := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
// 自定义 TLS、连接池等
|
||||
// Custom TLS, connection pool, etc.
|
||||
},
|
||||
}
|
||||
|
||||
@@ -136,16 +136,16 @@ func Example_custom_http_client() {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 示例 5: DeepSeek 客户端(新 API)
|
||||
// Example 5: DeepSeek Client (New API)
|
||||
// ============================================================
|
||||
|
||||
func Example_deepseek_new_api() {
|
||||
// 基础用法
|
||||
// Basic usage
|
||||
client := mcp.NewDeepSeekClientWithOptions(
|
||||
mcp.WithAPIKey("sk-xxx"),
|
||||
)
|
||||
|
||||
// 高级用法
|
||||
// Advanced usage
|
||||
client = mcp.NewDeepSeekClientWithOptions(
|
||||
mcp.WithAPIKey("sk-xxx"),
|
||||
mcp.WithLogger(&CustomLogger{}),
|
||||
@@ -158,16 +158,16 @@ func Example_deepseek_new_api() {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 示例 6: Qwen 客户端(新 API)
|
||||
// Example 6: Qwen Client (New API)
|
||||
// ============================================================
|
||||
|
||||
func Example_qwen_new_api() {
|
||||
// 基础用法
|
||||
// Basic usage
|
||||
client := mcp.NewQwenClientWithOptions(
|
||||
mcp.WithAPIKey("sk-xxx"),
|
||||
)
|
||||
|
||||
// 高级用法
|
||||
// Advanced usage
|
||||
client = mcp.NewQwenClientWithOptions(
|
||||
mcp.WithAPIKey("sk-xxx"),
|
||||
mcp.WithLogger(&CustomLogger{}),
|
||||
@@ -179,18 +179,18 @@ func Example_qwen_new_api() {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 示例 7: 在 trader/auto_trader.go 中的迁移示例
|
||||
// Example 7: Migration Example in trader/auto_trader.go
|
||||
// ============================================================
|
||||
|
||||
func Example_trader_migration() {
|
||||
// === 旧代码(继续工作)===
|
||||
// Old code (continues to work)
|
||||
oldStyleClient := func(apiKey, customURL, customModel string) mcp.AIClient {
|
||||
client := mcp.NewDeepSeekClient()
|
||||
client.SetAPIKey(apiKey, customURL, customModel)
|
||||
return client
|
||||
}
|
||||
|
||||
// === 新代码(推荐)===
|
||||
// New code (recommended)
|
||||
newStyleClient := func(apiKey, customURL, customModel string) mcp.AIClient {
|
||||
opts := []mcp.ClientOption{
|
||||
mcp.WithAPIKey(apiKey),
|
||||
@@ -207,37 +207,37 @@ func Example_trader_migration() {
|
||||
return mcp.NewDeepSeekClientWithOptions(opts...)
|
||||
}
|
||||
|
||||
// 两种方式都能工作
|
||||
// Both approaches work
|
||||
_ = oldStyleClient("sk-xxx", "", "")
|
||||
_ = newStyleClient("sk-xxx", "", "")
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 示例 8: 测试场景
|
||||
// Example 8: Testing Scenarios
|
||||
// ============================================================
|
||||
|
||||
// MockHTTPClient Mock HTTP 客户端
|
||||
// MockHTTPClient Mock HTTP client
|
||||
type MockHTTPClient struct {
|
||||
Response string
|
||||
}
|
||||
|
||||
func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
// 返回预设的响应
|
||||
// Return preset response
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: nil, // 实际测试中需要实现
|
||||
Body: nil, // Need to implement in actual tests
|
||||
}, nil
|
||||
}
|
||||
|
||||
func Example_testing_with_mock() {
|
||||
// 测试时使用 Mock
|
||||
// Use Mock during testing
|
||||
// mockHTTP := &MockHTTPClient{
|
||||
// Response: `{"choices":[{"message":{"content":"test response"}}]}`,
|
||||
// }
|
||||
|
||||
client := mcp.NewClient(
|
||||
// mcp.WithHTTPClient(mockHTTP), // 实际测试中使用 mockHTTP
|
||||
mcp.WithLogger(mcp.NewNoopLogger()), // 禁用日志
|
||||
// mcp.WithHTTPClient(mockHTTP), // Use mockHTTP in actual tests
|
||||
mcp.WithLogger(mcp.NewNoopLogger()), // Disable logging
|
||||
)
|
||||
|
||||
result, _ := client.CallWithMessages("system", "user")
|
||||
@@ -245,20 +245,20 @@ func Example_testing_with_mock() {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 示例 9: 环境特定配置
|
||||
// Example 9: Environment-Specific Configuration
|
||||
// ============================================================
|
||||
|
||||
func Example_environment_specific() {
|
||||
// 开发环境:详细日志
|
||||
// Development environment: detailed logging
|
||||
devClient := mcp.NewClient(
|
||||
mcp.WithDeepSeekConfig("sk-xxx"),
|
||||
mcp.WithLogger(&CustomLogger{}), // 详细日志
|
||||
mcp.WithLogger(&CustomLogger{}), // Detailed logging
|
||||
)
|
||||
|
||||
// 生产环境:结构化日志 + 超时保护
|
||||
// Production environment: structured logging + timeout protection
|
||||
prodClient := mcp.NewClient(
|
||||
mcp.WithDeepSeekConfig("sk-xxx"),
|
||||
// mcp.WithLogger(&ZapLogger{}), // 生产级日志
|
||||
// mcp.WithLogger(&ZapLogger{}), // Production-grade logging
|
||||
mcp.WithTimeout(30*time.Second),
|
||||
mcp.WithMaxRetries(3),
|
||||
)
|
||||
@@ -268,11 +268,11 @@ func Example_environment_specific() {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 示例 10: 完整实战示例
|
||||
// Example 10: Complete Real-World Example
|
||||
// ============================================================
|
||||
|
||||
func Example_real_world_usage() {
|
||||
// 创建带有完整配置的客户端
|
||||
// Create client with complete configuration
|
||||
client := mcp.NewDeepSeekClientWithOptions(
|
||||
mcp.WithAPIKey("sk-xxxxxxxxxx"),
|
||||
mcp.WithTimeout(60*time.Second),
|
||||
@@ -282,9 +282,9 @@ func Example_real_world_usage() {
|
||||
mcp.WithLogger(&CustomLogger{}),
|
||||
)
|
||||
|
||||
// 使用客户端
|
||||
systemPrompt := "你是一个专业的量化交易顾问"
|
||||
userPrompt := "分析 BTC 当前走势"
|
||||
// Use client
|
||||
systemPrompt := "You are a professional quantitative trading advisor"
|
||||
userPrompt := "Analyze current BTC trend"
|
||||
|
||||
result, err := client.CallWithMessages(systemPrompt, userPrompt)
|
||||
if err != nil {
|
||||
@@ -292,5 +292,5 @@ func Example_real_world_usage() {
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("AI 响应: %s\n", result)
|
||||
fmt.Printf("AI response: %s\n", result)
|
||||
}
|
||||
|
||||
+5
-5
@@ -5,18 +5,18 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// AIClient AI客户端公开接口(给外部使用)
|
||||
// AIClient public AI client interface (for external use)
|
||||
type AIClient interface {
|
||||
SetAPIKey(apiKey string, customURL string, customModel string)
|
||||
SetTimeout(timeout time.Duration)
|
||||
CallWithMessages(systemPrompt, userPrompt string) (string, error)
|
||||
CallWithRequest(req *Request) (string, error) // 构建器模式 API(支持高级功能)
|
||||
CallWithRequest(req *Request) (string, error) // Builder pattern API (supports advanced features)
|
||||
}
|
||||
|
||||
// clientHooks 内部钩子接口(用于子类重写特定步骤)
|
||||
// 这些方法只在包内部使用,实现动态分派
|
||||
// clientHooks internal hook interface (for subclass to override specific steps)
|
||||
// These methods are only used inside the package to implement dynamic dispatch
|
||||
type clientHooks interface {
|
||||
// 可被子类重写的钩子方法
|
||||
// Hook methods that can be overridden by subclass
|
||||
|
||||
call(systemPrompt, userPrompt string) (string, error)
|
||||
|
||||
|
||||
+5
-5
@@ -1,8 +1,8 @@
|
||||
package mcp
|
||||
|
||||
// Logger 日志接口(抽象依赖)
|
||||
// 使用 Printf 风格的方法名,方便集成 logrus、zap 等主流日志库
|
||||
// 默认使用全局 logger 包(见 mcp/config.go)
|
||||
// Logger interface (abstract dependency)
|
||||
// Uses Printf-style method names for easy integration with mainstream logging libraries like logrus, zap, etc.
|
||||
// Default uses global logger package (see mcp/config.go)
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...any)
|
||||
Infof(format string, args ...any)
|
||||
@@ -10,7 +10,7 @@ type Logger interface {
|
||||
Errorf(format string, args ...any)
|
||||
}
|
||||
|
||||
// noopLogger 空日志实现(测试时使用)
|
||||
// noopLogger no-op logger implementation (used in tests)
|
||||
type noopLogger struct{}
|
||||
|
||||
func (l *noopLogger) Debugf(format string, args ...any) {}
|
||||
@@ -18,7 +18,7 @@ func (l *noopLogger) Infof(format string, args ...any) {}
|
||||
func (l *noopLogger) Warnf(format string, args ...any) {}
|
||||
func (l *noopLogger) Errorf(format string, args ...any) {}
|
||||
|
||||
// NewNoopLogger 创建空日志器(测试使用)
|
||||
// NewNoopLogger creates no-op logger (for testing)
|
||||
func NewNoopLogger() Logger {
|
||||
return &noopLogger{}
|
||||
}
|
||||
|
||||
+28
-28
@@ -13,19 +13,19 @@ import (
|
||||
// Mock Logger
|
||||
// ============================================================
|
||||
|
||||
// MockLogger Mock 日志器(用于测试)
|
||||
// MockLogger Mock logger (for testing)
|
||||
type MockLogger struct {
|
||||
mu sync.Mutex
|
||||
Logs []LogEntry
|
||||
Enabled bool // 是否启用日志记录
|
||||
Enabled bool // Whether logging is enabled
|
||||
}
|
||||
|
||||
// LogEntry 日志条目
|
||||
// LogEntry log entry
|
||||
type LogEntry struct {
|
||||
Level string
|
||||
Format string
|
||||
Args []any
|
||||
Message string // 格式化后的消息
|
||||
Message string // Formatted message
|
||||
}
|
||||
|
||||
func NewMockLogger() *MockLogger {
|
||||
@@ -68,14 +68,14 @@ func (m *MockLogger) log(level, format string, args ...any) {
|
||||
})
|
||||
}
|
||||
|
||||
// GetLogs 获取所有日志
|
||||
// GetLogs gets all logs
|
||||
func (m *MockLogger) GetLogs() []LogEntry {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return append([]LogEntry{}, m.Logs...)
|
||||
}
|
||||
|
||||
// GetLogsByLevel 获取指定级别的日志
|
||||
// GetLogsByLevel gets logs by specified level
|
||||
func (m *MockLogger) GetLogsByLevel(level string) []LogEntry {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -89,14 +89,14 @@ func (m *MockLogger) GetLogsByLevel(level string) []LogEntry {
|
||||
return result
|
||||
}
|
||||
|
||||
// Clear 清空日志
|
||||
// Clear clears all logs
|
||||
func (m *MockLogger) Clear() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.Logs = make([]LogEntry, 0)
|
||||
}
|
||||
|
||||
// HasLog 检查是否包含指定消息
|
||||
// HasLog checks if contains specified message
|
||||
func (m *MockLogger) HasLog(level, message string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -110,20 +110,20 @@ func (m *MockLogger) HasLog(level, message string) bool {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Mock HTTP Client (实现 http.RoundTripper)
|
||||
// Mock HTTP Client (implements http.RoundTripper)
|
||||
// ============================================================
|
||||
|
||||
// MockHTTPClient Mock HTTP 客户端(实现 http.RoundTripper)
|
||||
// MockHTTPClient Mock HTTP client (implements http.RoundTripper)
|
||||
type MockHTTPClient struct {
|
||||
mu sync.Mutex
|
||||
|
||||
// 配置
|
||||
// Configuration
|
||||
Response string
|
||||
StatusCode int
|
||||
Error error
|
||||
ResponseFunc func(req *http.Request) (*http.Response, error) // 自定义响应函数
|
||||
ResponseFunc func(req *http.Request) (*http.Response, error) // Custom response function
|
||||
|
||||
// 记录
|
||||
// Recording
|
||||
Requests []*http.Request
|
||||
}
|
||||
|
||||
@@ -134,32 +134,32 @@ func NewMockHTTPClient() *MockHTTPClient {
|
||||
}
|
||||
}
|
||||
|
||||
// ToHTTPClient 转换为 http.Client
|
||||
// ToHTTPClient converts to http.Client
|
||||
func (m *MockHTTPClient) ToHTTPClient() *http.Client {
|
||||
return &http.Client{
|
||||
Transport: m,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip 实现 http.RoundTripper 接口
|
||||
// RoundTrip implements http.RoundTripper interface
|
||||
func (m *MockHTTPClient) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// 记录请求
|
||||
// Record request
|
||||
m.Requests = append(m.Requests, req)
|
||||
|
||||
// 如果有自定义响应函数,使用它
|
||||
// If custom response function exists, use it
|
||||
if m.ResponseFunc != nil {
|
||||
return m.ResponseFunc(req)
|
||||
}
|
||||
|
||||
// 如果设置了错误,返回错误
|
||||
// If error is set, return error
|
||||
if m.Error != nil {
|
||||
return nil, m.Error
|
||||
}
|
||||
|
||||
// 返回模拟响应
|
||||
// Return mock response
|
||||
resp := &http.Response{
|
||||
StatusCode: m.StatusCode,
|
||||
Body: io.NopCloser(bytes.NewBufferString(m.Response)),
|
||||
@@ -169,14 +169,14 @@ func (m *MockHTTPClient) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// GetRequests 获取所有请求
|
||||
// GetRequests gets all requests
|
||||
func (m *MockHTTPClient) GetRequests() []*http.Request {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return append([]*http.Request{}, m.Requests...)
|
||||
}
|
||||
|
||||
// GetLastRequest 获取最后一次请求
|
||||
// GetLastRequest gets last request
|
||||
func (m *MockHTTPClient) GetLastRequest() *http.Request {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -187,14 +187,14 @@ func (m *MockHTTPClient) GetLastRequest() *http.Request {
|
||||
return m.Requests[len(m.Requests)-1]
|
||||
}
|
||||
|
||||
// Reset 重置状态
|
||||
// Reset resets state
|
||||
func (m *MockHTTPClient) Reset() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.Requests = make([]*http.Request, 0)
|
||||
}
|
||||
|
||||
// SetSuccessResponse 设置成功响应
|
||||
// SetSuccessResponse sets success response
|
||||
func (m *MockHTTPClient) SetSuccessResponse(content string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -204,7 +204,7 @@ func (m *MockHTTPClient) SetSuccessResponse(content string) {
|
||||
m.Error = nil
|
||||
}
|
||||
|
||||
// SetErrorResponse 设置错误响应
|
||||
// SetErrorResponse sets error response
|
||||
func (m *MockHTTPClient) SetErrorResponse(statusCode int, message string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -214,7 +214,7 @@ func (m *MockHTTPClient) SetErrorResponse(statusCode int, message string) {
|
||||
m.Error = nil
|
||||
}
|
||||
|
||||
// SetNetworkError 设置网络错误
|
||||
// SetNetworkError sets network error
|
||||
func (m *MockHTTPClient) SetNetworkError(err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -223,10 +223,10 @@ func (m *MockHTTPClient) SetNetworkError(err error) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Mock Client Hooks (用于测试钩子机制)
|
||||
// Mock Client Hooks (for testing hook mechanism)
|
||||
// ============================================================
|
||||
|
||||
// MockClientHooks Mock 客户端钩子
|
||||
// MockClientHooks Mock client hooks
|
||||
type MockClientHooks struct {
|
||||
BuildRequestBodyCalled int
|
||||
BuildUrlCalled int
|
||||
@@ -235,7 +235,7 @@ type MockClientHooks struct {
|
||||
ParseResponseCalled int
|
||||
IsRetryableErrorCalled int
|
||||
|
||||
// 自定义返回值
|
||||
// Custom return values
|
||||
BuildUrlFunc func() string
|
||||
ParseResponseFunc func([]byte) (string, error)
|
||||
IsRetryableErrorFunc func(error) bool
|
||||
|
||||
+29
-29
@@ -5,16 +5,16 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientOption 客户端选项函数(Functional Options 模式)
|
||||
// ClientOption client option function (Functional Options pattern)
|
||||
type ClientOption func(*Config)
|
||||
|
||||
// ============================================================
|
||||
// 依赖注入选项
|
||||
// Dependency Injection Options
|
||||
// ============================================================
|
||||
|
||||
// WithLogger 设置自定义日志器
|
||||
// WithLogger sets custom logger
|
||||
//
|
||||
// 使用示例:
|
||||
// Usage example:
|
||||
// client := mcp.NewClient(mcp.WithLogger(customLogger))
|
||||
func WithLogger(logger Logger) ClientOption {
|
||||
return func(c *Config) {
|
||||
@@ -22,9 +22,9 @@ func WithLogger(logger Logger) ClientOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient 设置自定义 HTTP 客户端
|
||||
// WithHTTPClient sets custom HTTP client
|
||||
//
|
||||
// 使用示例:
|
||||
// Usage example:
|
||||
// httpClient := &http.Client{Timeout: 60 * time.Second}
|
||||
// client := mcp.NewClient(mcp.WithHTTPClient(httpClient))
|
||||
func WithHTTPClient(client *http.Client) ClientOption {
|
||||
@@ -34,12 +34,12 @@ func WithHTTPClient(client *http.Client) ClientOption {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 超时和重试选项
|
||||
// Timeout and Retry Options
|
||||
// ============================================================
|
||||
|
||||
// WithTimeout 设置请求超时时间
|
||||
// WithTimeout sets request timeout duration
|
||||
//
|
||||
// 使用示例:
|
||||
// Usage example:
|
||||
// client := mcp.NewClient(mcp.WithTimeout(60 * time.Second))
|
||||
func WithTimeout(timeout time.Duration) ClientOption {
|
||||
return func(c *Config) {
|
||||
@@ -48,9 +48,9 @@ func WithTimeout(timeout time.Duration) ClientOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxRetries 设置最大重试次数
|
||||
// WithMaxRetries sets maximum retry count
|
||||
//
|
||||
// 使用示例:
|
||||
// Usage example:
|
||||
// client := mcp.NewClient(mcp.WithMaxRetries(5))
|
||||
func WithMaxRetries(maxRetries int) ClientOption {
|
||||
return func(c *Config) {
|
||||
@@ -58,9 +58,9 @@ func WithMaxRetries(maxRetries int) ClientOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithRetryWaitBase 设置重试等待基础时长
|
||||
// WithRetryWaitBase sets base retry wait duration
|
||||
//
|
||||
// 使用示例:
|
||||
// Usage example:
|
||||
// client := mcp.NewClient(mcp.WithRetryWaitBase(3 * time.Second))
|
||||
func WithRetryWaitBase(waitTime time.Duration) ClientOption {
|
||||
return func(c *Config) {
|
||||
@@ -69,12 +69,12 @@ func WithRetryWaitBase(waitTime time.Duration) ClientOption {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// AI 参数选项
|
||||
// AI Parameter Options
|
||||
// ============================================================
|
||||
|
||||
// WithMaxTokens 设置最大 token 数
|
||||
// WithMaxTokens sets maximum token count
|
||||
//
|
||||
// 使用示例:
|
||||
// Usage example:
|
||||
// client := mcp.NewClient(mcp.WithMaxTokens(4000))
|
||||
func WithMaxTokens(maxTokens int) ClientOption {
|
||||
return func(c *Config) {
|
||||
@@ -82,9 +82,9 @@ func WithMaxTokens(maxTokens int) ClientOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithTemperature 设置温度参数
|
||||
// WithTemperature sets temperature parameter
|
||||
//
|
||||
// 使用示例:
|
||||
// Usage example:
|
||||
// client := mcp.NewClient(mcp.WithTemperature(0.7))
|
||||
func WithTemperature(temperature float64) ClientOption {
|
||||
return func(c *Config) {
|
||||
@@ -93,38 +93,38 @@ func WithTemperature(temperature float64) ClientOption {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Provider 配置选项
|
||||
// Provider Configuration Options
|
||||
// ============================================================
|
||||
|
||||
// WithAPIKey 设置 API Key
|
||||
// WithAPIKey sets API Key
|
||||
func WithAPIKey(apiKey string) ClientOption {
|
||||
return func(c *Config) {
|
||||
c.APIKey = apiKey
|
||||
}
|
||||
}
|
||||
|
||||
// WithBaseURL 设置基础 URL
|
||||
// WithBaseURL sets base URL
|
||||
func WithBaseURL(baseURL string) ClientOption {
|
||||
return func(c *Config) {
|
||||
c.BaseURL = baseURL
|
||||
}
|
||||
}
|
||||
|
||||
// WithModel 设置模型名称
|
||||
// WithModel sets model name
|
||||
func WithModel(model string) ClientOption {
|
||||
return func(c *Config) {
|
||||
c.Model = model
|
||||
}
|
||||
}
|
||||
|
||||
// WithProvider 设置提供商
|
||||
// WithProvider sets provider
|
||||
func WithProvider(provider string) ClientOption {
|
||||
return func(c *Config) {
|
||||
c.Provider = provider
|
||||
}
|
||||
}
|
||||
|
||||
// WithUseFullURL 设置是否使用完整 URL
|
||||
// WithUseFullURL sets whether to use full URL
|
||||
func WithUseFullURL(useFullURL bool) ClientOption {
|
||||
return func(c *Config) {
|
||||
c.UseFullURL = useFullURL
|
||||
@@ -132,12 +132,12 @@ func WithUseFullURL(useFullURL bool) ClientOption {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 组合选项(便捷方法)
|
||||
// Combined Options (Convenience Methods)
|
||||
// ============================================================
|
||||
|
||||
// WithDeepSeekConfig 设置 DeepSeek 配置
|
||||
// WithDeepSeekConfig sets DeepSeek configuration
|
||||
//
|
||||
// 使用示例:
|
||||
// Usage example:
|
||||
// client := mcp.NewClient(mcp.WithDeepSeekConfig("sk-xxx"))
|
||||
func WithDeepSeekConfig(apiKey string) ClientOption {
|
||||
return func(c *Config) {
|
||||
@@ -148,9 +148,9 @@ func WithDeepSeekConfig(apiKey string) ClientOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithQwenConfig 设置 Qwen 配置
|
||||
// WithQwenConfig sets Qwen configuration
|
||||
//
|
||||
// 使用示例:
|
||||
// Usage example:
|
||||
// client := mcp.NewClient(mcp.WithQwenConfig("sk-xxx"))
|
||||
func WithQwenConfig(apiKey string) ClientOption {
|
||||
return func(c *Config) {
|
||||
|
||||
+15
-15
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// 测试基础选项
|
||||
// Test Basic Options
|
||||
// ============================================================
|
||||
|
||||
func TestWithProvider(t *testing.T) {
|
||||
@@ -116,7 +116,7 @@ func TestWithHTTPClient(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试预设配置选项
|
||||
// Test Preset Configuration Options
|
||||
// ============================================================
|
||||
|
||||
func TestWithDeepSeekConfig(t *testing.T) {
|
||||
@@ -162,7 +162,7 @@ func TestWithQwenConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试选项组合
|
||||
// Test Options Combination
|
||||
// ============================================================
|
||||
|
||||
func TestMultipleOptions(t *testing.T) {
|
||||
@@ -170,7 +170,7 @@ func TestMultipleOptions(t *testing.T) {
|
||||
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// 应用多个选项
|
||||
// Apply multiple options
|
||||
options := []ClientOption{
|
||||
WithProvider("test-provider"),
|
||||
WithAPIKey("sk-test-key"),
|
||||
@@ -186,7 +186,7 @@ func TestMultipleOptions(t *testing.T) {
|
||||
opt(cfg)
|
||||
}
|
||||
|
||||
// 验证所有选项都被应用
|
||||
// Verify all options are applied
|
||||
if cfg.Provider != "test-provider" {
|
||||
t.Error("Provider should be set")
|
||||
}
|
||||
@@ -223,14 +223,14 @@ func TestMultipleOptions(t *testing.T) {
|
||||
func TestOptionsOverride(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// 先应用 DeepSeek 配置
|
||||
// First apply DeepSeek configuration
|
||||
WithDeepSeekConfig("sk-deepseek-key")(cfg)
|
||||
|
||||
// 然后覆盖某些选项
|
||||
// Then override some options
|
||||
WithModel("custom-model")(cfg)
|
||||
WithMaxTokens(5000)(cfg)
|
||||
|
||||
// 验证覆盖成功
|
||||
// Verify override succeeded
|
||||
if cfg.Model != "custom-model" {
|
||||
t.Errorf("Model should be overridden to 'custom-model', got '%s'", cfg.Model)
|
||||
}
|
||||
@@ -239,7 +239,7 @@ func TestOptionsOverride(t *testing.T) {
|
||||
t.Errorf("MaxTokens should be overridden to 5000, got %d", cfg.MaxTokens)
|
||||
}
|
||||
|
||||
// 验证其他 DeepSeek 配置保持不变
|
||||
// Verify other DeepSeek configurations remain unchanged
|
||||
if cfg.Provider != ProviderDeepSeek {
|
||||
t.Error("Provider should still be DeepSeek")
|
||||
}
|
||||
@@ -250,7 +250,7 @@ func TestOptionsOverride(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试与客户端集成
|
||||
// Test Integration with Client
|
||||
// ============================================================
|
||||
|
||||
func TestOptionsWithNewClient(t *testing.T) {
|
||||
@@ -266,7 +266,7 @@ func TestOptionsWithNewClient(t *testing.T) {
|
||||
|
||||
c := client.(*Client)
|
||||
|
||||
// 验证选项被正确应用到客户端
|
||||
// Verify options are correctly applied to client
|
||||
if c.Provider != "test-provider" {
|
||||
t.Error("Provider should be set from options")
|
||||
}
|
||||
@@ -299,7 +299,7 @@ func TestOptionsWithDeepSeekClient(t *testing.T) {
|
||||
|
||||
dsClient := client.(*DeepSeekClient)
|
||||
|
||||
// 验证 DeepSeek 默认值
|
||||
// Verify DeepSeek default values
|
||||
if dsClient.Provider != ProviderDeepSeek {
|
||||
t.Error("Provider should be DeepSeek")
|
||||
}
|
||||
@@ -312,7 +312,7 @@ func TestOptionsWithDeepSeekClient(t *testing.T) {
|
||||
t.Error("Model should be DeepSeek default")
|
||||
}
|
||||
|
||||
// 验证自定义选项
|
||||
// Verify custom options
|
||||
if dsClient.APIKey != "sk-deepseek-key" {
|
||||
t.Error("APIKey should be set from options")
|
||||
}
|
||||
@@ -337,7 +337,7 @@ func TestOptionsWithQwenClient(t *testing.T) {
|
||||
|
||||
qwenClient := client.(*QwenClient)
|
||||
|
||||
// 验证 Qwen 默认值
|
||||
// Verify Qwen default values
|
||||
if qwenClient.Provider != ProviderQwen {
|
||||
t.Error("Provider should be Qwen")
|
||||
}
|
||||
@@ -350,7 +350,7 @@ func TestOptionsWithQwenClient(t *testing.T) {
|
||||
t.Error("Model should be Qwen default")
|
||||
}
|
||||
|
||||
// 验证自定义选项
|
||||
// Verify custom options
|
||||
if qwenClient.APIKey != "sk-qwen-key" {
|
||||
t.Error("APIKey should be set from options")
|
||||
}
|
||||
|
||||
+15
-15
@@ -14,45 +14,45 @@ type QwenClient struct {
|
||||
*Client
|
||||
}
|
||||
|
||||
// NewQwenClient 创建 Qwen 客户端(向前兼容)
|
||||
// NewQwenClient creates Qwen client (backward compatible)
|
||||
//
|
||||
// Deprecated: 推荐使用 NewQwenClientWithOptions 以获得更好的灵活性
|
||||
// Deprecated: Recommend using NewQwenClientWithOptions for better flexibility
|
||||
func NewQwenClient() AIClient {
|
||||
return NewQwenClientWithOptions()
|
||||
}
|
||||
|
||||
// NewQwenClientWithOptions 创建 Qwen 客户端(支持选项模式)
|
||||
// NewQwenClientWithOptions creates Qwen client (supports options pattern)
|
||||
//
|
||||
// 使用示例:
|
||||
// // 基础用法
|
||||
// Usage examples:
|
||||
// // Basic usage
|
||||
// client := mcp.NewQwenClientWithOptions()
|
||||
//
|
||||
// // 自定义配置
|
||||
// // Custom configuration
|
||||
// client := mcp.NewQwenClientWithOptions(
|
||||
// mcp.WithAPIKey("sk-xxx"),
|
||||
// mcp.WithLogger(customLogger),
|
||||
// mcp.WithTimeout(60*time.Second),
|
||||
// )
|
||||
func NewQwenClientWithOptions(opts ...ClientOption) AIClient {
|
||||
// 1. 创建 Qwen 预设选项
|
||||
// 1. Create Qwen preset options
|
||||
qwenOpts := []ClientOption{
|
||||
WithProvider(ProviderQwen),
|
||||
WithModel(DefaultQwenModel),
|
||||
WithBaseURL(DefaultQwenBaseURL),
|
||||
}
|
||||
|
||||
// 2. 合并用户选项(用户选项优先级更高)
|
||||
// 2. Merge user options (user options have higher priority)
|
||||
allOpts := append(qwenOpts, opts...)
|
||||
|
||||
// 3. 创建基础客户端
|
||||
// 3. Create base client
|
||||
baseClient := NewClient(allOpts...).(*Client)
|
||||
|
||||
// 4. 创建 Qwen 客户端
|
||||
// 4. Create Qwen client
|
||||
qwenClient := &QwenClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
// 5. 设置 hooks 指向 QwenClient(实现动态分派)
|
||||
// 5. Set hooks to point to QwenClient (implement dynamic dispatch)
|
||||
baseClient.hooks = qwenClient
|
||||
|
||||
return qwenClient
|
||||
@@ -66,15 +66,15 @@ func (qwenClient *QwenClient) SetAPIKey(apiKey string, customURL string, customM
|
||||
}
|
||||
if customURL != "" {
|
||||
qwenClient.BaseURL = customURL
|
||||
qwenClient.logger.Infof("🔧 [MCP] Qwen 使用自定义 BaseURL: %s", customURL)
|
||||
qwenClient.logger.Infof("🔧 [MCP] Qwen using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
qwenClient.logger.Infof("🔧 [MCP] Qwen 使用默认 BaseURL: %s", qwenClient.BaseURL)
|
||||
qwenClient.logger.Infof("🔧 [MCP] Qwen using default BaseURL: %s", qwenClient.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
qwenClient.Model = customModel
|
||||
qwenClient.logger.Infof("🔧 [MCP] Qwen 使用自定义 Model: %s", customModel)
|
||||
qwenClient.logger.Infof("🔧 [MCP] Qwen using custom Model: %s", customModel)
|
||||
} else {
|
||||
qwenClient.logger.Infof("🔧 [MCP] Qwen 使用默认 Model: %s", qwenClient.Model)
|
||||
qwenClient.logger.Infof("🔧 [MCP] Qwen using default Model: %s", qwenClient.Model)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+22
-22
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// 测试 QwenClient 创建和配置
|
||||
// Test QwenClient Creation and Configuration
|
||||
// ============================================================
|
||||
|
||||
func TestNewQwenClient_Default(t *testing.T) {
|
||||
@@ -16,13 +16,13 @@ func TestNewQwenClient_Default(t *testing.T) {
|
||||
t.Fatal("client should not be nil")
|
||||
}
|
||||
|
||||
// 类型断言检查
|
||||
// Type assertion check
|
||||
qwenClient, ok := client.(*QwenClient)
|
||||
if !ok {
|
||||
t.Fatal("client should be *QwenClient")
|
||||
}
|
||||
|
||||
// 验证默认值
|
||||
// Verify default values
|
||||
if qwenClient.Provider != ProviderQwen {
|
||||
t.Errorf("Provider should be '%s', got '%s'", ProviderQwen, qwenClient.Provider)
|
||||
}
|
||||
@@ -58,7 +58,7 @@ func TestNewQwenClientWithOptions(t *testing.T) {
|
||||
|
||||
qwenClient := client.(*QwenClient)
|
||||
|
||||
// 验证自定义选项被应用
|
||||
// Verify custom options are applied
|
||||
if qwenClient.logger != mockLogger {
|
||||
t.Error("logger should be set from option")
|
||||
}
|
||||
@@ -75,7 +75,7 @@ func TestNewQwenClientWithOptions(t *testing.T) {
|
||||
t.Error("MaxTokens should be 4000")
|
||||
}
|
||||
|
||||
// 验证 Qwen 默认值仍然保留
|
||||
// Verify Qwen default values are retained
|
||||
if qwenClient.Provider != ProviderQwen {
|
||||
t.Errorf("Provider should still be '%s'", ProviderQwen)
|
||||
}
|
||||
@@ -86,7 +86,7 @@ func TestNewQwenClientWithOptions(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试 SetAPIKey
|
||||
// Test SetAPIKey
|
||||
// ============================================================
|
||||
|
||||
func TestQwenClient_SetAPIKey(t *testing.T) {
|
||||
@@ -97,20 +97,20 @@ func TestQwenClient_SetAPIKey(t *testing.T) {
|
||||
|
||||
qwenClient := client.(*QwenClient)
|
||||
|
||||
// 测试设置 API Key(默认 URL 和 Model)
|
||||
// Test setting API Key (default URL and Model)
|
||||
qwenClient.SetAPIKey("sk-test-key-12345678", "", "")
|
||||
|
||||
if qwenClient.APIKey != "sk-test-key-12345678" {
|
||||
t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", qwenClient.APIKey)
|
||||
}
|
||||
|
||||
// 验证日志记录
|
||||
// Verify logging
|
||||
logs := mockLogger.GetLogsByLevel("INFO")
|
||||
if len(logs) == 0 {
|
||||
t.Error("should have logged API key setting")
|
||||
}
|
||||
|
||||
// 验证 BaseURL 和 Model 保持默认
|
||||
// Verify BaseURL and Model remain default
|
||||
if qwenClient.BaseURL != DefaultQwenBaseURL {
|
||||
t.Error("BaseURL should remain default")
|
||||
}
|
||||
@@ -135,11 +135,11 @@ func TestQwenClient_SetAPIKey_WithCustomURL(t *testing.T) {
|
||||
t.Errorf("BaseURL should be '%s', got '%s'", customURL, qwenClient.BaseURL)
|
||||
}
|
||||
|
||||
// 验证日志记录
|
||||
// Verify logging
|
||||
logs := mockLogger.GetLogsByLevel("INFO")
|
||||
hasCustomURLLog := false
|
||||
for _, log := range logs {
|
||||
if log.Format == "🔧 [MCP] Qwen 使用自定义 BaseURL: %s" {
|
||||
if log.Format == "🔧 [MCP] Qwen using custom BaseURL: %s" {
|
||||
hasCustomURLLog = true
|
||||
break
|
||||
}
|
||||
@@ -165,11 +165,11 @@ func TestQwenClient_SetAPIKey_WithCustomModel(t *testing.T) {
|
||||
t.Errorf("Model should be '%s', got '%s'", customModel, qwenClient.Model)
|
||||
}
|
||||
|
||||
// 验证日志记录
|
||||
// Verify logging
|
||||
logs := mockLogger.GetLogsByLevel("INFO")
|
||||
hasCustomModelLog := false
|
||||
for _, log := range logs {
|
||||
if log.Format == "🔧 [MCP] Qwen 使用自定义 Model: %s" {
|
||||
if log.Format == "🔧 [MCP] Qwen using custom Model: %s" {
|
||||
hasCustomModelLog = true
|
||||
break
|
||||
}
|
||||
@@ -181,7 +181,7 @@ func TestQwenClient_SetAPIKey_WithCustomModel(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试集成功能
|
||||
// Test Integration Features
|
||||
// ============================================================
|
||||
|
||||
func TestQwenClient_CallWithMessages_Success(t *testing.T) {
|
||||
@@ -205,7 +205,7 @@ func TestQwenClient_CallWithMessages_Success(t *testing.T) {
|
||||
t.Errorf("expected 'Qwen AI response', got '%s'", result)
|
||||
}
|
||||
|
||||
// 验证请求
|
||||
// Verify request
|
||||
requests := mockHTTP.GetRequests()
|
||||
if len(requests) != 1 {
|
||||
t.Fatalf("expected 1 request, got %d", len(requests))
|
||||
@@ -213,19 +213,19 @@ func TestQwenClient_CallWithMessages_Success(t *testing.T) {
|
||||
|
||||
req := requests[0]
|
||||
|
||||
// 验证 URL
|
||||
// Verify URL
|
||||
expectedURL := DefaultQwenBaseURL + "/chat/completions"
|
||||
if req.URL.String() != expectedURL {
|
||||
t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String())
|
||||
}
|
||||
|
||||
// 验证 Authorization header
|
||||
// Verify Authorization header
|
||||
authHeader := req.Header.Get("Authorization")
|
||||
if authHeader != "Bearer sk-test-key" {
|
||||
t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader)
|
||||
}
|
||||
|
||||
// 验证 Content-Type
|
||||
// Verify Content-Type
|
||||
if req.Header.Get("Content-Type") != "application/json" {
|
||||
t.Error("Content-Type should be application/json")
|
||||
}
|
||||
@@ -242,7 +242,7 @@ func TestQwenClient_Timeout(t *testing.T) {
|
||||
t.Errorf("expected timeout 30s, got %v", qwenClient.httpClient.Timeout)
|
||||
}
|
||||
|
||||
// 测试 SetTimeout
|
||||
// Test SetTimeout
|
||||
client.SetTimeout(60 * time.Second)
|
||||
|
||||
if qwenClient.httpClient.Timeout != 60*time.Second {
|
||||
@@ -251,19 +251,19 @@ func TestQwenClient_Timeout(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试 hooks 机制
|
||||
// Test hooks Mechanism
|
||||
// ============================================================
|
||||
|
||||
func TestQwenClient_HooksIntegration(t *testing.T) {
|
||||
client := NewQwenClientWithOptions()
|
||||
qwenClient := client.(*QwenClient)
|
||||
|
||||
// 验证 hooks 指向 qwenClient 自己(实现多态)
|
||||
// Verify hooks point to qwenClient itself (implements polymorphism)
|
||||
if qwenClient.hooks != qwenClient {
|
||||
t.Error("hooks should point to qwenClient for polymorphism")
|
||||
}
|
||||
|
||||
// 验证 buildUrl 使用 Qwen 配置
|
||||
// Verify buildUrl uses Qwen configuration
|
||||
url := qwenClient.buildUrl()
|
||||
expectedURL := DefaultQwenBaseURL + "/chat/completions"
|
||||
if url != expectedURL {
|
||||
|
||||
+28
-28
@@ -1,45 +1,45 @@
|
||||
package mcp
|
||||
|
||||
// Message 表示一条对话消息
|
||||
// Message represents a conversation message
|
||||
type Message struct {
|
||||
Role string `json:"role"` // "system", "user", "assistant"
|
||||
Content string `json:"content"` // 消息内容
|
||||
Content string `json:"content"` // Message content
|
||||
}
|
||||
|
||||
// Tool 表示 AI 可以调用的工具/函数
|
||||
// Tool represents a tool/function that AI can call
|
||||
type Tool struct {
|
||||
Type string `json:"type"` // 通常为 "function"
|
||||
Function FunctionDef `json:"function"` // 函数定义
|
||||
Type string `json:"type"` // Usually "function"
|
||||
Function FunctionDef `json:"function"` // Function definition
|
||||
}
|
||||
|
||||
// FunctionDef 函数定义
|
||||
// FunctionDef function definition
|
||||
type FunctionDef struct {
|
||||
Name string `json:"name"` // 函数名
|
||||
Description string `json:"description,omitempty"` // 函数描述
|
||||
Parameters map[string]any `json:"parameters,omitempty"` // 参数 schema (JSON Schema)
|
||||
Name string `json:"name"` // Function name
|
||||
Description string `json:"description,omitempty"` // Function description
|
||||
Parameters map[string]any `json:"parameters,omitempty"` // Parameter schema (JSON Schema)
|
||||
}
|
||||
|
||||
// Request AI API 请求(支持高级功能)
|
||||
// Request AI API request (supports advanced features)
|
||||
type Request struct {
|
||||
// 基础字段
|
||||
Model string `json:"model"` // 模型名称
|
||||
Messages []Message `json:"messages"` // 对话消息列表
|
||||
Stream bool `json:"stream,omitempty"` // 是否流式响应
|
||||
// Basic fields
|
||||
Model string `json:"model"` // Model name
|
||||
Messages []Message `json:"messages"` // Conversation message list
|
||||
Stream bool `json:"stream,omitempty"` // Whether to stream response
|
||||
|
||||
// 可选参数(用于精细控制)
|
||||
Temperature *float64 `json:"temperature,omitempty"` // 温度 (0-2),控制随机性
|
||||
MaxTokens *int `json:"max_tokens,omitempty"` // 最大 token 数
|
||||
TopP *float64 `json:"top_p,omitempty"` // 核采样参数 (0-1)
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // 频率惩罚 (-2 to 2)
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // 存在惩罚 (-2 to 2)
|
||||
Stop []string `json:"stop,omitempty"` // 停止序列
|
||||
// Optional parameters (for fine-grained control)
|
||||
Temperature *float64 `json:"temperature,omitempty"` // Temperature (0-2), controls randomness
|
||||
MaxTokens *int `json:"max_tokens,omitempty"` // Maximum token count
|
||||
TopP *float64 `json:"top_p,omitempty"` // Nucleus sampling parameter (0-1)
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Frequency penalty (-2 to 2)
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Presence penalty (-2 to 2)
|
||||
Stop []string `json:"stop,omitempty"` // Stop sequences
|
||||
|
||||
// 高级功能
|
||||
Tools []Tool `json:"tools,omitempty"` // 可用工具列表
|
||||
ToolChoice string `json:"tool_choice,omitempty"` // 工具选择策略 ("auto", "none", {"type": "function", "function": {"name": "xxx"}})
|
||||
// Advanced features
|
||||
Tools []Tool `json:"tools,omitempty"` // Available tools list
|
||||
ToolChoice string `json:"tool_choice,omitempty"` // Tool choice strategy ("auto", "none", {"type": "function", "function": {"name": "xxx"}})
|
||||
}
|
||||
|
||||
// NewMessage 创建一条消息
|
||||
// NewMessage creates a message
|
||||
func NewMessage(role, content string) Message {
|
||||
return Message{
|
||||
Role: role,
|
||||
@@ -47,7 +47,7 @@ func NewMessage(role, content string) Message {
|
||||
}
|
||||
}
|
||||
|
||||
// NewSystemMessage 创建系统消息
|
||||
// NewSystemMessage creates a system message
|
||||
func NewSystemMessage(content string) Message {
|
||||
return Message{
|
||||
Role: "system",
|
||||
@@ -55,7 +55,7 @@ func NewSystemMessage(content string) Message {
|
||||
}
|
||||
}
|
||||
|
||||
// NewUserMessage 创建用户消息
|
||||
// NewUserMessage creates a user message
|
||||
func NewUserMessage(content string) Message {
|
||||
return Message{
|
||||
Role: "user",
|
||||
@@ -63,7 +63,7 @@ func NewUserMessage(content string) Message {
|
||||
}
|
||||
}
|
||||
|
||||
// NewAssistantMessage 创建助手消息
|
||||
// NewAssistantMessage creates an assistant message
|
||||
func NewAssistantMessage(content string) Message {
|
||||
return Message{
|
||||
Role: "assistant",
|
||||
|
||||
+49
-49
@@ -4,7 +4,7 @@ import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// RequestBuilder 请求构建器
|
||||
// RequestBuilder request builder
|
||||
type RequestBuilder struct {
|
||||
model string
|
||||
messages []Message
|
||||
@@ -19,9 +19,9 @@ type RequestBuilder struct {
|
||||
toolChoice string
|
||||
}
|
||||
|
||||
// NewRequestBuilder 创建请求构建器
|
||||
// NewRequestBuilder creates request builder
|
||||
//
|
||||
// 使用示例:
|
||||
// Usage example:
|
||||
// request := NewRequestBuilder().
|
||||
// WithSystemPrompt("You are helpful").
|
||||
// WithUserPrompt("Hello").
|
||||
@@ -35,26 +35,26 @@ func NewRequestBuilder() *RequestBuilder {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 模型和流式配置
|
||||
// Model and Stream Configuration
|
||||
// ============================================================
|
||||
|
||||
// WithModel 设置模型名称
|
||||
// WithModel sets model name
|
||||
func (b *RequestBuilder) WithModel(model string) *RequestBuilder {
|
||||
b.model = model
|
||||
return b
|
||||
}
|
||||
|
||||
// WithStream 设置是否使用流式响应
|
||||
// WithStream sets whether to use streaming response
|
||||
func (b *RequestBuilder) WithStream(stream bool) *RequestBuilder {
|
||||
b.stream = stream
|
||||
return b
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 消息构建方法
|
||||
// Message Building Methods
|
||||
// ============================================================
|
||||
|
||||
// WithSystemPrompt 添加系统提示词(便捷方法)
|
||||
// WithSystemPrompt adds system prompt (convenience method)
|
||||
func (b *RequestBuilder) WithSystemPrompt(prompt string) *RequestBuilder {
|
||||
if prompt != "" {
|
||||
b.messages = append(b.messages, NewSystemMessage(prompt))
|
||||
@@ -62,7 +62,7 @@ func (b *RequestBuilder) WithSystemPrompt(prompt string) *RequestBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// WithUserPrompt 添加用户提示词(便捷方法)
|
||||
// WithUserPrompt adds user prompt (convenience method)
|
||||
func (b *RequestBuilder) WithUserPrompt(prompt string) *RequestBuilder {
|
||||
if prompt != "" {
|
||||
b.messages = append(b.messages, NewUserMessage(prompt))
|
||||
@@ -70,17 +70,17 @@ func (b *RequestBuilder) WithUserPrompt(prompt string) *RequestBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// AddSystemMessage 添加系统消息
|
||||
// AddSystemMessage adds system message
|
||||
func (b *RequestBuilder) AddSystemMessage(content string) *RequestBuilder {
|
||||
return b.WithSystemPrompt(content)
|
||||
}
|
||||
|
||||
// AddUserMessage 添加用户消息
|
||||
// AddUserMessage adds user message
|
||||
func (b *RequestBuilder) AddUserMessage(content string) *RequestBuilder {
|
||||
return b.WithUserPrompt(content)
|
||||
}
|
||||
|
||||
// AddAssistantMessage 添加助手消息(用于多轮对话上下文)
|
||||
// AddAssistantMessage adds assistant message (for multi-turn conversation context)
|
||||
func (b *RequestBuilder) AddAssistantMessage(content string) *RequestBuilder {
|
||||
if content != "" {
|
||||
b.messages = append(b.messages, NewAssistantMessage(content))
|
||||
@@ -88,7 +88,7 @@ func (b *RequestBuilder) AddAssistantMessage(content string) *RequestBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// AddMessage 添加自定义角色的消息
|
||||
// AddMessage adds message with custom role
|
||||
func (b *RequestBuilder) AddMessage(role, content string) *RequestBuilder {
|
||||
if content != "" {
|
||||
b.messages = append(b.messages, NewMessage(role, content))
|
||||
@@ -96,33 +96,33 @@ func (b *RequestBuilder) AddMessage(role, content string) *RequestBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// AddMessages 批量添加消息
|
||||
// AddMessages adds messages in batch
|
||||
func (b *RequestBuilder) AddMessages(messages ...Message) *RequestBuilder {
|
||||
b.messages = append(b.messages, messages...)
|
||||
return b
|
||||
}
|
||||
|
||||
// AddConversationHistory 添加对话历史
|
||||
// AddConversationHistory adds conversation history
|
||||
func (b *RequestBuilder) AddConversationHistory(history []Message) *RequestBuilder {
|
||||
b.messages = append(b.messages, history...)
|
||||
return b
|
||||
}
|
||||
|
||||
// ClearMessages 清空所有消息
|
||||
// ClearMessages clears all messages
|
||||
func (b *RequestBuilder) ClearMessages() *RequestBuilder {
|
||||
b.messages = make([]Message, 0)
|
||||
return b
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 参数控制方法
|
||||
// Parameter Control Methods
|
||||
// ============================================================
|
||||
|
||||
// WithTemperature 设置温度参数 (0-2)
|
||||
// 较高的温度(如 1.2)会使输出更随机,较低的温度(如 0.2)会使输出更确定
|
||||
// WithTemperature sets temperature parameter (0-2)
|
||||
// Higher temperature (e.g. 1.2) makes output more random, lower temperature (e.g. 0.2) makes output more deterministic
|
||||
func (b *RequestBuilder) WithTemperature(t float64) *RequestBuilder {
|
||||
if t < 0 || t > 2 {
|
||||
// 可以选择 panic 或者静默忽略,这里选择限制范围
|
||||
// Can choose to panic or silently ignore, here we choose to limit the range
|
||||
if t < 0 {
|
||||
t = 0
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func (b *RequestBuilder) WithTemperature(t float64) *RequestBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// WithMaxTokens 设置最大 token 数
|
||||
// WithMaxTokens sets maximum token count
|
||||
func (b *RequestBuilder) WithMaxTokens(tokens int) *RequestBuilder {
|
||||
if tokens > 0 {
|
||||
b.maxTokens = &tokens
|
||||
@@ -142,8 +142,8 @@ func (b *RequestBuilder) WithMaxTokens(tokens int) *RequestBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// WithTopP 设置 top-p 核采样参数 (0-1)
|
||||
// 控制考虑的 token 范围,较小的值(如 0.1)使输出更聚焦
|
||||
// WithTopP sets top-p nucleus sampling parameter (0-1)
|
||||
// Controls the range of tokens considered, smaller values (e.g. 0.1) make output more focused
|
||||
func (b *RequestBuilder) WithTopP(p float64) *RequestBuilder {
|
||||
if p >= 0 && p <= 1 {
|
||||
b.topP = &p
|
||||
@@ -151,8 +151,8 @@ func (b *RequestBuilder) WithTopP(p float64) *RequestBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// WithFrequencyPenalty 设置频率惩罚 (-2 to 2)
|
||||
// 正值会根据 token 在文本中出现的频率惩罚它们,减少重复
|
||||
// WithFrequencyPenalty sets frequency penalty (-2 to 2)
|
||||
// Positive values penalize tokens based on their frequency in the text, reducing repetition
|
||||
func (b *RequestBuilder) WithFrequencyPenalty(penalty float64) *RequestBuilder {
|
||||
if penalty >= -2 && penalty <= 2 {
|
||||
b.frequencyPenalty = &penalty
|
||||
@@ -160,8 +160,8 @@ func (b *RequestBuilder) WithFrequencyPenalty(penalty float64) *RequestBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPresencePenalty 设置存在惩罚 (-2 to 2)
|
||||
// 正值会根据 token 是否出现在文本中惩罚它们,增加话题多样性
|
||||
// WithPresencePenalty sets presence penalty (-2 to 2)
|
||||
// Positive values penalize tokens based on whether they appear in the text, increasing topic diversity
|
||||
func (b *RequestBuilder) WithPresencePenalty(penalty float64) *RequestBuilder {
|
||||
if penalty >= -2 && penalty <= 2 {
|
||||
b.presencePenalty = &penalty
|
||||
@@ -169,14 +169,14 @@ func (b *RequestBuilder) WithPresencePenalty(penalty float64) *RequestBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
// WithStopSequences 设置停止序列
|
||||
// 当模型生成这些序列之一时,将停止生成
|
||||
// WithStopSequences sets stop sequences
|
||||
// Model will stop generating when it generates one of these sequences
|
||||
func (b *RequestBuilder) WithStopSequences(sequences []string) *RequestBuilder {
|
||||
b.stop = sequences
|
||||
return b
|
||||
}
|
||||
|
||||
// AddStopSequence 添加单个停止序列
|
||||
// AddStopSequence adds a single stop sequence
|
||||
func (b *RequestBuilder) AddStopSequence(sequence string) *RequestBuilder {
|
||||
if sequence != "" {
|
||||
b.stop = append(b.stop, sequence)
|
||||
@@ -185,16 +185,16 @@ func (b *RequestBuilder) AddStopSequence(sequence string) *RequestBuilder {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 工具/函数调用相关
|
||||
// Tool/Function Calling Related
|
||||
// ============================================================
|
||||
|
||||
// AddTool 添加工具
|
||||
// AddTool adds a tool
|
||||
func (b *RequestBuilder) AddTool(tool Tool) *RequestBuilder {
|
||||
b.tools = append(b.tools, tool)
|
||||
return b
|
||||
}
|
||||
|
||||
// AddFunction 添加函数(便捷方法)
|
||||
// AddFunction adds a function (convenience method)
|
||||
func (b *RequestBuilder) AddFunction(name, description string, parameters map[string]any) *RequestBuilder {
|
||||
tool := Tool{
|
||||
Type: "function",
|
||||
@@ -208,27 +208,27 @@ func (b *RequestBuilder) AddFunction(name, description string, parameters map[st
|
||||
return b
|
||||
}
|
||||
|
||||
// WithToolChoice 设置工具选择策略
|
||||
// - "auto": 自动选择是否调用工具
|
||||
// - "none": 不调用工具
|
||||
// - 也可以指定特定工具: `{"type": "function", "function": {"name": "my_function"}}`
|
||||
// WithToolChoice sets tool choice strategy
|
||||
// - "auto": automatically choose whether to call tools
|
||||
// - "none": don't call tools
|
||||
// - Can also specify a specific tool: `{"type": "function", "function": {"name": "my_function"}}`
|
||||
func (b *RequestBuilder) WithToolChoice(choice string) *RequestBuilder {
|
||||
b.toolChoice = choice
|
||||
return b
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 构建方法
|
||||
// Build Methods
|
||||
// ============================================================
|
||||
|
||||
// Build 构建请求对象
|
||||
// Build builds request object
|
||||
func (b *RequestBuilder) Build() (*Request, error) {
|
||||
// 验证:至少需要一条消息
|
||||
// Validation: at least one message is required
|
||||
if len(b.messages) == 0 {
|
||||
return nil, errors.New("至少需要一条消息")
|
||||
return nil, errors.New("at least one message is required")
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
// Create request
|
||||
req := &Request{
|
||||
Model: b.model,
|
||||
Messages: b.messages,
|
||||
@@ -238,7 +238,7 @@ func (b *RequestBuilder) Build() (*Request, error) {
|
||||
ToolChoice: b.toolChoice,
|
||||
}
|
||||
|
||||
// 只设置非 nil 的可选参数(避免发送 0 值覆盖服务端默认值)
|
||||
// Only set non-nil optional parameters (avoid sending 0 values that override server defaults)
|
||||
if b.temperature != nil {
|
||||
req.Temperature = b.temperature
|
||||
}
|
||||
@@ -258,8 +258,8 @@ func (b *RequestBuilder) Build() (*Request, error) {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// MustBuild 构建请求对象,如果失败则 panic
|
||||
// 适用于构建过程中确定不会出错的场景
|
||||
// MustBuild builds request object, panics if failed
|
||||
// Suitable for scenarios where build is guaranteed not to fail
|
||||
func (b *RequestBuilder) MustBuild() *Request {
|
||||
req, err := b.Build()
|
||||
if err != nil {
|
||||
@@ -269,10 +269,10 @@ func (b *RequestBuilder) MustBuild() *Request {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 便捷方法:预设场景
|
||||
// Convenience Methods: Preset Scenarios
|
||||
// ============================================================
|
||||
|
||||
// ForChat 创建用于聊天的构建器(预设合理的参数)
|
||||
// ForChat creates builder for chat (preset with reasonable parameters)
|
||||
func ForChat() *RequestBuilder {
|
||||
temp := 0.7
|
||||
tokens := 2000
|
||||
@@ -284,7 +284,7 @@ func ForChat() *RequestBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
// ForCodeGeneration 创建用于代码生成的构建器(低温度,更确定)
|
||||
// ForCodeGeneration creates builder for code generation (low temperature, more deterministic)
|
||||
func ForCodeGeneration() *RequestBuilder {
|
||||
temp := 0.2
|
||||
tokens := 2000
|
||||
@@ -298,7 +298,7 @@ func ForCodeGeneration() *RequestBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
// ForCreativeWriting 创建用于创意写作的构建器(高温度,更随机)
|
||||
// ForCreativeWriting creates builder for creative writing (high temperature, more random)
|
||||
func ForCreativeWriting() *RequestBuilder {
|
||||
temp := 1.2
|
||||
tokens := 4000
|
||||
|
||||
+17
-17
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// 测试 RequestBuilder 基本功能
|
||||
// Test RequestBuilder Basic Features
|
||||
// ============================================================
|
||||
|
||||
func TestRequestBuilder_BasicUsage(t *testing.T) {
|
||||
@@ -39,13 +39,13 @@ func TestRequestBuilder_EmptyMessages(t *testing.T) {
|
||||
t.Error("Build should error when no messages")
|
||||
}
|
||||
|
||||
if err.Error() != "至少需要一条消息" {
|
||||
if err.Error() != "at least one message is required" {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试消息构建方法
|
||||
// Test Message Building Methods
|
||||
// ============================================================
|
||||
|
||||
func TestRequestBuilder_MultipleMessages(t *testing.T) {
|
||||
@@ -85,7 +85,7 @@ func TestRequestBuilder_AddConversationHistory(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试参数控制方法
|
||||
// Test Parameter Control Methods
|
||||
// ============================================================
|
||||
|
||||
func TestRequestBuilder_WithTemperature(t *testing.T) {
|
||||
@@ -165,7 +165,7 @@ func TestRequestBuilder_WithStopSequences(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试工具/函数调用
|
||||
// Test Tool/Function Calling
|
||||
// ============================================================
|
||||
|
||||
func TestRequestBuilder_AddTool(t *testing.T) {
|
||||
@@ -229,7 +229,7 @@ func TestRequestBuilder_AddFunction(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试便捷方法
|
||||
// Test Convenience Methods
|
||||
// ============================================================
|
||||
|
||||
func TestRequestBuilder_ForChat(t *testing.T) {
|
||||
@@ -287,7 +287,7 @@ func TestRequestBuilder_ForCreativeWriting(t *testing.T) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 测试 CallWithRequest 集成
|
||||
// Test CallWithRequest Integration
|
||||
// ============================================================
|
||||
|
||||
func TestClient_CallWithRequest_Success(t *testing.T) {
|
||||
@@ -317,25 +317,25 @@ func TestClient_CallWithRequest_Success(t *testing.T) {
|
||||
t.Errorf("expected 'Builder response', got '%s'", result)
|
||||
}
|
||||
|
||||
// 验证请求体
|
||||
// Verify request body
|
||||
requests := mockHTTP.GetRequests()
|
||||
if len(requests) != 1 {
|
||||
t.Fatalf("expected 1 request, got %d", len(requests))
|
||||
}
|
||||
|
||||
// 解析请求体验证参数
|
||||
// Parse request body to verify parameters
|
||||
var body map[string]interface{}
|
||||
decoder := json.NewDecoder(requests[0].Body)
|
||||
if err := decoder.Decode(&body); err != nil {
|
||||
t.Fatalf("failed to decode request body: %v", err)
|
||||
}
|
||||
|
||||
// 验证 temperature
|
||||
// Verify temperature
|
||||
if body["temperature"] != 0.8 {
|
||||
t.Errorf("expected temperature 0.8, got %v", body["temperature"])
|
||||
}
|
||||
|
||||
// 验证 messages
|
||||
// Verify messages
|
||||
messages, ok := body["messages"].([]interface{})
|
||||
if !ok || len(messages) != 2 {
|
||||
t.Error("messages not correctly formatted")
|
||||
@@ -353,7 +353,7 @@ func TestClient_CallWithRequest_MultiRound(t *testing.T) {
|
||||
WithAPIKey("sk-test-key"),
|
||||
)
|
||||
|
||||
// 构建多轮对话
|
||||
// Build multi-round conversation
|
||||
request := NewRequestBuilder().
|
||||
AddSystemMessage("You are a trading advisor").
|
||||
AddUserMessage("Analyze BTC").
|
||||
@@ -372,7 +372,7 @@ func TestClient_CallWithRequest_MultiRound(t *testing.T) {
|
||||
t.Errorf("expected 'Multi-round response', got '%s'", result)
|
||||
}
|
||||
|
||||
// 验证请求体包含所有消息
|
||||
// Verify request body contains all messages
|
||||
requests := mockHTTP.GetRequests()
|
||||
var body map[string]interface{}
|
||||
json.NewDecoder(requests[0].Body).Decode(&body)
|
||||
@@ -411,7 +411,7 @@ func TestClient_CallWithRequest_WithTools(t *testing.T) {
|
||||
t.Fatalf("should not error: %v", err)
|
||||
}
|
||||
|
||||
// 验证请求体包含 tools
|
||||
// Verify request body contains tools
|
||||
requests := mockHTTP.GetRequests()
|
||||
var body map[string]interface{}
|
||||
json.NewDecoder(requests[0].Body).Decode(&body)
|
||||
@@ -440,7 +440,7 @@ func TestClient_CallWithRequest_NoAPIKey(t *testing.T) {
|
||||
t.Error("should error when API key not set")
|
||||
}
|
||||
|
||||
if err.Error() != "AI API密钥未设置,请先调用 SetAPIKey" {
|
||||
if err.Error() != "AI API key not set, please call SetAPIKey first" {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -456,7 +456,7 @@ func TestClient_CallWithRequest_UsesClientModel(t *testing.T) {
|
||||
WithAPIKey("sk-test-key"),
|
||||
)
|
||||
|
||||
// Request 不设置 model,应该使用 Client 的 model
|
||||
// Request does not set model, should use Client's model
|
||||
request := NewRequestBuilder().
|
||||
WithUserPrompt("Hello").
|
||||
MustBuild()
|
||||
@@ -467,7 +467,7 @@ func TestClient_CallWithRequest_UsesClientModel(t *testing.T) {
|
||||
|
||||
client.CallWithRequest(request)
|
||||
|
||||
// 验证使用了 DeepSeek 的 model
|
||||
// Verify DeepSeek's model is used
|
||||
requests := mockHTTP.GetRequests()
|
||||
var body map[string]interface{}
|
||||
json.NewDecoder(requests[0].Body).Decode(&body)
|
||||
|
||||
Reference in New Issue
Block a user