mirror of
https://github.com/laoxong/nofx.git
synced 2026-06-04 01:48:22 +08:00
7d58f56e49
- Add PostgreSQL + SQLite hybrid database support with automatic switching - Implement frontend AES-GCM + RSA-OAEP encryption for sensitive data - Add comprehensive DatabaseInterface with all required methods - Fix compilation issues with interface consistency - Update all database method signatures to use DatabaseInterface - Add missing UpdateTraderInitialBalance method to PostgreSQL implementation - Integrate RSA public key distribution via /api/config endpoint - Add frontend crypto service with proper error handling - Support graceful degradation between encrypted and plaintext transmission - Add directory creation for RSA keys and PEM parsing fixes - Test both SQLite and PostgreSQL modes successfully 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: tinkle-community <tinklefund@gmail.com>
308 lines
8.5 KiB
Go
308 lines
8.5 KiB
Go
package mcp
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
"os"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
// Provider AI提供商类型
|
||
type Provider string
|
||
|
||
const (
|
||
ProviderDeepSeek Provider = "deepseek"
|
||
ProviderQwen Provider = "qwen"
|
||
ProviderCustom Provider = "custom"
|
||
)
|
||
|
||
// Client AI API配置
|
||
type Client struct {
|
||
Provider Provider
|
||
APIKey string
|
||
BaseURL string
|
||
Model string
|
||
Timeout time.Duration
|
||
UseFullURL bool // 是否使用完整URL(不添加/chat/completions)
|
||
MaxTokens int // AI响应的最大token数
|
||
}
|
||
|
||
func New() *Client {
|
||
// 从环境变量读取 MaxTokens,默认 2000
|
||
maxTokens := 2000
|
||
if envMaxTokens := os.Getenv("AI_MAX_TOKENS"); envMaxTokens != "" {
|
||
if parsed, err := strconv.Atoi(envMaxTokens); err == nil && parsed > 0 {
|
||
maxTokens = parsed
|
||
log.Printf("🔧 [MCP] 使用环境变量 AI_MAX_TOKENS: %d", maxTokens)
|
||
} else {
|
||
log.Printf("⚠️ [MCP] 环境变量 AI_MAX_TOKENS 无效 (%s),使用默认值: %d", envMaxTokens, maxTokens)
|
||
}
|
||
}
|
||
|
||
// 默认配置
|
||
return &Client{
|
||
Provider: ProviderDeepSeek,
|
||
BaseURL: "https://api.deepseek.com/v1",
|
||
Model: "deepseek-chat",
|
||
Timeout: 120 * time.Second, // 增加到120秒,因为AI需要分析大量数据
|
||
MaxTokens: maxTokens,
|
||
}
|
||
}
|
||
|
||
// SetDeepSeekAPIKey 设置DeepSeek API密钥
|
||
// customURL 为空时使用默认URL,customModel 为空时使用默认模型
|
||
func (client *Client) SetDeepSeekAPIKey(apiKey string, customURL string, customModel string) {
|
||
client.Provider = ProviderDeepSeek
|
||
client.APIKey = apiKey
|
||
if customURL != "" {
|
||
client.BaseURL = customURL
|
||
log.Printf("🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s", customURL)
|
||
} else {
|
||
client.BaseURL = "https://api.deepseek.com/v1"
|
||
log.Printf("🔧 [MCP] DeepSeek 使用默认 BaseURL: %s", client.BaseURL)
|
||
}
|
||
if customModel != "" {
|
||
client.Model = customModel
|
||
log.Printf("🔧 [MCP] DeepSeek 使用自定义 Model: %s", customModel)
|
||
} else {
|
||
client.Model = "deepseek-chat"
|
||
log.Printf("🔧 [MCP] DeepSeek 使用默认 Model: %s", client.Model)
|
||
}
|
||
// 打印 API Key 的前后各4位用于验证
|
||
if len(apiKey) > 8 {
|
||
log.Printf("🔧 [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||
}
|
||
}
|
||
|
||
// SetQwenAPIKey 设置阿里云Qwen API密钥
|
||
// customURL 为空时使用默认URL,customModel 为空时使用默认模型
|
||
func (client *Client) SetQwenAPIKey(apiKey string, customURL string, customModel string) {
|
||
client.Provider = ProviderQwen
|
||
client.APIKey = apiKey
|
||
if customURL != "" {
|
||
client.BaseURL = customURL
|
||
log.Printf("🔧 [MCP] Qwen 使用自定义 BaseURL: %s", customURL)
|
||
} else {
|
||
client.BaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
log.Printf("🔧 [MCP] Qwen 使用默认 BaseURL: %s", client.BaseURL)
|
||
}
|
||
if customModel != "" {
|
||
client.Model = customModel
|
||
log.Printf("🔧 [MCP] Qwen 使用自定义 Model: %s", customModel)
|
||
} else {
|
||
client.Model = "qwen3-max"
|
||
log.Printf("🔧 [MCP] Qwen 使用默认 Model: %s", client.Model)
|
||
}
|
||
// 打印 API Key 的前后各4位用于验证
|
||
if len(apiKey) > 8 {
|
||
log.Printf("🔧 [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||
}
|
||
}
|
||
|
||
// SetCustomAPI 设置自定义OpenAI兼容API
|
||
func (client *Client) SetCustomAPI(apiURL, apiKey, modelName string) {
|
||
client.Provider = ProviderCustom
|
||
client.APIKey = apiKey
|
||
|
||
// 检查URL是否以#结尾,如果是则使用完整URL(不添加/chat/completions)
|
||
if strings.HasSuffix(apiURL, "#") {
|
||
client.BaseURL = strings.TrimSuffix(apiURL, "#")
|
||
client.UseFullURL = true
|
||
} else {
|
||
client.BaseURL = apiURL
|
||
client.UseFullURL = false
|
||
}
|
||
|
||
client.Model = modelName
|
||
client.Timeout = 120 * time.Second
|
||
}
|
||
|
||
// SetClient 设置完整的AI配置(高级用户)
|
||
func (client *Client) SetClient(Client Client) {
|
||
if Client.Timeout == 0 {
|
||
Client.Timeout = 30 * time.Second
|
||
}
|
||
client = &Client
|
||
}
|
||
|
||
// CallWithMessages 使用 system + user prompt 调用AI API(推荐)
|
||
func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, error) {
|
||
if client.APIKey == "" {
|
||
return "", fmt.Errorf("AI API密钥未设置,请先调用 SetDeepSeekAPIKey() 或 SetQwenAPIKey()")
|
||
}
|
||
|
||
// 重试配置
|
||
maxRetries := 3
|
||
var lastErr error
|
||
|
||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||
if attempt > 1 {
|
||
fmt.Printf("⚠️ AI API调用失败,正在重试 (%d/%d)...\n", attempt, maxRetries)
|
||
}
|
||
|
||
result, err := client.callOnce(systemPrompt, userPrompt)
|
||
if err == nil {
|
||
if attempt > 1 {
|
||
fmt.Printf("✓ AI API重试成功\n")
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
lastErr = err
|
||
// 如果不是网络错误,不重试
|
||
if !isRetryableError(err) {
|
||
return "", err
|
||
}
|
||
|
||
// 重试前等待
|
||
if attempt < maxRetries {
|
||
waitTime := time.Duration(attempt) * 2 * time.Second
|
||
fmt.Printf("⏳ 等待%v后重试...\n", waitTime)
|
||
time.Sleep(waitTime)
|
||
}
|
||
}
|
||
|
||
return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
|
||
}
|
||
|
||
// callOnce 单次调用AI API(内部使用)
|
||
func (client *Client) callOnce(systemPrompt, userPrompt string) (string, error) {
|
||
// 打印当前 AI 配置
|
||
log.Printf("📡 [MCP] AI 请求配置:")
|
||
log.Printf(" Provider: %s", client.Provider)
|
||
log.Printf(" BaseURL: %s", client.BaseURL)
|
||
log.Printf(" Model: %s", client.Model)
|
||
log.Printf(" UseFullURL: %v", client.UseFullURL)
|
||
if len(client.APIKey) > 8 {
|
||
log.Printf(" API Key: %s...%s", client.APIKey[:4], client.APIKey[len(client.APIKey)-4:])
|
||
}
|
||
|
||
// 构建 messages 数组
|
||
messages := []map[string]string{}
|
||
|
||
// 如果有 system prompt,添加 system message
|
||
if systemPrompt != "" {
|
||
messages = append(messages, map[string]string{
|
||
"role": "system",
|
||
"content": systemPrompt,
|
||
})
|
||
}
|
||
|
||
// 添加 user message
|
||
messages = append(messages, map[string]string{
|
||
"role": "user",
|
||
"content": userPrompt,
|
||
})
|
||
|
||
// 构建请求体
|
||
requestBody := map[string]interface{}{
|
||
"model": client.Model,
|
||
"messages": messages,
|
||
"temperature": 0.5, // 降低temperature以提高JSON格式稳定性
|
||
"max_tokens": client.MaxTokens,
|
||
}
|
||
|
||
// 注意:response_format 参数仅 OpenAI 支持,DeepSeek/Qwen 不支持
|
||
// 我们通过强化 prompt 和后处理来确保 JSON 格式正确
|
||
|
||
jsonData, err := json.Marshal(requestBody)
|
||
if err != nil {
|
||
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||
}
|
||
|
||
// 创建HTTP请求
|
||
var url string
|
||
if client.UseFullURL {
|
||
// 使用完整URL,不添加/chat/completions
|
||
url = client.BaseURL
|
||
} else {
|
||
// 默认行为:添加/chat/completions
|
||
url = fmt.Sprintf("%s/chat/completions", client.BaseURL)
|
||
}
|
||
log.Printf("📡 [MCP] 请求 URL: %s", url)
|
||
|
||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||
}
|
||
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
// 根据不同的Provider设置认证方式
|
||
switch client.Provider {
|
||
case ProviderDeepSeek:
|
||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey))
|
||
case ProviderQwen:
|
||
// 阿里云Qwen使用API-Key认证
|
||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey))
|
||
// 注意:如果使用的不是兼容模式,可能需要不同的认证方式
|
||
default:
|
||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey))
|
||
}
|
||
|
||
// 发送请求
|
||
httpClient := &http.Client{Timeout: client.Timeout}
|
||
resp, err := httpClient.Do(req)
|
||
if err != nil {
|
||
return "", fmt.Errorf("发送请求失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 读取响应
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取响应失败: %w", err)
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return "", fmt.Errorf("API返回错误 (status %d): %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
// 解析响应
|
||
var result struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
}
|
||
|
||
if err := json.Unmarshal(body, &result); err != nil {
|
||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||
}
|
||
|
||
if len(result.Choices) == 0 {
|
||
return "", fmt.Errorf("API返回空响应")
|
||
}
|
||
|
||
return result.Choices[0].Message.Content, nil
|
||
}
|
||
|
||
// isRetryableError 判断错误是否可重试
|
||
func isRetryableError(err error) bool {
|
||
errStr := err.Error()
|
||
// 网络错误、超时、EOF等可以重试
|
||
retryableErrors := []string{
|
||
"EOF",
|
||
"timeout",
|
||
"connection reset",
|
||
"connection refused",
|
||
"temporary failure",
|
||
"no such host",
|
||
"stream error", // HTTP/2 stream 错误
|
||
"INTERNAL_ERROR", // 服务端内部错误
|
||
}
|
||
for _, retryable := range retryableErrors {
|
||
if strings.Contains(errStr, retryable) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|