diff --git a/.env.example b/.env.example index cd64fe4e..3c88c0ce 100644 --- a/.env.example +++ b/.env.example @@ -1,14 +1,46 @@ # NOFX Environment Variables Template # Copy this file to .env and modify the values as needed -# Ports Configuration -# Backend API server port (internal: 8080, external: configurable) +# =========================================== +# Server Configuration +# =========================================== + +# Backend API server port NOFX_BACKEND_PORT=8080 -# Frontend web interface port (Nginx listens on port 80 internally) +# Frontend web interface port NOFX_FRONTEND_PORT=3000 -# Timezone Setting -# System timezone for container time synchronization +# Timezone NOFX_TIMEZONE=Asia/Shanghai +# =========================================== +# Authentication (Required) +# =========================================== + +# JWT signing secret (any random string, at least 32 characters) +# Generate with: openssl rand -base64 32 +JWT_SECRET=your-jwt-secret-change-this-in-production + +# =========================================== +# Encryption Keys (Required) +# =========================================== + +# AES-256 data encryption key (Base64 encoded, 32 bytes) +# Used for encrypting sensitive data in database (API keys, secrets) +# Generate with: openssl rand -base64 32 +DATA_ENCRYPTION_KEY=your-base64-encoded-32-byte-key + +# RSA private key for client-server encryption (PEM format) +# Used for end-to-end encryption of sensitive data from browser +# Generate with: openssl genrsa 2048 +# Note: Replace newlines with \n for single-line format +RSA_PRIVATE_KEY=-----BEGIN RSA PRIVATE KEY-----\nYOUR_KEY_HERE\n-----END RSA PRIVATE KEY----- + +# =========================================== +# Optional: External Services +# =========================================== + +# Telegram notifications (optional) +# TELEGRAM_BOT_TOKEN=your-bot-token +# TELEGRAM_CHAT_ID=your-chat-id diff --git a/.gitignore b/.gitignore index 05dd17f4..a9ab8f41 100644 --- a/.gitignore +++ b/.gitignore @@ -30,8 +30,7 @@ Thumbs.db # 环境变量 .env config.json -config.db* -nofx.db +data.db* configbak.json # 决策日志 diff --git a/ENCRYPTION_README.md b/ENCRYPTION_README.md index 78655876..c2893b85 100644 --- a/ENCRYPTION_README.md +++ b/ENCRYPTION_README.md @@ -116,7 +116,7 @@ If needed, rollback is simple: ```bash # Restore backup -cp config.db.backup config.db +cp data.db.backup data.db # Comment out 3 lines in main.go # (encryption initialization) diff --git a/api/backtest.go b/api/backtest.go index 5cf04796..9e000dda 100644 --- a/api/backtest.go +++ b/api/backtest.go @@ -12,8 +12,8 @@ import ( "time" "nofx/backtest" - "nofx/config" "nofx/decision" + "nofx/store" "github.com/gin-gonic/gin" ) @@ -486,9 +486,6 @@ func (s *Server) ensureBacktestRunOwnership(runID, userID string) (*backtest.Run if owner == "" { return meta, nil } - if owner == "default" && userID == "admin" { - return meta, nil - } if owner != userID { return nil, errBacktestForbidden } @@ -514,7 +511,7 @@ func (s *Server) resolveBacktestAIConfig(cfg *backtest.BacktestConfig, userID st if cfg == nil { return fmt.Errorf("config is nil") } - if s.database == nil { + if s.store == nil { return fmt.Errorf("系统数据库未就绪,无法加载AI模型配置") } @@ -527,7 +524,7 @@ func (s *Server) hydrateBacktestAIConfig(cfg *backtest.BacktestConfig) error { if cfg == nil { return fmt.Errorf("config is nil") } - if s.database == nil { + if s.store == nil { return fmt.Errorf("系统数据库未就绪,无法加载AI模型配置") } @@ -535,17 +532,17 @@ func (s *Server) hydrateBacktestAIConfig(cfg *backtest.BacktestConfig) error { modelID := strings.TrimSpace(cfg.AIModelID) var ( - model *config.AIModelConfig + model *store.AIModel err error ) if modelID != "" { - model, err = s.database.GetAIModel(cfg.UserID, modelID) + model, err = s.store.AIModel().Get(cfg.UserID, modelID) if err != nil { return fmt.Errorf("加载AI模型失败: %w", err) } } else { - model, err = s.database.GetDefaultAIModel(cfg.UserID) + model, err = s.store.AIModel().GetDefault(cfg.UserID) if err != nil { return fmt.Errorf("未找到可用的AI模型: %w", err) } diff --git a/api/server.go b/api/server.go index 89b6013f..d71b8955 100644 --- a/api/server.go +++ b/api/server.go @@ -4,15 +4,15 @@ import ( "context" "encoding/json" "fmt" - "log" + "nofx/logger" "net" "net/http" "nofx/auth" "nofx/backtest" - "nofx/config" "nofx/crypto" "nofx/decision" "nofx/manager" + "nofx/store" "nofx/trader" "strconv" "strings" @@ -26,14 +26,15 @@ import ( type Server struct { router *gin.Engine traderManager *manager.TraderManager - database *config.Database + store *store.Store cryptoHandler *CryptoHandler backtestManager *backtest.Manager httpServer *http.Server port int } + // NewServer 创建API服务器 -func NewServer(traderManager *manager.TraderManager, database *config.Database, cryptoService *crypto.CryptoService, backtestManager *backtest.Manager, port int) *Server { +func NewServer(traderManager *manager.TraderManager, st *store.Store, cryptoService *crypto.CryptoService, backtestManager *backtest.Manager, port int) *Server { // 设置为Release模式(减少日志输出) gin.SetMode(gin.ReleaseMode) @@ -48,7 +49,7 @@ func NewServer(traderManager *manager.TraderManager, database *config.Database, s := &Server{ router: router, traderManager: traderManager, - database: database, + store: st, cryptoHandler: cryptoHandler, backtestManager: backtestManager, port: port, @@ -154,7 +155,6 @@ func (s *Server) setupRoutes() { protected.GET("/decisions", s.handleDecisions) protected.GET("/decisions/latest", s.handleLatestDecisions) protected.GET("/statistics", s.handleStatistics) - protected.GET("/performance", s.handlePerformance) } } } @@ -170,7 +170,7 @@ func (s *Server) handleHealth(c *gin.Context) { // handleGetSystemConfig 获取系统配置(客户端需要知道的配置) func (s *Server) handleGetSystemConfig(c *gin.Context) { // 获取默认币种 - defaultCoinsStr, _ := s.database.GetSystemConfig("default_coins") + defaultCoinsStr, _ := s.store.SystemConfig().Get("default_coins") var defaultCoins []string if defaultCoinsStr != "" { json.Unmarshal([]byte(defaultCoinsStr), &defaultCoins) @@ -181,8 +181,8 @@ func (s *Server) handleGetSystemConfig(c *gin.Context) { } // 获取杠杆配置 - btcEthLeverageStr, _ := s.database.GetSystemConfig("btc_eth_leverage") - altcoinLeverageStr, _ := s.database.GetSystemConfig("altcoin_leverage") + btcEthLeverageStr, _ := s.store.SystemConfig().Get("btc_eth_leverage") + altcoinLeverageStr, _ := s.store.SystemConfig().Get("altcoin_leverage") btcEthLeverage := 5 if val, err := strconv.Atoi(btcEthLeverageStr); err == nil && val > 0 { @@ -195,14 +195,19 @@ func (s *Server) handleGetSystemConfig(c *gin.Context) { } // 获取内测模式配置 - betaModeStr, _ := s.database.GetSystemConfig("beta_mode") + betaModeStr, _ := s.store.SystemConfig().Get("beta_mode") betaMode := betaModeStr == "true" + // 获取注册开关配置(默认开启) + registrationEnabledStr, _ := s.store.SystemConfig().Get("registration_enabled") + registrationEnabled := registrationEnabledStr != "false" + c.JSON(http.StatusOK, gin.H{ - "beta_mode": betaMode, - "default_coins": defaultCoins, - "btc_eth_leverage": btcEthLeverage, - "altcoin_leverage": altcoinLeverage, + "beta_mode": betaMode, + "registration_enabled": registrationEnabled, + "default_coins": defaultCoins, + "btc_eth_leverage": btcEthLeverage, + "altcoin_leverage": altcoinLeverage, }) } @@ -339,9 +344,9 @@ func (s *Server) getTraderFromQuery(c *gin.Context) (*manager.TraderManager, str traderID := c.Query("trader_id") // 确保用户的交易员已加载到内存中 - err := s.traderManager.LoadUserTraders(s.database, userID) + err := s.traderManager.LoadUserTradersFromStore(s.store, userID) if err != nil { - log.Printf("⚠️ 加载用户 %s 的交易员失败: %v", userID, err) + logger.Infof("⚠️ 加载用户 %s 的交易员失败: %v", userID, err) } if traderID == "" { @@ -352,7 +357,7 @@ func (s *Server) getTraderFromQuery(c *gin.Context) (*manager.TraderManager, str } // 获取用户的交易员列表,优先返回用户自己的交易员 - userTraders, err := s.database.GetTraders(userID) + userTraders, err := s.store.Trader().List(userID) if err == nil && len(userTraders) > 0 { traderID = userTraders[0].ID } else { @@ -493,7 +498,7 @@ func (s *Server) handleCreateTrader(c *gin.Context) { btcEthLeverage = req.BTCETHLeverage } else { // 从系统配置获取默认值 - if btcEthLeverageStr, _ := s.database.GetSystemConfig("btc_eth_leverage"); btcEthLeverageStr != "" { + if btcEthLeverageStr, _ := s.store.SystemConfig().Get("btc_eth_leverage"); btcEthLeverageStr != "" { if val, err := strconv.Atoi(btcEthLeverageStr); err == nil && val > 0 { btcEthLeverage = val } @@ -503,7 +508,7 @@ func (s *Server) handleCreateTrader(c *gin.Context) { altcoinLeverage = req.AltcoinLeverage } else { // 从系统配置获取默认值 - if altcoinLeverageStr, _ := s.database.GetSystemConfig("altcoin_leverage"); altcoinLeverageStr != "" { + if altcoinLeverageStr, _ := s.store.SystemConfig().Get("altcoin_leverage"); altcoinLeverageStr != "" { if val, err := strconv.Atoi(altcoinLeverageStr); err == nil && val > 0 { altcoinLeverage = val } @@ -524,13 +529,13 @@ func (s *Server) handleCreateTrader(c *gin.Context) { // ✨ 查询交易所实际余额,覆盖用户输入 actualBalance := req.InitialBalance // 默认使用用户输入 - exchanges, err := s.database.GetExchanges(userID) + exchanges, err := s.store.Exchange().List(userID) if err != nil { - log.Printf("⚠️ 获取交易所配置失败,使用用户输入的初始资金: %v", err) + logger.Infof("⚠️ 获取交易所配置失败,使用用户输入的初始资金: %v", err) } // 查找匹配的交易所配置 - var exchangeCfg *config.ExchangeConfig + var exchangeCfg *store.Exchange for _, ex := range exchanges { if ex.ID == req.ExchangeID { exchangeCfg = ex @@ -539,9 +544,9 @@ func (s *Server) handleCreateTrader(c *gin.Context) { } if exchangeCfg == nil { - log.Printf("⚠️ 未找到交易所 %s 的配置,使用用户输入的初始资金", req.ExchangeID) + logger.Infof("⚠️ 未找到交易所 %s 的配置,使用用户输入的初始资金", req.ExchangeID) } else if !exchangeCfg.Enabled { - log.Printf("⚠️ 交易所 %s 未启用,使用用户输入的初始资金", req.ExchangeID) + logger.Infof("⚠️ 交易所 %s 未启用,使用用户输入的初始资金", req.ExchangeID) } else { // 根据交易所类型创建临时 trader 查询余额 var tempTrader trader.Trader @@ -568,44 +573,44 @@ func (s *Server) handleCreateTrader(c *gin.Context) { exchangeCfg.SecretKey, ) default: - log.Printf("⚠️ 不支持的交易所类型: %s,使用用户输入的初始资金", req.ExchangeID) + logger.Infof("⚠️ 不支持的交易所类型: %s,使用用户输入的初始资金", req.ExchangeID) } if createErr != nil { - log.Printf("⚠️ 创建临时 trader 失败,使用用户输入的初始资金: %v", createErr) + logger.Infof("⚠️ 创建临时 trader 失败,使用用户输入的初始资金: %v", createErr) } else if tempTrader != nil { // 查询实际余额 balanceInfo, balanceErr := tempTrader.GetBalance() if balanceErr != nil { - log.Printf("⚠️ 查询交易所余额失败,使用用户输入的初始资金: %v", balanceErr) + logger.Infof("⚠️ 查询交易所余额失败,使用用户输入的初始资金: %v", balanceErr) } else { // 提取可用余额 - 支持多种字段名格式 if availableBalance, ok := balanceInfo["availableBalance"].(float64); ok && availableBalance > 0 { // Binance 格式: availableBalance (camelCase) actualBalance = availableBalance - log.Printf("✓ 查询到交易所实际余额: %.2f USDT (用户输入: %.2f USDT)", actualBalance, req.InitialBalance) + logger.Infof("✓ 查询到交易所实际余额: %.2f USDT (用户输入: %.2f USDT)", actualBalance, req.InitialBalance) } else if availableBalance, ok := balanceInfo["available_balance"].(float64); ok && availableBalance > 0 { // 其他格式: available_balance (snake_case) actualBalance = availableBalance - log.Printf("✓ 查询到交易所实际余额: %.2f USDT (用户输入: %.2f USDT)", actualBalance, req.InitialBalance) + logger.Infof("✓ 查询到交易所实际余额: %.2f USDT (用户输入: %.2f USDT)", actualBalance, req.InitialBalance) } else if totalBalance, ok := balanceInfo["totalWalletBalance"].(float64); ok && totalBalance > 0 { // Binance 格式: totalWalletBalance (camelCase) actualBalance = totalBalance - log.Printf("✓ 查询到交易所总余额: %.2f USDT (用户输入: %.2f USDT)", actualBalance, req.InitialBalance) + logger.Infof("✓ 查询到交易所总余额: %.2f USDT (用户输入: %.2f USDT)", actualBalance, req.InitialBalance) } else if totalBalance, ok := balanceInfo["balance"].(float64); ok && totalBalance > 0 { // 其他格式: balance actualBalance = totalBalance - log.Printf("✓ 查询到交易所实际余额: %.2f USDT (用户输入: %.2f USDT)", actualBalance, req.InitialBalance) + logger.Infof("✓ 查询到交易所实际余额: %.2f USDT (用户输入: %.2f USDT)", actualBalance, req.InitialBalance) } else { - log.Printf("⚠️ 无法从余额信息中提取可用余额,balanceInfo=%v,使用用户输入的初始资金", balanceInfo) + logger.Infof("⚠️ 无法从余额信息中提取可用余额,balanceInfo=%v,使用用户输入的初始资金", balanceInfo) } } } } // 创建交易员配置(数据库实体) - log.Printf("🔧 DEBUG: 开始创建交易员配置, ID=%s, Name=%s, AIModel=%s, Exchange=%s", traderID, req.Name, req.AIModelID, req.ExchangeID) - trader := &config.TraderRecord{ + logger.Infof("🔧 DEBUG: 开始创建交易员配置, ID=%s, Name=%s, AIModel=%s, Exchange=%s", traderID, req.Name, req.AIModelID, req.ExchangeID) + traderRecord := &store.Trader{ ID: traderID, UserID: userID, Name: req.Name, @@ -626,25 +631,25 @@ func (s *Server) handleCreateTrader(c *gin.Context) { } // 保存到数据库 - log.Printf("🔧 DEBUG: 准备调用 CreateTrader") - err = s.database.CreateTrader(trader) + logger.Infof("🔧 DEBUG: 准备调用 CreateTrader") + err = s.store.Trader().Create(traderRecord) if err != nil { - log.Printf("❌ 创建交易员失败: %v", err) + logger.Infof("❌ 创建交易员失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("创建交易员失败: %v", err)}) return } - log.Printf("🔧 DEBUG: CreateTrader 成功") + logger.Infof("🔧 DEBUG: CreateTrader 成功") // 立即将新交易员加载到TraderManager中 - log.Printf("🔧 DEBUG: 准备调用 LoadUserTraders") - err = s.traderManager.LoadUserTraders(s.database, userID) + logger.Infof("🔧 DEBUG: 准备调用 LoadUserTraders") + err = s.traderManager.LoadUserTradersFromStore(s.store, userID) if err != nil { - log.Printf("⚠️ 加载用户交易员到内存失败: %v", err) + logger.Infof("⚠️ 加载用户交易员到内存失败: %v", err) // 这里不返回错误,因为交易员已经成功创建到数据库 } - log.Printf("🔧 DEBUG: LoadUserTraders 完成") + logger.Infof("🔧 DEBUG: LoadUserTraders 完成") - log.Printf("✓ 创建交易员成功: %s (模型: %s, 交易所: %s)", req.Name, req.AIModelID, req.ExchangeID) + logger.Infof("✓ 创建交易员成功: %s (模型: %s, 交易所: %s)", req.Name, req.AIModelID, req.ExchangeID) c.JSON(http.StatusCreated, gin.H{ "trader_id": traderID, @@ -656,17 +661,18 @@ func (s *Server) handleCreateTrader(c *gin.Context) { // UpdateTraderRequest 更新交易员请求 type UpdateTraderRequest struct { - Name string `json:"name" binding:"required"` - AIModelID string `json:"ai_model_id" binding:"required"` - ExchangeID string `json:"exchange_id" binding:"required"` - InitialBalance float64 `json:"initial_balance"` - ScanIntervalMinutes int `json:"scan_interval_minutes"` - BTCETHLeverage int `json:"btc_eth_leverage"` - AltcoinLeverage int `json:"altcoin_leverage"` - TradingSymbols string `json:"trading_symbols"` - CustomPrompt string `json:"custom_prompt"` - OverrideBasePrompt bool `json:"override_base_prompt"` - IsCrossMargin *bool `json:"is_cross_margin"` + Name string `json:"name" binding:"required"` + AIModelID string `json:"ai_model_id" binding:"required"` + ExchangeID string `json:"exchange_id" binding:"required"` + InitialBalance float64 `json:"initial_balance"` + ScanIntervalMinutes int `json:"scan_interval_minutes"` + BTCETHLeverage int `json:"btc_eth_leverage"` + AltcoinLeverage int `json:"altcoin_leverage"` + TradingSymbols string `json:"trading_symbols"` + CustomPrompt string `json:"custom_prompt"` + OverrideBasePrompt bool `json:"override_base_prompt"` + SystemPromptTemplate string `json:"system_prompt_template"` + IsCrossMargin *bool `json:"is_cross_margin"` } // handleUpdateTrader 更新交易员配置 @@ -681,16 +687,16 @@ func (s *Server) handleUpdateTrader(c *gin.Context) { } // 检查交易员是否存在且属于当前用户 - traders, err := s.database.GetTraders(userID) + traders, err := s.store.Trader().List(userID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "获取交易员列表失败"}) return } - var existingTrader *config.TraderRecord - for _, trader := range traders { - if trader.ID == traderID { - existingTrader = trader + var existingTrader *store.Trader + for _, t := range traders { + if t.ID == traderID { + existingTrader = t break } } @@ -724,8 +730,14 @@ func (s *Server) handleUpdateTrader(c *gin.Context) { scanIntervalMinutes = 3 } + // 设置系统提示词模板 + systemPromptTemplate := req.SystemPromptTemplate + if systemPromptTemplate == "" { + systemPromptTemplate = existingTrader.SystemPromptTemplate // 保持原值 + } + // 更新交易员配置 - trader := &config.TraderRecord{ + traderRecord := &store.Trader{ ID: traderID, UserID: userID, Name: req.Name, @@ -737,26 +749,26 @@ func (s *Server) handleUpdateTrader(c *gin.Context) { TradingSymbols: req.TradingSymbols, CustomPrompt: req.CustomPrompt, OverrideBasePrompt: req.OverrideBasePrompt, - SystemPromptTemplate: existingTrader.SystemPromptTemplate, // 保持原值 + SystemPromptTemplate: systemPromptTemplate, IsCrossMargin: isCrossMargin, ScanIntervalMinutes: scanIntervalMinutes, IsRunning: existingTrader.IsRunning, // 保持原值 } // 更新数据库 - err = s.database.UpdateTrader(trader) + err = s.store.Trader().Update(traderRecord) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("更新交易员失败: %v", err)}) return } // 重新加载交易员到内存 - err = s.traderManager.LoadUserTraders(s.database, userID) + err = s.traderManager.LoadUserTradersFromStore(s.store, userID) if err != nil { - log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err) + logger.Infof("⚠️ 重新加载用户交易员到内存失败: %v", err) } - log.Printf("✓ 更新交易员成功: %s (模型: %s, 交易所: %s)", req.Name, req.AIModelID, req.ExchangeID) + logger.Infof("✓ 更新交易员成功: %s (模型: %s, 交易所: %s)", req.Name, req.AIModelID, req.ExchangeID) c.JSON(http.StatusOK, gin.H{ "trader_id": traderID, @@ -772,7 +784,7 @@ func (s *Server) handleDeleteTrader(c *gin.Context) { traderID := c.Param("id") // 从数据库删除 - err := s.database.DeleteTrader(userID, traderID) + err := s.store.Trader().Delete(userID, traderID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("删除交易员失败: %v", err)}) return @@ -783,11 +795,11 @@ func (s *Server) handleDeleteTrader(c *gin.Context) { status := trader.GetStatus() if isRunning, ok := status["is_running"].(bool); ok && isRunning { trader.Stop() - log.Printf("⏹ 已停止运行中的交易员: %s", traderID) + logger.Infof("⏹ 已停止运行中的交易员: %s", traderID) } } - log.Printf("✓ 交易员已删除: %s", traderID) + logger.Infof("✓ 交易员已删除: %s", traderID) c.JSON(http.StatusOK, gin.H{"message": "交易员已删除"}) } @@ -797,7 +809,7 @@ func (s *Server) handleStartTrader(c *gin.Context) { traderID := c.Param("id") // 校验交易员是否属于当前用户 - _, _, _, err := s.database.GetTraderConfig(userID, traderID) + _, err := s.store.Trader().GetFullConfig(userID, traderID) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "交易员不存在或无访问权限"}) return @@ -818,19 +830,19 @@ func (s *Server) handleStartTrader(c *gin.Context) { // 启动交易员 go func() { - log.Printf("▶️ 启动交易员 %s (%s)", traderID, trader.GetName()) + logger.Infof("▶️ 启动交易员 %s (%s)", traderID, trader.GetName()) if err := trader.Run(); err != nil { - log.Printf("❌ 交易员 %s 运行错误: %v", trader.GetName(), err) + logger.Infof("❌ 交易员 %s 运行错误: %v", trader.GetName(), err) } }() // 更新数据库中的运行状态 - err = s.database.UpdateTraderStatus(userID, traderID, true) + err = s.store.Trader().UpdateStatus(userID, traderID, true) if err != nil { - log.Printf("⚠️ 更新交易员状态失败: %v", err) + logger.Infof("⚠️ 更新交易员状态失败: %v", err) } - log.Printf("✓ 交易员 %s 已启动", trader.GetName()) + logger.Infof("✓ 交易员 %s 已启动", trader.GetName()) c.JSON(http.StatusOK, gin.H{"message": "交易员已启动"}) } @@ -840,7 +852,7 @@ func (s *Server) handleStopTrader(c *gin.Context) { traderID := c.Param("id") // 校验交易员是否属于当前用户 - _, _, _, err := s.database.GetTraderConfig(userID, traderID) + _, err := s.store.Trader().GetFullConfig(userID, traderID) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "交易员不存在或无访问权限"}) return @@ -863,12 +875,12 @@ func (s *Server) handleStopTrader(c *gin.Context) { trader.Stop() // 更新数据库中的运行状态 - err = s.database.UpdateTraderStatus(userID, traderID, false) + err = s.store.Trader().UpdateStatus(userID, traderID, false) if err != nil { - log.Printf("⚠️ 更新交易员状态失败: %v", err) + logger.Infof("⚠️ 更新交易员状态失败: %v", err) } - log.Printf("⏹ 交易员 %s 已停止", trader.GetName()) + logger.Infof("⏹ 交易员 %s 已停止", trader.GetName()) c.JSON(http.StatusOK, gin.H{"message": "交易员已停止"}) } @@ -888,7 +900,7 @@ func (s *Server) handleUpdateTraderPrompt(c *gin.Context) { } // 更新数据库 - err := s.database.UpdateTraderCustomPrompt(userID, traderID, req.CustomPrompt, req.OverrideBasePrompt) + err := s.store.Trader().UpdateCustomPrompt(userID, traderID, req.CustomPrompt, req.OverrideBasePrompt) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("更新自定义prompt失败: %v", err)}) return @@ -899,7 +911,7 @@ func (s *Server) handleUpdateTraderPrompt(c *gin.Context) { if err == nil { trader.SetCustomPrompt(req.CustomPrompt) trader.SetOverrideBasePrompt(req.OverrideBasePrompt) - log.Printf("✓ 已更新交易员 %s 的自定义prompt (覆盖基础=%v)", trader.GetName(), req.OverrideBasePrompt) + logger.Infof("✓ 已更新交易员 %s 的自定义prompt (覆盖基础=%v)", trader.GetName(), req.OverrideBasePrompt) } c.JSON(http.StatusOK, gin.H{"message": "自定义prompt已更新"}) @@ -910,15 +922,18 @@ func (s *Server) handleSyncBalance(c *gin.Context) { userID := c.GetString("user_id") traderID := c.Param("id") - log.Printf("🔄 用户 %s 请求同步交易员 %s 的余额", userID, traderID) + logger.Infof("🔄 用户 %s 请求同步交易员 %s 的余额", userID, traderID) // 从数据库获取交易员配置(包含交易所信息) - traderConfig, _, exchangeCfg, err := s.database.GetTraderConfig(userID, traderID) + fullConfig, err := s.store.Trader().GetFullConfig(userID, traderID) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "交易员不存在"}) return } + traderConfig := fullConfig.Trader + exchangeCfg := fullConfig.Exchange + if exchangeCfg == nil || !exchangeCfg.Enabled { c.JSON(http.StatusBadRequest, gin.H{"error": "交易所未配置或未启用"}) return @@ -954,7 +969,7 @@ func (s *Server) handleSyncBalance(c *gin.Context) { } if createErr != nil { - log.Printf("⚠️ 创建临时 trader 失败: %v", createErr) + logger.Infof("⚠️ 创建临时 trader 失败: %v", createErr) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("连接交易所失败: %v", createErr)}) return } @@ -962,7 +977,7 @@ func (s *Server) handleSyncBalance(c *gin.Context) { // 查询实际余额 balanceInfo, balanceErr := tempTrader.GetBalance() if balanceErr != nil { - log.Printf("⚠️ 查询交易所余额失败: %v", balanceErr) + logger.Infof("⚠️ 查询交易所余额失败: %v", balanceErr) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("查询余额失败: %v", balanceErr)}) return } @@ -989,24 +1004,24 @@ func (s *Server) handleSyncBalance(c *gin.Context) { changeType = "减少" } - log.Printf("✓ 查询到交易所实际余额: %.2f USDT (当前配置: %.2f USDT, 变化: %.2f%%)", + logger.Infof("✓ 查询到交易所实际余额: %.2f USDT (当前配置: %.2f USDT, 变化: %.2f%%)", actualBalance, oldBalance, changePercent) // 更新数据库中的 initial_balance - err = s.database.UpdateTraderInitialBalance(userID, traderID, actualBalance) + err = s.store.Trader().UpdateInitialBalance(userID, traderID, actualBalance) if err != nil { - log.Printf("❌ 更新initial_balance失败: %v", err) + logger.Infof("❌ 更新initial_balance失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "更新余额失败"}) return } // 重新加载交易员到内存 - err = s.traderManager.LoadUserTraders(s.database, userID) + err = s.traderManager.LoadUserTradersFromStore(s.store, userID) if err != nil { - log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err) + logger.Infof("⚠️ 重新加载用户交易员到内存失败: %v", err) } - log.Printf("✅ 已同步余额: %.2f → %.2f USDT (%s %.2f%%)", oldBalance, actualBalance, changeType, changePercent) + logger.Infof("✅ 已同步余额: %.2f → %.2f USDT (%s %.2f%%)", oldBalance, actualBalance, changeType, changePercent) c.JSON(http.StatusOK, gin.H{ "message": "余额同步成功", @@ -1020,14 +1035,14 @@ func (s *Server) handleSyncBalance(c *gin.Context) { // handleGetModelConfigs 获取AI模型配置 func (s *Server) handleGetModelConfigs(c *gin.Context) { userID := c.GetString("user_id") - log.Printf("🔍 查询用户 %s 的AI模型配置", userID) - models, err := s.database.GetAIModels(userID) + logger.Infof("🔍 查询用户 %s 的AI模型配置", userID) + models, err := s.store.AIModel().List(userID) if err != nil { - log.Printf("❌ 获取AI模型配置失败: %v", err) + logger.Infof("❌ 获取AI模型配置失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("获取AI模型配置失败: %v", err)}) return } - log.Printf("✅ 找到 %d 个AI模型配置", len(models)) + logger.Infof("✅ 找到 %d 个AI模型配置", len(models)) // 转换为安全的响应结构,移除敏感信息 safeModels := make([]SafeModelConfig, len(models)) @@ -1059,14 +1074,14 @@ func (s *Server) handleUpdateModelConfigs(c *gin.Context) { // 解析加密的 payload var encryptedPayload crypto.EncryptedPayload if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil { - log.Printf("❌ 解析加密载荷失败: %v", err) + logger.Infof("❌ 解析加密载荷失败: %v", err) c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误,必须使用加密传输"}) return } // 验证是否为加密数据 if encryptedPayload.WrappedKey == "" { - log.Printf("❌ 检测到非加密请求 (UserID: %s)", userID) + logger.Infof("❌ 检测到非加密请求 (UserID: %s)", userID) c.JSON(http.StatusBadRequest, gin.H{ "error": "此接口仅支持加密传输,请使用加密客户端", "code": "ENCRYPTION_REQUIRED", @@ -1078,7 +1093,7 @@ func (s *Server) handleUpdateModelConfigs(c *gin.Context) { // 解密数据 decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload) if err != nil { - log.Printf("❌ 解密模型配置失败 (UserID: %s): %v", userID, err) + logger.Infof("❌ 解密模型配置失败 (UserID: %s): %v", userID, err) c.JSON(http.StatusBadRequest, gin.H{"error": "解密数据失败"}) return } @@ -1086,15 +1101,15 @@ func (s *Server) handleUpdateModelConfigs(c *gin.Context) { // 解析解密后的数据 var req UpdateModelConfigRequest if err := json.Unmarshal([]byte(decrypted), &req); err != nil { - log.Printf("❌ 解析解密数据失败: %v", err) + logger.Infof("❌ 解析解密数据失败: %v", err) c.JSON(http.StatusBadRequest, gin.H{"error": "解析解密数据失败"}) return } - log.Printf("🔓 已解密模型配置数据 (UserID: %s)", userID) + logger.Infof("🔓 已解密模型配置数据 (UserID: %s)", userID) // 更新每个模型的配置 for modelID, modelData := range req.Models { - err := s.database.UpdateAIModel(userID, modelID, modelData.Enabled, modelData.APIKey, modelData.CustomAPIURL, modelData.CustomModelName) + err := s.store.AIModel().Update(userID, modelID, modelData.Enabled, modelData.APIKey, modelData.CustomAPIURL, modelData.CustomModelName) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("更新模型 %s 失败: %v", modelID, err)}) return @@ -1102,27 +1117,27 @@ func (s *Server) handleUpdateModelConfigs(c *gin.Context) { } // 重新加载该用户的所有交易员,使新配置立即生效 - err = s.traderManager.LoadUserTraders(s.database, userID) + err = s.traderManager.LoadUserTradersFromStore(s.store, userID) if err != nil { - log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err) + logger.Infof("⚠️ 重新加载用户交易员到内存失败: %v", err) // 这里不返回错误,因为模型配置已经成功更新到数据库 } - log.Printf("✓ AI模型配置已更新: %+v", req.Models) + logger.Infof("✓ AI模型配置已更新: %+v", req.Models) c.JSON(http.StatusOK, gin.H{"message": "模型配置已更新"}) } // handleGetExchangeConfigs 获取交易所配置 func (s *Server) handleGetExchangeConfigs(c *gin.Context) { userID := c.GetString("user_id") - log.Printf("🔍 查询用户 %s 的交易所配置", userID) - exchanges, err := s.database.GetExchanges(userID) + logger.Infof("🔍 查询用户 %s 的交易所配置", userID) + exchanges, err := s.store.Exchange().List(userID) if err != nil { - log.Printf("❌ 获取交易所配置失败: %v", err) + logger.Infof("❌ 获取交易所配置失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("获取交易所配置失败: %v", err)}) return } - log.Printf("✅ 找到 %d 个交易所配置", len(exchanges)) + logger.Infof("✅ 找到 %d 个交易所配置", len(exchanges)) // 调试:输出配置详情(脱敏) for _, ex := range exchanges { @@ -1134,12 +1149,12 @@ func (s *Server) handleGetExchangeConfigs(c *gin.Context) { if len(ex.SecretKey) > 8 { secretKeyMasked = ex.SecretKey[:8] + "..." } - log.Printf(" └─ 交易所: %s, APIKey: %s, SecretKey: %s", ex.ID, apiKeyMasked, secretKeyMasked) + logger.Infof(" └─ 交易所: %s, APIKey: %s, SecretKey: %s", ex.ID, apiKeyMasked, secretKeyMasked) } // 打印完整JSON响应用于调试 jsonData, _ := json.Marshal(exchanges) - log.Printf("📤 完整JSON响应: %s", string(jsonData)) + logger.Infof("📤 完整JSON响应: %s", string(jsonData)) // 转换为安全的响应结构,移除敏感信息 safeExchanges := make([]SafeExchangeConfig, len(exchanges)) @@ -1173,14 +1188,14 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) { // 解析加密的 payload var encryptedPayload crypto.EncryptedPayload if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil { - log.Printf("❌ 解析加密载荷失败: %v", err) + logger.Infof("❌ 解析加密载荷失败: %v", err) c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误,必须使用加密传输"}) return } // 验证是否为加密数据 if encryptedPayload.WrappedKey == "" { - log.Printf("❌ 检测到非加密请求 (UserID: %s)", userID) + logger.Infof("❌ 检测到非加密请求 (UserID: %s)", userID) c.JSON(http.StatusBadRequest, gin.H{ "error": "此接口仅支持加密传输,请使用加密客户端", "code": "ENCRYPTION_REQUIRED", @@ -1192,7 +1207,7 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) { // 解密数据 decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload) if err != nil { - log.Printf("❌ 解密交易所配置失败 (UserID: %s): %v", userID, err) + logger.Infof("❌ 解密交易所配置失败 (UserID: %s): %v", userID, err) c.JSON(http.StatusBadRequest, gin.H{"error": "解密数据失败"}) return } @@ -1200,15 +1215,15 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) { // 解析解密后的数据 var req UpdateExchangeConfigRequest if err := json.Unmarshal([]byte(decrypted), &req); err != nil { - log.Printf("❌ 解析解密数据失败: %v", err) + logger.Infof("❌ 解析解密数据失败: %v", err) c.JSON(http.StatusBadRequest, gin.H{"error": "解析解密数据失败"}) return } - log.Printf("🔓 已解密交易所配置数据 (UserID: %s)", userID) + logger.Infof("🔓 已解密交易所配置数据 (UserID: %s)", userID) // 更新每个交易所的配置 for exchangeID, exchangeData := range req.Exchanges { - err := s.database.UpdateExchange(userID, exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Testnet, exchangeData.HyperliquidWalletAddr, exchangeData.AsterUser, exchangeData.AsterSigner, exchangeData.AsterPrivateKey, exchangeData.LighterWalletAddr, exchangeData.LighterPrivateKey) + err := s.store.Exchange().Update(userID, exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Testnet, exchangeData.HyperliquidWalletAddr, exchangeData.AsterUser, exchangeData.AsterSigner, exchangeData.AsterPrivateKey, exchangeData.LighterWalletAddr, exchangeData.LighterPrivateKey) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("更新交易所 %s 失败: %v", exchangeID, err)}) return @@ -1216,20 +1231,20 @@ func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) { } // 重新加载该用户的所有交易员,使新配置立即生效 - err = s.traderManager.LoadUserTraders(s.database, userID) + err = s.traderManager.LoadUserTradersFromStore(s.store, userID) if err != nil { - log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err) + logger.Infof("⚠️ 重新加载用户交易员到内存失败: %v", err) // 这里不返回错误,因为交易所配置已经成功更新到数据库 } - log.Printf("✓ 交易所配置已更新: %+v", req.Exchanges) + logger.Infof("✓ 交易所配置已更新: %+v", req.Exchanges) c.JSON(http.StatusOK, gin.H{"message": "交易所配置已更新"}) } // handleGetUserSignalSource 获取用户信号源配置 func (s *Server) handleGetUserSignalSource(c *gin.Context) { userID := c.GetString("user_id") - source, err := s.database.GetUserSignalSource(userID) + source, err := s.store.SignalSource().Get(userID) if err != nil { // 如果配置不存在,返回空配置而不是404错误 c.JSON(http.StatusOK, gin.H{ @@ -1258,20 +1273,20 @@ func (s *Server) handleSaveUserSignalSource(c *gin.Context) { return } - err := s.database.CreateUserSignalSource(userID, req.CoinPoolURL, req.OITopURL) + err := s.store.SignalSource().Create(userID, req.CoinPoolURL, req.OITopURL) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("保存用户信号源配置失败: %v", err)}) return } - log.Printf("✓ 用户信号源配置已保存: user=%s, coin_pool=%s, oi_top=%s", userID, req.CoinPoolURL, req.OITopURL) + logger.Infof("✓ 用户信号源配置已保存: user=%s, coin_pool=%s, oi_top=%s", userID, req.CoinPoolURL, req.OITopURL) c.JSON(http.StatusOK, gin.H{"message": "用户信号源配置已保存"}) } // handleTraderList trader列表 func (s *Server) handleTraderList(c *gin.Context) { userID := c.GetString("user_id") - traders, err := s.database.GetTraders(userID) + traders, err := s.store.Trader().List(userID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("获取交易员列表失败: %v", err)}) return @@ -1313,11 +1328,12 @@ func (s *Server) handleGetTraderConfig(c *gin.Context) { return } - traderConfig, _, _, err := s.database.GetTraderConfig(userID, traderID) + fullCfg, err := s.store.Trader().GetFullConfig(userID, traderID) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("获取交易员配置失败: %v", err)}) return } + traderConfig := fullCfg.Trader // 获取实时运行状态 isRunning := traderConfig.IsRunning @@ -1384,17 +1400,17 @@ func (s *Server) handleAccount(c *gin.Context) { return } - log.Printf("📊 收到账户信息请求 [%s]", trader.GetName()) + logger.Infof("📊 收到账户信息请求 [%s]", trader.GetName()) account, err := trader.GetAccountInfo() if err != nil { - log.Printf("❌ 获取账户信息失败 [%s]: %v", trader.GetName(), err) + logger.Infof("❌ 获取账户信息失败 [%s]: %v", trader.GetName(), err) c.JSON(http.StatusInternalServerError, gin.H{ "error": fmt.Sprintf("获取账户信息失败: %v", err), }) return } - log.Printf("✓ 返回账户信息 [%s]: 净值=%.2f, 可用=%.2f, 盈亏=%.2f (%.2f%%)", + logger.Infof("✓ 返回账户信息 [%s]: 净值=%.2f, 可用=%.2f, 盈亏=%.2f (%.2f%%)", trader.GetName(), account["total_equity"], account["available_balance"], @@ -1443,7 +1459,7 @@ func (s *Server) handleDecisions(c *gin.Context) { } // 获取所有历史决策记录(无限制) - records, err := trader.GetDecisionLogger().GetLatestRecords(10000) + records, err := trader.GetStore().Decision().GetLatestRecords(trader.GetID(), 10000) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": fmt.Sprintf("获取决策日志失败: %v", err), @@ -1468,7 +1484,7 @@ func (s *Server) handleLatestDecisions(c *gin.Context) { return } - records, err := trader.GetDecisionLogger().GetLatestRecords(5) + records, err := trader.GetStore().Decision().GetLatestRecords(trader.GetID(), 5) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": fmt.Sprintf("获取决策日志失败: %v", err), @@ -1499,7 +1515,7 @@ func (s *Server) handleStatistics(c *gin.Context) { return } - stats, err := trader.GetDecisionLogger().GetStatistics() + stats, err := trader.GetStore().Decision().GetStatistics(trader.GetID()) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": fmt.Sprintf("获取统计信息失败: %v", err), @@ -1515,9 +1531,9 @@ func (s *Server) handleCompetition(c *gin.Context) { userID := c.GetString("user_id") // 确保用户的交易员已加载到内存中 - err := s.traderManager.LoadUserTraders(s.database, userID) + err := s.traderManager.LoadUserTradersFromStore(s.store, userID) if err != nil { - log.Printf("⚠️ 加载用户 %s 的交易员失败: %v", userID, err) + logger.Infof("⚠️ 加载用户 %s 的交易员失败: %v", userID, err) } competition, err := s.traderManager.GetCompetitionData() @@ -1547,7 +1563,7 @@ func (s *Server) handleEquityHistory(c *gin.Context) { // 获取尽可能多的历史数据(几天的数据) // 每3分钟一个周期:10000条 = 约20天的数据 - records, err := trader.GetDecisionLogger().GetLatestRecords(10000) + records, err := trader.GetStore().Decision().GetLatestRecords(trader.GetID(), 10000) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": fmt.Sprintf("获取历史数据失败: %v", err), @@ -1617,33 +1633,6 @@ func (s *Server) handleEquityHistory(c *gin.Context) { c.JSON(http.StatusOK, history) } -// handlePerformance AI历史表现分析(用于展示AI学习和反思) -func (s *Server) handlePerformance(c *gin.Context) { - _, traderID, err := s.getTraderFromQuery(c) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - trader, err := s.traderManager.GetTrader(traderID) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) - return - } - - // 分析最近100个周期的交易表现(避免长期持仓的交易记录丢失) - // 假设每3分钟一个周期,100个周期 = 5小时,足够覆盖大部分交易 - performance, err := trader.GetDecisionLogger().AnalyzePerformance(100) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": fmt.Sprintf("分析历史表现失败: %v", err), - }) - return - } - - c.JSON(http.StatusOK, performance) -} - // authMiddleware JWT认证中间件 func (s *Server) authMiddleware() gin.HandlerFunc { return func(c *gin.Context) { @@ -1730,7 +1719,7 @@ func (s *Server) handleRegister(c *gin.Context) { } // 检查是否开启了内测模式 - betaModeStr, _ := s.database.GetSystemConfig("beta_mode") + betaModeStr, _ := s.store.SystemConfig().Get("beta_mode") if betaModeStr == "true" { // 内测模式下必须提供有效的内测码 if req.BetaCode == "" { @@ -1739,7 +1728,7 @@ func (s *Server) handleRegister(c *gin.Context) { } // 验证内测码 - isValid, err := s.database.ValidateBetaCode(req.BetaCode) + isValid, err := s.store.BetaCode().Validate(req.BetaCode) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "验证内测码失败"}) return @@ -1751,7 +1740,7 @@ func (s *Server) handleRegister(c *gin.Context) { } // 检查邮箱是否已存在 - _, err := s.database.GetUserByEmail(req.Email) + _, err := s.store.User().GetByEmail(req.Email) if err == nil { c.JSON(http.StatusConflict, gin.H{"error": "邮箱已被注册"}) return @@ -1773,7 +1762,7 @@ func (s *Server) handleRegister(c *gin.Context) { // 创建用户(未验证OTP状态) userID := uuid.New().String() - user := &config.User{ + user := &store.User{ ID: userID, Email: req.Email, PasswordHash: passwordHash, @@ -1781,21 +1770,21 @@ func (s *Server) handleRegister(c *gin.Context) { OTPVerified: false, } - err = s.database.CreateUser(user) + err = s.store.User().Create(user) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "创建用户失败: " + err.Error()}) return } // 如果是内测模式,标记内测码为已使用 - betaModeStr2, _ := s.database.GetSystemConfig("beta_mode") + betaModeStr2, _ := s.store.SystemConfig().Get("beta_mode") if betaModeStr2 == "true" && req.BetaCode != "" { - err := s.database.UseBetaCode(req.BetaCode, req.Email) + err := s.store.BetaCode().Use(req.BetaCode, req.Email) if err != nil { - log.Printf("⚠️ 标记内测码为已使用失败: %v", err) + logger.Infof("⚠️ 标记内测码为已使用失败: %v", err) // 这里不返回错误,因为用户已经创建成功 } else { - log.Printf("✓ 内测码 %s 已被用户 %s 使用", req.BetaCode, req.Email) + logger.Infof("✓ 内测码 %s 已被用户 %s 使用", req.BetaCode, req.Email) } } @@ -1823,7 +1812,7 @@ func (s *Server) handleCompleteRegistration(c *gin.Context) { } // 获取用户信息 - user, err := s.database.GetUserByID(req.UserID) + user, err := s.store.User().GetByID(req.UserID) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "用户不存在"}) return @@ -1836,7 +1825,7 @@ func (s *Server) handleCompleteRegistration(c *gin.Context) { } // 更新用户OTP验证状态 - err = s.database.UpdateUserOTPVerified(req.UserID, true) + err = s.store.User().UpdateOTPVerified(req.UserID, true) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "更新用户状态失败"}) return @@ -1852,7 +1841,7 @@ func (s *Server) handleCompleteRegistration(c *gin.Context) { // 初始化用户的默认模型和交易所配置 err = s.initUserDefaultConfigs(user.ID) if err != nil { - log.Printf("初始化用户默认配置失败: %v", err) + logger.Infof("初始化用户默认配置失败: %v", err) } c.JSON(http.StatusOK, gin.H{ @@ -1876,7 +1865,7 @@ func (s *Server) handleLogin(c *gin.Context) { } // 获取用户信息 - user, err := s.database.GetUserByEmail(req.Email) + user, err := s.store.User().GetByEmail(req.Email) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "邮箱或密码错误"}) return @@ -1920,7 +1909,7 @@ func (s *Server) handleVerifyOTP(c *gin.Context) { } // 获取用户信息 - user, err := s.database.GetUserByID(req.UserID) + user, err := s.store.User().GetByID(req.UserID) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "用户不存在"}) return @@ -1961,7 +1950,7 @@ func (s *Server) handleResetPassword(c *gin.Context) { } // 查询用户 - user, err := s.database.GetUserByEmail(req.Email) + user, err := s.store.User().GetByEmail(req.Email) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "邮箱不存在"}) return @@ -1981,13 +1970,13 @@ func (s *Server) handleResetPassword(c *gin.Context) { } // 更新密码 - err = s.database.UpdateUserPassword(user.ID, newPasswordHash) + err = s.store.User().UpdatePassword(user.ID, newPasswordHash) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "密码更新失败"}) return } - log.Printf("✓ 用户 %s 密码已重置", user.Email) + logger.Infof("✓ 用户 %s 密码已重置", user.Email) c.JSON(http.StatusOK, gin.H{"message": "密码重置成功,请使用新密码登录"}) } @@ -1995,16 +1984,16 @@ func (s *Server) handleResetPassword(c *gin.Context) { func (s *Server) initUserDefaultConfigs(userID string) error { // 注释掉自动创建默认配置,让用户手动添加 // 这样新用户注册后不会自动有配置项 - log.Printf("用户 %s 注册完成,等待手动配置AI模型和交易所", userID) + logger.Infof("用户 %s 注册完成,等待手动配置AI模型和交易所", userID) return nil } // handleGetSupportedModels 获取系统支持的AI模型列表 func (s *Server) handleGetSupportedModels(c *gin.Context) { // 返回系统支持的AI模型(从default用户获取) - models, err := s.database.GetAIModels("default") + models, err := s.store.AIModel().List("default") if err != nil { - log.Printf("❌ 获取支持的AI模型失败: %v", err) + logger.Infof("❌ 获取支持的AI模型失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "获取支持的AI模型失败"}) return } @@ -2015,9 +2004,9 @@ func (s *Server) handleGetSupportedModels(c *gin.Context) { // handleGetSupportedExchanges 获取系统支持的交易所列表 func (s *Server) handleGetSupportedExchanges(c *gin.Context) { // 返回系统支持的交易所(从default用户获取) - exchanges, err := s.database.GetExchanges("default") + exchanges, err := s.store.Exchange().List("default") if err != nil { - log.Printf("❌ 获取支持的交易所失败: %v", err) + logger.Infof("❌ 获取支持的交易所失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "获取支持的交易所失败"}) return } @@ -2043,31 +2032,31 @@ func (s *Server) handleGetSupportedExchanges(c *gin.Context) { // Start 启动服务器 func (s *Server) Start() error { addr := fmt.Sprintf(":%d", s.port) - log.Printf("🌐 API服务器启动在 http://localhost%s", addr) - log.Printf("📊 API文档:") - log.Printf(" • GET /api/health - 健康检查") - log.Printf(" • GET /api/traders - 公开的AI交易员排行榜前50名(无需认证)") - log.Printf(" • GET /api/competition - 公开的竞赛数据(无需认证)") - log.Printf(" • GET /api/top-traders - 前5名交易员数据(无需认证,表现对比用)") - log.Printf(" • GET /api/equity-history?trader_id=xxx - 公开的收益率历史数据(无需认证,竞赛用)") - log.Printf(" • GET /api/equity-history-batch?trader_ids=a,b,c - 批量获取历史数据(无需认证,表现对比优化)") - log.Printf(" • GET /api/traders/:id/public-config - 公开的交易员配置(无需认证,不含敏感信息)") - log.Printf(" • POST /api/traders - 创建新的AI交易员") - log.Printf(" • DELETE /api/traders/:id - 删除AI交易员") - log.Printf(" • POST /api/traders/:id/start - 启动AI交易员") - log.Printf(" • POST /api/traders/:id/stop - 停止AI交易员") - log.Printf(" • GET /api/models - 获取AI模型配置") - log.Printf(" • PUT /api/models - 更新AI模型配置") - log.Printf(" • GET /api/exchanges - 获取交易所配置") - log.Printf(" • PUT /api/exchanges - 更新交易所配置") - log.Printf(" • GET /api/status?trader_id=xxx - 指定trader的系统状态") - log.Printf(" • GET /api/account?trader_id=xxx - 指定trader的账户信息") - log.Printf(" • GET /api/positions?trader_id=xxx - 指定trader的持仓列表") - log.Printf(" • GET /api/decisions?trader_id=xxx - 指定trader的决策日志") - log.Printf(" • GET /api/decisions/latest?trader_id=xxx - 指定trader的最新决策") - log.Printf(" • GET /api/statistics?trader_id=xxx - 指定trader的统计信息") - log.Printf(" • GET /api/performance?trader_id=xxx - 指定trader的AI学习表现分析") - log.Println() + logger.Infof("🌐 API服务器启动在 http://localhost%s", addr) + logger.Infof("📊 API文档:") + logger.Infof(" • GET /api/health - 健康检查") + logger.Infof(" • GET /api/traders - 公开的AI交易员排行榜前50名(无需认证)") + logger.Infof(" • GET /api/competition - 公开的竞赛数据(无需认证)") + logger.Infof(" • GET /api/top-traders - 前5名交易员数据(无需认证,表现对比用)") + logger.Infof(" • GET /api/equity-history?trader_id=xxx - 公开的收益率历史数据(无需认证,竞赛用)") + logger.Infof(" • GET /api/equity-history-batch?trader_ids=a,b,c - 批量获取历史数据(无需认证,表现对比优化)") + logger.Infof(" • GET /api/traders/:id/public-config - 公开的交易员配置(无需认证,不含敏感信息)") + logger.Infof(" • POST /api/traders - 创建新的AI交易员") + logger.Infof(" • DELETE /api/traders/:id - 删除AI交易员") + logger.Infof(" • POST /api/traders/:id/start - 启动AI交易员") + logger.Infof(" • POST /api/traders/:id/stop - 停止AI交易员") + logger.Infof(" • GET /api/models - 获取AI模型配置") + logger.Infof(" • PUT /api/models - 更新AI模型配置") + logger.Infof(" • GET /api/exchanges - 获取交易所配置") + logger.Infof(" • PUT /api/exchanges - 更新交易所配置") + logger.Infof(" • GET /api/status?trader_id=xxx - 指定trader的系统状态") + logger.Infof(" • GET /api/account?trader_id=xxx - 指定trader的账户信息") + logger.Infof(" • GET /api/positions?trader_id=xxx - 指定trader的持仓列表") + logger.Infof(" • GET /api/decisions?trader_id=xxx - 指定trader的决策日志") + logger.Infof(" • GET /api/decisions/latest?trader_id=xxx - 指定trader的最新决策") + logger.Infof(" • GET /api/statistics?trader_id=xxx - 指定trader的统计信息") + logger.Infof(" • GET /api/performance?trader_id=xxx - 指定trader的AI学习表现分析") + logger.Info() s.httpServer = &http.Server{ Addr: addr, @@ -2265,7 +2254,7 @@ func (s *Server) getEquityHistoryForTraders(traderIDs []string) map[string]inter } // 获取历史数据(用于对比展示,限制数据量) - records, err := trader.GetDecisionLogger().GetLatestRecords(500) + records, err := trader.GetStore().Decision().GetLatestRecords(trader.GetID(), 500) if err != nil { errors[traderID] = fmt.Sprintf("获取历史数据失败: %v", err) continue diff --git a/api/server_test.go b/api/server_test.go index f59817bd..0b9997a2 100644 --- a/api/server_test.go +++ b/api/server_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "testing" - "nofx/config" + "nofx/store" ) // TestUpdateTraderRequest_SystemPromptTemplate 测试更新交易员时 SystemPromptTemplate 字段是否存在 @@ -100,12 +100,12 @@ func TestUpdateTraderRequest_SystemPromptTemplate(t *testing.T) { func TestGetTraderConfigResponse_SystemPromptTemplate(t *testing.T) { tests := []struct { name string - traderConfig *config.TraderRecord + traderConfig *store.Trader expectedTemplate string }{ { name: "获取配置应该返回 system_prompt_template=nof1", - traderConfig: &config.TraderRecord{ + traderConfig: &store.Trader{ ID: "trader-123", UserID: "user-1", Name: "Test Trader", @@ -126,7 +126,7 @@ func TestGetTraderConfigResponse_SystemPromptTemplate(t *testing.T) { }, { name: "获取配置应该返回 system_prompt_template=default", - traderConfig: &config.TraderRecord{ + traderConfig: &store.Trader{ ID: "trader-456", UserID: "user-1", Name: "Test Trader 2", @@ -229,7 +229,7 @@ func TestUpdateTraderRequest_CompleteFields(t *testing.T) { // TestTraderListResponse_SystemPromptTemplate 测试 handleTraderList API 返回的 trader 对象是否包含 system_prompt_template 字段 func TestTraderListResponse_SystemPromptTemplate(t *testing.T) { // 模拟 handleTraderList 中的 trader 对象构造 - trader := &config.TraderRecord{ + trader := &store.Trader{ ID: "trader-001", UserID: "user-1", Name: "My Trader", diff --git a/backtest/manager.go b/backtest/manager.go index 6a0a4199..e0359cac 100644 --- a/backtest/manager.go +++ b/backtest/manager.go @@ -4,14 +4,14 @@ import ( "context" "errors" "fmt" - "log" + "nofx/logger" "os" "sort" "strings" "sync" - "nofx/logger" "nofx/mcp" + "nofx/store" ) type Manager struct { @@ -377,7 +377,7 @@ func (m *Manager) Status(runID string) *StatusPayload { func (m *Manager) launchWatcher(runID string, runner *Runner) { go func() { if err := runner.Wait(); err != nil { - log.Printf("backtest run %s finished with error: %v", runID, err) + logger.Infof("backtest run %s finished with error: %v", runID, err) } runner.PersistMetadata() meta := runner.CurrentMetadata() @@ -419,7 +419,7 @@ func (m *Manager) storeMetadata(runID string, meta *RunMetadata) { m.mu.Unlock() _ = SaveRunMetadata(meta) if err := updateRunIndex(meta, nil); err != nil { - log.Printf("failed to update run index for %s: %v", runID, err) + logger.Infof("failed to update run index for %s: %v", runID, err) } } @@ -445,7 +445,7 @@ func (m *Manager) resolveAIConfig(cfg *BacktestConfig) error { return resolver(cfg) } -func (m *Manager) GetTrace(runID string, cycle int) (*logger.DecisionRecord, error) { +func (m *Manager) GetTrace(runID string, cycle int) (*store.DecisionRecord, error) { return LoadDecisionTrace(runID, cycle) } @@ -462,18 +462,18 @@ func (m *Manager) RestoreRuns() error { for _, runID := range runIDs { meta, err := LoadRunMetadata(runID) if err != nil { - log.Printf("skip run %s: %v", runID, err) + logger.Infof("skip run %s: %v", runID, err) continue } if meta.State == RunStateRunning { lock, err := loadRunLock(runID) if err != nil || lockIsStale(lock) { if err := deleteRunLock(runID); err != nil { - log.Printf("failed to cleanup lock for %s: %v", runID, err) + logger.Infof("failed to cleanup lock for %s: %v", runID, err) } meta.State = RunStatePaused if err := SaveRunMetadata(meta); err != nil { - log.Printf("failed to mark %s paused: %v", runID, err) + logger.Infof("failed to mark %s paused: %v", runID, err) } } } @@ -481,7 +481,7 @@ func (m *Manager) RestoreRuns() error { m.metadata[runID] = meta m.mu.Unlock() if err := updateRunIndex(meta, nil); err != nil { - log.Printf("failed to sync index for %s: %v", runID, err) + logger.Infof("failed to sync index for %s: %v", runID, err) } } return nil diff --git a/backtest/retention.go b/backtest/retention.go index 3201bdce..55395c97 100644 --- a/backtest/retention.go +++ b/backtest/retention.go @@ -1,7 +1,7 @@ package backtest import ( - "log" + "nofx/logger" "os" "sort" "time" @@ -56,13 +56,13 @@ func enforceRetention(maxRuns int) { for i := 0; i < toRemove; i++ { runID := candidates[i].entry.RunID if err := os.RemoveAll(runDir(runID)); err != nil { - log.Printf("failed to prune run %s: %v", runID, err) + logger.Infof("failed to prune run %s: %v", runID, err) continue } delete(idx.Runs, runID) } if err := saveRunIndex(idx); err != nil { - log.Printf("failed to save index after pruning: %v", err) + logger.Infof("failed to save index after pruning: %v", err) } } @@ -91,11 +91,11 @@ func enforceRetentionDB(maxRuns int) { continue } if err := deleteRunDB(runID); err != nil { - log.Printf("failed to remove run %s: %v", runID, err) + logger.Infof("failed to remove run %s: %v", runID, err) continue } if err := os.RemoveAll(runDir(runID)); err != nil { - log.Printf("failed to remove run dir %s: %v", runID, err) + logger.Infof("failed to remove run dir %s: %v", runID, err) } } } diff --git a/backtest/runner.go b/backtest/runner.go index fafcd676..2c94954b 100644 --- a/backtest/runner.go +++ b/backtest/runner.go @@ -5,7 +5,7 @@ import ( "encoding/json" "errors" "fmt" - "log" + "nofx/logger" "os" "path/filepath" "sort" @@ -14,9 +14,9 @@ import ( "time" "nofx/decision" - "nofx/logger" "nofx/market" "nofx/mcp" + "nofx/store" ) var ( @@ -35,7 +35,7 @@ type Runner struct { feed *DataFeed account *BacktestAccount - decisionLogger logger.IDecisionLogger + decisionLogDir string mcpClient mcp.AIClient statusMu sync.RWMutex @@ -83,7 +83,7 @@ func NewRunner(cfg BacktestConfig, mcpClient mcp.AIClient) (*Runner, error) { return nil, err } - dLog := logger.NewDecisionLogger(decisionLogDir(cfg.RunID)) + dLogDir := decisionLogDir(cfg.RunID) account := NewBacktestAccount(cfg.InitialBalance, cfg.FeeBps, cfg.SlippageBps) createdAt := time.Now().UTC() @@ -119,7 +119,7 @@ func NewRunner(cfg BacktestConfig, mcpClient mcp.AIClient) (*Runner, error) { cfg: cfg, feed: feed, account: account, - decisionLogger: dLog, + decisionLogDir: dLogDir, mcpClient: client, status: RunStateCreated, state: state, @@ -160,7 +160,7 @@ func (r *Runner) lockHeartbeatLoop() { select { case <-ticker.C: if err := updateRunLockHeartbeat(r.lockInfo); err != nil { - log.Printf("failed to update lock heartbeat for %s: %v", r.cfg.RunID, err) + logger.Infof("failed to update lock heartbeat for %s: %v", r.cfg.RunID, err) } case <-r.lockStop: return @@ -174,7 +174,7 @@ func (r *Runner) releaseLock() { r.lockStop = nil } if err := deleteRunLock(r.cfg.RunID); err != nil { - log.Printf("failed to release lock for %s: %v", r.cfg.RunID, err) + logger.Infof("failed to release lock for %s: %v", r.cfg.RunID, err) } r.lockInfo = nil } @@ -279,8 +279,8 @@ func (r *Runner) stepOnce() error { shouldDecide := r.shouldTriggerDecision(state.BarIndex) var ( - record *logger.DecisionRecord - decisionActions []logger.DecisionAction + record *store.DecisionRecord + decisionActions []store.DecisionAction tradeEvents = make([]TradeEvent, 0) execLog []string hadError bool @@ -317,7 +317,7 @@ func (r *Runner) stepOnce() error { return decisionErr } } else { - log.Printf("failed to compute ai cache key: %v", err) + logger.Infof("failed to compute ai cache key: %v", err) } } @@ -334,7 +334,7 @@ func (r *Runner) stepOnce() error { fullDecision = fd if r.cfg.CacheAI && r.aiCache != nil && cacheKey != "" { if err := r.aiCache.Put(cacheKey, r.cfg.PromptVariant, ts, fullDecision); err != nil { - log.Printf("failed to persist ai cache for %s: %v", r.cfg.RunID, err) + logger.Infof("failed to persist ai cache for %s: %v", r.cfg.RunID, err) } } } @@ -346,7 +346,7 @@ func (r *Runner) stepOnce() error { sorted := sortDecisionsByPriority(fullDecision.Decisions) prevLogs := execLog - decisionActions = make([]logger.DecisionAction, 0, len(sorted)) + decisionActions = make([]store.DecisionAction, 0, len(sorted)) execLog = make([]string, 0, len(sorted)+len(prevLogs)) if len(prevLogs) > 0 { execLog = append(execLog, prevLogs...) @@ -464,7 +464,7 @@ func (r *Runner) stepOnce() error { return nil } -func (r *Runner) buildDecisionContext(ts int64, marketData map[string]*market.Data, multiTF map[string]map[string]*market.Data, priceMap map[string]float64, callCount int) (*decision.Context, *logger.DecisionRecord, error) { +func (r *Runner) buildDecisionContext(ts int64, marketData map[string]*market.Data, multiTF map[string]map[string]*market.Data, priceMap map[string]float64, callCount int) (*decision.Context, *store.DecisionRecord, error) { equity, unrealized, _ := r.account.TotalEquity(priceMap) available := r.account.Cash() marginUsed := r.totalMarginUsed() @@ -505,8 +505,8 @@ func (r *Runner) buildDecisionContext(ts int64, marketData map[string]*market.Da AltcoinLeverage: r.cfg.Leverage.AltcoinLeverage, } - record := &logger.DecisionRecord{ - AccountState: logger.AccountSnapshot{ + record := &store.DecisionRecord{ + AccountState: store.AccountSnapshot{ TotalBalance: accountInfo.TotalEquity, AvailableBalance: accountInfo.AvailableBalance, TotalUnrealizedProfit: unrealized, @@ -524,7 +524,7 @@ func (r *Runner) buildDecisionContext(ts int64, marketData map[string]*market.Da return ctx, record, nil } -func (r *Runner) fillDecisionRecord(record *logger.DecisionRecord, full *decision.FullDecision) { +func (r *Runner) fillDecisionRecord(record *store.DecisionRecord, full *decision.FullDecision) { record.InputPrompt = full.UserPrompt record.CoTTrace = full.CoTTrace if len(full.Decisions) > 0 { @@ -554,10 +554,10 @@ func (r *Runner) invokeAIWithRetry(ctx *decision.Context) (*decision.FullDecisio return nil, lastErr } -func (r *Runner) executeDecision(dec decision.Decision, priceMap map[string]float64, ts int64, cycle int) (logger.DecisionAction, []TradeEvent, string, error) { +func (r *Runner) executeDecision(dec decision.Decision, priceMap map[string]float64, ts int64, cycle int) (store.DecisionAction, []TradeEvent, string, error) { symbol := dec.Symbol usedLeverage := r.resolveLeverage(dec.Leverage, symbol) - actionRecord := logger.DecisionAction{ + actionRecord := store.DecisionAction{ Action: dec.Action, Symbol: symbol, Leverage: usedLeverage, @@ -748,12 +748,12 @@ func (r *Runner) remainingPosition(symbol, side string) float64 { return 0 } -func (r *Runner) snapshotPositions(priceMap map[string]float64) []logger.PositionSnapshot { +func (r *Runner) snapshotPositions(priceMap map[string]float64) []store.PositionSnapshot { positions := r.account.Positions() - list := make([]logger.PositionSnapshot, 0, len(positions)) + list := make([]store.PositionSnapshot, 0, len(positions)) for _, pos := range positions { price := priceMap[pos.Symbol] - list = append(list, logger.PositionSnapshot{ + list = append(list, store.PositionSnapshot{ Symbol: pos.Symbol, Side: pos.Side, PositionAmt: pos.Quantity, @@ -1124,21 +1124,18 @@ func (r *Runner) persistMetadata() { meta := r.buildMetadata(state, r.Status()) meta.CreatedAt = r.createdAt if err := SaveRunMetadata(meta); err != nil { - log.Printf("failed to save run metadata for %s: %v", r.cfg.RunID, err) + logger.Infof("failed to save run metadata for %s: %v", r.cfg.RunID, err) } else { if err := updateRunIndex(meta, &r.cfg); err != nil { - log.Printf("failed to update index for %s: %v", r.cfg.RunID, err) + logger.Infof("failed to update index for %s: %v", r.cfg.RunID, err) } } } -func (r *Runner) logDecision(record *logger.DecisionRecord) error { +func (r *Runner) logDecision(record *store.DecisionRecord) error { if record == nil { return nil } - if err := r.decisionLogger.LogDecision(record); err != nil { - return err - } persistDecisionRecord(r.cfg.RunID, record) return nil } @@ -1157,14 +1154,14 @@ func (r *Runner) persistMetrics(force bool) { state := r.snapshotState() metrics, err := CalculateMetrics(r.cfg.RunID, &r.cfg, &state) if err != nil { - log.Printf("failed to compute metrics for %s: %v", r.cfg.RunID, err) + logger.Infof("failed to compute metrics for %s: %v", r.cfg.RunID, err) return } if metrics == nil { return } if err := PersistMetrics(r.cfg.RunID, metrics); err != nil { - log.Printf("failed to persist metrics for %s: %v", r.cfg.RunID, err) + logger.Infof("failed to persist metrics for %s: %v", r.cfg.RunID, err) return } r.lastMetricsWrite = time.Now() @@ -1264,7 +1261,7 @@ func (r *Runner) saveCheckpoint(state BacktestState) error { func (r *Runner) forceCheckpoint() { state := r.snapshotState() if err := r.saveCheckpoint(state); err != nil { - log.Printf("failed to save checkpoint for %s: %v", r.cfg.RunID, err) + logger.Infof("failed to save checkpoint for %s: %v", r.cfg.RunID, err) } } @@ -1281,7 +1278,6 @@ func (r *Runner) applyCheckpoint(ckpt *Checkpoint) error { return fmt.Errorf("checkpoint is nil") } r.account.RestoreFromSnapshots(ckpt.Cash, ckpt.RealizedPnL, ckpt.Positions) - r.decisionLogger.SetCycleNumber(ckpt.DecisionCycle) r.stateMu.Lock() defer r.stateMu.Unlock() r.state.BarIndex = ckpt.BarIndex diff --git a/backtest/storage.go b/backtest/storage.go index 7949655d..c5bf1405 100644 --- a/backtest/storage.go +++ b/backtest/storage.go @@ -13,7 +13,7 @@ import ( "strings" "time" - "nofx/logger" + "nofx/store" ) const ( @@ -380,7 +380,7 @@ func PersistMetrics(runID string, metrics *Metrics) error { return saveMetrics(runID, metrics) } -func LoadDecisionTrace(runID string, cycle int) (*logger.DecisionRecord, error) { +func LoadDecisionTrace(runID string, cycle int) (*store.DecisionRecord, error) { if usingDB() { return loadDecisionTraceDB(runID, cycle) } @@ -418,7 +418,7 @@ func LoadDecisionTrace(runID string, cycle int) (*logger.DecisionRecord, error) if err != nil { continue } - var record logger.DecisionRecord + var record store.DecisionRecord if err := json.Unmarshal(data, &record); err != nil { continue } @@ -429,7 +429,7 @@ func LoadDecisionTrace(runID string, cycle int) (*logger.DecisionRecord, error) return nil, fmt.Errorf("decision trace not found for run %s cycle %d", runID, cycle) } -func LoadDecisionRecords(runID string, limit, offset int) ([]*logger.DecisionRecord, error) { +func LoadDecisionRecords(runID string, limit, offset int) ([]*store.DecisionRecord, error) { if limit <= 0 { limit = 20 } @@ -443,7 +443,7 @@ func LoadDecisionRecords(runID string, limit, offset int) ([]*logger.DecisionRec entries, err := os.ReadDir(dir) if err != nil { if errors.Is(err, os.ErrNotExist) { - return []*logger.DecisionRecord{}, nil + return []*store.DecisionRecord{}, nil } return nil, err } @@ -471,19 +471,19 @@ func LoadDecisionRecords(runID string, limit, offset int) ([]*logger.DecisionRec return infoI.ModTime().After(infoJ.ModTime()) }) if offset >= len(files) { - return []*logger.DecisionRecord{}, nil + return []*store.DecisionRecord{}, nil } end := offset + limit if end > len(files) { end = len(files) } - records := make([]*logger.DecisionRecord, 0, end-offset) + records := make([]*store.DecisionRecord, 0, end-offset) for _, file := range files[offset:end] { data, err := os.ReadFile(file.path) if err != nil { continue } - var record logger.DecisionRecord + var record store.DecisionRecord if err := json.Unmarshal(data, &record); err != nil { continue } @@ -553,7 +553,7 @@ func CreateRunExport(runID string) (string, error) { return tmpFile.Name(), nil } -func persistDecisionRecord(runID string, record *logger.DecisionRecord) { +func persistDecisionRecord(runID string, record *store.DecisionRecord) { if !usingDB() || record == nil { return } diff --git a/backtest/storage_db_impl.go b/backtest/storage_db_impl.go index 3f7eb508..67cc0831 100644 --- a/backtest/storage_db_impl.go +++ b/backtest/storage_db_impl.go @@ -9,7 +9,7 @@ import ( "os" "time" - "nofx/logger" + "nofx/store" ) func saveCheckpointDB(runID string, ckpt *Checkpoint) error { @@ -273,7 +273,7 @@ func saveProgressDB(runID string, payload progressPayload) error { return err } -func loadDecisionTraceDB(runID string, cycle int) (*logger.DecisionRecord, error) { +func loadDecisionTraceDB(runID string, cycle int) (*store.DecisionRecord, error) { query := `SELECT payload FROM backtest_decisions WHERE run_id = ?` var rows *sql.Rows var err error @@ -293,14 +293,14 @@ func loadDecisionTraceDB(runID string, cycle int) (*logger.DecisionRecord, error if err := rows.Scan(&payload); err != nil { return nil, err } - var record logger.DecisionRecord + var record store.DecisionRecord if err := json.Unmarshal(payload, &record); err != nil { return nil, err } return &record, nil } -func saveDecisionRecordDB(runID string, record *logger.DecisionRecord) error { +func saveDecisionRecordDB(runID string, record *store.DecisionRecord) error { if record == nil { return nil } @@ -315,7 +315,7 @@ func saveDecisionRecordDB(runID string, record *logger.DecisionRecord) error { return err } -func loadDecisionRecordsDB(runID string, limit, offset int) ([]*logger.DecisionRecord, error) { +func loadDecisionRecordsDB(runID string, limit, offset int) ([]*store.DecisionRecord, error) { rows, err := persistenceDB.Query(` SELECT payload FROM backtest_decisions WHERE run_id = ? @@ -326,13 +326,13 @@ func loadDecisionRecordsDB(runID string, limit, offset int) ([]*logger.DecisionR return nil, err } defer rows.Close() - records := make([]*logger.DecisionRecord, 0, limit) + records := make([]*store.DecisionRecord, 0, limit) for rows.Next() { var payload []byte if err := rows.Scan(&payload); err != nil { return nil, err } - var record logger.DecisionRecord + var record store.DecisionRecord if err := json.Unmarshal(payload, &record); err != nil { return nil, err } diff --git a/bootstrap/README.md b/bootstrap/README.md deleted file mode 100644 index 4db4b260..00000000 --- a/bootstrap/README.md +++ /dev/null @@ -1,455 +0,0 @@ -# Bootstrap 模块初始化框架 - -## 概述 - -Bootstrap 是一个模块化的初始化框架,允许各个模块通过注册钩子的方式自动完成初始化,支持优先级控制、条件初始化、错误策略等高级特性。 - -## 核心特性 - -- ✅ **优先级排序** - 保证模块按正确的顺序初始化 -- ✅ **钩子命名** - 每个钩子都有清晰的名称,便于日志追踪和错误定位 -- ✅ **上下文传递** - 模块之间可以共享数据(如数据库实例) -- ✅ **条件初始化** - 根据配置动态决定是否初始化某个模块 -- ✅ **灵活的错误处理** - 支持快速失败、继续执行、警告三种策略 -- ✅ **详细日志** - 显示初始化进度、耗时统计 -- ✅ **线程安全** - 使用互斥锁保护全局状态 -- ✅ **测试友好** - 提供 Clear() 方法清除钩子 - -## 快速开始 - -### 1. 在模块中注册初始化钩子 - -在你的模块包中创建 `init.go` 文件: - -```go -// proxy/init.go -package proxy - -import ( - "nofx/bootstrap" - "nofx/config" -) - -func init() { - // 注册初始化钩子 - bootstrap.Register("Proxy模块", bootstrap.PriorityCore, initProxyModule) -} - -func initProxyModule(ctx *bootstrap.Context) error { - // 从配置中读取 proxy 配置 - proxyConfig := ctx.Config.Proxy - - // 初始化代理管理器 - if err := InitGlobalProxyManager(proxyConfig); err != nil { - return err - } - - // 将实例存储到上下文,供其他模块使用 - ctx.Set("proxy_manager", GetGlobalProxyManager()) - - return nil -} -``` - -### 2. 在 main.go 中运行初始化 - -```go -package main - -import ( - "log" - "nofx/bootstrap" - "nofx/config" - - // 导入需要初始化的模块(触发 init() 注册) - _ "nofx/proxy" - _ "nofx/market" - _ "nofx/trader" -) - -func main() { - // 加载配置 - cfg, err := config.LoadConfig("config.json") - if err != nil { - log.Fatalf("加载配置失败: %v", err) - } - - // 创建初始化上下文 - ctx := bootstrap.NewContext(cfg) - - // 执行所有初始化钩子 - if err := bootstrap.Run(ctx); err != nil { - log.Fatalf("初始化失败: %v", err) - } - - // 启动业务逻辑... -} -``` - -### 3. 运行效果 - -``` -🔄 开始初始化 3 个模块... - [1/3] 初始化: Database模块 (优先级: 20) - ✓ 完成: Database模块 (耗时: 120ms) - [2/3] 初始化: Proxy模块 (优先级: 50) - ↳ 代理自动刷新已启动 (间隔: 30m0s) - ↳ 代理池状态: 总计=5, 黑名单=0, 可用=5 - ✓ 完成: Proxy模块 (耗时: 35ms) - [3/3] 初始化: Market模块 (优先级: 100) - ✓ 完成: Market模块 (耗时: 200ms) -✅ 所有模块初始化完成 (总耗时: 355ms) -📊 统计: 成功=3, 跳过=0 -``` - -## 优先级常量 - -系统预定义了以下优先级常量(数值越小越先执行): - -| 常量 | 值 | 用途 | 示例 | -|------|-----|------|------| -| `PriorityInfrastructure` | 10 | 基础设施 | 日志系统、配置加载 | -| `PriorityDatabase` | 20 | 数据库连接 | SQLite、Redis | -| `PriorityCore` | 50 | 核心模块 | Proxy、Market Monitor | -| `PriorityBusiness` | 100 | 业务模块 | Trader、API Server | -| `PriorityBackground` | 200 | 后台任务 | 定时任务、监控 | - -### 使用示例 - -```go -// 数据库模块(最先初始化) -bootstrap.Register("Database", bootstrap.PriorityDatabase, initDatabase) - -// 代理模块(核心模块) -bootstrap.Register("Proxy", bootstrap.PriorityCore, initProxy) - -// Trader模块(依赖数据库和代理) -bootstrap.Register("Trader", bootstrap.PriorityBusiness, initTrader) -``` - -## 高级特性 - -### 1. 条件初始化 - -某些模块只在特定条件下才需要初始化: - -```go -bootstrap.Register("Proxy模块", bootstrap.PriorityCore, initProxy). - EnabledIf(func(ctx *bootstrap.Context) bool { - // 只在配置中启用 proxy 时才初始化 - return ctx.Config.Proxy != nil && ctx.Config.Proxy.Enabled - }) -``` - -**输出**: -``` - [2/5] 跳过: Proxy模块 (条件未满足) -``` - -### 2. 错误处理策略 - -支持三种错误处理策略: - -#### FailFast(默认)- 遇到错误立即停止 - -```go -bootstrap.Register("Database", bootstrap.PriorityDatabase, initDatabase) -// 默认就是 FailFast,无需显式设置 -``` - -**效果**:Database 初始化失败,整个系统停止启动 - -#### ContinueOnError - 继续执行,收集所有错误 - -```go -bootstrap.Register("Proxy", bootstrap.PriorityCore, initProxy). - OnError(bootstrap.ContinueOnError) -``` - -**效果**:Proxy 失败不影响其他模块,最后汇总所有错误 - -#### WarnOnError - 继续执行,只打印警告 - -```go -bootstrap.Register("Proxy", bootstrap.PriorityCore, initProxy). - OnError(bootstrap.WarnOnError) -``` - -**效果**:Proxy 失败只打印警告,不影响系统运行 - -**输出**: -``` - [2/5] 初始化: Proxy模块 (优先级: 50) - ⚠️ 警告: Proxy模块 (耗时: 15ms) - 连接代理服务器超时 -``` - -### 3. 上下文数据共享 - -模块之间可以通过 Context 共享数据: - -```go -// database/init.go - 存储数据库实例 -func initDatabase(ctx *bootstrap.Context) error { - db, err := sql.Open("sqlite", "config.db") - if err != nil { - return err - } - - // 存储到上下文 - ctx.Set("database", db) - return nil -} - -// trader/init.go - 获取数据库实例 -func initTrader(ctx *bootstrap.Context) error { - // 从上下文获取数据库实例 - db, ok := ctx.Get("database") - if !ok { - return fmt.Errorf("database 未初始化") - } - - database := db.(*sql.DB) - // 使用 database 初始化 trader... - return nil -} -``` - -**安全获取**: -```go -// 使用 MustGet,不存在会 panic(适合必需的依赖) -db := ctx.MustGet("database").(*sql.DB) -``` - -### 4. 链式调用 - -支持流畅的链式调用: - -```go -bootstrap.Register("Proxy", bootstrap.PriorityCore, initProxy). - EnabledIf(func(ctx *bootstrap.Context) bool { - return ctx.Config.Proxy != nil && ctx.Config.Proxy.Enabled - }). - OnError(bootstrap.WarnOnError) -``` - -### 5. 自定义错误策略 - -在 Run 时可以指定全局默认错误策略: - -```go -// 所有钩子默认使用 ContinueOnError,除非钩子自己指定了 FailFast -err := bootstrap.RunWithPolicy(ctx, bootstrap.ContinueOnError) -``` - -## 完整示例 - -### 示例1:Database 模块 - -```go -// database/init.go -package database - -import ( - "database/sql" - "nofx/bootstrap" -) - -func init() { - bootstrap.Register("Database", bootstrap.PriorityDatabase, initDatabase) -} - -func initDatabase(ctx *bootstrap.Context) error { - db, err := sql.Open("sqlite", "config.db") - if err != nil { - return err - } - - // 测试连接 - if err := db.Ping(); err != nil { - return err - } - - // 存储到上下文 - ctx.Set("database", db) - return nil -} -``` - -### 示例2:Proxy 模块(条件初始化 + 警告策略) - -```go -// proxy/init.go -package proxy - -import ( - "nofx/bootstrap" - "nofx/config" -) - -func init() { - bootstrap.Register("Proxy", bootstrap.PriorityCore, initProxy). - EnabledIf(func(ctx *bootstrap.Context) bool { - return ctx.Config.Proxy != nil && ctx.Config.Proxy.Enabled - }). - OnError(bootstrap.WarnOnError) // Proxy 失败不影响系统 -} - -func initProxy(ctx *bootstrap.Context) error { - proxyConfig := convertConfig(ctx.Config.Proxy) - - if err := InitGlobalProxyManager(proxyConfig); err != nil { - return err - } - - ctx.Set("proxy_manager", GetGlobalProxyManager()) - return nil -} -``` - -### 示例3:Trader 模块(依赖其他模块) - -```go -// trader/init.go -package trader - -import ( - "nofx/bootstrap" -) - -func init() { - bootstrap.Register("Trader", bootstrap.PriorityBusiness, initTrader) -} - -func initTrader(ctx *bootstrap.Context) error { - // 获取依赖 - db := ctx.MustGet("database").(*sql.DB) - - // 可选依赖 - var proxyMgr *proxy.ProxyManager - if pm, ok := ctx.Get("proxy_manager"); ok { - proxyMgr = pm.(*proxy.ProxyManager) - } - - // 使用依赖初始化 trader... - return nil -} -``` - -## 调试和测试 - -### 查看已注册的钩子 - -```go -hooks := bootstrap.GetRegistered() -for _, hook := range hooks { - fmt.Printf("钩子: %s, 优先级: %d\n", hook.Name, hook.Priority) -} -``` - -### 清除钩子(用于测试) - -```go -func TestMyModule(t *testing.T) { - // 清除之前注册的钩子 - bootstrap.Clear() - - // 注册测试钩子 - bootstrap.Register("Test", 10, func(ctx *bootstrap.Context) error { - return nil - }) - - // 运行测试... -} -``` - -### 统计钩子数量 - -```go -count := bootstrap.Count() -fmt.Printf("已注册 %d 个初始化钩子\n", count) -``` - -## 错误处理最佳实践 - -### 1. 关键模块使用 FailFast - -```go -// 数据库是关键依赖,失败必须停止 -bootstrap.Register("Database", bootstrap.PriorityDatabase, initDatabase) -// 默认是 FailFast,无需显式设置 -``` - -### 2. 可选模块使用 WarnOnError - -```go -// Proxy 是可选的,失败可以使用直连 -bootstrap.Register("Proxy", bootstrap.PriorityCore, initProxy). - OnError(bootstrap.WarnOnError) -``` - -### 3. 批量初始化使用 ContinueOnError - -```go -// 批量加载插件,希望看到所有失败的插件 -for _, plugin := range plugins { - bootstrap.Register(plugin.Name, 150, plugin.Init). - OnError(bootstrap.ContinueOnError) -} -``` - -## 常见问题 - -### Q1: 如何保证模块A在模块B之前初始化? - -使用优先级控制: -```go -bootstrap.Register("ModuleA", 50, initA) // 先执行 -bootstrap.Register("ModuleB", 100, initB) // 后执行 -``` - -### Q2: 如何在初始化失败时获取详细信息? - -钩子名称会自动包含在错误信息中: -``` -Error: [Proxy模块] 初始化失败: 连接代理服务器超时 -``` - -### Q3: 可以动态注册钩子吗? - -可以,但建议在 `init()` 函数中注册: -```go -// 推荐:在 init() 中注册(包加载时自动执行) -func init() { - bootstrap.Register("MyModule", 100, initModule) -} - -// 不推荐:在运行时注册(可能导致顺序问题) -func main() { - bootstrap.Register("MyModule", 100, initModule) -} -``` - -### Q4: 如何在钩子中访问命令行参数? - -通过 Context 的 Data 字段传递: -```go -// main.go -ctx := bootstrap.NewContext(cfg) -ctx.Set("args", os.Args) - -// module/init.go -func initModule(ctx *bootstrap.Context) error { - args := ctx.MustGet("args").([]string) - // 使用 args... -} -``` -## 性能考虑 - -- 钩子注册是线程安全的,但注册本身有轻微的锁开销 -- 建议在 `init()` 函数中注册,避免运行时动态注册 -- 钩子执行是顺序的,不会并发执行 -- 每个钩子的耗时会被记录并显示 - -## 许可证 - -本模块为 NOFX 项目内部模块,遵循项目整体许可证。 diff --git a/bootstrap/bootstrap.go b/bootstrap/bootstrap.go deleted file mode 100644 index ee756113..00000000 --- a/bootstrap/bootstrap.go +++ /dev/null @@ -1,169 +0,0 @@ -package bootstrap - -import ( - "fmt" - "log" - "nofx/logger" - "sort" - "sync" - "time" -) - -// Priority 初始化优先级常量 -const ( - PriorityInfrastructure = 10 // 基础设施(日志、配置等) - PriorityDatabase = 20 // 数据库连接 - PriorityCore = 50 // 核心模块(Proxy、Market等) - PriorityBusiness = 100 // 业务模块(Trader、API等) - PriorityBackground = 200 // 后台任务 -) - -// ErrorPolicy 错误处理策略 -type ErrorPolicy int - -const ( - // FailFast 遇到错误立即停止(默认) - FailFast ErrorPolicy = iota - // ContinueOnError 继续执行,收集所有错误 - ContinueOnError - // WarnOnError 继续执行,只打印警告 - WarnOnError -) - -var ( - hooks []Hook - hooksMu sync.Mutex -) - -// Register 注册初始化钩子 -// name: 模块名称(如 "Proxy", "Database") -// priority: 优先级(建议使用常量:PriorityCore、PriorityBusiness等) -// fn: 初始化函数 -func Register(name string, priority int, fn func(*Context) error) *HookBuilder { - hooksMu.Lock() - defer hooksMu.Unlock() - - hook := Hook{ - Name: name, - Priority: priority, - Func: fn, - Enabled: nil, // 默认启用 - ErrorPolicy: FailFast, - } - - hooks = append(hooks, hook) - - return &HookBuilder{hook: &hooks[len(hooks)-1]} -} - -// Run 执行所有已注册的钩子 -func Run(ctx *Context) error { - return RunWithPolicy(ctx, FailFast) -} - -// RunWithPolicy 使用指定的默认错误策略执行所有钩子 -func RunWithPolicy(ctx *Context, defaultPolicy ErrorPolicy) error { - hooksMu.Lock() - hooksCopy := make([]Hook, len(hooks)) - copy(hooksCopy, hooks) - hooksMu.Unlock() - - if len(hooksCopy) == 0 { - log.Printf("⚠️ 没有注册任何初始化钩子") - return nil - } - - // 按优先级排序 - sort.Slice(hooksCopy, func(i, j int) bool { - return hooksCopy[i].Priority < hooksCopy[j].Priority - }) - - log.Printf("🔄 开始初始化 %d 个模块...", len(hooksCopy)) - startTime := time.Now() - - var errors []error - successCount := 0 - skippedCount := 0 - - for i, hook := range hooksCopy { - // 检查是否启用 - if hook.Enabled != nil && !hook.Enabled(ctx) { - log.Printf(" [%d/%d] 跳过: %s (条件未满足)", - i+1, len(hooksCopy), hook.Name) - skippedCount++ - continue - } - - log.Printf(" [%d/%d] 初始化: %s (优先级: %d)", - i+1, len(hooksCopy), hook.Name, hook.Priority) - - hookStart := time.Now() - err := hook.Func(ctx) - elapsed := time.Since(hookStart) - - if err != nil { - errMsg := fmt.Errorf("[%s] 初始化失败: %w", hook.Name, err) - - // 根据错误策略处理 - policy := hook.ErrorPolicy - if policy == FailFast && defaultPolicy != FailFast { - policy = defaultPolicy - } - - switch policy { - case FailFast: - log.Printf(" ❌ 失败: %s (耗时: %v)", hook.Name, elapsed) - return errMsg - case ContinueOnError: - log.Printf(" ❌ 失败: %s (耗时: %v) - 继续执行", hook.Name, elapsed) - errors = append(errors, errMsg) - case WarnOnError: - log.Printf(" ⚠️ 警告: %s (耗时: %v) - %v", hook.Name, elapsed, err) - } - } else { - log.Printf(" ✓ 完成: %s (耗时: %v)", hook.Name, elapsed) - successCount++ - } - } - - totalElapsed := time.Since(startTime) - - // 汇总结果 - if len(errors) > 0 { - logger.Log.Warnf("⚠️ 初始化完成,但有 %d 个模块失败 (总耗时: %v)", - len(errors), totalElapsed) - log.Printf("📊 统计: 成功=%d, 失败=%d, 跳过=%d", - successCount, len(errors), skippedCount) - - // 返回合并的错误 - return fmt.Errorf("以下模块初始化失败: %v", errors) - } - - log.Printf("✅ 所有模块初始化完成 (总耗时: %v)", totalElapsed) - log.Printf("📊 统计: 成功=%d, 跳过=%d", successCount, skippedCount) - return nil -} - -// GetRegistered 获取已注册的钩子列表(用于调试) -func GetRegistered() []Hook { - hooksMu.Lock() - defer hooksMu.Unlock() - - hooksCopy := make([]Hook, len(hooks)) - copy(hooksCopy, hooks) - return hooksCopy -} - -// Clear 清除所有钩子(用于测试) -func Clear() { - hooksMu.Lock() - defer hooksMu.Unlock() - hooks = nil -} - -// Count 返回已注册的钩子数量 -func Count() int { - hooksMu.Lock() - defer hooksMu.Unlock() - return len(hooks) -} diff --git a/bootstrap/context.go b/bootstrap/context.go deleted file mode 100644 index 3616d004..00000000 --- a/bootstrap/context.go +++ /dev/null @@ -1,49 +0,0 @@ -package bootstrap - -import ( - "context" - "fmt" - "nofx/config" - "sync" -) - -// Context 初始化上下文,用于在钩子之间传递数据 -type Context struct { - Config *config.Config - Data map[string]interface{} // 存储模块之间共享的数据(如数据库实例) - ctx context.Context - mu sync.RWMutex -} - -// NewContext 创建新的初始化上下文 -func NewContext(cfg *config.Config) *Context { - return &Context{ - Config: cfg, - Data: make(map[string]interface{}), - ctx: context.Background(), - } -} - -// Set 存储数据到上下文 -func (c *Context) Set(key string, value interface{}) { - c.mu.Lock() - defer c.mu.Unlock() - c.Data[key] = value -} - -// Get 从上下文获取数据 -func (c *Context) Get(key string) (interface{}, bool) { - c.mu.RLock() - defer c.mu.RUnlock() - val, ok := c.Data[key] - return val, ok -} - -// MustGet 从上下文获取数据,不存在则 panic -func (c *Context) MustGet(key string) interface{} { - val, ok := c.Get(key) - if !ok { - panic(fmt.Sprintf("context key '%s' not found", key)) - } - return val -} diff --git a/bootstrap/hook_builder.go b/bootstrap/hook_builder.go deleted file mode 100644 index 5d88d175..00000000 --- a/bootstrap/hook_builder.go +++ /dev/null @@ -1,27 +0,0 @@ -package bootstrap - -// Hook 初始化钩子 -type Hook struct { - Name string // 钩子名称(模块名) - Priority int // 优先级(越小越先执行) - Func func(*Context) error // 初始化函数 - Enabled func(*Context) bool // 条件函数,返回 false 则跳过 - ErrorPolicy ErrorPolicy // 错误处理策略 -} - -// HookBuilder 钩子构建器(用于链式调用) -type HookBuilder struct { - hook *Hook -} - -// EnabledIf 设置条件函数(链式调用) -func (b *HookBuilder) EnabledIf(fn func(*Context) bool) *HookBuilder { - b.hook.Enabled = fn - return b -} - -// OnError 设置错误处理策略(链式调用) -func (b *HookBuilder) OnError(policy ErrorPolicy) *HookBuilder { - b.hook.ErrorPolicy = policy - return b -} diff --git a/bootstrap/init_hook.go b/bootstrap/init_hook.go deleted file mode 100644 index d31283c5..00000000 --- a/bootstrap/init_hook.go +++ /dev/null @@ -1,22 +0,0 @@ -package bootstrap - -import "nofx/config" - -type InitHook func(config *config.Config) error - -var InitHooks []InitHook - -// RegisterInitHook 注册初始化钩子 -func RegisterInitHook(hook InitHook) { - InitHooks = append(InitHooks, hook) -} - -// RunInitHooks 运行所有注册的初始化钩子 -func RunInitHooks(c *config.Config) error { - for _, hookF := range InitHooks { - if err := hookF(c); err != nil { - return err - } - } - return nil -} diff --git a/config/config.go b/config/config.go index 81ff3cea..26d89cd1 100644 --- a/config/config.go +++ b/config/config.go @@ -3,7 +3,7 @@ package config import ( "encoding/json" "fmt" - "log" + "nofx/logger" "os" ) @@ -15,16 +15,7 @@ type LeverageConfig struct { // LogConfig 日志配置 type LogConfig struct { - Level string `json:"level"` // 日志级别: debug, info, warn, error (默认: info) - Telegram *TelegramConfig `json:"telegram"` // Telegram推送配置(可选) -} - -// TelegramConfig Telegram推送配置(简化版,只保留必需字段) -type TelegramConfig struct { - Enabled bool `json:"enabled"` // 是否启用(默认: false) - BotToken string `json:"bot_token"` // Bot Token - ChatID int64 `json:"chat_id"` // Chat ID - MinLevel string `json:"min_level"` // 最低日志级别,该级别及以上的日志会推送到Telegram(可选,默认: error) + Level string `json:"level"` // 日志级别: debug, info, warn, error (默认: info) } // Config 总配置 @@ -41,14 +32,14 @@ type Config struct { Leverage LeverageConfig `json:"leverage"` JWTSecret string `json:"jwt_secret"` DataKLineTime string `json:"data_k_line_time"` - Log *LogConfig `json:"log"` // 日志配置 + Log *LogConfig `json:"nofx/logger"` // 日志配置 } // LoadConfig 从文件加载配置 func LoadConfig(filename string) (*Config, error) { // 检查filename是否存在 if _, err := os.Stat(filename); os.IsNotExist(err) { - log.Printf("📄 %s不存在,使用默认配置", filename) + logger.Infof("📄 %s不存在,使用默认配置", filename) return &Config{}, nil } diff --git a/config/database.go b/config/database.go deleted file mode 100644 index 466550f3..00000000 --- a/config/database.go +++ /dev/null @@ -1,1735 +0,0 @@ -package config - -import ( - "crypto/rand" - "database/sql" - "encoding/base32" - "encoding/json" - "errors" - "fmt" - "log" - "nofx/crypto" - "nofx/market" - "os" - "slices" - "strings" - "time" - - _ "modernc.org/sqlite" -) - -// DatabaseInterface 定义了数据库实现需要提供的方法集合 -type DatabaseInterface interface { - SetCryptoService(cs *crypto.CryptoService) - CreateUser(user *User) error - GetUserByEmail(email string) (*User, error) - GetUserByID(userID string) (*User, error) - GetAllUsers() ([]string, error) - UpdateUserOTPVerified(userID string, verified bool) error - GetAIModels(userID string) ([]*AIModelConfig, error) - UpdateAIModel(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error - GetExchanges(userID string) ([]*ExchangeConfig, error) - UpdateExchange(userID, id string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, lighterWalletAddr, lighterPrivateKey string) error - CreateAIModel(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error - CreateExchange(userID, id, name, typ string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error - CreateTrader(trader *TraderRecord) error - GetTraders(userID string) ([]*TraderRecord, error) - UpdateTraderStatus(userID, id string, isRunning bool) error - UpdateTrader(trader *TraderRecord) error - UpdateTraderInitialBalance(userID, id string, newBalance float64) error - UpdateTraderCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error - DeleteTrader(userID, id string) error - GetTraderConfig(userID, traderID string) (*TraderRecord, *AIModelConfig, *ExchangeConfig, error) - GetSystemConfig(key string) (string, error) - SetSystemConfig(key, value string) error - CreateUserSignalSource(userID, coinPoolURL, oiTopURL string) error - GetUserSignalSource(userID string) (*UserSignalSource, error) - UpdateUserSignalSource(userID, coinPoolURL, oiTopURL string) error - GetCustomCoins() []string - LoadBetaCodesFromFile(filePath string) error - ValidateBetaCode(code string) (bool, error) - UseBetaCode(code, userEmail string) error - GetBetaCodeStats() (total, used int, err error) - Close() error -} - -// Database 配置数据库 -type Database struct { - db *sql.DB - cryptoService *crypto.CryptoService -} - -// NewDatabase 创建配置数据库 -func NewDatabase(dbPath string) (*Database, error) { - db, err := sql.Open("sqlite", dbPath) - if err != nil { - return nil, fmt.Errorf("打开数据库失败: %w", err) - } - db.SetMaxOpenConns(1) - db.SetMaxIdleConns(1) - if _, err := db.Exec(`PRAGMA foreign_keys = ON`); err != nil { - return nil, fmt.Errorf("启用外键失败: %w", err) - } - if err := tuneSQLiteConnection(db); err != nil { - return nil, err - } - - // 🔒 启用 WAL 模式,提高并发性能和崩溃恢复能力 - // WAL (Write-Ahead Logging) 模式的优势: - // 1. 更好的并发性能:读操作不会被写操作阻塞 - // 2. 崩溃安全:即使在断电或强制终止时也能保证数据完整性 - // 3. 更快的写入:不需要每次都写入主数据库文件 - if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { - db.Close() - return nil, fmt.Errorf("启用WAL模式失败: %w", err) - } - - // 🔒 设置 synchronous=FULL 确保数据持久性 - // FULL (2) 模式: 确保数据在关键时刻完全写入磁盘 - // 配合 WAL 模式,在保证数据安全的同时获得良好性能 - if _, err := db.Exec("PRAGMA synchronous=FULL"); err != nil { - db.Close() - return nil, fmt.Errorf("设置synchronous失败: %w", err) - } - - database := &Database{db: db} - if err := database.createTables(); err != nil { - return nil, fmt.Errorf("创建表失败: %w", err) - } - if err := database.ensureBacktestRunColumns(); err != nil { - return nil, fmt.Errorf("初始化回测表结构失败: %w", err) - } - - // 确保存在默认用户(用于外键约束和默认配置种子) - if _, err := db.Exec(` - INSERT OR IGNORE INTO users (id, email, password_hash, otp_secret, otp_verified) - VALUES ('default', 'default@local', '__default__', '', 1) - `); err != nil { - return nil, fmt.Errorf("创建默认用户失败: %w", err) - } - - if err := database.initDefaultData(); err != nil { - return nil, fmt.Errorf("初始化默认数据失败: %w", err) - } - - log.Printf("✅ 数据库已启用 WAL 模式和 FULL 同步,数据持久性得到保证") - return database, nil -} - -// createTables 创建数据库表 -func (d *Database) createTables() error { - queries := []string{ - // AI模型配置表 - `CREATE TABLE IF NOT EXISTS ai_models ( - id TEXT PRIMARY KEY, - user_id TEXT NOT NULL DEFAULT 'default', - name TEXT NOT NULL, - provider TEXT NOT NULL, - enabled BOOLEAN DEFAULT 0, - api_key TEXT DEFAULT '', - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE - )`, - - // 交易所配置表 - `CREATE TABLE IF NOT EXISTS exchanges ( - id TEXT PRIMARY KEY, - user_id TEXT NOT NULL DEFAULT 'default', - name TEXT NOT NULL, - type TEXT NOT NULL, -- 'cex' or 'dex' - enabled BOOLEAN DEFAULT 0, - api_key TEXT DEFAULT '', - secret_key TEXT DEFAULT '', - testnet BOOLEAN DEFAULT 0, - -- Hyperliquid 特定字段 - hyperliquid_wallet_addr TEXT DEFAULT '', - -- Aster 特定字段 - aster_user TEXT DEFAULT '', - aster_signer TEXT DEFAULT '', - aster_private_key TEXT DEFAULT '', - -- LIGHTER 特定字段 - lighter_wallet_addr TEXT DEFAULT '', - lighter_private_key TEXT DEFAULT '', - lighter_api_key_private_key TEXT DEFAULT '', - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE - )`, - - // 用户信号源配置表 - `CREATE TABLE IF NOT EXISTS user_signal_sources ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id TEXT NOT NULL, - coin_pool_url TEXT DEFAULT '', - oi_top_url TEXT DEFAULT '', - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, - UNIQUE(user_id) - )`, - - // 交易员配置表 - `CREATE TABLE IF NOT EXISTS traders ( - id TEXT PRIMARY KEY, - user_id TEXT NOT NULL DEFAULT 'default', - name TEXT NOT NULL, - ai_model_id TEXT NOT NULL, - exchange_id TEXT NOT NULL, - initial_balance REAL NOT NULL, - scan_interval_minutes INTEGER DEFAULT 3, - is_running BOOLEAN DEFAULT 0, - btc_eth_leverage INTEGER DEFAULT 5, - altcoin_leverage INTEGER DEFAULT 5, - trading_symbols TEXT DEFAULT '', - use_coin_pool BOOLEAN DEFAULT 0, - use_oi_top BOOLEAN DEFAULT 0, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE - )`, - - // 用户表 - `CREATE TABLE IF NOT EXISTS users ( - id TEXT PRIMARY KEY, - email TEXT UNIQUE NOT NULL, - password_hash TEXT NOT NULL, - otp_secret TEXT, - otp_verified BOOLEAN DEFAULT 0, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - )`, - - // 系统配置表 - `CREATE TABLE IF NOT EXISTS system_config ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - )`, - - // 回测运行主表 - `CREATE TABLE IF NOT EXISTS backtest_runs ( - run_id TEXT PRIMARY KEY, - user_id TEXT NOT NULL DEFAULT 'default', - config_json TEXT NOT NULL DEFAULT '', - state TEXT NOT NULL DEFAULT 'created', - label TEXT DEFAULT '', - symbol_count INTEGER DEFAULT 0, - decision_tf TEXT DEFAULT '', - processed_bars INTEGER DEFAULT 0, - progress_pct REAL DEFAULT 0, - equity_last REAL DEFAULT 0, - max_drawdown_pct REAL DEFAULT 0, - liquidated BOOLEAN DEFAULT 0, - liquidation_note TEXT DEFAULT '', - prompt_template TEXT DEFAULT '', - custom_prompt TEXT DEFAULT '', - override_prompt BOOLEAN DEFAULT 0, - ai_provider TEXT DEFAULT '', - ai_model TEXT DEFAULT '', - last_error TEXT DEFAULT '', - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - )`, - - // 回测检查点 - `CREATE TABLE IF NOT EXISTS backtest_checkpoints ( - run_id TEXT PRIMARY KEY, - payload BLOB NOT NULL, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE - )`, - - // 回测权益曲线 - `CREATE TABLE IF NOT EXISTS backtest_equity ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - run_id TEXT NOT NULL, - ts INTEGER NOT NULL, - equity REAL NOT NULL, - available REAL NOT NULL, - pnl REAL NOT NULL, - pnl_pct REAL NOT NULL, - dd_pct REAL NOT NULL, - cycle INTEGER NOT NULL, - FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE - )`, - - // 回测交易记录 - `CREATE TABLE IF NOT EXISTS backtest_trades ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - run_id TEXT NOT NULL, - ts INTEGER NOT NULL, - symbol TEXT NOT NULL, - action TEXT NOT NULL, - side TEXT DEFAULT '', - qty REAL DEFAULT 0, - price REAL DEFAULT 0, - fee REAL DEFAULT 0, - slippage REAL DEFAULT 0, - order_value REAL DEFAULT 0, - realized_pnl REAL DEFAULT 0, - leverage INTEGER DEFAULT 0, - cycle INTEGER DEFAULT 0, - position_after REAL DEFAULT 0, - liquidation BOOLEAN DEFAULT 0, - note TEXT DEFAULT '', - FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE - )`, - - // 回测指标 - `CREATE TABLE IF NOT EXISTS backtest_metrics ( - run_id TEXT PRIMARY KEY, - payload BLOB NOT NULL, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE - )`, - - // 回测决策日志 - `CREATE TABLE IF NOT EXISTS backtest_decisions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - run_id TEXT NOT NULL, - cycle INTEGER NOT NULL, - payload BLOB NOT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE - )`, - - // 索引 - `CREATE INDEX IF NOT EXISTS idx_backtest_runs_state ON backtest_runs(state, updated_at)`, - `CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`, - `CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`, - `CREATE INDEX IF NOT EXISTS idx_backtest_decisions_run_cycle ON backtest_decisions(run_id, cycle)`, - - // 内测码表 - `CREATE TABLE IF NOT EXISTS beta_codes ( - code TEXT PRIMARY KEY, - used BOOLEAN DEFAULT 0, - used_by TEXT DEFAULT '', - used_at DATETIME DEFAULT NULL, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - )`, - - // 触发器:自动更新 updated_at - `CREATE TRIGGER IF NOT EXISTS update_users_updated_at - AFTER UPDATE ON users - BEGIN - UPDATE users SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; - END`, - - `CREATE TRIGGER IF NOT EXISTS update_ai_models_updated_at - AFTER UPDATE ON ai_models - BEGIN - UPDATE ai_models SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; - END`, - - `CREATE TRIGGER IF NOT EXISTS update_exchanges_updated_at - AFTER UPDATE ON exchanges - BEGIN - UPDATE exchanges SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; - END`, - - `CREATE TRIGGER IF NOT EXISTS update_traders_updated_at - AFTER UPDATE ON traders - BEGIN - UPDATE traders SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; - END`, - - `CREATE TRIGGER IF NOT EXISTS update_user_signal_sources_updated_at - AFTER UPDATE ON user_signal_sources - BEGIN - UPDATE user_signal_sources SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; - END`, - - `CREATE TRIGGER IF NOT EXISTS update_system_config_updated_at - AFTER UPDATE ON system_config - BEGIN - UPDATE system_config SET updated_at = CURRENT_TIMESTAMP WHERE key = NEW.key; - END`, - } - - for _, query := range queries { - if _, err := d.db.Exec(query); err != nil { - return fmt.Errorf("执行SQL失败 [%s]: %w", query, err) - } - } - - // 为现有数据库添加新字段(向后兼容) - alterQueries := []string{ - `ALTER TABLE exchanges ADD COLUMN hyperliquid_wallet_addr TEXT DEFAULT ''`, - `ALTER TABLE exchanges ADD COLUMN aster_user TEXT DEFAULT ''`, - `ALTER TABLE exchanges ADD COLUMN aster_signer TEXT DEFAULT ''`, - `ALTER TABLE exchanges ADD COLUMN aster_private_key TEXT DEFAULT ''`, - `ALTER TABLE exchanges ADD COLUMN lighter_wallet_addr TEXT DEFAULT ''`, - `ALTER TABLE exchanges ADD COLUMN lighter_private_key TEXT DEFAULT ''`, - `ALTER TABLE exchanges ADD COLUMN lighter_api_key_private_key TEXT DEFAULT ''`, - `ALTER TABLE traders ADD COLUMN custom_prompt TEXT DEFAULT ''`, - `ALTER TABLE traders ADD COLUMN override_base_prompt BOOLEAN DEFAULT 0`, - `ALTER TABLE traders ADD COLUMN is_cross_margin BOOLEAN DEFAULT 1`, // 默认为全仓模式 - `ALTER TABLE traders ADD COLUMN use_default_coins BOOLEAN DEFAULT 1`, // 默认使用默认币种 - `ALTER TABLE traders ADD COLUMN custom_coins TEXT DEFAULT ''`, // 自定义币种列表(JSON格式) - `ALTER TABLE traders ADD COLUMN btc_eth_leverage INTEGER DEFAULT 5`, // BTC/ETH杠杆倍数 - `ALTER TABLE traders ADD COLUMN altcoin_leverage INTEGER DEFAULT 5`, // 山寨币杠杆倍数 - `ALTER TABLE traders ADD COLUMN trading_symbols TEXT DEFAULT ''`, // 交易币种,逗号分隔 - `ALTER TABLE traders ADD COLUMN use_coin_pool BOOLEAN DEFAULT 0`, // 是否使用COIN POOL信号源 - `ALTER TABLE traders ADD COLUMN use_oi_top BOOLEAN DEFAULT 0`, // 是否使用OI TOP信号源 - `ALTER TABLE traders ADD COLUMN system_prompt_template TEXT DEFAULT 'default'`, // 系统提示词模板名称 - `ALTER TABLE ai_models ADD COLUMN custom_api_url TEXT DEFAULT ''`, // 自定义API地址 - `ALTER TABLE ai_models ADD COLUMN custom_model_name TEXT DEFAULT ''`, // 自定义模型名称 - } - - for _, query := range alterQueries { - // 忽略已存在字段的错误 - d.db.Exec(query) - } - - // 检查是否需要迁移exchanges表的主键结构 - err := d.migrateExchangesTable() - if err != nil { - log.Printf("⚠️ 迁移exchanges表失败: %v", err) - } - - // 修复traders表的外键约束问题 - err = d.migrateTradersTable() - if err != nil { - log.Printf("⚠️ 迁移traders表失败: %v", err) - } - - return nil -} - -func (d *Database) ensureBacktestRunColumns() error { - addColumn := func(table, column, definition string) error { - exists, err := columnExists(d.db, table, column) - if err != nil { - return err - } - if exists { - return nil - } - _, err = d.db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, definition)) - return err - } - if err := addColumn("backtest_runs", "label", "TEXT DEFAULT ''"); err != nil { - return err - } - if err := addColumn("backtest_runs", "last_error", "TEXT DEFAULT ''"); err != nil { - return err - } - if err := addColumn("backtest_trades", "leverage", "INTEGER DEFAULT 0"); err != nil { - return err - } - return nil -} - -func columnExists(db *sql.DB, table, column string) (bool, error) { - rows, err := db.Query(fmt.Sprintf("PRAGMA table_info(%s)", table)) - if err != nil { - return false, err - } - defer rows.Close() - for rows.Next() { - var ( - cid int - name string - ctype string - notnull int - dfltValue any - primaryKey int - ) - if err := rows.Scan(&cid, &name, &ctype, ¬null, &dfltValue, &primaryKey); err != nil { - return false, err - } - if name == column { - return true, nil - } - } - return false, rows.Err() -} - -func tuneSQLiteConnection(db *sql.DB) error { - if db == nil { - return fmt.Errorf("db is nil") - } - statements := []string{ - `PRAGMA busy_timeout = 5000`, - `PRAGMA journal_mode = WAL`, - `PRAGMA synchronous = NORMAL`, - } - for _, stmt := range statements { - if _, err := db.Exec(stmt); err != nil { - return fmt.Errorf("执行 %s 失败: %w", stmt, err) - } - } - return nil -} - -// initDefaultData 初始化默认数据 -func (d *Database) initDefaultData() error { - // 初始化AI模型(使用default用户) - aiModels := []struct { - id, name, provider string - }{ - {"deepseek", "DeepSeek", "deepseek"}, - {"qwen", "Qwen", "qwen"}, - } - - for _, model := range aiModels { - _, err := d.db.Exec(` - INSERT OR IGNORE INTO ai_models (id, user_id, name, provider, enabled) - VALUES (?, 'default', ?, ?, 0) - `, model.id, model.name, model.provider) - if err != nil { - return fmt.Errorf("初始化AI模型失败: %w", err) - } - } - - // 初始化交易所(使用default用户) - exchanges := []struct { - id, name, typ string - }{ - {"binance", "Binance Futures", "binance"}, - {"bybit", "Bybit Futures", "bybit"}, - {"hyperliquid", "Hyperliquid", "hyperliquid"}, - {"aster", "Aster DEX", "aster"}, - {"lighter", "LIGHTER DEX", "lighter"}, - } - - for _, exchange := range exchanges { - _, err := d.db.Exec(` - INSERT OR IGNORE INTO exchanges (id, user_id, name, type, enabled) - VALUES (?, 'default', ?, ?, 0) - `, exchange.id, exchange.name, exchange.typ) - if err != nil { - return fmt.Errorf("初始化交易所失败: %w", err) - } - } - - // 初始化系统配置 - 创建所有字段,设置默认值,后续由config.json同步更新 - systemConfigs := map[string]string{ - "beta_mode": "false", // 默认关闭内测模式 - "api_server_port": "8080", // 默认API端口 - "use_default_coins": "true", // 默认使用内置币种列表 - "default_coins": `["BTCUSDT","ETHUSDT","SOLUSDT","BNBUSDT","XRPUSDT","DOGEUSDT","ADAUSDT","HYPEUSDT"]`, // 默认币种列表(JSON格式) - "max_daily_loss": "10.0", // 最大日损失百分比 - "max_drawdown": "20.0", // 最大回撤百分比 - "stop_trading_minutes": "60", // 停止交易时间(分钟) - "btc_eth_leverage": "5", // BTC/ETH杠杆倍数 - "altcoin_leverage": "5", // 山寨币杠杆倍数 - "jwt_secret": "", // JWT密钥,默认为空,由config.json或系统生成 - "registration_enabled": "true", // 默认允许注册 - } - - for key, value := range systemConfigs { - _, err := d.db.Exec(` - INSERT OR IGNORE INTO system_config (key, value) - VALUES (?, ?) - `, key, value) - if err != nil { - return fmt.Errorf("初始化系统配置失败: %w", err) - } - } - - return nil -} - -// migrateExchangesTable 迁移exchanges表支持多用户 -func (d *Database) migrateExchangesTable() error { - // 检查是否已经迁移过 - var count int - err := d.db.QueryRow(` - SELECT COUNT(*) FROM sqlite_master - WHERE type='table' AND name='exchanges_new' - `).Scan(&count) - if err != nil { - return err - } - - // 如果已经迁移过,直接返回 - if count > 0 { - return nil - } - - log.Printf("🔄 开始迁移exchanges表...") - - // 创建新的exchanges表,使用复合主键 - _, err = d.db.Exec(` - CREATE TABLE exchanges_new ( - id TEXT NOT NULL, - user_id TEXT NOT NULL DEFAULT 'default', - name TEXT NOT NULL, - type TEXT NOT NULL, - enabled BOOLEAN DEFAULT 0, - api_key TEXT DEFAULT '', - secret_key TEXT DEFAULT '', - testnet BOOLEAN DEFAULT 0, - hyperliquid_wallet_addr TEXT DEFAULT '', - aster_user TEXT DEFAULT '', - aster_signer TEXT DEFAULT '', - aster_private_key TEXT DEFAULT '', - lighter_wallet_addr TEXT DEFAULT '', - lighter_private_key TEXT DEFAULT '', - lighter_api_key_private_key TEXT DEFAULT '', - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - PRIMARY KEY (id, user_id), - FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE - ) - `) - if err != nil { - return fmt.Errorf("创建新exchanges表失败: %w", err) - } - - // 复制数据到新表 - _, err = d.db.Exec(` - INSERT INTO exchanges_new - SELECT * FROM exchanges - `) - if err != nil { - return fmt.Errorf("复制数据失败: %w", err) - } - - // 删除旧表 - _, err = d.db.Exec(`DROP TABLE exchanges`) - if err != nil { - return fmt.Errorf("删除旧表失败: %w", err) - } - - // 重命名新表 - _, err = d.db.Exec(`ALTER TABLE exchanges_new RENAME TO exchanges`) - if err != nil { - return fmt.Errorf("重命名表失败: %w", err) - } - - // 重新创建触发器 - _, err = d.db.Exec(` - CREATE TRIGGER IF NOT EXISTS update_exchanges_updated_at - AFTER UPDATE ON exchanges - BEGIN - UPDATE exchanges SET updated_at = CURRENT_TIMESTAMP - WHERE id = NEW.id AND user_id = NEW.user_id; - END - `) - if err != nil { - return fmt.Errorf("创建触发器失败: %w", err) - } - - log.Printf("✅ exchanges表迁移完成") - return nil -} - -// migrateTradersTable 迁移traders表,移除外键约束 -func (d *Database) migrateTradersTable() error { - // 检查traders表是否存在外键约束(通过尝试创建一个测试记录来判断) - // 如果表已经没有外键约束,则跳过迁移 - var tableSQL string - err := d.db.QueryRow(`SELECT sql FROM sqlite_master WHERE type='table' AND name='traders'`).Scan(&tableSQL) - if err != nil { - // 表不存在,无需迁移 - return nil - } - - // 检查是否包含 FOREIGN KEY (exchange_id) 或 FOREIGN KEY (ai_model_id) - if !strings.Contains(tableSQL, "FOREIGN KEY (exchange_id)") && !strings.Contains(tableSQL, "FOREIGN KEY (ai_model_id)") { - // 已经没有这些外键约束,无需迁移 - return nil - } - - log.Printf("🔄 开始迁移traders表,移除外键约束...") - - // 创建新的traders表,不包含exchange_id和ai_model_id的外键约束 - _, err = d.db.Exec(` - CREATE TABLE traders_new ( - id TEXT PRIMARY KEY, - user_id TEXT NOT NULL DEFAULT 'default', - name TEXT NOT NULL, - ai_model_id TEXT NOT NULL, - exchange_id TEXT NOT NULL, - initial_balance REAL NOT NULL, - scan_interval_minutes INTEGER DEFAULT 3, - is_running BOOLEAN DEFAULT 0, - btc_eth_leverage INTEGER DEFAULT 5, - altcoin_leverage INTEGER DEFAULT 5, - trading_symbols TEXT DEFAULT '', - use_coin_pool BOOLEAN DEFAULT 0, - use_oi_top BOOLEAN DEFAULT 0, - custom_prompt TEXT DEFAULT '', - override_base_prompt BOOLEAN DEFAULT 0, - system_prompt_template TEXT DEFAULT 'default', - is_cross_margin BOOLEAN DEFAULT 1, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE - ) - `) - if err != nil { - return fmt.Errorf("创建新traders表失败: %w", err) - } - - // 复制数据到新表 - _, err = d.db.Exec(` - INSERT INTO traders_new (id, user_id, name, ai_model_id, exchange_id, initial_balance, - scan_interval_minutes, is_running, btc_eth_leverage, altcoin_leverage, trading_symbols, - use_coin_pool, use_oi_top, custom_prompt, override_base_prompt, system_prompt_template, - is_cross_margin, created_at, updated_at) - SELECT id, user_id, name, ai_model_id, exchange_id, initial_balance, - scan_interval_minutes, is_running, - COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), - COALESCE(trading_symbols, ''), COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), - COALESCE(custom_prompt, ''), COALESCE(override_base_prompt, 0), - COALESCE(system_prompt_template, 'default'), COALESCE(is_cross_margin, 1), - created_at, updated_at - FROM traders - `) - if err != nil { - // 如果复制失败,删除新表 - d.db.Exec(`DROP TABLE traders_new`) - return fmt.Errorf("复制traders数据失败: %w", err) - } - - // 删除旧表 - _, err = d.db.Exec(`DROP TABLE traders`) - if err != nil { - return fmt.Errorf("删除旧traders表失败: %w", err) - } - - // 重命名新表 - _, err = d.db.Exec(`ALTER TABLE traders_new RENAME TO traders`) - if err != nil { - return fmt.Errorf("重命名traders表失败: %w", err) - } - - log.Printf("✅ traders表迁移完成,已移除外键约束") - return nil -} - -// User 用户配置 -type User struct { - ID string `json:"id"` - Email string `json:"email"` - PasswordHash string `json:"-"` // 不返回到前端 - OTPSecret string `json:"-"` // 不返回到前端 - OTPVerified bool `json:"otp_verified"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} - -// AIModelConfig AI模型配置 -type AIModelConfig struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - Provider string `json:"provider"` - Enabled bool `json:"enabled"` - APIKey string `json:"apiKey"` - CustomAPIURL string `json:"customApiUrl"` - CustomModelName string `json:"customModelName"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} - -// ExchangeConfig 交易所配置 -type ExchangeConfig struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - Type string `json:"type"` - Enabled bool `json:"enabled"` - APIKey string `json:"apiKey"` // For Binance: API Key; For Hyperliquid: Agent Private Key (should have ~0 balance) - SecretKey string `json:"secretKey"` // For Binance: Secret Key; Not used for Hyperliquid - Testnet bool `json:"testnet"` - // Hyperliquid Agent Wallet configuration (following official best practices) - // Reference: https://hyperliquid.gitbook.io/hyperliquid-docs/for-developers/api/nonces-and-api-wallets - HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` // Main Wallet Address (holds funds, never expose private key) - // Aster 特定字段 - AsterUser string `json:"asterUser"` - AsterSigner string `json:"asterSigner"` - AsterPrivateKey string `json:"asterPrivateKey"` - // LIGHTER 特定字段 - LighterWalletAddr string `json:"lighterWalletAddr"` // Ethereum 钱包地址 (L1) - LighterPrivateKey string `json:"lighterPrivateKey"` // L1私钥(用于识别账户) - LighterAPIKeyPrivateKey string `json:"lighterAPIKeyPrivateKey"` // API Key私钥(40字节,用于签名交易) - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} - -// TraderRecord 交易员配置(数据库实体) -type TraderRecord struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - AIModelID string `json:"ai_model_id"` - ExchangeID string `json:"exchange_id"` - InitialBalance float64 `json:"initial_balance"` - ScanIntervalMinutes int `json:"scan_interval_minutes"` - IsRunning bool `json:"is_running"` - BTCETHLeverage int `json:"btc_eth_leverage"` // BTC/ETH杠杆倍数 - AltcoinLeverage int `json:"altcoin_leverage"` // 山寨币杠杆倍数 - TradingSymbols string `json:"trading_symbols"` // 交易币种,逗号分隔 - UseCoinPool bool `json:"use_coin_pool"` // 是否使用COIN POOL信号源 - UseOITop bool `json:"use_oi_top"` // 是否使用OI TOP信号源 - CustomPrompt string `json:"custom_prompt"` // 自定义交易策略prompt - OverrideBasePrompt bool `json:"override_base_prompt"` // 是否覆盖基础prompt - SystemPromptTemplate string `json:"system_prompt_template"` // 系统提示词模板名称 - IsCrossMargin bool `json:"is_cross_margin"` // 是否为全仓模式(true=全仓,false=逐仓) - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} - -// UserSignalSource 用户信号源配置 -type UserSignalSource struct { - ID int `json:"id"` - UserID string `json:"user_id"` - CoinPoolURL string `json:"coin_pool_url"` - OITopURL string `json:"oi_top_url"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} - -// GenerateOTPSecret 生成OTP密钥 -func GenerateOTPSecret() (string, error) { - secret := make([]byte, 20) - _, err := rand.Read(secret) - if err != nil { - return "", err - } - return base32.StdEncoding.EncodeToString(secret), nil -} - -// CreateUser 创建用户 -func (d *Database) CreateUser(user *User) error { - _, err := d.db.Exec(` - INSERT INTO users (id, email, password_hash, otp_secret, otp_verified) - VALUES (?, ?, ?, ?, ?) - `, user.ID, user.Email, user.PasswordHash, user.OTPSecret, user.OTPVerified) - return err -} - -// EnsureAdminUser 确保admin用户存在(用于管理员模式) -func (d *Database) EnsureAdminUser() error { - // 检查admin用户是否已存在 - var count int - err := d.db.QueryRow(`SELECT COUNT(*) FROM users WHERE id = 'admin'`).Scan(&count) - if err != nil { - return err - } - - // 如果已存在,直接返回 - if count > 0 { - return nil - } - - // 创建admin用户(密码为空,因为管理员模式下不需要密码) - adminUser := &User{ - ID: "admin", - Email: "admin@localhost", - PasswordHash: "", // 管理员模式下不使用密码 - OTPSecret: "", - OTPVerified: true, - } - - return d.CreateUser(adminUser) -} - -// GetUserByEmail 通过邮箱获取用户 -func (d *Database) GetUserByEmail(email string) (*User, error) { - var user User - var createdAt, updatedAt string - err := d.db.QueryRow(` - SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at - FROM users WHERE email = ? - `, email).Scan( - &user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret, - &user.OTPVerified, &createdAt, &updatedAt, - ) - if err != nil { - return nil, err - } - user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - return &user, nil -} - -// GetUserByID 通过ID获取用户 -func (d *Database) GetUserByID(userID string) (*User, error) { - var user User - var createdAt, updatedAt string - err := d.db.QueryRow(` - SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at - FROM users WHERE id = ? - `, userID).Scan( - &user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret, - &user.OTPVerified, &createdAt, &updatedAt, - ) - if err != nil { - return nil, err - } - user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - return &user, nil -} - -// GetAllUsers 获取所有用户ID列表 -func (d *Database) GetAllUsers() ([]string, error) { - rows, err := d.db.Query(`SELECT id FROM users ORDER BY id`) - if err != nil { - return nil, err - } - defer rows.Close() - - var userIDs []string - for rows.Next() { - var userID string - if err := rows.Scan(&userID); err != nil { - return nil, err - } - userIDs = append(userIDs, userID) - } - return userIDs, nil -} - -// UpdateUserOTPVerified 更新用户OTP验证状态 -func (d *Database) UpdateUserOTPVerified(userID string, verified bool) error { - _, err := d.db.Exec(`UPDATE users SET otp_verified = ? WHERE id = ?`, verified, userID) - return err -} - -// UpdateUserPassword 更新用户密码 -func (d *Database) UpdateUserPassword(userID, passwordHash string) error { - _, err := d.db.Exec(` - UPDATE users - SET password_hash = ?, updated_at = CURRENT_TIMESTAMP - WHERE id = ? - `, passwordHash, userID) - return err -} - -// GetAIModels 获取用户的AI模型配置 -func (d *Database) GetAIModels(userID string) ([]*AIModelConfig, error) { - rows, err := d.db.Query(` - SELECT id, user_id, name, provider, enabled, api_key, - COALESCE(custom_api_url, '') as custom_api_url, - COALESCE(custom_model_name, '') as custom_model_name, - created_at, updated_at - FROM ai_models WHERE user_id = ? ORDER BY id - `, userID) - if err != nil { - return nil, err - } - defer rows.Close() - - // 初始化为空切片而不是nil,确保JSON序列化为[]而不是null - models := make([]*AIModelConfig, 0) - for rows.Next() { - var model AIModelConfig - var createdAt, updatedAt string - err := rows.Scan( - &model.ID, &model.UserID, &model.Name, &model.Provider, - &model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName, - &createdAt, &updatedAt, - ) - if err != nil { - return nil, err - } - // 解析时间字符串 - model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - // 解密API Key - model.APIKey = d.decryptSensitiveData(model.APIKey) - models = append(models, &model) - } - - return models, nil -} - -// GetAIModel 根据模型ID和用户ID获取单个AI模型配置,若用户下不存在则回退到default用户。 -func (d *Database) GetAIModel(userID, modelID string) (*AIModelConfig, error) { - if modelID == "" { - return nil, fmt.Errorf("模型ID不能为空") - } - - candidates := []string{} - if userID != "" { - candidates = append(candidates, userID) - } - if userID != "default" { - candidates = append(candidates, "default") - } - if len(candidates) == 0 { - candidates = append(candidates, "default") - } - - for _, uid := range candidates { - var model AIModelConfig - var createdAt, updatedAt string - err := d.db.QueryRow(` - SELECT id, user_id, name, provider, enabled, api_key, - COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at - FROM ai_models - WHERE user_id = ? AND id = ? - LIMIT 1 - `, uid, modelID).Scan( - &model.ID, - &model.UserID, - &model.Name, - &model.Provider, - &model.Enabled, - &model.APIKey, - &model.CustomAPIURL, - &model.CustomModelName, - &createdAt, - &updatedAt, - ) - if err == nil { - // 解析时间字符串 - model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - // 解密API Key(与 GetAIModels 行为保持一致) - model.APIKey = d.decryptSensitiveData(model.APIKey) - return &model, nil - } - if !errors.Is(err, sql.ErrNoRows) { - return nil, err - } - } - - return nil, sql.ErrNoRows -} - -// GetDefaultAIModel 获取指定用户(或默认用户)的首个启用的AI模型。 -func (d *Database) GetDefaultAIModel(userID string) (*AIModelConfig, error) { - if userID == "" { - userID = "default" - } - model, err := d.firstEnabledAIModel(userID) - if err == nil { - return model, nil - } - if !errors.Is(err, sql.ErrNoRows) { - return nil, err - } - if userID != "default" { - return d.firstEnabledAIModel("default") - } - return nil, fmt.Errorf("请先在系统中配置可用的AI模型") -} - -func (d *Database) firstEnabledAIModel(userID string) (*AIModelConfig, error) { - var model AIModelConfig - var createdAt, updatedAt string - err := d.db.QueryRow(` - SELECT id, user_id, name, provider, enabled, api_key, - COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at - FROM ai_models - WHERE user_id = ? AND enabled = 1 - ORDER BY datetime(updated_at) DESC, id ASC - LIMIT 1 - `, userID).Scan( - &model.ID, - &model.UserID, - &model.Name, - &model.Provider, - &model.Enabled, - &model.APIKey, - &model.CustomAPIURL, - &model.CustomModelName, - &createdAt, - &updatedAt, - ) - if err != nil { - return nil, err - } - // 解析时间字符串 - model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - // 解密API Key,避免上层拿到加密串导致下游认证失败 - model.APIKey = d.decryptSensitiveData(model.APIKey) - return &model, nil -} - -// UpdateAIModel 更新AI模型配置,如果不存在则创建用户特定配置 -func (d *Database) UpdateAIModel(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error { - // 先尝试精确匹配 ID(新版逻辑,支持多个相同 provider 的模型) - var existingID string - err := d.db.QueryRow(` - SELECT id FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1 - `, userID, id).Scan(&existingID) - - if err == nil { - // 找到了现有配置(精确匹配 ID),更新它 - encryptedAPIKey := d.encryptSensitiveData(apiKey) - _, err = d.db.Exec(` - UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now') - WHERE id = ? AND user_id = ? - `, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID) - return err - } - - // ID 不存在,尝试兼容旧逻辑:将 id 作为 provider 查找 - provider := id - err = d.db.QueryRow(` - SELECT id FROM ai_models WHERE user_id = ? AND provider = ? LIMIT 1 - `, userID, provider).Scan(&existingID) - - if err == nil { - // 找到了现有配置(通过 provider 匹配,兼容旧版),更新它 - log.Printf("⚠️ 使用旧版 provider 匹配更新模型: %s -> %s", provider, existingID) - encryptedAPIKey := d.encryptSensitiveData(apiKey) - _, err = d.db.Exec(` - UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now') - WHERE id = ? AND user_id = ? - `, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID) - return err - } - - // 没有找到任何现有配置,创建新的 - // 推断 provider(从 id 中提取,或者直接使用 id) - if provider == id && (provider == "deepseek" || provider == "qwen") { - // id 本身就是 provider - provider = id - } else { - // 从 id 中提取 provider(假设格式是 userID_provider 或 timestamp_userID_provider) - parts := strings.Split(id, "_") - if len(parts) >= 2 { - provider = parts[len(parts)-1] // 取最后一部分作为 provider - } else { - provider = id - } - } - - // 获取模型的基本信息 - var name string - err = d.db.QueryRow(` - SELECT name FROM ai_models WHERE provider = ? LIMIT 1 - `, provider).Scan(&name) - if err != nil { - // 如果找不到基本信息,使用默认值 - if provider == "deepseek" { - name = "DeepSeek AI" - } else if provider == "qwen" { - name = "Qwen AI" - } else { - name = provider + " AI" - } - } - - // 如果传入的 ID 已经是完整格式(如 "admin_deepseek_custom1"),直接使用 - // 否则生成新的 ID - newModelID := id - if id == provider { - // id 就是 provider,生成新的用户特定 ID - newModelID = fmt.Sprintf("%s_%s", userID, provider) - } - - log.Printf("✓ 创建新的 AI 模型配置: ID=%s, Provider=%s, Name=%s", newModelID, provider, name) - encryptedAPIKey := d.encryptSensitiveData(apiKey) - _, err = d.db.Exec(` - INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url, custom_model_name, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now')) - `, newModelID, userID, name, provider, enabled, encryptedAPIKey, customAPIURL, customModelName) - - return err -} - -// GetExchanges 获取用户的交易所配置 -func (d *Database) GetExchanges(userID string) ([]*ExchangeConfig, error) { - rows, err := d.db.Query(` - SELECT id, user_id, name, type, enabled, api_key, secret_key, testnet, - COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr, - COALESCE(aster_user, '') as aster_user, - COALESCE(aster_signer, '') as aster_signer, - COALESCE(aster_private_key, '') as aster_private_key, - COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr, - COALESCE(lighter_private_key, '') as lighter_private_key, - COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key, - created_at, updated_at - FROM exchanges WHERE user_id = ? ORDER BY id - `, userID) - if err != nil { - return nil, err - } - defer rows.Close() - - // 初始化为空切片而不是nil,确保JSON序列化为[]而不是null - exchanges := make([]*ExchangeConfig, 0) - for rows.Next() { - var exchange ExchangeConfig - var createdAt, updatedAt string - err := rows.Scan( - &exchange.ID, &exchange.UserID, &exchange.Name, &exchange.Type, - &exchange.Enabled, &exchange.APIKey, &exchange.SecretKey, &exchange.Testnet, - &exchange.HyperliquidWalletAddr, &exchange.AsterUser, - &exchange.AsterSigner, &exchange.AsterPrivateKey, - &exchange.LighterWalletAddr, &exchange.LighterPrivateKey, - &exchange.LighterAPIKeyPrivateKey, - &createdAt, &updatedAt, - ) - if err != nil { - return nil, err - } - - // 解析时间字符串 - exchange.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - exchange.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - - // 解密敏感字段 - exchange.APIKey = d.decryptSensitiveData(exchange.APIKey) - exchange.SecretKey = d.decryptSensitiveData(exchange.SecretKey) - exchange.AsterPrivateKey = d.decryptSensitiveData(exchange.AsterPrivateKey) - exchange.LighterPrivateKey = d.decryptSensitiveData(exchange.LighterPrivateKey) - exchange.LighterAPIKeyPrivateKey = d.decryptSensitiveData(exchange.LighterAPIKeyPrivateKey) - - exchanges = append(exchanges, &exchange) - } - - return exchanges, nil -} - -// UpdateExchange 更新交易所配置,如果不存在则创建用户特定配置 -// 🔒 安全特性:空值不会覆盖现有的敏感字段(api_key, secret_key, aster_private_key, lighter_private_key) -func (d *Database) UpdateExchange(userID, id string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, lighterWalletAddr, lighterPrivateKey string) error { - log.Printf("🔧 UpdateExchange: userID=%s, id=%s, enabled=%v", userID, id, enabled) - - // 构建动态 UPDATE SET 子句 - // 基础字段:总是更新 - setClauses := []string{ - "enabled = ?", - "testnet = ?", - "hyperliquid_wallet_addr = ?", - "aster_user = ?", - "aster_signer = ?", - "lighter_wallet_addr = ?", - "updated_at = datetime('now')", - } - args := []interface{}{enabled, testnet, hyperliquidWalletAddr, asterUser, asterSigner, lighterWalletAddr} - - // 🔒 敏感字段:只在非空时更新(保护现有数据) - if apiKey != "" { - encryptedAPIKey := d.encryptSensitiveData(apiKey) - setClauses = append(setClauses, "api_key = ?") - args = append(args, encryptedAPIKey) - } - - if secretKey != "" { - encryptedSecretKey := d.encryptSensitiveData(secretKey) - setClauses = append(setClauses, "secret_key = ?") - args = append(args, encryptedSecretKey) - } - - if asterPrivateKey != "" { - encryptedAsterPrivateKey := d.encryptSensitiveData(asterPrivateKey) - setClauses = append(setClauses, "aster_private_key = ?") - args = append(args, encryptedAsterPrivateKey) - } - - if lighterPrivateKey != "" { - encryptedLighterPrivateKey := d.encryptSensitiveData(lighterPrivateKey) - setClauses = append(setClauses, "lighter_private_key = ?") - args = append(args, encryptedLighterPrivateKey) - } - - // WHERE 条件 - args = append(args, id, userID) - - // 构建完整的 UPDATE 语句 - query := fmt.Sprintf(` - UPDATE exchanges SET %s - WHERE id = ? AND user_id = ? - `, strings.Join(setClauses, ", ")) - - // 执行更新 - result, err := d.db.Exec(query, args...) - if err != nil { - log.Printf("❌ UpdateExchange: 更新失败: %v", err) - return err - } - - // 检查是否有行被更新 - rowsAffected, err := result.RowsAffected() - if err != nil { - log.Printf("❌ UpdateExchange: 获取影响行数失败: %v", err) - return err - } - - log.Printf("📊 UpdateExchange: 影响行数 = %d", rowsAffected) - - // 如果没有行被更新,说明用户没有这个交易所的配置,需要创建 - if rowsAffected == 0 { - log.Printf("💡 UpdateExchange: 没有现有记录,创建新记录") - - // 根据交易所ID确定基本信息 - var name, typ string - if id == "binance" { - name = "Binance Futures" - typ = "cex" - } else if id == "bybit" { - name = "Bybit Futures" - typ = "cex" - } else if id == "hyperliquid" { - name = "Hyperliquid" - typ = "dex" - } else if id == "aster" { - name = "Aster DEX" - typ = "dex" - } else if id == "lighter" { - name = "LIGHTER DEX" - typ = "dex" - } else { - name = id + " Exchange" - typ = "cex" - } - - log.Printf("🆕 UpdateExchange: 创建新记录 ID=%s, name=%s, type=%s", id, name, typ) - - // 加密敏感字段 - encryptedAPIKey := d.encryptSensitiveData(apiKey) - encryptedSecretKey := d.encryptSensitiveData(secretKey) - encryptedAsterPrivateKey := d.encryptSensitiveData(asterPrivateKey) - encryptedLighterPrivateKey := d.encryptSensitiveData(lighterPrivateKey) - - // 创建用户特定的配置,使用原始的交易所ID - _, err = d.db.Exec(` - INSERT INTO exchanges (id, user_id, name, type, enabled, api_key, secret_key, testnet, - hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key, - lighter_wallet_addr, lighter_private_key, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now')) - `, id, userID, name, typ, enabled, encryptedAPIKey, encryptedSecretKey, testnet, hyperliquidWalletAddr, asterUser, asterSigner, encryptedAsterPrivateKey, lighterWalletAddr, encryptedLighterPrivateKey) - - if err != nil { - log.Printf("❌ UpdateExchange: 创建记录失败: %v", err) - } else { - log.Printf("✅ UpdateExchange: 创建记录成功") - } - return err - } - - log.Printf("✅ UpdateExchange: 更新现有记录成功") - return nil -} - -// CreateAIModel 创建AI模型配置 -func (d *Database) CreateAIModel(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error { - _, err := d.db.Exec(` - INSERT OR IGNORE INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url) - VALUES (?, ?, ?, ?, ?, ?, ?) - `, id, userID, name, provider, enabled, apiKey, customAPIURL) - return err -} - -// CreateExchange 创建交易所配置 -func (d *Database) CreateExchange(userID, id, name, typ string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error { - // 加密敏感字段 - encryptedAPIKey := d.encryptSensitiveData(apiKey) - encryptedSecretKey := d.encryptSensitiveData(secretKey) - encryptedAsterPrivateKey := d.encryptSensitiveData(asterPrivateKey) - - _, err := d.db.Exec(` - INSERT OR IGNORE INTO exchanges (id, user_id, name, type, enabled, api_key, secret_key, testnet, hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key, lighter_wallet_addr, lighter_private_key) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, '', '') - `, id, userID, name, typ, enabled, encryptedAPIKey, encryptedSecretKey, testnet, hyperliquidWalletAddr, asterUser, asterSigner, encryptedAsterPrivateKey) - return err -} - -// CreateTrader 创建交易员 -func (d *Database) CreateTrader(trader *TraderRecord) error { - _, err := d.db.Exec(` - INSERT INTO traders (id, user_id, name, ai_model_id, exchange_id, initial_balance, scan_interval_minutes, is_running, btc_eth_leverage, altcoin_leverage, trading_symbols, use_coin_pool, use_oi_top, custom_prompt, override_base_prompt, system_prompt_template, is_cross_margin) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - `, trader.ID, trader.UserID, trader.Name, trader.AIModelID, trader.ExchangeID, trader.InitialBalance, trader.ScanIntervalMinutes, trader.IsRunning, trader.BTCETHLeverage, trader.AltcoinLeverage, trader.TradingSymbols, trader.UseCoinPool, trader.UseOITop, trader.CustomPrompt, trader.OverrideBasePrompt, trader.SystemPromptTemplate, trader.IsCrossMargin) - return err -} - -// GetTraders 获取用户的交易员 -func (d *Database) GetTraders(userID string) ([]*TraderRecord, error) { - rows, err := d.db.Query(` - SELECT id, user_id, name, ai_model_id, exchange_id, initial_balance, scan_interval_minutes, is_running, - COALESCE(btc_eth_leverage, 5) as btc_eth_leverage, COALESCE(altcoin_leverage, 5) as altcoin_leverage, - COALESCE(trading_symbols, '') as trading_symbols, - COALESCE(use_coin_pool, 0) as use_coin_pool, COALESCE(use_oi_top, 0) as use_oi_top, - COALESCE(custom_prompt, '') as custom_prompt, COALESCE(override_base_prompt, 0) as override_base_prompt, - COALESCE(system_prompt_template, 'default') as system_prompt_template, - COALESCE(is_cross_margin, 1) as is_cross_margin, created_at, updated_at - FROM traders WHERE user_id = ? ORDER BY created_at DESC - `, userID) - if err != nil { - return nil, err - } - defer rows.Close() - - var traders []*TraderRecord - for rows.Next() { - var trader TraderRecord - var createdAt, updatedAt string - err := rows.Scan( - &trader.ID, &trader.UserID, &trader.Name, &trader.AIModelID, &trader.ExchangeID, - &trader.InitialBalance, &trader.ScanIntervalMinutes, &trader.IsRunning, - &trader.BTCETHLeverage, &trader.AltcoinLeverage, &trader.TradingSymbols, - &trader.UseCoinPool, &trader.UseOITop, - &trader.CustomPrompt, &trader.OverrideBasePrompt, &trader.SystemPromptTemplate, - &trader.IsCrossMargin, - &createdAt, &updatedAt, - ) - if err != nil { - return nil, err - } - // 解析时间字符串 - trader.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - trader.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - traders = append(traders, &trader) - } - - return traders, nil -} - -// UpdateTraderStatus 更新交易员状态 -func (d *Database) UpdateTraderStatus(userID, id string, isRunning bool) error { - _, err := d.db.Exec(`UPDATE traders SET is_running = ? WHERE id = ? AND user_id = ?`, isRunning, id, userID) - return err -} - -// UpdateTrader 更新交易员配置 -func (d *Database) UpdateTrader(trader *TraderRecord) error { - _, err := d.db.Exec(` - UPDATE traders SET - name = ?, ai_model_id = ?, exchange_id = ?, - scan_interval_minutes = ?, btc_eth_leverage = ?, altcoin_leverage = ?, - trading_symbols = ?, custom_prompt = ?, override_base_prompt = ?, - system_prompt_template = ?, is_cross_margin = ?, updated_at = CURRENT_TIMESTAMP - WHERE id = ? AND user_id = ? - `, trader.Name, trader.AIModelID, trader.ExchangeID, - trader.ScanIntervalMinutes, trader.BTCETHLeverage, trader.AltcoinLeverage, - trader.TradingSymbols, trader.CustomPrompt, trader.OverrideBasePrompt, - trader.SystemPromptTemplate, trader.IsCrossMargin, trader.ID, trader.UserID) - return err -} - -// UpdateTraderCustomPrompt 更新交易员自定义Prompt -func (d *Database) UpdateTraderCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error { - _, err := d.db.Exec(`UPDATE traders SET custom_prompt = ?, override_base_prompt = ? WHERE id = ? AND user_id = ?`, customPrompt, overrideBase, id, userID) - return err -} - -// UpdateTraderInitialBalance 更新交易员初始余额(仅支持手动更新) -// ⚠️ 注意:系统不会自动调用此方法,仅供用户在充值/提现后手动同步使用 -func (d *Database) UpdateTraderInitialBalance(userID, id string, newBalance float64) error { - _, err := d.db.Exec(`UPDATE traders SET initial_balance = ? WHERE id = ? AND user_id = ?`, newBalance, id, userID) - return err -} - -// DeleteTrader 删除交易员 -func (d *Database) DeleteTrader(userID, id string) error { - _, err := d.db.Exec(`DELETE FROM traders WHERE id = ? AND user_id = ?`, id, userID) - return err -} - -// GetTraderConfig 获取交易员完整配置(包含AI模型和交易所信息) -func (d *Database) GetTraderConfig(userID, traderID string) (*TraderRecord, *AIModelConfig, *ExchangeConfig, error) { - var trader TraderRecord - var aiModel AIModelConfig - var exchange ExchangeConfig - var traderCreatedAt, traderUpdatedAt string - var aiModelCreatedAt, aiModelUpdatedAt string - var exchangeCreatedAt, exchangeUpdatedAt string - - err := d.db.QueryRow(` - SELECT - t.id, t.user_id, t.name, t.ai_model_id, t.exchange_id, t.initial_balance, t.scan_interval_minutes, t.is_running, - COALESCE(t.btc_eth_leverage, 5) as btc_eth_leverage, - COALESCE(t.altcoin_leverage, 5) as altcoin_leverage, - COALESCE(t.trading_symbols, '') as trading_symbols, - COALESCE(t.use_coin_pool, 0) as use_coin_pool, - COALESCE(t.use_oi_top, 0) as use_oi_top, - COALESCE(t.custom_prompt, '') as custom_prompt, - COALESCE(t.override_base_prompt, 0) as override_base_prompt, - COALESCE(t.system_prompt_template, 'default') as system_prompt_template, - COALESCE(t.is_cross_margin, 1) as is_cross_margin, - t.created_at, t.updated_at, - a.id, a.user_id, a.name, a.provider, a.enabled, a.api_key, - COALESCE(a.custom_api_url, '') as custom_api_url, - COALESCE(a.custom_model_name, '') as custom_model_name, - a.created_at, a.updated_at, - e.id, e.user_id, e.name, e.type, e.enabled, e.api_key, e.secret_key, e.testnet, - COALESCE(e.hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr, - COALESCE(e.aster_user, '') as aster_user, - COALESCE(e.aster_signer, '') as aster_signer, - COALESCE(e.aster_private_key, '') as aster_private_key, - COALESCE(e.lighter_wallet_addr, '') as lighter_wallet_addr, - COALESCE(e.lighter_private_key, '') as lighter_private_key, - COALESCE(e.lighter_api_key_private_key, '') as lighter_api_key_private_key, - e.created_at, e.updated_at - FROM traders t - JOIN ai_models a ON t.ai_model_id = a.id AND t.user_id = a.user_id - JOIN exchanges e ON t.exchange_id = e.id AND t.user_id = e.user_id - WHERE t.id = ? AND t.user_id = ? - `, traderID, userID).Scan( - &trader.ID, &trader.UserID, &trader.Name, &trader.AIModelID, &trader.ExchangeID, - &trader.InitialBalance, &trader.ScanIntervalMinutes, &trader.IsRunning, - &trader.BTCETHLeverage, &trader.AltcoinLeverage, &trader.TradingSymbols, - &trader.UseCoinPool, &trader.UseOITop, - &trader.CustomPrompt, &trader.OverrideBasePrompt, &trader.SystemPromptTemplate, - &trader.IsCrossMargin, - &traderCreatedAt, &traderUpdatedAt, - &aiModel.ID, &aiModel.UserID, &aiModel.Name, &aiModel.Provider, &aiModel.Enabled, &aiModel.APIKey, - &aiModel.CustomAPIURL, &aiModel.CustomModelName, - &aiModelCreatedAt, &aiModelUpdatedAt, - &exchange.ID, &exchange.UserID, &exchange.Name, &exchange.Type, &exchange.Enabled, - &exchange.APIKey, &exchange.SecretKey, &exchange.Testnet, - &exchange.HyperliquidWalletAddr, &exchange.AsterUser, &exchange.AsterSigner, &exchange.AsterPrivateKey, - &exchange.LighterWalletAddr, &exchange.LighterPrivateKey, &exchange.LighterAPIKeyPrivateKey, - &exchangeCreatedAt, &exchangeUpdatedAt, - ) - - if err != nil { - return nil, nil, nil, err - } - - // 解析时间字符串 - trader.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", traderCreatedAt) - trader.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", traderUpdatedAt) - aiModel.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelCreatedAt) - aiModel.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelUpdatedAt) - exchange.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeCreatedAt) - exchange.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeUpdatedAt) - - // 解密敏感数据 - aiModel.APIKey = d.decryptSensitiveData(aiModel.APIKey) - exchange.APIKey = d.decryptSensitiveData(exchange.APIKey) - exchange.SecretKey = d.decryptSensitiveData(exchange.SecretKey) - exchange.AsterPrivateKey = d.decryptSensitiveData(exchange.AsterPrivateKey) - exchange.LighterPrivateKey = d.decryptSensitiveData(exchange.LighterPrivateKey) - exchange.LighterAPIKeyPrivateKey = d.decryptSensitiveData(exchange.LighterAPIKeyPrivateKey) - - return &trader, &aiModel, &exchange, nil -} - -// GetSystemConfig 获取系统配置 -func (d *Database) GetSystemConfig(key string) (string, error) { - var value string - err := d.db.QueryRow(`SELECT value FROM system_config WHERE key = ?`, key).Scan(&value) - return value, err -} - -// SetSystemConfig 设置系统配置 -func (d *Database) SetSystemConfig(key, value string) error { - _, err := d.db.Exec(` - INSERT OR REPLACE INTO system_config (key, value) VALUES (?, ?) - `, key, value) - return err -} - -// CreateUserSignalSource 创建用户信号源配置 -func (d *Database) CreateUserSignalSource(userID, coinPoolURL, oiTopURL string) error { - _, err := d.db.Exec(` - INSERT OR REPLACE INTO user_signal_sources (user_id, coin_pool_url, oi_top_url, updated_at) - VALUES (?, ?, ?, CURRENT_TIMESTAMP) - `, userID, coinPoolURL, oiTopURL) - return err -} - -// GetUserSignalSource 获取用户信号源配置 -func (d *Database) GetUserSignalSource(userID string) (*UserSignalSource, error) { - var source UserSignalSource - var createdAt, updatedAt string - err := d.db.QueryRow(` - SELECT id, user_id, coin_pool_url, oi_top_url, created_at, updated_at - FROM user_signal_sources WHERE user_id = ? - `, userID).Scan( - &source.ID, &source.UserID, &source.CoinPoolURL, &source.OITopURL, - &createdAt, &updatedAt, - ) - if err != nil { - return nil, err - } - source.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) - source.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) - return &source, nil -} - -// UpdateUserSignalSource 更新用户信号源配置 -func (d *Database) UpdateUserSignalSource(userID, coinPoolURL, oiTopURL string) error { - _, err := d.db.Exec(` - UPDATE user_signal_sources SET coin_pool_url = ?, oi_top_url = ?, updated_at = CURRENT_TIMESTAMP - WHERE user_id = ? - `, coinPoolURL, oiTopURL, userID) - return err -} - -// GetCustomCoins 获取所有交易员自定义币种 / Get all trader-customized currencies -func (d *Database) GetCustomCoins() []string { - var symbol string - var symbols []string - _ = d.db.QueryRow(` - SELECT GROUP_CONCAT(custom_coins , ',') as symbol - FROM main.traders where custom_coins != '' - `).Scan(&symbol) - // 检测用户是否未配置币种 - 兼容性 - if symbol == "" { - symbolJSON, _ := d.GetSystemConfig("default_coins") - if err := json.Unmarshal([]byte(symbolJSON), &symbols); err != nil { - log.Printf("⚠️ 解析default_coins配置失败: %v,使用硬编码默认值", err) - symbols = []string{"BTCUSDT", "ETHUSDT", "SOLUSDT", "BNBUSDT"} - } - } - // filter Symbol - for _, s := range strings.Split(symbol, ",") { - if s == "" { - continue - } - coin := market.Normalize(s) - if !slices.Contains(symbols, coin) { - symbols = append(symbols, coin) - } - } - return symbols -} - -// Close 关闭数据库连接 -// Conn 返回底层 *sql.DB,供需要执行自定义查询的模块使用。 -func (d *Database) Conn() *sql.DB { - return d.db -} - -func (d *Database) Close() error { - return d.db.Close() -} - -// LoadBetaCodesFromFile 从文件加载内测码到数据库 -func (d *Database) LoadBetaCodesFromFile(filePath string) error { - // 读取文件内容 - content, err := os.ReadFile(filePath) - if err != nil { - return fmt.Errorf("读取内测码文件失败: %w", err) - } - - // 按行分割内测码 - lines := strings.Split(string(content), "\n") - var codes []string - for _, line := range lines { - code := strings.TrimSpace(line) - if code != "" && !strings.HasPrefix(code, "#") { - codes = append(codes, code) - } - } - - // 批量插入内测码 - tx, err := d.db.Begin() - if err != nil { - return fmt.Errorf("开始事务失败: %w", err) - } - defer tx.Rollback() - - stmt, err := tx.Prepare(`INSERT OR IGNORE INTO beta_codes (code) VALUES (?)`) - if err != nil { - return fmt.Errorf("准备语句失败: %w", err) - } - defer stmt.Close() - - insertedCount := 0 - for _, code := range codes { - result, err := stmt.Exec(code) - if err != nil { - log.Printf("插入内测码 %s 失败: %v", code, err) - continue - } - - if rowsAffected, _ := result.RowsAffected(); rowsAffected > 0 { - insertedCount++ - } - } - - if err := tx.Commit(); err != nil { - return fmt.Errorf("提交事务失败: %w", err) - } - - log.Printf("✅ 成功加载 %d 个内测码到数据库 (总计 %d 个)", insertedCount, len(codes)) - return nil -} - -// ValidateBetaCode 验证内测码是否有效且未使用 -func (d *Database) ValidateBetaCode(code string) (bool, error) { - var used bool - err := d.db.QueryRow(`SELECT used FROM beta_codes WHERE code = ?`, code).Scan(&used) - if err != nil { - if err == sql.ErrNoRows { - return false, nil // 内测码不存在 - } - return false, err - } - return !used, nil // 内测码存在且未使用 -} - -// UseBetaCode 使用内测码(标记为已使用) -func (d *Database) UseBetaCode(code, userEmail string) error { - result, err := d.db.Exec(` - UPDATE beta_codes SET used = 1, used_by = ?, used_at = CURRENT_TIMESTAMP - WHERE code = ? AND used = 0 - `, userEmail, code) - if err != nil { - return err - } - - rowsAffected, err := result.RowsAffected() - if err != nil { - return err - } - - if rowsAffected == 0 { - return fmt.Errorf("内测码无效或已被使用") - } - - return nil -} - -// GetBetaCodeStats 获取内测码统计信息 -func (d *Database) GetBetaCodeStats() (total, used int, err error) { - err = d.db.QueryRow(`SELECT COUNT(*) FROM beta_codes`).Scan(&total) - if err != nil { - return 0, 0, err - } - - err = d.db.QueryRow(`SELECT COUNT(*) FROM beta_codes WHERE used = 1`).Scan(&used) - if err != nil { - return 0, 0, err - } - - return total, used, nil -} - -// SetCryptoService 设置加密服务 -func (d *Database) SetCryptoService(cs *crypto.CryptoService) { - d.cryptoService = cs -} - -// encryptSensitiveData 加密敏感数据用于存储 -func (d *Database) encryptSensitiveData(plaintext string) string { - if d.cryptoService == nil || plaintext == "" { - return plaintext - } - - encrypted, err := d.cryptoService.EncryptForStorage(plaintext) - if err != nil { - log.Printf("⚠️ 加密失败: %v", err) - return plaintext // 返回明文作为降级处理 - } - - return encrypted -} - -// decryptSensitiveData 解密敏感数据 -func (d *Database) decryptSensitiveData(encrypted string) string { - if d.cryptoService == nil || encrypted == "" { - return encrypted - } - - // 如果不是加密格式,直接返回 - if !d.cryptoService.IsEncryptedStorageValue(encrypted) { - return encrypted - } - - decrypted, err := d.cryptoService.DecryptFromStorage(encrypted) - if err != nil { - log.Printf("⚠️ 解密失败: %v", err) - return encrypted // 返回加密文本作为降级处理 - } - - return decrypted -} diff --git a/config/database_test.go b/config/database_test.go deleted file mode 100644 index b3a009d8..00000000 --- a/config/database_test.go +++ /dev/null @@ -1,850 +0,0 @@ -package config - -import ( - "nofx/crypto" - "os" - "testing" - "time" -) - -// TestUpdateExchange_EmptyValuesShouldNotOverwrite 测试空值不应覆盖现有数据 -// 这是 Bug 的核心:当前实现会用空字符串覆盖现有的私钥 -func TestUpdateExchange_EmptyValuesShouldNotOverwrite(t *testing.T) { - // 准备测试数据库 - db, cleanup := setupTestDB(t) - defer cleanup() - - userID := "test-user-001" - - // 步骤 1: 创建初始配置(包含私钥) - initialAPIKey := "initial-api-key-12345" - initialSecretKey := "initial-secret-key-67890" - - err := db.UpdateExchange( - userID, - "hyperliquid", - true, // enabled - initialAPIKey, - initialSecretKey, - false, // testnet - "0xWalletAddress", - "", - "", - "", - "", // lighter_wallet_addr - "", // lighter_private_key - ) - if err != nil { - t.Fatalf("初始化失败: %v", err) - } - - // 步骤 2: 验证初始数据已保存 - exchanges, err := db.GetExchanges(userID) - if err != nil { - t.Fatalf("获取配置失败: %v", err) - } - if len(exchanges) == 0 { - t.Fatal("未找到配置") - } - - // 解密后应该能看到原始值 - if exchanges[0].APIKey != initialAPIKey { - t.Errorf("初始 APIKey 不正确,期望 %s,实际 %s", initialAPIKey, exchanges[0].APIKey) - } - - // 步骤 3: 用空值更新(模拟前端发送空值的场景) - // 🐛 Bug 重现:这应该 NOT 覆盖现有的私钥,但当前实现会覆盖 - err = db.UpdateExchange( - userID, - "hyperliquid", - false, // 只改变 enabled 状态 - "", // 空 apiKey - 不应该覆盖 - "", // 空 secretKey - 不应该覆盖 - true, // 改变 testnet 状态 - "0xWalletAddress", - "", - "", - "", // 空 aster_private_key - 不应该覆盖 - "", - "", - ) - if err != nil { - t.Fatalf("更新失败: %v", err) - } - - // 步骤 4: 验证私钥没有被空值覆盖 - exchanges, err = db.GetExchanges(userID) - if err != nil { - t.Fatalf("获取更新后配置失败: %v", err) - } - - // 🎯 关键断言:私钥应该保持不变 - if exchanges[0].APIKey != initialAPIKey { - t.Errorf("❌ Bug 确认:APIKey 被空值覆盖了!期望 %s,实际 %s", initialAPIKey, exchanges[0].APIKey) - } - if exchanges[0].SecretKey != initialSecretKey { - t.Errorf("❌ Bug 确认:SecretKey 被空值覆盖了!期望 %s,实际 %s", initialSecretKey, exchanges[0].SecretKey) - } - - // 验证非敏感字段正常更新 - if exchanges[0].Enabled { - t.Error("enabled 应该被更新为 false") - } - if !exchanges[0].Testnet { - t.Error("testnet 应该被更新为 true") - } -} - -// TestUpdateExchange_AsterEmptyValuesShouldNotOverwrite 测试 Aster 私钥不被空值覆盖 -func TestUpdateExchange_AsterEmptyValuesShouldNotOverwrite(t *testing.T) { - db, cleanup := setupTestDB(t) - defer cleanup() - - userID := "test-user-002" - - // 步骤 1: 创建 Aster 配置 - initialAsterKey := "aster-private-key-xyz123" - - err := db.UpdateExchange( - userID, - "aster", - true, - "", - "", - false, - "", - "0xAsterUser", - "0xAsterSigner", - initialAsterKey, - "", - "", - ) - if err != nil { - t.Fatalf("初始化 Aster 失败: %v", err) - } - - // 步骤 2: 用空值更新 - err = db.UpdateExchange( - userID, - "aster", - false, // 只改 enabled - "", - "", - false, - "", - "0xAsterUser", - "0xAsterSigner", - "", // 空 aster_private_key - "", - "", - ) - if err != nil { - t.Fatalf("更新失败: %v", err) - } - - // 步骤 3: 验证 aster_private_key 没有被覆盖 - exchanges, err := db.GetExchanges(userID) - if err != nil { - t.Fatalf("获取配置失败: %v", err) - } - - if exchanges[0].AsterPrivateKey != initialAsterKey { - t.Errorf("❌ Bug 确认:AsterPrivateKey 被空值覆盖了!期望 %s,实际 %s", initialAsterKey, exchanges[0].AsterPrivateKey) - } -} - -// TestUpdateExchange_NonEmptyValuesShouldUpdate 测试非空值应该正常更新 -func TestUpdateExchange_NonEmptyValuesShouldUpdate(t *testing.T) { - db, cleanup := setupTestDB(t) - defer cleanup() - - userID := "test-user-003" - - // 步骤 1: 创建初始配置 - err := db.UpdateExchange( - userID, - "hyperliquid", - true, - "old-api-key", - "old-secret-key", - false, - "0xOldWallet", - "", - "", - "", - "", - "", - ) - if err != nil { - t.Fatalf("初始化失败: %v", err) - } - - // 步骤 2: 用非空值更新 - newAPIKey := "new-api-key-456" - newSecretKey := "new-secret-key-789" - - err = db.UpdateExchange( - userID, - "hyperliquid", - true, - newAPIKey, - newSecretKey, - false, - "0xNewWallet", - "", - "", - "", - "", - "", - ) - if err != nil { - t.Fatalf("更新失败: %v", err) - } - - // 步骤 3: 验证新值已更新 - exchanges, err := db.GetExchanges(userID) - if err != nil { - t.Fatalf("获取配置失败: %v", err) - } - - if exchanges[0].APIKey != newAPIKey { - t.Errorf("APIKey 未更新,期望 %s,实际 %s", newAPIKey, exchanges[0].APIKey) - } - if exchanges[0].SecretKey != newSecretKey { - t.Errorf("SecretKey 未更新,期望 %s,实际 %s", newSecretKey, exchanges[0].SecretKey) - } - if exchanges[0].HyperliquidWalletAddr != "0xNewWallet" { - t.Errorf("WalletAddr 未更新") - } -} - -// TestUpdateExchange_PartialUpdateShouldWork 测试部分字段更新 -func TestUpdateExchange_PartialUpdateShouldWork(t *testing.T) { - db, cleanup := setupTestDB(t) - defer cleanup() - - userID := "test-user-005" - - // 创建初始配置 - err := db.UpdateExchange( - userID, - "hyperliquid", - true, - "api-key-123", - "secret-key-456", - false, - "0xWallet1", - "", - "", - "", - "", - "", - ) - if err != nil { - t.Fatalf("初始化失败: %v", err) - } - - // 只更新 enabled 和 testnet,私钥留空 - err = db.UpdateExchange( - userID, - "hyperliquid", - false, - "", // 留空 - "", // 留空 - true, - "0xWallet2", - "", - "", - "", - "", - "", - ) - if err != nil { - t.Fatalf("部分更新失败: %v", err) - } - - // 验证 - exchanges, err := db.GetExchanges(userID) - if err != nil { - t.Fatalf("获取配置失败: %v", err) - } - - // 私钥应该保持不变 - if exchanges[0].APIKey != "api-key-123" { - t.Errorf("APIKey 不应改变,期望 api-key-123,实际 %s", exchanges[0].APIKey) - } - if exchanges[0].SecretKey != "secret-key-456" { - t.Errorf("SecretKey 不应改变,期望 secret-key-456,实际 %s", exchanges[0].SecretKey) - } - - // 其他字段应该更新 - if exchanges[0].Enabled { - t.Error("enabled 应该更新为 false") - } - if !exchanges[0].Testnet { - t.Error("testnet 应该更新为 true") - } - if exchanges[0].HyperliquidWalletAddr != "0xWallet2" { - t.Error("wallet 地址应该更新") - } -} - -// TestUpdateExchange_MultipleExchangeTypes 测试不同交易所类型 -func TestUpdateExchange_MultipleExchangeTypes(t *testing.T) { - db, cleanup := setupTestDB(t) - defer cleanup() - - userID := "test-user-006" - - testCases := []struct { - exchangeID string - name string - typ string - }{ - {"binance", "Binance Futures", "cex"}, - {"hyperliquid", "Hyperliquid", "dex"}, - {"aster", "Aster DEX", "dex"}, - {"unknown-exchange", "unknown-exchange Exchange", "cex"}, - } - - for _, tc := range testCases { - t.Run(tc.exchangeID, func(t *testing.T) { - err := db.UpdateExchange( - userID, - tc.exchangeID, - true, - "api-key-"+tc.exchangeID, - "secret-key-"+tc.exchangeID, - false, - "", - "", - "", - "", - "", - "", - ) - if err != nil { - t.Fatalf("创建 %s 失败: %v", tc.exchangeID, err) - } - - // 验证创建成功 - exchanges, err := db.GetExchanges(userID) - if err != nil { - t.Fatalf("获取配置失败: %v", err) - } - - found := false - for _, ex := range exchanges { - if ex.ID == tc.exchangeID { - found = true - if ex.Name != tc.name { - t.Errorf("交易所名称不正确,期望 %s,实际 %s", tc.name, ex.Name) - } - if ex.Type != tc.typ { - t.Errorf("交易所类型不正确,期望 %s,实际 %s", tc.typ, ex.Type) - } - if ex.APIKey != "api-key-"+tc.exchangeID { - t.Errorf("APIKey 不正确") - } - break - } - } - - if !found { - t.Errorf("未找到交易所 %s", tc.exchangeID) - } - }) - } -} - -// TestUpdateExchange_MixedSensitiveFields 测试混合更新敏感和非敏感字段 -func TestUpdateExchange_MixedSensitiveFields(t *testing.T) { - db, cleanup := setupTestDB(t) - defer cleanup() - - userID := "test-user-007" - - // 创建初始配置 - err := db.UpdateExchange( - userID, - "hyperliquid", - true, - "old-api-key", - "old-secret-key", - false, - "0xOldWallet", - "", - "", - "", - "", - "", - ) - if err != nil { - t.Fatalf("初始化失败: %v", err) - } - - // 场景1: 只更新 apiKey,secretKey 留空 - err = db.UpdateExchange( - userID, - "hyperliquid", - false, - "new-api-key", - "", // 留空 - true, - "0xNewWallet", - "", - "", - "", - "", - "", - ) - if err != nil { - t.Fatalf("更新1失败: %v", err) - } - - exchanges, _ := db.GetExchanges(userID) - if exchanges[0].APIKey != "new-api-key" { - t.Error("APIKey 应该更新") - } - if exchanges[0].SecretKey != "old-secret-key" { - t.Error("SecretKey 应该保持不变") - } - - // 场景2: 只更新 secretKey,apiKey 留空 - err = db.UpdateExchange( - userID, - "hyperliquid", - true, - "", // 留空 - "new-secret-key", - false, - "0xFinalWallet", - "", - "", - "", - "", - "", - ) - if err != nil { - t.Fatalf("更新2失败: %v", err) - } - - exchanges, _ = db.GetExchanges(userID) - if exchanges[0].APIKey != "new-api-key" { - t.Error("APIKey 应该保持不变") - } - if exchanges[0].SecretKey != "new-secret-key" { - t.Error("SecretKey 应该更新") - } - if exchanges[0].Enabled != true { - t.Error("Enabled 应该更新为 true") - } - if exchanges[0].HyperliquidWalletAddr != "0xFinalWallet" { - t.Error("WalletAddr 应该更新") - } -} - -// TestUpdateExchange_OnlyNonSensitiveFields 测试只更新非敏感字段 -func TestUpdateExchange_OnlyNonSensitiveFields(t *testing.T) { - db, cleanup := setupTestDB(t) - defer cleanup() - - userID := "test-user-008" - - // 创建初始配置(包含所有私钥) - err := db.UpdateExchange( - userID, - "aster", - true, - "binance-api", - "binance-secret", - false, - "", - "0xUser1", - "0xSigner1", - "aster-private-key-1", - "", - "", - ) - if err != nil { - t.Fatalf("初始化失败: %v", err) - } - - // 只更新非敏感字段(所有私钥字段留空) - err = db.UpdateExchange( - userID, - "aster", - false, - "", - "", - true, - "", - "0xUser2", - "0xSigner2", - "", - "", - "", - ) - if err != nil { - t.Fatalf("更新失败: %v", err) - } - - // 验证所有私钥保持不变 - exchanges, _ := db.GetExchanges(userID) - if exchanges[0].APIKey != "binance-api" { - t.Errorf("APIKey 应该保持不变,实际 %s", exchanges[0].APIKey) - } - if exchanges[0].SecretKey != "binance-secret" { - t.Errorf("SecretKey 应该保持不变,实际 %s", exchanges[0].SecretKey) - } - if exchanges[0].AsterPrivateKey != "aster-private-key-1" { - t.Errorf("AsterPrivateKey 应该保持不变,实际 %s", exchanges[0].AsterPrivateKey) - } - - // 验证非敏感字段已更新 - if exchanges[0].Enabled != false { - t.Error("Enabled 应该更新为 false") - } - if exchanges[0].Testnet != true { - t.Error("Testnet 应该更新为 true") - } - if exchanges[0].AsterUser != "0xUser2" { - t.Error("AsterUser 应该更新") - } - if exchanges[0].AsterSigner != "0xSigner2" { - t.Error("AsterSigner 应该更新") - } -} - -// TestUpdateExchange_AllSensitiveFieldsUpdate 测试同时更新所有敏感字段 -func TestUpdateExchange_AllSensitiveFieldsUpdate(t *testing.T) { - db, cleanup := setupTestDB(t) - defer cleanup() - - userID := "test-user-009" - - // 创建初始配置 - err := db.UpdateExchange( - userID, - "binance", - true, - "old-api", - "old-secret", - false, - "", - "", - "", - "old-aster-key", - "", - "", - ) - if err != nil { - t.Fatalf("初始化失败: %v", err) - } - - // 同时更新所有敏感字段 - err = db.UpdateExchange( - userID, - "binance", - false, - "new-api", - "new-secret", - true, - "0xWallet", - "0xUser", - "0xSigner", - "new-aster-key", - "", - "", - ) - if err != nil { - t.Fatalf("更新失败: %v", err) - } - - // 验证所有字段都更新了 - exchanges, _ := db.GetExchanges(userID) - if exchanges[0].APIKey != "new-api" { - t.Error("APIKey 应该更新") - } - if exchanges[0].SecretKey != "new-secret" { - t.Error("SecretKey 应该更新") - } - if exchanges[0].AsterPrivateKey != "new-aster-key" { - t.Error("AsterPrivateKey 应该更新") - } - if !exchanges[0].Testnet { - t.Error("Testnet 应该更新为 true") - } -} - -// setupTestDB 创建测试数据库 -func setupTestDB(t *testing.T) (*Database, func()) { - // 创建临时数据库文件 - tmpFile := t.TempDir() + "/test.db" - - db, err := NewDatabase(tmpFile) - if err != nil { - t.Fatalf("创建测试数据库失败: %v", err) - } - - // 创建测试用户 - testUsers := []string{ - "test-user-001", "test-user-002", "test-user-003", "test-user-004", "test-user-005", - "test-user-006", "test-user-007", "test-user-008", "test-user-009", - "test-user-persistence", "user1", "user2", - } - for _, userID := range testUsers { - user := &User{ - ID: userID, - Email: userID + "@test.com", - PasswordHash: "hash", - OTPSecret: "", - OTPVerified: false, - } - _ = db.CreateUser(user) - } - - // 设置加密服务(用于测试加密功能) - // 创建临时 RSA 密钥 - rsaKeyPath := t.TempDir() + "/test_rsa_key" - cryptoService, err := crypto.NewCryptoService(rsaKeyPath) - if err != nil { - // 如果创建失败,继续测试但不使用加密 - t.Logf("警告:无法创建加密服务,将在无加密模式下测试: %v", err) - } else { - db.SetCryptoService(cryptoService) - } - - cleanup := func() { - db.Close() - os.RemoveAll(tmpFile) - os.RemoveAll(rsaKeyPath) - } - - return db, cleanup -} - -// TestWALModeEnabled 测试 WAL 模式是否启用 -// TDD: 这个测试应该失败,因为当前代码没有启用 WAL 模式 -func TestWALModeEnabled(t *testing.T) { - db, cleanup := setupTestDB(t) - defer cleanup() - - // 查询当前的 journal_mode - var journalMode string - err := db.db.QueryRow("PRAGMA journal_mode").Scan(&journalMode) - if err != nil { - t.Fatalf("查询 journal_mode 失败: %v", err) - } - - // 期望是 WAL 模式 - if journalMode != "wal" { - t.Errorf("期望 journal_mode=wal,实际是 %s", journalMode) - } -} - -// TestSynchronousMode 测试 synchronous 模式设置 -// TDD: 验证数据持久性设置 -func TestSynchronousMode(t *testing.T) { - db, cleanup := setupTestDB(t) - defer cleanup() - - // 查询 synchronous 设置 - var synchronous int - err := db.db.QueryRow("PRAGMA synchronous").Scan(&synchronous) - if err != nil { - t.Fatalf("查询 synchronous 失败: %v", err) - } - - // 期望是 FULL (2) 以确保数据持久性 - if synchronous != 2 { - t.Errorf("期望 synchronous=2 (FULL),实际是 %d", synchronous) - } -} - -// TestDataPersistenceAcrossReopen 测试数据在数据库关闭并重新打开后是否持久化 -// TDD: 模拟 Docker restart 场景 -func TestDataPersistenceAcrossReopen(t *testing.T) { - // 创建临时数据库文件 - tmpFile, err := os.CreateTemp("", "test_persistence_*.db") - if err != nil { - t.Fatalf("创建临时文件失败: %v", err) - } - tmpFile.Close() - dbPath := tmpFile.Name() - defer os.Remove(dbPath) - - // 设置加密服务 - rsaKeyPath := "test_rsa_key.pem" - cryptoService, err := crypto.NewCryptoService(rsaKeyPath) - if err != nil { - t.Fatalf("初始化加密服务失败: %v", err) - } - defer os.RemoveAll(rsaKeyPath) - - userID := "test-user-persistence" - testAPIKey := "test-api-key-should-persist" - testSecretKey := "test-secret-key-should-persist" - - // 第一次打开数据库并写入数据 - { - db, err := NewDatabase(dbPath) - if err != nil { - t.Fatalf("第一次创建数据库失败: %v", err) - } - db.SetCryptoService(cryptoService) - - // 创建持久化测试用户,避免外键约束失败 - _ = db.CreateUser(&User{ - ID: userID, - Email: userID + "@test.com", - PasswordHash: "hash", - OTPSecret: "", - OTPVerified: true, - }) - - // 写入交易所配置 - err = db.UpdateExchange( - userID, - "binance", - true, - testAPIKey, - testSecretKey, - false, - "", - "", - "", - "", - "", - "", - ) - if err != nil { - t.Fatalf("写入数据失败: %v", err) - } - - // 模拟正常关闭 - if err := db.Close(); err != nil { - t.Fatalf("关闭数据库失败: %v", err) - } - } - - // 第二次打开数据库并验证数据是否还在 - { - db, err := NewDatabase(dbPath) - if err != nil { - t.Fatalf("第二次打开数据库失败: %v", err) - } - db.SetCryptoService(cryptoService) - defer db.Close() - - // 读取数据 - exchanges, err := db.GetExchanges(userID) - if err != nil { - t.Fatalf("读取数据失败: %v", err) - } - - if len(exchanges) == 0 { - t.Fatal("数据丢失:没有找到任何交易所配置") - } - - // 验证数据完整性 - found := false - for _, ex := range exchanges { - if ex.ID == "binance" { - found = true - if ex.APIKey != testAPIKey { - t.Errorf("API Key 丢失或损坏,期望 %s,实际 %s", testAPIKey, ex.APIKey) - } - if ex.SecretKey != testSecretKey { - t.Errorf("Secret Key 丢失或损坏,期望 %s,实际 %s", testSecretKey, ex.SecretKey) - } - } - } - - if !found { - t.Error("数据丢失:找不到 binance 配置") - } - } -} - -// TestConcurrentWritesWithWAL 测试 WAL 模式下的并发写入 -// TDD: WAL 模式应该支持更好的并发性能 -func TestConcurrentWritesWithWAL(t *testing.T) { - db, cleanup := setupTestDB(t) - defer cleanup() - - // 这个测试验证多个并发写入可以成功 - // WAL 模式下并发性能更好,但 SQLite 仍然可能出现短暂的锁 - done := make(chan bool, 2) - errors := make(chan error, 10) - - // 并发写入1 - go func() { - for i := 0; i < 3; i++ { - err := db.UpdateExchange( - "user1", - "binance", - true, - "key1", - "secret1", - false, - "", - "", - "", - "", - "", - "", - ) - if err != nil { - errors <- err - } - // 小延迟减少锁冲突 - time.Sleep(10 * time.Millisecond) - } - done <- true - }() - - // 并发写入2 - go func() { - for i := 0; i < 3; i++ { - err := db.UpdateExchange( - "user2", - "hyperliquid", - true, - "key2", - "secret2", - false, - "0xWallet", - "", - "", - "", - "", - "", - ) - if err != nil { - errors <- err - } - // 小延迟减少锁冲突 - time.Sleep(10 * time.Millisecond) - } - done <- true - }() - - // 等待两个 goroutine 完成 - <-done - <-done - close(errors) - - // 检查是否有错误 - errorCount := 0 - for err := range errors { - t.Logf("并发写入错误: %v", err) - errorCount++ - } - - // WAL 模式下应该能处理并发,但可能有少量锁错误 - // 我们允许最多 2 个错误 - if errorCount > 2 { - t.Errorf("并发写入失败次数过多: %d", errorCount) - } -} diff --git a/config/test_rsa_key.pem.pub b/config/test_rsa_key.pem.pub deleted file mode 100644 index a9f89eeb..00000000 --- a/config/test_rsa_key.pem.pub +++ /dev/null @@ -1,9 +0,0 @@ ------BEGIN PUBLIC KEY----- -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4Y666RzY5LLi6PiYL+vC -7+fcr122Fd8BC7IdqUSYKQ33Nsi9J7J5fDgcMf7ZAnIBpxMV7+e1KEoiwtGmxwHj -mYo0ZV0E6JXdiK26S052+Shquri0IXkwGFraDuNKqmGrj6vZuXtq2L2gdSyZCxrI -veN9g6LxBvLBP1Rx7UEmZeyokRYvChcxAQXuS/0br44BOHGtwAElk6AGLISz55AG -oM40b3ktiza+8THKMz3GiylQQYpBltbM3yAXPlnXJ2MtUZiaHNhEQI4++PMvEErN -Izm8cIgcvUAXJ5vBfa4kD0kSgBJFuEQ2im3qcWTuEPRKztEeJDY7XAVHc1Xy6d4N -vQIDAQAB ------END PUBLIC KEY----- diff --git a/crypto/crypto.go b/crypto/crypto.go index df543efb..9c33cfe6 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -13,10 +13,7 @@ import ( "encoding/pem" "errors" "fmt" - "io/ioutil" - "log" "os" - "path/filepath" "strings" "time" ) @@ -24,8 +21,12 @@ import ( const ( storagePrefix = "ENC:v1:" storageDelimiter = ":" - dataKeyEnvName = "DATA_ENCRYPTION_KEY" - dataKeyFilePath = "secrets/data_key" +) + +// 环境变量名称 +const ( + EnvDataEncryptionKey = "DATA_ENCRYPTION_KEY" // AES 数据加密密钥 (Base64) + EnvRSAPrivateKey = "RSA_PRIVATE_KEY" // RSA 私钥 (PEM 格式,换行用 \n) ) type EncryptedPayload struct { @@ -50,29 +51,18 @@ type CryptoService struct { dataKey []byte } -func NewCryptoService(privateKeyPath string) (*CryptoService, error) { - // 读取私钥文件 - privateKeyPEM, err := ioutil.ReadFile(privateKeyPath) +// NewCryptoService 创建加密服务(从环境变量加载密钥) +func NewCryptoService() (*CryptoService, error) { + // 1. 加载 RSA 私钥 + privateKey, err := loadRSAPrivateKeyFromEnv() if err != nil { - // 如果私钥文件不存在,生成新的密钥对 - if err := GenerateRSAKeyPair(privateKeyPath); err != nil { - return nil, fmt.Errorf("failed to generate RSA key pair: %w", err) - } - privateKeyPEM, err = ioutil.ReadFile(privateKeyPath) - if err != nil { - return nil, fmt.Errorf("failed to read generated private key: %w", err) - } + return nil, fmt.Errorf("RSA 私钥加载失败: %w", err) } - // 解析私钥 - privateKey, err := ParseRSAPrivateKeyFromPEM(privateKeyPEM) + // 2. 加载 AES 数据加密密钥 + dataKey, err := loadDataKeyFromEnv() if err != nil { - return nil, fmt.Errorf("failed to parse private key: %w", err) - } - - dataKey, err := resolveDataKey() - if err != nil { - return nil, fmt.Errorf("failed to load data encryption key: %w", err) + return nil, fmt.Errorf("数据加密密钥加载失败: %w", err) } return &CryptoService{ @@ -82,56 +72,43 @@ func NewCryptoService(privateKeyPath string) (*CryptoService, error) { }, nil } -func GenerateRSAKeyPair(privateKeyPath string) error { - // 确保目录存在 - dir := filepath.Dir(privateKeyPath) - if dir != "." { - if err := os.MkdirAll(dir, 0700); err != nil { - return fmt.Errorf("failed to create directory %s: %w", dir, err) - } +// loadRSAPrivateKeyFromEnv 从环境变量加载 RSA 私钥 +func loadRSAPrivateKeyFromEnv() (*rsa.PrivateKey, error) { + keyPEM := os.Getenv(EnvRSAPrivateKey) + if keyPEM == "" { + return nil, fmt.Errorf("环境变量 %s 未设置,请在 .env 中配置 RSA 私钥", EnvRSAPrivateKey) } - // 生成 RSA 密钥对 - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return err - } + // 处理环境变量中的换行符(\n -> 实际换行) + keyPEM = strings.ReplaceAll(keyPEM, "\\n", "\n") - // 编码私钥 - privateKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(privateKey), - }) - - // 保存私钥 - if err := ioutil.WriteFile(privateKeyPath, privateKeyPEM, 0600); err != nil { - return err - } - - // 编码公钥 - publicKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) - if err != nil { - return err - } - - publicKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "PUBLIC KEY", - Bytes: publicKeyDER, - }) - - // 保存公钥 - publicKeyPath := privateKeyPath + ".pub" - if err := ioutil.WriteFile(publicKeyPath, publicKeyPEM, 0644); err != nil { - return err - } - - return nil + return ParseRSAPrivateKeyFromPEM([]byte(keyPEM)) } +// loadDataKeyFromEnv 从环境变量加载 AES 数据加密密钥 +func loadDataKeyFromEnv() ([]byte, error) { + keyStr := strings.TrimSpace(os.Getenv(EnvDataEncryptionKey)) + if keyStr == "" { + return nil, fmt.Errorf("环境变量 %s 未设置,请在 .env 中配置数据加密密钥", EnvDataEncryptionKey) + } + + // 尝试解码 + if key, ok := decodePossibleKey(keyStr); ok { + return key, nil + } + + // 如果无法解码,使用 SHA256 哈希作为密钥 + sum := sha256.Sum256([]byte(keyStr)) + key := make([]byte, len(sum)) + copy(key, sum[:]) + return key, nil +} + +// ParseRSAPrivateKeyFromPEM 解析 PEM 格式的 RSA 私钥 func ParseRSAPrivateKeyFromPEM(pemBytes []byte) (*rsa.PrivateKey, error) { block, _ := pem.Decode(pemBytes) if block == nil { - return nil, errors.New("no PEM block found") + return nil, errors.New("无效的 PEM 格式") } switch block.Type { @@ -144,100 +121,15 @@ func ParseRSAPrivateKeyFromPEM(pemBytes []byte) (*rsa.PrivateKey, error) { } rsaKey, ok := key.(*rsa.PrivateKey) if !ok { - return nil, errors.New("not an RSA key") + return nil, errors.New("不是 RSA 密钥") } return rsaKey, nil default: - return nil, errors.New("unsupported key type: " + block.Type) + return nil, errors.New("不支持的密钥类型: " + block.Type) } } -func resolveDataKey() ([]byte, error) { - if key, ok := loadDataKeyFromEnv(); ok { - return key, nil - } - - key, _, err := loadOrCreateDataKeyFile(dataKeyFilePath) - return key, err -} - -func loadDataKeyFromEnv() ([]byte, bool) { - keyStr := strings.TrimSpace(os.Getenv(dataKeyEnvName)) - if keyStr == "" { - return nil, false - } - - if key, ok := decodePossibleKey(keyStr); ok { - return key, true - } - - sum := sha256.Sum256([]byte(keyStr)) - key := make([]byte, len(sum)) - copy(key, sum[:]) - return key, true -} - -var errInvalidDataKeyMaterial = errors.New("invalid data encryption key material") - -func loadOrCreateDataKeyFile(path string) ([]byte, bool, error) { - key, err := readDataKeyFromFile(path) - if err == nil { - log.Printf("🔐 使用本地数据加密密钥: %s", path) - return key, false, nil - } - - if !errors.Is(err, os.ErrNotExist) && !errors.Is(err, errInvalidDataKeyMaterial) { - log.Printf("⚠️ 无法读取数据加密密钥文件 (%s): %v,尝试重新生成", path, err) - } - - key, err = generateAndPersistDataKey(path) - if err != nil { - return nil, false, err - } - return key, true, nil -} - -func readDataKeyFromFile(path string) ([]byte, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, err - } - - encoded := strings.TrimSpace(string(data)) - if encoded == "" { - return nil, errInvalidDataKeyMaterial - } - - if key, ok := decodePossibleKey(encoded); ok { - return key, nil - } - - return nil, errInvalidDataKeyMaterial -} - -func generateAndPersistDataKey(path string) ([]byte, error) { - raw := make([]byte, 32) - if _, err := rand.Read(raw); err != nil { - return nil, err - } - - dir := filepath.Dir(path) - if dir != "" && dir != "." { - if err := os.MkdirAll(dir, 0700); err != nil { - return nil, err - } - } - - encoded := base64.StdEncoding.EncodeToString(raw) - if err := os.WriteFile(path, []byte(encoded+"\n"), 0600); err != nil { - return nil, err - } - - log.Printf("🆕 已生成新的数据加密密钥并保存到 %s", path) - log.Printf(" 若需在生产或容器环境复用,请设置 %s 为该值", dataKeyEnvName) - return raw, nil -} - +// decodePossibleKey 尝试用多种编码方式解码密钥 func decodePossibleKey(value string) ([]byte, bool) { decoders := []func(string) ([]byte, error){ base64.StdEncoding.DecodeString, @@ -256,6 +148,7 @@ func decodePossibleKey(value string) ([]byte, bool) { return nil, false } +// normalizeAESKey 标准化 AES 密钥长度 func normalizeAESKey(raw []byte) ([]byte, bool) { switch len(raw) { case 16, 24, 32: @@ -293,7 +186,7 @@ func (cs *CryptoService) EncryptForStorage(plaintext string, aadParts ...string) return "", nil } if !cs.HasDataKey() { - return "", errors.New("data encryption key not configured") + return "", errors.New("数据加密密钥未配置") } if isEncryptedStorageValue(plaintext) { return plaintext, nil @@ -327,26 +220,26 @@ func (cs *CryptoService) DecryptFromStorage(value string, aadParts ...string) (s return "", nil } if !cs.HasDataKey() { - return "", errors.New("data encryption key not configured") + return "", errors.New("数据加密密钥未配置") } if !isEncryptedStorageValue(value) { - return "", errors.New("value is not encrypted") + return "", errors.New("数据未加密") } payload := strings.TrimPrefix(value, storagePrefix) parts := strings.SplitN(payload, storageDelimiter, 2) if len(parts) != 2 { - return "", errors.New("invalid encrypted payload format") + return "", errors.New("无效的加密数据格式") } nonce, err := base64.StdEncoding.DecodeString(parts[0]) if err != nil { - return "", fmt.Errorf("decode nonce failed: %w", err) + return "", fmt.Errorf("解码 nonce 失败: %w", err) } ciphertext, err := base64.StdEncoding.DecodeString(parts[1]) if err != nil { - return "", fmt.Errorf("decode ciphertext failed: %w", err) + return "", fmt.Errorf("解码密文失败: %w", err) } block, err := aes.NewCipher(cs.dataKey) @@ -360,13 +253,13 @@ func (cs *CryptoService) DecryptFromStorage(value string, aadParts ...string) (s } if len(nonce) != gcm.NonceSize() { - return "", fmt.Errorf("invalid nonce size: expected %d, got %d", gcm.NonceSize(), len(nonce)) + return "", fmt.Errorf("无效的 nonce 长度: 期望 %d, 实际 %d", gcm.NonceSize(), len(nonce)) } aad := composeAAD(aadParts) plaintext, err := gcm.Open(nil, nonce, ciphertext, aad) if err != nil { - return "", fmt.Errorf("decryption failed: %w", err) + return "", fmt.Errorf("解密失败: %w", err) } return string(plaintext), nil @@ -392,66 +285,63 @@ func (cs *CryptoService) DecryptPayload(payload *EncryptedPayload) ([]byte, erro if payload.TS != 0 { elapsed := time.Since(time.Unix(payload.TS, 0)) if elapsed > 5*time.Minute || elapsed < -1*time.Minute { - return nil, errors.New("timestamp invalid or expired") + return nil, errors.New("时间戳无效或已过期") } } // 2. 解码 base64url wrappedKey, err := base64.RawURLEncoding.DecodeString(payload.WrappedKey) if err != nil { - return nil, fmt.Errorf("failed to decode wrapped key: %w", err) + return nil, fmt.Errorf("解码 wrapped key 失败: %w", err) } iv, err := base64.RawURLEncoding.DecodeString(payload.IV) if err != nil { - return nil, fmt.Errorf("failed to decode IV: %w", err) + return nil, fmt.Errorf("解码 IV 失败: %w", err) } ciphertext, err := base64.RawURLEncoding.DecodeString(payload.Ciphertext) if err != nil { - return nil, fmt.Errorf("failed to decode ciphertext: %w", err) + return nil, fmt.Errorf("解码密文失败: %w", err) } var aad []byte if payload.AAD != "" { aad, err = base64.RawURLEncoding.DecodeString(payload.AAD) if err != nil { - return nil, fmt.Errorf("failed to decode AAD: %w", err) + return nil, fmt.Errorf("解码 AAD 失败: %w", err) } - // 验证 AAD var aadData AADData if err := json.Unmarshal(aad, &aadData); err == nil { // 可以在这里添加额外的验证逻辑 - // 例如:验证 sessionID、userID 等 } } // 3. 使用 RSA-OAEP 解密 AES 密钥 aesKey, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, cs.privateKey, wrappedKey, nil) if err != nil { - return nil, fmt.Errorf("failed to unwrap AES key: %w", err) + return nil, fmt.Errorf("RSA 解密失败: %w", err) } // 4. 使用 AES-GCM 解密数据 block, err := aes.NewCipher(aesKey) if err != nil { - return nil, fmt.Errorf("failed to create AES cipher: %w", err) + return nil, fmt.Errorf("创建 AES cipher 失败: %w", err) } gcm, err := cipher.NewGCM(block) if err != nil { - return nil, fmt.Errorf("failed to create GCM: %w", err) + return nil, fmt.Errorf("创建 GCM 失败: %w", err) } if len(iv) != gcm.NonceSize() { - return nil, fmt.Errorf("invalid IV size: expected %d, got %d", gcm.NonceSize(), len(iv)) + return nil, fmt.Errorf("无效的 IV 长度: 期望 %d, 实际 %d", gcm.NonceSize(), len(iv)) } - // 解密并验证认证标签 plaintext, err := gcm.Open(nil, iv, ciphertext, aad) if err != nil { - return nil, fmt.Errorf("authentication/decryption failed: %w", err) + return nil, fmt.Errorf("解密验证失败: %w", err) } return plaintext, nil @@ -464,3 +354,41 @@ func (cs *CryptoService) DecryptSensitiveData(payload *EncryptedPayload) (string } return string(plaintext), nil } + +// GenerateKeyPair 生成 RSA 密钥对(用于初始化时生成密钥) +// 返回 PEM 格式的私钥和公钥 +func GenerateKeyPair() (privateKeyPEM, publicKeyPEM string, err error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", "", err + } + + // 编码私钥 + privPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + + // 编码公钥 + publicKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) + if err != nil { + return "", "", err + } + + pubPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyDER, + }) + + return string(privPEM), string(pubPEM), nil +} + +// GenerateDataKey 生成 AES 数据加密密钥 +// 返回 Base64 编码的 32 字节密钥 +func GenerateDataKey() (string, error) { + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(key), nil +} diff --git a/crypto/encryption.go b/crypto/encryption.go deleted file mode 100644 index 73d1b5ba..00000000 --- a/crypto/encryption.go +++ /dev/null @@ -1,373 +0,0 @@ -package crypto - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - "crypto/x509" - "encoding/base64" - "encoding/binary" - "encoding/pem" - "errors" - "fmt" - "io" - "log" - "os" - "sync" -) - -// EncryptionManager 加密管理器(單例模式) -type EncryptionManager struct { - privateKey *rsa.PrivateKey - publicKeyPEM string - masterKey []byte // 用於數據庫加密的主密鑰 - mu sync.RWMutex -} - -var ( - instance *EncryptionManager - once sync.Once -) - -// GetEncryptionManager 獲取加密管理器實例 -func GetEncryptionManager() (*EncryptionManager, error) { - var initErr error - once.Do(func() { - instance, initErr = newEncryptionManager() - }) - return instance, initErr -} - -// newEncryptionManager 初始化加密管理器 -func newEncryptionManager() (*EncryptionManager, error) { - em := &EncryptionManager{} - - // 1. 加載或生成 RSA 密鑰對 - if err := em.loadOrGenerateRSAKeyPair(); err != nil { - return nil, fmt.Errorf("初始化 RSA 密鑰失敗: %w", err) - } - - // 2. 加載或生成數據庫主密鑰 - if err := em.loadOrGenerateMasterKey(); err != nil { - return nil, fmt.Errorf("初始化主密鑰失敗: %w", err) - } - - log.Println("🔐 加密管理器初始化成功") - return em, nil -} - -// ==================== RSA 密鑰管理 ==================== - -const ( - rsaKeySize = 4096 - rsaPrivateKeyFile = ".secrets/rsa_private.pem" - rsaPublicKeyFile = ".secrets/rsa_public.pem" - masterKeyFile = ".secrets/master.key" -) - -// loadOrGenerateRSAKeyPair 加載或生成 RSA 密鑰對 -func (em *EncryptionManager) loadOrGenerateRSAKeyPair() error { - // 確保 .secrets 目錄存在 - if err := os.MkdirAll(".secrets", 0700); err != nil { - return err - } - - // 嘗試加載現有密鑰 - if _, err := os.Stat(rsaPrivateKeyFile); err == nil { - return em.loadRSAKeyPair() - } - - // 生成新密鑰對 - log.Println("🔑 生成新的 RSA-4096 密鑰對...") - privateKey, err := rsa.GenerateKey(rand.Reader, rsaKeySize) - if err != nil { - return err - } - - em.privateKey = privateKey - - // 保存私鑰 - privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey) - privateKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: privateKeyBytes, - }) - if err := os.WriteFile(rsaPrivateKeyFile, privateKeyPEM, 0600); err != nil { - return err - } - - // 保存公鑰 - publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) - if err != nil { - return err - } - publicKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "PUBLIC KEY", - Bytes: publicKeyBytes, - }) - if err := os.WriteFile(rsaPublicKeyFile, publicKeyPEM, 0644); err != nil { - return err - } - - em.publicKeyPEM = string(publicKeyPEM) - log.Println("✅ RSA 密鑰對已生成並保存") - return nil -} - -// loadRSAKeyPair 加載 RSA 密鑰對 -func (em *EncryptionManager) loadRSAKeyPair() error { - // 加載私鑰 - privateKeyPEM, err := os.ReadFile(rsaPrivateKeyFile) - if err != nil { - return err - } - - block, _ := pem.Decode(privateKeyPEM) - if block == nil || block.Type != "RSA PRIVATE KEY" { - return errors.New("無效的私鑰 PEM 格式") - } - - privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - return err - } - em.privateKey = privateKey - - // 加載公鑰 - publicKeyPEM, err := os.ReadFile(rsaPublicKeyFile) - if err != nil { - return err - } - em.publicKeyPEM = string(publicKeyPEM) - - log.Println("✅ RSA 密鑰對已加載") - return nil -} - -// GetPublicKeyPEM 獲取公鑰 (PEM 格式) -func (em *EncryptionManager) GetPublicKeyPEM() string { - em.mu.RLock() - defer em.mu.RUnlock() - return em.publicKeyPEM -} - -// ==================== 混合解密 (RSA + AES) ==================== - -// DecryptWithPrivateKey 使用私鑰解密數據 -// 數據格式: [加密的 AES 密鑰長度(4字節)] + [加密的 AES 密鑰] + [IV(12字節)] + [加密數據] -func (em *EncryptionManager) DecryptWithPrivateKey(encryptedBase64 string) (string, error) { - em.mu.RLock() - defer em.mu.RUnlock() - - // Base64 解碼 - encryptedData, err := base64.StdEncoding.DecodeString(encryptedBase64) - if err != nil { - return "", fmt.Errorf("Base64 解碼失敗: %w", err) - } - - if len(encryptedData) < 4+256+12 { // 最小長度檢查 - return "", errors.New("加密數據長度不足") - } - - // 1. 讀取加密的 AES 密鑰長度 - aesKeyLen := binary.BigEndian.Uint32(encryptedData[:4]) - if aesKeyLen > 1024 { // 防止過大的長度值 - return "", errors.New("無效的 AES 密鑰長度") - } - - offset := 4 - // 2. 提取加密的 AES 密鑰 - encryptedAESKey := encryptedData[offset : offset+int(aesKeyLen)] - offset += int(aesKeyLen) - - // 3. 使用 RSA 私鑰解密 AES 密鑰 - aesKey, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, em.privateKey, encryptedAESKey, nil) - if err != nil { - return "", fmt.Errorf("RSA 解密失敗: %w", err) - } - - // 4. 提取 IV - iv := encryptedData[offset : offset+12] - offset += 12 - - // 5. 提取加密數據 - ciphertext := encryptedData[offset:] - - // 6. 使用 AES-GCM 解密 - block, err := aes.NewCipher(aesKey) - if err != nil { - return "", err - } - - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - - plaintext, err := aesGCM.Open(nil, iv, ciphertext, nil) - if err != nil { - return "", fmt.Errorf("AES 解密失敗: %w", err) - } - - // 清除敏感數據 - for i := range aesKey { - aesKey[i] = 0 - } - - return string(plaintext), nil -} - -// ==================== 數據庫加密 (AES-256-GCM) ==================== - -// loadOrGenerateMasterKey 加載或生成數據庫主密鑰 -func (em *EncryptionManager) loadOrGenerateMasterKey() error { - // 優先從環境變數加載 - if envKey := os.Getenv("NOFX_MASTER_KEY"); envKey != "" { - decoded, err := base64.StdEncoding.DecodeString(envKey) - if err == nil && len(decoded) == 32 { - em.masterKey = decoded - log.Println("✅ 從環境變數加載主密鑰") - return nil - } - log.Println("⚠️ 環境變數中的主密鑰無效,使用文件密鑰") - } - - // 嘗試從文件加載 - if _, err := os.Stat(masterKeyFile); err == nil { - keyBytes, err := os.ReadFile(masterKeyFile) - if err != nil { - return err - } - decoded, err := base64.StdEncoding.DecodeString(string(keyBytes)) - if err != nil || len(decoded) != 32 { - return errors.New("主密鑰文件損壞") - } - em.masterKey = decoded - log.Println("✅ 從文件加載主密鑰") - return nil - } - - // 生成新主密鑰 - log.Println("🔑 生成新的數據庫主密鑰 (AES-256)...") - masterKey := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, masterKey); err != nil { - return err - } - - em.masterKey = masterKey - - // 保存到文件 - encoded := base64.StdEncoding.EncodeToString(masterKey) - if err := os.WriteFile(masterKeyFile, []byte(encoded), 0600); err != nil { - return err - } - - log.Println("✅ 主密鑰已生成並保存") - log.Printf("📁 主密鑰文件位置: %s (權限: 0600)", masterKeyFile) - log.Println("🔐 生產環境請設置環境變數: NOFX_MASTER_KEY=<從文件讀取>") - log.Println("⚠️ 請妥善保管 .secrets 目錄,切勿將密鑰提交到版本控制系統") - return nil -} - -// EncryptForDatabase 使用主密鑰加密數據(用於數據庫存儲) -func (em *EncryptionManager) EncryptForDatabase(plaintext string) (string, error) { - em.mu.RLock() - defer em.mu.RUnlock() - - block, err := aes.NewCipher(em.masterKey) - if err != nil { - return "", err - } - - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - - nonce := make([]byte, aesGCM.NonceSize()) - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - return "", err - } - - ciphertext := aesGCM.Seal(nonce, nonce, []byte(plaintext), nil) - return base64.StdEncoding.EncodeToString(ciphertext), nil -} - -// DecryptFromDatabase 使用主密鑰解密數據(從數據庫讀取) -func (em *EncryptionManager) DecryptFromDatabase(encryptedBase64 string) (string, error) { - em.mu.RLock() - defer em.mu.RUnlock() - - // 處理空字符串(未加密的舊數據) - if encryptedBase64 == "" { - return "", nil - } - - ciphertext, err := base64.StdEncoding.DecodeString(encryptedBase64) - if err != nil { - return "", err - } - - block, err := aes.NewCipher(em.masterKey) - if err != nil { - return "", err - } - - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - - nonceSize := aesGCM.NonceSize() - if len(ciphertext) < nonceSize { - return "", errors.New("加密數據過短") - } - - nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] - plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil) - if err != nil { - return "", err - } - - return string(plaintext), nil -} - -// ==================== 密鑰輪換 ==================== - -// RotateMasterKey 輪換主密鑰(需要重新加密所有數據) -func (em *EncryptionManager) RotateMasterKey() error { - em.mu.Lock() - defer em.mu.Unlock() - - log.Println("🔄 開始輪換主密鑰...") - - // 生成新主密鑰 - newMasterKey := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, newMasterKey); err != nil { - return err - } - - // 備份舊密鑰 - oldMasterKey := em.masterKey - - // 更新密鑰 - em.masterKey = newMasterKey - - // 保存新密鑰 - encoded := base64.StdEncoding.EncodeToString(newMasterKey) - backupFile := fmt.Sprintf("%s.backup.%d", masterKeyFile, os.Getpid()) - if err := os.WriteFile(backupFile, []byte(base64.StdEncoding.EncodeToString(oldMasterKey)), 0600); err != nil { - return err - } - if err := os.WriteFile(masterKeyFile, []byte(encoded), 0600); err != nil { - return err - } - - log.Println("✅ 主密鑰已輪換") - log.Printf("⚠️ 舊密鑰已備份到: %s", backupFile) - log.Printf("🔐 新主密鑰: %s", encoded) - - return nil -} diff --git a/crypto/encryption_test.go b/crypto/encryption_test.go deleted file mode 100644 index 1e65a962..00000000 --- a/crypto/encryption_test.go +++ /dev/null @@ -1,159 +0,0 @@ -package crypto - -import ( - "testing" -) - -// TestRSAKeyPairGeneration 測試 RSA 密鑰對生成 -func TestRSAKeyPairGeneration(t *testing.T) { - em, err := GetEncryptionManager() - if err != nil { - t.Fatalf("初始化加密管理器失敗: %v", err) - } - - publicKey := em.GetPublicKeyPEM() - if publicKey == "" { - t.Fatal("公鑰為空") - } - - if len(publicKey) < 100 { - t.Fatal("公鑰長度異常") - } - - t.Logf("✅ RSA 密鑰對生成成功,公鑰長度: %d", len(publicKey)) -} - -// TestDatabaseEncryption 測試數據庫加密/解密 -func TestDatabaseEncryption(t *testing.T) { - em, err := GetEncryptionManager() - if err != nil { - t.Fatalf("初始化加密管理器失敗: %v", err) - } - - testCases := []string{ - "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", - "test_api_key_12345", - "very_secret_password", - "", - } - - for _, plaintext := range testCases { - // 加密 - encrypted, err := em.EncryptForDatabase(plaintext) - if err != nil { - t.Fatalf("加密失敗: %v (明文: %s)", err, plaintext) - } - - // 驗證加密後不等於明文 - if encrypted == plaintext && plaintext != "" { - t.Fatalf("加密失敗:加密後仍為明文") - } - - // 解密 - decrypted, err := em.DecryptFromDatabase(encrypted) - if err != nil { - t.Fatalf("解密失敗: %v (密文: %s)", err, encrypted) - } - - // 驗證解密後等於明文 - if decrypted != plaintext { - t.Fatalf("解密結果不匹配: 期望 %s, 得到 %s", plaintext, decrypted) - } - - t.Logf("✅ 加密/解密測試通過: %s", plaintext[:min(len(plaintext), 20)]) - } -} - -// TestHybridEncryption 測試混合加密(前端 → 後端場景) -func TestHybridEncryption(t *testing.T) { - _, err := GetEncryptionManager() - if err != nil { - t.Fatalf("初始化加密管理器失敗: %v", err) - } - // 模擬前端加密私鑰 - // plaintext := "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" - // 注意:這裡需要前端的 encryptWithServerPublicKey 實現 - // 為了測試,我們直接使用後端的加密函數(實際前端使用 Web Crypto API) - - // 由於前端加密邏輯較複雜,這裡僅測試解密流程 - // 實際測試需要端到端測試 - t.Log("⚠️ 混合加密測試需要完整的前後端環境,請執行端到端測試") -} - -// TestEmptyString 測試空字串處理 -func TestEmptyString(t *testing.T) { - em, err := GetEncryptionManager() - if err != nil { - t.Fatalf("初始化加密管理器失敗: %v", err) - } - - encrypted, err := em.EncryptForDatabase("") - if err != nil { - t.Fatalf("加密空字串失敗: %v", err) - } - - decrypted, err := em.DecryptFromDatabase(encrypted) - if err != nil { - t.Fatalf("解密空字串失敗: %v", err) - } - - if decrypted != "" { - t.Fatalf("空字串處理錯誤: 期望空字串, 得到 %s", decrypted) - } - - t.Log("✅ 空字串處理正確") -} - -// TestInvalidCiphertext 測試無效密文處理 -func TestInvalidCiphertext(t *testing.T) { - em, err := GetEncryptionManager() - if err != nil { - t.Fatalf("初始化加密管理器失敗: %v", err) - } - - invalidCiphertexts := []string{ - "not_base64!@#$%", - "dGVzdA==", // 有效 Base64,但內容太短 - "", - } - - for _, ciphertext := range invalidCiphertexts { - _, err := em.DecryptFromDatabase(ciphertext) - if err == nil && ciphertext != "" { - t.Fatalf("應該拒絕無效密文: %s", ciphertext) - } - } - - t.Log("✅ 無效密文處理正確") -} - -// BenchmarkEncryption 性能測試:加密 -func BenchmarkEncryption(b *testing.B) { - em, _ := GetEncryptionManager() - plaintext := "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = em.EncryptForDatabase(plaintext) - } -} - -// BenchmarkDecryption 性能測試:解密 -func BenchmarkDecryption(b *testing.B) { - em, _ := GetEncryptionManager() - plaintext := "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" - encrypted, _ := em.EncryptForDatabase(plaintext) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = em.DecryptFromDatabase(encrypted) - } -} - -// min 工具函數 -func min(a, b int) int { - if a < b { - return a - } - return b -} diff --git a/crypto/secure_storage.go b/crypto/secure_storage.go deleted file mode 100644 index b168f9f8..00000000 --- a/crypto/secure_storage.go +++ /dev/null @@ -1,302 +0,0 @@ -package crypto - -import ( - "database/sql" - "fmt" - "log" - "time" -) - -// SecureStorage 安全存儲層(自動加密/解密數據庫中的敏感字段) -type SecureStorage struct { - db *sql.DB - em *EncryptionManager -} - -// NewSecureStorage 創建安全存儲實例 -func NewSecureStorage(db *sql.DB) (*SecureStorage, error) { - em, err := GetEncryptionManager() - if err != nil { - return nil, err - } - - ss := &SecureStorage{ - db: db, - em: em, - } - - // 初始化審計日誌表 - if err := ss.initAuditLog(); err != nil { - return nil, fmt.Errorf("初始化審計日誌失敗: %w", err) - } - - return ss, nil -} - -// ==================== 交易所配置加密存儲 ==================== - -// SaveEncryptedExchangeConfig 保存加密的交易所配置 -func (ss *SecureStorage) SaveEncryptedExchangeConfig(userID, exchangeID, apiKey, secretKey, asterPrivateKey string) error { - // 加密敏感字段 - encryptedAPIKey, err := ss.em.EncryptForDatabase(apiKey) - if err != nil { - return fmt.Errorf("加密 API Key 失敗: %w", err) - } - - encryptedSecretKey, err := ss.em.EncryptForDatabase(secretKey) - if err != nil { - return fmt.Errorf("加密 Secret Key 失敗: %w", err) - } - - encryptedPrivateKey := "" - if asterPrivateKey != "" { - encryptedPrivateKey, err = ss.em.EncryptForDatabase(asterPrivateKey) - if err != nil { - return fmt.Errorf("加密 Private Key 失敗: %w", err) - } - } - - // 更新數據庫 - _, err = ss.db.Exec(` - UPDATE exchanges - SET api_key = ?, secret_key = ?, aster_private_key = ?, updated_at = datetime('now') - WHERE user_id = ? AND id = ? - `, encryptedAPIKey, encryptedSecretKey, encryptedPrivateKey, userID, exchangeID) - - if err != nil { - return err - } - - // 記錄審計日誌 - ss.logAudit(userID, "exchange_config_update", exchangeID, "密鑰已更新") - - log.Printf("🔐 [%s] 交易所 %s 的密鑰已加密保存", userID, exchangeID) - return nil -} - -// LoadDecryptedExchangeConfig 加載並解密交易所配置 -func (ss *SecureStorage) LoadDecryptedExchangeConfig(userID, exchangeID string) (apiKey, secretKey, asterPrivateKey string, err error) { - var encryptedAPIKey, encryptedSecretKey, encryptedPrivateKey sql.NullString - - err = ss.db.QueryRow(` - SELECT api_key, secret_key, aster_private_key - FROM exchanges - WHERE user_id = ? AND id = ? - `, userID, exchangeID).Scan(&encryptedAPIKey, &encryptedSecretKey, &encryptedPrivateKey) - - if err != nil { - return "", "", "", err - } - - // 解密 API Key - if encryptedAPIKey.Valid && encryptedAPIKey.String != "" { - apiKey, err = ss.em.DecryptFromDatabase(encryptedAPIKey.String) - if err != nil { - return "", "", "", fmt.Errorf("解密 API Key 失敗: %w", err) - } - } - - // 解密 Secret Key - if encryptedSecretKey.Valid && encryptedSecretKey.String != "" { - secretKey, err = ss.em.DecryptFromDatabase(encryptedSecretKey.String) - if err != nil { - return "", "", "", fmt.Errorf("解密 Secret Key 失敗: %w", err) - } - } - - // 解密 Private Key - if encryptedPrivateKey.Valid && encryptedPrivateKey.String != "" { - asterPrivateKey, err = ss.em.DecryptFromDatabase(encryptedPrivateKey.String) - if err != nil { - return "", "", "", fmt.Errorf("解密 Private Key 失敗: %w", err) - } - } - - // 記錄審計日誌 - ss.logAudit(userID, "exchange_config_read", exchangeID, "密鑰已讀取") - - return apiKey, secretKey, asterPrivateKey, nil -} - -// ==================== AI 模型配置加密存儲 ==================== - -// SaveEncryptedAIModelConfig 保存加密的 AI 模型 API Key -func (ss *SecureStorage) SaveEncryptedAIModelConfig(userID, modelID, apiKey string) error { - encryptedAPIKey, err := ss.em.EncryptForDatabase(apiKey) - if err != nil { - return fmt.Errorf("加密 API Key 失敗: %w", err) - } - - _, err = ss.db.Exec(` - UPDATE ai_models - SET api_key = ?, updated_at = datetime('now') - WHERE user_id = ? AND id = ? - `, encryptedAPIKey, userID, modelID) - - if err != nil { - return err - } - - ss.logAudit(userID, "ai_model_config_update", modelID, "API Key 已更新") - log.Printf("🔐 [%s] AI 模型 %s 的 API Key 已加密保存", userID, modelID) - return nil -} - -// LoadDecryptedAIModelConfig 加載並解密 AI 模型配置 -func (ss *SecureStorage) LoadDecryptedAIModelConfig(userID, modelID string) (string, error) { - var encryptedAPIKey sql.NullString - - err := ss.db.QueryRow(` - SELECT api_key FROM ai_models WHERE user_id = ? AND id = ? - `, userID, modelID).Scan(&encryptedAPIKey) - - if err != nil { - return "", err - } - - if !encryptedAPIKey.Valid || encryptedAPIKey.String == "" { - return "", nil - } - - apiKey, err := ss.em.DecryptFromDatabase(encryptedAPIKey.String) - if err != nil { - return "", fmt.Errorf("解密 API Key 失敗: %w", err) - } - - ss.logAudit(userID, "ai_model_config_read", modelID, "API Key 已讀取") - return apiKey, nil -} - -// ==================== 審計日誌 ==================== - -// initAuditLog 初始化審計日誌表 -func (ss *SecureStorage) initAuditLog() error { - _, err := ss.db.Exec(` - CREATE TABLE IF NOT EXISTS audit_logs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id TEXT NOT NULL, - action TEXT NOT NULL, - resource TEXT NOT NULL, - details TEXT, - ip_address TEXT, - user_agent TEXT, - timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, - INDEX idx_user_time (user_id, timestamp), - INDEX idx_action (action) - ) - `) - return err -} - -// logAudit 記錄審計日誌 -func (ss *SecureStorage) logAudit(userID, action, resource, details string) { - _, err := ss.db.Exec(` - INSERT INTO audit_logs (user_id, action, resource, details) - VALUES (?, ?, ?, ?) - `, userID, action, resource, details) - - if err != nil { - log.Printf("⚠️ 審計日誌記錄失敗: %v", err) - } -} - -// GetAuditLogs 查詢審計日誌 -func (ss *SecureStorage) GetAuditLogs(userID string, limit int) ([]AuditLog, error) { - rows, err := ss.db.Query(` - SELECT id, user_id, action, resource, details, timestamp - FROM audit_logs - WHERE user_id = ? - ORDER BY timestamp DESC - LIMIT ? - `, userID, limit) - - if err != nil { - return nil, err - } - defer rows.Close() - - var logs []AuditLog - for rows.Next() { - var log AuditLog - err := rows.Scan(&log.ID, &log.UserID, &log.Action, &log.Resource, &log.Details, &log.Timestamp) - if err != nil { - return nil, err - } - logs = append(logs, log) - } - - return logs, nil -} - -// AuditLog 審計日誌結構 -type AuditLog struct { - ID int64 `json:"id"` - UserID string `json:"user_id"` - Action string `json:"action"` - Resource string `json:"resource"` - Details string `json:"details"` - Timestamp time.Time `json:"timestamp"` -} - -// ==================== 數據遷移工具 ==================== - -// MigrateToEncrypted 將舊的明文數據遷移到加密格式 -func (ss *SecureStorage) MigrateToEncrypted() error { - log.Println("🔄 開始遷移明文數據到加密格式...") - - tx, err := ss.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - // 遷移交易所配置 - rows, err := tx.Query(` - SELECT user_id, id, api_key, secret_key, aster_private_key - FROM exchanges - WHERE api_key != '' AND api_key NOT LIKE '%==%' -- 過濾已加密數據 - `) - if err != nil { - return err - } - - var count int - for rows.Next() { - var userID, exchangeID, apiKey, secretKey string - var asterPrivateKey sql.NullString - if err := rows.Scan(&userID, &exchangeID, &apiKey, &secretKey, &asterPrivateKey); err != nil { - rows.Close() - return err - } - - // 加密 - encAPIKey, _ := ss.em.EncryptForDatabase(apiKey) - encSecretKey, _ := ss.em.EncryptForDatabase(secretKey) - encPrivateKey := "" - if asterPrivateKey.Valid && asterPrivateKey.String != "" { - encPrivateKey, _ = ss.em.EncryptForDatabase(asterPrivateKey.String) - } - - // 更新 - _, err = tx.Exec(` - UPDATE exchanges - SET api_key = ?, secret_key = ?, aster_private_key = ? - WHERE user_id = ? AND id = ? - `, encAPIKey, encSecretKey, encPrivateKey, userID, exchangeID) - - if err != nil { - rows.Close() - return err - } - - count++ - } - rows.Close() - - if err := tx.Commit(); err != nil { - return err - } - - log.Printf("✅ 已遷移 %d 個交易所配置到加密格式", count) - return nil -} diff --git a/decision/engine.go b/decision/engine.go index 2f96a15e..470d43a9 100644 --- a/decision/engine.go +++ b/decision/engine.go @@ -3,7 +3,7 @@ package decision import ( "encoding/json" "fmt" - "log" + "nofx/logger" "math" "nofx/market" "nofx/mcp" @@ -72,6 +72,29 @@ type OITopData struct { NetShort float64 // 净空仓 } +// TradingStats 交易统计(用于AI输入) +type TradingStats struct { + TotalTrades int `json:"total_trades"` // 总交易数(已平仓) + WinRate float64 `json:"win_rate"` // 胜率 (%) + ProfitFactor float64 `json:"profit_factor"` // 盈亏比 + SharpeRatio float64 `json:"sharpe_ratio"` // 夏普比 + TotalPnL float64 `json:"total_pnl"` // 总盈亏 + AvgWin float64 `json:"avg_win"` // 平均盈利 + AvgLoss float64 `json:"avg_loss"` // 平均亏损 + MaxDrawdownPct float64 `json:"max_drawdown_pct"` // 最大回撤 (%) +} + +// RecentOrder 最近完成的订单(用于AI输入) +type RecentOrder struct { + Symbol string `json:"symbol"` // 交易对 + Side string `json:"side"` // long/short + EntryPrice float64 `json:"entry_price"` // 开仓价 + ExitPrice float64 `json:"exit_price"` // 平仓价 + RealizedPnL float64 `json:"realized_pnl"` // 已实现盈亏 + PnLPct float64 `json:"pnl_pct"` // 盈亏百分比 + FilledAt string `json:"filled_at"` // 成交时间 +} + // Context 交易上下文(传递给AI的完整信息) type Context struct { CurrentTime string `json:"current_time"` @@ -81,10 +104,11 @@ type Context struct { Positions []PositionInfo `json:"positions"` CandidateCoins []CandidateCoin `json:"candidate_coins"` PromptVariant string `json:"prompt_variant,omitempty"` - MarketDataMap map[string]*market.Data `json:"-"` // 不序列化,但内部使用 + TradingStats *TradingStats `json:"trading_stats,omitempty"` // 交易统计指标 + RecentOrders []RecentOrder `json:"recent_orders,omitempty"` // 最近完成的订单(10条) + MarketDataMap map[string]*market.Data `json:"-"` // 不序列化,但内部使用 MultiTFMarket map[string]map[string]*market.Data `json:"-"` OITopDataMap map[string]*OITopData `json:"-"` // OI Top数据映射 - Performance interface{} `json:"-"` // 历史表现分析(logger.PerformanceAnalysis) BTCETHLeverage int `json:"-"` // BTC/ETH杠杆倍数(从配置读取) AltcoinLeverage int `json:"-"` // 山寨币杠杆倍数(从配置读取) } @@ -92,7 +116,7 @@ type Context struct { // Decision AI的交易决策 type Decision struct { Symbol string `json:"symbol"` - Action string `json:"action"` // "open_long", "open_short", "close_long", "close_short", "update_stop_loss", "update_take_profit", "partial_close", "hold", "wait" + Action string `json:"action"` // "open_long", "open_short", "close_long", "close_short", "hold", "wait" // 开仓参数 Leverage int `json:"leverage,omitempty"` @@ -100,11 +124,6 @@ type Decision struct { StopLoss float64 `json:"stop_loss,omitempty"` TakeProfit float64 `json:"take_profit,omitempty"` - // 调整参数(新增) - NewStopLoss float64 `json:"new_stop_loss,omitempty"` // 用于 update_stop_loss - NewTakeProfit float64 `json:"new_take_profit,omitempty"` // 用于 update_take_profit - ClosePercentage float64 `json:"close_percentage,omitempty"` // 用于 partial_close (0-100) - // 通用参数 Confidence int `json:"confidence,omitempty"` // 信心度 (0-100) RiskUSD float64 `json:"risk_usd,omitempty"` // 最大美元风险 @@ -232,7 +251,7 @@ func fetchMarketDataForContext(ctx *Context) error { oiValue := data.OpenInterest.Latest * data.CurrentPrice oiValueInMillions := oiValue / 1_000_000 // 转换为百万美元单位 if oiValueInMillions < minOIThresholdMillions { - log.Printf("⚠️ %s 持仓价值过低(%.2fM USD < %.1fM),跳过此币种 [持仓量:%.0f × 价格:%.4f]", + logger.Infof("⚠️ %s 持仓价值过低(%.2fM USD < %.1fM),跳过此币种 [持仓量:%.0f × 价格:%.4f]", symbol, oiValueInMillions, minOIThresholdMillions, data.OpenInterest.Latest, data.CurrentPrice) continue } @@ -329,11 +348,11 @@ func buildSystemPrompt(accountEquity float64, btcEthLeverage, altcoinLeverage in template, err := GetPromptTemplate(templateName) if err != nil { // 如果模板不存在,记录错误并使用 default - log.Printf("⚠️ 提示词模板 '%s' 不存在,使用 default: %v", templateName, err) + logger.Infof("⚠️ 提示词模板 '%s' 不存在,使用 default: %v", templateName, err) template, err = GetPromptTemplate("default") if err != nil { // 如果连 default 都不存在,使用内置的简化版本 - log.Printf("❌ 无法加载任何提示词模板,使用内置简化版本") + logger.Infof("❌ 无法加载任何提示词模板,使用内置简化版本") sb.WriteString("你是专业的加密货币交易AI。请根据市场数据做出交易决策。\n\n") } else { sb.WriteString(template.Content) @@ -379,19 +398,11 @@ func buildSystemPrompt(accountEquity float64, btcEthLeverage, altcoinLeverage in sb.WriteString("- AI500 / OI_Top 筛选标签(若有)\n\n") sb.WriteString("自由运用任何有效的分析方法,但**信心度 ≥75** 才能开仓;避免单一指标、信号矛盾、横盘震荡、刚平仓即重启等低质量行为。\n\n") - // 5. 夏普比率驱动的自适应 - sb.WriteString("# 🧬 夏普比率自我进化\n\n") - sb.WriteString("- Sharpe < -0.5:立即停止交易,至少观望6个周期并深度复盘\n") - sb.WriteString("- -0.5 ~ 0:只做信心度>80的交易,并降低频率\n") - sb.WriteString("- 0 ~ 0.7:保持当前策略\n") - sb.WriteString("- >0.7:允许适度加仓,但仍遵守风控\n\n") - - // 6. 决策流程提示 + // 5. 决策流程提示 sb.WriteString("# 📋 决策流程\n\n") - sb.WriteString("1. 回顾夏普比率/盈亏 → 是否需要降频或暂停\n") - sb.WriteString("2. 检查持仓 → 是否该止盈/止损/调整\n") - sb.WriteString("3. 扫描候选币 + 多时间框 → 是否存在强信号\n") - sb.WriteString("4. 先写思维链,再输出结构化JSON\n\n") + sb.WriteString("1. 检查持仓 → 是否该止盈/止损\n") + sb.WriteString("2. 扫描候选币 + 多时间框 → 是否存在强信号\n") + sb.WriteString("3. 先写思维链,再输出结构化JSON\n\n") // 7. 输出格式 - 动态生成 sb.WriteString("# 输出格式 (严格遵守)\n\n") @@ -405,17 +416,13 @@ func buildSystemPrompt(accountEquity float64, btcEthLeverage, altcoinLeverage in sb.WriteString("第二步: JSON决策数组\n\n") sb.WriteString("```json\n[\n") sb.WriteString(fmt.Sprintf(" {\"symbol\": \"BTCUSDT\", \"action\": \"open_short\", \"leverage\": %d, \"position_size_usd\": %.0f, \"stop_loss\": 97000, \"take_profit\": 91000, \"confidence\": 85, \"risk_usd\": 300},\n", btcEthLeverage, accountEquity*5)) - sb.WriteString(" {\"symbol\": \"SOLUSDT\", \"action\": \"update_stop_loss\", \"new_stop_loss\": 155},\n") sb.WriteString(" {\"symbol\": \"ETHUSDT\", \"action\": \"close_long\"}\n") sb.WriteString("]\n```\n") sb.WriteString("\n\n") sb.WriteString("## 字段说明\n\n") - sb.WriteString("- `action`: open_long | open_short | close_long | close_short | update_stop_loss | update_take_profit | partial_close | hold | wait\n") + sb.WriteString("- `action`: open_long | open_short | close_long | close_short | hold | wait\n") sb.WriteString("- `confidence`: 0-100(开仓建议≥75)\n") - sb.WriteString("- 开仓时必填: leverage, position_size_usd, stop_loss, take_profit, confidence, risk_usd\n") - sb.WriteString("- update_stop_loss 时必填: new_stop_loss (注意是 new_stop_loss,不是 stop_loss)\n") - sb.WriteString("- update_take_profit 时必填: new_take_profit (注意是 new_take_profit,不是 take_profit)\n") - sb.WriteString("- partial_close 时必填: close_percentage (0-100)\n\n") + sb.WriteString("- 开仓时必填: leverage, position_size_usd, stop_loss, take_profit, confidence, risk_usd\n\n") return sb.String() } @@ -462,7 +469,7 @@ func buildUserPrompt(ctx *Context) string { } } - // 计算仓位价值(用于 partial_close 检查) + // 计算仓位价值 positionValue := math.Abs(pos.Quantity) * pos.MarkPrice sb.WriteString(fmt.Sprintf("%d. %s %s | 入场价%.4f 当前价%.4f | 数量%.4f | 仓位价值%.2f USDT | 盈亏%+.2f%% | 盈亏金额%+.2f USDT | 最高收益率%.2f%% | 杠杆%dx | 保证金%.0f | 强平价%.4f%s\n\n", @@ -480,6 +487,38 @@ func buildUserPrompt(ctx *Context) string { sb.WriteString("当前持仓: 无\n\n") } + // 交易统计(如果有) + if ctx.TradingStats != nil && ctx.TradingStats.TotalTrades > 0 { + sb.WriteString("## 历史交易统计\n") + sb.WriteString(fmt.Sprintf("总交易数: %d | 胜率: %.1f%% | 盈亏比: %.2f | 夏普比: %.2f\n", + ctx.TradingStats.TotalTrades, + ctx.TradingStats.WinRate, + ctx.TradingStats.ProfitFactor, + ctx.TradingStats.SharpeRatio)) + sb.WriteString(fmt.Sprintf("总盈亏: %.2f USDT | 平均盈利: %.2f | 平均亏损: %.2f | 最大回撤: %.1f%%\n\n", + ctx.TradingStats.TotalPnL, + ctx.TradingStats.AvgWin, + ctx.TradingStats.AvgLoss, + ctx.TradingStats.MaxDrawdownPct)) + } + + // 最近完成的订单(如果有) + if len(ctx.RecentOrders) > 0 { + sb.WriteString("## 最近完成的交易\n") + for i, order := range ctx.RecentOrders { + resultStr := "盈利" + if order.RealizedPnL < 0 { + resultStr = "亏损" + } + sb.WriteString(fmt.Sprintf("%d. %s %s | 入场%.4f 出场%.4f | %s: %+.2f USDT (%+.2f%%) | %s\n", + i+1, order.Symbol, order.Side, + order.EntryPrice, order.ExitPrice, + resultStr, order.RealizedPnL, order.PnLPct, + order.FilledAt)) + } + sb.WriteString("\n") + } + // 候选币种(完整市场数据) sb.WriteString(fmt.Sprintf("## 候选币种 (%d个)\n\n", len(ctx.MarketDataMap))) displayedCount := 0 @@ -504,20 +543,6 @@ func buildUserPrompt(ctx *Context) string { } sb.WriteString("\n") - // 夏普比率(直接传值,不要复杂格式化) - if ctx.Performance != nil { - // 直接从interface{}中提取SharpeRatio - type PerformanceData struct { - SharpeRatio float64 `json:"sharpe_ratio"` - } - var perfData PerformanceData - if jsonData, err := json.Marshal(ctx.Performance); err == nil { - if err := json.Unmarshal(jsonData, &perfData); err == nil { - sb.WriteString(fmt.Sprintf("## 📊 夏普比率: %.2f\n\n", perfData.SharpeRatio)) - } - } - } - sb.WriteString("---\n\n") sb.WriteString("现在请分析并输出决策(思维链 + JSON)\n") @@ -556,20 +581,20 @@ func parseFullDecisionResponse(aiResponse string, accountEquity float64, btcEthL func extractCoTTrace(response string) string { // 方法1: 优先尝试提取 标签内容 if match := reReasoningTag.FindStringSubmatch(response); match != nil && len(match) > 1 { - log.Printf("✓ 使用 标签提取思维链") + logger.Infof("✓ 使用 标签提取思维链") return strings.TrimSpace(match[1]) } // 方法2: 如果没有 标签,但有 标签,提取 之前的内容 if decisionIdx := strings.Index(response, ""); decisionIdx > 0 { - log.Printf("✓ 提取 标签之前的内容作为思维链") + logger.Infof("✓ 提取 标签之前的内容作为思维链") return strings.TrimSpace(response[:decisionIdx]) } // 方法3: 后备方案 - 查找JSON数组的开始位置 jsonStart := strings.Index(response, "[") if jsonStart > 0 { - log.Printf("⚠️ 使用旧版格式([ 字符分离)提取思维链") + logger.Infof("⚠️ 使用旧版格式([ 字符分离)提取思维链") return strings.TrimSpace(response[:jsonStart]) } @@ -591,11 +616,11 @@ func extractDecisions(response string) ([]Decision, error) { var jsonPart string if match := reDecisionTag.FindStringSubmatch(s); match != nil && len(match) > 1 { jsonPart = strings.TrimSpace(match[1]) - log.Printf("✓ 使用 标签提取JSON") + logger.Infof("✓ 使用 标签提取JSON") } else { // 后备方案:使用整个响应 jsonPart = s - log.Printf("⚠️ 未找到 标签,使用全文搜索JSON") + logger.Infof("⚠️ 未找到 标签,使用全文搜索JSON") } // 修复 jsonPart 中的全角字符 @@ -621,7 +646,7 @@ func extractDecisions(response string) ([]Decision, error) { jsonContent := strings.TrimSpace(reJSONArray.FindString(jsonPart)) if jsonContent == "" { // 🔧 安全回退 (Safe Fallback):当AI只输出思维链没有JSON时,生成保底决策(避免系统崩溃) - log.Printf("⚠️ [SafeFallback] AI未输出JSON决策,进入安全等待模式 (AI response without JSON, entering safe wait mode)") + logger.Infof("⚠️ [SafeFallback] AI未输出JSON决策,进入安全等待模式 (AI response without JSON, entering safe wait mode)") // 提取思维链摘要(最多 240 字符) cotSummary := jsonPart @@ -773,15 +798,12 @@ func findMatchingBracket(s string, start int) int { func validateDecision(d *Decision, accountEquity float64, btcEthLeverage, altcoinLeverage int) error { // 验证action validActions := map[string]bool{ - "open_long": true, - "open_short": true, - "close_long": true, - "close_short": true, - "update_stop_loss": true, - "update_take_profit": true, - "partial_close": true, - "hold": true, - "wait": true, + "open_long": true, + "open_short": true, + "close_long": true, + "close_short": true, + "hold": true, + "wait": true, } if !validActions[d.Action] { @@ -803,7 +825,7 @@ func validateDecision(d *Decision, accountEquity float64, btcEthLeverage, altcoi return fmt.Errorf("杠杆必须大于0: %d", d.Leverage) } if d.Leverage > maxLeverage { - log.Printf("⚠️ [Leverage Fallback] %s 杠杆超限 (%dx > %dx),自动调整为上限值 %dx", + logger.Infof("⚠️ [Leverage Fallback] %s 杠杆超限 (%dx > %dx),自动调整为上限值 %dx", d.Symbol, d.Leverage, maxLeverage, maxLeverage) d.Leverage = maxLeverage // 自动修正为上限值 } @@ -883,26 +905,5 @@ func validateDecision(d *Decision, accountEquity float64, btcEthLeverage, altcoi } } - // 动态调整止损验证 - if d.Action == "update_stop_loss" { - if d.NewStopLoss <= 0 { - return fmt.Errorf("新止损价格必须大于0: %.2f", d.NewStopLoss) - } - } - - // 动态调整止盈验证 - if d.Action == "update_take_profit" { - if d.NewTakeProfit <= 0 { - return fmt.Errorf("新止盈价格必须大于0: %.2f", d.NewTakeProfit) - } - } - - // 部分平仓验证 - if d.Action == "partial_close" { - if d.ClosePercentage <= 0 || d.ClosePercentage > 100 { - return fmt.Errorf("平仓百分比必须在0-100之间: %.1f", d.ClosePercentage) - } - } - return nil } diff --git a/decision/prompt_test.go b/decision/prompt_test.go index 21c64830..69bec67f 100644 --- a/decision/prompt_test.go +++ b/decision/prompt_test.go @@ -13,9 +13,6 @@ func TestBuildSystemPrompt_ContainsAllValidActions(t *testing.T) { "open_short", "close_long", "close_short", - "update_stop_loss", - "update_take_profit", - "partial_close", "hold", "wait", } @@ -30,21 +27,3 @@ func TestBuildSystemPrompt_ContainsAllValidActions(t *testing.T) { } } } - -// TestBuildSystemPrompt_ActionListCompleteness 测试 action 列表的完整性 -func TestBuildSystemPrompt_ActionListCompleteness(t *testing.T) { - prompt := buildSystemPrompt(1000.0, 10, 5, "default", "") - - // 检查是否包含关键的缺失 action - missingActions := []string{ - "update_stop_loss", - "update_take_profit", - "partial_close", - } - - for _, action := range missingActions { - if !strings.Contains(prompt, action) { - t.Errorf("Prompt 缺少关键 action: %s(这会导致 AI 返回无效决策)", action) - } - } -} diff --git a/decision/validate_test.go b/decision/validate_test.go index d7e89229..468f9778 100644 --- a/decision/validate_test.go +++ b/decision/validate_test.go @@ -99,185 +99,6 @@ func TestLeverageFallback(t *testing.T) { } } -// TestUpdateStopLossValidation 测试 update_stop_loss 动作的字段验证 -func TestUpdateStopLossValidation(t *testing.T) { - tests := []struct { - name string - decision Decision - wantError bool - errorMsg string - }{ - { - name: "正确使用new_stop_loss字段", - decision: Decision{ - Symbol: "SOLUSDT", - Action: "update_stop_loss", - NewStopLoss: 155.5, - Reasoning: "移动止损至保本位", - }, - wantError: false, - }, - { - name: "new_stop_loss为0应该报错", - decision: Decision{ - Symbol: "SOLUSDT", - Action: "update_stop_loss", - NewStopLoss: 0, - Reasoning: "测试错误情况", - }, - wantError: true, - errorMsg: "新止损价格必须大于0", - }, - { - name: "new_stop_loss为负数应该报错", - decision: Decision{ - Symbol: "SOLUSDT", - Action: "update_stop_loss", - NewStopLoss: -100, - Reasoning: "测试错误情况", - }, - wantError: true, - errorMsg: "新止损价格必须大于0", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateDecision(&tt.decision, 1000.0, 10, 5) - - if (err != nil) != tt.wantError { - t.Errorf("validateDecision() error = %v, wantError %v", err, tt.wantError) - return - } - - if tt.wantError && err != nil { - if tt.errorMsg != "" && !contains(err.Error(), tt.errorMsg) { - t.Errorf("错误信息不匹配: got %q, want to contain %q", err.Error(), tt.errorMsg) - } - } - }) - } -} - -// TestUpdateTakeProfitValidation 测试 update_take_profit 动作的字段验证 -func TestUpdateTakeProfitValidation(t *testing.T) { - tests := []struct { - name string - decision Decision - wantError bool - errorMsg string - }{ - { - name: "正确使用new_take_profit字段", - decision: Decision{ - Symbol: "BTCUSDT", - Action: "update_take_profit", - NewTakeProfit: 98000, - Reasoning: "调整止盈至关键阻力位", - }, - wantError: false, - }, - { - name: "new_take_profit为0应该报错", - decision: Decision{ - Symbol: "BTCUSDT", - Action: "update_take_profit", - NewTakeProfit: 0, - Reasoning: "测试错误情况", - }, - wantError: true, - errorMsg: "新止盈价格必须大于0", - }, - { - name: "new_take_profit为负数应该报错", - decision: Decision{ - Symbol: "BTCUSDT", - Action: "update_take_profit", - NewTakeProfit: -1000, - Reasoning: "测试错误情况", - }, - wantError: true, - errorMsg: "新止盈价格必须大于0", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateDecision(&tt.decision, 1000.0, 10, 5) - - if (err != nil) != tt.wantError { - t.Errorf("validateDecision() error = %v, wantError %v", err, tt.wantError) - return - } - - if tt.wantError && err != nil { - if tt.errorMsg != "" && !contains(err.Error(), tt.errorMsg) { - t.Errorf("错误信息不匹配: got %q, want to contain %q", err.Error(), tt.errorMsg) - } - } - }) - } -} - -// TestPartialCloseValidation 测试 partial_close 动作的字段验证 -func TestPartialCloseValidation(t *testing.T) { - tests := []struct { - name string - decision Decision - wantError bool - errorMsg string - }{ - { - name: "正确使用close_percentage字段", - decision: Decision{ - Symbol: "ETHUSDT", - Action: "partial_close", - ClosePercentage: 50.0, - Reasoning: "锁定一半利润", - }, - wantError: false, - }, - { - name: "close_percentage为0应该报错", - decision: Decision{ - Symbol: "ETHUSDT", - Action: "partial_close", - ClosePercentage: 0, - Reasoning: "测试错误情况", - }, - wantError: true, - errorMsg: "平仓百分比必须在0-100之间", - }, - { - name: "close_percentage超过100应该报错", - decision: Decision{ - Symbol: "ETHUSDT", - Action: "partial_close", - ClosePercentage: 150, - Reasoning: "测试错误情况", - }, - wantError: true, - errorMsg: "平仓百分比必须在0-100之间", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateDecision(&tt.decision, 1000.0, 10, 5) - - if (err != nil) != tt.wantError { - t.Errorf("validateDecision() error = %v, wantError %v", err, tt.wantError) - return - } - - if tt.wantError && err != nil { - if tt.errorMsg != "" && !contains(err.Error(), tt.errorMsg) { - t.Errorf("错误信息不匹配: got %q, want to contain %q", err.Error(), tt.errorMsg) - } - } - }) - } -} // contains 检查字符串是否包含子串(辅助函数) func contains(s, substr string) bool { diff --git a/deploy_encryption.sh b/deploy_encryption.sh deleted file mode 100755 index 93633c1a..00000000 --- a/deploy_encryption.sh +++ /dev/null @@ -1,286 +0,0 @@ -#!/bin/bash -# NOFX 加密系統一鍵部署腳本 -# 使用方式: chmod +x deploy_encryption.sh && ./deploy_encryption.sh - -set -e # 遇到錯誤立即退出 - -# 顏色定義 -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -NC='\033[0m' # No Color - -# 輔助函數 -log_info() { - echo -e "${BLUE}ℹ️ $1${NC}" -} - -log_success() { - echo -e "${GREEN}✅ $1${NC}" -} - -log_warning() { - echo -e "${YELLOW}⚠️ $1${NC}" -} - -log_error() { - echo -e "${RED}❌ $1${NC}" -} - -# 檢查必要工具 -check_dependencies() { - log_info "檢查依賴工具..." - - if ! command -v go &> /dev/null; then - log_error "Go 未安裝,請先安裝 Go 1.21+" - exit 1 - fi - - if ! command -v npm &> /dev/null; then - log_error "npm 未安裝,請先安裝 Node.js 18+" - exit 1 - fi - - if ! command -v sqlite3 &> /dev/null; then - log_warning "sqlite3 未安裝,部分驗證功能不可用" - fi - - log_success "依賴檢查通過" -} - -# 備份數據庫 -backup_database() { - log_info "備份現有數據庫..." - - if [ -f "config.db" ]; then - BACKUP_FILE="config.db.pre_encryption.$(date +%Y%m%d_%H%M%S).backup" - cp config.db "$BACKUP_FILE" - log_success "數據庫已備份到: $BACKUP_FILE" - else - log_warning "未找到 config.db,跳過備份(首次安裝)" - fi -} - -# 創建密鑰目錄 -setup_secrets_dir() { - log_info "設置密鑰目錄..." - - if [ ! -d ".secrets" ]; then - mkdir -p .secrets - chmod 700 .secrets - log_success "密鑰目錄已創建: .secrets/" - else - log_warning "密鑰目錄已存在,跳過創建" - fi -} - -# 更新 .gitignore -update_gitignore() { - log_info "更新 .gitignore..." - - if ! grep -q ".secrets/" .gitignore 2>/dev/null; then - echo ".secrets/" >> .gitignore - log_success "已添加 .secrets/ 到 .gitignore" - fi - - if ! grep -q "config.db.backup" .gitignore 2>/dev/null; then - echo "config.db.*.backup" >> .gitignore - log_success "已添加備份檔案規則到 .gitignore" - fi -} - -# 安裝依賴 -install_dependencies() { - log_info "安裝 Go 依賴..." - go mod tidy - log_success "Go 依賴已更新" - - log_info "安裝前端依賴..." - cd web - if [ ! -d "node_modules" ]; then - npm install - fi - npm install tweetnacl tweetnacl-util @noble/secp256k1 --save - cd .. - log_success "前端依賴已安裝" -} - -# 運行測試 -run_tests() { - log_info "運行加密系統測試..." - - if go test ./crypto -v > /tmp/nofx_test.log 2>&1; then - log_success "加密系統測試通過" - cat /tmp/nofx_test.log | grep "✅" - else - log_error "加密系統測試失敗,詳情:" - cat /tmp/nofx_test.log - exit 1 - fi -} - -# 遷移數據 -migrate_data() { - log_info "遷移現有數據到加密格式..." - - if [ -f "config.db" ]; then - # 檢查是否已經加密過 - if sqlite3 config.db "SELECT api_key FROM exchanges LIMIT 1;" 2>/dev/null | grep -q "=="; then - log_warning "數據庫似乎已經加密過,跳過遷移" - read -p "是否強制重新遷移?(y/N): " -n 1 -r - echo - if [[ ! $REPLY =~ ^[Yy]$ ]]; then - return - fi - fi - - if go run scripts/migrate_encryption.go; then - log_success "數據遷移完成" - else - log_error "數據遷移失敗" - exit 1 - fi - else - log_warning "未找到數據庫,跳過遷移" - fi -} - -# 設置環境變數 -setup_env_vars() { - log_info "設置環境變數..." - - if [ -f ".secrets/master.key" ]; then - MASTER_KEY=$(cat .secrets/master.key) - - # 添加到當前 shell 配置 - SHELL_RC="$HOME/.bashrc" - if [ -f "$HOME/.zshrc" ]; then - SHELL_RC="$HOME/.zshrc" - fi - - if ! grep -q "NOFX_MASTER_KEY" "$SHELL_RC" 2>/dev/null; then - echo "" >> "$SHELL_RC" - echo "# NOFX 加密系統主密鑰" >> "$SHELL_RC" - echo "export NOFX_MASTER_KEY='$MASTER_KEY'" >> "$SHELL_RC" - log_success "主密鑰已添加到 $SHELL_RC" - else - log_warning "主密鑰已存在於 $SHELL_RC" - fi - - # 導出到當前 session - export NOFX_MASTER_KEY="$MASTER_KEY" - log_success "主密鑰已導出到當前 session" - else - log_warning "主密鑰文件未生成,請先運行應用初始化" - fi -} - -# 驗證部署 -verify_deployment() { - log_info "驗證部署結果..." - - # 1. 檢查密鑰檔案 - if [ -f ".secrets/rsa_private.pem" ] && [ -f ".secrets/rsa_public.pem" ] && [ -f ".secrets/master.key" ]; then - log_success "密鑰檔案完整" - else - log_error "密鑰檔案缺失,請檢查日誌" - return 1 - fi - - # 2. 檢查檔案權限 - PERM=$(stat -f "%Lp" .secrets 2>/dev/null || stat -c "%a" .secrets 2>/dev/null) - if [ "$PERM" = "700" ]; then - log_success "密鑰目錄權限正確 (700)" - else - log_warning "密鑰目錄權限為 $PERM,建議修改為 700" - chmod 700 .secrets - fi - - # 3. 檢查資料庫加密 - if [ -f "config.db" ] && command -v sqlite3 &> /dev/null; then - SAMPLE=$(sqlite3 config.db "SELECT api_key FROM exchanges WHERE api_key != '' LIMIT 1;" 2>/dev/null || echo "") - if echo "$SAMPLE" | grep -q "=="; then - log_success "數據庫密鑰已加密(Base64 格式)" - else - log_warning "數據庫可能未加密或無數據" - fi - fi - - log_success "部署驗證通過" -} - -# 打印後續步驟 -print_next_steps() { - echo "" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo -e "${GREEN}🎉 加密系統部署成功!${NC}" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "" - echo "📝 後續步驟:" - echo "" - echo " 1️⃣ 啟動後端服務:" - echo " $ go run main.go" - echo "" - echo " 2️⃣ 啟動前端服務:" - echo " $ cd web && npm run dev" - echo "" - echo " 3️⃣ 驗證加密功能:" - echo " $ curl http://localhost:8080/api/crypto/public-key" - echo "" - echo " 4️⃣ 查看審計日誌:" - echo " $ sqlite3 config.db 'SELECT * FROM audit_logs ORDER BY timestamp DESC LIMIT 10;'" - echo "" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "" - echo "⚠️ 重要提醒:" - echo "" - echo " • 請妥善保管 .secrets/ 目錄(已設置為 700 權限)" - echo " • 生產環境務必使用環境變數管理主密鑰" - echo " • 定期執行密鑰輪換(建議每季度一次)" - echo " • 數據庫備份已保存,驗證無誤後可手動刪除" - echo "" - echo "📚 詳細文檔:" - echo " - 快速開始: cat SECURITY_QUICKSTART.md" - echo " - 完整指南: cat ENCRYPTION_DEPLOYMENT.md" - echo "" -} - -# 主函數 -main() { - echo "" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo -e "${BLUE}🔐 NOFX 加密系統部署腳本${NC}" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "" - - # 確認執行 - log_warning "此腳本將:" - echo " 1. 備份現有數據庫" - echo " 2. 生成 RSA-4096 密鑰對" - echo " 3. 生成 AES-256 主密鑰" - echo " 4. 遷移現有數據到加密格式" - echo " 5. 設置環境變數" - echo "" - read -p "是否繼續?(y/N): " -n 1 -r - echo - if [[ ! $REPLY =~ ^[Yy]$ ]]; then - log_info "已取消部署" - exit 0 - fi - - # 執行部署步驟 - check_dependencies - backup_database - setup_secrets_dir - update_gitignore - install_dependencies - run_tests - migrate_data - setup_env_vars - verify_deployment - print_next_steps -} - -# 執行主函數 -main diff --git a/docker-compose.yml b/docker-compose.yml index 0fe50998..72be2ed1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,17 +11,17 @@ services: - "${NOFX_BACKEND_PORT:-8080}:8080" volumes: - ./config.json:/app/config.json:ro - - ./config.db:/app/config.db + - ./data.db:/app/data.db - ./beta_codes.txt:/app/beta_codes.txt:ro - ./decision_logs:/app/decision_logs - ./prompts:/app/prompts - - ./secrets:/app/secrets:ro # RSA密钥文件 - /etc/localtime:/etc/localtime:ro # Sync host time environment: - TZ=${NOFX_TIMEZONE:-Asia/Shanghai} # Set timezone - AI_MAX_TOKENS=4000 # AI响应的最大token数(默认2000,建议4000-8000) - DATA_ENCRYPTION_KEY=${DATA_ENCRYPTION_KEY} # 数据库加密密钥 - JWT_SECRET=${JWT_SECRET} # JWT认证密钥 + - RSA_PRIVATE_KEY=${RSA_PRIVATE_KEY} # RSA私钥(客户端加密) networks: - nofx-network healthcheck: diff --git a/logger/config.go b/logger/config.go index 32774558..f18eb041 100644 --- a/logger/config.go +++ b/logger/config.go @@ -1,21 +1,8 @@ package logger -import ( - "github.com/sirupsen/logrus" -) - // Config 日志配置(简化版) type Config struct { - Level string `json:"level"` // 日志级别: debug, info, warn, error (默认: info) - Telegram *TelegramConfig `json:"telegram"` // Telegram推送配置(可选) -} - -// TelegramConfig Telegram推送配置(简化版,高级参数使用默认值) -type TelegramConfig struct { - Enabled bool `json:"enabled"` // 是否启用(默认: false) - BotToken string `json:"bot_token"` // Bot Token - ChatID int64 `json:"chat_id"` // Chat ID - MinLevel string `json:"min_level"` // 最低日志级别,该级别及以上的日志会推送到Telegram(可选,默认: error) + Level string `json:"level"` // 日志级别: debug, info, warn, error (默认: info) } // SetDefaults 设置默认值 @@ -24,41 +11,3 @@ func (c *Config) SetDefaults() { c.Level = "info" } } - -// GetLogrusLevels 返回要推送到Telegram的日志级别 -// 根据配置的MinLevel返回该级别及以上的所有日志级别 -// 如果未配置或配置无效,默认返回error, fatal, panic(向后兼容) -func (tc *TelegramConfig) GetLogrusLevels() []logrus.Level { - // 如果未配置,使用默认值error(向后兼容) - minLevelStr := tc.MinLevel - if minLevelStr == "" { - minLevelStr = "error" - } - - // 解析配置的日志级别 - minLevel, err := logrus.ParseLevel(minLevelStr) - if err != nil { - // 如果解析失败,使用默认值error(向后兼容) - minLevel = logrus.ErrorLevel - } - - // 定义所有日志级别(从高到低:panic, fatal, error, warn, info, debug) - allLevels := []logrus.Level{ - logrus.PanicLevel, - logrus.FatalLevel, - logrus.ErrorLevel, - logrus.WarnLevel, - logrus.InfoLevel, - logrus.DebugLevel, - } - - // 返回所有大于等于minLevel的日志级别 - var result []logrus.Level - for _, level := range allLevels { - if level <= minLevel { - result = append(result, level) - } - } - - return result -} diff --git a/logger/config.telegram.json b/logger/config.telegram.json deleted file mode 100644 index 197c0802..00000000 --- a/logger/config.telegram.json +++ /dev/null @@ -1,33 +0,0 @@ -{ - "traders": [ - { - "id": "trader1", - "name": "AI Trader 1", - "enabled": true, - "ai_model": "deepseek", - "exchange": "binance", - "binance_api_key": "your_api_key", - "binance_secret_key": "your_secret_key", - "deepseek_key": "your_deepseek_key", - "initial_balance": 1000, - "scan_interval_minutes": 3 - } - ], - "use_default_coins": true, - "default_coins": ["BTCUSDT", "ETHUSDT", "SOLUSDT"], - "api_server_port": 8080, - "leverage": { - "btc_eth_leverage": 5, - "altcoin_leverage": 5 - }, - "log": { - "level": "info", - "telegram": { - "enabled": true, - "bot_token": "79472419:feafe231414", - "chat_id": -100323252626, - "min_level": "error" - } - }, - "_comment": "日志配置说明:level 可选值为 debug/info/warn/error,默认 info。telegram 部分作为可选配置, Telegram 推送默认为 error/fatal/panic 级别,min_level 如果设置为warn,则推送warn级别及以上的日志" -} diff --git a/logger/decision_logger.go b/logger/decision_logger.go deleted file mode 100644 index 2ac77c88..00000000 --- a/logger/decision_logger.go +++ /dev/null @@ -1,768 +0,0 @@ -package logger - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "math" - "os" - "path/filepath" - "time" -) - -// DecisionRecord 决策记录 -type DecisionRecord struct { - Timestamp time.Time `json:"timestamp"` // 决策时间 - CycleNumber int `json:"cycle_number"` // 周期编号 - SystemPrompt string `json:"system_prompt"` // 系统提示词(发送给AI的系统prompt) - InputPrompt string `json:"input_prompt"` // 发送给AI的输入prompt - CoTTrace string `json:"cot_trace"` // AI思维链(输出) - DecisionJSON string `json:"decision_json"` // 决策JSON - AccountState AccountSnapshot `json:"account_state"` // 账户状态快照 - Positions []PositionSnapshot `json:"positions"` // 持仓快照 - CandidateCoins []string `json:"candidate_coins"` // 候选币种列表 - Decisions []DecisionAction `json:"decisions"` // 执行的决策 - ExecutionLog []string `json:"execution_log"` // 执行日志 - Success bool `json:"success"` // 是否成功 - ErrorMessage string `json:"error_message"` // 错误信息(如果有) - // AIRequestDurationMs 记录 AI API 调用耗时(毫秒),方便评估调用性能 - AIRequestDurationMs int64 `json:"ai_request_duration_ms,omitempty"` -} - -// AccountSnapshot 账户状态快照 -type AccountSnapshot struct { - TotalBalance float64 `json:"total_balance"` - AvailableBalance float64 `json:"available_balance"` - TotalUnrealizedProfit float64 `json:"total_unrealized_profit"` - PositionCount int `json:"position_count"` - MarginUsedPct float64 `json:"margin_used_pct"` - InitialBalance float64 `json:"initial_balance"` // 记录当时的初始余额基准 -} - -// PositionSnapshot 持仓快照 -type PositionSnapshot struct { - Symbol string `json:"symbol"` - Side string `json:"side"` - PositionAmt float64 `json:"position_amt"` - EntryPrice float64 `json:"entry_price"` - MarkPrice float64 `json:"mark_price"` - UnrealizedProfit float64 `json:"unrealized_profit"` - Leverage float64 `json:"leverage"` - LiquidationPrice float64 `json:"liquidation_price"` -} - -// DecisionAction 决策动作 -type DecisionAction struct { - Action string `json:"action"` // open_long, open_short, close_long, close_short, update_stop_loss, update_take_profit, partial_close - Symbol string `json:"symbol"` // 币种 - Quantity float64 `json:"quantity"` // 数量(部分平仓时使用) - Leverage int `json:"leverage"` // 杠杆(开仓时) - Price float64 `json:"price"` // 执行价格 - OrderID int64 `json:"order_id"` // 订单ID - Timestamp time.Time `json:"timestamp"` // 执行时间 - Success bool `json:"success"` // 是否成功 - Error string `json:"error"` // 错误信息 -} - -// IDecisionLogger 决策日志记录器接口 -type IDecisionLogger interface { - // LogDecision 记录决策 - LogDecision(record *DecisionRecord) error - // GetLatestRecords 获取最近N条记录(按时间正序:从旧到新) - GetLatestRecords(n int) ([]*DecisionRecord, error) - // GetRecordByDate 获取指定日期的所有记录 - GetRecordByDate(date time.Time) ([]*DecisionRecord, error) - // CleanOldRecords 清理N天前的旧记录 - CleanOldRecords(days int) error - // GetStatistics 获取统计信息 - GetStatistics() (*Statistics, error) - // AnalyzePerformance 分析最近N个周期的交易表现 - AnalyzePerformance(lookbackCycles int) (*PerformanceAnalysis, error) - // SetCycleNumber 允许恢复内部计数(用于回测恢复) - SetCycleNumber(n int) -} - -// DecisionLogger 决策日志记录器 -type DecisionLogger struct { - logDir string - cycleNumber int -} - -// NewDecisionLogger 创建决策日志记录器 -func NewDecisionLogger(logDir string) IDecisionLogger { - if logDir == "" { - logDir = "decision_logs" - } - - // 确保日志目录存在(使用安全权限:只有所有者可访问) - if err := os.MkdirAll(logDir, 0700); err != nil { - fmt.Printf("⚠ 创建日志目录失败: %v\n", err) - } - - // 强制设置目录权限(即使目录已存在)- 确保安全 - if err := os.Chmod(logDir, 0700); err != nil { - fmt.Printf("⚠ 设置日志目录权限失败: %v\n", err) - } - - return &DecisionLogger{ - logDir: logDir, - cycleNumber: 0, - } -} - -// SetCycleNumber 允许外部恢复内部的周期计数(用于回测恢复)。 -func (l *DecisionLogger) SetCycleNumber(n int) { - if n > 0 { - l.cycleNumber = n - } -} - -// LogDecision 记录决策 -func (l *DecisionLogger) LogDecision(record *DecisionRecord) error { - l.cycleNumber++ - record.CycleNumber = l.cycleNumber - if record.Timestamp.IsZero() { - record.Timestamp = time.Now().UTC() - } else { - record.Timestamp = record.Timestamp.UTC() - } - - // 生成文件名:decision_YYYYMMDD_HHMMSS_cycleN.json - filename := fmt.Sprintf("decision_%s_cycle%d.json", - record.Timestamp.Format("20060102_150405"), - record.CycleNumber) - - filepath := filepath.Join(l.logDir, filename) - - // 序列化为JSON(带缩进,方便阅读) - data, err := json.MarshalIndent(record, "", " ") - if err != nil { - return fmt.Errorf("序列化决策记录失败: %w", err) - } - - // 写入文件(使用安全权限:只有所有者可读写) - if err := ioutil.WriteFile(filepath, data, 0600); err != nil { - return fmt.Errorf("写入决策记录失败: %w", err) - } - - fmt.Printf("📝 决策记录已保存: %s\n", filename) - return nil -} - -// GetLatestRecords 获取最近N条记录(按时间正序:从旧到新) -func (l *DecisionLogger) GetLatestRecords(n int) ([]*DecisionRecord, error) { - files, err := ioutil.ReadDir(l.logDir) - if err != nil { - return nil, fmt.Errorf("读取日志目录失败: %w", err) - } - - // 先按修改时间倒序收集(最新的在前) - var records []*DecisionRecord - count := 0 - for i := len(files) - 1; i >= 0 && count < n; i-- { - file := files[i] - if file.IsDir() { - continue - } - - filepath := filepath.Join(l.logDir, file.Name()) - data, err := ioutil.ReadFile(filepath) - if err != nil { - continue - } - - var record DecisionRecord - if err := json.Unmarshal(data, &record); err != nil { - continue - } - - records = append(records, &record) - count++ - } - - // 反转数组,让时间从旧到新排列(用于图表显示) - for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 { - records[i], records[j] = records[j], records[i] - } - - return records, nil -} - -// GetRecordByDate 获取指定日期的所有记录 -func (l *DecisionLogger) GetRecordByDate(date time.Time) ([]*DecisionRecord, error) { - dateStr := date.Format("20060102") - pattern := filepath.Join(l.logDir, fmt.Sprintf("decision_%s_*.json", dateStr)) - - files, err := filepath.Glob(pattern) - if err != nil { - return nil, fmt.Errorf("查找日志文件失败: %w", err) - } - - var records []*DecisionRecord - for _, filepath := range files { - data, err := ioutil.ReadFile(filepath) - if err != nil { - continue - } - - var record DecisionRecord - if err := json.Unmarshal(data, &record); err != nil { - continue - } - - records = append(records, &record) - } - - return records, nil -} - -// CleanOldRecords 清理N天前的旧记录 -func (l *DecisionLogger) CleanOldRecords(days int) error { - cutoffTime := time.Now().AddDate(0, 0, -days) - - files, err := ioutil.ReadDir(l.logDir) - if err != nil { - return fmt.Errorf("读取日志目录失败: %w", err) - } - - removedCount := 0 - for _, file := range files { - if file.IsDir() { - continue - } - - if file.ModTime().Before(cutoffTime) { - filepath := filepath.Join(l.logDir, file.Name()) - if err := os.Remove(filepath); err != nil { - fmt.Printf("⚠ 删除旧记录失败 %s: %v\n", file.Name(), err) - continue - } - removedCount++ - } - } - - if removedCount > 0 { - fmt.Printf("🗑️ 已清理 %d 条旧记录(%d天前)\n", removedCount, days) - } - - return nil -} - -// GetStatistics 获取统计信息 -func (l *DecisionLogger) GetStatistics() (*Statistics, error) { - files, err := ioutil.ReadDir(l.logDir) - if err != nil { - return nil, fmt.Errorf("读取日志目录失败: %w", err) - } - - stats := &Statistics{} - - for _, file := range files { - if file.IsDir() { - continue - } - - filepath := filepath.Join(l.logDir, file.Name()) - data, err := ioutil.ReadFile(filepath) - if err != nil { - continue - } - - var record DecisionRecord - if err := json.Unmarshal(data, &record); err != nil { - continue - } - - stats.TotalCycles++ - - for _, action := range record.Decisions { - if action.Success { - switch action.Action { - case "open_long", "open_short": - stats.TotalOpenPositions++ - case "close_long", "close_short", "auto_close_long", "auto_close_short": - stats.TotalClosePositions++ - // 🔧 BUG FIX:partial_close 不計入 TotalClosePositions,避免重複計數 - // case "partial_close": // 不計數,因為只有完全平倉才算一次 - // update_stop_loss 和 update_take_profit 不計入統計 - } - } - } - - if record.Success { - stats.SuccessfulCycles++ - } else { - stats.FailedCycles++ - } - } - - return stats, nil -} - -// Statistics 统计信息 -type Statistics struct { - TotalCycles int `json:"total_cycles"` - SuccessfulCycles int `json:"successful_cycles"` - FailedCycles int `json:"failed_cycles"` - TotalOpenPositions int `json:"total_open_positions"` - TotalClosePositions int `json:"total_close_positions"` -} - -// TradeOutcome 单笔交易结果 -type TradeOutcome struct { - Symbol string `json:"symbol"` // 币种 - Side string `json:"side"` // long/short - Quantity float64 `json:"quantity"` // 仓位数量 - Leverage int `json:"leverage"` // 杠杆倍数 - OpenPrice float64 `json:"open_price"` // 开仓价 - ClosePrice float64 `json:"close_price"` // 平仓价 - PositionValue float64 `json:"position_value"` // 仓位价值(quantity × openPrice) - MarginUsed float64 `json:"margin_used"` // 保证金使用(positionValue / leverage) - PnL float64 `json:"pn_l"` // 盈亏(USDT) - PnLPct float64 `json:"pn_l_pct"` // 盈亏百分比(相对保证金) - Duration string `json:"duration"` // 持仓时长 - OpenTime time.Time `json:"open_time"` // 开仓时间 - CloseTime time.Time `json:"close_time"` // 平仓时间 - WasStopLoss bool `json:"was_stop_loss"` // 是否止损 -} - -// PerformanceAnalysis 交易表现分析 -type PerformanceAnalysis struct { - TotalTrades int `json:"total_trades"` // 总交易数 - WinningTrades int `json:"winning_trades"` // 盈利交易数 - LosingTrades int `json:"losing_trades"` // 亏损交易数 - WinRate float64 `json:"win_rate"` // 胜率 - AvgWin float64 `json:"avg_win"` // 平均盈利 - AvgLoss float64 `json:"avg_loss"` // 平均亏损 - ProfitFactor float64 `json:"profit_factor"` // 盈亏比 - SharpeRatio float64 `json:"sharpe_ratio"` // 夏普比率(风险调整后收益) - RecentTrades []TradeOutcome `json:"recent_trades"` // 最近N笔交易 - SymbolStats map[string]*SymbolPerformance `json:"symbol_stats"` // 各币种表现 - BestSymbol string `json:"best_symbol"` // 表现最好的币种 - WorstSymbol string `json:"worst_symbol"` // 表现最差的币种 -} - -// SymbolPerformance 币种表现统计 -type SymbolPerformance struct { - Symbol string `json:"symbol"` // 币种 - TotalTrades int `json:"total_trades"` // 交易次数 - WinningTrades int `json:"winning_trades"` // 盈利次数 - LosingTrades int `json:"losing_trades"` // 亏损次数 - WinRate float64 `json:"win_rate"` // 胜率 - TotalPnL float64 `json:"total_pn_l"` // 总盈亏 - AvgPnL float64 `json:"avg_pn_l"` // 平均盈亏 -} - -// AnalyzePerformance 分析最近N个周期的交易表现 -func (l *DecisionLogger) AnalyzePerformance(lookbackCycles int) (*PerformanceAnalysis, error) { - records, err := l.GetLatestRecords(lookbackCycles) - if err != nil { - return nil, fmt.Errorf("读取历史记录失败: %w", err) - } - - if len(records) == 0 { - return &PerformanceAnalysis{ - RecentTrades: []TradeOutcome{}, - SymbolStats: make(map[string]*SymbolPerformance), - }, nil - } - - analysis := &PerformanceAnalysis{ - RecentTrades: []TradeOutcome{}, - SymbolStats: make(map[string]*SymbolPerformance), - } - - // 追踪持仓状态:symbol_side -> {side, openPrice, openTime, quantity, leverage} - openPositions := make(map[string]map[string]interface{}) - - // 为了避免开仓记录在窗口外导致匹配失败,需要先从所有历史记录中找出未平仓的持仓 - // 获取更多历史记录来构建完整的持仓状态(使用更大的窗口) - allRecords, err := l.GetLatestRecords(lookbackCycles * 3) // 扩大3倍窗口 - if err == nil && len(allRecords) > len(records) { - // 先从扩大的窗口中收集所有开仓记录 - for _, record := range allRecords { - for _, action := range record.Decisions { - if !action.Success { - continue - } - - symbol := action.Symbol - side := "" - if action.Action == "open_long" || action.Action == "close_long" || action.Action == "partial_close" || action.Action == "auto_close_long" { - side = "long" - } else if action.Action == "open_short" || action.Action == "close_short" || action.Action == "auto_close_short" { - side = "short" - } - - // partial_close 需要根據持倉判斷方向 - if action.Action == "partial_close" && side == "" { - for key, pos := range openPositions { - if posSymbol, _ := pos["side"].(string); key == symbol+"_"+posSymbol { - side = posSymbol - break - } - } - } - - posKey := symbol + "_" + side - - switch action.Action { - case "open_long", "open_short": - // 记录开仓 - openPositions[posKey] = map[string]interface{}{ - "side": side, - "openPrice": action.Price, - "openTime": action.Timestamp, - "quantity": action.Quantity, - "leverage": action.Leverage, - } - case "close_long", "close_short", "auto_close_long", "auto_close_short": - // 移除已平仓记录 - delete(openPositions, posKey) - // partial_close 不處理,保留持倉記錄 - } - } - } - } - - // 遍历分析窗口内的记录,生成交易结果 - for _, record := range records { - for _, action := range record.Decisions { - if !action.Success { - continue - } - - symbol := action.Symbol - side := "" - if action.Action == "open_long" || action.Action == "close_long" || action.Action == "partial_close" || action.Action == "auto_close_long" { - side = "long" - } else if action.Action == "open_short" || action.Action == "close_short" || action.Action == "auto_close_short" { - side = "short" - } - - // partial_close 需要根據持倉判斷方向 - if action.Action == "partial_close" { - // 從 openPositions 中查找持倉方向 - for key, pos := range openPositions { - if posSymbol, _ := pos["side"].(string); key == symbol+"_"+posSymbol { - side = posSymbol - break - } - } - } - - posKey := symbol + "_" + side // 使用symbol_side作为key,区分多空持仓 - - switch action.Action { - case "open_long", "open_short": - // 更新开仓记录(可能已经在预填充时记录过了) - openPositions[posKey] = map[string]interface{}{ - "side": side, - "openPrice": action.Price, - "openTime": action.Timestamp, - "quantity": action.Quantity, - "leverage": action.Leverage, - "remainingQuantity": action.Quantity, // 🔧 BUG FIX:追蹤剩餘數量 - "accumulatedPnL": 0.0, // 🔧 BUG FIX:累積部分平倉盈虧 - "partialCloseCount": 0, // 🔧 BUG FIX:部分平倉次數 - "partialCloseVolume": 0.0, // 🔧 BUG FIX:部分平倉總量 - } - - case "close_long", "close_short", "partial_close", "auto_close_long", "auto_close_short": - // 查找对应的开仓记录(可能来自预填充或当前窗口) - if openPos, exists := openPositions[posKey]; exists { - openPrice := openPos["openPrice"].(float64) - openTime := openPos["openTime"].(time.Time) - side := openPos["side"].(string) - quantity := openPos["quantity"].(float64) - leverage := openPos["leverage"].(int) - - // 🔧 BUG FIX:取得追蹤字段(若不存在則初始化) - remainingQty, _ := openPos["remainingQuantity"].(float64) - if remainingQty == 0 { - remainingQty = quantity // 兼容舊數據(沒有 remainingQuantity 字段) - } - accumulatedPnL, _ := openPos["accumulatedPnL"].(float64) - partialCloseCount, _ := openPos["partialCloseCount"].(int) - partialCloseVolume, _ := openPos["partialCloseVolume"].(float64) - - // 对于 partial_close,使用实际平仓数量;否则使用剩余仓位数量 - actualQuantity := remainingQty - if action.Action == "partial_close" { - actualQuantity = action.Quantity - } - - // 计算本次平仓的盈亏(USDT) - var pnl float64 - if side == "long" { - pnl = actualQuantity * (action.Price - openPrice) - } else { - pnl = actualQuantity * (openPrice - action.Price) - } - - // 🔧 BUG FIX:處理 partial_close 聚合邏輯 - if action.Action == "partial_close" { - // 累積盈虧和數量 - accumulatedPnL += pnl - remainingQty -= actualQuantity - partialCloseCount++ - partialCloseVolume += actualQuantity - - // 更新 openPositions(保留持倉記錄,但更新追蹤數據) - openPos["remainingQuantity"] = remainingQty - openPos["accumulatedPnL"] = accumulatedPnL - openPos["partialCloseCount"] = partialCloseCount - openPos["partialCloseVolume"] = partialCloseVolume - - // 判斷是否已完全平倉 - if remainingQty <= 0.0001 { // 使用小閾值避免浮點誤差 - // ✅ 完全平倉:記錄為一筆完整交易 - positionValue := quantity * openPrice - marginUsed := positionValue / float64(leverage) - pnlPct := 0.0 - if marginUsed > 0 { - pnlPct = (accumulatedPnL / marginUsed) * 100 - } - - outcome := TradeOutcome{ - Symbol: symbol, - Side: side, - Quantity: quantity, // 使用原始總量 - Leverage: leverage, - OpenPrice: openPrice, - ClosePrice: action.Price, // 最後一次平倉價格 - PositionValue: positionValue, - MarginUsed: marginUsed, - PnL: accumulatedPnL, // 🔧 使用累積盈虧 - PnLPct: pnlPct, - Duration: action.Timestamp.Sub(openTime).String(), - OpenTime: openTime, - CloseTime: action.Timestamp, - } - - analysis.RecentTrades = append(analysis.RecentTrades, outcome) - analysis.TotalTrades++ // 🔧 只在完全平倉時計數 - - // 分类交易 - if accumulatedPnL > 0 { - analysis.WinningTrades++ - analysis.AvgWin += accumulatedPnL - } else if accumulatedPnL < 0 { - analysis.LosingTrades++ - analysis.AvgLoss += accumulatedPnL - } - - // 更新币种统计 - if _, exists := analysis.SymbolStats[symbol]; !exists { - analysis.SymbolStats[symbol] = &SymbolPerformance{ - Symbol: symbol, - } - } - stats := analysis.SymbolStats[symbol] - stats.TotalTrades++ - stats.TotalPnL += accumulatedPnL - if accumulatedPnL > 0 { - stats.WinningTrades++ - } else if accumulatedPnL < 0 { - stats.LosingTrades++ - } - - // 刪除持倉記錄 - delete(openPositions, posKey) - } - // ⚠️ 否則不做任何操作(等待後續 partial_close 或 full close) - - } else { - // 🔧 完全平倉(close_long/close_short/auto_close) - // 如果之前有部分平倉,需要加上累積的 PnL - totalPnL := accumulatedPnL + pnl - - positionValue := quantity * openPrice - marginUsed := positionValue / float64(leverage) - pnlPct := 0.0 - if marginUsed > 0 { - pnlPct = (totalPnL / marginUsed) * 100 - } - - outcome := TradeOutcome{ - Symbol: symbol, - Side: side, - Quantity: quantity, // 使用原始總量 - Leverage: leverage, - OpenPrice: openPrice, - ClosePrice: action.Price, - PositionValue: positionValue, - MarginUsed: marginUsed, - PnL: totalPnL, // 🔧 包含之前部分平倉的 PnL - PnLPct: pnlPct, - Duration: action.Timestamp.Sub(openTime).String(), - OpenTime: openTime, - CloseTime: action.Timestamp, - } - - analysis.RecentTrades = append(analysis.RecentTrades, outcome) - analysis.TotalTrades++ - - // 分类交易 - if totalPnL > 0 { - analysis.WinningTrades++ - analysis.AvgWin += totalPnL - } else if totalPnL < 0 { - analysis.LosingTrades++ - analysis.AvgLoss += totalPnL - } - - // 更新币种统计 - if _, exists := analysis.SymbolStats[symbol]; !exists { - analysis.SymbolStats[symbol] = &SymbolPerformance{ - Symbol: symbol, - } - } - stats := analysis.SymbolStats[symbol] - stats.TotalTrades++ - stats.TotalPnL += totalPnL - if totalPnL > 0 { - stats.WinningTrades++ - } else if totalPnL < 0 { - stats.LosingTrades++ - } - - // 刪除持倉記錄 - delete(openPositions, posKey) - } - } - } - } - } - - // 计算统计指标 - if analysis.TotalTrades > 0 { - analysis.WinRate = (float64(analysis.WinningTrades) / float64(analysis.TotalTrades)) * 100 - - // 计算总盈利和总亏损 - totalWinAmount := analysis.AvgWin // 当前是累加的总和 - totalLossAmount := analysis.AvgLoss // 当前是累加的总和(负数) - - if analysis.WinningTrades > 0 { - analysis.AvgWin /= float64(analysis.WinningTrades) - } - if analysis.LosingTrades > 0 { - analysis.AvgLoss /= float64(analysis.LosingTrades) - } - - // Profit Factor = 总盈利 / 总亏损(绝对值) - // 注意:totalLossAmount 是负数,所以取负号得到绝对值 - if totalLossAmount != 0 { - analysis.ProfitFactor = totalWinAmount / (-totalLossAmount) - } else if totalWinAmount > 0 { - // 只有盈利没有亏损的情况,设置为一个很大的值表示完美策略 - analysis.ProfitFactor = 999.0 - } - } - - // 计算各币种胜率和平均盈亏 - bestPnL := -999999.0 - worstPnL := 999999.0 - for symbol, stats := range analysis.SymbolStats { - if stats.TotalTrades > 0 { - stats.WinRate = (float64(stats.WinningTrades) / float64(stats.TotalTrades)) * 100 - stats.AvgPnL = stats.TotalPnL / float64(stats.TotalTrades) - - if stats.TotalPnL > bestPnL { - bestPnL = stats.TotalPnL - analysis.BestSymbol = symbol - } - if stats.TotalPnL < worstPnL { - worstPnL = stats.TotalPnL - analysis.WorstSymbol = symbol - } - } - } - - // 只保留最近的交易(倒序:最新的在前) - if len(analysis.RecentTrades) > 10 { - // 反转数组,让最新的在前 - for i, j := 0, len(analysis.RecentTrades)-1; i < j; i, j = i+1, j-1 { - analysis.RecentTrades[i], analysis.RecentTrades[j] = analysis.RecentTrades[j], analysis.RecentTrades[i] - } - analysis.RecentTrades = analysis.RecentTrades[:10] - } else if len(analysis.RecentTrades) > 0 { - // 反转数组 - for i, j := 0, len(analysis.RecentTrades)-1; i < j; i, j = i+1, j-1 { - analysis.RecentTrades[i], analysis.RecentTrades[j] = analysis.RecentTrades[j], analysis.RecentTrades[i] - } - } - - // 计算夏普比率(需要至少2个数据点) - analysis.SharpeRatio = l.calculateSharpeRatio(records) - - return analysis, nil -} - -// calculateSharpeRatio 计算夏普比率 -// 基于账户净值的变化计算风险调整后收益 -func (l *DecisionLogger) calculateSharpeRatio(records []*DecisionRecord) float64 { - if len(records) < 2 { - return 0.0 - } - - // 提取每个周期的账户净值 - // 注意:TotalBalance字段实际存储的是TotalEquity(账户总净值) - // TotalUnrealizedProfit字段实际存储的是TotalPnL(相对初始余额的盈亏) - var equities []float64 - for _, record := range records { - // 直接使用TotalBalance,因为它已经是完整的账户净值 - equity := record.AccountState.TotalBalance - if equity > 0 { - equities = append(equities, equity) - } - } - - if len(equities) < 2 { - return 0.0 - } - - // 计算周期收益率(period returns) - var returns []float64 - for i := 1; i < len(equities); i++ { - if equities[i-1] > 0 { - periodReturn := (equities[i] - equities[i-1]) / equities[i-1] - returns = append(returns, periodReturn) - } - } - - if len(returns) == 0 { - return 0.0 - } - - // 计算平均收益率 - sumReturns := 0.0 - for _, r := range returns { - sumReturns += r - } - meanReturn := sumReturns / float64(len(returns)) - - // 计算收益率标准差 - sumSquaredDiff := 0.0 - for _, r := range returns { - diff := r - meanReturn - sumSquaredDiff += diff * diff - } - variance := sumSquaredDiff / float64(len(returns)) - stdDev := math.Sqrt(variance) - - // 避免除以零 - if stdDev == 0 { - if meanReturn > 0 { - return 999.0 // 无波动的正收益 - } else if meanReturn < 0 { - return -999.0 // 无波动的负收益 - } - return 0.0 - } - - // 计算夏普比率(假设无风险利率为0) - // 注:直接返回周期级别的夏普比率(非年化),正常范围 -2 到 +2 - sharpeRatio := meanReturn / stdDev - return sharpeRatio -} diff --git a/logger/logger.go b/logger/logger.go index 527c46e2..fd5b87b7 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -1,7 +1,6 @@ package logger import ( - "nofx/config" "os" "github.com/sirupsen/logrus" @@ -10,11 +9,20 @@ import ( var ( // Log 全局logger实例 Log *logrus.Logger - - // telegramHook 保存hook引用,用于优雅关闭 - telegramHook *TelegramHook ) +func init() { + // 自动初始化默认 logger,确保在 Init 被调用前也能使用 + Log = logrus.New() + Log.SetLevel(logrus.InfoLevel) + Log.SetFormatter(&logrus.TextFormatter{ + FullTimestamp: true, + TimestampFormat: "2006-01-02 15:04:05", + ForceColors: true, + }) + Log.SetOutput(os.Stdout) +} + // ============================================================================ // 初始化函数 // ============================================================================ @@ -52,26 +60,6 @@ func Init(cfg *Config) error { // 启用调用位置信息 Log.SetReportCaller(true) - // 添加Telegram Hook(可选) - if cfg.Telegram != nil && cfg.Telegram.Enabled { - if err := setupTelegramHook(cfg.Telegram); err != nil { - Log.Warnf("初始化Telegram推送失败,将继续使用普通日志: %v", err) - } - } - - return nil -} - -// setupTelegramHook 设置Telegram Hook -func setupTelegramHook(telegramCfg *TelegramConfig) error { - hook, err := NewTelegramHook(telegramCfg) - if err != nil { - return err - } - - Log.AddHook(hook) - telegramHook = hook - Log.Info("✅ Telegram日志推送已启用") return nil } @@ -81,69 +69,9 @@ func InitWithSimpleConfig(level string) error { return Init(&Config{Level: level}) } -// InitWithTelegram 使用Telegram配置初始化logger -func InitWithTelegram(botToken string, chatID int64) error { - return Init(&Config{ - Level: "info", - Telegram: &TelegramConfig{ - Enabled: true, - BotToken: botToken, - ChatID: chatID, - }, - }) -} - -// InitFromLogConfig 从config.LogConfig初始化logger -func InitFromLogConfig(logConfig *config.LogConfig) error { - if logConfig == nil { - return InitWithSimpleConfig("info") - } - - cfg := &Config{ - Level: logConfig.Level, - } - - if cfg.Level == "" { - cfg.Level = "info" - } - - // 如果启用了Telegram,添加配置 - if logConfig.Telegram != nil && logConfig.Telegram.Enabled { - if botToken := logConfig.Telegram.BotToken; botToken != "" && logConfig.Telegram.ChatID != 0 { - cfg.Telegram = &TelegramConfig{ - Enabled: true, - BotToken: botToken, - ChatID: logConfig.Telegram.ChatID, - MinLevel: logConfig.Telegram.MinLevel, - } - } - } - - return Init(cfg) -} - -// InitFromParams 从参数初始化logger -// 适用于不依赖config包的场景 -func InitFromParams(level string, telegramEnabled bool, botToken string, chatID int64) error { - cfg := &Config{Level: level} - - if telegramEnabled && botToken != "" && chatID != 0 { - cfg.Telegram = &TelegramConfig{ - Enabled: true, - BotToken: botToken, - ChatID: chatID, - } - } - - return Init(cfg) -} - -// Shutdown 优雅关闭logger(主要用于关闭Telegram发送器) +// Shutdown 优雅关闭logger func Shutdown() { - if telegramHook != nil { - telegramHook.Stop() - telegramHook = nil - } + // 预留用于未来扩展 } // ============================================================================ @@ -208,3 +136,32 @@ func Panic(args ...interface{}) { func Panicf(format string, args ...interface{}) { Log.Panicf(format, args...) } + +// ============================================================================ +// MCP Logger 适配器 +// ============================================================================ + +// MCPLogger 适配器,使 MCP 包使用全局 logger +// 实现 mcp.Logger 接口 +type MCPLogger struct{} + +// NewMCPLogger 创建 MCP 日志适配器 +func NewMCPLogger() *MCPLogger { + return &MCPLogger{} +} + +func (l *MCPLogger) Debugf(format string, args ...any) { + Log.Debugf(format, args...) +} + +func (l *MCPLogger) Infof(format string, args ...any) { + Log.Infof(format, args...) +} + +func (l *MCPLogger) Warnf(format string, args ...any) { + Log.Warnf(format, args...) +} + +func (l *MCPLogger) Errorf(format string, args ...any) { + Log.Errorf(format, args...) +} diff --git a/logger/telegram_hook.go b/logger/telegram_hook.go deleted file mode 100644 index e8477f47..00000000 --- a/logger/telegram_hook.go +++ /dev/null @@ -1,158 +0,0 @@ -package logger - -import ( - "fmt" - "runtime" - "strings" - - "github.com/sirupsen/logrus" -) - -// TelegramHook 实现logrus.Hook接口,将日志推送到Telegram -type TelegramHook struct { - sender *TelegramSender - levels []logrus.Level - enabled bool -} - -// NewTelegramHook 创建Telegram Hook -func NewTelegramHook(config *TelegramConfig) (*TelegramHook, error) { - if !config.Enabled { - return &TelegramHook{enabled: false}, nil - } - - if config.BotToken == "" || config.ChatID == 0 { - return nil, fmt.Errorf("telegram配置不完整: bot_token和chat_id不能为空") - } - - // 创建发送器(使用默认参数) - sender, err := NewTelegramSender(config.BotToken, config.ChatID) - if err != nil { - return nil, fmt.Errorf("创建telegram发送器失败: %w", err) - } - - hook := &TelegramHook{ - sender: sender, - levels: config.GetLogrusLevels(), - enabled: true, - } - - return hook, nil -} - -// Levels 返回需要触发的日志级别 -func (h *TelegramHook) Levels() []logrus.Level { - if !h.enabled { - return []logrus.Level{} - } - return h.levels -} - -// Fire 当日志触发时调用 -func (h *TelegramHook) Fire(entry *logrus.Entry) error { - if !h.enabled { - return nil - } - - // 格式化消息 - message := h.formatMessage(entry) - - // 异步发送(非阻塞) - h.sender.SendAsync(message) - - return nil -} - -// formatMessage 格式化日志消息为Telegram格式 -func (h *TelegramHook) formatMessage(entry *logrus.Entry) string { - // 级别emoji - levelEmoji := h.getLevelEmoji(entry.Level) - - // 基本信息 - var builder strings.Builder - builder.WriteString(fmt.Sprintf("%s *%s*: 系统日志警报\n", levelEmoji, strings.ToUpper(entry.Level.String()))) - builder.WriteString(fmt.Sprintf("📝 消息: `%s`\n", escapeMarkdown(entry.Message))) - - // 字段信息 - if len(entry.Data) > 0 { - builder.WriteString("📊 字段:\n") - for key, value := range entry.Data { - builder.WriteString(fmt.Sprintf(" • %s: `%v`\n", key, value)) - } - } - - // 调用位置 - if entry.HasCaller() { - file := entry.Caller.File - // 只保留相对路径 - if idx := strings.Index(file, "nofx/"); idx >= 0 { - file = file[idx:] - } - builder.WriteString(fmt.Sprintf("📍 位置: `%s:%d`\n", file, entry.Caller.Line)) - } else { - // 如果entry没有caller,手动获取 - if _, file, line, ok := runtime.Caller(8); ok { - if idx := strings.Index(file, "nofx/"); idx >= 0 { - file = file[idx:] - } - builder.WriteString(fmt.Sprintf("📍 位置: `%s:%d`\n", file, line)) - } - } - - // 时间戳 - builder.WriteString(fmt.Sprintf("🕐 时间: `%s`", entry.Time.Format("2006-01-02 15:04:05"))) - - return builder.String() -} - -// getLevelEmoji 获取日志级别对应的emoji -func (h *TelegramHook) getLevelEmoji(level logrus.Level) string { - switch level { - case logrus.PanicLevel: - return "🔴" - case logrus.FatalLevel: - return "🔴" - case logrus.ErrorLevel: - return "🟠" - case logrus.WarnLevel: - return "🟡" - case logrus.InfoLevel: - return "🟢" - case logrus.DebugLevel: - return "🔵" - default: - return "⚪" - } -} - -// escapeMarkdown 转义Markdown特殊字符 -func escapeMarkdown(text string) string { - replacer := strings.NewReplacer( - "_", "\\_", - "*", "\\*", - "[", "\\[", - "]", "\\]", - "(", "\\(", - ")", "\\)", - "~", "\\~", - "`", "\\`", - ">", "\\>", - "#", "\\#", - "+", "\\+", - "-", "\\-", - "=", "\\=", - "|", "\\|", - "{", "\\{", - "}", "\\}", - ".", "\\.", - "!", "\\!", - ) - return replacer.Replace(text) -} - -// Stop 停止Hook(优雅关闭) -func (h *TelegramHook) Stop() { - if h.enabled && h.sender != nil { - h.sender.Stop() - } -} diff --git a/logger/telegram_sender.go b/logger/telegram_sender.go deleted file mode 100644 index 6658d9f2..00000000 --- a/logger/telegram_sender.go +++ /dev/null @@ -1,120 +0,0 @@ -package logger - -import ( - "fmt" - "sync" - "time" - - tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" -) - -// TelegramSender Telegram消息发送器(异步) -type TelegramSender struct { - bot *tgbotapi.BotAPI - chatID int64 - msgChan chan string - retryCount int - retryInterval time.Duration - wg sync.WaitGroup - stopChan chan struct{} - once sync.Once -} - -// NewTelegramSender 创建Telegram发送器(使用默认参数) -func NewTelegramSender(botToken string, chatID int64) (*TelegramSender, error) { - bot, err := tgbotapi.NewBotAPI(botToken) - if err != nil { - return nil, fmt.Errorf("创建telegram bot失败: %w", err) - } - - // 设置为静默模式(不打印bot信息) - bot.Debug = false - - sender := &TelegramSender{ - bot: bot, - chatID: chatID, - msgChan: make(chan string, 20), // 固定缓冲区大小: 20 - retryCount: 3, // 固定重试次数: 3 - retryInterval: 3 * time.Second, // 固定重试间隔: 3秒 - stopChan: make(chan struct{}), - } - - // 启动异步发送协程 - sender.Start() - - return sender, nil -} - -// Start 启动异步发送协程 -func (s *TelegramSender) Start() { - s.wg.Add(1) - go s.listenAndSend() -} - -// SendAsync 异步发送消息(非阻塞) -func (s *TelegramSender) SendAsync(message string) { - select { - case s.msgChan <- message: - // 成功写入缓冲区 - default: - // 缓冲区满,丢弃消息(不阻塞主流程) - fmt.Printf("[Telegram] 消息缓冲区已满,消息被丢弃\n") - } -} - -// listenAndSend 监听channel并发送消息 -func (s *TelegramSender) listenAndSend() { - defer s.wg.Done() - - for { - select { - case msg := <-s.msgChan: - s.sendWithRetry(msg) - case <-s.stopChan: - // 清空缓冲区后退出 - for len(s.msgChan) > 0 { - msg := <-s.msgChan - s.sendWithRetry(msg) - } - return - } - } -} - -// sendWithRetry 发送消息(带重试) -func (s *TelegramSender) sendWithRetry(message string) { - var err error - for i := 0; i < s.retryCount; i++ { - err = s.send(message) - if err == nil { - return // 发送成功 - } - - // 重试前等待 - if i < s.retryCount-1 { - time.Sleep(s.retryInterval) - } - } - - // 所有重试都失败 - if err != nil { - fmt.Printf("[Telegram] 发送消息失败(已重试%d次): %v\n", s.retryCount, err) - } -} - -// send 发送单条消息 -func (s *TelegramSender) send(message string) error { - msg := tgbotapi.NewMessage(s.chatID, message) - msg.ParseMode = tgbotapi.ModeMarkdown - - _, err := s.bot.Send(msg) - return err -} - -// Stop 停止发送器(优雅关闭) -func (s *TelegramSender) Stop() { - s.once.Do(func() { - close(s.stopChan) - s.wg.Wait() - }) -} diff --git a/main.go b/main.go index f456684d..89027e24 100644 --- a/main.go +++ b/main.go @@ -3,21 +3,24 @@ package main import ( "encoding/json" "fmt" - "log" "nofx/api" "nofx/auth" "nofx/backtest" "nofx/config" "nofx/crypto" + "nofx/logger" "nofx/manager" "nofx/market" "nofx/mcp" "nofx/pool" + "nofx/store" + "nofx/trader" "os" "os/signal" "strconv" "strings" "syscall" + "time" "github.com/joho/godotenv" ) @@ -44,7 +47,7 @@ type ConfigFile struct { func loadConfigFile() (*ConfigFile, error) { // 检查config.json是否存在 if _, err := os.Stat("config.json"); os.IsNotExist(err) { - log.Printf("📄 config.json不存在,使用默认配置") + logger.Info("📄 config.json不存在,使用默认配置") return &ConfigFile{}, nil } @@ -64,12 +67,12 @@ func loadConfigFile() (*ConfigFile, error) { } // syncConfigToDatabase 将配置同步到数据库 -func syncConfigToDatabase(database *config.Database, configFile *ConfigFile) error { +func syncConfigToDatabase(st *store.Store, configFile *ConfigFile) error { if configFile == nil { return nil } - log.Printf("🔄 开始同步config.json到数据库...") + logger.Info("🔄 开始同步config.json到数据库...") // 同步各配置项到数据库 configs := map[string]string{ @@ -106,24 +109,24 @@ func syncConfigToDatabase(database *config.Database, configFile *ConfigFile) err // 更新数据库配置 for key, value := range configs { - if err := database.SetSystemConfig(key, value); err != nil { - log.Printf("⚠️ 更新配置 %s 失败: %v", key, err) + if err := st.SystemConfig().Set(key, value); err != nil { + logger.Warnf("⚠️ 更新配置 %s 失败: %v", key, err) } else { - log.Printf("✓ 同步配置: %s = %s", key, value) + logger.Infof("✓ 同步配置: %s = %s", key, value) } } - log.Printf("✅ config.json同步完成") + logger.Info("✅ config.json同步完成") return nil } // loadBetaCodesToDatabase 加载内测码文件到数据库 -func loadBetaCodesToDatabase(database *config.Database) error { +func loadBetaCodesToDatabase(st *store.Store) error { betaCodeFile := "beta_codes.txt" // 检查内测码文件是否存在 if _, err := os.Stat(betaCodeFile); os.IsNotExist(err) { - log.Printf("📄 内测码文件 %s 不存在,跳过加载", betaCodeFile) + logger.Infof("📄 内测码文件 %s 不存在,跳过加载", betaCodeFile) return nil } @@ -133,37 +136,39 @@ func loadBetaCodesToDatabase(database *config.Database) error { return fmt.Errorf("获取内测码文件信息失败: %w", err) } - log.Printf("🔄 发现内测码文件 %s (%.1f KB),开始加载...", betaCodeFile, float64(fileInfo.Size())/1024) + logger.Infof("🔄 发现内测码文件 %s (%.1f KB),开始加载...", betaCodeFile, float64(fileInfo.Size())/1024) // 加载内测码到数据库 - err = database.LoadBetaCodesFromFile(betaCodeFile) + err = st.BetaCode().LoadFromFile(betaCodeFile) if err != nil { return fmt.Errorf("加载内测码失败: %w", err) } // 显示统计信息 - total, used, err := database.GetBetaCodeStats() + total, used, err := st.BetaCode().GetStats() if err != nil { - log.Printf("⚠️ 获取内测码统计失败: %v", err) + logger.Warnf("⚠️ 获取内测码统计失败: %v", err) } else { - log.Printf("✅ 内测码加载完成: 总计 %d 个,已使用 %d 个,剩余 %d 个", total, used, total-used) + logger.Infof("✅ 内测码加载完成: 总计 %d 个,已使用 %d 个,剩余 %d 个", total, used, total-used) } return nil } func main() { - fmt.Println("╔════════════════════════════════════════════════════════════╗") - fmt.Println("║ 🤖 AI多模型交易系统 - 支持 DeepSeek & Qwen ║") - fmt.Println("╚════════════════════════════════════════════════════════════╝") - fmt.Println() - // Load environment variables from .env file if present (for local/dev runs) // In Docker Compose, variables are injected by the runtime and this is harmless. _ = godotenv.Load() + // 初始化日志 + logger.Init(nil) + + logger.Info("╔════════════════════════════════════════════════════════════╗") + logger.Info("║ 🤖 AI多模型交易系统 - 支持 DeepSeek & Qwen ║") + logger.Info("╚════════════════════════════════════════════════════════════╝") + // 初始化数据库配置 - dbPath := "config.db" + dbPath := "data.db" if len(os.Args) > 1 { dbPath = os.Args[1] } @@ -171,163 +176,174 @@ func main() { // 读取配置文件 configFile, err := loadConfigFile() if err != nil { - log.Fatalf("❌ 读取config.json失败: %v", err) + logger.Fatalf("❌ 读取config.json失败: %v", err) } - log.Printf("📋 初始化配置数据库: %s", dbPath) - database, err := config.NewDatabase(dbPath) + logger.Infof("📋 初始化配置数据库: %s", dbPath) + st, err := store.New(dbPath) if err != nil { - log.Fatalf("❌ 初始化数据库失败: %v", err) + logger.Fatalf("❌ 初始化数据库失败: %v", err) } - defer database.Close() - backtest.UseDatabase(database.Conn()) + defer st.Close() + backtest.UseDatabase(st.DB()) // 初始化加密服务 - log.Printf("🔐 初始化加密服务...") - cryptoService, err := crypto.NewCryptoService("secrets/rsa_key") + logger.Info("🔐 初始化加密服务...") + cryptoService, err := crypto.NewCryptoService() if err != nil { - log.Fatalf("❌ 初始化加密服务失败: %v", err) + logger.Fatalf("❌ 初始化加密服务失败: %v", err) } - database.SetCryptoService(cryptoService) - log.Printf("✅ 加密服务初始化成功") + // 创建加密/解密包装函数 + encryptFunc := func(plaintext string) string { + if plaintext == "" { + return plaintext + } + encrypted, err := cryptoService.EncryptForStorage(plaintext) + if err != nil { + logger.Warnf("⚠️ 加密失败: %v", err) + return plaintext + } + return encrypted + } + decryptFunc := func(encrypted string) string { + if encrypted == "" { + return encrypted + } + if !cryptoService.IsEncryptedStorageValue(encrypted) { + return encrypted + } + decrypted, err := cryptoService.DecryptFromStorage(encrypted) + if err != nil { + logger.Warnf("⚠️ 解密失败: %v", err) + return encrypted + } + return decrypted + } + st.SetCryptoFuncs(encryptFunc, decryptFunc) + logger.Info("✅ 加密服务初始化成功") // 同步config.json到数据库 - if err := syncConfigToDatabase(database, configFile); err != nil { - log.Printf("⚠️ 同步config.json到数据库失败: %v", err) + if err := syncConfigToDatabase(st, configFile); err != nil { + logger.Warnf("⚠️ 同步config.json到数据库失败: %v", err) } // 加载内测码到数据库 - if err := loadBetaCodesToDatabase(database); err != nil { - log.Printf("⚠️ 加载内测码到数据库失败: %v", err) + if err := loadBetaCodesToDatabase(st); err != nil { + logger.Warnf("⚠️ 加载内测码到数据库失败: %v", err) } // 获取系统配置 - useDefaultCoinsStr, _ := database.GetSystemConfig("use_default_coins") + useDefaultCoinsStr, _ := st.SystemConfig().Get("use_default_coins") useDefaultCoins := useDefaultCoinsStr == "true" - apiPortStr, _ := database.GetSystemConfig("api_server_port") + apiPortStr, _ := st.SystemConfig().Get("api_server_port") // 设置JWT密钥(优先使用环境变量) jwtSecret := strings.TrimSpace(os.Getenv("JWT_SECRET")) if jwtSecret == "" { // 回退到数据库配置 - jwtSecret, _ = database.GetSystemConfig("jwt_secret") + jwtSecret, _ = st.SystemConfig().Get("jwt_secret") if jwtSecret == "" { jwtSecret = "your-jwt-secret-key-change-in-production-make-it-long-and-random" - log.Printf("⚠️ 使用默认JWT密钥,建议使用加密设置脚本生成安全密钥") + logger.Warn("⚠️ 使用默认JWT密钥,建议使用加密设置脚本生成安全密钥") } else { - log.Printf("🔑 使用数据库中JWT密钥") + logger.Info("🔑 使用数据库中JWT密钥") } } else { - log.Printf("🔑 使用环境变量JWT密钥") + logger.Info("🔑 使用环境变量JWT密钥") } auth.SetJWTSecret(jwtSecret) // 管理员模式下需要管理员密码,缺失则退出 - log.Printf("✓ 配置数据库初始化成功") - fmt.Println() + logger.Info("✓ 配置数据库初始化成功") // 从数据库读取默认主流币种列表 - defaultCoinsJSON, _ := database.GetSystemConfig("default_coins") + defaultCoinsJSON, _ := st.SystemConfig().Get("default_coins") var defaultCoins []string if defaultCoinsJSON != "" { // 尝试从JSON解析 if err := json.Unmarshal([]byte(defaultCoinsJSON), &defaultCoins); err != nil { - log.Printf("⚠️ 解析default_coins配置失败: %v,使用硬编码默认值", err) + logger.Warnf("⚠️ 解析default_coins配置失败: %v,使用硬编码默认值", err) defaultCoins = []string{"BTCUSDT", "ETHUSDT", "SOLUSDT", "BNBUSDT", "XRPUSDT", "DOGEUSDT", "ADAUSDT", "HYPEUSDT"} } else { - log.Printf("✓ 从数据库加载默认币种列表(共%d个): %v", len(defaultCoins), defaultCoins) + logger.Infof("✓ 从数据库加载默认币种列表(共%d个): %v", len(defaultCoins), defaultCoins) } } else { // 如果数据库中没有配置,使用硬编码默认值 defaultCoins = []string{"BTCUSDT", "ETHUSDT", "SOLUSDT", "BNBUSDT", "XRPUSDT", "DOGEUSDT", "ADAUSDT", "HYPEUSDT"} - log.Printf("⚠️ 数据库中未配置default_coins,使用硬编码默认值") + logger.Warn("⚠️ 数据库中未配置default_coins,使用硬编码默认值") } pool.SetDefaultCoins(defaultCoins) // 设置是否使用默认主流币种 pool.SetUseDefaultCoins(useDefaultCoins) if useDefaultCoins { - log.Printf("✓ 已启用默认主流币种列表") + logger.Info("✓ 已启用默认主流币种列表") } // 设置币种池API URL - coinPoolAPIURL, _ := database.GetSystemConfig("coin_pool_api_url") + coinPoolAPIURL, _ := st.SystemConfig().Get("coin_pool_api_url") if coinPoolAPIURL != "" { pool.SetCoinPoolAPI(coinPoolAPIURL) - log.Printf("✓ 已配置AI500币种池API") + logger.Info("✓ 已配置AI500币种池API") } - oiTopAPIURL, _ := database.GetSystemConfig("oi_top_api_url") + oiTopAPIURL, _ := st.SystemConfig().Get("oi_top_api_url") if oiTopAPIURL != "" { pool.SetOITopAPI(oiTopAPIURL) - log.Printf("✓ 已配置OI Top API") + logger.Info("✓ 已配置OI Top API") } // 创建TraderManager 与 BacktestManager cfgForAI, cfgErr := config.LoadConfig("config.json") if cfgErr != nil { - log.Printf("⚠️ 加载config.json用于AI客户端失败: %v", cfgErr) + logger.Warnf("⚠️ 加载config.json用于AI客户端失败: %v", cfgErr) } traderManager := manager.NewTraderManager() mcpClient := newSharedMCPClient(cfgForAI) backtestManager := backtest.NewManager(mcpClient) if err := backtestManager.RestoreRuns(); err != nil { - log.Printf("⚠️ 恢复历史回测失败: %v", err) + logger.Warnf("⚠️ 恢复历史回测失败: %v", err) } // 从数据库加载所有交易员到内存 - err = traderManager.LoadTradersFromDatabase(database) + err = traderManager.LoadTradersFromStore(st) if err != nil { - log.Fatalf("❌ 加载交易员失败: %v", err) + logger.Fatalf("❌ 加载交易员失败: %v", err) } // 获取数据库中的所有交易员配置(用于显示,使用default用户) - traders, err := database.GetTraders("default") + traders, err := st.Trader().List("default") if err != nil { - log.Fatalf("❌ 获取交易员列表失败: %v", err) + logger.Fatalf("❌ 获取交易员列表失败: %v", err) } // 显示加载的交易员信息 - fmt.Println() - fmt.Println("🤖 数据库中的AI交易员配置:") + logger.Info("🤖 数据库中的AI交易员配置:") if len(traders) == 0 { - fmt.Println(" • 暂无配置的交易员,请通过Web界面创建") + logger.Info(" • 暂无配置的交易员,请通过Web界面创建") } else { for _, trader := range traders { status := "停止" if trader.IsRunning { status = "运行中" } - fmt.Printf(" • %s (%s + %s) - 初始资金: %.0f USDT [%s]\n", + logger.Infof(" • %s (%s + %s) - 初始资金: %.0f USDT [%s]", trader.Name, strings.ToUpper(trader.AIModelID), strings.ToUpper(trader.ExchangeID), trader.InitialBalance, status) } } - // 创建初始化上下文 - // TODO : 传入实际配置, 现在并未实际使用,未来所有模块初始化都将通过上下文传递配置 - // ctx := bootstrap.NewContext(&config.Config{}) - - // // 执行所有初始化钩子 - // if err := bootstrap.Run(ctx); err != nil { - // log.Fatalf("初始化失败: %v", err) - // } - - fmt.Println() - fmt.Println("🤖 AI全权决策模式:") - fmt.Printf(" • AI将自主决定每笔交易的杠杆倍数(山寨币最高5倍,BTC/ETH最高5倍)\n") - fmt.Println(" • AI将自主决定每笔交易的仓位大小") - fmt.Println(" • AI将自主设置止损和止盈价格") - fmt.Println(" • AI将基于市场数据、技术指标、账户状态做出全面分析") - fmt.Println() - fmt.Println("⚠️ 风险提示: AI自动交易有风险,建议小额资金测试!") - fmt.Println() - fmt.Println("按 Ctrl+C 停止运行") - fmt.Println(strings.Repeat("=", 60)) - fmt.Println() + logger.Info("🤖 AI全权决策模式:") + logger.Info(" • AI将自主决定每笔交易的杠杆倍数(山寨币最高5倍,BTC/ETH最高5倍)") + logger.Info(" • AI将自主决定每笔交易的仓位大小") + logger.Info(" • AI将自主设置止损和止盈价格") + logger.Info(" • AI将基于市场数据、技术指标、账户状态做出全面分析") + logger.Warn("⚠️ 风险提示: AI自动交易有风险,建议小额资金测试!") + logger.Info("按 Ctrl+C 停止运行") + logger.Info(strings.Repeat("=", 60)) // 获取API服务器端口(优先级:环境变量 > 数据库配置 > 默认值) apiPort := 8080 // 默认端口 @@ -336,30 +352,38 @@ func main() { if envPort := strings.TrimSpace(os.Getenv("NOFX_BACKEND_PORT")); envPort != "" { if port, err := strconv.Atoi(envPort); err == nil && port > 0 { apiPort = port - log.Printf("🔌 使用环境变量端口: %d (NOFX_BACKEND_PORT)", apiPort) + logger.Infof("🔌 使用环境变量端口: %d (NOFX_BACKEND_PORT)", apiPort) } else { - log.Printf("⚠️ 环境变量 NOFX_BACKEND_PORT 无效: %s", envPort) + logger.Warnf("⚠️ 环境变量 NOFX_BACKEND_PORT 无效: %s", envPort) } } else if apiPortStr != "" { // 2. 从数据库配置读取(config.json 同步过来的) if port, err := strconv.Atoi(apiPortStr); err == nil && port > 0 { apiPort = port - log.Printf("🔌 使用数据库配置端口: %d (api_server_port)", apiPort) + logger.Infof("🔌 使用数据库配置端口: %d (api_server_port)", apiPort) } } else { - log.Printf("🔌 使用默认端口: %d", apiPort) + logger.Infof("🔌 使用默认端口: %d", apiPort) } + // 启动订单同步管理器 + orderSyncManager := trader.NewOrderSyncManager(st, 10*time.Second) + orderSyncManager.Start() + + // 启动仓位同步管理器(检测手动平仓等变化) + positionSyncManager := trader.NewPositionSyncManager(st, 10*time.Second) + positionSyncManager.Start() + // 创建并启动API服务器 - apiServer := api.NewServer(traderManager, database, cryptoService, backtestManager, apiPort) + apiServer := api.NewServer(traderManager, st, cryptoService, backtestManager, apiPort) go func() { if err := apiServer.Start(); err != nil { - log.Printf("❌ API服务器错误: %v", err) + logger.Errorf("❌ API服务器错误: %v", err) } }() // 启动流行情数据 - 默认使用所有交易员设置的币种 如果没有设置币种 则优先使用系统默认 - go market.NewWSMonitor(150).Start(database.GetCustomCoins()) + go market.NewWSMonitor(150).Start(st.Trader().GetCustomCoins()) //go market.NewWSMonitor(150).Start([]string{}) //这里是一个使用方式 传入空的话 则使用market市场的所有币种 // 设置优雅退出 sigChan := make(chan os.Signal, 1) @@ -370,33 +394,36 @@ func main() { // 等待退出信号 <-sigChan - fmt.Println() - fmt.Println() - log.Println("📛 收到退出信号,正在优雅关闭...") + logger.Info("📛 收到退出信号,正在优雅关闭...") // 步骤 1: 停止所有交易员 - log.Println("⏸️ 停止所有交易员...") + logger.Info("⏸️ 停止所有交易员...") traderManager.StopAll() - log.Println("✅ 所有交易员已停止") + logger.Info("✅ 所有交易员已停止") - // 步骤 2: 关闭 API 服务器 - log.Println("🛑 停止 API 服务器...") + // 步骤 2: 停止订单同步管理器和仓位同步管理器 + logger.Info("📦 停止订单同步管理器...") + orderSyncManager.Stop() + logger.Info("📊 停止仓位同步管理器...") + positionSyncManager.Stop() + + // 步骤 3: 关闭 API 服务器 + logger.Info("🛑 停止 API 服务器...") if err := apiServer.Shutdown(); err != nil { - log.Printf("⚠️ 关闭 API 服务器时出错: %v", err) + logger.Warnf("⚠️ 关闭 API 服务器时出错: %v", err) } else { - log.Println("✅ API 服务器已安全关闭") + logger.Info("✅ API 服务器已安全关闭") } - // 步骤 3: 关闭数据库连接 (确保所有写入完成) - log.Println("💾 关闭数据库连接...") - if err := database.Close(); err != nil { - log.Printf("❌ 关闭数据库失败: %v", err) + // 步骤 4: 关闭数据库连接 (确保所有写入完成) + logger.Info("💾 关闭数据库连接...") + if err := st.Close(); err != nil { + logger.Errorf("❌ 关闭数据库失败: %v", err) } else { - log.Println("✅ 数据库已安全关闭,所有数据已持久化") + logger.Info("✅ 数据库已安全关闭,所有数据已持久化") } - fmt.Println() - fmt.Println("👋 感谢使用AI交易系统!") + logger.Info("👋 感谢使用AI交易系统!") } func newSharedMCPClient(cfg *config.Config) mcp.AIClient { diff --git a/manager/trader_manager.go b/manager/trader_manager.go index d6c93f61..ec510ba6 100644 --- a/manager/trader_manager.go +++ b/manager/trader_manager.go @@ -4,8 +4,8 @@ import ( "context" "encoding/json" "fmt" - "log" - "nofx/config" + "nofx/logger" + "nofx/store" "nofx/trader" "sort" "strconv" @@ -38,371 +38,6 @@ func NewTraderManager() *TraderManager { } } -// LoadTradersFromDatabase 从数据库加载所有交易员到内存 -func (tm *TraderManager) LoadTradersFromDatabase(database *config.Database) error { - tm.mu.Lock() - defer tm.mu.Unlock() - - // 获取所有用户 - userIDs, err := database.GetAllUsers() - if err != nil { - return fmt.Errorf("获取用户列表失败: %w", err) - } - - log.Printf("📋 发现 %d 个用户,开始加载所有交易员配置...", len(userIDs)) - - var allTraders []*config.TraderRecord - for _, userID := range userIDs { - // 获取每个用户的交易员 - traders, err := database.GetTraders(userID) - if err != nil { - log.Printf("⚠️ 获取用户 %s 的交易员失败: %v", userID, err) - continue - } - log.Printf("📋 用户 %s: %d 个交易员", userID, len(traders)) - allTraders = append(allTraders, traders...) - } - - log.Printf("📋 总共加载 %d 个交易员配置", len(allTraders)) - - // 获取系统配置(不包含信号源,信号源现在为用户级别) - maxDailyLossStr, _ := database.GetSystemConfig("max_daily_loss") - maxDrawdownStr, _ := database.GetSystemConfig("max_drawdown") - stopTradingMinutesStr, _ := database.GetSystemConfig("stop_trading_minutes") - defaultCoinsStr, _ := database.GetSystemConfig("default_coins") - - // 解析配置 - maxDailyLoss := 10.0 // 默认值 - if val, err := strconv.ParseFloat(maxDailyLossStr, 64); err == nil { - maxDailyLoss = val - } - - maxDrawdown := 20.0 // 默认值 - if val, err := strconv.ParseFloat(maxDrawdownStr, 64); err == nil { - maxDrawdown = val - } - - stopTradingMinutes := 60 // 默认值 - if val, err := strconv.Atoi(stopTradingMinutesStr); err == nil { - stopTradingMinutes = val - } - - // 解析默认币种列表 - var defaultCoins []string - if defaultCoinsStr != "" { - if err := json.Unmarshal([]byte(defaultCoinsStr), &defaultCoins); err != nil { - log.Printf("⚠️ 解析默认币种配置失败: %v,使用空列表", err) - defaultCoins = []string{} - } - } - - // 为每个交易员获取AI模型和交易所配置 - for _, traderCfg := range allTraders { - // 获取AI模型配置(使用交易员所属的用户ID) - aiModels, err := database.GetAIModels(traderCfg.UserID) - if err != nil { - log.Printf("⚠️ 获取AI模型配置失败: %v", err) - continue - } - - var aiModelCfg *config.AIModelConfig - // 优先精确匹配 model.ID(新版逻辑) - for _, model := range aiModels { - if model.ID == traderCfg.AIModelID { - aiModelCfg = model - break - } - } - // 如果没有精确匹配,尝试匹配 provider(兼容旧数据) - if aiModelCfg == nil { - for _, model := range aiModels { - if model.Provider == traderCfg.AIModelID { - aiModelCfg = model - log.Printf("⚠️ 交易员 %s 使用旧版 provider 匹配: %s -> %s", traderCfg.Name, traderCfg.AIModelID, model.ID) - break - } - } - } - - if aiModelCfg == nil { - log.Printf("⚠️ 交易员 %s 的AI模型 %s 不存在,跳过", traderCfg.Name, traderCfg.AIModelID) - continue - } - - if !aiModelCfg.Enabled { - log.Printf("⚠️ 交易员 %s 的AI模型 %s 未启用,跳过", traderCfg.Name, traderCfg.AIModelID) - continue - } - - // 获取交易所配置(使用交易员所属的用户ID) - exchanges, err := database.GetExchanges(traderCfg.UserID) - if err != nil { - log.Printf("⚠️ 获取交易所配置失败: %v", err) - continue - } - - var exchangeCfg *config.ExchangeConfig - for _, exchange := range exchanges { - if exchange.ID == traderCfg.ExchangeID { - exchangeCfg = exchange - break - } - } - - if exchangeCfg == nil { - log.Printf("⚠️ 交易员 %s 的交易所 %s 不存在,跳过", traderCfg.Name, traderCfg.ExchangeID) - continue - } - - if !exchangeCfg.Enabled { - log.Printf("⚠️ 交易员 %s 的交易所 %s 未启用,跳过", traderCfg.Name, traderCfg.ExchangeID) - continue - } - - // 获取用户信号源配置 - var coinPoolURL, oiTopURL string - if userSignalSource, err := database.GetUserSignalSource(traderCfg.UserID); err == nil { - coinPoolURL = userSignalSource.CoinPoolURL - oiTopURL = userSignalSource.OITopURL - } else { - // 如果用户没有配置信号源,使用空字符串 - log.Printf("🔍 用户 %s 暂未配置信号源", traderCfg.UserID) - } - - // 添加到TraderManager - err = tm.addTraderFromDB(traderCfg, aiModelCfg, exchangeCfg, coinPoolURL, oiTopURL, maxDailyLoss, maxDrawdown, stopTradingMinutes, defaultCoins, database, traderCfg.UserID) - if err != nil { - log.Printf("❌ 添加交易员 %s 失败: %v", traderCfg.Name, err) - continue - } - } - - log.Printf("✓ 成功加载 %d 个交易员到内存", len(tm.traders)) - return nil -} - -// addTraderFromConfig 内部方法:从配置添加交易员(不加锁,因为调用方已加锁) -func (tm *TraderManager) addTraderFromDB(traderCfg *config.TraderRecord, aiModelCfg *config.AIModelConfig, exchangeCfg *config.ExchangeConfig, coinPoolURL, oiTopURL string, maxDailyLoss, maxDrawdown float64, stopTradingMinutes int, defaultCoins []string, database *config.Database, userID string) error { - if _, exists := tm.traders[traderCfg.ID]; exists { - return fmt.Errorf("trader ID '%s' 已存在", traderCfg.ID) - } - - // 处理交易币种列表 - var tradingCoins []string - if traderCfg.TradingSymbols != "" { - // 解析逗号分隔的交易币种列表 - symbols := strings.Split(traderCfg.TradingSymbols, ",") - for _, symbol := range symbols { - symbol = strings.TrimSpace(symbol) - if symbol != "" { - tradingCoins = append(tradingCoins, symbol) - } - } - } - - // 如果没有指定交易币种,使用默认币种 - if len(tradingCoins) == 0 { - tradingCoins = defaultCoins - } - - // 根据交易员配置决定是否使用信号源 - var effectiveCoinPoolURL string - if traderCfg.UseCoinPool && coinPoolURL != "" { - effectiveCoinPoolURL = coinPoolURL - log.Printf("✓ 交易员 %s 启用 COIN POOL 信号源: %s", traderCfg.Name, coinPoolURL) - } - - // 构建AutoTraderConfig - traderConfig := trader.AutoTraderConfig{ - ID: traderCfg.ID, - Name: traderCfg.Name, - AIModel: aiModelCfg.Provider, // 使用provider作为模型标识 - Exchange: exchangeCfg.ID, // 使用exchange ID - BinanceAPIKey: "", - BinanceSecretKey: "", - HyperliquidPrivateKey: "", - HyperliquidTestnet: exchangeCfg.Testnet, - CoinPoolAPIURL: effectiveCoinPoolURL, - UseQwen: aiModelCfg.Provider == "qwen", - DeepSeekKey: "", - QwenKey: "", - CustomAPIURL: aiModelCfg.CustomAPIURL, // 自定义API URL - CustomModelName: aiModelCfg.CustomModelName, // 自定义模型名称 - ScanInterval: time.Duration(traderCfg.ScanIntervalMinutes) * time.Minute, - InitialBalance: traderCfg.InitialBalance, - BTCETHLeverage: traderCfg.BTCETHLeverage, - AltcoinLeverage: traderCfg.AltcoinLeverage, - MaxDailyLoss: maxDailyLoss, - MaxDrawdown: maxDrawdown, - StopTradingTime: time.Duration(stopTradingMinutes) * time.Minute, - IsCrossMargin: traderCfg.IsCrossMargin, - DefaultCoins: defaultCoins, - TradingCoins: tradingCoins, - SystemPromptTemplate: traderCfg.SystemPromptTemplate, // 系统提示词模板 - } - - // 根据交易所类型设置API密钥 - if exchangeCfg.ID == "binance" { - traderConfig.BinanceAPIKey = exchangeCfg.APIKey - traderConfig.BinanceSecretKey = exchangeCfg.SecretKey - } else if exchangeCfg.ID == "bybit" { - traderConfig.BybitAPIKey = exchangeCfg.APIKey - traderConfig.BybitSecretKey = exchangeCfg.SecretKey - } else if exchangeCfg.ID == "hyperliquid" { - traderConfig.HyperliquidPrivateKey = exchangeCfg.APIKey // hyperliquid用APIKey存储private key - traderConfig.HyperliquidWalletAddr = exchangeCfg.HyperliquidWalletAddr - } else if exchangeCfg.ID == "aster" { - traderConfig.AsterUser = exchangeCfg.AsterUser - traderConfig.AsterSigner = exchangeCfg.AsterSigner - traderConfig.AsterPrivateKey = exchangeCfg.AsterPrivateKey - } else if exchangeCfg.ID == "lighter" { - traderConfig.LighterPrivateKey = exchangeCfg.LighterPrivateKey - traderConfig.LighterWalletAddr = exchangeCfg.LighterWalletAddr - traderConfig.LighterTestnet = exchangeCfg.Testnet - } - - // 根据AI模型设置API密钥 - if aiModelCfg.Provider == "qwen" { - traderConfig.QwenKey = aiModelCfg.APIKey - } else if aiModelCfg.Provider == "deepseek" { - traderConfig.DeepSeekKey = aiModelCfg.APIKey - } - - // 创建trader实例 - at, err := trader.NewAutoTrader(traderConfig, database, userID) - if err != nil { - return fmt.Errorf("创建trader失败: %w", err) - } - - // 设置自定义prompt(如果有) - if traderCfg.CustomPrompt != "" { - at.SetCustomPrompt(traderCfg.CustomPrompt) - at.SetOverrideBasePrompt(traderCfg.OverrideBasePrompt) - if traderCfg.OverrideBasePrompt { - log.Printf("✓ 已设置自定义交易策略prompt (覆盖基础prompt)") - } else { - log.Printf("✓ 已设置自定义交易策略prompt (补充基础prompt)") - } - } - - tm.traders[traderCfg.ID] = at - log.Printf("✓ Trader '%s' (%s + %s) 已加载到内存", traderCfg.Name, aiModelCfg.Provider, exchangeCfg.ID) - return nil -} - -// AddTrader 从数据库配置添加trader (移除旧版兼容性) - -// AddTraderFromDB 从数据库配置添加trader -func (tm *TraderManager) AddTraderFromDB(traderCfg *config.TraderRecord, aiModelCfg *config.AIModelConfig, exchangeCfg *config.ExchangeConfig, coinPoolURL, oiTopURL string, maxDailyLoss, maxDrawdown float64, stopTradingMinutes int, defaultCoins []string, database *config.Database, userID string) error { - tm.mu.Lock() - defer tm.mu.Unlock() - - if _, exists := tm.traders[traderCfg.ID]; exists { - return fmt.Errorf("trader ID '%s' 已存在", traderCfg.ID) - } - - // 处理交易币种列表 - var tradingCoins []string - if traderCfg.TradingSymbols != "" { - // 解析逗号分隔的交易币种列表 - symbols := strings.Split(traderCfg.TradingSymbols, ",") - for _, symbol := range symbols { - symbol = strings.TrimSpace(symbol) - if symbol != "" { - tradingCoins = append(tradingCoins, symbol) - } - } - } - - // 如果没有指定交易币种,使用默认币种 - if len(tradingCoins) == 0 { - tradingCoins = defaultCoins - } - - // 根据交易员配置决定是否使用信号源 - var effectiveCoinPoolURL string - if traderCfg.UseCoinPool && coinPoolURL != "" { - effectiveCoinPoolURL = coinPoolURL - log.Printf("✓ 交易员 %s 启用 COIN POOL 信号源: %s", traderCfg.Name, coinPoolURL) - } - - // 构建AutoTraderConfig - traderConfig := trader.AutoTraderConfig{ - ID: traderCfg.ID, - Name: traderCfg.Name, - AIModel: aiModelCfg.Provider, // 使用provider作为模型标识 - Exchange: exchangeCfg.ID, // 使用exchange ID - BinanceAPIKey: "", - BinanceSecretKey: "", - HyperliquidPrivateKey: "", - HyperliquidTestnet: exchangeCfg.Testnet, - CoinPoolAPIURL: effectiveCoinPoolURL, - UseQwen: aiModelCfg.Provider == "qwen", - DeepSeekKey: "", - QwenKey: "", - CustomAPIURL: aiModelCfg.CustomAPIURL, // 自定义API URL - CustomModelName: aiModelCfg.CustomModelName, // 自定义模型名称 - ScanInterval: time.Duration(traderCfg.ScanIntervalMinutes) * time.Minute, - InitialBalance: traderCfg.InitialBalance, - BTCETHLeverage: traderCfg.BTCETHLeverage, - AltcoinLeverage: traderCfg.AltcoinLeverage, - MaxDailyLoss: maxDailyLoss, - MaxDrawdown: maxDrawdown, - StopTradingTime: time.Duration(stopTradingMinutes) * time.Minute, - IsCrossMargin: traderCfg.IsCrossMargin, - DefaultCoins: defaultCoins, - TradingCoins: tradingCoins, - } - - // 根据交易所类型设置API密钥 - if exchangeCfg.ID == "binance" { - traderConfig.BinanceAPIKey = exchangeCfg.APIKey - traderConfig.BinanceSecretKey = exchangeCfg.SecretKey - } else if exchangeCfg.ID == "bybit" { - traderConfig.BybitAPIKey = exchangeCfg.APIKey - traderConfig.BybitSecretKey = exchangeCfg.SecretKey - } else if exchangeCfg.ID == "hyperliquid" { - traderConfig.HyperliquidPrivateKey = exchangeCfg.APIKey // hyperliquid用APIKey存储private key - traderConfig.HyperliquidWalletAddr = exchangeCfg.HyperliquidWalletAddr - } else if exchangeCfg.ID == "aster" { - traderConfig.AsterUser = exchangeCfg.AsterUser - traderConfig.AsterSigner = exchangeCfg.AsterSigner - traderConfig.AsterPrivateKey = exchangeCfg.AsterPrivateKey - } else if exchangeCfg.ID == "lighter" { - traderConfig.LighterPrivateKey = exchangeCfg.LighterPrivateKey - traderConfig.LighterWalletAddr = exchangeCfg.LighterWalletAddr - traderConfig.LighterTestnet = exchangeCfg.Testnet - } - - // 根据AI模型设置API密钥 - if aiModelCfg.Provider == "qwen" { - traderConfig.QwenKey = aiModelCfg.APIKey - } else if aiModelCfg.Provider == "deepseek" { - traderConfig.DeepSeekKey = aiModelCfg.APIKey - } - - // 创建trader实例 - at, err := trader.NewAutoTrader(traderConfig, database, userID) - if err != nil { - return fmt.Errorf("创建trader失败: %w", err) - } - - // 设置自定义prompt(如果有) - if traderCfg.CustomPrompt != "" { - at.SetCustomPrompt(traderCfg.CustomPrompt) - at.SetOverrideBasePrompt(traderCfg.OverrideBasePrompt) - if traderCfg.OverrideBasePrompt { - log.Printf("✓ 已设置自定义交易策略prompt (覆盖基础prompt)") - } else { - log.Printf("✓ 已设置自定义交易策略prompt (补充基础prompt)") - } - } - - tm.traders[traderCfg.ID] = at - log.Printf("✓ Trader '%s' (%s + %s) 已添加", traderCfg.Name, aiModelCfg.Provider, exchangeCfg.ID) - return nil -} - // GetTrader 获取指定ID的trader func (tm *TraderManager) GetTrader(id string) (*trader.AutoTrader, error) { tm.mu.RLock() @@ -444,12 +79,12 @@ func (tm *TraderManager) StartAll() { tm.mu.RLock() defer tm.mu.RUnlock() - log.Println("🚀 启动所有Trader...") + logger.Info("🚀 启动所有Trader...") for id, t := range tm.traders { go func(traderID string, at *trader.AutoTrader) { - log.Printf("▶️ 启动 %s...", at.GetName()) + logger.Infof("▶️ 启动 %s...", at.GetName()) if err := at.Run(); err != nil { - log.Printf("❌ %s 运行错误: %v", at.GetName(), err) + logger.Infof("❌ %s 运行错误: %v", at.GetName(), err) } }(id, t) } @@ -460,7 +95,7 @@ func (tm *TraderManager) StopAll() { tm.mu.RLock() defer tm.mu.RUnlock() - log.Println("⏹ 停止所有Trader...") + logger.Info("⏹ 停止所有Trader...") for _, t := range tm.traders { t.Stop() } @@ -514,7 +149,7 @@ func (tm *TraderManager) GetCompetitionData() (map[string]interface{}, error) { cachedData[k] = v } tm.competitionCache.mu.RUnlock() - log.Printf("📋 返回竞赛数据缓存 (缓存时间: %.1fs)", time.Since(tm.competitionCache.timestamp).Seconds()) + logger.Infof("📋 返回竞赛数据缓存 (缓存时间: %.1fs)", time.Since(tm.competitionCache.timestamp).Seconds()) return cachedData, nil } tm.competitionCache.mu.RUnlock() @@ -528,7 +163,7 @@ func (tm *TraderManager) GetCompetitionData() (map[string]interface{}, error) { } tm.mu.RUnlock() - log.Printf("🔄 重新获取竞赛数据,交易员数量: %d", len(allTraders)) + logger.Infof("🔄 重新获取竞赛数据,交易员数量: %d", len(allTraders)) // 并发获取交易员数据 traders := tm.getConcurrentTraderData(allTraders) @@ -618,7 +253,7 @@ func (tm *TraderManager) getConcurrentTraderData(traders []*trader.AutoTrader) [ } case err := <-errorChan: // 获取账户信息失败 - log.Printf("⚠️ 获取交易员 %s 账户信息失败: %v", trader.GetID(), err) + logger.Infof("⚠️ 获取交易员 %s 账户信息失败: %v", trader.GetID(), err) traderData = map[string]interface{}{ "trader_id": trader.GetID(), "trader_name": trader.GetName(), @@ -635,7 +270,7 @@ func (tm *TraderManager) getConcurrentTraderData(traders []*trader.AutoTrader) [ } case <-ctx.Done(): // 超时 - log.Printf("⏰ 获取交易员 %s 账户信息超时", trader.GetID()) + logger.Infof("⏰ 获取交易员 %s 账户信息超时", trader.GetID()) traderData = map[string]interface{}{ "trader_id": trader.GetID(), "trader_name": trader.GetName(), @@ -695,63 +330,46 @@ func (tm *TraderManager) GetTopTradersData() (map[string]interface{}, error) { return result, nil } -// isUserTrader 检查trader是否属于指定用户 -func isUserTrader(traderID, userID string) bool { - // trader ID格式: userID_traderName 或 randomUUID_modelName - // 为了兼容性,我们检查前缀 - if len(traderID) >= len(userID) && traderID[:len(userID)] == userID { - return true + +// RemoveTrader 从内存中移除指定的trader(不影响数据库) +// 用于更新trader配置时强制重新加载 +func (tm *TraderManager) RemoveTrader(traderID string) { + tm.mu.Lock() + defer tm.mu.Unlock() + + if _, exists := tm.traders[traderID]; exists { + delete(tm.traders, traderID) + logger.Infof("✓ Trader %s 已从内存中移除", traderID) } - // 对于老的default用户,所有没有明确用户前缀的都属于default - if userID == "default" && !containsUserPrefix(traderID) { - return true - } - return false } -// containsUserPrefix 检查trader ID是否包含用户前缀 -func containsUserPrefix(traderID string) bool { - // 检查是否包含邮箱格式的前缀(user@example.com_traderName) - for i, ch := range traderID { - if ch == '@' { - // 找到@符号,说明可能是email前缀 - return true - } - if ch == '_' && i > 0 { - // 找到下划线但前面没有@,可能是UUID或其他格式 - break - } - } - return false -} - -// LoadUserTraders 为特定用户加载交易员到内存 -func (tm *TraderManager) LoadUserTraders(database *config.Database, userID string) error { +// LoadUserTradersFromStore 为特定用户从store加载交易员到内存 +func (tm *TraderManager) LoadUserTradersFromStore(st *store.Store, userID string) error { tm.mu.Lock() defer tm.mu.Unlock() // 获取指定用户的所有交易员 - traders, err := database.GetTraders(userID) + traders, err := st.Trader().List(userID) if err != nil { return fmt.Errorf("获取用户 %s 的交易员列表失败: %w", userID, err) } - log.Printf("📋 为用户 %s 加载交易员配置: %d 个", userID, len(traders)) + logger.Infof("📋 为用户 %s 加载交易员配置: %d 个", userID, len(traders)) - // 获取系统配置(不包含信号源,信号源现在为用户级别) - maxDailyLossStr, _ := database.GetSystemConfig("max_daily_loss") - maxDrawdownStr, _ := database.GetSystemConfig("max_drawdown") - stopTradingMinutesStr, _ := database.GetSystemConfig("stop_trading_minutes") - defaultCoinsStr, _ := database.GetSystemConfig("default_coins") + // 获取系统配置 + maxDailyLossStr, _ := st.SystemConfig().Get("max_daily_loss") + maxDrawdownStr, _ := st.SystemConfig().Get("max_drawdown") + stopTradingMinutesStr, _ := st.SystemConfig().Get("stop_trading_minutes") + defaultCoinsStr, _ := st.SystemConfig().Get("default_coins") // 获取用户信号源配置 var coinPoolURL, oiTopURL string - if userSignalSource, err := database.GetUserSignalSource(userID); err == nil { - coinPoolURL = userSignalSource.CoinPoolURL - oiTopURL = userSignalSource.OITopURL - log.Printf("📡 加载用户 %s 的信号源配置: COIN POOL=%s, OI TOP=%s", userID, coinPoolURL, oiTopURL) + if signalSource, err := st.SignalSource().Get(userID); err == nil { + coinPoolURL = signalSource.CoinPoolURL + oiTopURL = signalSource.OITopURL + logger.Infof("📡 加载用户 %s 的信号源配置: COIN POOL=%s, OI TOP=%s", userID, coinPoolURL, oiTopURL) } else { - log.Printf("🔍 用户 %s 暂未配置信号源", userID) + logger.Infof("🔍 用户 %s 暂未配置信号源", userID) } // 解析配置 @@ -774,22 +392,21 @@ func (tm *TraderManager) LoadUserTraders(database *config.Database, userID strin var defaultCoins []string if defaultCoinsStr != "" { if err := json.Unmarshal([]byte(defaultCoinsStr), &defaultCoins); err != nil { - log.Printf("⚠️ 解析默认币种配置失败: %v,使用空列表", err) + logger.Infof("⚠️ 解析默认币种配置失败: %v,使用空列表", err) defaultCoins = []string{} } } - // 🔧 性能优化:在循环外只查询一次AI模型和交易所配置 - // 避免在循环中重复查询相同的数据,减少数据库压力和锁持有时间 - aiModels, err := database.GetAIModels(userID) + // 获取AI模型和交易所列表(在循环外只查询一次) + aiModels, err := st.AIModel().List(userID) if err != nil { - log.Printf("⚠️ 获取用户 %s 的AI模型配置失败: %v", userID, err) + logger.Infof("⚠️ 获取用户 %s 的AI模型配置失败: %v", userID, err) return fmt.Errorf("获取AI模型配置失败: %w", err) } - exchanges, err := database.GetExchanges(userID) + exchanges, err := st.Exchange().List(userID) if err != nil { - log.Printf("⚠️ 获取用户 %s 的交易所配置失败: %v", userID, err) + logger.Infof("⚠️ 获取用户 %s 的交易所配置失败: %v", userID, err) return fmt.Errorf("获取交易所配置失败: %w", err) } @@ -797,43 +414,39 @@ func (tm *TraderManager) LoadUserTraders(database *config.Database, userID strin for _, traderCfg := range traders { // 检查是否已经加载过这个交易员 if _, exists := tm.traders[traderCfg.ID]; exists { - log.Printf("⚠️ 交易员 %s 已经加载,跳过", traderCfg.Name) + logger.Infof("⚠️ 交易员 %s 已经加载,跳过", traderCfg.Name) continue } // 从已查询的列表中查找AI模型配置 - - var aiModelCfg *config.AIModelConfig - // 优先精确匹配 model.ID(新版逻辑) + var aiModelCfg *store.AIModel for _, model := range aiModels { if model.ID == traderCfg.AIModelID { aiModelCfg = model break } } - // 如果没有精确匹配,尝试匹配 provider(兼容旧数据) if aiModelCfg == nil { for _, model := range aiModels { if model.Provider == traderCfg.AIModelID { aiModelCfg = model - log.Printf("⚠️ 交易员 %s 使用旧版 provider 匹配: %s -> %s", traderCfg.Name, traderCfg.AIModelID, model.ID) break } } } if aiModelCfg == nil { - log.Printf("⚠️ 交易员 %s 的AI模型 %s 不存在,跳过", traderCfg.Name, traderCfg.AIModelID) + logger.Infof("⚠️ 交易员 %s 的AI模型 %s 不存在,跳过", traderCfg.Name, traderCfg.AIModelID) continue } if !aiModelCfg.Enabled { - log.Printf("⚠️ 交易员 %s 的AI模型 %s 未启用,跳过", traderCfg.Name, traderCfg.AIModelID) + logger.Infof("⚠️ 交易员 %s 的AI模型 %s 未启用,跳过", traderCfg.Name, traderCfg.AIModelID) continue } // 从已查询的列表中查找交易所配置 - var exchangeCfg *config.ExchangeConfig + var exchangeCfg *store.Exchange for _, exchange := range exchanges { if exchange.ID == traderCfg.ExchangeID { exchangeCfg = exchange @@ -842,134 +455,59 @@ func (tm *TraderManager) LoadUserTraders(database *config.Database, userID strin } if exchangeCfg == nil { - log.Printf("⚠️ 交易员 %s 的交易所 %s 不存在,跳过", traderCfg.Name, traderCfg.ExchangeID) + logger.Infof("⚠️ 交易员 %s 的交易所 %s 不存在,跳过", traderCfg.Name, traderCfg.ExchangeID) continue } if !exchangeCfg.Enabled { - log.Printf("⚠️ 交易员 %s 的交易所 %s 未启用,跳过", traderCfg.Name, traderCfg.ExchangeID) + logger.Infof("⚠️ 交易员 %s 的交易所 %s 未启用,跳过", traderCfg.Name, traderCfg.ExchangeID) continue } // 使用现有的方法加载交易员 - err = tm.loadSingleTrader(traderCfg, aiModelCfg, exchangeCfg, coinPoolURL, oiTopURL, maxDailyLoss, maxDrawdown, stopTradingMinutes, defaultCoins, database, userID) + err = tm.addTraderFromStore(traderCfg, aiModelCfg, exchangeCfg, coinPoolURL, oiTopURL, maxDailyLoss, maxDrawdown, stopTradingMinutes, defaultCoins, st) if err != nil { - log.Printf("⚠️ 加载交易员 %s 失败: %v", traderCfg.Name, err) + logger.Infof("⚠️ 加载交易员 %s 失败: %v", traderCfg.Name, err) } } return nil } -// LoadTraderByID 加载指定ID的单个交易员到内存 -// 此方法会自动查询所需的所有配置(AI模型、交易所、系统配置等) -// 参数: -// - database: 数据库实例 -// - userID: 用户ID -// - traderID: 交易员ID -// -// 返回: -// - error: 如果交易员不存在、配置无效或加载失败则返回错误 -func (tm *TraderManager) LoadTraderByID(database *config.Database, userID, traderID string) error { +// LoadTradersFromStore 从store加载所有交易员到内存(新版API) +func (tm *TraderManager) LoadTradersFromStore(st *store.Store) error { tm.mu.Lock() defer tm.mu.Unlock() - // 1. 检查是否已加载 - if _, exists := tm.traders[traderID]; exists { - log.Printf("⚠️ 交易员 %s 已经加载,跳过", traderID) - return nil - } - - // 2. 查询交易员配置 - traders, err := database.GetTraders(userID) + // 获取所有用户 + userIDs, err := st.User().GetAllIDs() if err != nil { - return fmt.Errorf("获取交易员列表失败: %w", err) + return fmt.Errorf("获取用户列表失败: %w", err) } - var traderCfg *config.TraderRecord - for _, t := range traders { - if t.ID == traderID { - traderCfg = t - break + logger.Infof("📋 发现 %d 个用户,开始加载所有交易员配置...", len(userIDs)) + + var allTraders []*store.Trader + for _, userID := range userIDs { + // 获取每个用户的交易员 + traders, err := st.Trader().List(userID) + if err != nil { + logger.Infof("⚠️ 获取用户 %s 的交易员失败: %v", userID, err) + continue } + logger.Infof("📋 用户 %s: %d 个交易员", userID, len(traders)) + allTraders = append(allTraders, traders...) } - if traderCfg == nil { - return fmt.Errorf("交易员 %s 不存在", traderID) - } + logger.Infof("📋 总共加载 %d 个交易员配置", len(allTraders)) - // 3. 查询AI模型配置 - aiModels, err := database.GetAIModels(userID) - if err != nil { - return fmt.Errorf("获取AI模型配置失败: %w", err) - } + // 获取系统配置 + maxDailyLossStr, _ := st.SystemConfig().Get("max_daily_loss") + maxDrawdownStr, _ := st.SystemConfig().Get("max_drawdown") + stopTradingMinutesStr, _ := st.SystemConfig().Get("stop_trading_minutes") + defaultCoinsStr, _ := st.SystemConfig().Get("default_coins") - var aiModelCfg *config.AIModelConfig - // 优先精确匹配 model.ID - for _, model := range aiModels { - if model.ID == traderCfg.AIModelID { - aiModelCfg = model - break - } - } - // 如果没有精确匹配,尝试匹配 provider(兼容旧数据) - if aiModelCfg == nil { - for _, model := range aiModels { - if model.Provider == traderCfg.AIModelID { - aiModelCfg = model - log.Printf("⚠️ 交易员 %s 使用旧版 provider 匹配: %s -> %s", traderCfg.Name, traderCfg.AIModelID, model.ID) - break - } - } - } - - if aiModelCfg == nil { - return fmt.Errorf("AI模型 %s 不存在", traderCfg.AIModelID) - } - - if !aiModelCfg.Enabled { - return fmt.Errorf("AI模型 %s 未启用", traderCfg.AIModelID) - } - - // 4. 查询交易所配置 - exchanges, err := database.GetExchanges(userID) - if err != nil { - return fmt.Errorf("获取交易所配置失败: %w", err) - } - - var exchangeCfg *config.ExchangeConfig - for _, exchange := range exchanges { - if exchange.ID == traderCfg.ExchangeID { - exchangeCfg = exchange - break - } - } - - if exchangeCfg == nil { - return fmt.Errorf("交易所 %s 不存在", traderCfg.ExchangeID) - } - - if !exchangeCfg.Enabled { - return fmt.Errorf("交易所 %s 未启用", traderCfg.ExchangeID) - } - - // 5. 查询系统配置 - maxDailyLossStr, _ := database.GetSystemConfig("max_daily_loss") - maxDrawdownStr, _ := database.GetSystemConfig("max_drawdown") - stopTradingMinutesStr, _ := database.GetSystemConfig("stop_trading_minutes") - defaultCoinsStr, _ := database.GetSystemConfig("default_coins") - - // 6. 查询用户信号源配置 - var coinPoolURL, oiTopURL string - if userSignalSource, err := database.GetUserSignalSource(userID); err == nil { - coinPoolURL = userSignalSource.CoinPoolURL - oiTopURL = userSignalSource.OITopURL - log.Printf("📡 加载用户 %s 的信号源配置: COIN POOL=%s, OI TOP=%s", userID, coinPoolURL, oiTopURL) - } else { - log.Printf("🔍 用户 %s 暂未配置信号源", userID) - } - - // 7. 解析系统配置 + // 解析配置 maxDailyLoss := 10.0 // 默认值 if val, err := strconv.ParseFloat(maxDailyLossStr, 64); err == nil { maxDailyLoss = val @@ -989,34 +527,104 @@ func (tm *TraderManager) LoadTraderByID(database *config.Database, userID, trade var defaultCoins []string if defaultCoinsStr != "" { if err := json.Unmarshal([]byte(defaultCoinsStr), &defaultCoins); err != nil { - log.Printf("⚠️ 解析默认币种配置失败: %v,使用空列表", err) + logger.Infof("⚠️ 解析默认币种配置失败: %v,使用空列表", err) defaultCoins = []string{} } } - // 8. 调用私有方法加载交易员 - log.Printf("📋 加载单个交易员: %s (%s)", traderCfg.Name, traderID) - return tm.loadSingleTrader( - traderCfg, - aiModelCfg, - exchangeCfg, - coinPoolURL, - oiTopURL, - maxDailyLoss, - maxDrawdown, - stopTradingMinutes, - defaultCoins, - database, - userID, - ) + // 为每个交易员获取AI模型和交易所配置 + for _, traderCfg := range allTraders { + // 获取AI模型配置 + aiModels, err := st.AIModel().List(traderCfg.UserID) + if err != nil { + logger.Infof("⚠️ 获取AI模型配置失败: %v", err) + continue + } + + var aiModelCfg *store.AIModel + // 优先精确匹配 model.ID + for _, model := range aiModels { + if model.ID == traderCfg.AIModelID { + aiModelCfg = model + break + } + } + // 如果没有精确匹配,尝试匹配 provider(兼容旧数据) + if aiModelCfg == nil { + for _, model := range aiModels { + if model.Provider == traderCfg.AIModelID { + aiModelCfg = model + logger.Infof("⚠️ 交易员 %s 使用旧版 provider 匹配: %s -> %s", traderCfg.Name, traderCfg.AIModelID, model.ID) + break + } + } + } + + if aiModelCfg == nil { + logger.Infof("⚠️ 交易员 %s 的AI模型 %s 不存在,跳过", traderCfg.Name, traderCfg.AIModelID) + continue + } + + if !aiModelCfg.Enabled { + logger.Infof("⚠️ 交易员 %s 的AI模型 %s 未启用,跳过", traderCfg.Name, traderCfg.AIModelID) + continue + } + + // 获取交易所配置 + exchanges, err := st.Exchange().List(traderCfg.UserID) + if err != nil { + logger.Infof("⚠️ 获取交易所配置失败: %v", err) + continue + } + + var exchangeCfg *store.Exchange + for _, exchange := range exchanges { + if exchange.ID == traderCfg.ExchangeID { + exchangeCfg = exchange + break + } + } + + if exchangeCfg == nil { + logger.Infof("⚠️ 交易员 %s 的交易所 %s 不存在,跳过", traderCfg.Name, traderCfg.ExchangeID) + continue + } + + if !exchangeCfg.Enabled { + logger.Infof("⚠️ 交易员 %s 的交易所 %s 未启用,跳过", traderCfg.Name, traderCfg.ExchangeID) + continue + } + + // 获取用户信号源配置 + var coinPoolURL, oiTopURL string + if signalSource, err := st.SignalSource().Get(traderCfg.UserID); err == nil { + coinPoolURL = signalSource.CoinPoolURL + oiTopURL = signalSource.OITopURL + } else { + logger.Infof("🔍 用户 %s 暂未配置信号源", traderCfg.UserID) + } + + // 添加到TraderManager + err = tm.addTraderFromStore(traderCfg, aiModelCfg, exchangeCfg, coinPoolURL, oiTopURL, maxDailyLoss, maxDrawdown, stopTradingMinutes, defaultCoins, st) + if err != nil { + logger.Infof("❌ 添加交易员 %s 失败: %v", traderCfg.Name, err) + continue + } + } + + logger.Infof("✓ 成功加载 %d 个交易员到内存", len(tm.traders)) + return nil } -// loadSingleTrader 加载单个交易员(从现有代码提取的公共逻辑) -func (tm *TraderManager) loadSingleTrader(traderCfg *config.TraderRecord, aiModelCfg *config.AIModelConfig, exchangeCfg *config.ExchangeConfig, coinPoolURL, oiTopURL string, maxDailyLoss, maxDrawdown float64, stopTradingMinutes int, defaultCoins []string, database *config.Database, userID string) error { +// addTraderFromStore 内部方法:从store配置添加交易员 +func (tm *TraderManager) addTraderFromStore(traderCfg *store.Trader, aiModelCfg *store.AIModel, exchangeCfg *store.Exchange, coinPoolURL, oiTopURL string, maxDailyLoss, maxDrawdown float64, stopTradingMinutes int, defaultCoins []string, st *store.Store) error { + if _, exists := tm.traders[traderCfg.ID]; exists { + return fmt.Errorf("trader ID '%s' 已存在", traderCfg.ID) + } + // 处理交易币种列表 var tradingCoins []string if traderCfg.TradingSymbols != "" { - // 解析逗号分隔的交易币种列表 symbols := strings.Split(traderCfg.TradingSymbols, ",") for _, symbol := range symbols { symbol = strings.TrimSpace(symbol) @@ -1035,48 +643,54 @@ func (tm *TraderManager) loadSingleTrader(traderCfg *config.TraderRecord, aiMode var effectiveCoinPoolURL string if traderCfg.UseCoinPool && coinPoolURL != "" { effectiveCoinPoolURL = coinPoolURL - log.Printf("✓ 交易员 %s 启用 COIN POOL 信号源: %s", traderCfg.Name, coinPoolURL) + logger.Infof("✓ 交易员 %s 启用 COIN POOL 信号源: %s", traderCfg.Name, coinPoolURL) } // 构建AutoTraderConfig traderConfig := trader.AutoTraderConfig{ - ID: traderCfg.ID, - Name: traderCfg.Name, - AIModel: aiModelCfg.Provider, // 使用provider作为模型标识 - Exchange: exchangeCfg.ID, // 使用exchange ID - InitialBalance: traderCfg.InitialBalance, - BTCETHLeverage: traderCfg.BTCETHLeverage, - AltcoinLeverage: traderCfg.AltcoinLeverage, - ScanInterval: time.Duration(traderCfg.ScanIntervalMinutes) * time.Minute, - CoinPoolAPIURL: effectiveCoinPoolURL, - CustomAPIURL: aiModelCfg.CustomAPIURL, // 自定义API URL - CustomModelName: aiModelCfg.CustomModelName, // 自定义模型名称 - UseQwen: aiModelCfg.Provider == "qwen", - MaxDailyLoss: maxDailyLoss, - MaxDrawdown: maxDrawdown, - StopTradingTime: time.Duration(stopTradingMinutes) * time.Minute, - IsCrossMargin: traderCfg.IsCrossMargin, - DefaultCoins: defaultCoins, - TradingCoins: tradingCoins, - SystemPromptTemplate: traderCfg.SystemPromptTemplate, // 系统提示词模板 - HyperliquidTestnet: exchangeCfg.Testnet, // Hyperliquid测试网 + ID: traderCfg.ID, + Name: traderCfg.Name, + AIModel: aiModelCfg.Provider, + Exchange: exchangeCfg.ID, + BinanceAPIKey: "", + BinanceSecretKey: "", + HyperliquidPrivateKey: "", + HyperliquidTestnet: exchangeCfg.Testnet, + CoinPoolAPIURL: effectiveCoinPoolURL, + UseQwen: aiModelCfg.Provider == "qwen", + DeepSeekKey: "", + QwenKey: "", + CustomAPIURL: aiModelCfg.CustomAPIURL, + CustomModelName: aiModelCfg.CustomModelName, + ScanInterval: time.Duration(traderCfg.ScanIntervalMinutes) * time.Minute, + InitialBalance: traderCfg.InitialBalance, + BTCETHLeverage: traderCfg.BTCETHLeverage, + AltcoinLeverage: traderCfg.AltcoinLeverage, + MaxDailyLoss: maxDailyLoss, + MaxDrawdown: maxDrawdown, + StopTradingTime: time.Duration(stopTradingMinutes) * time.Minute, + IsCrossMargin: traderCfg.IsCrossMargin, + DefaultCoins: defaultCoins, + TradingCoins: tradingCoins, + SystemPromptTemplate: traderCfg.SystemPromptTemplate, } // 根据交易所类型设置API密钥 - if exchangeCfg.ID == "binance" { + switch exchangeCfg.ID { + case "binance": traderConfig.BinanceAPIKey = exchangeCfg.APIKey traderConfig.BinanceSecretKey = exchangeCfg.SecretKey - } else if exchangeCfg.ID == "bybit" { + case "bybit": traderConfig.BybitAPIKey = exchangeCfg.APIKey traderConfig.BybitSecretKey = exchangeCfg.SecretKey - } else if exchangeCfg.ID == "hyperliquid" { - traderConfig.HyperliquidPrivateKey = exchangeCfg.APIKey // hyperliquid用APIKey存储private key + case "hyperliquid": + traderConfig.HyperliquidPrivateKey = exchangeCfg.APIKey traderConfig.HyperliquidWalletAddr = exchangeCfg.HyperliquidWalletAddr - } else if exchangeCfg.ID == "aster" { + case "aster": traderConfig.AsterUser = exchangeCfg.AsterUser traderConfig.AsterSigner = exchangeCfg.AsterSigner traderConfig.AsterPrivateKey = exchangeCfg.AsterPrivateKey - } else if exchangeCfg.ID == "lighter" { + case "lighter": traderConfig.LighterPrivateKey = exchangeCfg.LighterPrivateKey traderConfig.LighterWalletAddr = exchangeCfg.LighterWalletAddr traderConfig.LighterTestnet = exchangeCfg.Testnet @@ -1090,7 +704,7 @@ func (tm *TraderManager) loadSingleTrader(traderCfg *config.TraderRecord, aiMode } // 创建trader实例 - at, err := trader.NewAutoTrader(traderConfig, database, userID) + at, err := trader.NewAutoTrader(traderConfig, st, traderCfg.UserID) if err != nil { return fmt.Errorf("创建trader失败: %w", err) } @@ -1100,25 +714,13 @@ func (tm *TraderManager) loadSingleTrader(traderCfg *config.TraderRecord, aiMode at.SetCustomPrompt(traderCfg.CustomPrompt) at.SetOverrideBasePrompt(traderCfg.OverrideBasePrompt) if traderCfg.OverrideBasePrompt { - log.Printf("✓ 已设置自定义交易策略prompt (覆盖基础prompt)") + logger.Infof("✓ 已设置自定义交易策略prompt (覆盖基础prompt)") } else { - log.Printf("✓ 已设置自定义交易策略prompt (补充基础prompt)") + logger.Infof("✓ 已设置自定义交易策略prompt (补充基础prompt)") } } tm.traders[traderCfg.ID] = at - log.Printf("✓ Trader '%s' (%s + %s) 已为用户加载到内存", traderCfg.Name, aiModelCfg.Provider, exchangeCfg.ID) + logger.Infof("✓ Trader '%s' (%s + %s) 已加载到内存", traderCfg.Name, aiModelCfg.Provider, exchangeCfg.ID) return nil } - -// RemoveTrader 从内存中移除指定的trader(不影响数据库) -// 用于更新trader配置时强制重新加载 -func (tm *TraderManager) RemoveTrader(traderID string) { - tm.mu.Lock() - defer tm.mu.Unlock() - - if _, exists := tm.traders[traderID]; exists { - delete(tm.traders, traderID) - log.Printf("✓ Trader %s 已从内存中移除", traderID) - } -} diff --git a/market/data.go b/market/data.go index 32a9f8c4..6a151391 100644 --- a/market/data.go +++ b/market/data.go @@ -4,7 +4,7 @@ import ( "encoding/json" "fmt" "io" - "log" + "nofx/logger" "math" "strconv" "strings" @@ -38,7 +38,7 @@ func Get(symbol string) (*Data, error) { // Data staleness detection: Prevent DOGEUSDT-style price freeze issues if isStaleData(klines3m, symbol) { - log.Printf("⚠️ WARNING: %s detected stale data (consecutive price freeze), skipping symbol", symbol) + logger.Infof("⚠️ WARNING: %s detected stale data (consecutive price freeze), skipping symbol", symbol) return nil, fmt.Errorf("%s data is stale, possible cache failure", symbol) } @@ -633,11 +633,11 @@ func isStaleData(klines []Kline, symbol string) bool { } if allVolumeZero { - log.Printf("⚠️ %s stale data confirmed: price freeze + zero volume", symbol) + logger.Infof("⚠️ %s stale data confirmed: price freeze + zero volume", symbol) return true } // Price frozen but has volume: might be extremely low volatility market, allow but log warning - log.Printf("⚠️ %s detected extreme price stability (no fluctuation for %d consecutive periods), but volume is normal", symbol, stalePriceThreshold) + logger.Infof("⚠️ %s detected extreme price stability (no fluctuation for %d consecutive periods), but volume is normal", symbol, stalePriceThreshold) return false } diff --git a/mcp/config.go b/mcp/config.go index a32686a5..d235a28d 100644 --- a/mcp/config.go +++ b/mcp/config.go @@ -5,6 +5,8 @@ import ( "os" "strconv" "time" + + "nofx/logger" ) // Config 客户端配置(集中管理所有配置) @@ -44,8 +46,8 @@ func DefaultConfig() *Config { Timeout: DefaultTimeout, RetryableErrors: retryableErrors, - // 默认依赖 - Logger: &defaultLogger{}, + // 默认依赖(使用全局 logger) + Logger: logger.NewMCPLogger(), HTTPClient: &http.Client{Timeout: DefaultTimeout}, } } diff --git a/mcp/logger.go b/mcp/logger.go index 863310db..e12aa206 100644 --- a/mcp/logger.go +++ b/mcp/logger.go @@ -1,9 +1,8 @@ package mcp -import "log" - // Logger 日志接口(抽象依赖) // 使用 Printf 风格的方法名,方便集成 logrus、zap 等主流日志库 +// 默认使用全局 logger 包(见 mcp/config.go) type Logger interface { Debugf(format string, args ...any) Infof(format string, args ...any) @@ -11,25 +10,6 @@ type Logger interface { Errorf(format string, args ...any) } -// defaultLogger 默认日志实现(包装标准库 log) -type defaultLogger struct{} - -func (l *defaultLogger) Debugf(format string, args ...any) { - log.Printf("[DEBUG] "+format, args...) -} - -func (l *defaultLogger) Infof(format string, args ...any) { - log.Printf("[INFO] "+format, args...) -} - -func (l *defaultLogger) Warnf(format string, args ...any) { - log.Printf("[WARN] "+format, args...) -} - -func (l *defaultLogger) Errorf(format string, args ...any) { - log.Printf("[ERROR] "+format, args...) -} - // noopLogger 空日志实现(测试时使用) type noopLogger struct{} @@ -42,27 +22,3 @@ func (l *noopLogger) Errorf(format string, args ...any) {} func NewNoopLogger() Logger { return &noopLogger{} } - -// ============================================================ -// 适配第三方日志库示例 -// ============================================================ - -// Logrus 适配示例: -// type LogrusLogger struct { -// logger *logrus.Logger -// } -// -// func (l *LogrusLogger) Infof(format string, args ...any) { -// l.logger.Infof(format, args...) -// } -// -// Zap 适配示例: -// type ZapLogger struct { -// logger *zap.Logger -// } -// -// func (l *ZapLogger) Infof(format string, args ...any) { -// l.logger.Sugar().Infof(format, args...) -// } -// -// 然后通过 WithLogger(logger) 注入 diff --git a/screenshots/competition-page.png b/screenshots/competition-page.png deleted file mode 100644 index dad13d4e..00000000 Binary files a/screenshots/competition-page.png and /dev/null differ diff --git a/scripts/ENCRYPTION_README.md b/scripts/ENCRYPTION_README.md index f72a410e..672ad0d8 100644 --- a/scripts/ENCRYPTION_README.md +++ b/scripts/ENCRYPTION_README.md @@ -203,7 +203,7 @@ spec: ./scripts/generate_data_key.sh # 2. 备份旧数据库 -cp config.db config.db.backup +cp data.db data.db.backup # 3. 重启服务 (会自动处理密钥迁移) source .env && ./mars diff --git a/scripts/generate_data_key.sh b/scripts/generate_data_key.sh deleted file mode 100755 index 2e739162..00000000 --- a/scripts/generate_data_key.sh +++ /dev/null @@ -1,143 +0,0 @@ -#!/bin/bash - -# 数据加密密钥生成脚本 - 用于Mars AI交易系统数据库加密 -# 生成用于AES-256-GCM数据库加密的随机密钥 - -set -e # 遇到错误立即退出 - -# 颜色定义 -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -PURPLE='\033[0;35m' -NC='\033[0m' # No Color - -echo -e "${BLUE}╔══════════════════════════════════════════════════════════════════╗${NC}" -echo -e "${BLUE}║ Mars AI交易系统 安全密钥生成器 ║${NC}" -echo -e "${BLUE}║ AES-256-GCM数据密钥 + JWT认证密钥 ║${NC}" -echo -e "${BLUE}╚══════════════════════════════════════════════════════════════════╝${NC}" -echo - -# 检查是否安装了 OpenSSL -if ! command -v openssl &> /dev/null; then - echo -e "${RED}❌ 错误: 系统中未安装 OpenSSL${NC}" - echo -e "请安装 OpenSSL:" - echo -e " macOS: ${YELLOW}brew install openssl${NC}" - echo -e " Ubuntu/Debian: ${YELLOW}sudo apt-get install openssl${NC}" - echo -e " CentOS/RHEL: ${YELLOW}sudo yum install openssl${NC}" - exit 1 -fi - -echo -e "${GREEN}✓ OpenSSL 已安装: $(openssl version)${NC}" - -# 生成安全密钥 -echo -e "${BLUE}🔐 生成安全密钥...${NC}" -echo - -# 生成 AES-256 数据加密密钥 -echo -e "${YELLOW}1/2: 生成 AES-256 数据加密密钥...${NC}" -DATA_KEY=$(openssl rand -base64 32) -if [ $? -eq 0 ]; then - echo -e "${GREEN} ✓ 数据加密密钥生成成功${NC}" -else - echo -e "${RED} ❌ 数据加密密钥生成失败${NC}" - exit 1 -fi - -# 生成 JWT 认证密钥 -echo -e "${YELLOW}2/2: 生成 JWT 认证密钥...${NC}" -JWT_KEY=$(openssl rand -base64 64) -if [ $? -eq 0 ]; then - echo -e "${GREEN} ✓ JWT认证密钥生成成功${NC}" -else - echo -e "${RED} ❌ JWT认证密钥生成失败${NC}" - exit 1 -fi - -# 显示密钥 -echo -echo -e "${GREEN}🎉 安全密钥生成完成!${NC}" -echo -echo -e "${BLUE}📋 生成的密钥:${NC}" -echo -e "${PURPLE}1. 数据加密密钥 (AES-256):${NC}" -echo -e "${YELLOW}$DATA_KEY${NC}" -echo -echo -e "${PURPLE}2. JWT认证密钥 (512-bit):${NC}" -echo -e "${YELLOW}$JWT_KEY${NC}" -echo - -# 显示使用方法 -echo -e "${YELLOW}📋 使用方法:${NC}" -echo -echo -e "${BLUE}1. 环境变量设置:${NC}" -echo -e " export DATA_ENCRYPTION_KEY=\"$DATA_KEY\"" -echo -e " export JWT_SECRET=\"$JWT_KEY\"" -echo -echo -e "${BLUE}2. .env 文件设置:${NC}" -echo -e " DATA_ENCRYPTION_KEY=$DATA_KEY" -echo -e " JWT_SECRET=$JWT_KEY" -echo -echo -e "${BLUE}3. Docker环境设置:${NC}" -echo -e " docker run -e DATA_ENCRYPTION_KEY=\"$DATA_KEY\" -e JWT_SECRET=\"$JWT_KEY\" ..." -echo -echo -e "${BLUE}4. Kubernetes Secret:${NC}" -echo -e " kubectl create secret generic mars-crypto-key \\" -echo -e " --from-literal=DATA_ENCRYPTION_KEY=\"$DATA_KEY\" \\" -echo -e " --from-literal=JWT_SECRET=\"$JWT_KEY\"" -echo - -# 显示密钥特性 -echo -e "${BLUE}🔍 密钥特性:${NC}" -echo -e " • 数据加密: ${YELLOW}AES-256-GCM (256 bits)${NC}" -echo -e " • JWT认证: ${YELLOW}HS256 (512 bits)${NC}" -echo -e " • 格式: ${YELLOW}Base64 编码${NC}" -echo -e " • 用途: ${YELLOW}数据库加密 + 用户认证${NC}" - -# 安全提醒 -echo -echo -e "${RED}⚠️ 安全提醒:${NC}" -echo -e " • 请妥善保管此密钥,丢失后无法恢复加密的数据" -echo -e " • 不要将密钥提交到版本控制系统" -echo -e " • 建议在不同环境使用不同的密钥" -echo -e " • 定期更换密钥并重新加密数据" -echo -e " • 在生产环境中,建议使用密钥管理服务" - -echo -echo -e "${GREEN}✅ 数据加密密钥生成完成!${NC}" - -# 可选:保存到 .env 文件 -echo -read -p "是否将密钥保存到 .env 文件? [y/N]: " -n 1 -r -echo -if [[ $REPLY =~ ^[Yy]$ ]]; then - if [ -f ".env" ]; then - # 检查是否已存在 DATA_ENCRYPTION_KEY - if grep -q "^DATA_ENCRYPTION_KEY=" .env; then - echo -e "${YELLOW}⚠️ .env 文件中已存在 DATA_ENCRYPTION_KEY${NC}" - read -p "是否覆盖现有密钥? [y/N]: " -n 1 -r - echo - if [[ $REPLY =~ ^[Yy]$ ]]; then - # 替换现有密钥 - if [[ "$OSTYPE" == "darwin"* ]]; then - # macOS - sed -i '' "s/^DATA_ENCRYPTION_KEY=.*/DATA_ENCRYPTION_KEY=$RAW_KEY/" .env - else - # Linux - sed -i "s/^DATA_ENCRYPTION_KEY=.*/DATA_ENCRYPTION_KEY=$RAW_KEY/" .env - fi - echo -e "${GREEN}✓ .env 文件中的密钥已更新${NC}" - else - echo -e "${BLUE}ℹ️ 保持现有密钥不变${NC}" - fi - else - # 追加新密钥 - echo "DATA_ENCRYPTION_KEY=$RAW_KEY" >> .env - echo -e "${GREEN}✓ 密钥已保存到 .env 文件${NC}" - fi - else - # 创建新的 .env 文件 - echo "DATA_ENCRYPTION_KEY=$RAW_KEY" > .env - echo -e "${GREEN}✓ 密钥已保存到 .env 文件${NC}" - fi -fi \ No newline at end of file diff --git a/scripts/generate_rsa_keys.sh b/scripts/generate_rsa_keys.sh deleted file mode 100755 index 021a7cce..00000000 --- a/scripts/generate_rsa_keys.sh +++ /dev/null @@ -1,149 +0,0 @@ -#!/bin/bash - -# RSA密钥对生成脚本 - 用于Mars AI交易系统加密服务 -# 生成用于混合加密的RSA-2048密钥对 - -set -e # 遇到错误立即退出 - -# 颜色定义 -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -NC='\033[0m' # No Color - -# 配置 -RSA_KEY_SIZE=2048 -SECRETS_DIR="secrets" -PRIVATE_KEY_FILE="$SECRETS_DIR/rsa_key" -PUBLIC_KEY_FILE="$SECRETS_DIR/rsa_key.pub" - -echo -e "${BLUE}╔══════════════════════════════════════════════════════════════════╗${NC}" -echo -e "${BLUE}║ Mars AI交易系统 RSA密钥生成器 ║${NC}" -echo -e "${BLUE}║ RSA-2048 混合加密密钥对 ║${NC}" -echo -e "${BLUE}╚══════════════════════════════════════════════════════════════════╝${NC}" -echo - -# 检查是否安装了 OpenSSL -if ! command -v openssl &> /dev/null; then - echo -e "${RED}❌ 错误: 系统中未安装 OpenSSL${NC}" - echo -e "请安装 OpenSSL:" - echo -e " macOS: ${YELLOW}brew install openssl${NC}" - echo -e " Ubuntu/Debian: ${YELLOW}sudo apt-get install openssl${NC}" - echo -e " CentOS/RHEL: ${YELLOW}sudo yum install openssl${NC}" - exit 1 -fi - -echo -e "${GREEN}✓ OpenSSL 已安装: $(openssl version)${NC}" - -# 创建 secrets 目录 -if [ ! -d "$SECRETS_DIR" ]; then - echo -e "${YELLOW}📁 创建 $SECRETS_DIR 目录...${NC}" - mkdir -p "$SECRETS_DIR" - chmod 700 "$SECRETS_DIR" - echo -e "${GREEN}✓ 目录创建成功${NC}" -else - echo -e "${GREEN}✓ $SECRETS_DIR 目录已存在${NC}" -fi - -# 检查现有密钥 -if [ -f "$PRIVATE_KEY_FILE" ] || [ -f "$PUBLIC_KEY_FILE" ]; then - echo - echo -e "${YELLOW}⚠️ 检测到现有的RSA密钥文件:${NC}" - [ -f "$PRIVATE_KEY_FILE" ] && echo -e " • $PRIVATE_KEY_FILE" - [ -f "$PUBLIC_KEY_FILE" ] && echo -e " • $PUBLIC_KEY_FILE" - echo - read -p "是否覆盖现有密钥? [y/N]: " -n 1 -r - echo - if [[ ! $REPLY =~ ^[Yy]$ ]]; then - echo -e "${BLUE}ℹ️ 操作已取消${NC}" - exit 0 - fi - echo -e "${YELLOW}🗑️ 删除现有密钥文件...${NC}" - rm -f "$PRIVATE_KEY_FILE" "$PUBLIC_KEY_FILE" -fi - -echo -echo -e "${BLUE}🔐 开始生成 RSA-$RSA_KEY_SIZE 密钥对...${NC}" - -# 生成私钥 -echo -e "${YELLOW}📝 步骤 1/3: 生成 RSA 私钥 ($RSA_KEY_SIZE bits)...${NC}" -if openssl genrsa -out "$PRIVATE_KEY_FILE" $RSA_KEY_SIZE 2>/dev/null; then - echo -e "${GREEN}✓ 私钥生成成功${NC}" -else - echo -e "${RED}❌ 私钥生成失败${NC}" - exit 1 -fi - -# 设置私钥权限 -chmod 600 "$PRIVATE_KEY_FILE" -echo -e "${GREEN}✓ 私钥权限设置为 600${NC}" - -# 生成公钥 -echo -e "${YELLOW}📝 步骤 2/3: 从私钥提取公钥...${NC}" -if openssl rsa -in "$PRIVATE_KEY_FILE" -pubout -out "$PUBLIC_KEY_FILE" 2>/dev/null; then - echo -e "${GREEN}✓ 公钥生成成功${NC}" -else - echo -e "${RED}❌ 公钥生成失败${NC}" - exit 1 -fi - -# 设置公钥权限 -chmod 644 "$PUBLIC_KEY_FILE" -echo -e "${GREEN}✓ 公钥权限设置为 644${NC}" - -# 验证密钥 -echo -e "${YELLOW}📝 步骤 3/3: 验证密钥对...${NC}" -if openssl rsa -in "$PRIVATE_KEY_FILE" -check -noout 2>/dev/null; then - echo -e "${GREEN}✓ 私钥验证通过${NC}" -else - echo -e "${RED}❌ 私钥验证失败${NC}" - exit 1 -fi - -if openssl rsa -in "$PUBLIC_KEY_FILE" -pubin -text -noout &>/dev/null; then - echo -e "${GREEN}✓ 公钥验证通过${NC}" -else - echo -e "${RED}❌ 公钥验证失败${NC}" - exit 1 -fi - -# 显示密钥信息 -echo -echo -e "${GREEN}🎉 RSA密钥对生成成功!${NC}" -echo -echo -e "${BLUE}📋 密钥信息:${NC}" -echo -e " 私钥文件: ${YELLOW}$PRIVATE_KEY_FILE${NC}" -echo -e " 公钥文件: ${YELLOW}$PUBLIC_KEY_FILE${NC}" -echo -e " 密钥大小: ${YELLOW}$RSA_KEY_SIZE bits${NC}" -echo - -# 显示文件大小 -PRIVATE_SIZE=$(stat -f%z "$PRIVATE_KEY_FILE" 2>/dev/null || stat -c%s "$PRIVATE_KEY_FILE" 2>/dev/null || echo "未知") -PUBLIC_SIZE=$(stat -f%z "$PUBLIC_KEY_FILE" 2>/dev/null || stat -c%s "$PUBLIC_KEY_FILE" 2>/dev/null || echo "未知") - -echo -e "${BLUE}📏 文件大小:${NC}" -echo -e " 私钥: ${YELLOW}$PRIVATE_SIZE bytes${NC}" -echo -e " 公钥: ${YELLOW}$PUBLIC_SIZE bytes${NC}" - -# 显示公钥内容预览 -echo -echo -e "${BLUE}🔍 公钥内容预览:${NC}" -head -n 5 "$PUBLIC_KEY_FILE" | sed 's/^/ /' -echo -e " ${YELLOW}...${NC}" -tail -n 2 "$PUBLIC_KEY_FILE" | sed 's/^/ /' - -echo -echo -e "${GREEN}✅ RSA密钥对生成完成!${NC}" -echo -echo -e "${YELLOW}📋 使用说明:${NC}" -echo -e " 1. 私钥文件 ($PRIVATE_KEY_FILE) 用于服务器端解密" -echo -e " 2. 公钥文件 ($PUBLIC_KEY_FILE) 可以分发给客户端用于加密" -echo -e " 3. 确保私钥文件的安全性,不要泄露给第三方" -echo -e " 4. 在生产环境中,建议将私钥存储在安全的密钥管理服务中" -echo -echo -e "${RED}⚠️ 安全提醒:${NC}" -echo -e " • 私钥文件权限已设置为 600 (仅所有者可读写)" -echo -e " • 请定期备份密钥文件" -echo -e " • 建议在不同环境使用不同的密钥对" -echo \ No newline at end of file diff --git a/scripts/migrate_encryption.go b/scripts/migrate_encryption.go index f17fbe7e..2c4b9d38 100644 --- a/scripts/migrate_encryption.go +++ b/scripts/migrate_encryption.go @@ -12,71 +12,71 @@ import ( ) func main() { - log.Println("🔄 開始遷移數據庫到加密格式...") + log.Println("🔄 开始迁移数据库到加密格式...") - // 1. 檢查數據庫檔案 - dbPath := "config.db" + // 1. 检查数据库文件 + dbPath := "data.db" if len(os.Args) > 1 { dbPath = os.Args[1] } if _, err := os.Stat(dbPath); os.IsNotExist(err) { - log.Fatalf("❌ 數據庫檔案不存在: %s", dbPath) + log.Fatalf("❌ 数据库文件不存在: %s", dbPath) } - // 2. 備份數據庫 + // 2. 备份数据库 backupPath := fmt.Sprintf("%s.pre_encryption_backup", dbPath) - log.Printf("📦 備份數據庫到: %s", backupPath) + log.Printf("📦 备份数据库到: %s", backupPath) input, err := os.ReadFile(dbPath) if err != nil { - log.Fatalf("❌ 讀取數據庫失敗: %v", err) + log.Fatalf("❌ 读取数据库失败: %v", err) } if err := os.WriteFile(backupPath, input, 0600); err != nil { - log.Fatalf("❌ 備份失敗: %v", err) + log.Fatalf("❌ 备份失败: %v", err) } - // 3. 打開數據庫 + // 3. 打开数据库 db, err := sql.Open("sqlite", dbPath) if err != nil { - log.Fatalf("❌ 打開數據庫失敗: %v", err) + log.Fatalf("❌ 打开数据库失败: %v", err) } defer db.Close() - // 4. 初始化加密管理器 - em, err := crypto.GetEncryptionManager() + // 4. 初始化 CryptoService(从环境变量加载密钥) + cs, err := crypto.NewCryptoService() if err != nil { - log.Fatalf("❌ 初始化加密管理器失敗: %v", err) + log.Fatalf("❌ 初始化加密服务失败: %v", err) } - // 5. 遷移交易所配置 - if err := migrateExchanges(db, em); err != nil { - log.Fatalf("❌ 遷移交易所配置失敗: %v", err) + // 5. 迁移交易所配置 + if err := migrateExchanges(db, cs); err != nil { + log.Fatalf("❌ 迁移交易所配置失败: %v", err) } - // 6. 遷移 AI 模型配置 - if err := migrateAIModels(db, em); err != nil { - log.Fatalf("❌ 遷移 AI 模型配置失敗: %v", err) + // 6. 迁移 AI 模型配置 + if err := migrateAIModels(db, cs); err != nil { + log.Fatalf("❌ 迁移 AI 模型配置失败: %v", err) } - log.Println("✅ 數據遷移完成!") - log.Printf("📝 原始數據備份位於: %s", backupPath) - log.Println("⚠️ 請驗證系統功能正常後,手動刪除備份檔案") + log.Println("✅ 数据迁移完成!") + log.Printf("📝 原始数据备份位于: %s", backupPath) + log.Println("⚠️ 请验证系统功能正常后,手动删除备份文件") } -// migrateExchanges 遷移交易所配置 -func migrateExchanges(db *sql.DB, em *crypto.EncryptionManager) error { - log.Println("🔄 遷移交易所配置...") +// migrateExchanges 迁移交易所配置 +func migrateExchanges(db *sql.DB, cs *crypto.CryptoService) error { + log.Println("🔄 迁移交易所配置...") - // 查詢所有未加密的記錄(假設加密數據都包含 '==' Base64 特徵) + // 查询所有未加密的记录(加密数据以 ENC:v1: 开头) rows, err := db.Query(` SELECT user_id, id, api_key, secret_key, COALESCE(hyperliquid_private_key, ''), COALESCE(aster_private_key, '') FROM exchanges - WHERE (api_key != '' AND api_key NOT LIKE '%==%') - OR (secret_key != '' AND secret_key NOT LIKE '%==%') + WHERE (api_key != '' AND api_key NOT LIKE 'ENC:v1:%') + OR (secret_key != '' AND secret_key NOT LIKE 'ENC:v1:%') `) if err != nil { return err @@ -96,34 +96,34 @@ func migrateExchanges(db *sql.DB, em *crypto.EncryptionManager) error { return err } - // 加密每個字段 - encAPIKey, err := em.EncryptForDatabase(apiKey) + // 加密每个字段 + encAPIKey, err := cs.EncryptForStorage(apiKey) if err != nil { - return fmt.Errorf("加密 API Key 失敗: %w", err) + return fmt.Errorf("加密 API Key 失败: %w", err) } - encSecretKey, err := em.EncryptForDatabase(secretKey) + encSecretKey, err := cs.EncryptForStorage(secretKey) if err != nil { - return fmt.Errorf("加密 Secret Key 失敗: %w", err) + return fmt.Errorf("加密 Secret Key 失败: %w", err) } encHLPrivateKey := "" if hlPrivateKey != "" { - encHLPrivateKey, err = em.EncryptForDatabase(hlPrivateKey) + encHLPrivateKey, err = cs.EncryptForStorage(hlPrivateKey) if err != nil { - return fmt.Errorf("加密 Hyperliquid Private Key 失敗: %w", err) + return fmt.Errorf("加密 Hyperliquid Private Key 失败: %w", err) } } encAsterPrivateKey := "" if asterPrivateKey != "" { - encAsterPrivateKey, err = em.EncryptForDatabase(asterPrivateKey) + encAsterPrivateKey, err = cs.EncryptForStorage(asterPrivateKey) if err != nil { - return fmt.Errorf("加密 Aster Private Key 失敗: %w", err) + return fmt.Errorf("加密 Aster Private Key 失败: %w", err) } } - // 更新數據庫 + // 更新数据库 _, err = tx.Exec(` UPDATE exchanges SET api_key = ?, secret_key = ?, @@ -132,7 +132,7 @@ func migrateExchanges(db *sql.DB, em *crypto.EncryptionManager) error { `, encAPIKey, encSecretKey, encHLPrivateKey, encAsterPrivateKey, userID, exchangeID) if err != nil { - return fmt.Errorf("更新數據庫失敗: %w", err) + return fmt.Errorf("更新数据库失败: %w", err) } log.Printf(" ✓ 已加密: [%s] %s", userID, exchangeID) @@ -143,18 +143,18 @@ func migrateExchanges(db *sql.DB, em *crypto.EncryptionManager) error { return err } - log.Printf("✅ 已遷移 %d 個交易所配置", count) + log.Printf("✅ 已迁移 %d 个交易所配置", count) return nil } -// migrateAIModels 遷移 AI 模型配置 -func migrateAIModels(db *sql.DB, em *crypto.EncryptionManager) error { - log.Println("🔄 遷移 AI 模型配置...") +// migrateAIModels 迁移 AI 模型配置 +func migrateAIModels(db *sql.DB, cs *crypto.CryptoService) error { + log.Println("🔄 迁移 AI 模型配置...") rows, err := db.Query(` SELECT user_id, id, api_key FROM ai_models - WHERE api_key != '' AND api_key NOT LIKE '%==%' + WHERE api_key != '' AND api_key NOT LIKE 'ENC:v1:%' `) if err != nil { return err @@ -174,9 +174,9 @@ func migrateAIModels(db *sql.DB, em *crypto.EncryptionManager) error { return err } - encAPIKey, err := em.EncryptForDatabase(apiKey) + encAPIKey, err := cs.EncryptForStorage(apiKey) if err != nil { - return fmt.Errorf("加密 API Key 失敗: %w", err) + return fmt.Errorf("加密 API Key 失败: %w", err) } _, err = tx.Exec(` @@ -184,7 +184,7 @@ func migrateAIModels(db *sql.DB, em *crypto.EncryptionManager) error { `, encAPIKey, userID, modelID) if err != nil { - return fmt.Errorf("更新數據庫失敗: %w", err) + return fmt.Errorf("更新数据库失败: %w", err) } log.Printf(" ✓ 已加密: [%s] %s", userID, modelID) @@ -195,6 +195,6 @@ func migrateAIModels(db *sql.DB, em *crypto.EncryptionManager) error { return err } - log.Printf("✅ 已遷移 %d 個 AI 模型配置", count) + log.Printf("✅ 已迁移 %d 个 AI 模型配置", count) return nil } diff --git a/scripts/setup_encryption.sh b/scripts/setup_encryption.sh deleted file mode 100755 index ec371063..00000000 --- a/scripts/setup_encryption.sh +++ /dev/null @@ -1,319 +0,0 @@ -#!/bin/bash - -# Mars AI交易系统加密环境设置脚本 -# 一键生成RSA密钥对和数据加密密钥,完整设置加密环境 - -set -e # 遇到错误立即退出 - -# 颜色定义 -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -PURPLE='\033[0;35m' -CYAN='\033[0;36m' -NC='\033[0m' # No Color - -# 获取脚本所在目录 -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" - -echo -e "${PURPLE}╔════════════════════════════════════════════════════════════════════════╗${NC}" -echo -e "${PURPLE}║ Mars AI交易系统 ║${NC}" -echo -e "${PURPLE}║ 🔐 加密环境一键设置工具 ║${NC}" -echo -e "${PURPLE}║ ║${NC}" -echo -e "${PURPLE}║ 功能: 生成RSA密钥对 + 数据加密密钥 + 配置环境变量 ║${NC}" -echo -e "${PURPLE}╚════════════════════════════════════════════════════════════════════════╝${NC}" -echo - -# 检查依赖 -echo -e "${CYAN}🔍 检查系统依赖...${NC}" - -# 检查 OpenSSL -if ! command -v openssl &> /dev/null; then - echo -e "${RED}❌ 错误: 系统中未安装 OpenSSL${NC}" - echo -e "请安装 OpenSSL:" - echo -e " macOS: ${YELLOW}brew install openssl${NC}" - echo -e " Ubuntu/Debian: ${YELLOW}sudo apt-get install openssl${NC}" - echo -e " CentOS/RHEL: ${YELLOW}sudo yum install openssl${NC}" - exit 1 -fi - -echo -e "${GREEN}✓ OpenSSL: $(openssl version)${NC}" - -# 进入项目根目录 -cd "$PROJECT_ROOT" -echo -e "${GREEN}✓ 工作目录: $(pwd)${NC}" - -# 配置参数 -RSA_KEY_SIZE=2048 -SECRETS_DIR="secrets" -PRIVATE_KEY_FILE="$SECRETS_DIR/rsa_key" -PUBLIC_KEY_FILE="$SECRETS_DIR/rsa_key.pub" - -echo -echo -e "${BLUE}📋 配置参数:${NC}" -echo -e " • RSA密钥大小: ${YELLOW}$RSA_KEY_SIZE bits${NC}" -echo -e " • 私钥文件: ${YELLOW}$PRIVATE_KEY_FILE${NC}" -echo -e " • 公钥文件: ${YELLOW}$PUBLIC_KEY_FILE${NC}" -echo -e " • AES密钥: ${YELLOW}256 bits (自动生成)${NC}" - -# 询问用户确认 -echo -read -p "是否继续设置加密环境? [Y/n]: " -n 1 -r -echo -if [[ $REPLY =~ ^[Nn]$ ]]; then - echo -e "${BLUE}ℹ️ 操作已取消${NC}" - exit 0 -fi - -echo -echo -e "${CYAN}🚀 开始设置加密环境...${NC}" - -# ============= 步骤1: 创建目录 ============= -echo -echo -e "${YELLOW}📁 步骤 1/4: 创建必要目录...${NC}" - -if [ ! -d "$SECRETS_DIR" ]; then - mkdir -p "$SECRETS_DIR" - chmod 700 "$SECRETS_DIR" - echo -e "${GREEN}✓ 创建 $SECRETS_DIR 目录${NC}" -else - echo -e "${GREEN}✓ $SECRETS_DIR 目录已存在${NC}" -fi - -if [ ! -d "scripts" ]; then - mkdir -p "scripts" - echo -e "${GREEN}✓ 创建 scripts 目录${NC}" -else - echo -e "${GREEN}✓ scripts 目录已存在${NC}" -fi - -# ============= 步骤2: 生成RSA密钥对 ============= -echo -echo -e "${YELLOW}🔐 步骤 2/4: 生成 RSA-$RSA_KEY_SIZE 密钥对...${NC}" - -# 检查现有RSA密钥 -if [ -f "$PRIVATE_KEY_FILE" ] || [ -f "$PUBLIC_KEY_FILE" ]; then - echo -e "${YELLOW}⚠️ 检测到现有的RSA密钥文件${NC}" - read -p "是否重新生成RSA密钥? [y/N]: " -n 1 -r - echo - if [[ $REPLY =~ ^[Yy]$ ]]; then - rm -f "$PRIVATE_KEY_FILE" "$PUBLIC_KEY_FILE" - echo -e "${YELLOW}🗑️ 删除旧密钥${NC}" - else - echo -e "${BLUE}ℹ️ 保持现有RSA密钥${NC}" - RSA_SKIPPED=true - fi -fi - -if [ "$RSA_SKIPPED" != "true" ]; then - # 生成私钥 - echo -e " ${CYAN}生成RSA私钥...${NC}" - openssl genrsa -out "$PRIVATE_KEY_FILE" $RSA_KEY_SIZE 2>/dev/null - chmod 600 "$PRIVATE_KEY_FILE" - echo -e "${GREEN} ✓ 私钥生成完成${NC}" - - # 生成公钥 - echo -e " ${CYAN}提取RSA公钥...${NC}" - openssl rsa -in "$PRIVATE_KEY_FILE" -pubout -out "$PUBLIC_KEY_FILE" 2>/dev/null - chmod 644 "$PUBLIC_KEY_FILE" - echo -e "${GREEN} ✓ 公钥生成完成${NC}" - - # 验证密钥 - echo -e " ${CYAN}验证密钥对...${NC}" - openssl rsa -in "$PRIVATE_KEY_FILE" -check -noout 2>/dev/null - echo -e "${GREEN} ✓ 密钥验证通过${NC}" -fi - -# ============= 步骤3: 生成数据加密密钥和JWT密钥 ============= -echo -echo -e "${YELLOW}🔑 步骤 3/4: 生成 AES-256 数据加密密钥和JWT认证密钥...${NC}" - -# 检查现有密钥 -DATA_KEY_EXISTS=false -JWT_KEY_EXISTS=false - -if [ -f ".env" ]; then - if grep -q "^DATA_ENCRYPTION_KEY=" .env; then - DATA_KEY_EXISTS=true - fi - if grep -q "^JWT_SECRET=" .env; then - JWT_KEY_EXISTS=true - fi -fi - -if [ "$DATA_KEY_EXISTS" = "true" ] || [ "$JWT_KEY_EXISTS" = "true" ]; then - echo -e "${YELLOW}⚠️ 检测到现有的密钥配置${NC}" - if [ "$DATA_KEY_EXISTS" = "true" ]; then - echo -e " • 数据加密密钥已存在" - fi - if [ "$JWT_KEY_EXISTS" = "true" ]; then - echo -e " • JWT认证密钥已存在" - fi - read -p "是否重新生成所有密钥? [y/N]: " -n 1 -r - echo - if [[ ! $REPLY =~ ^[Yy]$ ]]; then - echo -e "${BLUE}ℹ️ 保持现有密钥${NC}" - KEY_SKIPPED=true - # 读取现有密钥 - if [ "$DATA_KEY_EXISTS" = "true" ]; then - DATA_KEY=$(grep "^DATA_ENCRYPTION_KEY=" .env | cut -d'=' -f2) - fi - if [ "$JWT_KEY_EXISTS" = "true" ]; then - JWT_KEY=$(grep "^JWT_SECRET=" .env | cut -d'=' -f2) - fi - fi -fi - -if [ "$KEY_SKIPPED" != "true" ]; then - # 生成新的密钥 - echo -e " ${CYAN}生成AES-256数据加密密钥...${NC}" - DATA_KEY=$(openssl rand -base64 32) - echo -e "${GREEN} ✓ 数据加密密钥生成完成${NC}" - - echo -e " ${CYAN}生成JWT认证密钥...${NC}" - JWT_KEY=$(openssl rand -base64 64) - echo -e "${GREEN} ✓ JWT认证密钥生成完成${NC}" - - # 保存到.env文件 - if [ -f ".env" ]; then - # 更新现有文件 - if grep -q "^DATA_ENCRYPTION_KEY=" .env; then - if [[ "$OSTYPE" == "darwin"* ]]; then - sed -i '' "s/^DATA_ENCRYPTION_KEY=.*/DATA_ENCRYPTION_KEY=$DATA_KEY/" .env - else - sed -i "s/^DATA_ENCRYPTION_KEY=.*/DATA_ENCRYPTION_KEY=$DATA_KEY/" .env - fi - else - echo "DATA_ENCRYPTION_KEY=$DATA_KEY" >> .env - fi - - if grep -q "^JWT_SECRET=" .env; then - # 使用替代分隔符避免 / 字符冲突,并用引号保护值 - if [[ "$OSTYPE" == "darwin"* ]]; then - sed -i '' "s|^JWT_SECRET=.*|JWT_SECRET=\"$JWT_KEY\"|" .env - else - sed -i "s|^JWT_SECRET=.*|JWT_SECRET=\"$JWT_KEY\"|" .env - fi - else - # 使用引号确保值在同一行 - printf "JWT_SECRET=\"%s\"\n" "$JWT_KEY" >> .env - fi - else - # 创建新文件 - echo "DATA_ENCRYPTION_KEY=$DATA_KEY" > .env - printf "JWT_SECRET=\"%s\"\n" "$JWT_KEY" >> .env - fi - chmod 600 .env - echo -e "${GREEN} ✓ 密钥已保存到 .env 文件${NC}" -elif [ "$DATA_KEY_EXISTS" != "true" ] || [ "$JWT_KEY_EXISTS" != "true" ]; then - # 生成缺失的密钥 - if [ "$DATA_KEY_EXISTS" != "true" ]; then - echo -e " ${CYAN}生成缺失的AES-256数据加密密钥...${NC}" - DATA_KEY=$(openssl rand -base64 32) - echo "DATA_ENCRYPTION_KEY=$DATA_KEY" >> .env - echo -e "${GREEN} ✓ 数据加密密钥生成完成${NC}" - fi - - if [ "$JWT_KEY_EXISTS" != "true" ]; then - echo -e " ${CYAN}生成缺失的JWT认证密钥...${NC}" - JWT_KEY=$(openssl rand -base64 64) - printf "JWT_SECRET=\"%s\"\n" "$JWT_KEY" >> .env - echo -e "${GREEN} ✓ JWT认证密钥生成完成${NC}" - fi - - chmod 600 .env - echo -e "${GREEN} ✓ 密钥已保存到 .env 文件${NC}" -fi - -# ============= 步骤4: 验证和总结 ============= -echo -echo -e "${YELLOW}✅ 步骤 4/4: 环境验证和总结...${NC}" - -# 验证文件存在性和权限 -echo -e " ${CYAN}验证文件和权限...${NC}" - -if [ -f "$PRIVATE_KEY_FILE" ]; then - PRIVATE_PERM=$(stat -f "%A" "$PRIVATE_KEY_FILE" 2>/dev/null || stat -c "%a" "$PRIVATE_KEY_FILE" 2>/dev/null) - echo -e "${GREEN} ✓ 私钥文件: $PRIVATE_KEY_FILE (权限: $PRIVATE_PERM)${NC}" -else - echo -e "${RED} ❌ 私钥文件不存在${NC}" - exit 1 -fi - -if [ -f "$PUBLIC_KEY_FILE" ]; then - PUBLIC_PERM=$(stat -f "%A" "$PUBLIC_KEY_FILE" 2>/dev/null || stat -c "%a" "$PUBLIC_KEY_FILE" 2>/dev/null) - echo -e "${GREEN} ✓ 公钥文件: $PUBLIC_KEY_FILE (权限: $PUBLIC_PERM)${NC}" -else - echo -e "${RED} ❌ 公钥文件不存在${NC}" - exit 1 -fi - -if [ -f ".env" ] && grep -q "^DATA_ENCRYPTION_KEY=" .env && grep -q "^JWT_SECRET=" .env; then - ENV_PERM=$(stat -f "%A" ".env" 2>/dev/null || stat -c "%a" ".env" 2>/dev/null) - echo -e "${GREEN} ✓ 环境文件: .env (权限: $ENV_PERM)${NC}" - echo -e "${GREEN} 包含: DATA_ENCRYPTION_KEY, JWT_SECRET${NC}" -else - echo -e "${RED} ❌ 环境文件不存在或缺少必要密钥${NC}" - exit 1 -fi - -# 测试密钥功能 -echo -e " ${CYAN}测试密钥功能...${NC}" -TEST_DATA="Hello Mars AI Trading System" -ENCRYPTED=$(echo "$TEST_DATA" | openssl rsautl -encrypt -pubin -inkey "$PUBLIC_KEY_FILE" | base64) -DECRYPTED=$(echo "$ENCRYPTED" | base64 -d | openssl rsautl -decrypt -inkey "$PRIVATE_KEY_FILE") - -if [ "$DECRYPTED" = "$TEST_DATA" ]; then - echo -e "${GREEN} ✓ RSA加密/解密测试通过${NC}" -else - echo -e "${RED} ❌ RSA加密/解密测试失败${NC}" - exit 1 -fi - -# 显示最终结果 -echo -echo -e "${GREEN}🎉 加密环境设置完成!${NC}" -echo -echo -e "${PURPLE}╔════════════════════════════════════════════════════════════════════════╗${NC}" -echo -e "${PURPLE}║ 设置完成摘要 ║${NC}" -echo -e "${PURPLE}╠════════════════════════════════════════════════════════════════════════╣${NC}" -echo -e "${PURPLE}║${NC} ${BLUE}RSA密钥对:${NC} ${PURPLE}║${NC}" -echo -e "${PURPLE}║${NC} 私钥: ${YELLOW}$PRIVATE_KEY_FILE${NC} ${PURPLE}║${NC}" -echo -e "${PURPLE}║${NC} 公钥: ${YELLOW}$PUBLIC_KEY_FILE${NC} ${PURPLE}║${NC}" -echo -e "${PURPLE}║${NC} 大小: ${YELLOW}$RSA_KEY_SIZE bits${NC} ${PURPLE}║${NC}" -echo -e "${PURPLE}║${NC} ${PURPLE}║${NC}" -echo -e "${PURPLE}║${NC} ${BLUE}安全密钥配置:${NC} ${PURPLE}║${NC}" -echo -e "${PURPLE}║${NC} 文件: ${YELLOW}.env${NC} ${PURPLE}║${NC}" -echo -e "${PURPLE}║${NC} 数据加密: ${YELLOW}DATA_ENCRYPTION_KEY (AES-256-GCM)${NC} ${PURPLE}║${NC}" -echo -e "${PURPLE}║${NC} JWT认证: ${YELLOW}JWT_SECRET (HS256)${NC} ${PURPLE}║${NC}" -echo -e "${PURPLE}╚════════════════════════════════════════════════════════════════════════╝${NC}" - -# 使用指南 -echo -echo -e "${BLUE}📋 使用指南:${NC}" -echo -echo -e "${YELLOW}1. 启动Mars AI交易系统:${NC}" -echo -e " source .env && ./mars" -echo -echo -e "${YELLOW}2. Docker部署:${NC}" -echo -e " docker run --env-file .env mars-ai-trading" -echo -echo -e "${YELLOW}3. 查看公钥内容:${NC}" -echo -e " cat $PUBLIC_KEY_FILE" -echo -echo -e "${YELLOW}4. 测试加密API:${NC}" -echo -e " curl http://localhost:8080/api/crypto/public-key" - -# 安全提醒 -echo -echo -e "${RED}🔒 安全提醒:${NC}" -echo -e " • 私钥文件 ($PRIVATE_KEY_FILE) 权限已设置为 600" -echo -e " • 环境文件 (.env) 权限已设置为 600" -echo -e " • 请勿将私钥和数据密钥提交到版本控制系统" -echo -e " • 建议在生产环境中使用密钥管理服务" -echo -e " • 定期备份密钥文件" - -echo -echo -e "${GREEN}✅ Mars AI交易系统加密环境设置完成!${NC}" \ No newline at end of file diff --git a/start.sh b/start.sh index b9a84bac..c6a860e9 100755 --- a/start.sh +++ b/start.sh @@ -14,6 +14,7 @@ RED='\033[0;31m' GREEN='\033[0;32m' YELLOW='\033[1;33m' BLUE='\033[0;34m' +CYAN='\033[0;36m' NC='\033[0m' # No Color # ------------------------------------------------------------------------ @@ -70,95 +71,109 @@ check_env() { if [ ! -f ".env" ]; then print_warning ".env 不存在,从模板复制..." cp .env.example .env - print_info "✓ 已使用默认环境变量创建 .env" - print_info "💡 如需修改端口等设置,可编辑 .env 文件" + print_info "已创建 .env 文件" fi print_success "环境变量文件存在" } # ------------------------------------------------------------------------ -# Validation: Encryption Environment (RSA Keys + Data Encryption Key) +# Helper: Check if env var is set and not placeholder # ------------------------------------------------------------------------ -check_encryption() { - local need_setup=false - - print_info "检查加密环境..." - - # 检查RSA密钥对 - if [ ! -f "secrets/rsa_key" ] || [ ! -f "secrets/rsa_key.pub" ]; then - print_warning "RSA密钥对不存在" - need_setup=true - fi - - # 检查数据加密密钥 - if [ ! -f ".env" ] || ! grep -q "^DATA_ENCRYPTION_KEY=" .env; then - print_warning "数据加密密钥未配置" - need_setup=true - fi - - # 检查JWT认证密钥 - if [ ! -f ".env" ] || ! grep -q "^JWT_SECRET=" .env; then - print_warning "JWT认证密钥未配置" - need_setup=true - fi - - # 如果需要设置加密环境,直接自动设置 - if [ "$need_setup" = "true" ]; then - print_info "🔐 检测到加密环境未配置,正在自动设置..." - print_info "加密环境用于保护敏感数据(API密钥、私钥等)" - echo "" +is_env_configured() { + local var_name="$1" + local value=$(grep "^${var_name}=" .env 2>/dev/null | cut -d'=' -f2-) - # 检查加密设置脚本是否存在 - if [ -f "scripts/setup_encryption.sh" ]; then - print_info "加密系统将保护: API密钥、私钥、Hyperliquid代理钱包" - echo "" + # 去除引号 + value=$(echo "$value" | tr -d '"'"'") - # 自动运行加密设置脚本 - echo -e "Y\nn\nn" | bash scripts/setup_encryption.sh - if [ $? -eq 0 ]; then - echo "" - print_success "🔐 加密环境设置完成!" - print_info " • RSA-2048密钥对已生成" - print_info " • AES-256数据加密密钥已配置" - print_info " • JWT认证密钥已配置" - print_info " • 所有敏感数据现在都受加密保护" - echo "" - else - print_error "加密环境设置失败" - exit 1 - fi + # 检查是否为空或占位符 + if [ -z "$value" ]; then + return 1 + fi + + # 检查是否是示例值 + case "$value" in + *your-*|*YOUR_*|*change-this*|*CHANGE_THIS*|*example*|*EXAMPLE*) + return 1 + ;; + esac + + return 0 +} + +# ------------------------------------------------------------------------ +# Helper: Generate and set env var in .env file +# ------------------------------------------------------------------------ +set_env_var() { + local var_name="$1" + local var_value="$2" + + # 如果变量已存在(即使是占位符),替换它 + if grep -q "^${var_name}=" .env 2>/dev/null; then + # macOS 和 Linux 兼容的 sed + if [[ "$OSTYPE" == "darwin"* ]]; then + sed -i '' "s|^${var_name}=.*|${var_name}=${var_value}|" .env else - print_error "加密设置脚本不存在: scripts/setup_encryption.sh" - print_info "请手动运行: ./scripts/setup_encryption.sh" - exit 1 + sed -i "s|^${var_name}=.*|${var_name}=${var_value}|" .env fi else - print_success "🔐 加密环境已配置" - print_info " • RSA密钥对: secrets/rsa_key + secrets/rsa_key.pub" - print_info " • 数据加密密钥: .env (DATA_ENCRYPTION_KEY)" - print_info " • JWT认证密钥: .env (JWT_SECRET)" - print_info " • 加密算法: RSA-OAEP-2048 + AES-256-GCM + HS256" - print_info " • 保护数据: API密钥、私钥、Hyperliquid代理钱包、用户认证" - - # 验证密钥文件权限 - if [ -f "secrets/rsa_key" ]; then - local perm=$(stat -f "%A" "secrets/rsa_key" 2>/dev/null || stat -c "%a" "secrets/rsa_key" 2>/dev/null) - if [ "$perm" != "600" ]; then - print_warning "修复RSA私钥权限..." - chmod 600 secrets/rsa_key - fi - fi - - if [ -f ".env" ]; then - local perm=$(stat -f "%A" ".env" 2>/dev/null || stat -c "%a" ".env" 2>/dev/null) - if [ "$perm" != "600" ]; then - print_warning "修复环境文件权限..." - chmod 600 .env - fi - fi + # 变量不存在,追加 + echo "${var_name}=${var_value}" >> .env fi } +# ------------------------------------------------------------------------ +# Validation: Encryption Keys in .env +# ------------------------------------------------------------------------ +check_encryption() { + print_info "检查加密密钥配置..." + + local generated=false + + # 检查并生成 JWT_SECRET + if ! is_env_configured "JWT_SECRET"; then + print_warning "JWT_SECRET 未配置,正在生成..." + local jwt_secret=$(openssl rand -base64 32) + set_env_var "JWT_SECRET" "$jwt_secret" + print_success "JWT_SECRET 已生成" + generated=true + fi + + # 检查并生成 DATA_ENCRYPTION_KEY + if ! is_env_configured "DATA_ENCRYPTION_KEY"; then + print_warning "DATA_ENCRYPTION_KEY 未配置,正在生成..." + local data_key=$(openssl rand -base64 32) + set_env_var "DATA_ENCRYPTION_KEY" "$data_key" + print_success "DATA_ENCRYPTION_KEY 已生成" + generated=true + fi + + # 检查并生成 RSA_PRIVATE_KEY + if ! is_env_configured "RSA_PRIVATE_KEY"; then + print_warning "RSA_PRIVATE_KEY 未配置,正在生成..." + # 生成 RSA 密钥并转换为单行格式(\n 替换为 \\n) + local rsa_key=$(openssl genrsa 2048 2>/dev/null | awk '{printf "%s\\n", $0}') + set_env_var "RSA_PRIVATE_KEY" "\"$rsa_key\"" + print_success "RSA_PRIVATE_KEY 已生成" + generated=true + fi + + if [ "$generated" = true ]; then + echo "" + print_success "所有缺失的密钥已自动生成并保存到 .env" + print_warning "请妥善保管 .env 文件,不要提交到版本控制系统" + echo "" + fi + + print_success "加密密钥检查完成" + print_info " • JWT_SECRET: OK" + print_info " • DATA_ENCRYPTION_KEY: OK" + print_info " • RSA_PRIVATE_KEY: OK" + + # 修复 .env 文件权限 + chmod 600 .env 2>/dev/null || true +} + # ------------------------------------------------------------------------ # Validation: Configuration File (config.json) - BASIC SETTINGS ONLY # ------------------------------------------------------------------------ @@ -166,9 +181,7 @@ check_config() { if [ ! -f "config.json" ]; then print_warning "config.json 不存在,从模板复制..." cp config.json.example config.json - print_info "✓ 已使用默认配置创建 config.json" - print_info "💡 如需修改基础设置(杠杆大小、开仓币种、管理员模式、JWT密钥等),可编辑 config.json" - print_info "💡 模型/交易所/交易员配置请使用Web界面" + print_info "已使用默认配置创建 config.json" fi print_success "配置文件存在" } @@ -178,101 +191,55 @@ check_config() { # ------------------------------------------------------------------------ read_env_vars() { if [ -f ".env" ]; then - # 读取端口配置,设置默认值 NOFX_FRONTEND_PORT=$(grep "^NOFX_FRONTEND_PORT=" .env 2>/dev/null | cut -d'=' -f2 || echo "3000") NOFX_BACKEND_PORT=$(grep "^NOFX_BACKEND_PORT=" .env 2>/dev/null | cut -d'=' -f2 || echo "8080") - - # 去除可能的引号和空格 + NOFX_FRONTEND_PORT=$(echo "$NOFX_FRONTEND_PORT" | tr -d '"'"'" | tr -d ' ') NOFX_BACKEND_PORT=$(echo "$NOFX_BACKEND_PORT" | tr -d '"'"'" | tr -d ' ') - - # 如果为空则使用默认值 + NOFX_FRONTEND_PORT=${NOFX_FRONTEND_PORT:-3000} NOFX_BACKEND_PORT=${NOFX_BACKEND_PORT:-8080} else - # 如果.env不存在,使用默认端口 NOFX_FRONTEND_PORT=3000 NOFX_BACKEND_PORT=8080 fi } # ------------------------------------------------------------------------ -# Validation: Database File (config.db) +# Validation: Database File (data.db) # ------------------------------------------------------------------------ check_database() { - if [ -d "config.db" ]; then - # 如果存在的是目录,删除它 - print_warning "config.db 是目录而非文件,正在删除目录..." - rm -rf config.db - print_info "✓ 已删除目录,现在创建文件..." - install -m 600 /dev/null config.db - print_success "✓ 已创建空数据库文件(权限: 600),系统将在启动时初始化" - elif [ ! -f "config.db" ]; then - # 如果不存在文件,创建它 + if [ -d "data.db" ]; then + print_warning "data.db 是目录而非文件,正在删除目录..." + rm -rf data.db + install -m 600 /dev/null data.db + print_success "已创建空数据库文件" + elif [ ! -f "data.db" ]; then print_warning "数据库文件不存在,创建空数据库文件..." - # 创建空文件以避免Docker创建目录(使用安全权限600) - install -m 600 /dev/null config.db - print_info "✓ 已创建空数据库文件(权限: 600),系统将在启动时初始化" + install -m 600 /dev/null data.db + print_info "已创建空数据库文件,系统将在启动时初始化" else - # 文件存在 print_success "数据库文件存在" fi } -# ------------------------------------------------------------------------ -# Build: Frontend (Node.js Based) -# ------------------------------------------------------------------------ -# build_frontend() { -# print_info "检查前端构建环境..." - -# if ! command -v node &> /dev/null; then -# print_error "Node.js 未安装!请先安装 Node.js" -# exit 1 -# fi - -# if ! command -v npm &> /dev/null; then -# print_error "npm 未安装!请先安装 npm" -# exit 1 -# fi - -# print_info "正在构建前端..." -# cd web - -# print_info "安装 Node.js 依赖..." -# npm install - -# print_info "构建前端应用..." -# npm run build - -# cd .. -# print_success "前端构建完成" -# } - # ------------------------------------------------------------------------ # Service Management: Start # ------------------------------------------------------------------------ start() { print_info "正在启动 NOFX AI Trading System..." - # 读取环境变量 read_env_vars - # 确保必要的文件和目录存在(修复 Docker volume 挂载问题) - if [ ! -f "config.db" ]; then + if [ ! -f "data.db" ]; then print_info "创建数据库文件..." - install -m 600 /dev/null config.db + install -m 600 /dev/null data.db fi if [ ! -d "decision_logs" ]; then print_info "创建日志目录..." install -m 700 -d decision_logs fi - # Auto-build frontend if missing or forced - # if [ ! -d "web/dist" ] || [ "$1" == "--build" ]; then - # build_frontend - # fi - - # Rebuild images if flag set if [ "$1" == "--build" ]; then print_info "重新构建镜像..." $COMPOSE_CMD up -d --build @@ -322,9 +289,8 @@ logs() { # Monitoring: Status # ------------------------------------------------------------------------ status() { - # 读取环境变量 read_env_vars - + print_info "服务状态:" $COMPOSE_CMD ps echo "" @@ -358,18 +324,42 @@ update() { } # ------------------------------------------------------------------------ -# Encryption: Manual Setup +# Command: Regenerate all keys (force) # ------------------------------------------------------------------------ -setup_encryption_manual() { - print_info "🔐 手动设置加密环境" - - if [ -f "scripts/setup_encryption.sh" ]; then - bash scripts/setup_encryption.sh - else - print_error "加密设置脚本不存在: scripts/setup_encryption.sh" - print_info "请确保项目文件完整" - exit 1 +regenerate_keys() { + print_warning "这将重新生成所有加密密钥!" + print_warning "如果已有加密数据,重新生成后将无法解密!" + echo "" + read -p "确认重新生成?(yes/no): " confirm + if [ "$confirm" != "yes" ]; then + print_info "已取消" + return fi + + check_env + + print_info "正在生成新的密钥..." + + # 生成 JWT_SECRET + local jwt_secret=$(openssl rand -base64 32) + set_env_var "JWT_SECRET" "$jwt_secret" + print_success "JWT_SECRET 已生成" + + # 生成 DATA_ENCRYPTION_KEY + local data_key=$(openssl rand -base64 32) + set_env_var "DATA_ENCRYPTION_KEY" "$data_key" + print_success "DATA_ENCRYPTION_KEY 已生成" + + # 生成 RSA_PRIVATE_KEY + local rsa_key=$(openssl genrsa 2048 2>/dev/null | awk '{printf "%s\\n", $0}') + set_env_var "RSA_PRIVATE_KEY" "\"$rsa_key\"" + print_success "RSA_PRIVATE_KEY 已生成" + + chmod 600 .env 2>/dev/null || true + + echo "" + print_success "所有密钥已重新生成并保存到 .env" + print_warning "请妥善保管 .env 文件" } # ------------------------------------------------------------------------ @@ -388,18 +378,16 @@ show_help() { echo " status 查看服务状态" echo " clean 清理所有容器和数据" echo " update 更新代码并重启" - echo " setup-encryption 设置加密环境(RSA密钥+数据加密)" + echo " regenerate-keys 重新生成所有加密密钥(慎用)" echo " help 显示此帮助信息" echo "" echo "示例:" echo " ./start.sh start --build # 构建并启动" echo " ./start.sh logs backend # 查看后端日志" echo " ./start.sh status # 查看状态" - echo " ./start.sh setup-encryption # 手动设置加密环境" echo "" - echo "🔐 关于加密:" - echo " 系统自动检测加密环境,首次运行时会自动设置" - echo " 手动设置: ./scripts/setup_encryption.sh" + echo "首次使用:" + echo " 直接运行 ./start.sh 即可,缺失的密钥会自动生成" } # ------------------------------------------------------------------------ @@ -434,8 +422,8 @@ main() { update) update ;; - setup-encryption) - setup_encryption_manual + regenerate-keys) + regenerate_keys ;; help|--help|-h) show_help @@ -449,4 +437,4 @@ main() { } # Execute Main -main "$@" \ No newline at end of file +main "$@" diff --git a/store/ai_model.go b/store/ai_model.go new file mode 100644 index 00000000..d8f0594c --- /dev/null +++ b/store/ai_model.go @@ -0,0 +1,294 @@ +package store + +import ( + "database/sql" + "errors" + "fmt" + "nofx/logger" + "strings" + "time" +) + +// AIModelStore AI模型存储 +type AIModelStore struct { + db *sql.DB + encryptFunc func(string) string + decryptFunc func(string) string +} + +// AIModel AI模型配置 +type AIModel struct { + ID string `json:"id"` + UserID string `json:"user_id"` + Name string `json:"name"` + Provider string `json:"provider"` + Enabled bool `json:"enabled"` + APIKey string `json:"apiKey"` + CustomAPIURL string `json:"customApiUrl"` + CustomModelName string `json:"customModelName"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (s *AIModelStore) initTables() error { + _, err := s.db.Exec(` + CREATE TABLE IF NOT EXISTS ai_models ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL DEFAULT 'default', + name TEXT NOT NULL, + provider TEXT NOT NULL, + enabled BOOLEAN DEFAULT 0, + api_key TEXT DEFAULT '', + custom_api_url TEXT DEFAULT '', + custom_model_name TEXT DEFAULT '', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE + ) + `) + if err != nil { + return err + } + + // 触发器 + _, err = s.db.Exec(` + CREATE TRIGGER IF NOT EXISTS update_ai_models_updated_at + AFTER UPDATE ON ai_models + BEGIN + UPDATE ai_models SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; + END + `) + if err != nil { + return err + } + + // 向后兼容:添加可能缺失的列 + s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_api_url TEXT DEFAULT ''`) + s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_model_name TEXT DEFAULT ''`) + + return nil +} + +func (s *AIModelStore) initDefaultData() error { + models := []struct { + id, name, provider string + }{ + {"deepseek", "DeepSeek", "deepseek"}, + {"qwen", "Qwen", "qwen"}, + } + + for _, model := range models { + _, err := s.db.Exec(` + INSERT OR IGNORE INTO ai_models (id, user_id, name, provider, enabled) + VALUES (?, 'default', ?, ?, 0) + `, model.id, model.name, model.provider) + if err != nil { + return fmt.Errorf("初始化AI模型失败: %w", err) + } + } + return nil +} + +func (s *AIModelStore) encrypt(plaintext string) string { + if s.encryptFunc != nil { + return s.encryptFunc(plaintext) + } + return plaintext +} + +func (s *AIModelStore) decrypt(encrypted string) string { + if s.decryptFunc != nil { + return s.decryptFunc(encrypted) + } + return encrypted +} + +// List 获取用户的AI模型列表 +func (s *AIModelStore) List(userID string) ([]*AIModel, error) { + rows, err := s.db.Query(` + SELECT id, user_id, name, provider, enabled, api_key, + COALESCE(custom_api_url, '') as custom_api_url, + COALESCE(custom_model_name, '') as custom_model_name, + created_at, updated_at + FROM ai_models WHERE user_id = ? ORDER BY id + `, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + models := make([]*AIModel, 0) + for rows.Next() { + var model AIModel + var createdAt, updatedAt string + err := rows.Scan( + &model.ID, &model.UserID, &model.Name, &model.Provider, + &model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName, + &createdAt, &updatedAt, + ) + if err != nil { + return nil, err + } + model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) + model.APIKey = s.decrypt(model.APIKey) + models = append(models, &model) + } + return models, nil +} + +// Get 获取单个AI模型 +func (s *AIModelStore) Get(userID, modelID string) (*AIModel, error) { + if modelID == "" { + return nil, fmt.Errorf("模型ID不能为空") + } + + candidates := []string{} + if userID != "" { + candidates = append(candidates, userID) + } + if userID != "default" { + candidates = append(candidates, "default") + } + if len(candidates) == 0 { + candidates = append(candidates, "default") + } + + for _, uid := range candidates { + var model AIModel + var createdAt, updatedAt string + err := s.db.QueryRow(` + SELECT id, user_id, name, provider, enabled, api_key, + COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at + FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1 + `, uid, modelID).Scan( + &model.ID, &model.UserID, &model.Name, &model.Provider, + &model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName, + &createdAt, &updatedAt, + ) + if err == nil { + model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) + model.APIKey = s.decrypt(model.APIKey) + return &model, nil + } + if !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + } + return nil, sql.ErrNoRows +} + +// GetDefault 获取默认启用的AI模型 +func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) { + if userID == "" { + userID = "default" + } + model, err := s.firstEnabled(userID) + if err == nil { + return model, nil + } + if !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + if userID != "default" { + return s.firstEnabled("default") + } + return nil, fmt.Errorf("请先在系统中配置可用的AI模型") +} + +func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) { + var model AIModel + var createdAt, updatedAt string + err := s.db.QueryRow(` + SELECT id, user_id, name, provider, enabled, api_key, + COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at + FROM ai_models WHERE user_id = ? AND enabled = 1 + ORDER BY datetime(updated_at) DESC, id ASC LIMIT 1 + `, userID).Scan( + &model.ID, &model.UserID, &model.Name, &model.Provider, + &model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName, + &createdAt, &updatedAt, + ) + if err != nil { + return nil, err + } + model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) + model.APIKey = s.decrypt(model.APIKey) + return &model, nil +} + +// Update 更新AI模型,不存在则创建 +func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error { + // 先尝试精确匹配ID + var existingID string + err := s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1`, userID, id).Scan(&existingID) + if err == nil { + encryptedAPIKey := s.encrypt(apiKey) + _, err = s.db.Exec(` + UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now') + WHERE id = ? AND user_id = ? + `, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID) + return err + } + + // 尝试兼容旧逻辑:将id作为provider查找 + provider := id + err = s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND provider = ? LIMIT 1`, userID, provider).Scan(&existingID) + if err == nil { + logger.Warnf("⚠️ 使用旧版 provider 匹配更新模型: %s -> %s", provider, existingID) + encryptedAPIKey := s.encrypt(apiKey) + _, err = s.db.Exec(` + UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now') + WHERE id = ? AND user_id = ? + `, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID) + return err + } + + // 创建新记录 + if provider == id && (provider == "deepseek" || provider == "qwen") { + provider = id + } else { + parts := strings.Split(id, "_") + if len(parts) >= 2 { + provider = parts[len(parts)-1] + } else { + provider = id + } + } + + var name string + err = s.db.QueryRow(`SELECT name FROM ai_models WHERE provider = ? LIMIT 1`, provider).Scan(&name) + if err != nil { + if provider == "deepseek" { + name = "DeepSeek AI" + } else if provider == "qwen" { + name = "Qwen AI" + } else { + name = provider + " AI" + } + } + + newModelID := id + if id == provider { + newModelID = fmt.Sprintf("%s_%s", userID, provider) + } + + logger.Infof("✓ 创建新的 AI 模型配置: ID=%s, Provider=%s, Name=%s", newModelID, provider, name) + encryptedAPIKey := s.encrypt(apiKey) + _, err = s.db.Exec(` + INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url, custom_model_name, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now')) + `, newModelID, userID, name, provider, enabled, encryptedAPIKey, customAPIURL, customModelName) + return err +} + +// Create 创建AI模型 +func (s *AIModelStore) Create(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error { + _, err := s.db.Exec(` + INSERT OR IGNORE INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url) + VALUES (?, ?, ?, ?, ?, ?, ?) + `, id, userID, name, provider, enabled, apiKey, customAPIURL) + return err +} diff --git a/store/backtest.go b/store/backtest.go new file mode 100644 index 00000000..89ecb14d --- /dev/null +++ b/store/backtest.go @@ -0,0 +1,583 @@ +package store + +import ( + "database/sql" + "encoding/json" + "fmt" + "time" +) + +// BacktestStore 回测数据存储 +type BacktestStore struct { + db *sql.DB +} + +// RunState 回测状态 +type RunState string + +const ( + RunStateCreated RunState = "created" + RunStateRunning RunState = "running" + RunStatePaused RunState = "paused" + RunStateCompleted RunState = "completed" + RunStateFailed RunState = "failed" +) + +// RunMetadata 回测元数据 +type RunMetadata struct { + RunID string `json:"run_id"` + UserID string `json:"user_id"` + Version int `json:"version"` + State RunState `json:"state"` + Label string `json:"label"` + LastError string `json:"last_error"` + Summary RunSummary `json:"summary"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// RunSummary 回测摘要 +type RunSummary struct { + SymbolCount int `json:"symbol_count"` + DecisionTF string `json:"decision_tf"` + ProcessedBars int `json:"processed_bars"` + ProgressPct float64 `json:"progress_pct"` + EquityLast float64 `json:"equity_last"` + MaxDrawdownPct float64 `json:"max_drawdown_pct"` + Liquidated bool `json:"liquidated"` + LiquidationNote string `json:"liquidation_note"` +} + +// EquityPoint 权益点 +type EquityPoint struct { + Timestamp int64 `json:"timestamp"` + Equity float64 `json:"equity"` + Available float64 `json:"available"` + PnL float64 `json:"pnl"` + PnLPct float64 `json:"pnl_pct"` + DrawdownPct float64 `json:"drawdown_pct"` + Cycle int `json:"cycle"` +} + +// TradeEvent 交易事件 +type TradeEvent struct { + Timestamp int64 `json:"timestamp"` + Symbol string `json:"symbol"` + Action string `json:"action"` + Side string `json:"side"` + Quantity float64 `json:"quantity"` + Price float64 `json:"price"` + Fee float64 `json:"fee"` + Slippage float64 `json:"slippage"` + OrderValue float64 `json:"order_value"` + RealizedPnL float64 `json:"realized_pnl"` + Leverage int `json:"leverage"` + Cycle int `json:"cycle"` + PositionAfter float64 `json:"position_after"` + LiquidationFlag bool `json:"liquidation_flag"` + Note string `json:"note"` +} + +// RunIndexEntry 回测索引条目 +type RunIndexEntry struct { + RunID string `json:"run_id"` + State string `json:"state"` + Symbols []string `json:"symbols"` + DecisionTF string `json:"decision_tf"` + EquityLast float64 `json:"equity_last"` + MaxDrawdownPct float64 `json:"max_drawdown_pct"` + StartTS int64 `json:"start_ts"` + EndTS int64 `json:"end_ts"` + CreatedAtISO string `json:"created_at"` + UpdatedAtISO string `json:"updated_at"` +} + +// initTables 初始化回测相关表 +func (s *BacktestStore) initTables() error { + queries := []string{ + // 回测运行主表 + `CREATE TABLE IF NOT EXISTS backtest_runs ( + run_id TEXT PRIMARY KEY, + user_id TEXT NOT NULL DEFAULT '', + config_json TEXT NOT NULL DEFAULT '', + state TEXT NOT NULL DEFAULT 'created', + label TEXT DEFAULT '', + symbol_count INTEGER DEFAULT 0, + decision_tf TEXT DEFAULT '', + processed_bars INTEGER DEFAULT 0, + progress_pct REAL DEFAULT 0, + equity_last REAL DEFAULT 0, + max_drawdown_pct REAL DEFAULT 0, + liquidated BOOLEAN DEFAULT 0, + liquidation_note TEXT DEFAULT '', + prompt_template TEXT DEFAULT '', + custom_prompt TEXT DEFAULT '', + override_prompt BOOLEAN DEFAULT 0, + ai_provider TEXT DEFAULT '', + ai_model TEXT DEFAULT '', + last_error TEXT DEFAULT '', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + )`, + + // 回测检查点 + `CREATE TABLE IF NOT EXISTS backtest_checkpoints ( + run_id TEXT PRIMARY KEY, + payload BLOB NOT NULL, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE + )`, + + // 回测权益曲线 + `CREATE TABLE IF NOT EXISTS backtest_equity ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_id TEXT NOT NULL, + ts INTEGER NOT NULL, + equity REAL NOT NULL, + available REAL NOT NULL, + pnl REAL NOT NULL, + pnl_pct REAL NOT NULL, + dd_pct REAL NOT NULL, + cycle INTEGER NOT NULL, + FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE + )`, + + // 回测交易记录 + `CREATE TABLE IF NOT EXISTS backtest_trades ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_id TEXT NOT NULL, + ts INTEGER NOT NULL, + symbol TEXT NOT NULL, + action TEXT NOT NULL, + side TEXT DEFAULT '', + qty REAL DEFAULT 0, + price REAL DEFAULT 0, + fee REAL DEFAULT 0, + slippage REAL DEFAULT 0, + order_value REAL DEFAULT 0, + realized_pnl REAL DEFAULT 0, + leverage INTEGER DEFAULT 0, + cycle INTEGER DEFAULT 0, + position_after REAL DEFAULT 0, + liquidation BOOLEAN DEFAULT 0, + note TEXT DEFAULT '', + FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE + )`, + + // 回测指标 + `CREATE TABLE IF NOT EXISTS backtest_metrics ( + run_id TEXT PRIMARY KEY, + payload BLOB NOT NULL, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE + )`, + + // 回测决策日志 + `CREATE TABLE IF NOT EXISTS backtest_decisions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_id TEXT NOT NULL, + cycle INTEGER NOT NULL, + payload BLOB NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE + )`, + + // 索引 + `CREATE INDEX IF NOT EXISTS idx_backtest_runs_state ON backtest_runs(state, updated_at)`, + `CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`, + `CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`, + `CREATE INDEX IF NOT EXISTS idx_backtest_decisions_run_cycle ON backtest_decisions(run_id, cycle)`, + } + + for _, query := range queries { + if _, err := s.db.Exec(query); err != nil { + return fmt.Errorf("执行SQL失败: %w", err) + } + } + + // 添加可能缺失的列(向后兼容) + s.addColumnIfNotExists("backtest_runs", "label", "TEXT DEFAULT ''") + s.addColumnIfNotExists("backtest_runs", "last_error", "TEXT DEFAULT ''") + s.addColumnIfNotExists("backtest_trades", "leverage", "INTEGER DEFAULT 0") + + return nil +} + +func (s *BacktestStore) addColumnIfNotExists(table, column, definition string) { + rows, err := s.db.Query(fmt.Sprintf("PRAGMA table_info(%s)", table)) + if err != nil { + return + } + defer rows.Close() + + for rows.Next() { + var cid int + var name, ctype string + var notnull, pk int + var dflt interface{} + if err := rows.Scan(&cid, &name, &ctype, ¬null, &dflt, &pk); err != nil { + continue + } + if name == column { + return // 列已存在 + } + } + + s.db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, definition)) +} + +// SaveCheckpoint 保存检查点 +func (s *BacktestStore) SaveCheckpoint(runID string, payload []byte) error { + _, err := s.db.Exec(` + INSERT INTO backtest_checkpoints (run_id, payload, updated_at) + VALUES (?, ?, CURRENT_TIMESTAMP) + ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP + `, runID, payload) + return err +} + +// LoadCheckpoint 加载检查点 +func (s *BacktestStore) LoadCheckpoint(runID string) ([]byte, error) { + var payload []byte + err := s.db.QueryRow(`SELECT payload FROM backtest_checkpoints WHERE run_id = ?`, runID).Scan(&payload) + return payload, err +} + +// SaveRunMetadata 保存运行元数据 +func (s *BacktestStore) SaveRunMetadata(meta *RunMetadata) error { + created := meta.CreatedAt.UTC().Format(time.RFC3339) + updated := meta.UpdatedAt.UTC().Format(time.RFC3339) + userID := meta.UserID + + if _, err := s.db.Exec(` + INSERT INTO backtest_runs (run_id, user_id, label, last_error, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(run_id) DO NOTHING + `, meta.RunID, userID, meta.Label, meta.LastError, created, updated); err != nil { + return err + } + + _, err := s.db.Exec(` + UPDATE backtest_runs + SET user_id = ?, state = ?, symbol_count = ?, decision_tf = ?, processed_bars = ?, + progress_pct = ?, equity_last = ?, max_drawdown_pct = ?, liquidated = ?, + liquidation_note = ?, label = ?, last_error = ?, updated_at = ? + WHERE run_id = ? + `, userID, string(meta.State), meta.Summary.SymbolCount, meta.Summary.DecisionTF, + meta.Summary.ProcessedBars, meta.Summary.ProgressPct, meta.Summary.EquityLast, + meta.Summary.MaxDrawdownPct, meta.Summary.Liquidated, meta.Summary.LiquidationNote, + meta.Label, meta.LastError, updated, meta.RunID) + return err +} + +// LoadRunMetadata 加载运行元数据 +func (s *BacktestStore) LoadRunMetadata(runID string) (*RunMetadata, error) { + var ( + userID string + state string + label string + lastErr string + symbolCount int + decisionTF string + processedBars int + progressPct float64 + equityLast float64 + maxDD float64 + liquidated bool + liquidationNote string + createdISO string + updatedISO string + ) + + err := s.db.QueryRow(` + SELECT user_id, state, label, last_error, symbol_count, decision_tf, processed_bars, + progress_pct, equity_last, max_drawdown_pct, liquidated, liquidation_note, + created_at, updated_at + FROM backtest_runs WHERE run_id = ? + `, runID).Scan(&userID, &state, &label, &lastErr, &symbolCount, &decisionTF, + &processedBars, &progressPct, &equityLast, &maxDD, &liquidated, &liquidationNote, + &createdISO, &updatedISO) + if err != nil { + return nil, err + } + + meta := &RunMetadata{ + RunID: runID, + UserID: userID, + Version: 1, + State: RunState(state), + Label: label, + LastError: lastErr, + Summary: RunSummary{ + SymbolCount: symbolCount, + DecisionTF: decisionTF, + ProcessedBars: processedBars, + ProgressPct: progressPct, + EquityLast: equityLast, + MaxDrawdownPct: maxDD, + Liquidated: liquidated, + LiquidationNote: liquidationNote, + }, + } + + meta.CreatedAt, _ = time.Parse(time.RFC3339, createdISO) + meta.UpdatedAt, _ = time.Parse(time.RFC3339, updatedISO) + + return meta, nil +} + +// ListRunIDs 列出所有运行ID +func (s *BacktestStore) ListRunIDs() ([]string, error) { + rows, err := s.db.Query(`SELECT run_id FROM backtest_runs ORDER BY datetime(updated_at) DESC`) + if err != nil { + return nil, err + } + defer rows.Close() + + var ids []string + for rows.Next() { + var runID string + if err := rows.Scan(&runID); err != nil { + return nil, err + } + ids = append(ids, runID) + } + return ids, rows.Err() +} + +// AppendEquityPoint 添加权益点 +func (s *BacktestStore) AppendEquityPoint(runID string, point EquityPoint) error { + _, err := s.db.Exec(` + INSERT INTO backtest_equity (run_id, ts, equity, available, pnl, pnl_pct, dd_pct, cycle) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + `, runID, point.Timestamp, point.Equity, point.Available, point.PnL, + point.PnLPct, point.DrawdownPct, point.Cycle) + return err +} + +// LoadEquityPoints 加载权益点 +func (s *BacktestStore) LoadEquityPoints(runID string) ([]EquityPoint, error) { + rows, err := s.db.Query(` + SELECT ts, equity, available, pnl, pnl_pct, dd_pct, cycle + FROM backtest_equity WHERE run_id = ? ORDER BY ts ASC + `, runID) + if err != nil { + return nil, err + } + defer rows.Close() + + points := make([]EquityPoint, 0) + for rows.Next() { + var point EquityPoint + if err := rows.Scan(&point.Timestamp, &point.Equity, &point.Available, + &point.PnL, &point.PnLPct, &point.DrawdownPct, &point.Cycle); err != nil { + return nil, err + } + points = append(points, point) + } + return points, rows.Err() +} + +// AppendTradeEvent 添加交易事件 +func (s *BacktestStore) AppendTradeEvent(runID string, event TradeEvent) error { + _, err := s.db.Exec(` + INSERT INTO backtest_trades (run_id, ts, symbol, action, side, qty, price, fee, + slippage, order_value, realized_pnl, leverage, cycle, + position_after, liquidation, note) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `, runID, event.Timestamp, event.Symbol, event.Action, event.Side, event.Quantity, + event.Price, event.Fee, event.Slippage, event.OrderValue, event.RealizedPnL, + event.Leverage, event.Cycle, event.PositionAfter, event.LiquidationFlag, event.Note) + return err +} + +// LoadTradeEvents 加载交易事件 +func (s *BacktestStore) LoadTradeEvents(runID string) ([]TradeEvent, error) { + rows, err := s.db.Query(` + SELECT ts, symbol, action, side, qty, price, fee, slippage, order_value, + realized_pnl, leverage, cycle, position_after, liquidation, note + FROM backtest_trades WHERE run_id = ? ORDER BY ts ASC + `, runID) + if err != nil { + return nil, err + } + defer rows.Close() + + events := make([]TradeEvent, 0) + for rows.Next() { + var event TradeEvent + if err := rows.Scan(&event.Timestamp, &event.Symbol, &event.Action, &event.Side, + &event.Quantity, &event.Price, &event.Fee, &event.Slippage, &event.OrderValue, + &event.RealizedPnL, &event.Leverage, &event.Cycle, &event.PositionAfter, + &event.LiquidationFlag, &event.Note); err != nil { + return nil, err + } + events = append(events, event) + } + return events, rows.Err() +} + +// SaveMetrics 保存指标 +func (s *BacktestStore) SaveMetrics(runID string, payload []byte) error { + _, err := s.db.Exec(` + INSERT INTO backtest_metrics (run_id, payload, updated_at) + VALUES (?, ?, CURRENT_TIMESTAMP) + ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP + `, runID, payload) + return err +} + +// LoadMetrics 加载指标 +func (s *BacktestStore) LoadMetrics(runID string) ([]byte, error) { + var payload []byte + err := s.db.QueryRow(`SELECT payload FROM backtest_metrics WHERE run_id = ?`, runID).Scan(&payload) + return payload, err +} + +// SaveDecisionRecord 保存决策记录 +func (s *BacktestStore) SaveDecisionRecord(runID string, cycle int, payload []byte) error { + _, err := s.db.Exec(` + INSERT INTO backtest_decisions (run_id, cycle, payload) + VALUES (?, ?, ?) + `, runID, cycle, payload) + return err +} + +// LoadDecisionRecords 加载决策记录 +func (s *BacktestStore) LoadDecisionRecords(runID string, limit, offset int) ([]json.RawMessage, error) { + rows, err := s.db.Query(` + SELECT payload FROM backtest_decisions + WHERE run_id = ? + ORDER BY id DESC + LIMIT ? OFFSET ? + `, runID, limit, offset) + if err != nil { + return nil, err + } + defer rows.Close() + + records := make([]json.RawMessage, 0, limit) + for rows.Next() { + var payload []byte + if err := rows.Scan(&payload); err != nil { + return nil, err + } + records = append(records, json.RawMessage(payload)) + } + return records, rows.Err() +} + +// LoadLatestDecision 加载最新决策 +func (s *BacktestStore) LoadLatestDecision(runID string, cycle int) ([]byte, error) { + var query string + var args []interface{} + + if cycle > 0 { + query = `SELECT payload FROM backtest_decisions WHERE run_id = ? AND cycle = ? ORDER BY datetime(created_at) DESC LIMIT 1` + args = []interface{}{runID, cycle} + } else { + query = `SELECT payload FROM backtest_decisions WHERE run_id = ? ORDER BY datetime(created_at) DESC LIMIT 1` + args = []interface{}{runID} + } + + var payload []byte + err := s.db.QueryRow(query, args...).Scan(&payload) + return payload, err +} + +// UpdateProgress 更新进度 +func (s *BacktestStore) UpdateProgress(runID string, progressPct, equity float64, barIndex int, liquidated bool) error { + _, err := s.db.Exec(` + UPDATE backtest_runs + SET progress_pct = ?, equity_last = ?, processed_bars = ?, liquidated = ?, updated_at = CURRENT_TIMESTAMP + WHERE run_id = ? + `, progressPct, equity, barIndex, liquidated, runID) + return err +} + +// ListIndexEntries 列出索引条目 +func (s *BacktestStore) ListIndexEntries() ([]RunIndexEntry, error) { + rows, err := s.db.Query(` + SELECT run_id, state, symbol_count, decision_tf, equity_last, max_drawdown_pct, + created_at, updated_at, config_json + FROM backtest_runs + ORDER BY datetime(updated_at) DESC + `) + if err != nil { + return nil, err + } + defer rows.Close() + + var entries []RunIndexEntry + for rows.Next() { + var entry RunIndexEntry + var symbolCnt int + var cfgJSON []byte + var createdISO, updatedISO string + + if err := rows.Scan(&entry.RunID, &entry.State, &symbolCnt, &entry.DecisionTF, + &entry.EquityLast, &entry.MaxDrawdownPct, &createdISO, &updatedISO, &cfgJSON); err != nil { + return nil, err + } + + entry.CreatedAtISO = createdISO + entry.UpdatedAtISO = updatedISO + entry.Symbols = make([]string, 0, symbolCnt) + + // 尝试从配置中提取更多信息 + if len(cfgJSON) > 0 { + var cfg struct { + Symbols []string `json:"symbols"` + StartTS int64 `json:"start_ts"` + EndTS int64 `json:"end_ts"` + } + if json.Unmarshal(cfgJSON, &cfg) == nil { + entry.Symbols = cfg.Symbols + entry.StartTS = cfg.StartTS + entry.EndTS = cfg.EndTS + } + } + + entries = append(entries, entry) + } + return entries, rows.Err() +} + +// DeleteRun 删除运行 +func (s *BacktestStore) DeleteRun(runID string) error { + _, err := s.db.Exec(`DELETE FROM backtest_runs WHERE run_id = ?`, runID) + return err +} + +// SaveConfig 保存配置 +func (s *BacktestStore) SaveConfig(runID, userID, template, customPrompt, provider, model string, override bool, configJSON []byte) error { + now := time.Now().UTC().Format(time.RFC3339) + if userID == "" { + userID = "default" + } + + _, err := s.db.Exec(` + INSERT INTO backtest_runs (run_id, user_id, config_json, prompt_template, custom_prompt, + override_prompt, ai_provider, ai_model, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(run_id) DO NOTHING + `, runID, userID, configJSON, template, customPrompt, override, provider, model, now, now) + if err != nil { + return err + } + + _, err = s.db.Exec(` + UPDATE backtest_runs + SET user_id = ?, config_json = ?, prompt_template = ?, custom_prompt = ?, + override_prompt = ?, ai_provider = ?, ai_model = ?, updated_at = CURRENT_TIMESTAMP + WHERE run_id = ? + `, userID, configJSON, template, customPrompt, override, provider, model, runID) + return err +} + +// LoadConfig 加载配置 +func (s *BacktestStore) LoadConfig(runID string) ([]byte, error) { + var payload []byte + err := s.db.QueryRow(`SELECT config_json FROM backtest_runs WHERE run_id = ?`, runID).Scan(&payload) + return payload, err +} diff --git a/store/beta_code.go b/store/beta_code.go new file mode 100644 index 00000000..dc4f3658 --- /dev/null +++ b/store/beta_code.go @@ -0,0 +1,121 @@ +package store + +import ( + "database/sql" + "fmt" + "nofx/logger" + "os" + "strings" +) + +// BetaCodeStore 内测码存储 +type BetaCodeStore struct { + db *sql.DB +} + +func (s *BetaCodeStore) initTables() error { + _, err := s.db.Exec(` + CREATE TABLE IF NOT EXISTS beta_codes ( + code TEXT PRIMARY KEY, + used BOOLEAN DEFAULT 0, + used_by TEXT DEFAULT '', + used_at DATETIME DEFAULT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + `) + return err +} + +// LoadFromFile 从文件加载内测码 +func (s *BetaCodeStore) LoadFromFile(filePath string) error { + content, err := os.ReadFile(filePath) + if err != nil { + return fmt.Errorf("读取内测码文件失败: %w", err) + } + + lines := strings.Split(string(content), "\n") + var codes []string + for _, line := range lines { + code := strings.TrimSpace(line) + if code != "" && !strings.HasPrefix(code, "#") { + codes = append(codes, code) + } + } + + tx, err := s.db.Begin() + if err != nil { + return fmt.Errorf("开始事务失败: %w", err) + } + defer tx.Rollback() + + stmt, err := tx.Prepare(`INSERT OR IGNORE INTO beta_codes (code) VALUES (?)`) + if err != nil { + return fmt.Errorf("准备语句失败: %w", err) + } + defer stmt.Close() + + insertedCount := 0 + for _, code := range codes { + result, err := stmt.Exec(code) + if err != nil { + logger.Warnf("插入内测码 %s 失败: %v", code, err) + continue + } + if rowsAffected, _ := result.RowsAffected(); rowsAffected > 0 { + insertedCount++ + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("提交事务失败: %w", err) + } + + logger.Infof("✅ 成功加载 %d 个内测码到数据库 (总计 %d 个)", insertedCount, len(codes)) + return nil +} + +// Validate 验证内测码是否有效 +func (s *BetaCodeStore) Validate(code string) (bool, error) { + var used bool + err := s.db.QueryRow(`SELECT used FROM beta_codes WHERE code = ?`, code).Scan(&used) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return !used, nil +} + +// Use 使用内测码 +func (s *BetaCodeStore) Use(code, userEmail string) error { + result, err := s.db.Exec(` + UPDATE beta_codes SET used = 1, used_by = ?, used_at = CURRENT_TIMESTAMP + WHERE code = ? AND used = 0 + `, userEmail, code) + if err != nil { + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + return fmt.Errorf("内测码无效或已被使用") + } + return nil +} + +// GetStats 获取内测码统计 +func (s *BetaCodeStore) GetStats() (total, used int, err error) { + err = s.db.QueryRow(`SELECT COUNT(*) FROM beta_codes`).Scan(&total) + if err != nil { + return 0, 0, err + } + err = s.db.QueryRow(`SELECT COUNT(*) FROM beta_codes WHERE used = 1`).Scan(&used) + if err != nil { + return 0, 0, err + } + return total, used, nil +} diff --git a/store/decision.go b/store/decision.go new file mode 100644 index 00000000..7758deb0 --- /dev/null +++ b/store/decision.go @@ -0,0 +1,530 @@ +package store + +import ( + "database/sql" + "encoding/json" + "fmt" + "time" +) + +// DecisionStore 决策日志存储 +type DecisionStore struct { + db *sql.DB +} + +// DecisionRecord 决策记录 +type DecisionRecord struct { + ID int64 `json:"id"` + TraderID string `json:"trader_id"` + CycleNumber int `json:"cycle_number"` + Timestamp time.Time `json:"timestamp"` + SystemPrompt string `json:"system_prompt"` + InputPrompt string `json:"input_prompt"` + CoTTrace string `json:"cot_trace"` + DecisionJSON string `json:"decision_json"` + CandidateCoins []string `json:"candidate_coins"` + ExecutionLog []string `json:"execution_log"` + Success bool `json:"success"` + ErrorMessage string `json:"error_message"` + AIRequestDurationMs int64 `json:"ai_request_duration_ms"` + AccountState AccountSnapshot `json:"account_state"` + Positions []PositionSnapshot `json:"positions"` + Decisions []DecisionAction `json:"decisions"` +} + +// AccountSnapshot 账户状态快照 +type AccountSnapshot struct { + TotalBalance float64 `json:"total_balance"` + AvailableBalance float64 `json:"available_balance"` + TotalUnrealizedProfit float64 `json:"total_unrealized_profit"` + PositionCount int `json:"position_count"` + MarginUsedPct float64 `json:"margin_used_pct"` + InitialBalance float64 `json:"initial_balance"` +} + +// PositionSnapshot 持仓快照 +type PositionSnapshot struct { + Symbol string `json:"symbol"` + Side string `json:"side"` + PositionAmt float64 `json:"position_amt"` + EntryPrice float64 `json:"entry_price"` + MarkPrice float64 `json:"mark_price"` + UnrealizedProfit float64 `json:"unrealized_profit"` + Leverage float64 `json:"leverage"` + LiquidationPrice float64 `json:"liquidation_price"` +} + +// DecisionAction 决策动作 +type DecisionAction struct { + Action string `json:"action"` + Symbol string `json:"symbol"` + Quantity float64 `json:"quantity"` + Leverage int `json:"leverage"` + Price float64 `json:"price"` + OrderID int64 `json:"order_id"` + Timestamp time.Time `json:"timestamp"` + Success bool `json:"success"` + Error string `json:"error"` +} + +// Statistics 统计信息 +type Statistics struct { + TotalCycles int `json:"total_cycles"` + SuccessfulCycles int `json:"successful_cycles"` + FailedCycles int `json:"failed_cycles"` + TotalOpenPositions int `json:"total_open_positions"` + TotalClosePositions int `json:"total_close_positions"` +} + +// initTables 初始化决策相关表 +func (s *DecisionStore) initTables() error { + queries := []string{ + // 决策记录主表 + `CREATE TABLE IF NOT EXISTS decision_records ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + trader_id TEXT NOT NULL, + cycle_number INTEGER NOT NULL, + timestamp DATETIME NOT NULL, + system_prompt TEXT DEFAULT '', + input_prompt TEXT DEFAULT '', + cot_trace TEXT DEFAULT '', + decision_json TEXT DEFAULT '', + candidate_coins TEXT DEFAULT '', + execution_log TEXT DEFAULT '', + success BOOLEAN DEFAULT 0, + error_message TEXT DEFAULT '', + ai_request_duration_ms INTEGER DEFAULT 0, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + )`, + + // 账户状态快照表 + `CREATE TABLE IF NOT EXISTS decision_account_snapshots ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + decision_id INTEGER NOT NULL, + total_balance REAL DEFAULT 0, + available_balance REAL DEFAULT 0, + total_unrealized_profit REAL DEFAULT 0, + position_count INTEGER DEFAULT 0, + margin_used_pct REAL DEFAULT 0, + initial_balance REAL DEFAULT 0, + FOREIGN KEY (decision_id) REFERENCES decision_records(id) ON DELETE CASCADE + )`, + + // 持仓快照表 + `CREATE TABLE IF NOT EXISTS decision_position_snapshots ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + decision_id INTEGER NOT NULL, + symbol TEXT NOT NULL, + side TEXT DEFAULT '', + position_amt REAL DEFAULT 0, + entry_price REAL DEFAULT 0, + mark_price REAL DEFAULT 0, + unrealized_profit REAL DEFAULT 0, + leverage REAL DEFAULT 0, + liquidation_price REAL DEFAULT 0, + FOREIGN KEY (decision_id) REFERENCES decision_records(id) ON DELETE CASCADE + )`, + + // 决策动作表(订单详情) + `CREATE TABLE IF NOT EXISTS decision_actions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + decision_id INTEGER NOT NULL, + trader_id TEXT NOT NULL, + action TEXT NOT NULL, + symbol TEXT NOT NULL, + quantity REAL DEFAULT 0, + leverage INTEGER DEFAULT 0, + price REAL DEFAULT 0, + order_id INTEGER DEFAULT 0, + timestamp DATETIME NOT NULL, + success BOOLEAN DEFAULT 0, + error TEXT DEFAULT '', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (decision_id) REFERENCES decision_records(id) ON DELETE CASCADE + )`, + + // 索引 + `CREATE INDEX IF NOT EXISTS idx_decision_records_trader_time ON decision_records(trader_id, timestamp DESC)`, + `CREATE INDEX IF NOT EXISTS idx_decision_records_timestamp ON decision_records(timestamp DESC)`, + `CREATE INDEX IF NOT EXISTS idx_decision_actions_trader ON decision_actions(trader_id, timestamp DESC)`, + `CREATE INDEX IF NOT EXISTS idx_decision_actions_symbol ON decision_actions(symbol, timestamp DESC)`, + } + + for _, query := range queries { + if _, err := s.db.Exec(query); err != nil { + return fmt.Errorf("执行SQL失败: %w", err) + } + } + + return nil +} + +// LogDecision 记录决策 +func (s *DecisionStore) LogDecision(record *DecisionRecord) error { + if record.Timestamp.IsZero() { + record.Timestamp = time.Now().UTC() + } else { + record.Timestamp = record.Timestamp.UTC() + } + + // 开始事务 + tx, err := s.db.Begin() + if err != nil { + return fmt.Errorf("开始事务失败: %w", err) + } + defer tx.Rollback() + + // 序列化候选币种和执行日志为 JSON + candidateCoinsJSON, _ := json.Marshal(record.CandidateCoins) + executionLogJSON, _ := json.Marshal(record.ExecutionLog) + + // 插入决策记录主表 + result, err := tx.Exec(` + INSERT INTO decision_records ( + trader_id, cycle_number, timestamp, system_prompt, input_prompt, + cot_trace, decision_json, candidate_coins, execution_log, + success, error_message, ai_request_duration_ms + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `, + record.TraderID, record.CycleNumber, record.Timestamp.Format(time.RFC3339), + record.SystemPrompt, record.InputPrompt, record.CoTTrace, record.DecisionJSON, + string(candidateCoinsJSON), string(executionLogJSON), + record.Success, record.ErrorMessage, record.AIRequestDurationMs, + ) + if err != nil { + return fmt.Errorf("插入决策记录失败: %w", err) + } + + decisionID, err := result.LastInsertId() + if err != nil { + return fmt.Errorf("获取决策ID失败: %w", err) + } + record.ID = decisionID + + // 插入账户状态快照 + _, err = tx.Exec(` + INSERT INTO decision_account_snapshots ( + decision_id, total_balance, available_balance, total_unrealized_profit, + position_count, margin_used_pct, initial_balance + ) VALUES (?, ?, ?, ?, ?, ?, ?) + `, + decisionID, record.AccountState.TotalBalance, record.AccountState.AvailableBalance, + record.AccountState.TotalUnrealizedProfit, record.AccountState.PositionCount, + record.AccountState.MarginUsedPct, record.AccountState.InitialBalance, + ) + if err != nil { + return fmt.Errorf("插入账户快照失败: %w", err) + } + + // 插入持仓快照 + for _, pos := range record.Positions { + _, err = tx.Exec(` + INSERT INTO decision_position_snapshots ( + decision_id, symbol, side, position_amt, entry_price, + mark_price, unrealized_profit, leverage, liquidation_price + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + `, + decisionID, pos.Symbol, pos.Side, pos.PositionAmt, pos.EntryPrice, + pos.MarkPrice, pos.UnrealizedProfit, pos.Leverage, pos.LiquidationPrice, + ) + if err != nil { + return fmt.Errorf("插入持仓快照失败: %w", err) + } + } + + // 插入决策动作(订单详情) + for _, action := range record.Decisions { + actionTimestamp := action.Timestamp + if actionTimestamp.IsZero() { + actionTimestamp = record.Timestamp + } + _, err = tx.Exec(` + INSERT INTO decision_actions ( + decision_id, trader_id, action, symbol, quantity, leverage, + price, order_id, timestamp, success, error + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `, + decisionID, record.TraderID, action.Action, action.Symbol, action.Quantity, + action.Leverage, action.Price, action.OrderID, + actionTimestamp.Format(time.RFC3339), action.Success, action.Error, + ) + if err != nil { + return fmt.Errorf("插入决策动作失败: %w", err) + } + } + + // 提交事务 + if err := tx.Commit(); err != nil { + return fmt.Errorf("提交事务失败: %w", err) + } + + return nil +} + +// GetLatestRecords 获取指定交易员最近N条记录(按时间正序:从旧到新) +func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRecord, error) { + rows, err := s.db.Query(` + SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt, + cot_trace, decision_json, candidate_coins, execution_log, + success, error_message, ai_request_duration_ms + FROM decision_records + WHERE trader_id = ? + ORDER BY timestamp DESC + LIMIT ? + `, traderID, n) + if err != nil { + return nil, fmt.Errorf("查询决策记录失败: %w", err) + } + defer rows.Close() + + var records []*DecisionRecord + for rows.Next() { + record, err := s.scanDecisionRecord(rows) + if err != nil { + continue + } + records = append(records, record) + } + + // 填充关联数据 + for _, record := range records { + s.fillRecordDetails(record) + } + + // 反转数组,让时间从旧到新排列 + for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 { + records[i], records[j] = records[j], records[i] + } + + return records, nil +} + +// GetAllLatestRecords 获取所有交易员最近N条记录 +func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) { + rows, err := s.db.Query(` + SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt, + cot_trace, decision_json, candidate_coins, execution_log, + success, error_message, ai_request_duration_ms + FROM decision_records + ORDER BY timestamp DESC + LIMIT ? + `, n) + if err != nil { + return nil, fmt.Errorf("查询决策记录失败: %w", err) + } + defer rows.Close() + + var records []*DecisionRecord + for rows.Next() { + record, err := s.scanDecisionRecord(rows) + if err != nil { + continue + } + records = append(records, record) + } + + // 反转数组 + for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 { + records[i], records[j] = records[j], records[i] + } + + return records, nil +} + +// GetRecordsByDate 获取指定交易员指定日期的所有记录 +func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*DecisionRecord, error) { + dateStr := date.Format("2006-01-02") + + rows, err := s.db.Query(` + SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt, + cot_trace, decision_json, candidate_coins, execution_log, + success, error_message, ai_request_duration_ms + FROM decision_records + WHERE trader_id = ? AND DATE(timestamp) = ? + ORDER BY timestamp ASC + `, traderID, dateStr) + if err != nil { + return nil, fmt.Errorf("查询决策记录失败: %w", err) + } + defer rows.Close() + + var records []*DecisionRecord + for rows.Next() { + record, err := s.scanDecisionRecord(rows) + if err != nil { + continue + } + records = append(records, record) + } + + return records, nil +} + +// CleanOldRecords 清理N天前的旧记录 +func (s *DecisionStore) CleanOldRecords(traderID string, days int) (int64, error) { + cutoffTime := time.Now().AddDate(0, 0, -days).Format(time.RFC3339) + + result, err := s.db.Exec(` + DELETE FROM decision_records + WHERE trader_id = ? AND timestamp < ? + `, traderID, cutoffTime) + if err != nil { + return 0, fmt.Errorf("清理旧记录失败: %w", err) + } + + return result.RowsAffected() +} + +// GetStatistics 获取指定交易员的统计信息 +func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) { + stats := &Statistics{} + + err := s.db.QueryRow(` + SELECT COUNT(*) FROM decision_records WHERE trader_id = ? + `, traderID).Scan(&stats.TotalCycles) + if err != nil { + return nil, fmt.Errorf("查询总周期数失败: %w", err) + } + + err = s.db.QueryRow(` + SELECT COUNT(*) FROM decision_records WHERE trader_id = ? AND success = 1 + `, traderID).Scan(&stats.SuccessfulCycles) + if err != nil { + return nil, fmt.Errorf("查询成功周期数失败: %w", err) + } + stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles + + err = s.db.QueryRow(` + SELECT COUNT(*) FROM decision_actions + WHERE trader_id = ? AND success = 1 AND action IN ('open_long', 'open_short') + `, traderID).Scan(&stats.TotalOpenPositions) + if err != nil { + return nil, fmt.Errorf("查询开仓次数失败: %w", err) + } + + err = s.db.QueryRow(` + SELECT COUNT(*) FROM decision_actions + WHERE trader_id = ? AND success = 1 AND action IN ('close_long', 'close_short', 'auto_close_long', 'auto_close_short') + `, traderID).Scan(&stats.TotalClosePositions) + if err != nil { + return nil, fmt.Errorf("查询平仓次数失败: %w", err) + } + + return stats, nil +} + +// GetAllStatistics 获取所有交易员的统计信息 +func (s *DecisionStore) GetAllStatistics() (*Statistics, error) { + stats := &Statistics{} + + s.db.QueryRow(`SELECT COUNT(*) FROM decision_records`).Scan(&stats.TotalCycles) + s.db.QueryRow(`SELECT COUNT(*) FROM decision_records WHERE success = 1`).Scan(&stats.SuccessfulCycles) + stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles + + s.db.QueryRow(` + SELECT COUNT(*) FROM decision_actions + WHERE success = 1 AND action IN ('open_long', 'open_short') + `).Scan(&stats.TotalOpenPositions) + + s.db.QueryRow(` + SELECT COUNT(*) FROM decision_actions + WHERE success = 1 AND action IN ('close_long', 'close_short', 'auto_close_long', 'auto_close_short') + `).Scan(&stats.TotalClosePositions) + + return stats, nil +} + +// GetLastCycleNumber 获取指定交易员的最后周期编号 +func (s *DecisionStore) GetLastCycleNumber(traderID string) (int, error) { + var cycleNumber int + err := s.db.QueryRow(` + SELECT COALESCE(MAX(cycle_number), 0) FROM decision_records WHERE trader_id = ? + `, traderID).Scan(&cycleNumber) + if err != nil { + return 0, err + } + return cycleNumber, nil +} + +// scanDecisionRecord 从行中扫描决策记录 +func (s *DecisionStore) scanDecisionRecord(rows *sql.Rows) (*DecisionRecord, error) { + var record DecisionRecord + var timestampStr string + var candidateCoinsJSON, executionLogJSON string + + err := rows.Scan( + &record.ID, &record.TraderID, &record.CycleNumber, ×tampStr, + &record.SystemPrompt, &record.InputPrompt, &record.CoTTrace, + &record.DecisionJSON, &candidateCoinsJSON, &executionLogJSON, + &record.Success, &record.ErrorMessage, &record.AIRequestDurationMs, + ) + if err != nil { + return nil, err + } + + record.Timestamp, _ = time.Parse(time.RFC3339, timestampStr) + json.Unmarshal([]byte(candidateCoinsJSON), &record.CandidateCoins) + json.Unmarshal([]byte(executionLogJSON), &record.ExecutionLog) + + return &record, nil +} + +// fillRecordDetails 填充决策记录的关联数据 +func (s *DecisionStore) fillRecordDetails(record *DecisionRecord) { + // 查询账户状态 + s.db.QueryRow(` + SELECT total_balance, available_balance, total_unrealized_profit, + position_count, margin_used_pct, initial_balance + FROM decision_account_snapshots + WHERE decision_id = ? + `, record.ID).Scan( + &record.AccountState.TotalBalance, + &record.AccountState.AvailableBalance, + &record.AccountState.TotalUnrealizedProfit, + &record.AccountState.PositionCount, + &record.AccountState.MarginUsedPct, + &record.AccountState.InitialBalance, + ) + + // 查询持仓快照 + posRows, err := s.db.Query(` + SELECT symbol, side, position_amt, entry_price, mark_price, + unrealized_profit, leverage, liquidation_price + FROM decision_position_snapshots + WHERE decision_id = ? + `, record.ID) + if err == nil { + defer posRows.Close() + for posRows.Next() { + var pos PositionSnapshot + posRows.Scan( + &pos.Symbol, &pos.Side, &pos.PositionAmt, &pos.EntryPrice, + &pos.MarkPrice, &pos.UnrealizedProfit, &pos.Leverage, + &pos.LiquidationPrice, + ) + record.Positions = append(record.Positions, pos) + } + } + + // 查询决策动作 + actionRows, err := s.db.Query(` + SELECT action, symbol, quantity, leverage, price, order_id, + timestamp, success, error + FROM decision_actions + WHERE decision_id = ? + `, record.ID) + if err == nil { + defer actionRows.Close() + for actionRows.Next() { + var action DecisionAction + var timestampStr string + actionRows.Scan( + &action.Action, &action.Symbol, &action.Quantity, + &action.Leverage, &action.Price, &action.OrderID, + ×tampStr, &action.Success, &action.Error, + ) + action.Timestamp, _ = time.Parse(time.RFC3339, timestampStr) + record.Decisions = append(record.Decisions, action) + } + } +} diff --git a/store/exchange.go b/store/exchange.go new file mode 100644 index 00000000..ee532c1b --- /dev/null +++ b/store/exchange.go @@ -0,0 +1,245 @@ +package store + +import ( + "database/sql" + "fmt" + "nofx/logger" + "strings" + "time" +) + +// ExchangeStore 交易所存储 +type ExchangeStore struct { + db *sql.DB + encryptFunc func(string) string + decryptFunc func(string) string +} + +// Exchange 交易所配置 +type Exchange struct { + ID string `json:"id"` + UserID string `json:"user_id"` + Name string `json:"name"` + Type string `json:"type"` + Enabled bool `json:"enabled"` + APIKey string `json:"apiKey"` + SecretKey string `json:"secretKey"` + Testnet bool `json:"testnet"` + HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` + AsterUser string `json:"asterUser"` + AsterSigner string `json:"asterSigner"` + AsterPrivateKey string `json:"asterPrivateKey"` + LighterWalletAddr string `json:"lighterWalletAddr"` + LighterPrivateKey string `json:"lighterPrivateKey"` + LighterAPIKeyPrivateKey string `json:"lighterAPIKeyPrivateKey"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (s *ExchangeStore) initTables() error { + _, err := s.db.Exec(` + CREATE TABLE IF NOT EXISTS exchanges ( + id TEXT NOT NULL, + user_id TEXT NOT NULL DEFAULT 'default', + name TEXT NOT NULL, + type TEXT NOT NULL, + enabled BOOLEAN DEFAULT 0, + api_key TEXT DEFAULT '', + secret_key TEXT DEFAULT '', + testnet BOOLEAN DEFAULT 0, + hyperliquid_wallet_addr TEXT DEFAULT '', + aster_user TEXT DEFAULT '', + aster_signer TEXT DEFAULT '', + aster_private_key TEXT DEFAULT '', + lighter_wallet_addr TEXT DEFAULT '', + lighter_private_key TEXT DEFAULT '', + lighter_api_key_private_key TEXT DEFAULT '', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id, user_id), + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE + ) + `) + if err != nil { + return err + } + + // 触发器 + _, err = s.db.Exec(` + CREATE TRIGGER IF NOT EXISTS update_exchanges_updated_at + AFTER UPDATE ON exchanges + BEGIN + UPDATE exchanges SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id AND user_id = NEW.user_id; + END + `) + return err +} + +func (s *ExchangeStore) initDefaultData() error { + exchanges := []struct { + id, name, typ string + }{ + {"binance", "Binance Futures", "binance"}, + {"bybit", "Bybit Futures", "bybit"}, + {"hyperliquid", "Hyperliquid", "hyperliquid"}, + {"aster", "Aster DEX", "aster"}, + {"lighter", "LIGHTER DEX", "lighter"}, + } + + for _, exchange := range exchanges { + _, err := s.db.Exec(` + INSERT OR IGNORE INTO exchanges (id, user_id, name, type, enabled) + VALUES (?, 'default', ?, ?, 0) + `, exchange.id, exchange.name, exchange.typ) + if err != nil { + return fmt.Errorf("初始化交易所失败: %w", err) + } + } + return nil +} + +func (s *ExchangeStore) encrypt(plaintext string) string { + if s.encryptFunc != nil { + return s.encryptFunc(plaintext) + } + return plaintext +} + +func (s *ExchangeStore) decrypt(encrypted string) string { + if s.decryptFunc != nil { + return s.decryptFunc(encrypted) + } + return encrypted +} + +// List 获取用户的交易所列表 +func (s *ExchangeStore) List(userID string) ([]*Exchange, error) { + rows, err := s.db.Query(` + SELECT id, user_id, name, type, enabled, api_key, secret_key, testnet, + COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr, + COALESCE(aster_user, '') as aster_user, + COALESCE(aster_signer, '') as aster_signer, + COALESCE(aster_private_key, '') as aster_private_key, + COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr, + COALESCE(lighter_private_key, '') as lighter_private_key, + COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key, + created_at, updated_at + FROM exchanges WHERE user_id = ? ORDER BY id + `, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + exchanges := make([]*Exchange, 0) + for rows.Next() { + var e Exchange + var createdAt, updatedAt string + err := rows.Scan( + &e.ID, &e.UserID, &e.Name, &e.Type, + &e.Enabled, &e.APIKey, &e.SecretKey, &e.Testnet, + &e.HyperliquidWalletAddr, &e.AsterUser, &e.AsterSigner, &e.AsterPrivateKey, + &e.LighterWalletAddr, &e.LighterPrivateKey, &e.LighterAPIKeyPrivateKey, + &createdAt, &updatedAt, + ) + if err != nil { + return nil, err + } + e.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + e.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) + e.APIKey = s.decrypt(e.APIKey) + e.SecretKey = s.decrypt(e.SecretKey) + e.AsterPrivateKey = s.decrypt(e.AsterPrivateKey) + e.LighterPrivateKey = s.decrypt(e.LighterPrivateKey) + e.LighterAPIKeyPrivateKey = s.decrypt(e.LighterAPIKeyPrivateKey) + exchanges = append(exchanges, &e) + } + return exchanges, nil +} + +// Update 更新交易所配置 +func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKey string, testnet bool, + hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, lighterWalletAddr, lighterPrivateKey string) error { + + logger.Debugf("🔧 ExchangeStore.Update: userID=%s, id=%s, enabled=%v", userID, id, enabled) + + setClauses := []string{ + "enabled = ?", + "testnet = ?", + "hyperliquid_wallet_addr = ?", + "aster_user = ?", + "aster_signer = ?", + "lighter_wallet_addr = ?", + "updated_at = datetime('now')", + } + args := []interface{}{enabled, testnet, hyperliquidWalletAddr, asterUser, asterSigner, lighterWalletAddr} + + if apiKey != "" { + setClauses = append(setClauses, "api_key = ?") + args = append(args, s.encrypt(apiKey)) + } + if secretKey != "" { + setClauses = append(setClauses, "secret_key = ?") + args = append(args, s.encrypt(secretKey)) + } + if asterPrivateKey != "" { + setClauses = append(setClauses, "aster_private_key = ?") + args = append(args, s.encrypt(asterPrivateKey)) + } + if lighterPrivateKey != "" { + setClauses = append(setClauses, "lighter_private_key = ?") + args = append(args, s.encrypt(lighterPrivateKey)) + } + + args = append(args, id, userID) + query := fmt.Sprintf(`UPDATE exchanges SET %s WHERE id = ? AND user_id = ?`, strings.Join(setClauses, ", ")) + + result, err := s.db.Exec(query, args...) + if err != nil { + return err + } + + rowsAffected, _ := result.RowsAffected() + if rowsAffected == 0 { + // 创建新记录 + var name, typ string + switch id { + case "binance": + name, typ = "Binance Futures", "cex" + case "bybit": + name, typ = "Bybit Futures", "cex" + case "hyperliquid": + name, typ = "Hyperliquid", "dex" + case "aster": + name, typ = "Aster DEX", "dex" + case "lighter": + name, typ = "LIGHTER DEX", "dex" + default: + name, typ = id+" Exchange", "cex" + } + + _, err = s.db.Exec(` + INSERT INTO exchanges (id, user_id, name, type, enabled, api_key, secret_key, testnet, + hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key, + lighter_wallet_addr, lighter_private_key, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now')) + `, id, userID, name, typ, enabled, s.encrypt(apiKey), s.encrypt(secretKey), testnet, + hyperliquidWalletAddr, asterUser, asterSigner, s.encrypt(asterPrivateKey), + lighterWalletAddr, s.encrypt(lighterPrivateKey)) + return err + } + return nil +} + +// Create 创建交易所配置 +func (s *ExchangeStore) Create(userID, id, name, typ string, enabled bool, apiKey, secretKey string, testnet bool, + hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error { + _, err := s.db.Exec(` + INSERT OR IGNORE INTO exchanges (id, user_id, name, type, enabled, api_key, secret_key, testnet, + hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key, + lighter_wallet_addr, lighter_private_key) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, '', '') + `, id, userID, name, typ, enabled, s.encrypt(apiKey), s.encrypt(secretKey), testnet, + hyperliquidWalletAddr, asterUser, asterSigner, s.encrypt(asterPrivateKey)) + return err +} diff --git a/store/order.go b/store/order.go new file mode 100644 index 00000000..68649b47 --- /dev/null +++ b/store/order.go @@ -0,0 +1,511 @@ +package store + +import ( + "database/sql" + "fmt" + "math" + "time" +) + +// TraderOrder 交易员订单记录 +type TraderOrder struct { + ID int64 `json:"id"` + TraderID string `json:"trader_id"` // 交易员ID + OrderID string `json:"order_id"` // 交易所订单ID + ClientOrderID string `json:"client_order_id"` // 客户端订单ID + Symbol string `json:"symbol"` // 交易对 + Side string `json:"side"` // BUY/SELL + PositionSide string `json:"position_side"` // LONG/SHORT/BOTH + Action string `json:"action"` // open_long/close_long/open_short/close_short + OrderType string `json:"order_type"` // MARKET/LIMIT + Quantity float64 `json:"quantity"` // 订单数量 + Price float64 `json:"price"` // 订单价格 + AvgPrice float64 `json:"avg_price"` // 实际成交均价 + ExecutedQty float64 `json:"executed_qty"` // 已成交数量 + Leverage int `json:"leverage"` // 杠杆倍数 + Status string `json:"status"` // NEW/FILLED/CANCELED/EXPIRED + Fee float64 `json:"fee"` // 手续费 + FeeAsset string `json:"fee_asset"` // 手续费资产 + RealizedPnL float64 `json:"realized_pnl"` // 已实现盈亏(平仓时) + EntryPrice float64 `json:"entry_price"` // 开仓价(平仓时记录) + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + FilledAt time.Time `json:"filled_at"` // 成交时间 +} + +// TraderStats 交易统计指标 +type TraderStats struct { + TotalTrades int `json:"total_trades"` // 总交易数(已平仓) + WinTrades int `json:"win_trades"` // 盈利交易数 + LossTrades int `json:"loss_trades"` // 亏损交易数 + WinRate float64 `json:"win_rate"` // 胜率 (%) + ProfitFactor float64 `json:"profit_factor"` // 盈亏比 + SharpeRatio float64 `json:"sharpe_ratio"` // 夏普比 + TotalPnL float64 `json:"total_pnl"` // 总盈亏 + TotalFee float64 `json:"total_fee"` // 总手续费 + AvgWin float64 `json:"avg_win"` // 平均盈利 + AvgLoss float64 `json:"avg_loss"` // 平均亏损 + MaxDrawdownPct float64 `json:"max_drawdown_pct"` // 最大回撤 (%) +} + +// CompletedOrder 已完成订单(用于AI输入) +type CompletedOrder struct { + Symbol string `json:"symbol"` // 交易对 + Action string `json:"action"` // close_long/close_short + Side string `json:"side"` // long/short + Quantity float64 `json:"quantity"` // 数量 + EntryPrice float64 `json:"entry_price"` // 开仓价 + ExitPrice float64 `json:"exit_price"` // 平仓价 + RealizedPnL float64 `json:"realized_pnl"` // 已实现盈亏 + PnLPct float64 `json:"pnl_pct"` // 盈亏百分比 + Fee float64 `json:"fee"` // 手续费 + Leverage int `json:"leverage"` // 杠杆 + FilledAt time.Time `json:"filled_at"` // 成交时间 +} + +// OrderStore 订单存储 +type OrderStore struct { + db *sql.DB +} + +// NewOrderStore 创建订单存储实例 +func NewOrderStore(db *sql.DB) *OrderStore { + return &OrderStore{db: db} +} + +// InitTables 初始化订单表 +func (s *OrderStore) InitTables() error { + _, err := s.db.Exec(` + CREATE TABLE IF NOT EXISTS trader_orders ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + trader_id TEXT NOT NULL, + order_id TEXT NOT NULL, + client_order_id TEXT DEFAULT '', + symbol TEXT NOT NULL, + side TEXT NOT NULL, + position_side TEXT DEFAULT '', + action TEXT NOT NULL, + order_type TEXT DEFAULT 'MARKET', + quantity REAL NOT NULL, + price REAL DEFAULT 0, + avg_price REAL DEFAULT 0, + executed_qty REAL DEFAULT 0, + leverage INTEGER DEFAULT 1, + status TEXT DEFAULT 'NEW', + fee REAL DEFAULT 0, + fee_asset TEXT DEFAULT 'USDT', + realized_pnl REAL DEFAULT 0, + entry_price REAL DEFAULT 0, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + filled_at DATETIME, + UNIQUE(trader_id, order_id) + ) + `) + if err != nil { + return fmt.Errorf("创建trader_orders表失败: %w", err) + } + + // 创建索引 + indices := []string{ + `CREATE INDEX IF NOT EXISTS idx_trader_orders_trader ON trader_orders(trader_id)`, + `CREATE INDEX IF NOT EXISTS idx_trader_orders_status ON trader_orders(trader_id, status)`, + `CREATE INDEX IF NOT EXISTS idx_trader_orders_symbol ON trader_orders(trader_id, symbol)`, + `CREATE INDEX IF NOT EXISTS idx_trader_orders_filled ON trader_orders(trader_id, filled_at DESC)`, + } + for _, idx := range indices { + if _, err := s.db.Exec(idx); err != nil { + return fmt.Errorf("创建索引失败: %w", err) + } + } + + return nil +} + +// Create 创建订单记录 +func (s *OrderStore) Create(order *TraderOrder) error { + now := time.Now().Format(time.RFC3339) + result, err := s.db.Exec(` + INSERT INTO trader_orders ( + trader_id, order_id, client_order_id, symbol, side, position_side, + action, order_type, quantity, price, avg_price, executed_qty, + leverage, status, fee, fee_asset, realized_pnl, entry_price, + created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `, + order.TraderID, order.OrderID, order.ClientOrderID, order.Symbol, + order.Side, order.PositionSide, order.Action, order.OrderType, + order.Quantity, order.Price, order.AvgPrice, order.ExecutedQty, + order.Leverage, order.Status, order.Fee, order.FeeAsset, + order.RealizedPnL, order.EntryPrice, now, now, + ) + if err != nil { + return fmt.Errorf("创建订单记录失败: %w", err) + } + + id, _ := result.LastInsertId() + order.ID = id + return nil +} + +// Update 更新订单记录 +func (s *OrderStore) Update(order *TraderOrder) error { + now := time.Now().Format(time.RFC3339) + filledAt := "" + if !order.FilledAt.IsZero() { + filledAt = order.FilledAt.Format(time.RFC3339) + } + + _, err := s.db.Exec(` + UPDATE trader_orders SET + avg_price = ?, executed_qty = ?, status = ?, fee = ?, + realized_pnl = ?, entry_price = ?, updated_at = ?, filled_at = ? + WHERE trader_id = ? AND order_id = ? + `, + order.AvgPrice, order.ExecutedQty, order.Status, order.Fee, + order.RealizedPnL, order.EntryPrice, now, filledAt, + order.TraderID, order.OrderID, + ) + if err != nil { + return fmt.Errorf("更新订单记录失败: %w", err) + } + return nil +} + +// GetByOrderID 根据订单ID获取订单 +func (s *OrderStore) GetByOrderID(traderID, orderID string) (*TraderOrder, error) { + var order TraderOrder + var createdAt, updatedAt, filledAt sql.NullString + + err := s.db.QueryRow(` + SELECT id, trader_id, order_id, client_order_id, symbol, side, position_side, + action, order_type, quantity, price, avg_price, executed_qty, + leverage, status, fee, fee_asset, realized_pnl, entry_price, + created_at, updated_at, filled_at + FROM trader_orders WHERE trader_id = ? AND order_id = ? + `, traderID, orderID).Scan( + &order.ID, &order.TraderID, &order.OrderID, &order.ClientOrderID, + &order.Symbol, &order.Side, &order.PositionSide, &order.Action, + &order.OrderType, &order.Quantity, &order.Price, &order.AvgPrice, + &order.ExecutedQty, &order.Leverage, &order.Status, &order.Fee, + &order.FeeAsset, &order.RealizedPnL, &order.EntryPrice, + &createdAt, &updatedAt, &filledAt, + ) + if err != nil { + return nil, err + } + + if createdAt.Valid { + order.CreatedAt, _ = time.Parse(time.RFC3339, createdAt.String) + } + if updatedAt.Valid { + order.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String) + } + if filledAt.Valid { + order.FilledAt, _ = time.Parse(time.RFC3339, filledAt.String) + } + + return &order, nil +} + +// GetLatestOpenOrder 获取某币种最近的开仓订单(用于计算平仓盈亏) +func (s *OrderStore) GetLatestOpenOrder(traderID, symbol, side string) (*TraderOrder, error) { + // side: long -> 找 open_long, short -> 找 open_short + action := "open_long" + if side == "short" { + action = "open_short" + } + + var order TraderOrder + var createdAt, updatedAt, filledAt sql.NullString + + err := s.db.QueryRow(` + SELECT id, trader_id, order_id, client_order_id, symbol, side, position_side, + action, order_type, quantity, price, avg_price, executed_qty, + leverage, status, fee, fee_asset, realized_pnl, entry_price, + created_at, updated_at, filled_at + FROM trader_orders + WHERE trader_id = ? AND symbol = ? AND action = ? AND status = 'FILLED' + ORDER BY filled_at DESC LIMIT 1 + `, traderID, symbol, action).Scan( + &order.ID, &order.TraderID, &order.OrderID, &order.ClientOrderID, + &order.Symbol, &order.Side, &order.PositionSide, &order.Action, + &order.OrderType, &order.Quantity, &order.Price, &order.AvgPrice, + &order.ExecutedQty, &order.Leverage, &order.Status, &order.Fee, + &order.FeeAsset, &order.RealizedPnL, &order.EntryPrice, + &createdAt, &updatedAt, &filledAt, + ) + if err != nil { + return nil, err + } + + if createdAt.Valid { + order.CreatedAt, _ = time.Parse(time.RFC3339, createdAt.String) + } + if updatedAt.Valid { + order.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String) + } + if filledAt.Valid { + order.FilledAt, _ = time.Parse(time.RFC3339, filledAt.String) + } + + return &order, nil +} + +// GetRecentCompletedOrders 获取最近已完成的平仓订单 +func (s *OrderStore) GetRecentCompletedOrders(traderID string, limit int) ([]CompletedOrder, error) { + rows, err := s.db.Query(` + SELECT symbol, action, side, executed_qty, entry_price, avg_price, + realized_pnl, fee, leverage, filled_at + FROM trader_orders + WHERE trader_id = ? AND status = 'FILLED' + AND (action = 'close_long' OR action = 'close_short') + ORDER BY filled_at DESC + LIMIT ? + `, traderID, limit) + if err != nil { + return nil, fmt.Errorf("查询已完成订单失败: %w", err) + } + defer rows.Close() + + var orders []CompletedOrder + for rows.Next() { + var o CompletedOrder + var filledAt sql.NullString + var side sql.NullString + + err := rows.Scan( + &o.Symbol, &o.Action, &side, &o.Quantity, &o.EntryPrice, &o.ExitPrice, + &o.RealizedPnL, &o.Fee, &o.Leverage, &filledAt, + ) + if err != nil { + continue + } + + // 根据action推断side + if o.Action == "close_long" { + o.Side = "long" + } else if o.Action == "close_short" { + o.Side = "short" + } else if side.Valid { + o.Side = side.String + } + + // 计算盈亏百分比 + if o.EntryPrice > 0 { + if o.Side == "long" { + o.PnLPct = (o.ExitPrice - o.EntryPrice) / o.EntryPrice * 100 * float64(o.Leverage) + } else { + o.PnLPct = (o.EntryPrice - o.ExitPrice) / o.EntryPrice * 100 * float64(o.Leverage) + } + } + + if filledAt.Valid { + o.FilledAt, _ = time.Parse(time.RFC3339, filledAt.String) + } + + orders = append(orders, o) + } + + return orders, nil +} + +// GetTraderStats 获取交易统计指标 +func (s *OrderStore) GetTraderStats(traderID string) (*TraderStats, error) { + stats := &TraderStats{} + + // 查询所有已完成的平仓订单 + rows, err := s.db.Query(` + SELECT realized_pnl, fee, filled_at + FROM trader_orders + WHERE trader_id = ? AND status = 'FILLED' + AND (action = 'close_long' OR action = 'close_short') + ORDER BY filled_at ASC + `, traderID) + if err != nil { + return nil, fmt.Errorf("查询订单统计失败: %w", err) + } + defer rows.Close() + + var pnls []float64 + var totalWin, totalLoss float64 + + for rows.Next() { + var pnl, fee float64 + var filledAt sql.NullString + if err := rows.Scan(&pnl, &fee, &filledAt); err != nil { + continue + } + + stats.TotalTrades++ + stats.TotalPnL += pnl + stats.TotalFee += fee + pnls = append(pnls, pnl) + + if pnl > 0 { + stats.WinTrades++ + totalWin += pnl + } else if pnl < 0 { + stats.LossTrades++ + totalLoss += math.Abs(pnl) + } + } + + // 计算胜率 + if stats.TotalTrades > 0 { + stats.WinRate = float64(stats.WinTrades) / float64(stats.TotalTrades) * 100 + } + + // 计算盈亏比 + if totalLoss > 0 { + stats.ProfitFactor = totalWin / totalLoss + } + + // 计算平均盈亏 + if stats.WinTrades > 0 { + stats.AvgWin = totalWin / float64(stats.WinTrades) + } + if stats.LossTrades > 0 { + stats.AvgLoss = totalLoss / float64(stats.LossTrades) + } + + // 计算夏普比(使用盈亏序列) + if len(pnls) > 1 { + stats.SharpeRatio = calculateSharpeRatio(pnls) + } + + // 计算最大回撤 + if len(pnls) > 0 { + stats.MaxDrawdownPct = calculateMaxDrawdown(pnls) + } + + return stats, nil +} + +// calculateSharpeRatio 计算夏普比 +func calculateSharpeRatio(pnls []float64) float64 { + if len(pnls) < 2 { + return 0 + } + + // 计算平均收益 + var sum float64 + for _, pnl := range pnls { + sum += pnl + } + mean := sum / float64(len(pnls)) + + // 计算标准差 + var variance float64 + for _, pnl := range pnls { + variance += (pnl - mean) * (pnl - mean) + } + stdDev := math.Sqrt(variance / float64(len(pnls)-1)) + + if stdDev == 0 { + return 0 + } + + // 夏普比 = 平均收益 / 标准差 + return mean / stdDev +} + +// calculateMaxDrawdown 计算最大回撤 +func calculateMaxDrawdown(pnls []float64) float64 { + if len(pnls) == 0 { + return 0 + } + + // 计算累计权益曲线 + var cumulative float64 + var peak float64 + var maxDD float64 + + for _, pnl := range pnls { + cumulative += pnl + if cumulative > peak { + peak = cumulative + } + if peak > 0 { + dd := (peak - cumulative) / peak * 100 + if dd > maxDD { + maxDD = dd + } + } + } + + return maxDD +} + +// GetPendingOrders 获取未成交的订单(用于轮询) +func (s *OrderStore) GetPendingOrders(traderID string) ([]*TraderOrder, error) { + rows, err := s.db.Query(` + SELECT id, trader_id, order_id, client_order_id, symbol, side, position_side, + action, order_type, quantity, price, avg_price, executed_qty, + leverage, status, fee, fee_asset, realized_pnl, entry_price, + created_at, updated_at, filled_at + FROM trader_orders + WHERE trader_id = ? AND status = 'NEW' + ORDER BY created_at ASC + `, traderID) + if err != nil { + return nil, fmt.Errorf("查询未成交订单失败: %w", err) + } + defer rows.Close() + + return s.scanOrders(rows) +} + +// GetAllPendingOrders 获取所有未成交的订单(用于全局同步) +func (s *OrderStore) GetAllPendingOrders() ([]*TraderOrder, error) { + rows, err := s.db.Query(` + SELECT id, trader_id, order_id, client_order_id, symbol, side, position_side, + action, order_type, quantity, price, avg_price, executed_qty, + leverage, status, fee, fee_asset, realized_pnl, entry_price, + created_at, updated_at, filled_at + FROM trader_orders + WHERE status = 'NEW' + ORDER BY trader_id, created_at ASC + `) + if err != nil { + return nil, fmt.Errorf("查询未成交订单失败: %w", err) + } + defer rows.Close() + + return s.scanOrders(rows) +} + +// scanOrders 扫描订单行到结构体 +func (s *OrderStore) scanOrders(rows *sql.Rows) ([]*TraderOrder, error) { + var orders []*TraderOrder + for rows.Next() { + var order TraderOrder + var createdAt, updatedAt, filledAt sql.NullString + + err := rows.Scan( + &order.ID, &order.TraderID, &order.OrderID, &order.ClientOrderID, + &order.Symbol, &order.Side, &order.PositionSide, &order.Action, + &order.OrderType, &order.Quantity, &order.Price, &order.AvgPrice, + &order.ExecutedQty, &order.Leverage, &order.Status, &order.Fee, + &order.FeeAsset, &order.RealizedPnL, &order.EntryPrice, + &createdAt, &updatedAt, &filledAt, + ) + if err != nil { + continue + } + + if createdAt.Valid { + order.CreatedAt, _ = time.Parse(time.RFC3339, createdAt.String) + } + if updatedAt.Valid { + order.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String) + } + if filledAt.Valid { + order.FilledAt, _ = time.Parse(time.RFC3339, filledAt.String) + } + + orders = append(orders, &order) + } + + return orders, nil +} diff --git a/store/position.go b/store/position.go new file mode 100644 index 00000000..4d7310e7 --- /dev/null +++ b/store/position.go @@ -0,0 +1,473 @@ +package store + +import ( + "database/sql" + "fmt" + "math" + "time" +) + +// TraderPosition 仓位记录(完整的开平仓追踪) +type TraderPosition struct { + ID int64 `json:"id"` + TraderID string `json:"trader_id"` + Symbol string `json:"symbol"` + Side string `json:"side"` // LONG/SHORT + Quantity float64 `json:"quantity"` // 开仓数量 + EntryPrice float64 `json:"entry_price"` // 开仓均价 + EntryOrderID string `json:"entry_order_id"` // 开仓订单ID + EntryTime time.Time `json:"entry_time"` // 开仓时间 + ExitPrice float64 `json:"exit_price"` // 平仓均价 + ExitOrderID string `json:"exit_order_id"` // 平仓订单ID + ExitTime *time.Time `json:"exit_time"` // 平仓时间 + RealizedPnL float64 `json:"realized_pnl"` // 已实现盈亏 + Fee float64 `json:"fee"` // 手续费 + Leverage int `json:"leverage"` // 杠杆倍数 + Status string `json:"status"` // OPEN/CLOSED + CloseReason string `json:"close_reason"` // 平仓原因: ai_decision/manual/stop_loss/take_profit + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// PositionStore 仓位存储 +type PositionStore struct { + db *sql.DB +} + +// NewPositionStore 创建仓位存储实例 +func NewPositionStore(db *sql.DB) *PositionStore { + return &PositionStore{db: db} +} + +// InitTables 初始化仓位表 +func (s *PositionStore) InitTables() error { + _, err := s.db.Exec(` + CREATE TABLE IF NOT EXISTS trader_positions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + trader_id TEXT NOT NULL, + symbol TEXT NOT NULL, + side TEXT NOT NULL, + quantity REAL NOT NULL, + entry_price REAL NOT NULL, + entry_order_id TEXT DEFAULT '', + entry_time DATETIME NOT NULL, + exit_price REAL DEFAULT 0, + exit_order_id TEXT DEFAULT '', + exit_time DATETIME, + realized_pnl REAL DEFAULT 0, + fee REAL DEFAULT 0, + leverage INTEGER DEFAULT 1, + status TEXT DEFAULT 'OPEN', + close_reason TEXT DEFAULT '', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + `) + if err != nil { + return fmt.Errorf("创建trader_positions表失败: %w", err) + } + + // 创建索引 + indices := []string{ + `CREATE INDEX IF NOT EXISTS idx_positions_trader ON trader_positions(trader_id)`, + `CREATE INDEX IF NOT EXISTS idx_positions_status ON trader_positions(trader_id, status)`, + `CREATE INDEX IF NOT EXISTS idx_positions_symbol ON trader_positions(trader_id, symbol, side, status)`, + `CREATE INDEX IF NOT EXISTS idx_positions_entry ON trader_positions(trader_id, entry_time DESC)`, + `CREATE INDEX IF NOT EXISTS idx_positions_exit ON trader_positions(trader_id, exit_time DESC)`, + } + for _, idx := range indices { + if _, err := s.db.Exec(idx); err != nil { + return fmt.Errorf("创建索引失败: %w", err) + } + } + + return nil +} + +// Create 创建仓位记录(开仓时调用) +func (s *PositionStore) Create(pos *TraderPosition) error { + now := time.Now() + pos.CreatedAt = now + pos.UpdatedAt = now + pos.Status = "OPEN" + + result, err := s.db.Exec(` + INSERT INTO trader_positions ( + trader_id, symbol, side, quantity, entry_price, entry_order_id, + entry_time, leverage, status, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `, + pos.TraderID, pos.Symbol, pos.Side, pos.Quantity, pos.EntryPrice, + pos.EntryOrderID, pos.EntryTime.Format(time.RFC3339), pos.Leverage, + pos.Status, now.Format(time.RFC3339), now.Format(time.RFC3339), + ) + if err != nil { + return fmt.Errorf("创建仓位记录失败: %w", err) + } + + id, _ := result.LastInsertId() + pos.ID = id + return nil +} + +// ClosePosition 平仓(更新仓位记录) +func (s *PositionStore) ClosePosition(id int64, exitPrice float64, exitOrderID string, realizedPnL float64, fee float64, closeReason string) error { + now := time.Now() + _, err := s.db.Exec(` + UPDATE trader_positions SET + exit_price = ?, exit_order_id = ?, exit_time = ?, + realized_pnl = ?, fee = ?, status = 'CLOSED', + close_reason = ?, updated_at = ? + WHERE id = ? + `, + exitPrice, exitOrderID, now.Format(time.RFC3339), + realizedPnL, fee, closeReason, now.Format(time.RFC3339), id, + ) + if err != nil { + return fmt.Errorf("更新仓位记录失败: %w", err) + } + return nil +} + +// GetOpenPositions 获取所有未平仓位 +func (s *PositionStore) GetOpenPositions(traderID string) ([]*TraderPosition, error) { + rows, err := s.db.Query(` + SELECT id, trader_id, symbol, side, quantity, entry_price, entry_order_id, + entry_time, exit_price, exit_order_id, exit_time, realized_pnl, fee, + leverage, status, close_reason, created_at, updated_at + FROM trader_positions + WHERE trader_id = ? AND status = 'OPEN' + ORDER BY entry_time DESC + `, traderID) + if err != nil { + return nil, fmt.Errorf("查询未平仓位失败: %w", err) + } + defer rows.Close() + + return s.scanPositions(rows) +} + +// GetOpenPositionBySymbol 获取指定币种方向的未平仓位 +func (s *PositionStore) GetOpenPositionBySymbol(traderID, symbol, side string) (*TraderPosition, error) { + var pos TraderPosition + var entryTime, exitTime, createdAt, updatedAt sql.NullString + + err := s.db.QueryRow(` + SELECT id, trader_id, symbol, side, quantity, entry_price, entry_order_id, + entry_time, exit_price, exit_order_id, exit_time, realized_pnl, fee, + leverage, status, close_reason, created_at, updated_at + FROM trader_positions + WHERE trader_id = ? AND symbol = ? AND side = ? AND status = 'OPEN' + ORDER BY entry_time DESC LIMIT 1 + `, traderID, symbol, side).Scan( + &pos.ID, &pos.TraderID, &pos.Symbol, &pos.Side, &pos.Quantity, + &pos.EntryPrice, &pos.EntryOrderID, &entryTime, &pos.ExitPrice, + &pos.ExitOrderID, &exitTime, &pos.RealizedPnL, &pos.Fee, + &pos.Leverage, &pos.Status, &pos.CloseReason, &createdAt, &updatedAt, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + + s.parsePositionTimes(&pos, entryTime, exitTime, createdAt, updatedAt) + return &pos, nil +} + +// GetClosedPositions 获取已平仓位(历史记录) +func (s *PositionStore) GetClosedPositions(traderID string, limit int) ([]*TraderPosition, error) { + rows, err := s.db.Query(` + SELECT id, trader_id, symbol, side, quantity, entry_price, entry_order_id, + entry_time, exit_price, exit_order_id, exit_time, realized_pnl, fee, + leverage, status, close_reason, created_at, updated_at + FROM trader_positions + WHERE trader_id = ? AND status = 'CLOSED' + ORDER BY exit_time DESC + LIMIT ? + `, traderID, limit) + if err != nil { + return nil, fmt.Errorf("查询已平仓位失败: %w", err) + } + defer rows.Close() + + return s.scanPositions(rows) +} + +// GetAllOpenPositions 获取所有trader的未平仓位(用于全局同步) +func (s *PositionStore) GetAllOpenPositions() ([]*TraderPosition, error) { + rows, err := s.db.Query(` + SELECT id, trader_id, symbol, side, quantity, entry_price, entry_order_id, + entry_time, exit_price, exit_order_id, exit_time, realized_pnl, fee, + leverage, status, close_reason, created_at, updated_at + FROM trader_positions + WHERE status = 'OPEN' + ORDER BY trader_id, entry_time DESC + `) + if err != nil { + return nil, fmt.Errorf("查询所有未平仓位失败: %w", err) + } + defer rows.Close() + + return s.scanPositions(rows) +} + +// GetPositionStats 获取仓位统计(简单版) +func (s *PositionStore) GetPositionStats(traderID string) (map[string]interface{}, error) { + stats := make(map[string]interface{}) + + // 总交易数 + var totalTrades, winTrades int + var totalPnL, totalFee float64 + + err := s.db.QueryRow(` + SELECT + COUNT(*) as total, + SUM(CASE WHEN realized_pnl > 0 THEN 1 ELSE 0 END) as wins, + COALESCE(SUM(realized_pnl), 0) as total_pnl, + COALESCE(SUM(fee), 0) as total_fee + FROM trader_positions + WHERE trader_id = ? AND status = 'CLOSED' + `, traderID).Scan(&totalTrades, &winTrades, &totalPnL, &totalFee) + if err != nil { + return nil, err + } + + stats["total_trades"] = totalTrades + stats["win_trades"] = winTrades + stats["total_pnl"] = totalPnL + stats["total_fee"] = totalFee + if totalTrades > 0 { + stats["win_rate"] = float64(winTrades) / float64(totalTrades) * 100 + } else { + stats["win_rate"] = 0.0 + } + + return stats, nil +} + +// GetFullStats 获取完整的交易统计(与 TraderStats 兼容) +func (s *PositionStore) GetFullStats(traderID string) (*TraderStats, error) { + stats := &TraderStats{} + + // 查询所有已平仓位 + rows, err := s.db.Query(` + SELECT realized_pnl, fee, exit_time + FROM trader_positions + WHERE trader_id = ? AND status = 'CLOSED' + ORDER BY exit_time ASC + `, traderID) + if err != nil { + return nil, fmt.Errorf("查询仓位统计失败: %w", err) + } + defer rows.Close() + + var pnls []float64 + var totalWin, totalLoss float64 + + for rows.Next() { + var pnl, fee float64 + var exitTime sql.NullString + if err := rows.Scan(&pnl, &fee, &exitTime); err != nil { + continue + } + + stats.TotalTrades++ + stats.TotalPnL += pnl + stats.TotalFee += fee + pnls = append(pnls, pnl) + + if pnl > 0 { + stats.WinTrades++ + totalWin += pnl + } else if pnl < 0 { + stats.LossTrades++ + totalLoss += -pnl // 转为正数 + } + } + + // 计算胜率 + if stats.TotalTrades > 0 { + stats.WinRate = float64(stats.WinTrades) / float64(stats.TotalTrades) * 100 + } + + // 计算盈亏比 + if totalLoss > 0 { + stats.ProfitFactor = totalWin / totalLoss + } + + // 计算平均盈亏 + if stats.WinTrades > 0 { + stats.AvgWin = totalWin / float64(stats.WinTrades) + } + if stats.LossTrades > 0 { + stats.AvgLoss = totalLoss / float64(stats.LossTrades) + } + + // 计算夏普比 + if len(pnls) > 1 { + stats.SharpeRatio = calculateSharpeRatioFromPnls(pnls) + } + + // 计算最大回撤 + if len(pnls) > 0 { + stats.MaxDrawdownPct = calculateMaxDrawdownFromPnls(pnls) + } + + return stats, nil +} + +// RecentTrade 最近的交易记录(用于AI输入) +type RecentTrade struct { + Symbol string `json:"symbol"` + Side string `json:"side"` // long/short + EntryPrice float64 `json:"entry_price"` + ExitPrice float64 `json:"exit_price"` + RealizedPnL float64 `json:"realized_pnl"` + PnLPct float64 `json:"pnl_pct"` + ExitTime string `json:"exit_time"` +} + +// GetRecentTrades 获取最近的已平仓交易 +func (s *PositionStore) GetRecentTrades(traderID string, limit int) ([]RecentTrade, error) { + rows, err := s.db.Query(` + SELECT symbol, side, entry_price, exit_price, realized_pnl, leverage, exit_time + FROM trader_positions + WHERE trader_id = ? AND status = 'CLOSED' + ORDER BY exit_time DESC + LIMIT ? + `, traderID, limit) + if err != nil { + return nil, fmt.Errorf("查询最近交易失败: %w", err) + } + defer rows.Close() + + var trades []RecentTrade + for rows.Next() { + var t RecentTrade + var leverage int + var exitTime sql.NullString + + err := rows.Scan(&t.Symbol, &t.Side, &t.EntryPrice, &t.ExitPrice, &t.RealizedPnL, &leverage, &exitTime) + if err != nil { + continue + } + + // 转换 side 格式 + if t.Side == "LONG" { + t.Side = "long" + } else if t.Side == "SHORT" { + t.Side = "short" + } + + // 计算盈亏百分比 + if t.EntryPrice > 0 { + if t.Side == "long" { + t.PnLPct = (t.ExitPrice - t.EntryPrice) / t.EntryPrice * 100 * float64(leverage) + } else { + t.PnLPct = (t.EntryPrice - t.ExitPrice) / t.EntryPrice * 100 * float64(leverage) + } + } + + // 格式化时间 + if exitTime.Valid { + if parsed, err := time.Parse(time.RFC3339, exitTime.String); err == nil { + t.ExitTime = parsed.Format("01-02 15:04") + } + } + + trades = append(trades, t) + } + + return trades, nil +} + +// calculateSharpeRatioFromPnls 计算夏普比 +func calculateSharpeRatioFromPnls(pnls []float64) float64 { + if len(pnls) < 2 { + return 0 + } + + var sum float64 + for _, pnl := range pnls { + sum += pnl + } + mean := sum / float64(len(pnls)) + + var variance float64 + for _, pnl := range pnls { + variance += (pnl - mean) * (pnl - mean) + } + stdDev := math.Sqrt(variance / float64(len(pnls)-1)) + + if stdDev == 0 { + return 0 + } + + return mean / stdDev +} + +// calculateMaxDrawdownFromPnls 计算最大回撤 +func calculateMaxDrawdownFromPnls(pnls []float64) float64 { + if len(pnls) == 0 { + return 0 + } + + var cumulative, peak, maxDD float64 + for _, pnl := range pnls { + cumulative += pnl + if cumulative > peak { + peak = cumulative + } + if peak > 0 { + dd := (peak - cumulative) / peak * 100 + if dd > maxDD { + maxDD = dd + } + } + } + + return maxDD +} + +// scanPositions 扫描仓位行到结构体 +func (s *PositionStore) scanPositions(rows *sql.Rows) ([]*TraderPosition, error) { + var positions []*TraderPosition + for rows.Next() { + var pos TraderPosition + var entryTime, exitTime, createdAt, updatedAt sql.NullString + + err := rows.Scan( + &pos.ID, &pos.TraderID, &pos.Symbol, &pos.Side, &pos.Quantity, + &pos.EntryPrice, &pos.EntryOrderID, &entryTime, &pos.ExitPrice, + &pos.ExitOrderID, &exitTime, &pos.RealizedPnL, &pos.Fee, + &pos.Leverage, &pos.Status, &pos.CloseReason, &createdAt, &updatedAt, + ) + if err != nil { + continue + } + + s.parsePositionTimes(&pos, entryTime, exitTime, createdAt, updatedAt) + positions = append(positions, &pos) + } + + return positions, nil +} + +// parsePositionTimes 解析时间字段 +func (s *PositionStore) parsePositionTimes(pos *TraderPosition, entryTime, exitTime, createdAt, updatedAt sql.NullString) { + if entryTime.Valid { + pos.EntryTime, _ = time.Parse(time.RFC3339, entryTime.String) + } + if exitTime.Valid { + t, _ := time.Parse(time.RFC3339, exitTime.String) + pos.ExitTime = &t + } + if createdAt.Valid { + pos.CreatedAt, _ = time.Parse(time.RFC3339, createdAt.String) + } + if updatedAt.Valid { + pos.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String) + } +} diff --git a/store/signal_source.go b/store/signal_source.go new file mode 100644 index 00000000..6f0cc0e5 --- /dev/null +++ b/store/signal_source.go @@ -0,0 +1,86 @@ +package store + +import ( + "database/sql" + "time" +) + +// SignalSourceStore 信号源存储 +type SignalSourceStore struct { + db *sql.DB +} + +// SignalSource 用户信号源配置 +type SignalSource struct { + ID int `json:"id"` + UserID string `json:"user_id"` + CoinPoolURL string `json:"coin_pool_url"` + OITopURL string `json:"oi_top_url"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (s *SignalSourceStore) initTables() error { + _, err := s.db.Exec(` + CREATE TABLE IF NOT EXISTS user_signal_sources ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, + coin_pool_url TEXT DEFAULT '', + oi_top_url TEXT DEFAULT '', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + UNIQUE(user_id) + ) + `) + if err != nil { + return err + } + + // 触发器 + _, err = s.db.Exec(` + CREATE TRIGGER IF NOT EXISTS update_user_signal_sources_updated_at + AFTER UPDATE ON user_signal_sources + BEGIN + UPDATE user_signal_sources SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; + END + `) + return err +} + +// Create 创建信号源配置 +func (s *SignalSourceStore) Create(userID, coinPoolURL, oiTopURL string) error { + _, err := s.db.Exec(` + INSERT OR REPLACE INTO user_signal_sources (user_id, coin_pool_url, oi_top_url, updated_at) + VALUES (?, ?, ?, CURRENT_TIMESTAMP) + `, userID, coinPoolURL, oiTopURL) + return err +} + +// Get 获取信号源配置 +func (s *SignalSourceStore) Get(userID string) (*SignalSource, error) { + var source SignalSource + var createdAt, updatedAt string + err := s.db.QueryRow(` + SELECT id, user_id, coin_pool_url, oi_top_url, created_at, updated_at + FROM user_signal_sources WHERE user_id = ? + `, userID).Scan( + &source.ID, &source.UserID, &source.CoinPoolURL, &source.OITopURL, + &createdAt, &updatedAt, + ) + if err != nil { + return nil, err + } + source.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + source.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) + return &source, nil +} + +// Update 更新信号源配置 +func (s *SignalSourceStore) Update(userID, coinPoolURL, oiTopURL string) error { + _, err := s.db.Exec(` + UPDATE user_signal_sources SET coin_pool_url = ?, oi_top_url = ?, updated_at = CURRENT_TIMESTAMP + WHERE user_id = ? + `, coinPoolURL, oiTopURL, userID) + return err +} diff --git a/store/store.go b/store/store.go new file mode 100644 index 00000000..4c327ead --- /dev/null +++ b/store/store.go @@ -0,0 +1,319 @@ +// Package store 提供统一的数据库存储层 +// 所有数据库操作都应该通过这个包进行 +package store + +import ( + "database/sql" + "fmt" + "nofx/logger" + "sync" + + _ "modernc.org/sqlite" +) + +// Store 统一的数据存储接口 +type Store struct { + db *sql.DB + + // 子存储(延迟初始化) + user *UserStore + aiModel *AIModelStore + exchange *ExchangeStore + trader *TraderStore + systemConfig *SystemConfigStore + betaCode *BetaCodeStore + signalSource *SignalSourceStore + decision *DecisionStore + backtest *BacktestStore + order *OrderStore + position *PositionStore + + // 加密函数 + encryptFunc func(string) string + decryptFunc func(string) string + + mu sync.RWMutex +} + +// New 创建新的 Store 实例 +func New(dbPath string) (*Store, error) { + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, fmt.Errorf("打开数据库失败: %w", err) + } + + // SQLite 配置 + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + + // 启用外键约束 + if _, err := db.Exec(`PRAGMA foreign_keys = ON`); err != nil { + db.Close() + return nil, fmt.Errorf("启用外键失败: %w", err) + } + + // 使用 DELETE 模式(传统模式)以确保 Docker bind mount 兼容性 + // 注意:WAL 模式在 macOS Docker 下会导致数据同步问题 + if _, err := db.Exec("PRAGMA journal_mode=DELETE"); err != nil { + db.Close() + return nil, fmt.Errorf("设置journal_mode失败: %w", err) + } + + // 设置 synchronous=FULL + if _, err := db.Exec("PRAGMA synchronous=FULL"); err != nil { + db.Close() + return nil, fmt.Errorf("设置synchronous失败: %w", err) + } + + // 设置 busy_timeout + if _, err := db.Exec("PRAGMA busy_timeout = 5000"); err != nil { + db.Close() + return nil, fmt.Errorf("设置busy_timeout失败: %w", err) + } + + s := &Store{db: db} + + // 初始化所有表结构 + if err := s.initTables(); err != nil { + db.Close() + return nil, fmt.Errorf("初始化表结构失败: %w", err) + } + + // 初始化默认数据 + if err := s.initDefaultData(); err != nil { + db.Close() + return nil, fmt.Errorf("初始化默认数据失败: %w", err) + } + + logger.Info("✅ 数据库已启用 DELETE 模式和 FULL 同步") + return s, nil +} + +// NewFromDB 从现有数据库连接创建 Store +func NewFromDB(db *sql.DB) *Store { + return &Store{db: db} +} + +// SetCryptoFuncs 设置加密解密函数 +func (s *Store) SetCryptoFuncs(encrypt, decrypt func(string) string) { + s.mu.Lock() + defer s.mu.Unlock() + s.encryptFunc = encrypt + s.decryptFunc = decrypt + + // 更新已初始化的子存储 + if s.aiModel != nil { + s.aiModel.encryptFunc = encrypt + s.aiModel.decryptFunc = decrypt + } + if s.exchange != nil { + s.exchange.encryptFunc = encrypt + s.exchange.decryptFunc = decrypt + } + if s.trader != nil { + s.trader.decryptFunc = decrypt + } +} + +// initTables 初始化所有数据库表 +func (s *Store) initTables() error { + // 按依赖顺序初始化 + if err := s.User().initTables(); err != nil { + return fmt.Errorf("初始化用户表失败: %w", err) + } + if err := s.AIModel().initTables(); err != nil { + return fmt.Errorf("初始化AI模型表失败: %w", err) + } + if err := s.Exchange().initTables(); err != nil { + return fmt.Errorf("初始化交易所表失败: %w", err) + } + if err := s.Trader().initTables(); err != nil { + return fmt.Errorf("初始化交易员表失败: %w", err) + } + if err := s.SystemConfig().initTables(); err != nil { + return fmt.Errorf("初始化系统配置表失败: %w", err) + } + if err := s.BetaCode().initTables(); err != nil { + return fmt.Errorf("初始化内测码表失败: %w", err) + } + if err := s.SignalSource().initTables(); err != nil { + return fmt.Errorf("初始化信号源表失败: %w", err) + } + if err := s.Decision().initTables(); err != nil { + return fmt.Errorf("初始化决策日志表失败: %w", err) + } + if err := s.Backtest().initTables(); err != nil { + return fmt.Errorf("初始化回测表失败: %w", err) + } + if err := s.Order().InitTables(); err != nil { + return fmt.Errorf("初始化订单表失败: %w", err) + } + if err := s.Position().InitTables(); err != nil { + return fmt.Errorf("初始化仓位表失败: %w", err) + } + return nil +} + +// initDefaultData 初始化默认数据 +func (s *Store) initDefaultData() error { + if err := s.AIModel().initDefaultData(); err != nil { + return err + } + if err := s.Exchange().initDefaultData(); err != nil { + return err + } + if err := s.SystemConfig().initDefaultData(); err != nil { + return err + } + return nil +} + +// User 获取用户存储 +func (s *Store) User() *UserStore { + s.mu.Lock() + defer s.mu.Unlock() + if s.user == nil { + s.user = &UserStore{db: s.db} + } + return s.user +} + +// AIModel 获取AI模型存储 +func (s *Store) AIModel() *AIModelStore { + s.mu.Lock() + defer s.mu.Unlock() + if s.aiModel == nil { + s.aiModel = &AIModelStore{ + db: s.db, + encryptFunc: s.encryptFunc, + decryptFunc: s.decryptFunc, + } + } + return s.aiModel +} + +// Exchange 获取交易所存储 +func (s *Store) Exchange() *ExchangeStore { + s.mu.Lock() + defer s.mu.Unlock() + if s.exchange == nil { + s.exchange = &ExchangeStore{ + db: s.db, + encryptFunc: s.encryptFunc, + decryptFunc: s.decryptFunc, + } + } + return s.exchange +} + +// Trader 获取交易员存储 +func (s *Store) Trader() *TraderStore { + s.mu.Lock() + defer s.mu.Unlock() + if s.trader == nil { + s.trader = &TraderStore{ + db: s.db, + decryptFunc: s.decryptFunc, + } + } + return s.trader +} + +// SystemConfig 获取系统配置存储 +func (s *Store) SystemConfig() *SystemConfigStore { + s.mu.Lock() + defer s.mu.Unlock() + if s.systemConfig == nil { + s.systemConfig = &SystemConfigStore{db: s.db} + } + return s.systemConfig +} + +// BetaCode 获取内测码存储 +func (s *Store) BetaCode() *BetaCodeStore { + s.mu.Lock() + defer s.mu.Unlock() + if s.betaCode == nil { + s.betaCode = &BetaCodeStore{db: s.db} + } + return s.betaCode +} + +// SignalSource 获取信号源存储 +func (s *Store) SignalSource() *SignalSourceStore { + s.mu.Lock() + defer s.mu.Unlock() + if s.signalSource == nil { + s.signalSource = &SignalSourceStore{db: s.db} + } + return s.signalSource +} + +// Decision 获取决策日志存储 +func (s *Store) Decision() *DecisionStore { + s.mu.Lock() + defer s.mu.Unlock() + if s.decision == nil { + s.decision = &DecisionStore{db: s.db} + } + return s.decision +} + +// Backtest 获取回测数据存储 +func (s *Store) Backtest() *BacktestStore { + s.mu.Lock() + defer s.mu.Unlock() + if s.backtest == nil { + s.backtest = &BacktestStore{db: s.db} + } + return s.backtest +} + +// Order 获取订单存储 +func (s *Store) Order() *OrderStore { + s.mu.Lock() + defer s.mu.Unlock() + if s.order == nil { + s.order = NewOrderStore(s.db) + } + return s.order +} + +// Position 获取仓位存储 +func (s *Store) Position() *PositionStore { + s.mu.Lock() + defer s.mu.Unlock() + if s.position == nil { + s.position = NewPositionStore(s.db) + } + return s.position +} + +// Close 关闭数据库连接 +func (s *Store) Close() error { + return s.db.Close() +} + +// DB 获取底层数据库连接(仅用于兼容旧代码,逐步废弃) +// Deprecated: 使用 Store 的方法代替 +func (s *Store) DB() *sql.DB { + return s.db +} + +// Transaction 执行事务 +func (s *Store) Transaction(fn func(tx *sql.Tx) error) error { + tx, err := s.db.Begin() + if err != nil { + return fmt.Errorf("开始事务失败: %w", err) + } + + if err := fn(tx); err != nil { + tx.Rollback() + return err + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("提交事务失败: %w", err) + } + return nil +} diff --git a/store/system_config.go b/store/system_config.go new file mode 100644 index 00000000..45fd0401 --- /dev/null +++ b/store/system_config.go @@ -0,0 +1,70 @@ +package store + +import ( + "database/sql" +) + +// SystemConfigStore 系统配置存储 +type SystemConfigStore struct { + db *sql.DB +} + +func (s *SystemConfigStore) initTables() error { + _, err := s.db.Exec(` + CREATE TABLE IF NOT EXISTS system_config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + `) + if err != nil { + return err + } + + // 触发器 + _, err = s.db.Exec(` + CREATE TRIGGER IF NOT EXISTS update_system_config_updated_at + AFTER UPDATE ON system_config + BEGIN + UPDATE system_config SET updated_at = CURRENT_TIMESTAMP WHERE key = NEW.key; + END + `) + return err +} + +func (s *SystemConfigStore) initDefaultData() error { + configs := map[string]string{ + "beta_mode": "false", + "api_server_port": "8080", + "use_default_coins": "true", + "default_coins": `["BTCUSDT","ETHUSDT","SOLUSDT","BNBUSDT","XRPUSDT","DOGEUSDT","ADAUSDT","HYPEUSDT"]`, + "max_daily_loss": "10.0", + "max_drawdown": "20.0", + "stop_trading_minutes": "60", + "btc_eth_leverage": "5", + "altcoin_leverage": "5", + "jwt_secret": "", + "registration_enabled": "true", + } + + for key, value := range configs { + _, err := s.db.Exec(`INSERT OR IGNORE INTO system_config (key, value) VALUES (?, ?)`, key, value) + if err != nil { + return err + } + } + return nil +} + +// Get 获取配置值 +func (s *SystemConfigStore) Get(key string) (string, error) { + var value string + err := s.db.QueryRow(`SELECT value FROM system_config WHERE key = ?`, key).Scan(&value) + return value, err +} + +// Set 设置配置值 +func (s *SystemConfigStore) Set(key, value string) error { + _, err := s.db.Exec(`INSERT OR REPLACE INTO system_config (key, value) VALUES (?, ?)`, key, value) + return err +} diff --git a/store/trader.go b/store/trader.go new file mode 100644 index 00000000..e951640e --- /dev/null +++ b/store/trader.go @@ -0,0 +1,344 @@ +package store + +import ( + "database/sql" + "encoding/json" + "nofx/logger" + "nofx/market" + "slices" + "strings" + "time" +) + +// TraderStore 交易员存储 +type TraderStore struct { + db *sql.DB + decryptFunc func(string) string +} + +// Trader 交易员配置 +type Trader struct { + ID string `json:"id"` + UserID string `json:"user_id"` + Name string `json:"name"` + AIModelID string `json:"ai_model_id"` + ExchangeID string `json:"exchange_id"` + InitialBalance float64 `json:"initial_balance"` + ScanIntervalMinutes int `json:"scan_interval_minutes"` + IsRunning bool `json:"is_running"` + BTCETHLeverage int `json:"btc_eth_leverage"` + AltcoinLeverage int `json:"altcoin_leverage"` + TradingSymbols string `json:"trading_symbols"` + UseCoinPool bool `json:"use_coin_pool"` + UseOITop bool `json:"use_oi_top"` + CustomPrompt string `json:"custom_prompt"` + OverrideBasePrompt bool `json:"override_base_prompt"` + SystemPromptTemplate string `json:"system_prompt_template"` + IsCrossMargin bool `json:"is_cross_margin"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// TraderFullConfig 交易员完整配置(包含AI模型和交易所) +type TraderFullConfig struct { + Trader *Trader + AIModel *AIModel + Exchange *Exchange +} + +func (s *TraderStore) initTables() error { + _, err := s.db.Exec(` + CREATE TABLE IF NOT EXISTS traders ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL DEFAULT 'default', + name TEXT NOT NULL, + ai_model_id TEXT NOT NULL, + exchange_id TEXT NOT NULL, + initial_balance REAL NOT NULL, + scan_interval_minutes INTEGER DEFAULT 3, + is_running BOOLEAN DEFAULT 0, + btc_eth_leverage INTEGER DEFAULT 5, + altcoin_leverage INTEGER DEFAULT 5, + trading_symbols TEXT DEFAULT '', + use_coin_pool BOOLEAN DEFAULT 0, + use_oi_top BOOLEAN DEFAULT 0, + custom_prompt TEXT DEFAULT '', + override_base_prompt BOOLEAN DEFAULT 0, + system_prompt_template TEXT DEFAULT 'default', + is_cross_margin BOOLEAN DEFAULT 1, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE + ) + `) + if err != nil { + return err + } + + // 触发器 + _, err = s.db.Exec(` + CREATE TRIGGER IF NOT EXISTS update_traders_updated_at + AFTER UPDATE ON traders + BEGIN + UPDATE traders SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; + END + `) + if err != nil { + return err + } + + // 向后兼容 + alterQueries := []string{ + `ALTER TABLE traders ADD COLUMN custom_prompt TEXT DEFAULT ''`, + `ALTER TABLE traders ADD COLUMN override_base_prompt BOOLEAN DEFAULT 0`, + `ALTER TABLE traders ADD COLUMN is_cross_margin BOOLEAN DEFAULT 1`, + `ALTER TABLE traders ADD COLUMN btc_eth_leverage INTEGER DEFAULT 5`, + `ALTER TABLE traders ADD COLUMN altcoin_leverage INTEGER DEFAULT 5`, + `ALTER TABLE traders ADD COLUMN trading_symbols TEXT DEFAULT ''`, + `ALTER TABLE traders ADD COLUMN use_coin_pool BOOLEAN DEFAULT 0`, + `ALTER TABLE traders ADD COLUMN use_oi_top BOOLEAN DEFAULT 0`, + `ALTER TABLE traders ADD COLUMN system_prompt_template TEXT DEFAULT 'default'`, + } + for _, q := range alterQueries { + s.db.Exec(q) + } + + return nil +} + +func (s *TraderStore) decrypt(encrypted string) string { + if s.decryptFunc != nil { + return s.decryptFunc(encrypted) + } + return encrypted +} + +// Create 创建交易员 +func (s *TraderStore) Create(trader *Trader) error { + _, err := s.db.Exec(` + INSERT INTO traders (id, user_id, name, ai_model_id, exchange_id, initial_balance, scan_interval_minutes, + is_running, btc_eth_leverage, altcoin_leverage, trading_symbols, use_coin_pool, + use_oi_top, custom_prompt, override_base_prompt, system_prompt_template, is_cross_margin) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `, trader.ID, trader.UserID, trader.Name, trader.AIModelID, trader.ExchangeID, trader.InitialBalance, + trader.ScanIntervalMinutes, trader.IsRunning, trader.BTCETHLeverage, trader.AltcoinLeverage, + trader.TradingSymbols, trader.UseCoinPool, trader.UseOITop, trader.CustomPrompt, + trader.OverrideBasePrompt, trader.SystemPromptTemplate, trader.IsCrossMargin) + return err +} + +// List 获取用户的交易员列表 +func (s *TraderStore) List(userID string) ([]*Trader, error) { + rows, err := s.db.Query(` + SELECT id, user_id, name, ai_model_id, exchange_id, initial_balance, scan_interval_minutes, is_running, + COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''), + COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''), + COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'), + COALESCE(is_cross_margin, 1), created_at, updated_at + FROM traders WHERE user_id = ? ORDER BY created_at DESC + `, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var traders []*Trader + for rows.Next() { + var t Trader + var createdAt, updatedAt string + err := rows.Scan( + &t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, + &t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, + &t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols, + &t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt, + &t.SystemPromptTemplate, &t.IsCrossMargin, &createdAt, &updatedAt, + ) + if err != nil { + return nil, err + } + t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) + traders = append(traders, &t) + } + return traders, nil +} + +// UpdateStatus 更新交易员运行状态 +func (s *TraderStore) UpdateStatus(userID, id string, isRunning bool) error { + _, err := s.db.Exec(`UPDATE traders SET is_running = ? WHERE id = ? AND user_id = ?`, isRunning, id, userID) + return err +} + +// Update 更新交易员配置 +func (s *TraderStore) Update(trader *Trader) error { + _, err := s.db.Exec(` + UPDATE traders SET + name = ?, ai_model_id = ?, exchange_id = ?, scan_interval_minutes = ?, + btc_eth_leverage = ?, altcoin_leverage = ?, trading_symbols = ?, + custom_prompt = ?, override_base_prompt = ?, system_prompt_template = ?, + is_cross_margin = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = ? AND user_id = ? + `, trader.Name, trader.AIModelID, trader.ExchangeID, trader.ScanIntervalMinutes, + trader.BTCETHLeverage, trader.AltcoinLeverage, trader.TradingSymbols, + trader.CustomPrompt, trader.OverrideBasePrompt, trader.SystemPromptTemplate, + trader.IsCrossMargin, trader.ID, trader.UserID) + return err +} + +// UpdateInitialBalance 更新初始余额 +func (s *TraderStore) UpdateInitialBalance(userID, id string, newBalance float64) error { + _, err := s.db.Exec(`UPDATE traders SET initial_balance = ? WHERE id = ? AND user_id = ?`, newBalance, id, userID) + return err +} + +// UpdateCustomPrompt 更新自定义提示词 +func (s *TraderStore) UpdateCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error { + _, err := s.db.Exec(`UPDATE traders SET custom_prompt = ?, override_base_prompt = ? WHERE id = ? AND user_id = ?`, + customPrompt, overrideBase, id, userID) + return err +} + +// Delete 删除交易员 +func (s *TraderStore) Delete(userID, id string) error { + _, err := s.db.Exec(`DELETE FROM traders WHERE id = ? AND user_id = ?`, id, userID) + return err +} + +// GetFullConfig 获取交易员完整配置 +func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig, error) { + var trader Trader + var aiModel AIModel + var exchange Exchange + var traderCreatedAt, traderUpdatedAt string + var aiModelCreatedAt, aiModelUpdatedAt string + var exchangeCreatedAt, exchangeUpdatedAt string + + err := s.db.QueryRow(` + SELECT + t.id, t.user_id, t.name, t.ai_model_id, t.exchange_id, t.initial_balance, t.scan_interval_minutes, t.is_running, + COALESCE(t.btc_eth_leverage, 5), COALESCE(t.altcoin_leverage, 5), COALESCE(t.trading_symbols, ''), + COALESCE(t.use_coin_pool, 0), COALESCE(t.use_oi_top, 0), COALESCE(t.custom_prompt, ''), + COALESCE(t.override_base_prompt, 0), COALESCE(t.system_prompt_template, 'default'), + COALESCE(t.is_cross_margin, 1), t.created_at, t.updated_at, + a.id, a.user_id, a.name, a.provider, a.enabled, a.api_key, + COALESCE(a.custom_api_url, ''), COALESCE(a.custom_model_name, ''), a.created_at, a.updated_at, + e.id, e.user_id, e.name, e.type, e.enabled, e.api_key, e.secret_key, e.testnet, + COALESCE(e.hyperliquid_wallet_addr, ''), COALESCE(e.aster_user, ''), COALESCE(e.aster_signer, ''), + COALESCE(e.aster_private_key, ''), COALESCE(e.lighter_wallet_addr, ''), COALESCE(e.lighter_private_key, ''), + COALESCE(e.lighter_api_key_private_key, ''), e.created_at, e.updated_at + FROM traders t + JOIN ai_models a ON t.ai_model_id = a.id AND t.user_id = a.user_id + JOIN exchanges e ON t.exchange_id = e.id AND t.user_id = e.user_id + WHERE t.id = ? AND t.user_id = ? + `, traderID, userID).Scan( + &trader.ID, &trader.UserID, &trader.Name, &trader.AIModelID, &trader.ExchangeID, + &trader.InitialBalance, &trader.ScanIntervalMinutes, &trader.IsRunning, + &trader.BTCETHLeverage, &trader.AltcoinLeverage, &trader.TradingSymbols, + &trader.UseCoinPool, &trader.UseOITop, &trader.CustomPrompt, &trader.OverrideBasePrompt, + &trader.SystemPromptTemplate, &trader.IsCrossMargin, &traderCreatedAt, &traderUpdatedAt, + &aiModel.ID, &aiModel.UserID, &aiModel.Name, &aiModel.Provider, &aiModel.Enabled, &aiModel.APIKey, + &aiModel.CustomAPIURL, &aiModel.CustomModelName, &aiModelCreatedAt, &aiModelUpdatedAt, + &exchange.ID, &exchange.UserID, &exchange.Name, &exchange.Type, &exchange.Enabled, + &exchange.APIKey, &exchange.SecretKey, &exchange.Testnet, &exchange.HyperliquidWalletAddr, + &exchange.AsterUser, &exchange.AsterSigner, &exchange.AsterPrivateKey, + &exchange.LighterWalletAddr, &exchange.LighterPrivateKey, &exchange.LighterAPIKeyPrivateKey, + &exchangeCreatedAt, &exchangeUpdatedAt, + ) + if err != nil { + return nil, err + } + + trader.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", traderCreatedAt) + trader.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", traderUpdatedAt) + aiModel.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelCreatedAt) + aiModel.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelUpdatedAt) + exchange.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeCreatedAt) + exchange.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeUpdatedAt) + + // 解密 + aiModel.APIKey = s.decrypt(aiModel.APIKey) + exchange.APIKey = s.decrypt(exchange.APIKey) + exchange.SecretKey = s.decrypt(exchange.SecretKey) + exchange.AsterPrivateKey = s.decrypt(exchange.AsterPrivateKey) + exchange.LighterPrivateKey = s.decrypt(exchange.LighterPrivateKey) + exchange.LighterAPIKeyPrivateKey = s.decrypt(exchange.LighterAPIKeyPrivateKey) + + return &TraderFullConfig{ + Trader: &trader, + AIModel: &aiModel, + Exchange: &exchange, + }, nil +} + +// GetCustomCoins 获取所有交易员自定义币种 +func (s *TraderStore) GetCustomCoins() []string { + var symbol string + var symbols []string + _ = s.db.QueryRow(` + SELECT GROUP_CONCAT(trading_symbols, ',') as symbol + FROM traders WHERE trading_symbols != '' + `).Scan(&symbol) + + // 如果没有自定义币种,返回默认币种 + if symbol == "" { + var symbolJSON string + _ = s.db.QueryRow(`SELECT value FROM system_config WHERE key = 'default_coins'`).Scan(&symbolJSON) + if symbolJSON != "" { + if err := json.Unmarshal([]byte(symbolJSON), &symbols); err != nil { + logger.Warnf("⚠️ 解析default_coins配置失败: %v,使用硬编码默认值", err) + symbols = []string{"BTCUSDT", "ETHUSDT", "SOLUSDT", "BNBUSDT"} + } + } else { + symbols = []string{"BTCUSDT", "ETHUSDT", "SOLUSDT", "BNBUSDT"} + } + return symbols + } + + // 处理并去重币种列表 + for _, s := range strings.Split(symbol, ",") { + if s == "" { + continue + } + coin := market.Normalize(s) + if !slices.Contains(symbols, coin) { + symbols = append(symbols, coin) + } + } + return symbols +} + +// ListAll 获取所有用户的交易员列表 +func (s *TraderStore) ListAll() ([]*Trader, error) { + rows, err := s.db.Query(` + SELECT id, user_id, name, ai_model_id, exchange_id, initial_balance, scan_interval_minutes, is_running, + COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''), + COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''), + COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'), + COALESCE(is_cross_margin, 1), created_at, updated_at + FROM traders ORDER BY created_at DESC + `) + if err != nil { + return nil, err + } + defer rows.Close() + + var traders []*Trader + for rows.Next() { + var t Trader + var createdAt, updatedAt string + err := rows.Scan( + &t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, + &t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, + &t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols, + &t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt, + &t.SystemPromptTemplate, &t.IsCrossMargin, &createdAt, &updatedAt, + ) + if err != nil { + return nil, err + } + t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) + traders = append(traders, &t) + } + return traders, nil +} diff --git a/store/user.go b/store/user.go new file mode 100644 index 00000000..6e9c993f --- /dev/null +++ b/store/user.go @@ -0,0 +1,164 @@ +package store + +import ( + "crypto/rand" + "database/sql" + "encoding/base32" + "time" +) + +// UserStore 用户存储 +type UserStore struct { + db *sql.DB +} + +// User 用户 +type User struct { + ID string `json:"id"` + Email string `json:"email"` + PasswordHash string `json:"-"` + OTPSecret string `json:"-"` + OTPVerified bool `json:"otp_verified"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// GenerateOTPSecret 生成OTP密钥 +func GenerateOTPSecret() (string, error) { + secret := make([]byte, 20) + _, err := rand.Read(secret) + if err != nil { + return "", err + } + return base32.StdEncoding.EncodeToString(secret), nil +} + +func (s *UserStore) initTables() error { + _, err := s.db.Exec(` + CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, + email TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + otp_secret TEXT, + otp_verified BOOLEAN DEFAULT 0, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + `) + if err != nil { + return err + } + + // 触发器 + _, err = s.db.Exec(` + CREATE TRIGGER IF NOT EXISTS update_users_updated_at + AFTER UPDATE ON users + BEGIN + UPDATE users SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id; + END + `) + if err != nil { + return err + } + + return nil +} + +// Create 创建用户 +func (s *UserStore) Create(user *User) error { + _, err := s.db.Exec(` + INSERT INTO users (id, email, password_hash, otp_secret, otp_verified) + VALUES (?, ?, ?, ?, ?) + `, user.ID, user.Email, user.PasswordHash, user.OTPSecret, user.OTPVerified) + return err +} + +// GetByEmail 通过邮箱获取用户 +func (s *UserStore) GetByEmail(email string) (*User, error) { + var user User + var createdAt, updatedAt string + err := s.db.QueryRow(` + SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at + FROM users WHERE email = ? + `, email).Scan( + &user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret, + &user.OTPVerified, &createdAt, &updatedAt, + ) + if err != nil { + return nil, err + } + user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) + return &user, nil +} + +// GetByID 通过ID获取用户 +func (s *UserStore) GetByID(userID string) (*User, error) { + var user User + var createdAt, updatedAt string + err := s.db.QueryRow(` + SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at + FROM users WHERE id = ? + `, userID).Scan( + &user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret, + &user.OTPVerified, &createdAt, &updatedAt, + ) + if err != nil { + return nil, err + } + user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) + return &user, nil +} + +// GetAllIDs 获取所有用户ID +func (s *UserStore) GetAllIDs() ([]string, error) { + rows, err := s.db.Query(`SELECT id FROM users ORDER BY id`) + if err != nil { + return nil, err + } + defer rows.Close() + + var userIDs []string + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + userIDs = append(userIDs, userID) + } + return userIDs, nil +} + +// UpdateOTPVerified 更新OTP验证状态 +func (s *UserStore) UpdateOTPVerified(userID string, verified bool) error { + _, err := s.db.Exec(`UPDATE users SET otp_verified = ? WHERE id = ?`, verified, userID) + return err +} + +// UpdatePassword 更新密码 +func (s *UserStore) UpdatePassword(userID, passwordHash string) error { + _, err := s.db.Exec(` + UPDATE users SET password_hash = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ? + `, passwordHash, userID) + return err +} + +// EnsureAdmin 确保admin用户存在 +func (s *UserStore) EnsureAdmin() error { + var count int + err := s.db.QueryRow(`SELECT COUNT(*) FROM users WHERE id = 'admin'`).Scan(&count) + if err != nil { + return err + } + if count > 0 { + return nil + } + return s.Create(&User{ + ID: "admin", + Email: "admin@localhost", + PasswordHash: "", + OTPSecret: "", + OTPVerified: true, + }) +} diff --git a/trader/aster_trader.go b/trader/aster_trader.go index e33c1b0e..7a172739 100644 --- a/trader/aster_trader.go +++ b/trader/aster_trader.go @@ -8,7 +8,7 @@ import ( "errors" "fmt" "io" - "log" + "nofx/logger" "math" "math/big" "net/http" @@ -469,13 +469,13 @@ func (t *AsterTrader) GetBalance() (map[string]interface{}, error) { } if !foundUSDT { - log.Printf("⚠️ 未找到USDT资产记录!") + logger.Infof("⚠️ 未找到USDT资产记录!") } // 获取持仓计算保证金占用和真实未实现盈亏 positions, err := t.GetPositions() if err != nil { - log.Printf("⚠️ 获取持仓信息失败: %v", err) + logger.Infof("⚠️ 获取持仓信息失败: %v", err) // fallback: 无法获取持仓时使用简单计算 return map[string]interface{}{ "totalWalletBalance": crossWalletBalance, @@ -577,7 +577,7 @@ func (t *AsterTrader) GetPositions() ([]map[string]interface{}, error) { func (t *AsterTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { // 开仓前先取消所有挂单,防止残留挂单导致仓位叠加 if err := t.CancelAllOrders(symbol); err != nil { - log.Printf(" ⚠ 取消挂单失败(继续开仓): %v", err) + logger.Infof(" ⚠ 取消挂单失败(继续开仓): %v", err) } // 先设置杠杆 @@ -614,7 +614,7 @@ func (t *AsterTrader) OpenLong(symbol string, quantity float64, leverage int) (m priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) - log.Printf(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)", + logger.Infof(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)", limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision) params := map[string]interface{}{ @@ -644,7 +644,7 @@ func (t *AsterTrader) OpenLong(symbol string, quantity float64, leverage int) (m func (t *AsterTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { // 开仓前先取消所有挂单,防止残留挂单导致仓位叠加 if err := t.CancelAllOrders(symbol); err != nil { - log.Printf(" ⚠ 取消挂单失败(继续开仓): %v", err) + logger.Infof(" ⚠ 取消挂单失败(继续开仓): %v", err) } // 先设置杠杆 @@ -681,7 +681,7 @@ func (t *AsterTrader) OpenShort(symbol string, quantity float64, leverage int) ( priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) - log.Printf(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)", + logger.Infof(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)", limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision) params := map[string]interface{}{ @@ -726,7 +726,7 @@ func (t *AsterTrader) CloseLong(symbol string, quantity float64) (map[string]int if quantity == 0 { return nil, fmt.Errorf("没有找到 %s 的多仓", symbol) } - log.Printf(" 📊 获取到多仓数量: %.8f", quantity) + logger.Infof(" 📊 获取到多仓数量: %.8f", quantity) } price, err := t.GetMarketPrice(symbol) @@ -756,7 +756,7 @@ func (t *AsterTrader) CloseLong(symbol string, quantity float64) (map[string]int priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) - log.Printf(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)", + logger.Infof(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)", limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision) params := map[string]interface{}{ @@ -779,11 +779,11 @@ func (t *AsterTrader) CloseLong(symbol string, quantity float64) (map[string]int return nil, err } - log.Printf("✓ 平多仓成功: %s 数量: %s", symbol, qtyStr) + logger.Infof("✓ 平多仓成功: %s 数量: %s", symbol, qtyStr) // 平仓后取消该币种的所有挂单(止损止盈单) if err := t.CancelAllOrders(symbol); err != nil { - log.Printf(" ⚠ 取消挂单失败: %v", err) + logger.Infof(" ⚠ 取消挂单失败: %v", err) } return result, nil @@ -809,7 +809,7 @@ func (t *AsterTrader) CloseShort(symbol string, quantity float64) (map[string]in if quantity == 0 { return nil, fmt.Errorf("没有找到 %s 的空仓", symbol) } - log.Printf(" 📊 获取到空仓数量: %.8f", quantity) + logger.Infof(" 📊 获取到空仓数量: %.8f", quantity) } price, err := t.GetMarketPrice(symbol) @@ -839,7 +839,7 @@ func (t *AsterTrader) CloseShort(symbol string, quantity float64) (map[string]in priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision) qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision) - log.Printf(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)", + logger.Infof(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)", limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision) params := map[string]interface{}{ @@ -862,11 +862,11 @@ func (t *AsterTrader) CloseShort(symbol string, quantity float64) (map[string]in return nil, err } - log.Printf("✓ 平空仓成功: %s 数量: %s", symbol, qtyStr) + logger.Infof("✓ 平空仓成功: %s 数量: %s", symbol, qtyStr) // 平仓后取消该币种的所有挂单(止损止盈单) if err := t.CancelAllOrders(symbol); err != nil { - log.Printf(" ⚠ 取消挂单失败: %v", err) + logger.Infof(" ⚠ 取消挂单失败: %v", err) } return result, nil @@ -892,30 +892,30 @@ func (t *AsterTrader) SetMarginMode(symbol string, isCrossMargin bool) error { // 如果错误表示无需更改,忽略错误 if strings.Contains(err.Error(), "No need to change") || strings.Contains(err.Error(), "Margin type cannot be changed") { - log.Printf(" ✓ %s 仓位模式已是 %s 或有持仓无法更改", symbol, marginType) + logger.Infof(" ✓ %s 仓位模式已是 %s 或有持仓无法更改", symbol, marginType) return nil } // 检测多资产模式(错误码 -4168) if strings.Contains(err.Error(), "Multi-Assets mode") || strings.Contains(err.Error(), "-4168") || strings.Contains(err.Error(), "4168") { - log.Printf(" ⚠️ %s 检测到多资产模式,强制使用全仓模式", symbol) - log.Printf(" 💡 提示:如需使用逐仓模式,请在交易所关闭多资产模式") + logger.Infof(" ⚠️ %s 检测到多资产模式,强制使用全仓模式", symbol) + logger.Infof(" 💡 提示:如需使用逐仓模式,请在交易所关闭多资产模式") return nil } // 检测统一账户 API if strings.Contains(err.Error(), "unified") || strings.Contains(err.Error(), "portfolio") || strings.Contains(err.Error(), "Portfolio") { - log.Printf(" ❌ %s 检测到统一账户 API,无法进行合约交易", symbol) + logger.Infof(" ❌ %s 检测到统一账户 API,无法进行合约交易", symbol) return fmt.Errorf("请使用「现货与合约交易」API 权限,不要使用「统一账户 API」") } - log.Printf(" ⚠️ 设置仓位模式失败: %v", err) + logger.Infof(" ⚠️ 设置仓位模式失败: %v", err) // 不返回错误,让交易继续 return nil } - log.Printf(" ✓ %s 仓位模式已设置为 %s", symbol, marginType) + logger.Infof(" ✓ %s 仓位模式已设置为 %s", symbol, marginType) return nil } @@ -1075,19 +1075,19 @@ func (t *AsterTrader) CancelStopLossOrders(symbol string) error { if err != nil { errMsg := fmt.Sprintf("订单ID %d: %v", int64(orderID), err) cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) - log.Printf(" ⚠ 取消止损单失败: %s", errMsg) + logger.Infof(" ⚠ 取消止损单失败: %s", errMsg) continue } canceledCount++ - log.Printf(" ✓ 已取消止损单 (订单ID: %d, 类型: %s, 方向: %s)", int64(orderID), orderType, positionSide) + logger.Infof(" ✓ 已取消止损单 (订单ID: %d, 类型: %s, 方向: %s)", int64(orderID), orderType, positionSide) } } if canceledCount == 0 && len(cancelErrors) == 0 { - log.Printf(" ℹ %s 没有止损单需要取消", symbol) + logger.Infof(" ℹ %s 没有止损单需要取消", symbol) } else if canceledCount > 0 { - log.Printf(" ✓ 已取消 %s 的 %d 个止损单", symbol, canceledCount) + logger.Infof(" ✓ 已取消 %s 的 %d 个止损单", symbol, canceledCount) } // 如果所有取消都失败了,返回错误 @@ -1134,19 +1134,19 @@ func (t *AsterTrader) CancelTakeProfitOrders(symbol string) error { if err != nil { errMsg := fmt.Sprintf("订单ID %d: %v", int64(orderID), err) cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) - log.Printf(" ⚠ 取消止盈单失败: %s", errMsg) + logger.Infof(" ⚠ 取消止盈单失败: %s", errMsg) continue } canceledCount++ - log.Printf(" ✓ 已取消止盈单 (订单ID: %d, 类型: %s, 方向: %s)", int64(orderID), orderType, positionSide) + logger.Infof(" ✓ 已取消止盈单 (订单ID: %d, 类型: %s, 方向: %s)", int64(orderID), orderType, positionSide) } } if canceledCount == 0 && len(cancelErrors) == 0 { - log.Printf(" ℹ %s 没有止盈单需要取消", symbol) + logger.Infof(" ℹ %s 没有止盈单需要取消", symbol) } else if canceledCount > 0 { - log.Printf(" ✓ 已取消 %s 的 %d 个止盈单", symbol, canceledCount) + logger.Infof(" ✓ 已取消 %s 的 %d 个止盈单", symbol, canceledCount) } // 如果所有取消都失败了,返回错误 @@ -1203,20 +1203,20 @@ func (t *AsterTrader) CancelStopOrders(symbol string) error { _, err := t.request("DELETE", "/fapi/v3/order", cancelParams) if err != nil { - log.Printf(" ⚠ 取消订单 %d 失败: %v", int64(orderID), err) + logger.Infof(" ⚠ 取消订单 %d 失败: %v", int64(orderID), err) continue } canceledCount++ - log.Printf(" ✓ 已取消 %s 的止盈/止损单 (订单ID: %d, 类型: %s)", + logger.Infof(" ✓ 已取消 %s 的止盈/止损单 (订单ID: %d, 类型: %s)", symbol, int64(orderID), orderType) } } if canceledCount == 0 { - log.Printf(" ℹ %s 没有止盈/止损单需要取消", symbol) + logger.Infof(" ℹ %s 没有止盈/止损单需要取消", symbol) } else { - log.Printf(" ✓ 已取消 %s 的 %d 个止盈/止损单", symbol, canceledCount) + logger.Infof(" ✓ 已取消 %s 的 %d 个止盈/止损单", symbol, canceledCount) } return nil @@ -1230,3 +1230,52 @@ func (t *AsterTrader) FormatQuantity(symbol string, quantity float64) (string, e } return fmt.Sprintf("%v", formatted), nil } + +// GetOrderStatus 获取订单状态 +func (t *AsterTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { + params := map[string]interface{}{ + "symbol": symbol, + "orderId": orderID, + } + + body, err := t.request("GET", "/fapi/v3/order", params) + if err != nil { + return nil, fmt.Errorf("获取订单状态失败: %w", err) + } + + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("解析订单响应失败: %w", err) + } + + // 标准化返回字段 + response := map[string]interface{}{ + "orderId": result["orderId"], + "symbol": result["symbol"], + "status": result["status"], + "side": result["side"], + "type": result["type"], + "time": result["time"], + "updateTime": result["updateTime"], + "commission": 0.0, // Aster 可能需要单独查询 + } + + // 解析数值字段 + if avgPrice, ok := result["avgPrice"].(string); ok { + if v, err := strconv.ParseFloat(avgPrice, 64); err == nil { + response["avgPrice"] = v + } + } else if avgPrice, ok := result["avgPrice"].(float64); ok { + response["avgPrice"] = avgPrice + } + + if executedQty, ok := result["executedQty"].(string); ok { + if v, err := strconv.ParseFloat(executedQty, 64); err == nil { + response["executedQty"] = v + } + } else if executedQty, ok := result["executedQty"].(float64); ok { + response["executedQty"] = executedQty + } + + return response, nil +} diff --git a/trader/auto_trader.go b/trader/auto_trader.go index 86ca1cd8..af5d3c55 100644 --- a/trader/auto_trader.go +++ b/trader/auto_trader.go @@ -3,13 +3,13 @@ package trader import ( "encoding/json" "fmt" - "log" + "nofx/logger" "math" "nofx/decision" - "nofx/logger" "nofx/market" "nofx/mcp" "nofx/pool" + "nofx/store" "strings" "sync" "time" @@ -96,7 +96,8 @@ type AutoTrader struct { config AutoTraderConfig trader Trader // 使用Trader接口(支持多平台) mcpClient mcp.AIClient - decisionLogger logger.IDecisionLogger // 决策日志记录器 + store *store.Store // 数据存储(决策记录等) + cycleNumber int // 当前周期编号 initialBalance float64 dailyPnL float64 customPrompt string // 自定义交易策略prompt @@ -115,12 +116,12 @@ type AutoTrader struct { peakPnLCache map[string]float64 // 最高收益缓存 (symbol -> 峰值盈亏百分比) peakPnLCacheMutex sync.RWMutex // 缓存读写锁 lastBalanceSyncTime time.Time // 上次余额同步时间 - database interface{} // 数据库引用(用于自动更新余额) userID string // 用户ID } // NewAutoTrader 创建自动交易器 -func NewAutoTrader(config AutoTraderConfig, database interface{}, userID string) (*AutoTrader, error) { +// st 参数用于存储决策记录到数据库 +func NewAutoTrader(config AutoTraderConfig, st *store.Store, userID string) (*AutoTrader, error) { // 设置默认值 if config.ID == "" { config.ID = "default_trader" @@ -142,24 +143,24 @@ func NewAutoTrader(config AutoTraderConfig, database interface{}, userID string) if config.AIModel == "custom" { // 使用自定义API mcpClient.SetAPIKey(config.CustomAPIKey, config.CustomAPIURL, config.CustomModelName) - log.Printf("🤖 [%s] 使用自定义AI API: %s (模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName) + logger.Infof("🤖 [%s] 使用自定义AI API: %s (模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName) } else if config.UseQwen || config.AIModel == "qwen" { // 使用Qwen (支持自定义URL和Model) mcpClient = mcp.NewQwenClient() mcpClient.SetAPIKey(config.QwenKey, config.CustomAPIURL, config.CustomModelName) if config.CustomAPIURL != "" || config.CustomModelName != "" { - log.Printf("🤖 [%s] 使用阿里云Qwen AI (自定义URL: %s, 模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName) + logger.Infof("🤖 [%s] 使用阿里云Qwen AI (自定义URL: %s, 模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName) } else { - log.Printf("🤖 [%s] 使用阿里云Qwen AI", config.Name) + logger.Infof("🤖 [%s] 使用阿里云Qwen AI", config.Name) } } else { // 默认使用DeepSeek (支持自定义URL和Model) mcpClient = mcp.NewDeepSeekClient() mcpClient.SetAPIKey(config.DeepSeekKey, config.CustomAPIURL, config.CustomModelName) if config.CustomAPIURL != "" || config.CustomModelName != "" { - log.Printf("🤖 [%s] 使用DeepSeek AI (自定义URL: %s, 模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName) + logger.Infof("🤖 [%s] 使用DeepSeek AI (自定义URL: %s, 模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName) } else { - log.Printf("🤖 [%s] 使用DeepSeek AI", config.Name) + logger.Infof("🤖 [%s] 使用DeepSeek AI", config.Name) } } @@ -182,33 +183,33 @@ func NewAutoTrader(config AutoTraderConfig, database interface{}, userID string) if !config.IsCrossMargin { marginModeStr = "逐仓" } - log.Printf("📊 [%s] 仓位模式: %s", config.Name, marginModeStr) + logger.Infof("📊 [%s] 仓位模式: %s", config.Name, marginModeStr) switch config.Exchange { case "binance": - log.Printf("🏦 [%s] 使用币安合约交易", config.Name) + logger.Infof("🏦 [%s] 使用币安合约交易", config.Name) trader = NewFuturesTrader(config.BinanceAPIKey, config.BinanceSecretKey, userID) case "bybit": - log.Printf("🏦 [%s] 使用Bybit合约交易", config.Name) + logger.Infof("🏦 [%s] 使用Bybit合约交易", config.Name) trader = NewBybitTrader(config.BybitAPIKey, config.BybitSecretKey) case "hyperliquid": - log.Printf("🏦 [%s] 使用Hyperliquid交易", config.Name) + logger.Infof("🏦 [%s] 使用Hyperliquid交易", config.Name) trader, err = NewHyperliquidTrader(config.HyperliquidPrivateKey, config.HyperliquidWalletAddr, config.HyperliquidTestnet) if err != nil { return nil, fmt.Errorf("初始化Hyperliquid交易器失败: %w", err) } case "aster": - log.Printf("🏦 [%s] 使用Aster交易", config.Name) + logger.Infof("🏦 [%s] 使用Aster交易", config.Name) trader, err = NewAsterTrader(config.AsterUser, config.AsterSigner, config.AsterPrivateKey) if err != nil { return nil, fmt.Errorf("初始化Aster交易器失败: %w", err) } case "lighter": - log.Printf("🏦 [%s] 使用LIGHTER交易", config.Name) + logger.Infof("🏦 [%s] 使用LIGHTER交易", config.Name) // 優先使用 V2(需要 API Key) if config.LighterAPIKeyPrivateKey != "" { - log.Printf("✓ 使用 LIGHTER SDK (V2) - 完整簽名支持") + logger.Infof("✓ 使用 LIGHTER SDK (V2) - 完整簽名支持") trader, err = NewLighterTraderV2( config.LighterPrivateKey, config.LighterWalletAddr, @@ -220,7 +221,7 @@ func NewAutoTrader(config AutoTraderConfig, database interface{}, userID string) } } else { // 降級使用 V1(基本HTTP實現) - log.Printf("⚠️ 使用 LIGHTER 基本實現 (V1) - 功能受限,請配置 API Key") + logger.Infof("⚠️ 使用 LIGHTER 基本實現 (V1) - 功能受限,請配置 API Key") trader, err = NewLighterTrader(config.LighterPrivateKey, config.LighterWalletAddr, config.LighterTestnet) if err != nil { return nil, fmt.Errorf("初始化LIGHTER交易器(V1)失败: %w", err) @@ -235,9 +236,12 @@ func NewAutoTrader(config AutoTraderConfig, database interface{}, userID string) return nil, fmt.Errorf("初始金额必须大于0,请在配置中设置InitialBalance") } - // 初始化决策日志记录器(使用trader ID创建独立目录) - logDir := fmt.Sprintf("decision_logs/%s", config.ID) - decisionLogger := logger.NewDecisionLogger(logDir) + // 获取最后的周期编号(用于恢复) + var cycleNumber int + if st != nil { + cycleNumber, _ = st.Decision().GetLastCycleNumber(config.ID) + logger.Infof("📊 [%s] 决策记录将存储到数据库", config.Name) + } // 设置默认系统提示词模板 systemPromptTemplate := config.SystemPromptTemplate @@ -254,7 +258,8 @@ func NewAutoTrader(config AutoTraderConfig, database interface{}, userID string) config: config, trader: trader, mcpClient: mcpClient, - decisionLogger: decisionLogger, + store: st, + cycleNumber: cycleNumber, initialBalance: config.InitialBalance, systemPromptTemplate: systemPromptTemplate, defaultCoins: config.DefaultCoins, @@ -268,8 +273,7 @@ func NewAutoTrader(config AutoTraderConfig, database interface{}, userID string) monitorWg: sync.WaitGroup{}, peakPnLCache: make(map[string]float64), peakPnLCacheMutex: sync.RWMutex{}, - lastBalanceSyncTime: time.Now(), // 初始化为当前时间 - database: database, + lastBalanceSyncTime: time.Now(), userID: userID, }, nil } @@ -280,10 +284,10 @@ func (at *AutoTrader) Run() error { at.stopMonitorCh = make(chan struct{}) at.startTime = time.Now() - log.Println("🚀 AI驱动自动交易系统启动") - log.Printf("💰 初始余额: %.2f USDT", at.initialBalance) - log.Printf("⚙️ 扫描间隔: %v", at.config.ScanInterval) - log.Println("🤖 AI将全权决定杠杆、仓位大小、止损止盈等参数") + logger.Info("🚀 AI驱动自动交易系统启动") + logger.Infof("💰 初始余额: %.2f USDT", at.initialBalance) + logger.Infof("⚙️ 扫描间隔: %v", at.config.ScanInterval) + logger.Info("🤖 AI将全权决定杠杆、仓位大小、止损止盈等参数") at.monitorWg.Add(1) defer at.monitorWg.Done() @@ -295,17 +299,17 @@ func (at *AutoTrader) Run() error { // 首次立即执行 if err := at.runCycle(); err != nil { - log.Printf("❌ 执行失败: %v", err) + logger.Infof("❌ 执行失败: %v", err) } for at.isRunning { select { case <-ticker.C: if err := at.runCycle(); err != nil { - log.Printf("❌ 执行失败: %v", err) + logger.Infof("❌ 执行失败: %v", err) } case <-at.stopMonitorCh: - log.Printf("[%s] ⏹ 收到停止信号,退出自动交易主循环", at.name) + logger.Infof("[%s] ⏹ 收到停止信号,退出自动交易主循环", at.name) return nil } } @@ -321,19 +325,19 @@ func (at *AutoTrader) Stop() { at.isRunning = false close(at.stopMonitorCh) // 通知监控goroutine停止 at.monitorWg.Wait() // 等待监控goroutine结束 - log.Println("⏹ 自动交易系统停止") + logger.Info("⏹ 自动交易系统停止") } // runCycle 运行一个交易周期(使用AI全权决策) func (at *AutoTrader) runCycle() error { at.callCount++ - log.Print("\n" + strings.Repeat("=", 70) + "\n") - log.Printf("⏰ %s - AI决策周期 #%d", time.Now().Format("2006-01-02 15:04:05"), at.callCount) - log.Println(strings.Repeat("=", 70)) + logger.Info("\n" + strings.Repeat("=", 70) + "\n") + logger.Infof("⏰ %s - AI决策周期 #%d", time.Now().Format("2006-01-02 15:04:05"), at.callCount) + logger.Info(strings.Repeat("=", 70)) // 创建决策记录 - record := &logger.DecisionRecord{ + record := &store.DecisionRecord{ ExecutionLog: []string{}, Success: true, } @@ -341,10 +345,10 @@ func (at *AutoTrader) runCycle() error { // 1. 检查是否需要停止交易 if time.Now().Before(at.stopUntil) { remaining := at.stopUntil.Sub(time.Now()) - log.Printf("⏸ 风险控制:暂停交易中,剩余 %.0f 分钟", remaining.Minutes()) + logger.Infof("⏸ 风险控制:暂停交易中,剩余 %.0f 分钟", remaining.Minutes()) record.Success = false record.ErrorMessage = fmt.Sprintf("风险控制暂停中,剩余 %.0f 分钟", remaining.Minutes()) - at.decisionLogger.LogDecision(record) + at.saveDecision(record) return nil } @@ -352,7 +356,7 @@ func (at *AutoTrader) runCycle() error { if time.Since(at.lastResetTime) > 24*time.Hour { at.dailyPnL = 0 at.lastResetTime = time.Now() - log.Println("📅 日盈亏已重置") + logger.Info("📅 日盈亏已重置") } // 4. 收集交易上下文 @@ -360,12 +364,12 @@ func (at *AutoTrader) runCycle() error { if err != nil { record.Success = false record.ErrorMessage = fmt.Sprintf("构建交易上下文失败: %v", err) - at.decisionLogger.LogDecision(record) + at.saveDecision(record) return fmt.Errorf("构建交易上下文失败: %w", err) } // 保存账户状态快照 - record.AccountState = logger.AccountSnapshot{ + record.AccountState = store.AccountSnapshot{ TotalBalance: ctx.Account.TotalEquity - ctx.Account.UnrealizedPnL, AvailableBalance: ctx.Account.AvailableBalance, TotalUnrealizedProfit: ctx.Account.UnrealizedPnL, @@ -376,7 +380,7 @@ func (at *AutoTrader) runCycle() error { // 保存持仓快照 for _, pos := range ctx.Positions { - record.Positions = append(record.Positions, logger.PositionSnapshot{ + record.Positions = append(record.Positions, store.PositionSnapshot{ Symbol: pos.Symbol, Side: pos.Side, PositionAmt: pos.Quantity, @@ -388,21 +392,21 @@ func (at *AutoTrader) runCycle() error { }) } - log.Print(strings.Repeat("=", 70)) + logger.Info(strings.Repeat("=", 70)) for _, coin := range ctx.CandidateCoins { record.CandidateCoins = append(record.CandidateCoins, coin.Symbol) } - log.Printf("📊 账户净值: %.2f USDT | 可用: %.2f USDT | 持仓: %d", + logger.Infof("📊 账户净值: %.2f USDT | 可用: %.2f USDT | 持仓: %d", ctx.Account.TotalEquity, ctx.Account.AvailableBalance, ctx.Account.PositionCount) // 5. 调用AI获取完整决策 - log.Printf("🤖 正在请求AI分析并决策... [模板: %s]", at.systemPromptTemplate) + logger.Infof("🤖 正在请求AI分析并决策... [模板: %s]", at.systemPromptTemplate) decision, err := decision.GetFullDecisionWithCustomPrompt(ctx, at.mcpClient, at.customPrompt, at.overrideBasePrompt, at.systemPromptTemplate) if decision != nil && decision.AIRequestDurationMs > 0 { record.AIRequestDurationMs = decision.AIRequestDurationMs - log.Printf("⏱️ AI调用耗时: %.2f 秒", float64(record.AIRequestDurationMs)/1000) + logger.Infof("⏱️ AI调用耗时: %.2f 秒", float64(record.AIRequestDurationMs)/1000) record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("AI调用耗时: %d ms", record.AIRequestDurationMs)) } @@ -424,65 +428,65 @@ func (at *AutoTrader) runCycle() error { // 打印系统提示词和AI思维链(即使有错误,也要输出以便调试) if decision != nil { - log.Print("\n" + strings.Repeat("=", 70) + "\n") - log.Printf("📋 系统提示词 [模板: %s] (错误情况)", at.systemPromptTemplate) - log.Println(strings.Repeat("=", 70)) - log.Println(decision.SystemPrompt) - log.Println(strings.Repeat("=", 70)) + logger.Info("\n" + strings.Repeat("=", 70) + "\n") + logger.Infof("📋 系统提示词 [模板: %s] (错误情况)", at.systemPromptTemplate) + logger.Info(strings.Repeat("=", 70)) + logger.Info(decision.SystemPrompt) + logger.Info(strings.Repeat("=", 70)) if decision.CoTTrace != "" { - log.Print("\n" + strings.Repeat("-", 70) + "\n") - log.Println("💭 AI思维链分析(错误情况):") - log.Println(strings.Repeat("-", 70)) - log.Println(decision.CoTTrace) - log.Println(strings.Repeat("-", 70)) + logger.Info("\n" + strings.Repeat("-", 70) + "\n") + logger.Info("💭 AI思维链分析(错误情况):") + logger.Info(strings.Repeat("-", 70)) + logger.Info(decision.CoTTrace) + logger.Info(strings.Repeat("-", 70)) } } - at.decisionLogger.LogDecision(record) + at.saveDecision(record) return fmt.Errorf("获取AI决策失败: %w", err) } // // 5. 打印系统提示词 - // log.Printf("\n" + strings.Repeat("=", 70)) - // log.Printf("📋 系统提示词 [模板: %s]", at.systemPromptTemplate) - // log.Println(strings.Repeat("=", 70)) - // log.Println(decision.SystemPrompt) - // log.Printf(strings.Repeat("=", 70) + "\n") + // logger.Infof("\n" + strings.Repeat("=", 70)) + // logger.Infof("📋 系统提示词 [模板: %s]", at.systemPromptTemplate) + // logger.Info(strings.Repeat("=", 70)) + // logger.Info(decision.SystemPrompt) + // logger.Infof(strings.Repeat("=", 70) + "\n") // 6. 打印AI思维链 - // log.Printf("\n" + strings.Repeat("-", 70)) - // log.Println("💭 AI思维链分析:") - // log.Println(strings.Repeat("-", 70)) - // log.Println(decision.CoTTrace) - // log.Printf(strings.Repeat("-", 70) + "\n") + // logger.Infof("\n" + strings.Repeat("-", 70)) + // logger.Info("💭 AI思维链分析:") + // logger.Info(strings.Repeat("-", 70)) + // logger.Info(decision.CoTTrace) + // logger.Infof(strings.Repeat("-", 70) + "\n") // 7. 打印AI决策 - // log.Printf("📋 AI决策列表 (%d 个):\n", len(decision.Decisions)) + // logger.Infof("📋 AI决策列表 (%d 个):\n", len(decision.Decisions)) // for i, d := range decision.Decisions { - // log.Printf(" [%d] %s: %s - %s", i+1, d.Symbol, d.Action, d.Reasoning) + // logger.Infof(" [%d] %s: %s - %s", i+1, d.Symbol, d.Action, d.Reasoning) // if d.Action == "open_long" || d.Action == "open_short" { - // log.Printf(" 杠杆: %dx | 仓位: %.2f USDT | 止损: %.4f | 止盈: %.4f", + // logger.Infof(" 杠杆: %dx | 仓位: %.2f USDT | 止损: %.4f | 止盈: %.4f", // d.Leverage, d.PositionSizeUSD, d.StopLoss, d.TakeProfit) // } // } - log.Println() - log.Print(strings.Repeat("-", 70)) + logger.Info() + logger.Info(strings.Repeat("-", 70)) // 8. 对决策排序:确保先平仓后开仓(防止仓位叠加超限) - log.Print(strings.Repeat("-", 70)) + logger.Info(strings.Repeat("-", 70)) // 8. 对决策排序:确保先平仓后开仓(防止仓位叠加超限) sortedDecisions := sortDecisionsByPriority(decision.Decisions) - log.Println("🔄 执行顺序(已优化): 先平仓→后开仓") + logger.Info("🔄 执行顺序(已优化): 先平仓→后开仓") for i, d := range sortedDecisions { - log.Printf(" [%d] %s %s", i+1, d.Symbol, d.Action) + logger.Infof(" [%d] %s %s", i+1, d.Symbol, d.Action) } - log.Println() + logger.Info() // 执行决策并记录结果 for _, d := range sortedDecisions { - actionRecord := logger.DecisionAction{ + actionRecord := store.DecisionAction{ Action: d.Action, Symbol: d.Symbol, Quantity: 0, @@ -493,7 +497,7 @@ func (at *AutoTrader) runCycle() error { } if err := at.executeDecisionWithRecord(&d, &actionRecord); err != nil { - log.Printf("❌ 执行决策失败 (%s %s): %v", d.Symbol, d.Action, err) + logger.Infof("❌ 执行决策失败 (%s %s): %v", d.Symbol, d.Action, err) actionRecord.Error = err.Error() record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("❌ %s %s 失败: %v", d.Symbol, d.Action, err)) } else { @@ -507,8 +511,8 @@ func (at *AutoTrader) runCycle() error { } // 9. 保存决策记录 - if err := at.decisionLogger.LogDecision(record); err != nil { - log.Printf("⚠ 保存决策记录失败: %v", err) + if err := at.saveDecision(record); err != nil { + logger.Infof("⚠ 保存决策记录失败: %v", err) } return nil @@ -636,16 +640,7 @@ func (at *AutoTrader) buildTradingContext() (*decision.Context, error) { marginUsedPct = (totalMarginUsed / totalEquity) * 100 } - // 5. 分析历史表现(最近100个周期,避免长期持仓的交易记录丢失) - // 假设每3分钟一个周期,100个周期 = 5小时,足够覆盖大部分交易 - performance, err := at.decisionLogger.AnalyzePerformance(100) - if err != nil { - log.Printf("⚠️ 分析历史表现失败: %v", err) - // 不影响主流程,继续执行(但设置performance为nil以避免传递错误数据) - performance = nil - } - - // 6. 构建上下文 + // 5. 构建上下文 ctx := &decision.Context{ CurrentTime: time.Now().Format("2006-01-02 15:04:05"), RuntimeMinutes: int(time.Since(at.startTime).Minutes()), @@ -664,14 +659,45 @@ func (at *AutoTrader) buildTradingContext() (*decision.Context, error) { }, Positions: positionInfos, CandidateCoins: candidateCoins, - Performance: performance, // 添加历史表现分析 + } + + // 6. 添加交易统计和历史订单(如果store可用) + if at.store != nil { + // 获取交易统计(使用新的 positions 表) + if stats, err := at.store.Position().GetFullStats(at.id); err == nil { + ctx.TradingStats = &decision.TradingStats{ + TotalTrades: stats.TotalTrades, + WinRate: stats.WinRate, + ProfitFactor: stats.ProfitFactor, + SharpeRatio: stats.SharpeRatio, + TotalPnL: stats.TotalPnL, + AvgWin: stats.AvgWin, + AvgLoss: stats.AvgLoss, + MaxDrawdownPct: stats.MaxDrawdownPct, + } + } + + // 获取最近10条已平仓交易(使用新的 positions 表) + if recentTrades, err := at.store.Position().GetRecentTrades(at.id, 10); err == nil { + for _, trade := range recentTrades { + ctx.RecentOrders = append(ctx.RecentOrders, decision.RecentOrder{ + Symbol: trade.Symbol, + Side: trade.Side, + EntryPrice: trade.EntryPrice, + ExitPrice: trade.ExitPrice, + RealizedPnL: trade.RealizedPnL, + PnLPct: trade.PnLPct, + FilledAt: trade.ExitTime, + }) + } + } } return ctx, nil } // executeDecisionWithRecord 执行AI决策并记录详细信息 -func (at *AutoTrader) executeDecisionWithRecord(decision *decision.Decision, actionRecord *logger.DecisionAction) error { +func (at *AutoTrader) executeDecisionWithRecord(decision *decision.Decision, actionRecord *store.DecisionAction) error { switch decision.Action { case "open_long": return at.executeOpenLongWithRecord(decision, actionRecord) @@ -681,12 +707,6 @@ func (at *AutoTrader) executeDecisionWithRecord(decision *decision.Decision, act return at.executeCloseLongWithRecord(decision, actionRecord) case "close_short": return at.executeCloseShortWithRecord(decision, actionRecord) - case "update_stop_loss": - return at.executeUpdateStopLossWithRecord(decision, actionRecord) - case "update_take_profit": - return at.executeUpdateTakeProfitWithRecord(decision, actionRecord) - case "partial_close": - return at.executePartialCloseWithRecord(decision, actionRecord) case "hold", "wait": // 无需执行,仅记录 return nil @@ -696,8 +716,8 @@ func (at *AutoTrader) executeDecisionWithRecord(decision *decision.Decision, act } // executeOpenLongWithRecord 执行开多仓并记录详细信息 -func (at *AutoTrader) executeOpenLongWithRecord(decision *decision.Decision, actionRecord *logger.DecisionAction) error { - log.Printf(" 📈 开多仓: %s", decision.Symbol) +func (at *AutoTrader) executeOpenLongWithRecord(decision *decision.Decision, actionRecord *store.DecisionAction) error { + logger.Infof(" 📈 开多仓: %s", decision.Symbol) // ⚠️ 关键:检查是否已有同币种同方向持仓,如果有则拒绝开仓(防止仓位叠加超限) positions, err := at.trader.GetPositions() @@ -743,7 +763,7 @@ func (at *AutoTrader) executeOpenLongWithRecord(decision *decision.Decision, act // 设置仓位模式 if err := at.trader.SetMarginMode(decision.Symbol, at.config.IsCrossMargin); err != nil { - log.Printf(" ⚠️ 设置仓位模式失败: %v", err) + logger.Infof(" ⚠️ 设置仓位模式失败: %v", err) // 继续执行,不影响交易 } @@ -758,7 +778,10 @@ func (at *AutoTrader) executeOpenLongWithRecord(decision *decision.Decision, act actionRecord.OrderID = orderID } - log.Printf(" ✓ 开仓成功,订单ID: %v, 数量: %.4f", order["orderId"], quantity) + logger.Infof(" ✓ 开仓成功,订单ID: %v, 数量: %.4f", order["orderId"], quantity) + + // 记录订单到数据库并轮询确认 + at.recordAndConfirmOrder(order, decision.Symbol, "open_long", quantity, marketData.CurrentPrice, decision.Leverage, 0) // 记录开仓时间 posKey := decision.Symbol + "_long" @@ -766,18 +789,18 @@ func (at *AutoTrader) executeOpenLongWithRecord(decision *decision.Decision, act // 设置止损止盈 if err := at.trader.SetStopLoss(decision.Symbol, "LONG", quantity, decision.StopLoss); err != nil { - log.Printf(" ⚠ 设置止损失败: %v", err) + logger.Infof(" ⚠ 设置止损失败: %v", err) } if err := at.trader.SetTakeProfit(decision.Symbol, "LONG", quantity, decision.TakeProfit); err != nil { - log.Printf(" ⚠ 设置止盈失败: %v", err) + logger.Infof(" ⚠ 设置止盈失败: %v", err) } return nil } // executeOpenShortWithRecord 执行开空仓并记录详细信息 -func (at *AutoTrader) executeOpenShortWithRecord(decision *decision.Decision, actionRecord *logger.DecisionAction) error { - log.Printf(" 📉 开空仓: %s", decision.Symbol) +func (at *AutoTrader) executeOpenShortWithRecord(decision *decision.Decision, actionRecord *store.DecisionAction) error { + logger.Infof(" 📉 开空仓: %s", decision.Symbol) // ⚠️ 关键:检查是否已有同币种同方向持仓,如果有则拒绝开仓(防止仓位叠加超限) positions, err := at.trader.GetPositions() @@ -823,7 +846,7 @@ func (at *AutoTrader) executeOpenShortWithRecord(decision *decision.Decision, ac // 设置仓位模式 if err := at.trader.SetMarginMode(decision.Symbol, at.config.IsCrossMargin); err != nil { - log.Printf(" ⚠️ 设置仓位模式失败: %v", err) + logger.Infof(" ⚠️ 设置仓位模式失败: %v", err) // 继续执行,不影响交易 } @@ -838,7 +861,10 @@ func (at *AutoTrader) executeOpenShortWithRecord(decision *decision.Decision, ac actionRecord.OrderID = orderID } - log.Printf(" ✓ 开仓成功,订单ID: %v, 数量: %.4f", order["orderId"], quantity) + logger.Infof(" ✓ 开仓成功,订单ID: %v, 数量: %.4f", order["orderId"], quantity) + + // 记录订单到数据库并轮询确认 + at.recordAndConfirmOrder(order, decision.Symbol, "open_short", quantity, marketData.CurrentPrice, decision.Leverage, 0) // 记录开仓时间 posKey := decision.Symbol + "_short" @@ -846,18 +872,18 @@ func (at *AutoTrader) executeOpenShortWithRecord(decision *decision.Decision, ac // 设置止损止盈 if err := at.trader.SetStopLoss(decision.Symbol, "SHORT", quantity, decision.StopLoss); err != nil { - log.Printf(" ⚠ 设置止损失败: %v", err) + logger.Infof(" ⚠ 设置止损失败: %v", err) } if err := at.trader.SetTakeProfit(decision.Symbol, "SHORT", quantity, decision.TakeProfit); err != nil { - log.Printf(" ⚠ 设置止盈失败: %v", err) + logger.Infof(" ⚠ 设置止盈失败: %v", err) } return nil } // executeCloseLongWithRecord 执行平多仓并记录详细信息 -func (at *AutoTrader) executeCloseLongWithRecord(decision *decision.Decision, actionRecord *logger.DecisionAction) error { - log.Printf(" 🔄 平多仓: %s", decision.Symbol) +func (at *AutoTrader) executeCloseLongWithRecord(decision *decision.Decision, actionRecord *store.DecisionAction) error { + logger.Infof(" 🔄 平多仓: %s", decision.Symbol) // 获取当前价格 marketData, err := market.Get(decision.Symbol) @@ -866,6 +892,16 @@ func (at *AutoTrader) executeCloseLongWithRecord(decision *decision.Decision, ac } actionRecord.Price = marketData.CurrentPrice + // 获取开仓价格(用于计算盈亏) + var entryPrice float64 + var quantity float64 + if at.store != nil { + if openOrder, err := at.store.Order().GetLatestOpenOrder(at.id, decision.Symbol, "long"); err == nil { + entryPrice = openOrder.AvgPrice + quantity = openOrder.ExecutedQty + } + } + // 平仓 order, err := at.trader.CloseLong(decision.Symbol, 0) // 0 = 全部平仓 if err != nil { @@ -877,13 +913,16 @@ func (at *AutoTrader) executeCloseLongWithRecord(decision *decision.Decision, ac actionRecord.OrderID = orderID } - log.Printf(" ✓ 平仓成功") + // 记录订单到数据库并轮询确认 + at.recordAndConfirmOrder(order, decision.Symbol, "close_long", quantity, marketData.CurrentPrice, 0, entryPrice) + + logger.Infof(" ✓ 平仓成功") return nil } // executeCloseShortWithRecord 执行平空仓并记录详细信息 -func (at *AutoTrader) executeCloseShortWithRecord(decision *decision.Decision, actionRecord *logger.DecisionAction) error { - log.Printf(" 🔄 平空仓: %s", decision.Symbol) +func (at *AutoTrader) executeCloseShortWithRecord(decision *decision.Decision, actionRecord *store.DecisionAction) error { + logger.Infof(" 🔄 平空仓: %s", decision.Symbol) // 获取当前价格 marketData, err := market.Get(decision.Symbol) @@ -892,6 +931,16 @@ func (at *AutoTrader) executeCloseShortWithRecord(decision *decision.Decision, a } actionRecord.Price = marketData.CurrentPrice + // 获取开仓价格(用于计算盈亏) + var entryPrice float64 + var quantity float64 + if at.store != nil { + if openOrder, err := at.store.Order().GetLatestOpenOrder(at.id, decision.Symbol, "short"); err == nil { + entryPrice = openOrder.AvgPrice + quantity = openOrder.ExecutedQty + } + } + // 平仓 order, err := at.trader.CloseShort(decision.Symbol, 0) // 0 = 全部平仓 if err != nil { @@ -903,302 +952,10 @@ func (at *AutoTrader) executeCloseShortWithRecord(decision *decision.Decision, a actionRecord.OrderID = orderID } - log.Printf(" ✓ 平仓成功") - return nil -} - -// executeUpdateStopLossWithRecord 执行调整止损并记录详细信息 -func (at *AutoTrader) executeUpdateStopLossWithRecord(decision *decision.Decision, actionRecord *logger.DecisionAction) error { - log.Printf(" 🎯 调整止损: %s → %.2f", decision.Symbol, decision.NewStopLoss) - - // 获取当前价格 - marketData, err := market.Get(decision.Symbol) - if err != nil { - return err - } - actionRecord.Price = marketData.CurrentPrice - - // 获取当前持仓 - positions, err := at.trader.GetPositions() - if err != nil { - return fmt.Errorf("获取持仓失败: %w", err) - } - - // 查找目标持仓 - var targetPosition map[string]interface{} - for _, pos := range positions { - symbol, _ := pos["symbol"].(string) - posAmt, _ := pos["positionAmt"].(float64) - if symbol == decision.Symbol && posAmt != 0 { - targetPosition = pos - break - } - } - - if targetPosition == nil { - return fmt.Errorf("持仓不存在: %s", decision.Symbol) - } - - // 获取持仓方向和数量 - side, _ := targetPosition["side"].(string) - positionSide := strings.ToUpper(side) - positionAmt, _ := targetPosition["positionAmt"].(float64) - - // 验证新止损价格合理性 - if positionSide == "LONG" && decision.NewStopLoss >= marketData.CurrentPrice { - return fmt.Errorf("多单止损必须低于当前价格 (当前: %.2f, 新止损: %.2f)", marketData.CurrentPrice, decision.NewStopLoss) - } - if positionSide == "SHORT" && decision.NewStopLoss <= marketData.CurrentPrice { - return fmt.Errorf("空单止损必须高于当前价格 (当前: %.2f, 新止损: %.2f)", marketData.CurrentPrice, decision.NewStopLoss) - } - - // ⚠️ 防御性检查:检测是否存在双向持仓(不应该出现,但提供保护) - var hasOppositePosition bool - oppositeSide := "" - for _, pos := range positions { - symbol, _ := pos["symbol"].(string) - posSide, _ := pos["side"].(string) - posAmt, _ := pos["positionAmt"].(float64) - if symbol == decision.Symbol && posAmt != 0 && strings.ToUpper(posSide) != positionSide { - hasOppositePosition = true - oppositeSide = strings.ToUpper(posSide) - break - } - } - - if hasOppositePosition { - log.Printf(" 🚨 警告:检测到 %s 存在双向持仓(%s + %s),这违反了策略规则", - decision.Symbol, positionSide, oppositeSide) - log.Printf(" 🚨 取消止损单将影响两个方向的订单,请检查是否为用户手动操作导致") - log.Printf(" 🚨 建议:手动平掉其中一个方向的持仓,或检查系统是否有BUG") - } - - // 取消旧的止损单(只删除止损单,不影响止盈单) - // 注意:如果存在双向持仓,这会删除两个方向的止损单 - if err := at.trader.CancelStopLossOrders(decision.Symbol); err != nil { - log.Printf(" ⚠ 取消旧止损单失败: %v", err) - // 不中断执行,继续设置新止损 - } - - // 调用交易所 API 修改止损 - quantity := math.Abs(positionAmt) - err = at.trader.SetStopLoss(decision.Symbol, positionSide, quantity, decision.NewStopLoss) - if err != nil { - return fmt.Errorf("修改止损失败: %w", err) - } - - log.Printf(" ✓ 止损已调整: %.2f (当前价格: %.2f)", decision.NewStopLoss, marketData.CurrentPrice) - return nil -} - -// executeUpdateTakeProfitWithRecord 执行调整止盈并记录详细信息 -func (at *AutoTrader) executeUpdateTakeProfitWithRecord(decision *decision.Decision, actionRecord *logger.DecisionAction) error { - log.Printf(" 🎯 调整止盈: %s → %.2f", decision.Symbol, decision.NewTakeProfit) - - // 获取当前价格 - marketData, err := market.Get(decision.Symbol) - if err != nil { - return err - } - actionRecord.Price = marketData.CurrentPrice - - // 获取当前持仓 - positions, err := at.trader.GetPositions() - if err != nil { - return fmt.Errorf("获取持仓失败: %w", err) - } - - // 查找目标持仓 - var targetPosition map[string]interface{} - for _, pos := range positions { - symbol, _ := pos["symbol"].(string) - posAmt, _ := pos["positionAmt"].(float64) - if symbol == decision.Symbol && posAmt != 0 { - targetPosition = pos - break - } - } - - if targetPosition == nil { - return fmt.Errorf("持仓不存在: %s", decision.Symbol) - } - - // 获取持仓方向和数量 - side, _ := targetPosition["side"].(string) - positionSide := strings.ToUpper(side) - positionAmt, _ := targetPosition["positionAmt"].(float64) - - // 验证新止盈价格合理性 - if positionSide == "LONG" && decision.NewTakeProfit <= marketData.CurrentPrice { - return fmt.Errorf("多单止盈必须高于当前价格 (当前: %.2f, 新止盈: %.2f)", marketData.CurrentPrice, decision.NewTakeProfit) - } - if positionSide == "SHORT" && decision.NewTakeProfit >= marketData.CurrentPrice { - return fmt.Errorf("空单止盈必须低于当前价格 (当前: %.2f, 新止盈: %.2f)", marketData.CurrentPrice, decision.NewTakeProfit) - } - - // ⚠️ 防御性检查:检测是否存在双向持仓(不应该出现,但提供保护) - var hasOppositePosition bool - oppositeSide := "" - for _, pos := range positions { - symbol, _ := pos["symbol"].(string) - posSide, _ := pos["side"].(string) - posAmt, _ := pos["positionAmt"].(float64) - if symbol == decision.Symbol && posAmt != 0 && strings.ToUpper(posSide) != positionSide { - hasOppositePosition = true - oppositeSide = strings.ToUpper(posSide) - break - } - } - - if hasOppositePosition { - log.Printf(" 🚨 警告:检测到 %s 存在双向持仓(%s + %s),这违反了策略规则", - decision.Symbol, positionSide, oppositeSide) - log.Printf(" 🚨 取消止盈单将影响两个方向的订单,请检查是否为用户手动操作导致") - log.Printf(" 🚨 建议:手动平掉其中一个方向的持仓,或检查系统是否有BUG") - } - - // 取消旧的止盈单(只删除止盈单,不影响止损单) - // 注意:如果存在双向持仓,这会删除两个方向的止盈单 - if err := at.trader.CancelTakeProfitOrders(decision.Symbol); err != nil { - log.Printf(" ⚠ 取消旧止盈单失败: %v", err) - // 不中断执行,继续设置新止盈 - } - - // 调用交易所 API 修改止盈 - quantity := math.Abs(positionAmt) - err = at.trader.SetTakeProfit(decision.Symbol, positionSide, quantity, decision.NewTakeProfit) - if err != nil { - return fmt.Errorf("修改止盈失败: %w", err) - } - - log.Printf(" ✓ 止盈已调整: %.2f (当前价格: %.2f)", decision.NewTakeProfit, marketData.CurrentPrice) - return nil -} - -// executePartialCloseWithRecord 执行部分平仓并记录详细信息 -func (at *AutoTrader) executePartialCloseWithRecord(decision *decision.Decision, actionRecord *logger.DecisionAction) error { - log.Printf(" 📊 部分平仓: %s %.1f%%", decision.Symbol, decision.ClosePercentage) - - // 验证百分比范围 - if decision.ClosePercentage <= 0 || decision.ClosePercentage > 100 { - return fmt.Errorf("平仓百分比必须在 0-100 之间,当前: %.1f", decision.ClosePercentage) - } - - // 获取当前价格 - marketData, err := market.Get(decision.Symbol) - if err != nil { - return err - } - actionRecord.Price = marketData.CurrentPrice - - // 获取当前持仓 - positions, err := at.trader.GetPositions() - if err != nil { - return fmt.Errorf("获取持仓失败: %w", err) - } - - // 查找目标持仓 - var targetPosition map[string]interface{} - for _, pos := range positions { - symbol, _ := pos["symbol"].(string) - posAmt, _ := pos["positionAmt"].(float64) - if symbol == decision.Symbol && posAmt != 0 { - targetPosition = pos - break - } - } - - if targetPosition == nil { - return fmt.Errorf("持仓不存在: %s", decision.Symbol) - } - - // 获取持仓方向和数量 - side, _ := targetPosition["side"].(string) - positionSide := strings.ToUpper(side) - positionAmt, _ := targetPosition["positionAmt"].(float64) - - // 计算平仓数量 - totalQuantity := math.Abs(positionAmt) - closeQuantity := totalQuantity * (decision.ClosePercentage / 100.0) - actionRecord.Quantity = closeQuantity - - // ✅ Layer 2: 最小仓位检查(防止产生小额剩余) - markPrice, ok := targetPosition["markPrice"].(float64) - if !ok || markPrice <= 0 { - return fmt.Errorf("无法解析当前价格,无法执行最小仓位检查") - } - - currentPositionValue := totalQuantity * markPrice - remainingQuantity := totalQuantity - closeQuantity - remainingValue := remainingQuantity * markPrice - - const MIN_POSITION_VALUE = 10.0 // 最小持仓价值 10 USDT(對齊交易所底线,小仓位建议直接全平) - - if remainingValue > 0 && remainingValue <= MIN_POSITION_VALUE { - log.Printf("⚠️ 检测到 partial_close 后剩余仓位 %.2f USDT < %.0f USDT", - remainingValue, MIN_POSITION_VALUE) - log.Printf(" → 当前仓位价值: %.2f USDT, 平仓 %.1f%%, 剩余: %.2f USDT", - currentPositionValue, decision.ClosePercentage, remainingValue) - log.Printf(" → 自动修正为全部平仓,避免产生无法平仓的小额剩余") - - // 🔄 自动修正为全部平仓 - if positionSide == "LONG" { - decision.Action = "close_long" - log.Printf(" ✓ 已修正为: close_long") - return at.executeCloseLongWithRecord(decision, actionRecord) - } else { - decision.Action = "close_short" - log.Printf(" ✓ 已修正为: close_short") - return at.executeCloseShortWithRecord(decision, actionRecord) - } - } - - // 执行平仓 - var order map[string]interface{} - if positionSide == "LONG" { - order, err = at.trader.CloseLong(decision.Symbol, closeQuantity) - } else { - order, err = at.trader.CloseShort(decision.Symbol, closeQuantity) - } - - if err != nil { - return fmt.Errorf("部分平仓失败: %w", err) - } - - // 记录订单ID - if orderID, ok := order["orderId"].(int64); ok { - actionRecord.OrderID = orderID - } - - log.Printf(" ✓ 部分平仓成功: 平仓 %.4f (%.1f%%), 剩余 %.4f", - closeQuantity, decision.ClosePercentage, remainingQuantity) - - // ✅ Step 4: 恢复止盈止损(防止剩余仓位裸奔) - // 重要:币安等交易所在部分平仓后会自动取消原有的 TP/SL 订单(因为数量不匹配) - // 如果 AI 提供了新的止损止盈价格,则为剩余仓位重新设置保护 - if decision.NewStopLoss > 0 { - log.Printf(" → 为剩余仓位 %.4f 恢复止损单: %.2f", remainingQuantity, decision.NewStopLoss) - err = at.trader.SetStopLoss(decision.Symbol, positionSide, remainingQuantity, decision.NewStopLoss) - if err != nil { - log.Printf(" ⚠️ 恢复止损失败: %v(不影响平仓结果)", err) - } - } - - if decision.NewTakeProfit > 0 { - log.Printf(" → 为剩余仓位 %.4f 恢复止盈单: %.2f", remainingQuantity, decision.NewTakeProfit) - err = at.trader.SetTakeProfit(decision.Symbol, positionSide, remainingQuantity, decision.NewTakeProfit) - if err != nil { - log.Printf(" ⚠️ 恢复止盈失败: %v(不影响平仓结果)", err) - } - } - - // 如果 AI 没有提供新的止盈止损,记录警告 - if decision.NewStopLoss <= 0 && decision.NewTakeProfit <= 0 { - log.Printf(" ⚠️⚠️⚠️ 警告: 部分平仓后AI未提供新的止盈止损价格") - log.Printf(" → 剩余仓位 %.4f (价值 %.2f USDT) 目前没有止盈止损保护", remainingQuantity, remainingValue) - log.Printf(" → 建议: 在 partial_close 决策中包含 new_stop_loss 和 new_take_profit 字段") - } + // 记录订单到数据库并轮询确认 + at.recordAndConfirmOrder(order, decision.Symbol, "close_short", quantity, marketData.CurrentPrice, 0, entryPrice) + logger.Infof(" ✓ 平仓成功") return nil } @@ -1242,9 +999,32 @@ func (at *AutoTrader) GetSystemPromptTemplate() string { return at.systemPromptTemplate } -// GetDecisionLogger 获取决策日志记录器 -func (at *AutoTrader) GetDecisionLogger() logger.IDecisionLogger { - return at.decisionLogger +// saveDecision 保存决策记录到数据库 +func (at *AutoTrader) saveDecision(record *store.DecisionRecord) error { + if at.store == nil { + return nil // 没有 store 时静默忽略 + } + + at.cycleNumber++ + record.CycleNumber = at.cycleNumber + record.TraderID = at.id + + if record.Timestamp.IsZero() { + record.Timestamp = time.Now().UTC() + } + + if err := at.store.Decision().LogDecision(record); err != nil { + logger.Infof("⚠️ 保存决策记录失败: %v", err) + return err + } + + logger.Infof("📝 决策记录已保存: trader=%s, cycle=%d", at.id, at.cycleNumber) + return nil +} + +// GetStore 获取数据存储(用于外部访问决策记录等) +func (at *AutoTrader) GetStore() *store.Store { + return at.store } // GetStatus 获取系统状态(用于API) @@ -1324,7 +1104,7 @@ func (at *AutoTrader) GetAccountInfo() (map[string]interface{}, error) { // 验证未实现盈亏的一致性(API值 vs 从持仓计算) diff := math.Abs(totalUnrealizedProfit - totalUnrealizedPnLCalculated) if diff > 0.1 { // 允许0.01 USDT的误差 - log.Printf("⚠️ 未实现盈亏不一致: API=%.4f, 计算=%.4f, 差异=%.4f", + logger.Infof("⚠️ 未实现盈亏不一致: API=%.4f, 计算=%.4f, 差异=%.4f", totalUnrealizedProfit, totalUnrealizedPnLCalculated, diff) } @@ -1333,7 +1113,7 @@ func (at *AutoTrader) GetAccountInfo() (map[string]interface{}, error) { if at.initialBalance > 0 { totalPnLPct = (totalPnL / at.initialBalance) * 100 } else { - log.Printf("⚠️ Initial Balance异常: %.2f,无法计算PNL百分比", at.initialBalance) + logger.Infof("⚠️ Initial Balance异常: %.2f,无法计算PNL百分比", at.initialBalance) } marginUsedPct := 0.0 @@ -1428,14 +1208,12 @@ func sortDecisionsByPriority(decisions []decision.Decision) []decision.Decision // 定义优先级 getActionPriority := func(action string) int { switch action { - case "close_long", "close_short", "partial_close": - return 1 // 最高优先级:先平仓(包括部分平仓) - case "update_stop_loss", "update_take_profit": - return 2 // 调整持仓止盈止损 + case "close_long", "close_short": + return 1 // 最高优先级:先平仓 case "open_long", "open_short": - return 3 // 次优先级:后开仓 + return 2 // 次优先级:后开仓 case "hold", "wait": - return 4 // 最低优先级:观望 + return 3 // 最低优先级:观望 default: return 999 // 未知动作放最后 } @@ -1472,7 +1250,7 @@ func (at *AutoTrader) getCandidateCoins() ([]decision.CandidateCoin, error) { Sources: []string{"default"}, // 标记为数据库默认币种 }) } - log.Printf("📋 [%s] 使用数据库默认币种: %d个币种 %v", + logger.Infof("📋 [%s] 使用数据库默认币种: %d个币种 %v", at.name, len(candidateCoins), at.defaultCoins) return candidateCoins, nil } else { @@ -1493,7 +1271,7 @@ func (at *AutoTrader) getCandidateCoins() ([]decision.CandidateCoin, error) { }) } - log.Printf("📋 [%s] 数据库无默认币种配置,使用AI500+OI Top: AI500前%d + OI_Top20 = 总计%d个候选币种", + logger.Infof("📋 [%s] 数据库无默认币种配置,使用AI500+OI Top: AI500前%d + OI_Top20 = 总计%d个候选币种", at.name, ai500Limit, len(candidateCoins)) return candidateCoins, nil } @@ -1509,7 +1287,7 @@ func (at *AutoTrader) getCandidateCoins() ([]decision.CandidateCoin, error) { }) } - log.Printf("📋 [%s] 使用自定义币种: %d个币种 %v", + logger.Infof("📋 [%s] 使用自定义币种: %d个币种 %v", at.name, len(candidateCoins), at.tradingCoins) return candidateCoins, nil } @@ -1537,14 +1315,14 @@ func (at *AutoTrader) startDrawdownMonitor() { ticker := time.NewTicker(1 * time.Minute) // 每分钟检查一次 defer ticker.Stop() - log.Println("📊 启动持仓回撤监控(每分钟检查一次)") + logger.Info("📊 启动持仓回撤监控(每分钟检查一次)") for { select { case <-ticker.C: at.checkPositionDrawdown() case <-at.stopMonitorCh: - log.Println("⏹ 停止持仓回撤监控") + logger.Info("⏹ 停止持仓回撤监控") return } } @@ -1556,7 +1334,7 @@ func (at *AutoTrader) checkPositionDrawdown() { // 获取当前持仓 positions, err := at.trader.GetPositions() if err != nil { - log.Printf("❌ 回撤监控:获取持仓失败: %v", err) + logger.Infof("❌ 回撤监控:获取持仓失败: %v", err) return } @@ -1608,20 +1386,20 @@ func (at *AutoTrader) checkPositionDrawdown() { // 检查平仓条件:收益大于5%且回撤超过40% if currentPnLPct > 5.0 && drawdownPct >= 40.0 { - log.Printf("🚨 触发回撤平仓条件: %s %s | 当前收益: %.2f%% | 最高收益: %.2f%% | 回撤: %.2f%%", + logger.Infof("🚨 触发回撤平仓条件: %s %s | 当前收益: %.2f%% | 最高收益: %.2f%% | 回撤: %.2f%%", symbol, side, currentPnLPct, peakPnLPct, drawdownPct) // 执行平仓 if err := at.emergencyClosePosition(symbol, side); err != nil { - log.Printf("❌ 回撤平仓失败 (%s %s): %v", symbol, side, err) + logger.Infof("❌ 回撤平仓失败 (%s %s): %v", symbol, side, err) } else { - log.Printf("✅ 回撤平仓成功: %s %s", symbol, side) + logger.Infof("✅ 回撤平仓成功: %s %s", symbol, side) // 平仓后清理该持仓的缓存 at.ClearPeakPnLCache(symbol, side) } } else if currentPnLPct > 5.0 { // 记录接近平仓条件的情况(用于调试) - log.Printf("📊 回撤监控: %s %s | 收益: %.2f%% | 最高: %.2f%% | 回撤: %.2f%%", + logger.Infof("📊 回撤监控: %s %s | 收益: %.2f%% | 最高: %.2f%% | 回撤: %.2f%%", symbol, side, currentPnLPct, peakPnLPct, drawdownPct) } } @@ -1635,13 +1413,13 @@ func (at *AutoTrader) emergencyClosePosition(symbol, side string) error { if err != nil { return err } - log.Printf("✅ 紧急平多仓成功,订单ID: %v", order["orderId"]) + logger.Infof("✅ 紧急平多仓成功,订单ID: %v", order["orderId"]) case "short": order, err := at.trader.CloseShort(symbol, 0) // 0 = 全部平仓 if err != nil { return err } - log.Printf("✅ 紧急平空仓成功,订单ID: %v", order["orderId"]) + logger.Infof("✅ 紧急平空仓成功,订单ID: %v", order["orderId"]) default: return fmt.Errorf("未知的持仓方向: %s", side) } @@ -1687,3 +1465,135 @@ func (at *AutoTrader) ClearPeakPnLCache(symbol, side string) { posKey := symbol + "_" + side delete(at.peakPnLCache, posKey) } + +// recordAndConfirmOrder 记录订单并轮询确认状态 +// action: open_long, open_short, close_long, close_short +// entryPrice: 平仓时的开仓价(开仓时为0) +func (at *AutoTrader) recordAndConfirmOrder(orderResult map[string]interface{}, symbol, action string, quantity float64, price float64, leverage int, entryPrice float64) { + if at.store == nil { + return + } + + // 获取订单ID(支持多种类型) + var orderID string + switch v := orderResult["orderId"].(type) { + case int64: + orderID = fmt.Sprintf("%d", v) + case float64: + orderID = fmt.Sprintf("%.0f", v) + case string: + orderID = v + default: + orderID = fmt.Sprintf("%v", v) + } + + if orderID == "" || orderID == "0" { + logger.Infof(" ⚠️ 订单ID为空,跳过记录") + return + } + + // 确定 side 和 positionSide + var side, positionSide string + switch action { + case "open_long": + side = "BUY" + positionSide = "LONG" + case "close_long": + side = "SELL" + positionSide = "LONG" + case "open_short": + side = "SELL" + positionSide = "SHORT" + case "close_short": + side = "BUY" + positionSide = "SHORT" + } + + // 创建订单记录 + order := &store.TraderOrder{ + TraderID: at.id, + OrderID: orderID, + Symbol: symbol, + Side: side, + PositionSide: positionSide, + Action: action, + OrderType: "MARKET", + Quantity: quantity, + Price: price, + Leverage: leverage, + Status: "NEW", + EntryPrice: entryPrice, + } + + // 保存到数据库 + if err := at.store.Order().Create(order); err != nil { + logger.Infof(" ⚠️ 记录订单失败: %v", err) + return + } + + logger.Infof(" 📝 订单已记录 (ID: %s, action: %s)", orderID, action) + + // 记录仓位变化 + at.recordPositionChange(orderID, symbol, positionSide, action, quantity, price, leverage, entryPrice) +} + +// recordPositionChange 记录仓位变化(开仓创建记录,平仓更新记录) +func (at *AutoTrader) recordPositionChange(orderID, symbol, side, action string, quantity, price float64, leverage int, entryPrice float64) { + if at.store == nil { + return + } + + switch action { + case "open_long", "open_short": + // 开仓:创建新的仓位记录 + pos := &store.TraderPosition{ + TraderID: at.id, + Symbol: symbol, + Side: side, // LONG or SHORT + Quantity: quantity, + EntryPrice: price, + EntryOrderID: orderID, + EntryTime: time.Now(), + Leverage: leverage, + Status: "OPEN", + } + if err := at.store.Position().Create(pos); err != nil { + logger.Infof(" ⚠️ 记录仓位失败: %v", err) + } else { + logger.Infof(" 📊 仓位已记录 [%s] %s %s @ %.4f", at.id[:8], symbol, side, price) + } + + case "close_long", "close_short": + // 平仓:找到对应的开仓记录并更新 + openPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, symbol, side) + if err != nil || openPos == nil { + logger.Infof(" ⚠️ 找不到对应的开仓记录 (%s %s)", symbol, side) + return + } + + // 计算盈亏 + var realizedPnL float64 + if side == "LONG" { + realizedPnL = (price - openPos.EntryPrice) * openPos.Quantity + } else { + realizedPnL = (openPos.EntryPrice - price) * openPos.Quantity + } + + // 更新仓位记录 + err = at.store.Position().ClosePosition( + openPos.ID, + price, // exitPrice + orderID, // exitOrderID + realizedPnL, + 0, // fee (暂不计算) + "ai_decision", + ) + if err != nil { + logger.Infof(" ⚠️ 更新仓位失败: %v", err) + } else { + logger.Infof(" 📊 仓位已平仓 [%s] %s %s @ %.4f → %.4f, PnL: %.2f", + at.id[:8], symbol, side, openPos.EntryPrice, price, realizedPnL) + } + } +} + diff --git a/trader/auto_trader_test.go b/trader/auto_trader_test.go index 9316981f..8981ca81 100644 --- a/trader/auto_trader_test.go +++ b/trader/auto_trader_test.go @@ -8,9 +8,9 @@ import ( "time" "nofx/decision" - "nofx/logger" "nofx/market" "nofx/pool" + "nofx/store" "github.com/agiledragon/gomonkey/v2" "github.com/stretchr/testify/suite" @@ -30,8 +30,7 @@ type AutoTraderTestSuite struct { // Mock 依赖 mockTrader *MockTrader - mockDB *MockDatabase - mockLogger logger.IDecisionLogger + mockStore *store.Store // gomonkey patches patches *gomonkey.Patches @@ -65,10 +64,9 @@ func (s *AutoTraderTestSuite) SetupTest() { positions: []map[string]interface{}{}, } - s.mockDB = &MockDatabase{} - // 创建临时决策日志记录器 - s.mockLogger = logger.NewDecisionLogger("/tmp/test_decision_logs") + // 创建临时store(使用nil表示测试中不需要实际的store) + s.mockStore = nil // 设置默认配置 s.config = AutoTraderConfig{ @@ -93,7 +91,7 @@ func (s *AutoTraderTestSuite) SetupTest() { config: s.config, trader: s.mockTrader, mcpClient: nil, // 测试中不需要实际的 MCP Client - decisionLogger: s.mockLogger, + store: s.mockStore, initialBalance: s.config.InitialBalance, systemPromptTemplate: s.config.SystemPromptTemplate, defaultCoins: []string{"BTC", "ETH"}, @@ -106,7 +104,6 @@ func (s *AutoTraderTestSuite) SetupTest() { stopMonitorCh: make(chan struct{}), peakPnLCache: make(map[string]float64), lastBalanceSyncTime: time.Now(), - database: s.mockDB, userID: "test_user", } } @@ -134,9 +131,8 @@ func (s *AutoTraderTestSuite) TestSortDecisionsByPriority() { {Action: "open_long", Symbol: "BTCUSDT"}, {Action: "close_short", Symbol: "ETHUSDT"}, {Action: "hold", Symbol: "BNBUSDT"}, - {Action: "update_stop_loss", Symbol: "SOLUSDT"}, {Action: "open_short", Symbol: "ADAUSDT"}, - {Action: "partial_close", Symbol: "DOGEUSDT"}, + {Action: "close_long", Symbol: "DOGEUSDT"}, }, }, } @@ -150,14 +146,12 @@ func (s *AutoTraderTestSuite) TestSortDecisionsByPriority() { // 验证优先级是否递增 getActionPriority := func(action string) int { switch action { - case "close_long", "close_short", "partial_close": + case "close_long", "close_short": return 1 - case "update_stop_loss", "update_take_profit": - return 2 case "open_long", "open_short": - return 3 + return 2 case "hold", "wait": - return 4 + return 3 default: return 999 } @@ -413,14 +407,14 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() { existingSide string availBalance float64 expectedErr string - executeFn func(*decision.Decision, *logger.DecisionAction) error + executeFn func(*decision.Decision, *store.DecisionAction) error }{ { name: "成功开多仓", action: "open_long", expectedOrder: 123456, availBalance: 8000.0, - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + executeFn: func(d *decision.Decision, a *store.DecisionAction) error { return s.autoTrader.executeOpenLongWithRecord(d, a) }, }, @@ -429,7 +423,7 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() { action: "open_short", expectedOrder: 123457, availBalance: 8000.0, - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + executeFn: func(d *decision.Decision, a *store.DecisionAction) error { return s.autoTrader.executeOpenShortWithRecord(d, a) }, }, @@ -438,7 +432,7 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() { action: "open_long", availBalance: 0.0, expectedErr: "保证金不足", - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + executeFn: func(d *decision.Decision, a *store.DecisionAction) error { return s.autoTrader.executeOpenLongWithRecord(d, a) }, }, @@ -447,7 +441,7 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() { action: "open_short", availBalance: 0.0, expectedErr: "保证金不足", - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + executeFn: func(d *decision.Decision, a *store.DecisionAction) error { return s.autoTrader.executeOpenShortWithRecord(d, a) }, }, @@ -457,7 +451,7 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() { existingSide: "long", availBalance: 8000.0, expectedErr: "已有多仓", - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + executeFn: func(d *decision.Decision, a *store.DecisionAction) error { return s.autoTrader.executeOpenLongWithRecord(d, a) }, }, @@ -467,7 +461,7 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() { existingSide: "short", availBalance: 8000.0, expectedErr: "已有空仓", - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + executeFn: func(d *decision.Decision, a *store.DecisionAction) error { return s.autoTrader.executeOpenShortWithRecord(d, a) }, }, @@ -488,7 +482,7 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() { } decision := &decision.Decision{Action: tt.action, Symbol: "BTCUSDT", PositionSizeUSD: 1000.0, Leverage: 10} - actionRecord := &logger.DecisionAction{Action: tt.action, Symbol: "BTCUSDT"} + actionRecord := &store.DecisionAction{Action: tt.action, Symbol: "BTCUSDT"} err := tt.executeFn(decision, actionRecord) @@ -516,14 +510,14 @@ func (s *AutoTraderTestSuite) TestExecuteClosePosition() { action string currentPrice float64 expectedOrder int64 - executeFn func(*decision.Decision, *logger.DecisionAction) error + executeFn func(*decision.Decision, *store.DecisionAction) error }{ { name: "成功平多仓", action: "close_long", currentPrice: 51000.0, expectedOrder: 123458, - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + executeFn: func(d *decision.Decision, a *store.DecisionAction) error { return s.autoTrader.executeCloseLongWithRecord(d, a) }, }, @@ -532,7 +526,7 @@ func (s *AutoTraderTestSuite) TestExecuteClosePosition() { action: "close_short", currentPrice: 49000.0, expectedOrder: 123459, - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { + executeFn: func(d *decision.Decision, a *store.DecisionAction) error { return s.autoTrader.executeCloseShortWithRecord(d, a) }, }, @@ -546,7 +540,7 @@ func (s *AutoTraderTestSuite) TestExecuteClosePosition() { }) decision := &decision.Decision{Action: tt.action, Symbol: "BTCUSDT"} - actionRecord := &logger.DecisionAction{Action: tt.action, Symbol: "BTCUSDT"} + actionRecord := &store.DecisionAction{Action: tt.action, Symbol: "BTCUSDT"} err := tt.executeFn(decision, actionRecord) @@ -557,221 +551,6 @@ func (s *AutoTraderTestSuite) TestExecuteClosePosition() { } } -// TestExecuteUpdateStopOrTakeProfit 测试更新止损/止盈(多空通用) -func (s *AutoTraderTestSuite) TestExecuteUpdateStopOrTakeProfit() { - // 使用指针变量来控制 market.Get 的返回值 - var testPrice *float64 - s.patches.ApplyFunc(market.Get, func(symbol string) (*market.Data, error) { - price := 50000.0 - if testPrice != nil { - price = *testPrice - } - return &market.Data{Symbol: symbol, CurrentPrice: price}, nil - }) - - tests := []struct { - name string - action string - symbol string - side string - currentPrice float64 - newPrice float64 - hasPosition bool - expectedErr string - executeFn func(*decision.Decision, *logger.DecisionAction) error - }{ - { - name: "成功更新多头止损", - action: "update_stop_loss", - symbol: "BTCUSDT", - side: "long", - currentPrice: 52000.0, - newPrice: 51000.0, - hasPosition: true, - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { - return s.autoTrader.executeUpdateStopLossWithRecord(d, a) - }, - }, - { - name: "成功更新空头止损", - action: "update_stop_loss", - symbol: "ETHUSDT", - side: "short", - currentPrice: 2900.0, - newPrice: 2950.0, - hasPosition: true, - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { - return s.autoTrader.executeUpdateStopLossWithRecord(d, a) - }, - }, - { - name: "成功更新多头止盈", - action: "update_take_profit", - symbol: "BTCUSDT", - side: "long", - currentPrice: 52000.0, - newPrice: 55000.0, - hasPosition: true, - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { - return s.autoTrader.executeUpdateTakeProfitWithRecord(d, a) - }, - }, - { - name: "成功更新空头止盈", - action: "update_take_profit", - symbol: "ETHUSDT", - side: "short", - currentPrice: 2900.0, - newPrice: 2800.0, - hasPosition: true, - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { - return s.autoTrader.executeUpdateTakeProfitWithRecord(d, a) - }, - }, - { - name: "多头止损价格不合理", - action: "update_stop_loss", - symbol: "BTCUSDT", - side: "long", - currentPrice: 50000.0, - newPrice: 51000.0, - hasPosition: true, - expectedErr: "多单止损必须低于当前价格", - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { - return s.autoTrader.executeUpdateStopLossWithRecord(d, a) - }, - }, - { - name: "多头止盈价格不合理", - action: "update_take_profit", - symbol: "BTCUSDT", - side: "long", - currentPrice: 50000.0, - newPrice: 49000.0, - hasPosition: true, - expectedErr: "多单止盈必须高于当前价格", - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { - return s.autoTrader.executeUpdateTakeProfitWithRecord(d, a) - }, - }, - { - name: "止损_持仓不存在", - action: "update_stop_loss", - symbol: "BTCUSDT", - currentPrice: 50000.0, - newPrice: 49000.0, - hasPosition: false, - expectedErr: "持仓不存在", - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { - return s.autoTrader.executeUpdateStopLossWithRecord(d, a) - }, - }, - { - name: "止盈_持仓不存在", - action: "update_take_profit", - symbol: "BTCUSDT", - currentPrice: 50000.0, - newPrice: 55000.0, - hasPosition: false, - expectedErr: "持仓不存在", - executeFn: func(d *decision.Decision, a *logger.DecisionAction) error { - return s.autoTrader.executeUpdateTakeProfitWithRecord(d, a) - }, - }, - } - - for _, tt := range tests { - time.Sleep(time.Millisecond) - s.Run(tt.name, func() { - // 设置当前测试用例的价格 - testPrice = &tt.currentPrice - - if tt.hasPosition { - s.mockTrader.positions = []map[string]interface{}{ - {"symbol": tt.symbol, "side": tt.side, "positionAmt": 0.1}, - } - } else { - s.mockTrader.positions = []map[string]interface{}{} - } - - decision := &decision.Decision{Action: tt.action, Symbol: tt.symbol} - if tt.action == "update_stop_loss" { - decision.NewStopLoss = tt.newPrice - } else { - decision.NewTakeProfit = tt.newPrice - } - actionRecord := &logger.DecisionAction{Action: tt.action, Symbol: tt.symbol} - - err := tt.executeFn(decision, actionRecord) - - if tt.expectedErr != "" { - s.Error(err) - s.Contains(err.Error(), tt.expectedErr) - } else { - s.NoError(err) - s.Equal(tt.currentPrice, actionRecord.Price) - } - - // 恢复默认状态 - s.mockTrader.positions = []map[string]interface{}{} - }) - } -} - -func (s *AutoTraderTestSuite) TestExecutePartialCloseWithRecord() { - s.Run("成功部分平仓", func() { - // 设置持仓 - s.mockTrader.positions = []map[string]interface{}{ - { - "symbol": "BTCUSDT", - "side": "long", - "positionAmt": 0.1, - "entryPrice": 50000.0, - "markPrice": 52000.0, - }, - } - - // Mock market.Get - s.patches.ApplyFunc(market.Get, func(symbol string) (*market.Data, error) { - return &market.Data{ - Symbol: symbol, - CurrentPrice: 52000.0, - }, nil - }) - - decision := &decision.Decision{ - Action: "partial_close", - Symbol: "BTCUSDT", - ClosePercentage: 50.0, - } - - actionRecord := &logger.DecisionAction{ - Action: "partial_close", - Symbol: "BTCUSDT", - } - - err := s.autoTrader.executePartialCloseWithRecord(decision, actionRecord) - - s.NoError(err) - s.Equal(0.05, actionRecord.Quantity) // 50% of 0.1 - }) - - s.Run("无效的平仓百分比", func() { - decision := &decision.Decision{ - Action: "partial_close", - Symbol: "BTCUSDT", - ClosePercentage: 150.0, // 无效 - } - - actionRecord := &logger.DecisionAction{} - - err := s.autoTrader.executePartialCloseWithRecord(decision, actionRecord) - - s.Error(err) - s.Contains(err.Error(), "平仓百分比必须在 0-100 之间") - }) -} - // ============================================================ // 层次 10: executeDecisionWithRecord 路由测试 // ============================================================ @@ -792,7 +571,7 @@ func (s *AutoTraderTestSuite) TestExecuteDecisionWithRecord() { PositionSizeUSD: 1000.0, Leverage: 10, } - actionRecord := &logger.DecisionAction{} + actionRecord := &store.DecisionAction{} err := s.autoTrader.executeDecisionWithRecord(decision, actionRecord) s.NoError(err) @@ -803,7 +582,7 @@ func (s *AutoTraderTestSuite) TestExecuteDecisionWithRecord() { Action: "close_long", Symbol: "BTCUSDT", } - actionRecord := &logger.DecisionAction{} + actionRecord := &store.DecisionAction{} err := s.autoTrader.executeDecisionWithRecord(decision, actionRecord) s.NoError(err) @@ -814,7 +593,7 @@ func (s *AutoTraderTestSuite) TestExecuteDecisionWithRecord() { Action: "hold", Symbol: "BTCUSDT", } - actionRecord := &logger.DecisionAction{} + actionRecord := &store.DecisionAction{} err := s.autoTrader.executeDecisionWithRecord(decision, actionRecord) s.NoError(err) @@ -825,7 +604,7 @@ func (s *AutoTraderTestSuite) TestExecuteDecisionWithRecord() { Action: "unknown_action", Symbol: "BTCUSDT", } - actionRecord := &logger.DecisionAction{} + actionRecord := &store.DecisionAction{} err := s.autoTrader.executeDecisionWithRecord(decision, actionRecord) s.Error(err) diff --git a/trader/binance_futures.go b/trader/binance_futures.go index f2489f6b..1d5a256e 100644 --- a/trader/binance_futures.go +++ b/trader/binance_futures.go @@ -5,8 +5,8 @@ import ( "crypto/rand" "encoding/hex" "fmt" - "log" "nofx/hook" + "nofx/logger" "strconv" "strings" "sync" @@ -80,7 +80,7 @@ func NewFuturesTrader(apiKey, secretKey string, userId string) *FuturesTrader { // 设置双向持仓模式(Hedge Mode) // 这是必需的,因为代码中使用了 PositionSide (LONG/SHORT) if err := trader.setDualSidePosition(); err != nil { - log.Printf("⚠️ 设置双向持仓模式失败: %v (如果已是双向模式则忽略此警告)", err) + logger.Infof("⚠️ 设置双向持仓模式失败: %v (如果已是双向模式则忽略此警告)", err) } return trader @@ -96,15 +96,15 @@ func (t *FuturesTrader) setDualSidePosition() error { if err != nil { // 如果错误信息包含"No need to change",说明已经是双向持仓模式 if strings.Contains(err.Error(), "No need to change position side") { - log.Printf(" ✓ 账户已是双向持仓模式(Hedge Mode)") + logger.Infof(" ✓ 账户已是双向持仓模式(Hedge Mode)") return nil } // 其他错误则返回(但在调用方不会中断初始化) return err } - log.Printf(" ✓ 账户已切换为双向持仓模式(Hedge Mode)") - log.Printf(" ℹ️ 双向持仓模式允许同时持有多单和空单") + logger.Infof(" ✓ 账户已切换为双向持仓模式(Hedge Mode)") + logger.Infof(" ℹ️ 双向持仓模式允许同时持有多单和空单") return nil } @@ -112,14 +112,14 @@ func (t *FuturesTrader) setDualSidePosition() error { func syncBinanceServerTime(client *futures.Client) { serverTime, err := client.NewServerTimeService().Do(context.Background()) if err != nil { - log.Printf("⚠️ 同步币安服务器时间失败: %v", err) + logger.Infof("⚠️ 同步币安服务器时间失败: %v", err) return } now := time.Now().UnixMilli() offset := now - serverTime client.TimeOffset = offset - log.Printf("⏱ 已同步币安服务器时间,偏移 %dms", offset) + logger.Infof("⏱ 已同步币安服务器时间,偏移 %dms", offset) } // GetBalance 获取账户余额(带缓存) @@ -129,16 +129,16 @@ func (t *FuturesTrader) GetBalance() (map[string]interface{}, error) { if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration { cacheAge := time.Since(t.balanceCacheTime) t.balanceCacheMutex.RUnlock() - log.Printf("✓ 使用缓存的账户余额(缓存时间: %.1f秒前)", cacheAge.Seconds()) + logger.Infof("✓ 使用缓存的账户余额(缓存时间: %.1f秒前)", cacheAge.Seconds()) return t.cachedBalance, nil } t.balanceCacheMutex.RUnlock() // 缓存过期或不存在,调用API - log.Printf("🔄 缓存过期,正在调用币安API获取账户余额...") + logger.Infof("🔄 缓存过期,正在调用币安API获取账户余额...") account, err := t.client.NewGetAccountService().Do(context.Background()) if err != nil { - log.Printf("❌ 币安API调用失败: %v", err) + logger.Infof("❌ 币安API调用失败: %v", err) return nil, fmt.Errorf("获取账户信息失败: %w", err) } @@ -147,7 +147,7 @@ func (t *FuturesTrader) GetBalance() (map[string]interface{}, error) { result["availableBalance"], _ = strconv.ParseFloat(account.AvailableBalance, 64) result["totalUnrealizedProfit"], _ = strconv.ParseFloat(account.TotalUnrealizedProfit, 64) - log.Printf("✓ 币安API返回: 总余额=%s, 可用=%s, 未实现盈亏=%s", + logger.Infof("✓ 币安API返回: 总余额=%s, 可用=%s, 未实现盈亏=%s", account.TotalWalletBalance, account.AvailableBalance, account.TotalUnrealizedProfit) @@ -168,13 +168,13 @@ func (t *FuturesTrader) GetPositions() ([]map[string]interface{}, error) { if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration { cacheAge := time.Since(t.positionsCacheTime) t.positionsCacheMutex.RUnlock() - log.Printf("✓ 使用缓存的持仓信息(缓存时间: %.1f秒前)", cacheAge.Seconds()) + logger.Infof("✓ 使用缓存的持仓信息(缓存时间: %.1f秒前)", cacheAge.Seconds()) return t.cachedPositions, nil } t.positionsCacheMutex.RUnlock() // 缓存过期或不存在,调用API - log.Printf("🔄 缓存过期,正在调用币安API获取持仓信息...") + logger.Infof("🔄 缓存过期,正在调用币安API获取持仓信息...") positions, err := t.client.NewGetPositionRiskService().Do(context.Background()) if err != nil { return nil, fmt.Errorf("获取持仓失败: %w", err) @@ -238,31 +238,31 @@ func (t *FuturesTrader) SetMarginMode(symbol string, isCrossMargin bool) error { if err != nil { // 如果错误信息包含"No need to change",说明仓位模式已经是目标值 if contains(err.Error(), "No need to change margin type") { - log.Printf(" ✓ %s 仓位模式已是 %s", symbol, marginModeStr) + logger.Infof(" ✓ %s 仓位模式已是 %s", symbol, marginModeStr) return nil } // 如果有持仓,无法更改仓位模式,但不影响交易 if contains(err.Error(), "Margin type cannot be changed if there exists position") { - log.Printf(" ⚠️ %s 有持仓,无法更改仓位模式,继续使用当前模式", symbol) + logger.Infof(" ⚠️ %s 有持仓,无法更改仓位模式,继续使用当前模式", symbol) return nil } // 检测多资产模式(错误码 -4168) if contains(err.Error(), "Multi-Assets mode") || contains(err.Error(), "-4168") || contains(err.Error(), "4168") { - log.Printf(" ⚠️ %s 检测到多资产模式,强制使用全仓模式", symbol) - log.Printf(" 💡 提示:如需使用逐仓模式,请在币安关闭多资产模式") + logger.Infof(" ⚠️ %s 检测到多资产模式,强制使用全仓模式", symbol) + logger.Infof(" 💡 提示:如需使用逐仓模式,请在币安关闭多资产模式") return nil } // 检测统一账户 API(Portfolio Margin) if contains(err.Error(), "unified") || contains(err.Error(), "portfolio") || contains(err.Error(), "Portfolio") { - log.Printf(" ❌ %s 检测到统一账户 API,无法进行合约交易", symbol) + logger.Infof(" ❌ %s 检测到统一账户 API,无法进行合约交易", symbol) return fmt.Errorf("请使用「现货与合约交易」API 权限,不要使用「统一账户 API」") } - log.Printf(" ⚠️ 设置仓位模式失败: %v", err) + logger.Infof(" ⚠️ 设置仓位模式失败: %v", err) // 不返回错误,让交易继续 return nil } - log.Printf(" ✓ %s 仓位模式已设置为 %s", symbol, marginModeStr) + logger.Infof(" ✓ %s 仓位模式已设置为 %s", symbol, marginModeStr) return nil } @@ -284,7 +284,7 @@ func (t *FuturesTrader) SetLeverage(symbol string, leverage int) error { // 如果当前杠杆已经是目标杠杆,跳过 if currentLeverage == leverage && currentLeverage > 0 { - log.Printf(" ✓ %s 杠杆已是 %dx,无需切换", symbol, leverage) + logger.Infof(" ✓ %s 杠杆已是 %dx,无需切换", symbol, leverage) return nil } @@ -297,16 +297,16 @@ func (t *FuturesTrader) SetLeverage(symbol string, leverage int) error { if err != nil { // 如果错误信息包含"No need to change",说明杠杆已经是目标值 if contains(err.Error(), "No need to change") { - log.Printf(" ✓ %s 杠杆已是 %dx", symbol, leverage) + logger.Infof(" ✓ %s 杠杆已是 %dx", symbol, leverage) return nil } return fmt.Errorf("设置杠杆失败: %w", err) } - log.Printf(" ✓ %s 杠杆已切换为 %dx", symbol, leverage) + logger.Infof(" ✓ %s 杠杆已切换为 %dx", symbol, leverage) // 切换杠杆后等待5秒(避免冷却期错误) - log.Printf(" ⏱ 等待5秒冷却期...") + logger.Infof(" ⏱ 等待5秒冷却期...") time.Sleep(5 * time.Second) return nil @@ -316,7 +316,7 @@ func (t *FuturesTrader) SetLeverage(symbol string, leverage int) error { func (t *FuturesTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { // 先取消该币种的所有委托单(清理旧的止损止盈单) if err := t.CancelAllOrders(symbol); err != nil { - log.Printf(" ⚠ 取消旧委托单失败(可能没有委托单): %v", err) + logger.Infof(" ⚠ 取消旧委托单失败(可能没有委托单): %v", err) } // 设置杠杆 @@ -357,8 +357,8 @@ func (t *FuturesTrader) OpenLong(symbol string, quantity float64, leverage int) return nil, fmt.Errorf("开多仓失败: %w", err) } - log.Printf("✓ 开多仓成功: %s 数量: %s", symbol, quantityStr) - log.Printf(" 订单ID: %d", order.OrderID) + logger.Infof("✓ 开多仓成功: %s 数量: %s", symbol, quantityStr) + logger.Infof(" 订单ID: %d", order.OrderID) result := make(map[string]interface{}) result["orderId"] = order.OrderID @@ -371,7 +371,7 @@ func (t *FuturesTrader) OpenLong(symbol string, quantity float64, leverage int) func (t *FuturesTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { // 先取消该币种的所有委托单(清理旧的止损止盈单) if err := t.CancelAllOrders(symbol); err != nil { - log.Printf(" ⚠ 取消旧委托单失败(可能没有委托单): %v", err) + logger.Infof(" ⚠ 取消旧委托单失败(可能没有委托单): %v", err) } // 设置杠杆 @@ -412,8 +412,8 @@ func (t *FuturesTrader) OpenShort(symbol string, quantity float64, leverage int) return nil, fmt.Errorf("开空仓失败: %w", err) } - log.Printf("✓ 开空仓成功: %s 数量: %s", symbol, quantityStr) - log.Printf(" 订单ID: %d", order.OrderID) + logger.Infof("✓ 开空仓成功: %s 数量: %s", symbol, quantityStr) + logger.Infof(" 订单ID: %d", order.OrderID) result := make(map[string]interface{}) result["orderId"] = order.OrderID @@ -463,11 +463,11 @@ func (t *FuturesTrader) CloseLong(symbol string, quantity float64) (map[string]i return nil, fmt.Errorf("平多仓失败: %w", err) } - log.Printf("✓ 平多仓成功: %s 数量: %s", symbol, quantityStr) + logger.Infof("✓ 平多仓成功: %s 数量: %s", symbol, quantityStr) // 平仓后取消该币种的所有挂单(止损止盈单) if err := t.CancelAllOrders(symbol); err != nil { - log.Printf(" ⚠ 取消挂单失败: %v", err) + logger.Infof(" ⚠ 取消挂单失败: %v", err) } result := make(map[string]interface{}) @@ -518,11 +518,11 @@ func (t *FuturesTrader) CloseShort(symbol string, quantity float64) (map[string] return nil, fmt.Errorf("平空仓失败: %w", err) } - log.Printf("✓ 平空仓成功: %s 数量: %s", symbol, quantityStr) + logger.Infof("✓ 平空仓成功: %s 数量: %s", symbol, quantityStr) // 平仓后取消该币种的所有挂单(止损止盈单) if err := t.CancelAllOrders(symbol); err != nil { - log.Printf(" ⚠ 取消挂单失败: %v", err) + logger.Infof(" ⚠ 取消挂单失败: %v", err) } result := make(map[string]interface{}) @@ -559,19 +559,19 @@ func (t *FuturesTrader) CancelStopLossOrders(symbol string) error { if err != nil { errMsg := fmt.Sprintf("订单ID %d: %v", order.OrderID, err) cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) - log.Printf(" ⚠ 取消止损单失败: %s", errMsg) + logger.Infof(" ⚠ 取消止损单失败: %s", errMsg) continue } canceledCount++ - log.Printf(" ✓ 已取消止损单 (订单ID: %d, 类型: %s, 方向: %s)", order.OrderID, orderType, order.PositionSide) + logger.Infof(" ✓ 已取消止损单 (订单ID: %d, 类型: %s, 方向: %s)", order.OrderID, orderType, order.PositionSide) } } if canceledCount == 0 && len(cancelErrors) == 0 { - log.Printf(" ℹ %s 没有止损单需要取消", symbol) + logger.Infof(" ℹ %s 没有止损单需要取消", symbol) } else if canceledCount > 0 { - log.Printf(" ✓ 已取消 %s 的 %d 个止损单", symbol, canceledCount) + logger.Infof(" ✓ 已取消 %s 的 %d 个止损单", symbol, canceledCount) } // 如果所有取消都失败了,返回错误 @@ -609,19 +609,19 @@ func (t *FuturesTrader) CancelTakeProfitOrders(symbol string) error { if err != nil { errMsg := fmt.Sprintf("订单ID %d: %v", order.OrderID, err) cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg)) - log.Printf(" ⚠ 取消止盈单失败: %s", errMsg) + logger.Infof(" ⚠ 取消止盈单失败: %s", errMsg) continue } canceledCount++ - log.Printf(" ✓ 已取消止盈单 (订单ID: %d, 类型: %s, 方向: %s)", order.OrderID, orderType, order.PositionSide) + logger.Infof(" ✓ 已取消止盈单 (订单ID: %d, 类型: %s, 方向: %s)", order.OrderID, orderType, order.PositionSide) } } if canceledCount == 0 && len(cancelErrors) == 0 { - log.Printf(" ℹ %s 没有止盈单需要取消", symbol) + logger.Infof(" ℹ %s 没有止盈单需要取消", symbol) } else if canceledCount > 0 { - log.Printf(" ✓ 已取消 %s 的 %d 个止盈单", symbol, canceledCount) + logger.Infof(" ✓ 已取消 %s 的 %d 个止盈单", symbol, canceledCount) } // 如果所有取消都失败了,返回错误 @@ -642,7 +642,7 @@ func (t *FuturesTrader) CancelAllOrders(symbol string) error { return fmt.Errorf("取消挂单失败: %w", err) } - log.Printf(" ✓ 已取消 %s 的所有挂单", symbol) + logger.Infof(" ✓ 已取消 %s 的所有挂单", symbol) return nil } @@ -674,20 +674,20 @@ func (t *FuturesTrader) CancelStopOrders(symbol string) error { Do(context.Background()) if err != nil { - log.Printf(" ⚠ 取消订单 %d 失败: %v", order.OrderID, err) + logger.Infof(" ⚠ 取消订单 %d 失败: %v", order.OrderID, err) continue } canceledCount++ - log.Printf(" ✓ 已取消 %s 的止盈/止损单 (订单ID: %d, 类型: %s)", + logger.Infof(" ✓ 已取消 %s 的止盈/止损单 (订单ID: %d, 类型: %s)", symbol, order.OrderID, orderType) } } if canceledCount == 0 { - log.Printf(" ℹ %s 没有止盈/止损单需要取消", symbol) + logger.Infof(" ℹ %s 没有止盈/止损单需要取消", symbol) } else { - log.Printf(" ✓ 已取消 %s 的 %d 个止盈/止损单", symbol, canceledCount) + logger.Infof(" ✓ 已取消 %s 的 %d 个止盈/止损单", symbol, canceledCount) } return nil @@ -748,13 +748,14 @@ func (t *FuturesTrader) SetStopLoss(symbol string, positionSide string, quantity Quantity(quantityStr). WorkingType(futures.WorkingTypeContractPrice). ClosePosition(true). + NewClientOrderID(getBrOrderID()). Do(context.Background()) if err != nil { return fmt.Errorf("设置止损失败: %w", err) } - log.Printf(" 止损价设置: %.4f", stopPrice) + logger.Infof(" 止损价设置: %.4f", stopPrice) return nil } @@ -786,13 +787,14 @@ func (t *FuturesTrader) SetTakeProfit(symbol string, positionSide string, quanti Quantity(quantityStr). WorkingType(futures.WorkingTypeContractPrice). ClosePosition(true). + NewClientOrderID(getBrOrderID()). Do(context.Background()) if err != nil { return fmt.Errorf("设置止盈失败: %w", err) } - log.Printf(" 止盈价设置: %.4f", takeProfitPrice) + logger.Infof(" 止盈价设置: %.4f", takeProfitPrice) return nil } @@ -836,14 +838,14 @@ func (t *FuturesTrader) GetSymbolPrecision(symbol string) (int, error) { if filter["filterType"] == "LOT_SIZE" { stepSize := filter["stepSize"].(string) precision := calculatePrecision(stepSize) - log.Printf(" %s 数量精度: %d (stepSize: %s)", symbol, precision, stepSize) + logger.Infof(" %s 数量精度: %d (stepSize: %s)", symbol, precision, stepSize) return precision, nil } } } } - log.Printf(" ⚠ %s 未找到精度信息,使用默认精度3", symbol) + logger.Infof(" ⚠ %s 未找到精度信息,使用默认精度3", symbol) return 3, nil // 默认精度为3 } @@ -915,3 +917,42 @@ func stringContains(s, substr string) bool { } return false } + +// GetOrderStatus 获取订单状态 +func (t *FuturesTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { + // 将 orderID 转换为 int64 + orderIDInt, err := strconv.ParseInt(orderID, 10, 64) + if err != nil { + return nil, fmt.Errorf("无效的订单ID: %s", orderID) + } + + order, err := t.client.NewGetOrderService(). + Symbol(symbol). + OrderID(orderIDInt). + Do(context.Background()) + if err != nil { + return nil, fmt.Errorf("获取订单状态失败: %w", err) + } + + // 解析成交价格 + avgPrice, _ := strconv.ParseFloat(order.AvgPrice, 64) + executedQty, _ := strconv.ParseFloat(order.ExecutedQuantity, 64) + + result := map[string]interface{}{ + "orderId": order.OrderID, + "symbol": order.Symbol, + "status": string(order.Status), + "avgPrice": avgPrice, + "executedQty": executedQty, + "side": string(order.Side), + "type": string(order.Type), + "time": order.Time, + "updateTime": order.UpdateTime, + } + + // 币安合约的手续费需要通过 GetUserTrades 获取,这里暂时不获取 + // 后续可以通过 WebSocket 或单独查询获取 + result["commission"] = 0.0 + + return result, nil +} diff --git a/trader/bybit_trader.go b/trader/bybit_trader.go index 7c055d0b..cdebd65a 100644 --- a/trader/bybit_trader.go +++ b/trader/bybit_trader.go @@ -3,7 +3,7 @@ package trader import ( "context" "fmt" - "log" + "nofx/logger" "net/http" "strconv" "strings" @@ -55,7 +55,7 @@ func NewBybitTrader(apiKey, secretKey string) *BybitTrader { cacheDuration: 15 * time.Second, } - log.Printf("🔵 [Bybit] 交易器已初始化") + logger.Infof("🔵 [Bybit] 交易器已初始化") return trader } @@ -224,7 +224,7 @@ func (t *BybitTrader) GetPositions() ([]map[string]interface{}, error) { func (t *BybitTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { // 先设置杠杆 if err := t.SetLeverage(symbol, leverage); err != nil { - log.Printf("⚠️ [Bybit] 设置杠杆失败: %v", err) + logger.Infof("⚠️ [Bybit] 设置杠杆失败: %v", err) } params := map[string]interface{}{ @@ -251,7 +251,7 @@ func (t *BybitTrader) OpenLong(symbol string, quantity float64, leverage int) (m func (t *BybitTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { // 先设置杠杆 if err := t.SetLeverage(symbol, leverage); err != nil { - log.Printf("⚠️ [Bybit] 设置杠杆失败: %v", err) + logger.Infof("⚠️ [Bybit] 设置杠杆失败: %v", err) } params := map[string]interface{}{ @@ -485,7 +485,7 @@ func (t *BybitTrader) SetStopLoss(symbol string, positionSide string, quantity, return fmt.Errorf("设置止损失败: %s", result.RetMsg) } - log.Printf(" ✓ [Bybit] 止损单已设置: %s @ %.2f", symbol, stopPrice) + logger.Infof(" ✓ [Bybit] 止损单已设置: %s @ %.2f", symbol, stopPrice) return nil } @@ -528,7 +528,7 @@ func (t *BybitTrader) SetTakeProfit(symbol string, positionSide string, quantity return fmt.Errorf("设置止盈失败: %s", result.RetMsg) } - log.Printf(" ✓ [Bybit] 止盈单已设置: %s @ %.2f", symbol, takeProfitPrice) + logger.Infof(" ✓ [Bybit] 止盈单已设置: %s @ %.2f", symbol, takeProfitPrice) return nil } @@ -560,10 +560,10 @@ func (t *BybitTrader) CancelAllOrders(symbol string) error { // CancelStopOrders 取消所有止盈止损单 func (t *BybitTrader) CancelStopOrders(symbol string) error { if err := t.CancelStopLossOrders(symbol); err != nil { - log.Printf("⚠️ [Bybit] 取消止损单失败: %v", err) + logger.Infof("⚠️ [Bybit] 取消止损单失败: %v", err) } if err := t.CancelTakeProfitOrders(symbol); err != nil { - log.Printf("⚠️ [Bybit] 取消止盈单失败: %v", err) + logger.Infof("⚠️ [Bybit] 取消止盈单失败: %v", err) } return nil } @@ -604,6 +604,67 @@ func (t *BybitTrader) parseOrderResult(result *bybit.ServerResponse) (map[string }, nil } +// GetOrderStatus 获取订单状态 +func (t *BybitTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { + params := map[string]interface{}{ + "category": "linear", + "symbol": symbol, + "orderId": orderID, + } + + result, err := t.client.NewUtaBybitServiceWithParams(params).GetOrderHistory(context.Background()) + if err != nil { + return nil, fmt.Errorf("获取订单状态失败: %w", err) + } + + if result.RetCode != 0 { + return nil, fmt.Errorf("API 错误: %s", result.RetMsg) + } + + resultData, ok := result.Result.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("返回格式错误") + } + + list, _ := resultData["list"].([]interface{}) + if len(list) == 0 { + return nil, fmt.Errorf("未找到订单 %s", orderID) + } + + order, _ := list[0].(map[string]interface{}) + + // 解析订单数据 + status, _ := order["orderStatus"].(string) + avgPriceStr, _ := order["avgPrice"].(string) + cumExecQtyStr, _ := order["cumExecQty"].(string) + cumExecFeeStr, _ := order["cumExecFee"].(string) + + avgPrice, _ := strconv.ParseFloat(avgPriceStr, 64) + executedQty, _ := strconv.ParseFloat(cumExecQtyStr, 64) + commission, _ := strconv.ParseFloat(cumExecFeeStr, 64) + + // 转换状态为统一格式 + unifiedStatus := status + switch status { + case "Filled": + unifiedStatus = "FILLED" + case "New", "Created": + unifiedStatus = "NEW" + case "Cancelled", "Rejected": + unifiedStatus = "CANCELED" + case "PartiallyFilled": + unifiedStatus = "PARTIALLY_FILLED" + } + + return map[string]interface{}{ + "orderId": orderID, + "status": unifiedStatus, + "avgPrice": avgPrice, + "executedQty": executedQty, + "commission": commission, + }, nil +} + func (t *BybitTrader) cancelConditionalOrders(symbol string, orderType string) error { // 先获取所有条件单 params := map[string]interface{}{ diff --git a/trader/hyperliquid_trader.go b/trader/hyperliquid_trader.go index 885ce0d8..d075c54d 100644 --- a/trader/hyperliquid_trader.go +++ b/trader/hyperliquid_trader.go @@ -5,7 +5,7 @@ import ( "crypto/ecdsa" "encoding/json" "fmt" - "log" + "nofx/logger" "strconv" "strings" "sync" @@ -56,14 +56,14 @@ func NewHyperliquidTrader(privateKeyHex string, walletAddr string, testnet bool) // Check if user accidentally uses main wallet private key (security risk) if strings.EqualFold(walletAddr, agentAddr) { - log.Printf("⚠️⚠️⚠️ WARNING: Main wallet address (%s) matches Agent wallet address!", walletAddr) - log.Printf(" This indicates you may be using your main wallet private key, which poses extremely high security risks!") - log.Printf(" Recommendation: Immediately create a separate Agent Wallet on Hyperliquid official website") - log.Printf(" Reference: https://hyperliquid.gitbook.io/hyperliquid-docs/for-developers/api/nonces-and-api-wallets") + logger.Infof("⚠️⚠️⚠️ WARNING: Main wallet address (%s) matches Agent wallet address!", walletAddr) + logger.Infof(" This indicates you may be using your main wallet private key, which poses extremely high security risks!") + logger.Infof(" Recommendation: Immediately create a separate Agent Wallet on Hyperliquid official website") + logger.Infof(" Reference: https://hyperliquid.gitbook.io/hyperliquid-docs/for-developers/api/nonces-and-api-wallets") } else { - log.Printf("✓ Using Agent Wallet mode (secure)") - log.Printf(" └─ Agent wallet address: %s (for signing)", agentAddr) - log.Printf(" └─ Main wallet address: %s (holds funds)", walletAddr) + logger.Infof("✓ Using Agent Wallet mode (secure)") + logger.Infof(" └─ Agent wallet address: %s (for signing)", agentAddr) + logger.Infof(" └─ Main wallet address: %s (holds funds)", walletAddr) } ctx := context.Background() @@ -79,7 +79,7 @@ func NewHyperliquidTrader(privateKeyHex string, walletAddr string, testnet bool) nil, // SpotMeta will be fetched automatically ) - log.Printf("✓ Hyperliquid交易器初始化成功 (testnet=%v, wallet=%s)", testnet, walletAddr) + logger.Infof("✓ Hyperliquid交易器初始化成功 (testnet=%v, wallet=%s)", testnet, walletAddr) // 获取meta信息(包含精度等配置) meta, err := exchange.Info().Meta(ctx) @@ -97,26 +97,26 @@ func NewHyperliquidTrader(privateKeyHex string, walletAddr string, testnet bool) if agentBalance > 100 { // Critical: Agent wallet holds too much funds - log.Printf("🚨🚨🚨 CRITICAL SECURITY WARNING 🚨🚨🚨") - log.Printf(" Agent wallet balance: %.2f USDC (exceeds safe threshold of 100 USDC)", agentBalance) - log.Printf(" Agent wallet address: %s", agentAddr) - log.Printf(" ⚠️ Agent wallets should only be used for signing and hold minimal/zero balance") - log.Printf(" ⚠️ High balance in Agent wallet poses security risks") - log.Printf(" 📖 Reference: https://hyperliquid.gitbook.io/hyperliquid-docs/for-developers/api/nonces-and-api-wallets") - log.Printf(" 💡 Recommendation: Transfer funds to main wallet and keep Agent wallet balance near 0") + logger.Infof("🚨🚨🚨 CRITICAL SECURITY WARNING 🚨🚨🚨") + logger.Infof(" Agent wallet balance: %.2f USDC (exceeds safe threshold of 100 USDC)", agentBalance) + logger.Infof(" Agent wallet address: %s", agentAddr) + logger.Infof(" ⚠️ Agent wallets should only be used for signing and hold minimal/zero balance") + logger.Infof(" ⚠️ High balance in Agent wallet poses security risks") + logger.Infof(" 📖 Reference: https://hyperliquid.gitbook.io/hyperliquid-docs/for-developers/api/nonces-and-api-wallets") + logger.Infof(" 💡 Recommendation: Transfer funds to main wallet and keep Agent wallet balance near 0") return nil, fmt.Errorf("security check failed: Agent wallet balance too high (%.2f USDC), exceeds 100 USDC threshold", agentBalance) } else if agentBalance > 10 { // Warning: Agent wallet has some balance (acceptable but not ideal) - log.Printf("⚠️ Notice: Agent wallet address (%s) has some balance: %.2f USDC", agentAddr, agentBalance) - log.Printf(" While not critical, it's recommended to keep Agent wallet balance near 0 for security") + logger.Infof("⚠️ Notice: Agent wallet address (%s) has some balance: %.2f USDC", agentAddr, agentBalance) + logger.Infof(" While not critical, it's recommended to keep Agent wallet balance near 0 for security") } else { // OK: Agent wallet balance is safe - log.Printf("✓ Agent wallet balance is safe: %.2f USDC (near zero as recommended)", agentBalance) + logger.Infof("✓ Agent wallet balance is safe: %.2f USDC (near zero as recommended)", agentBalance) } } else if err != nil { // Failed to query agent balance - log warning but don't block initialization - log.Printf("⚠️ Could not verify Agent wallet balance (query failed): %v", err) - log.Printf(" Proceeding with initialization, but please manually verify Agent wallet balance is near 0") + logger.Infof("⚠️ Could not verify Agent wallet balance (query failed): %v", err) + logger.Infof(" Proceeding with initialization, but please manually verify Agent wallet balance is near 0") } } @@ -131,18 +131,18 @@ func NewHyperliquidTrader(privateKeyHex string, walletAddr string, testnet bool) // GetBalance 获取账户余额 func (t *HyperliquidTrader) GetBalance() (map[string]interface{}, error) { - log.Printf("🔄 正在调用Hyperliquid API获取账户余额...") + logger.Infof("🔄 正在调用Hyperliquid API获取账户余额...") // ✅ Step 1: 查询 Spot 现货账户余额 spotState, err := t.exchange.Info().SpotUserState(t.ctx, t.walletAddr) var spotUSDCBalance float64 = 0.0 if err != nil { - log.Printf("⚠️ 查询 Spot 余额失败(可能无现货资产): %v", err) + logger.Infof("⚠️ 查询 Spot 余额失败(可能无现货资产): %v", err) } else if spotState != nil && len(spotState.Balances) > 0 { for _, balance := range spotState.Balances { if balance.Coin == "USDC" { spotUSDCBalance, _ = strconv.ParseFloat(balance.Total, 64) - log.Printf("✓ 发现 Spot 现货余额: %.2f USDC", spotUSDCBalance) + logger.Infof("✓ 发现 Spot 现货余额: %.2f USDC", spotUSDCBalance) break } } @@ -151,7 +151,7 @@ func (t *HyperliquidTrader) GetBalance() (map[string]interface{}, error) { // ✅ Step 2: 查询 Perpetuals 合约账户状态 accountState, err := t.exchange.Info().UserState(t.ctx, t.walletAddr) if err != nil { - log.Printf("❌ Hyperliquid Perpetuals API调用失败: %v", err) + logger.Infof("❌ Hyperliquid Perpetuals API调用失败: %v", err) return nil, fmt.Errorf("获取账户信息失败: %w", err) } @@ -179,8 +179,8 @@ func (t *HyperliquidTrader) GetBalance() (map[string]interface{}, error) { // 🔍 调试:打印API返回的完整摘要结构 summaryJSON, _ := json.MarshalIndent(summary, " ", " ") - log.Printf("🔍 [DEBUG] Hyperliquid API %s 完整数据:", summaryType) - log.Printf("%s", string(summaryJSON)) + logger.Infof("🔍 [DEBUG] Hyperliquid API %s 完整数据:", summaryType) + logger.Infof("%s", string(summaryJSON)) // ⚠️ 关键修复:从所有持仓中累加真正的未实现盈亏 totalUnrealizedPnl := 0.0 @@ -204,7 +204,7 @@ func (t *HyperliquidTrader) GetBalance() (map[string]interface{}, error) { withdrawable, err := strconv.ParseFloat(accountState.Withdrawable, 64) if err == nil && withdrawable > 0 { availableBalance = withdrawable - log.Printf("✓ 使用 Withdrawable 作为可用余额: %.2f", availableBalance) + logger.Infof("✓ 使用 Withdrawable 作为可用余额: %.2f", availableBalance) } } @@ -212,7 +212,7 @@ func (t *HyperliquidTrader) GetBalance() (map[string]interface{}, error) { if availableBalance == 0 && accountState.Withdrawable == "" { availableBalance = accountValue - totalMarginUsed if availableBalance < 0 { - log.Printf("⚠️ 计算出的可用余额为负数 (%.2f),重置为 0", availableBalance) + logger.Infof("⚠️ 计算出的可用余额为负数 (%.2f),重置为 0", availableBalance) availableBalance = 0 } } @@ -227,16 +227,16 @@ func (t *HyperliquidTrader) GetBalance() (map[string]interface{}, error) { result["totalUnrealizedProfit"] = totalUnrealizedPnl // 未实现盈亏(仅来自 Perpetuals) result["spotBalance"] = spotUSDCBalance // Spot 现货余额(单独返回) - log.Printf("✓ Hyperliquid 完整账户:") - log.Printf(" • Spot 现货余额: %.2f USDC (需手动转账到 Perpetuals 才能开仓)", spotUSDCBalance) - log.Printf(" • Perpetuals 合约净值: %.2f USDC (钱包%.2f + 未实现%.2f)", + logger.Infof("✓ Hyperliquid 完整账户:") + logger.Infof(" • Spot 现货余额: %.2f USDC (需手动转账到 Perpetuals 才能开仓)", spotUSDCBalance) + logger.Infof(" • Perpetuals 合约净值: %.2f USDC (钱包%.2f + 未实现%.2f)", accountValue, walletBalanceWithoutUnrealized, totalUnrealizedPnl) - log.Printf(" • Perpetuals 可用余额: %.2f USDC (可直接用于开仓)", availableBalance) - log.Printf(" • 保证金占用: %.2f USDC", totalMarginUsed) - log.Printf(" • 总资产 (Perp+Spot): %.2f USDC", totalWalletBalance) - log.Printf(" ⭐ 总资产: %.2f USDC | Perp 可用: %.2f USDC | Spot 余额: %.2f USDC", + logger.Infof(" • Perpetuals 可用余额: %.2f USDC (可直接用于开仓)", availableBalance) + logger.Infof(" • 保证金占用: %.2f USDC", totalMarginUsed) + logger.Infof(" • 总资产 (Perp+Spot): %.2f USDC", totalWalletBalance) + logger.Infof(" ⭐ 总资产: %.2f USDC | Perp 可用: %.2f USDC | Spot 余额: %.2f USDC", totalWalletBalance, availableBalance, spotUSDCBalance) return result, nil @@ -316,7 +316,7 @@ func (t *HyperliquidTrader) SetMarginMode(symbol string, isCrossMargin bool) err if !isCrossMargin { marginModeStr = "逐仓" } - log.Printf(" ✓ %s 将使用 %s 模式", symbol, marginModeStr) + logger.Infof(" ✓ %s 将使用 %s 模式", symbol, marginModeStr) return nil } @@ -332,7 +332,7 @@ func (t *HyperliquidTrader) SetLeverage(symbol string, leverage int) error { return fmt.Errorf("设置杠杆失败: %w", err) } - log.Printf(" ✓ %s 杠杆已切换为 %dx", symbol, leverage) + logger.Infof(" ✓ %s 杠杆已切换为 %dx", symbol, leverage) return nil } @@ -343,7 +343,7 @@ func (t *HyperliquidTrader) refreshMetaIfNeeded(coin string) error { return nil // Meta 正常,无需刷新 } - log.Printf("⚠️ %s 的 Asset ID 为 0,尝试刷新 Meta 信息...", coin) + logger.Infof("⚠️ %s 的 Asset ID 为 0,尝试刷新 Meta 信息...", coin) // 刷新 Meta 信息 meta, err := t.exchange.Info().Meta(t.ctx) @@ -356,7 +356,7 @@ func (t *HyperliquidTrader) refreshMetaIfNeeded(coin string) error { t.meta = meta t.metaMutex.Unlock() - log.Printf("✅ Meta 信息已刷新,包含 %d 个资产", len(meta.Universe)) + logger.Infof("✅ Meta 信息已刷新,包含 %d 个资产", len(meta.Universe)) // 验证刷新后的 Asset ID assetID = t.exchange.Info().NameToAsset(coin) @@ -367,7 +367,7 @@ func (t *HyperliquidTrader) refreshMetaIfNeeded(coin string) error { " 3. API 连接问题", coin) } - log.Printf("✅ 刷新后 Asset ID 检查通过: %s -> %d", coin, assetID) + logger.Infof("✅ 刷新后 Asset ID 检查通过: %s -> %d", coin, assetID) return nil } @@ -375,7 +375,7 @@ func (t *HyperliquidTrader) refreshMetaIfNeeded(coin string) error { func (t *HyperliquidTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { // 先取消该币种的所有委托单 if err := t.CancelAllOrders(symbol); err != nil { - log.Printf(" ⚠ 取消旧委托单失败: %v", err) + logger.Infof(" ⚠ 取消旧委托单失败: %v", err) } // 设置杠杆 @@ -394,11 +394,11 @@ func (t *HyperliquidTrader) OpenLong(symbol string, quantity float64, leverage i // ⚠️ 关键:根据币种精度要求,四舍五入数量 roundedQuantity := t.roundToSzDecimals(coin, quantity) - log.Printf(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) + logger.Infof(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) // ⚠️ 关键:价格也需要处理为5位有效数字 aggressivePrice := t.roundPriceToSigfigs(price * 1.01) - log.Printf(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*1.01, aggressivePrice) + logger.Infof(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*1.01, aggressivePrice) // 创建市价买入订单(使用IOC limit order with aggressive price) order := hyperliquid.CreateOrderRequest{ @@ -419,7 +419,7 @@ func (t *HyperliquidTrader) OpenLong(symbol string, quantity float64, leverage i return nil, fmt.Errorf("开多仓失败: %w", err) } - log.Printf("✓ 开多仓成功: %s 数量: %.4f", symbol, roundedQuantity) + logger.Infof("✓ 开多仓成功: %s 数量: %.4f", symbol, roundedQuantity) result := make(map[string]interface{}) result["orderId"] = 0 // Hyperliquid没有返回order ID @@ -433,7 +433,7 @@ func (t *HyperliquidTrader) OpenLong(symbol string, quantity float64, leverage i func (t *HyperliquidTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { // 先取消该币种的所有委托单 if err := t.CancelAllOrders(symbol); err != nil { - log.Printf(" ⚠ 取消旧委托单失败: %v", err) + logger.Infof(" ⚠ 取消旧委托单失败: %v", err) } // 设置杠杆 @@ -452,11 +452,11 @@ func (t *HyperliquidTrader) OpenShort(symbol string, quantity float64, leverage // ⚠️ 关键:根据币种精度要求,四舍五入数量 roundedQuantity := t.roundToSzDecimals(coin, quantity) - log.Printf(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) + logger.Infof(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) // ⚠️ 关键:价格也需要处理为5位有效数字 aggressivePrice := t.roundPriceToSigfigs(price * 0.99) - log.Printf(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*0.99, aggressivePrice) + logger.Infof(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*0.99, aggressivePrice) // 创建市价卖出订单 order := hyperliquid.CreateOrderRequest{ @@ -477,7 +477,7 @@ func (t *HyperliquidTrader) OpenShort(symbol string, quantity float64, leverage return nil, fmt.Errorf("开空仓失败: %w", err) } - log.Printf("✓ 开空仓成功: %s 数量: %.4f", symbol, roundedQuantity) + logger.Infof("✓ 开空仓成功: %s 数量: %.4f", symbol, roundedQuantity) result := make(map[string]interface{}) result["orderId"] = 0 @@ -519,11 +519,11 @@ func (t *HyperliquidTrader) CloseLong(symbol string, quantity float64) (map[stri // ⚠️ 关键:根据币种精度要求,四舍五入数量 roundedQuantity := t.roundToSzDecimals(coin, quantity) - log.Printf(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) + logger.Infof(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) // ⚠️ 关键:价格也需要处理为5位有效数字 aggressivePrice := t.roundPriceToSigfigs(price * 0.99) - log.Printf(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*0.99, aggressivePrice) + logger.Infof(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*0.99, aggressivePrice) // 创建平仓订单(卖出 + ReduceOnly) order := hyperliquid.CreateOrderRequest{ @@ -544,11 +544,11 @@ func (t *HyperliquidTrader) CloseLong(symbol string, quantity float64) (map[stri return nil, fmt.Errorf("平多仓失败: %w", err) } - log.Printf("✓ 平多仓成功: %s 数量: %.4f", symbol, roundedQuantity) + logger.Infof("✓ 平多仓成功: %s 数量: %.4f", symbol, roundedQuantity) // 平仓后取消该币种的所有挂单 if err := t.CancelAllOrders(symbol); err != nil { - log.Printf(" ⚠ 取消挂单失败: %v", err) + logger.Infof(" ⚠ 取消挂单失败: %v", err) } result := make(map[string]interface{}) @@ -591,11 +591,11 @@ func (t *HyperliquidTrader) CloseShort(symbol string, quantity float64) (map[str // ⚠️ 关键:根据币种精度要求,四舍五入数量 roundedQuantity := t.roundToSzDecimals(coin, quantity) - log.Printf(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) + logger.Infof(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin)) // ⚠️ 关键:价格也需要处理为5位有效数字 aggressivePrice := t.roundPriceToSigfigs(price * 1.01) - log.Printf(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*1.01, aggressivePrice) + logger.Infof(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*1.01, aggressivePrice) // 创建平仓订单(买入 + ReduceOnly) order := hyperliquid.CreateOrderRequest{ @@ -616,11 +616,11 @@ func (t *HyperliquidTrader) CloseShort(symbol string, quantity float64) (map[str return nil, fmt.Errorf("平空仓失败: %w", err) } - log.Printf("✓ 平空仓成功: %s 数量: %.4f", symbol, roundedQuantity) + logger.Infof("✓ 平空仓成功: %s 数量: %.4f", symbol, roundedQuantity) // 平仓后取消该币种的所有挂单 if err := t.CancelAllOrders(symbol); err != nil { - log.Printf(" ⚠ 取消挂单失败: %v", err) + logger.Infof(" ⚠ 取消挂单失败: %v", err) } result := make(map[string]interface{}) @@ -637,7 +637,7 @@ func (t *HyperliquidTrader) CloseShort(symbol string, quantity float64) (map[str func (t *HyperliquidTrader) CancelStopLossOrders(symbol string) error { // Hyperliquid SDK 的 OpenOrder 结构不暴露 trigger 字段 // 无法区分止损和止盈单,因此取消该币种的所有挂单 - log.Printf(" ⚠️ Hyperliquid 无法区分止损/止盈单,将取消所有挂单") + logger.Infof(" ⚠️ Hyperliquid 无法区分止损/止盈单,将取消所有挂单") return t.CancelStopOrders(symbol) } @@ -645,7 +645,7 @@ func (t *HyperliquidTrader) CancelStopLossOrders(symbol string) error { func (t *HyperliquidTrader) CancelTakeProfitOrders(symbol string) error { // Hyperliquid SDK 的 OpenOrder 结构不暴露 trigger 字段 // 无法区分止损和止盈单,因此取消该币种的所有挂单 - log.Printf(" ⚠️ Hyperliquid 无法区分止损/止盈单,将取消所有挂单") + logger.Infof(" ⚠️ Hyperliquid 无法区分止损/止盈单,将取消所有挂单") return t.CancelStopOrders(symbol) } @@ -664,12 +664,12 @@ func (t *HyperliquidTrader) CancelAllOrders(symbol string) error { if order.Coin == coin { _, err := t.exchange.Cancel(t.ctx, coin, order.Oid) if err != nil { - log.Printf(" ⚠ 取消订单失败 (oid=%d): %v", order.Oid, err) + logger.Infof(" ⚠ 取消订单失败 (oid=%d): %v", order.Oid, err) } } } - log.Printf(" ✓ 已取消 %s 的所有挂单", symbol) + logger.Infof(" ✓ 已取消 %s 的所有挂单", symbol) return nil } @@ -691,7 +691,7 @@ func (t *HyperliquidTrader) CancelStopOrders(symbol string) error { if order.Coin == coin { _, err := t.exchange.Cancel(t.ctx, coin, order.Oid) if err != nil { - log.Printf(" ⚠ 取消订单失败 (oid=%d): %v", order.Oid, err) + logger.Infof(" ⚠ 取消订单失败 (oid=%d): %v", order.Oid, err) continue } canceledCount++ @@ -699,9 +699,9 @@ func (t *HyperliquidTrader) CancelStopOrders(symbol string) error { } if canceledCount == 0 { - log.Printf(" ℹ %s 没有挂单需要取消", symbol) + logger.Infof(" ℹ %s 没有挂单需要取消", symbol) } else { - log.Printf(" ✓ 已取消 %s 的 %d 个挂单(包括止盈/止损单)", symbol, canceledCount) + logger.Infof(" ✓ 已取消 %s 的 %d 个挂单(包括止盈/止损单)", symbol, canceledCount) } return nil @@ -762,7 +762,7 @@ func (t *HyperliquidTrader) SetStopLoss(symbol string, positionSide string, quan return fmt.Errorf("设置止损失败: %w", err) } - log.Printf(" 止损价设置: %.4f", roundedStopPrice) + logger.Infof(" 止损价设置: %.4f", roundedStopPrice) return nil } @@ -799,7 +799,7 @@ func (t *HyperliquidTrader) SetTakeProfit(symbol string, positionSide string, qu return fmt.Errorf("设置止盈失败: %w", err) } - log.Printf(" 止盈价设置: %.4f", roundedTakeProfitPrice) + logger.Infof(" 止盈价设置: %.4f", roundedTakeProfitPrice) return nil } @@ -820,7 +820,7 @@ func (t *HyperliquidTrader) getSzDecimals(coin string) int { defer t.metaMutex.RUnlock() if t.meta == nil { - log.Printf("⚠️ meta信息为空,使用默认精度4") + logger.Infof("⚠️ meta信息为空,使用默认精度4") return 4 // 默认精度 } @@ -831,7 +831,7 @@ func (t *HyperliquidTrader) getSzDecimals(coin string) int { } } - log.Printf("⚠️ 未找到 %s 的精度信息,使用默认精度4", coin) + logger.Infof("⚠️ 未找到 %s 的精度信息,使用默认精度4", coin) return 4 // 默认精度 } @@ -897,6 +897,53 @@ func convertSymbolToHyperliquid(symbol string) string { return symbol } +// GetOrderStatus 获取订单状态 +// Hyperliquid 使用 IOC 订单,通常立即成交或取消 +// 对于已完成的订单,需要查询历史记录 +func (t *HyperliquidTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { + // Hyperliquid 的 IOC 订单几乎立即完成 + // 如果订单是通过本系统下单的,返回的 status 都是 FILLED + // 这里尝试查询开放订单来判断是否还在等待 + coin := convertSymbolToHyperliquid(symbol) + + // 首先检查是否在开放订单中 + openOrders, err := t.exchange.Info().OpenOrders(t.ctx, t.walletAddr) + if err != nil { + // 如果查询失败,假设订单已完成 + return map[string]interface{}{ + "orderId": orderID, + "status": "FILLED", + "avgPrice": 0.0, + "executedQty": 0.0, + "commission": 0.0, + }, nil + } + + // 检查订单是否在开放订单列表中 + for _, order := range openOrders { + if order.Coin == coin && fmt.Sprintf("%d", order.Oid) == orderID { + // 订单仍在等待 + return map[string]interface{}{ + "orderId": orderID, + "status": "NEW", + "avgPrice": 0.0, + "executedQty": 0.0, + "commission": 0.0, + }, nil + } + } + + // 订单不在开放列表中,说明已完成或已取消 + // Hyperliquid IOC 订单如果不在开放列表中,通常是已成交 + return map[string]interface{}{ + "orderId": orderID, + "status": "FILLED", + "avgPrice": 0.0, // Hyperliquid 不直接返回成交价格,需要从持仓信息获取 + "executedQty": 0.0, + "commission": 0.0, + }, nil +} + // absFloat 返回浮点数的绝对值 func absFloat(x float64) float64 { if x < 0 { diff --git a/trader/interface.go b/trader/interface.go index 3d3a6e90..b1fa555f 100644 --- a/trader/interface.go +++ b/trader/interface.go @@ -50,4 +50,8 @@ type Trader interface { // FormatQuantity 格式化数量到正确的精度 FormatQuantity(symbol string, quantity float64) (string, error) + + // GetOrderStatus 获取订单状态 + // 返回: status(FILLED/NEW/CANCELED), avgPrice, executedQty, commission + GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) } diff --git a/trader/lighter_orders.go b/trader/lighter_orders.go index d16604a4..f95c67eb 100644 --- a/trader/lighter_orders.go +++ b/trader/lighter_orders.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" "io" - "log" + "nofx/logger" "net/http" ) @@ -62,7 +62,7 @@ func (t *LighterTrader) CreateOrder(symbol, side string, quantity, price float64 return "", err } - log.Printf("✓ LIGHTER订单已创建 - ID: %s, Symbol: %s, Side: %s, Qty: %.4f", + logger.Infof("✓ LIGHTER订单已创建 - ID: %s, Symbol: %s, Side: %s, Qty: %.4f", orderResp.OrderID, symbol, side, quantity) return orderResp.OrderID, nil @@ -143,7 +143,7 @@ func (t *LighterTrader) CancelOrder(symbol, orderID string) error { return fmt.Errorf("取消订单失败 (status %d): %s", resp.StatusCode, string(body)) } - log.Printf("✓ LIGHTER订单已取消 - ID: %s", orderID) + logger.Infof("✓ LIGHTER订单已取消 - ID: %s", orderID) return nil } @@ -160,18 +160,18 @@ func (t *LighterTrader) CancelAllOrders(symbol string) error { } if len(orders) == 0 { - log.Printf("✓ LIGHTER - 无需取消订单(无活跃订单)") + logger.Infof("✓ LIGHTER - 无需取消订单(无活跃订单)") return nil } // 批量取消 for _, order := range orders { if err := t.CancelOrder(symbol, order.OrderID); err != nil { - log.Printf("⚠️ 取消订单失败 (ID: %s): %v", order.OrderID, err) + logger.Infof("⚠️ 取消订单失败 (ID: %s): %v", order.OrderID, err) } } - log.Printf("✓ LIGHTER - 已取消 %d 个订单", len(orders)) + logger.Infof("✓ LIGHTER - 已取消 %d 个订单", len(orders)) return nil } @@ -223,8 +223,8 @@ func (t *LighterTrader) GetActiveOrders(symbol string) ([]OrderResponse, error) return orders, nil } -// GetOrderStatus 获取订单状态 -func (t *LighterTrader) GetOrderStatus(orderID string) (*OrderResponse, error) { +// GetOrderStatus 获取订单状态(实现 Trader 接口) +func (t *LighterTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { if err := t.ensureAuthToken(); err != nil { return nil, fmt.Errorf("认证令牌无效: %w", err) } @@ -261,20 +261,37 @@ func (t *LighterTrader) GetOrderStatus(orderID string) (*OrderResponse, error) { return nil, fmt.Errorf("解析订单响应失败: %w", err) } - return &order, nil + // 转换状态为统一格式 + unifiedStatus := order.Status + switch order.Status { + case "filled": + unifiedStatus = "FILLED" + case "open": + unifiedStatus = "NEW" + case "cancelled": + unifiedStatus = "CANCELED" + } + + return map[string]interface{}{ + "orderId": order.OrderID, + "status": unifiedStatus, + "avgPrice": order.Price, + "executedQty": order.FilledQty, + "commission": 0.0, + }, nil } // CancelStopLossOrders 仅取消止损单(LIGHTER 暂无法区分,取消所有止盈止损单) func (t *LighterTrader) CancelStopLossOrders(symbol string) error { // LIGHTER 暂时无法区分止损和止盈单,取消所有止盈止损单 - log.Printf(" ⚠️ LIGHTER 无法区分止损/止盈单,将取消所有止盈止损单") + logger.Infof(" ⚠️ LIGHTER 无法区分止损/止盈单,将取消所有止盈止损单") return t.CancelStopOrders(symbol) } // CancelTakeProfitOrders 仅取消止盈单(LIGHTER 暂无法区分,取消所有止盈止损单) func (t *LighterTrader) CancelTakeProfitOrders(symbol string) error { // LIGHTER 暂时无法区分止损和止盈单,取消所有止盈止损单 - log.Printf(" ⚠️ LIGHTER 无法区分止损/止盈单,将取消所有止盈止损单") + logger.Infof(" ⚠️ LIGHTER 无法区分止损/止盈单,将取消所有止盈止损单") return t.CancelStopOrders(symbol) } @@ -295,12 +312,12 @@ func (t *LighterTrader) CancelStopOrders(symbol string) error { // TODO: 需要检查订单类型,只取消止盈止损单 // 暂时取消所有订单 if err := t.CancelOrder(symbol, order.OrderID); err != nil { - log.Printf("⚠️ 取消订单失败 (ID: %s): %v", order.OrderID, err) + logger.Infof("⚠️ 取消订单失败 (ID: %s): %v", order.OrderID, err) } else { canceledCount++ } } - log.Printf("✓ LIGHTER - 已取消 %d 个止盈止损单", canceledCount) + logger.Infof("✓ LIGHTER - 已取消 %d 个止盈止损单", canceledCount) return nil } diff --git a/trader/lighter_trader.go b/trader/lighter_trader.go index 66c427a1..7280550d 100644 --- a/trader/lighter_trader.go +++ b/trader/lighter_trader.go @@ -7,7 +7,7 @@ import ( "encoding/json" "fmt" "io" - "log" + "nofx/logger" "net/http" "strings" "sync" @@ -59,7 +59,7 @@ func NewLighterTrader(privateKeyHex string, walletAddr string, testnet bool) (*L // 从私钥派生钱包地址(如果未提供) if walletAddr == "" { walletAddr = crypto.PubkeyToAddress(*privateKey.Public().(*ecdsa.PublicKey)).Hex() - log.Printf("✓ 从私钥派生钱包地址: %s", walletAddr) + logger.Infof("✓ 从私钥派生钱包地址: %s", walletAddr) } // 选择API URL @@ -78,7 +78,7 @@ func NewLighterTrader(privateKeyHex string, walletAddr string, testnet bool) (*L symbolPrecision: make(map[string]SymbolPrecision), } - log.Printf("✓ LIGHTER交易器初始化成功 (testnet=%v, wallet=%s)", testnet, walletAddr) + logger.Infof("✓ LIGHTER交易器初始化成功 (testnet=%v, wallet=%s)", testnet, walletAddr) // 初始化账户信息(获取账户索引和API密钥) if err := trader.initializeAccount(); err != nil { @@ -100,7 +100,7 @@ func (t *LighterTrader) initializeAccount() error { t.accountIndex = accountInfo["index"].(int) t.accountMutex.Unlock() - log.Printf("✓ LIGHTER账户索引: %d", t.accountIndex) + logger.Infof("✓ LIGHTER账户索引: %d", t.accountIndex) // 2. 生成认证令牌(有效期8小时) if err := t.refreshAuthToken(); err != nil { @@ -153,7 +153,7 @@ func (t *LighterTrader) refreshAuthToken() error { // 临时实现:设置过期时间为8小时后 t.tokenExpiry = time.Now().Add(8 * time.Hour) - log.Printf("✓ 认证令牌已生成(有效期至: %s)", t.tokenExpiry.Format(time.RFC3339)) + logger.Infof("✓ 认证令牌已生成(有效期至: %s)", t.tokenExpiry.Format(time.RFC3339)) return nil } @@ -165,7 +165,7 @@ func (t *LighterTrader) ensureAuthToken() error { t.accountMutex.RUnlock() if expired { - log.Println("🔄 认证令牌即将过期,刷新中...") + logger.Info("🔄 认证令牌即将过期,刷新中...") return t.refreshAuthToken() } @@ -204,12 +204,12 @@ func (t *LighterTrader) GetExchangeType() string { // Close 关闭交易器 func (t *LighterTrader) Close() error { - log.Println("✓ LIGHTER交易器已关闭") + logger.Info("✓ LIGHTER交易器已关闭") return nil } // Run 运行交易器(实现Trader接口) func (t *LighterTrader) Run() error { - log.Println("⚠️ LIGHTER交易器的Run方法应由AutoTrader调用") + logger.Info("⚠️ LIGHTER交易器的Run方法应由AutoTrader调用") return fmt.Errorf("请使用AutoTrader管理交易器生命周期") } diff --git a/trader/lighter_trader_v2.go b/trader/lighter_trader_v2.go index f6510a40..673c3741 100644 --- a/trader/lighter_trader_v2.go +++ b/trader/lighter_trader_v2.go @@ -6,7 +6,7 @@ import ( "encoding/json" "fmt" "io" - "log" + "nofx/logger" "net/http" "strings" "sync" @@ -76,7 +76,7 @@ func NewLighterTraderV2(l1PrivateKeyHex, walletAddr, apiKeyPrivateKeyHex string, // 2. 如果沒有提供錢包地址,從私鑰派生 if walletAddr == "" { walletAddr = crypto.PubkeyToAddress(*l1PrivateKey.Public().(*ecdsa.PublicKey)).Hex() - log.Printf("✓ 從私鑰派生錢包地址: %s", walletAddr) + logger.Infof("✓ 從私鑰派生錢包地址: %s", walletAddr) } // 3. 確定 API URL 和 Chain ID @@ -112,8 +112,8 @@ func NewLighterTraderV2(l1PrivateKeyHex, walletAddr, apiKeyPrivateKeyHex string, // 6. 如果沒有 API Key,提示用戶需要生成 if apiKeyPrivateKeyHex == "" { - log.Printf("⚠️ 未提供 API Key 私鑰,請調用 GenerateAndRegisterAPIKey() 生成") - log.Printf(" 或者從 LIGHTER 官網獲取現有的 API Key") + logger.Infof("⚠️ 未提供 API Key 私鑰,請調用 GenerateAndRegisterAPIKey() 生成") + logger.Infof(" 或者從 LIGHTER 官網獲取現有的 API Key") return trader, nil } @@ -133,12 +133,12 @@ func NewLighterTraderV2(l1PrivateKeyHex, walletAddr, apiKeyPrivateKeyHex string, // 8. 驗證 API Key 是否正確 if err := trader.checkClient(); err != nil { - log.Printf("⚠️ API Key 驗證失敗: %v", err) - log.Printf(" 您可能需要重新生成 API Key 或檢查配置") + logger.Infof("⚠️ API Key 驗證失敗: %v", err) + logger.Infof(" 您可能需要重新生成 API Key 或檢查配置") return trader, err } - log.Printf("✓ LIGHTER 交易器初始化成功 (account=%d, apiKey=%d, testnet=%v)", + logger.Infof("✓ LIGHTER 交易器初始化成功 (account=%d, apiKey=%d, testnet=%v)", trader.accountIndex, trader.apiKeyIndex, testnet) return trader, nil @@ -156,7 +156,7 @@ func (t *LighterTraderV2) initializeAccount() error { t.accountIndex = accountInfo.AccountIndex t.accountMutex.Unlock() - log.Printf("✓ 賬戶索引: %d", t.accountIndex) + logger.Infof("✓ 賬戶索引: %d", t.accountIndex) return nil } @@ -214,7 +214,7 @@ func (t *LighterTraderV2) checkClient() error { return fmt.Errorf("API Key 不匹配:本地=%s, 服務器=%s", localPubKey, publicKey) } - log.Printf("✓ API Key 驗證通過") + logger.Infof("✓ API Key 驗證通過") return nil } @@ -249,7 +249,7 @@ func (t *LighterTraderV2) refreshAuthToken() error { t.tokenExpiry = deadline t.accountMutex.Unlock() - log.Printf("✓ 認證令牌已生成(有效期至: %s)", t.tokenExpiry.Format(time.RFC3339)) + logger.Infof("✓ 認證令牌已生成(有效期至: %s)", t.tokenExpiry.Format(time.RFC3339)) return nil } @@ -260,7 +260,7 @@ func (t *LighterTraderV2) ensureAuthToken() error { t.accountMutex.RUnlock() if expired { - log.Println("🔄 認證令牌即將過期,刷新中...") + logger.Info("🔄 認證令牌即將過期,刷新中...") return t.refreshAuthToken() } @@ -274,6 +274,6 @@ func (t *LighterTraderV2) GetExchangeType() string { // Cleanup 清理資源 func (t *LighterTraderV2) Cleanup() error { - log.Println("⏹ LIGHTER 交易器清理完成") + logger.Info("⏹ LIGHTER 交易器清理完成") return nil } diff --git a/trader/lighter_trader_v2_orders.go b/trader/lighter_trader_v2_orders.go index 1c207826..8ddf2687 100644 --- a/trader/lighter_trader_v2_orders.go +++ b/trader/lighter_trader_v2_orders.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" "io" - "log" + "nofx/logger" "net/http" "strconv" @@ -18,7 +18,7 @@ func (t *LighterTraderV2) SetStopLoss(symbol string, positionSide string, quanti return fmt.Errorf("TxClient 未初始化") } - log.Printf("🛑 LIGHTER 設置止損: %s %s qty=%.4f, stop=%.2f", symbol, positionSide, quantity, stopPrice) + logger.Infof("🛑 LIGHTER 設置止損: %s %s qty=%.4f, stop=%.2f", symbol, positionSide, quantity, stopPrice) // 確定訂單方向(做空止損用買單,做多止損用賣單) isAsk := (positionSide == "LONG" || positionSide == "long") @@ -29,7 +29,7 @@ func (t *LighterTraderV2) SetStopLoss(symbol string, positionSide string, quanti return fmt.Errorf("設置止損失敗: %w", err) } - log.Printf("✓ LIGHTER 止損已設置: %.2f", stopPrice) + logger.Infof("✓ LIGHTER 止損已設置: %.2f", stopPrice) return nil } @@ -39,7 +39,7 @@ func (t *LighterTraderV2) SetTakeProfit(symbol string, positionSide string, quan return fmt.Errorf("TxClient 未初始化") } - log.Printf("🎯 LIGHTER 設置止盈: %s %s qty=%.4f, tp=%.2f", symbol, positionSide, quantity, takeProfitPrice) + logger.Infof("🎯 LIGHTER 設置止盈: %s %s qty=%.4f, tp=%.2f", symbol, positionSide, quantity, takeProfitPrice) // 確定訂單方向(做空止盈用買單,做多止盈用賣單) isAsk := (positionSide == "LONG" || positionSide == "long") @@ -50,7 +50,7 @@ func (t *LighterTraderV2) SetTakeProfit(symbol string, positionSide string, quan return fmt.Errorf("設置止盈失敗: %w", err) } - log.Printf("✓ LIGHTER 止盈已設置: %.2f", takeProfitPrice) + logger.Infof("✓ LIGHTER 止盈已設置: %.2f", takeProfitPrice) return nil } @@ -71,7 +71,7 @@ func (t *LighterTraderV2) CancelAllOrders(symbol string) error { } if len(orders) == 0 { - log.Printf("✓ LIGHTER - 無需取消訂單(無活躍訂單)") + logger.Infof("✓ LIGHTER - 無需取消訂單(無活躍訂單)") return nil } @@ -79,27 +79,101 @@ func (t *LighterTraderV2) CancelAllOrders(symbol string) error { canceledCount := 0 for _, order := range orders { if err := t.CancelOrder(symbol, order.OrderID); err != nil { - log.Printf("⚠️ 取消訂單失敗 (ID: %s): %v", order.OrderID, err) + logger.Infof("⚠️ 取消訂單失敗 (ID: %s): %v", order.OrderID, err) } else { canceledCount++ } } - log.Printf("✓ LIGHTER - 已取消 %d 個訂單", canceledCount) + logger.Infof("✓ LIGHTER - 已取消 %d 個訂單", canceledCount) return nil } +// GetOrderStatus 獲取訂單狀態(實現 Trader 接口) +func (t *LighterTraderV2) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) { + // LIGHTER 使用市價單通常立即成交 + // 嘗試查詢訂單狀態 + if err := t.ensureAuthToken(); err != nil { + return nil, fmt.Errorf("認證令牌無效: %w", err) + } + + // 構建請求 URL + endpoint := fmt.Sprintf("%s/api/v1/order/%s", t.baseURL, orderID) + + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", t.authToken) + req.Header.Set("Content-Type", "application/json") + + resp, err := t.client.Do(req) + if err != nil { + // 如果查詢失敗,假設訂單已完成 + return map[string]interface{}{ + "orderId": orderID, + "status": "FILLED", + "avgPrice": 0.0, + "executedQty": 0.0, + "commission": 0.0, + }, nil + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return map[string]interface{}{ + "orderId": orderID, + "status": "FILLED", + "avgPrice": 0.0, + "executedQty": 0.0, + "commission": 0.0, + }, nil + } + + var order OrderResponse + if err := json.Unmarshal(body, &order); err != nil { + return map[string]interface{}{ + "orderId": orderID, + "status": "FILLED", + "avgPrice": 0.0, + "executedQty": 0.0, + "commission": 0.0, + }, nil + } + + // 轉換狀態為統一格式 + unifiedStatus := order.Status + switch order.Status { + case "filled": + unifiedStatus = "FILLED" + case "open": + unifiedStatus = "NEW" + case "cancelled": + unifiedStatus = "CANCELED" + } + + return map[string]interface{}{ + "orderId": order.OrderID, + "status": unifiedStatus, + "avgPrice": order.Price, + "executedQty": order.FilledQty, + "commission": 0.0, + }, nil +} + // CancelStopLossOrders 僅取消止損單(實現 Trader 接口) func (t *LighterTraderV2) CancelStopLossOrders(symbol string) error { // LIGHTER 暫時無法區分止損和止盈單,取消所有止盈止損單 - log.Printf("⚠️ LIGHTER 無法區分止損/止盈單,將取消所有止盈止損單") + logger.Infof("⚠️ LIGHTER 無法區分止損/止盈單,將取消所有止盈止損單") return t.CancelStopOrders(symbol) } // CancelTakeProfitOrders 僅取消止盈單(實現 Trader 接口) func (t *LighterTraderV2) CancelTakeProfitOrders(symbol string) error { // LIGHTER 暫時無法區分止損和止盈單,取消所有止盈止損單 - log.Printf("⚠️ LIGHTER 無法區分止損/止盈單,將取消所有止盈止損單") + logger.Infof("⚠️ LIGHTER 無法區分止損/止盈單,將取消所有止盈止損單") return t.CancelStopOrders(symbol) } @@ -124,13 +198,13 @@ func (t *LighterTraderV2) CancelStopOrders(symbol string) error { // TODO: 檢查訂單類型,只取消止盈止損單 // 暫時取消所有訂單 if err := t.CancelOrder(symbol, order.OrderID); err != nil { - log.Printf("⚠️ 取消訂單失敗 (ID: %s): %v", order.OrderID, err) + logger.Infof("⚠️ 取消訂單失敗 (ID: %s): %v", order.OrderID, err) } else { canceledCount++ } } - log.Printf("✓ LIGHTER - 已取消 %d 個止盈止損單", canceledCount) + logger.Infof("✓ LIGHTER - 已取消 %d 個止盈止損單", canceledCount) return nil } @@ -186,7 +260,7 @@ func (t *LighterTraderV2) GetActiveOrders(symbol string) ([]OrderResponse, error return nil, fmt.Errorf("獲取活躍訂單失敗 (code %d): %s", apiResp.Code, apiResp.Message) } - log.Printf("✓ LIGHTER - 獲取到 %d 個活躍訂單", len(apiResp.Data)) + logger.Infof("✓ LIGHTER - 獲取到 %d 個活躍訂單", len(apiResp.Data)) return apiResp.Data, nil } @@ -235,7 +309,7 @@ func (t *LighterTraderV2) CancelOrder(symbol, orderID string) error { return fmt.Errorf("提交取消訂單失敗: %w", err) } - log.Printf("✓ LIGHTER訂單已取消 - ID: %s", orderID) + logger.Infof("✓ LIGHTER訂單已取消 - ID: %s", orderID) return nil } @@ -291,6 +365,6 @@ func (t *LighterTraderV2) submitCancelOrder(signedTx []byte) (map[string]interfa "status": "cancelled", } - log.Printf("✓ 取消訂單已提交到 LIGHTER - tx_hash: %v", sendResp.Data["tx_hash"]) + logger.Infof("✓ 取消訂單已提交到 LIGHTER - tx_hash: %v", sendResp.Data["tx_hash"]) return result, nil } diff --git a/trader/lighter_trader_v2_trading.go b/trader/lighter_trader_v2_trading.go index 36b13f55..00fe61c2 100644 --- a/trader/lighter_trader_v2_trading.go +++ b/trader/lighter_trader_v2_trading.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" "io" - "log" + "nofx/logger" "net/http" "time" @@ -18,11 +18,11 @@ func (t *LighterTraderV2) OpenLong(symbol string, quantity float64, leverage int return nil, fmt.Errorf("TxClient 未初始化,請先設置 API Key") } - log.Printf("📈 LIGHTER 開多倉: %s, qty=%.4f, leverage=%dx", symbol, quantity, leverage) + logger.Infof("📈 LIGHTER 開多倉: %s, qty=%.4f, leverage=%dx", symbol, quantity, leverage) // 1. 設置杠杆(如果需要) if err := t.SetLeverage(symbol, leverage); err != nil { - log.Printf("⚠️ 設置杠杆失敗: %v", err) + logger.Infof("⚠️ 設置杠杆失敗: %v", err) } // 2. 獲取市場價格 @@ -37,7 +37,7 @@ func (t *LighterTraderV2) OpenLong(symbol string, quantity float64, leverage int return nil, fmt.Errorf("開多倉失敗: %w", err) } - log.Printf("✓ LIGHTER 開多倉成功: %s @ %.2f", symbol, marketPrice) + logger.Infof("✓ LIGHTER 開多倉成功: %s @ %.2f", symbol, marketPrice) return map[string]interface{}{ "orderId": orderResult["orderId"], @@ -54,11 +54,11 @@ func (t *LighterTraderV2) OpenShort(symbol string, quantity float64, leverage in return nil, fmt.Errorf("TxClient 未初始化,請先設置 API Key") } - log.Printf("📉 LIGHTER 開空倉: %s, qty=%.4f, leverage=%dx", symbol, quantity, leverage) + logger.Infof("📉 LIGHTER 開空倉: %s, qty=%.4f, leverage=%dx", symbol, quantity, leverage) // 1. 設置杠杆 if err := t.SetLeverage(symbol, leverage); err != nil { - log.Printf("⚠️ 設置杠杆失敗: %v", err) + logger.Infof("⚠️ 設置杠杆失敗: %v", err) } // 2. 獲取市場價格 @@ -73,7 +73,7 @@ func (t *LighterTraderV2) OpenShort(symbol string, quantity float64, leverage in return nil, fmt.Errorf("開空倉失敗: %w", err) } - log.Printf("✓ LIGHTER 開空倉成功: %s @ %.2f", symbol, marketPrice) + logger.Infof("✓ LIGHTER 開空倉成功: %s @ %.2f", symbol, marketPrice) return map[string]interface{}{ "orderId": orderResult["orderId"], @@ -105,7 +105,7 @@ func (t *LighterTraderV2) CloseLong(symbol string, quantity float64) (map[string quantity = pos.Size } - log.Printf("🔻 LIGHTER 平多倉: %s, qty=%.4f", symbol, quantity) + logger.Infof("🔻 LIGHTER 平多倉: %s, qty=%.4f", symbol, quantity) // 創建市價賣出單平倉(reduceOnly=true) orderResult, err := t.CreateOrder(symbol, true, quantity, 0, "market") @@ -115,10 +115,10 @@ func (t *LighterTraderV2) CloseLong(symbol string, quantity float64) (map[string // 平倉後取消所有掛單 if err := t.CancelAllOrders(symbol); err != nil { - log.Printf("⚠️ 取消掛單失敗: %v", err) + logger.Infof("⚠️ 取消掛單失敗: %v", err) } - log.Printf("✓ LIGHTER 平多倉成功: %s", symbol) + logger.Infof("✓ LIGHTER 平多倉成功: %s", symbol) return map[string]interface{}{ "orderId": orderResult["orderId"], @@ -148,7 +148,7 @@ func (t *LighterTraderV2) CloseShort(symbol string, quantity float64) (map[strin quantity = pos.Size } - log.Printf("🔺 LIGHTER 平空倉: %s, qty=%.4f", symbol, quantity) + logger.Infof("🔺 LIGHTER 平空倉: %s, qty=%.4f", symbol, quantity) // 創建市價買入單平倉(reduceOnly=true) orderResult, err := t.CreateOrder(symbol, false, quantity, 0, "market") @@ -158,10 +158,10 @@ func (t *LighterTraderV2) CloseShort(symbol string, quantity float64) (map[strin // 平倉後取消所有掛單 if err := t.CancelAllOrders(symbol); err != nil { - log.Printf("⚠️ 取消掛單失敗: %v", err) + logger.Infof("⚠️ 取消掛單失敗: %v", err) } - log.Printf("✓ LIGHTER 平空倉成功: %s", symbol) + logger.Infof("✓ LIGHTER 平空倉成功: %s", symbol) return map[string]interface{}{ "orderId": orderResult["orderId"], @@ -235,7 +235,7 @@ func (t *LighterTraderV2) CreateOrder(symbol string, isAsk bool, quantity float6 if isAsk { side = "sell" } - log.Printf("✓ LIGHTER訂單已創建: %s %s qty=%.4f", symbol, side, quantity) + logger.Infof("✓ LIGHTER訂單已創建: %s %s qty=%.4f", symbol, side, quantity) return orderResp, nil } @@ -315,7 +315,7 @@ func (t *LighterTraderV2) submitOrder(signedTx []byte) (map[string]interface{}, result["orderId"] = txHash } - log.Printf("✓ 訂單已提交到 LIGHTER - tx_hash: %v", sendResp.Data["tx_hash"]) + logger.Infof("✓ 訂單已提交到 LIGHTER - tx_hash: %v", sendResp.Data["tx_hash"]) return result, nil } @@ -334,7 +334,7 @@ func (t *LighterTraderV2) getMarketIndex(symbol string) (uint8, error) { markets, err := t.fetchMarketList() if err != nil { // 如果 API 失敗,回退到硬編碼映射 - log.Printf("⚠️ 從 API 獲取市場列表失敗,使用硬編碼映射: %v", err) + logger.Infof("⚠️ 從 API 獲取市場列表失敗,使用硬編碼映射: %v", err) return t.getFallbackMarketIndex(symbol) } @@ -412,7 +412,7 @@ func (t *LighterTraderV2) fetchMarketList() ([]MarketInfo, error) { } } - log.Printf("✓ 獲取到 %d 個市場", len(markets)) + logger.Infof("✓ 獲取到 %d 個市場", len(markets)) return markets, nil } @@ -428,7 +428,7 @@ func (t *LighterTraderV2) getFallbackMarketIndex(symbol string) (uint8, error) { } if index, ok := fallbackMap[symbol]; ok { - log.Printf("✓ 使用硬編碼市場索引: %s -> %d", symbol, index) + logger.Infof("✓ 使用硬編碼市場索引: %s -> %d", symbol, index) return index, nil } @@ -442,7 +442,7 @@ func (t *LighterTraderV2) SetLeverage(symbol string, leverage int) error { } // TODO: 使用SDK簽名並提交SetLeverage交易 - log.Printf("⚙️ 設置杠杆: %s = %dx", symbol, leverage) + logger.Infof("⚙️ 設置杠杆: %s = %dx", symbol, leverage) return nil // 暫時返回成功 } @@ -458,7 +458,7 @@ func (t *LighterTraderV2) SetMarginMode(symbol string, isCrossMargin bool) error modeStr = "全倉" } - log.Printf("⚙️ 設置倉位模式: %s = %s", symbol, modeStr) + logger.Infof("⚙️ 設置倉位模式: %s = %s", symbol, modeStr) // TODO: 使用SDK簽名並提交SetMarginMode交易 return nil diff --git a/trader/lighter_trading.go b/trader/lighter_trading.go index 26fab466..ee1a21cf 100644 --- a/trader/lighter_trading.go +++ b/trader/lighter_trading.go @@ -2,13 +2,13 @@ package trader import ( "fmt" - "log" + "nofx/logger" ) // OpenLong 开多仓 func (t *LighterTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { // TODO: 实现完整的开多仓逻辑 - log.Printf("🚧 LIGHTER OpenLong 暂未完全实现 (symbol=%s, qty=%.4f, leverage=%d)", symbol, quantity, leverage) + logger.Infof("🚧 LIGHTER OpenLong 暂未完全实现 (symbol=%s, qty=%.4f, leverage=%d)", symbol, quantity, leverage) // 使用市价买入单 orderID, err := t.CreateOrder(symbol, "buy", quantity, 0, "market") @@ -26,7 +26,7 @@ func (t *LighterTrader) OpenLong(symbol string, quantity float64, leverage int) // OpenShort 开空仓 func (t *LighterTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) { // TODO: 实现完整的开空仓逻辑 - log.Printf("🚧 LIGHTER OpenShort 暂未完全实现 (symbol=%s, qty=%.4f, leverage=%d)", symbol, quantity, leverage) + logger.Infof("🚧 LIGHTER OpenShort 暂未完全实现 (symbol=%s, qty=%.4f, leverage=%d)", symbol, quantity, leverage) // 使用市价卖出单 orderID, err := t.CreateOrder(symbol, "sell", quantity, 0, "market") @@ -66,7 +66,7 @@ func (t *LighterTrader) CloseLong(symbol string, quantity float64) (map[string]i // 平仓后取消所有挂单 if err := t.CancelAllOrders(symbol); err != nil { - log.Printf(" ⚠ 取消挂单失败: %v", err) + logger.Infof(" ⚠ 取消挂单失败: %v", err) } return map[string]interface{}{ @@ -101,7 +101,7 @@ func (t *LighterTrader) CloseShort(symbol string, quantity float64) (map[string] // 平仓后取消所有挂单 if err := t.CancelAllOrders(symbol); err != nil { - log.Printf(" ⚠ 取消挂单失败: %v", err) + logger.Infof(" ⚠ 取消挂单失败: %v", err) } return map[string]interface{}{ @@ -114,7 +114,7 @@ func (t *LighterTrader) CloseShort(symbol string, quantity float64) (map[string] // SetStopLoss 设置止损单 func (t *LighterTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error { // TODO: 实现完整的止损单逻辑 - log.Printf("🚧 LIGHTER SetStopLoss 暂未完全实现 (symbol=%s, side=%s, qty=%.4f, stop=%.2f)", symbol, positionSide, quantity, stopPrice) + logger.Infof("🚧 LIGHTER SetStopLoss 暂未完全实现 (symbol=%s, side=%s, qty=%.4f, stop=%.2f)", symbol, positionSide, quantity, stopPrice) // 确定订单方向(做空止损用买单,做多止损用卖单) side := "sell" @@ -128,14 +128,14 @@ func (t *LighterTrader) SetStopLoss(symbol string, positionSide string, quantity return fmt.Errorf("设置止损失败: %w", err) } - log.Printf("✓ LIGHTER - 止损已设置: %.2f (side: %s)", stopPrice, side) + logger.Infof("✓ LIGHTER - 止损已设置: %.2f (side: %s)", stopPrice, side) return nil } // SetTakeProfit 设置止盈单 func (t *LighterTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error { // TODO: 实现完整的止盈单逻辑 - log.Printf("🚧 LIGHTER SetTakeProfit 暂未完全实现 (symbol=%s, side=%s, qty=%.4f, tp=%.2f)", symbol, positionSide, quantity, takeProfitPrice) + logger.Infof("🚧 LIGHTER SetTakeProfit 暂未完全实现 (symbol=%s, side=%s, qty=%.4f, tp=%.2f)", symbol, positionSide, quantity, takeProfitPrice) // 确定订单方向(做空止盈用买单,做多止盈用卖单) side := "sell" @@ -149,7 +149,7 @@ func (t *LighterTrader) SetTakeProfit(symbol string, positionSide string, quanti return fmt.Errorf("设置止盈失败: %w", err) } - log.Printf("✓ LIGHTER - 止盈已设置: %.2f (side: %s)", takeProfitPrice, side) + logger.Infof("✓ LIGHTER - 止盈已设置: %.2f (side: %s)", takeProfitPrice, side) return nil } @@ -160,7 +160,7 @@ func (t *LighterTrader) SetMarginMode(symbol string, isCrossMargin bool) error { if isCrossMargin { modeStr = "全仓" } - log.Printf("🚧 LIGHTER SetMarginMode 暂未实现 (symbol=%s, mode=%s)", symbol, modeStr) + logger.Infof("🚧 LIGHTER SetMarginMode 暂未实现 (symbol=%s, mode=%s)", symbol, modeStr) return nil } diff --git a/trader/order_sync.go b/trader/order_sync.go new file mode 100644 index 00000000..2562d0ed --- /dev/null +++ b/trader/order_sync.go @@ -0,0 +1,309 @@ +package trader + +import ( + "fmt" + "nofx/logger" + "nofx/store" + "sync" + "time" +) + +// OrderSyncManager 订单状态同步管理器 +// 负责定期扫描所有 NEW 状态的订单,并更新其状态 +type OrderSyncManager struct { + store *store.Store + interval time.Duration + stopCh chan struct{} + wg sync.WaitGroup + traderCache map[string]Trader // trader_id -> Trader 实例缓存 + configCache map[string]*store.TraderFullConfig // trader_id -> 配置缓存 + cacheMutex sync.RWMutex +} + +// NewOrderSyncManager 创建订单同步管理器 +func NewOrderSyncManager(st *store.Store, interval time.Duration) *OrderSyncManager { + if interval == 0 { + interval = 10 * time.Second + } + return &OrderSyncManager{ + store: st, + interval: interval, + stopCh: make(chan struct{}), + traderCache: make(map[string]Trader), + configCache: make(map[string]*store.TraderFullConfig), + } +} + +// Start 启动订单同步服务 +func (m *OrderSyncManager) Start() { + m.wg.Add(1) + go m.run() + logger.Info("📦 订单同步管理器已启动") +} + +// Stop 停止订单同步服务 +func (m *OrderSyncManager) Stop() { + close(m.stopCh) + m.wg.Wait() + + // 清理缓存 + m.cacheMutex.Lock() + m.traderCache = make(map[string]Trader) + m.configCache = make(map[string]*store.TraderFullConfig) + m.cacheMutex.Unlock() + + logger.Info("📦 订单同步管理器已停止") +} + +// run 主循环 +func (m *OrderSyncManager) run() { + defer m.wg.Done() + + // 启动时立即执行一次 + m.syncOrders() + + ticker := time.NewTicker(m.interval) + defer ticker.Stop() + + for { + select { + case <-m.stopCh: + return + case <-ticker.C: + m.syncOrders() + } + } +} + +// syncOrders 同步所有待处理订单 +func (m *OrderSyncManager) syncOrders() { + // 获取所有 NEW 状态的订单 + orders, err := m.store.Order().GetAllPendingOrders() + if err != nil { + logger.Infof("⚠️ 获取待处理订单失败: %v", err) + return + } + + if len(orders) == 0 { + return + } + + logger.Infof("📦 开始同步 %d 个待处理订单...", len(orders)) + + // 按 trader_id 分组 + ordersByTrader := make(map[string][]*store.TraderOrder) + for _, order := range orders { + ordersByTrader[order.TraderID] = append(ordersByTrader[order.TraderID], order) + } + + // 逐个 trader 处理 + for traderID, traderOrders := range ordersByTrader { + m.syncTraderOrders(traderID, traderOrders) + } +} + +// syncTraderOrders 同步单个 trader 的订单 +func (m *OrderSyncManager) syncTraderOrders(traderID string, orders []*store.TraderOrder) { + // 获取或创建 trader 实例 + trader, err := m.getOrCreateTrader(traderID) + if err != nil { + logger.Infof("⚠️ 获取 trader 实例失败 (ID: %s): %v", traderID, err) + return + } + + for _, order := range orders { + m.syncSingleOrder(trader, order) + } +} + +// syncSingleOrder 同步单个订单状态 +func (m *OrderSyncManager) syncSingleOrder(trader Trader, order *store.TraderOrder) { + status, err := trader.GetOrderStatus(order.Symbol, order.OrderID) + if err != nil { + // 查询失败,检查订单创建时间,超过一定时间假设已成交 + if time.Since(order.CreatedAt) > 5*time.Minute { + logger.Infof("⚠️ 订单查询超时,假设已成交 (ID: %s)", order.OrderID) + m.markOrderFilled(order, 0, 0, 0) + } + return + } + + statusStr, _ := status["status"].(string) + + switch statusStr { + case "FILLED": + avgPrice, _ := status["avgPrice"].(float64) + executedQty, _ := status["executedQty"].(float64) + commission, _ := status["commission"].(float64) + + // 如果 API 未返回数量,使用原始数量 + if executedQty == 0 { + executedQty = order.Quantity + } + + m.markOrderFilled(order, avgPrice, executedQty, commission) + + case "CANCELED", "EXPIRED": + order.Status = statusStr + if err := m.store.Order().Update(order); err != nil { + logger.Infof("⚠️ 更新订单状态失败: %v", err) + } else { + logger.Infof("📦 订单状态更新: %s (ID: %s)", statusStr, order.OrderID) + } + } +} + +// markOrderFilled 标记订单已成交 +func (m *OrderSyncManager) markOrderFilled(order *store.TraderOrder, avgPrice, executedQty, commission float64) { + // 如果 avgPrice 为 0,使用订单价格 + if avgPrice == 0 { + avgPrice = order.Price + } + if executedQty == 0 { + executedQty = order.Quantity + } + + // 计算已实现盈亏(仅平仓订单) + var realizedPnL float64 + if (order.Action == "close_long" || order.Action == "close_short") && order.EntryPrice > 0 && avgPrice > 0 { + if order.Action == "close_long" { + // 平多盈亏 = (平仓价 - 开仓价) * 数量 + realizedPnL = (avgPrice - order.EntryPrice) * executedQty + } else { + // 平空盈亏 = (开仓价 - 平仓价) * 数量 + realizedPnL = (order.EntryPrice - avgPrice) * executedQty + } + } + + order.AvgPrice = avgPrice + order.ExecutedQty = executedQty + order.Status = "FILLED" + order.Fee = commission + order.RealizedPnL = realizedPnL + order.FilledAt = time.Now() + + if err := m.store.Order().Update(order); err != nil { + logger.Infof("⚠️ 更新订单状态失败: %v", err) + } else { + if realizedPnL != 0 { + logger.Infof("✅ 订单已成交 (ID: %s, avgPrice: %.4f, qty: %.4f, PnL: %.2f)", + order.OrderID, avgPrice, executedQty, realizedPnL) + } else { + logger.Infof("✅ 订单已成交 (ID: %s, avgPrice: %.4f, qty: %.4f)", + order.OrderID, avgPrice, executedQty) + } + } +} + +// getOrCreateTrader 获取或创建 trader 实例 +func (m *OrderSyncManager) getOrCreateTrader(traderID string) (Trader, error) { + m.cacheMutex.RLock() + trader, exists := m.traderCache[traderID] + m.cacheMutex.RUnlock() + + if exists && trader != nil { + return trader, nil + } + + // 需要创建新的 trader 实例 + // 首先获取 trader 配置 + config, err := m.getTraderConfig(traderID) + if err != nil { + return nil, fmt.Errorf("获取 trader 配置失败: %w", err) + } + + // 根据交易所类型创建 trader + trader, err = m.createTrader(config) + if err != nil { + return nil, fmt.Errorf("创建 trader 实例失败: %w", err) + } + + m.cacheMutex.Lock() + m.traderCache[traderID] = trader + m.cacheMutex.Unlock() + + return trader, nil +} + +// getTraderConfig 获取 trader 配置 +func (m *OrderSyncManager) getTraderConfig(traderID string) (*store.TraderFullConfig, error) { + m.cacheMutex.RLock() + config, exists := m.configCache[traderID] + m.cacheMutex.RUnlock() + + if exists { + return config, nil + } + + // 从数据库获取 - 需要找到 trader 对应的 userID + // 首先查询所有 traders 找到对应的 userID + traders, err := m.store.Trader().ListAll() + if err != nil { + return nil, fmt.Errorf("获取 trader 列表失败: %w", err) + } + + var userID string + for _, t := range traders { + if t.ID == traderID { + userID = t.UserID + break + } + } + + if userID == "" { + return nil, fmt.Errorf("找不到 trader: %s", traderID) + } + + config, err = m.store.Trader().GetFullConfig(userID, traderID) + if err != nil { + return nil, err + } + + m.cacheMutex.Lock() + m.configCache[traderID] = config + m.cacheMutex.Unlock() + + return config, nil +} + +// createTrader 根据配置创建 trader 实例 +func (m *OrderSyncManager) createTrader(config *store.TraderFullConfig) (Trader, error) { + exchange := config.Exchange + + switch exchange.Type { + case "binance": + return NewFuturesTrader(exchange.APIKey, exchange.SecretKey, config.Trader.UserID), nil + + case "bybit": + return NewBybitTrader(exchange.APIKey, exchange.SecretKey), nil + + case "hyperliquid": + return NewHyperliquidTrader(exchange.SecretKey, exchange.HyperliquidWalletAddr, exchange.Testnet) + + case "aster": + return NewAsterTrader(exchange.AsterUser, exchange.AsterSigner, exchange.AsterPrivateKey) + + case "lighter": + if exchange.LighterAPIKeyPrivateKey != "" { + return NewLighterTraderV2( + exchange.LighterPrivateKey, + exchange.LighterWalletAddr, + exchange.LighterAPIKeyPrivateKey, + exchange.Testnet, + ) + } + return NewLighterTrader(exchange.LighterPrivateKey, exchange.LighterWalletAddr, exchange.Testnet) + + default: + return nil, fmt.Errorf("不支持的交易所类型: %s", exchange.Type) + } +} + +// InvalidateCache 使缓存失效(当配置变更时调用) +func (m *OrderSyncManager) InvalidateCache(traderID string) { + m.cacheMutex.Lock() + defer m.cacheMutex.Unlock() + + delete(m.traderCache, traderID) + delete(m.configCache, traderID) +} diff --git a/trader/partial_close_test.go b/trader/partial_close_test.go deleted file mode 100644 index 5b4b50be..00000000 --- a/trader/partial_close_test.go +++ /dev/null @@ -1,393 +0,0 @@ -package trader - -import ( - "fmt" - "nofx/decision" - "nofx/logger" - "testing" -) - -// MockPartialCloseTrader 用於測試 partial close 邏輯 -type MockPartialCloseTrader struct { - positions []map[string]interface{} - closePartialCalled bool - closeLongCalled bool - closeShortCalled bool - stopLossCalled bool - takeProfitCalled bool - lastStopLoss float64 - lastTakeProfit float64 -} - -func (m *MockPartialCloseTrader) GetPositions() ([]map[string]interface{}, error) { - return m.positions, nil -} - -func (m *MockPartialCloseTrader) ClosePartialLong(symbol string, quantity float64) (map[string]interface{}, error) { - m.closePartialCalled = true - return map[string]interface{}{"orderId": "12345"}, nil -} - -func (m *MockPartialCloseTrader) ClosePartialShort(symbol string, quantity float64) (map[string]interface{}, error) { - m.closePartialCalled = true - return map[string]interface{}{"orderId": "12345"}, nil -} - -func (m *MockPartialCloseTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) { - m.closeLongCalled = true - return map[string]interface{}{"orderId": "12346"}, nil -} - -func (m *MockPartialCloseTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) { - m.closeShortCalled = true - return map[string]interface{}{"orderId": "12346"}, nil -} - -func (m *MockPartialCloseTrader) SetStopLoss(symbol, side string, quantity, price float64) error { - m.stopLossCalled = true - m.lastStopLoss = price - return nil -} - -func (m *MockPartialCloseTrader) SetTakeProfit(symbol, side string, quantity, price float64) error { - m.takeProfitCalled = true - m.lastTakeProfit = price - return nil -} - -// TestPartialCloseMinPositionCheck 測試最小倉位檢查邏輯 -func TestPartialCloseMinPositionCheck(t *testing.T) { - tests := []struct { - name string - totalQuantity float64 - markPrice float64 - closePercentage float64 - expectFullClose bool // 是否應該觸發全平邏輯 - expectRemainValue float64 - }{ - { - name: "正常部分平倉_剩餘價值充足", - totalQuantity: 1.0, - markPrice: 100.0, - closePercentage: 50.0, - expectFullClose: false, - expectRemainValue: 50.0, // 剩餘 0.5 * 100 = 50 USDT - }, - { - name: "部分平倉_剩餘價值小於10USDT_應該全平", - totalQuantity: 0.2, - markPrice: 100.0, - closePercentage: 95.0, // 平倉 95%,剩餘 1 USDT (0.2 * 5% * 100) - expectFullClose: true, - expectRemainValue: 1.0, - }, - { - name: "部分平倉_剩餘價值剛好10USDT_應該全平", - totalQuantity: 1.0, - markPrice: 100.0, - closePercentage: 90.0, // 剩餘 10 USDT (1.0 * 10% * 100),邊界測試 (<=) - expectFullClose: true, - expectRemainValue: 10.0, - }, - { - name: "部分平倉_剩餘價值11USDT_不應全平", - totalQuantity: 1.1, - markPrice: 100.0, - closePercentage: 90.0, // 剩餘 11 USDT (1.1 * 10% * 100) - expectFullClose: false, - expectRemainValue: 11.0, - }, - { - name: "大倉位部分平倉_剩餘價值遠大於10USDT", - totalQuantity: 10.0, - markPrice: 1000.0, - closePercentage: 80.0, - expectFullClose: false, - expectRemainValue: 2000.0, // 剩餘 2 * 1000 = 2000 USDT - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 計算剩餘價值 - closeQuantity := tt.totalQuantity * (tt.closePercentage / 100.0) - remainingQuantity := tt.totalQuantity - closeQuantity - remainingValue := remainingQuantity * tt.markPrice - - // 驗證計算(使用浮點數比較允許微小誤差) - const epsilon = 0.001 - if remainingValue-tt.expectRemainValue > epsilon || tt.expectRemainValue-remainingValue > epsilon { - t.Errorf("計算錯誤: 剩餘價值 = %.2f, 期望 = %.2f", - remainingValue, tt.expectRemainValue) - } - - // 驗證最小倉位檢查邏輯 - const MIN_POSITION_VALUE = 10.0 - shouldFullClose := remainingValue > 0 && remainingValue <= MIN_POSITION_VALUE - - if shouldFullClose != tt.expectFullClose { - t.Errorf("最小倉位檢查失敗: shouldFullClose = %v, 期望 = %v (剩餘價值 = %.2f USDT)", - shouldFullClose, tt.expectFullClose, remainingValue) - } - }) - } -} - -// TestPartialCloseWithStopLossTakeProfitRecovery 測試止盈止損恢復邏輯 -func TestPartialCloseWithStopLossTakeProfitRecovery(t *testing.T) { - tests := []struct { - name string - newStopLoss float64 - newTakeProfit float64 - expectStopLoss bool - expectTakeProfit bool - }{ - { - name: "有新止損和止盈_應該恢復兩者", - newStopLoss: 95.0, - newTakeProfit: 110.0, - expectStopLoss: true, - expectTakeProfit: true, - }, - { - name: "只有新止損_僅恢復止損", - newStopLoss: 95.0, - newTakeProfit: 0, - expectStopLoss: true, - expectTakeProfit: false, - }, - { - name: "只有新止盈_僅恢復止盈", - newStopLoss: 0, - newTakeProfit: 110.0, - expectStopLoss: false, - expectTakeProfit: true, - }, - { - name: "沒有新止損止盈_不恢復", - newStopLoss: 0, - newTakeProfit: 0, - expectStopLoss: false, - expectTakeProfit: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 模擬止盈止損恢復邏輯 - stopLossRecovered := tt.newStopLoss > 0 - takeProfitRecovered := tt.newTakeProfit > 0 - - if stopLossRecovered != tt.expectStopLoss { - t.Errorf("止損恢復邏輯錯誤: recovered = %v, 期望 = %v", - stopLossRecovered, tt.expectStopLoss) - } - - if takeProfitRecovered != tt.expectTakeProfit { - t.Errorf("止盈恢復邏輯錯誤: recovered = %v, 期望 = %v", - takeProfitRecovered, tt.expectTakeProfit) - } - }) - } -} - -// TestPartialCloseEdgeCases 測試邊界情況 -func TestPartialCloseEdgeCases(t *testing.T) { - tests := []struct { - name string - closePercentage float64 - totalQuantity float64 - markPrice float64 - expectError bool - errorContains string - }{ - { - name: "平倉百分比為0_應該報錯", - closePercentage: 0, - totalQuantity: 1.0, - markPrice: 100.0, - expectError: true, - errorContains: "0-100", - }, - { - name: "平倉百分比超過100_應該報錯", - closePercentage: 101.0, - totalQuantity: 1.0, - markPrice: 100.0, - expectError: true, - errorContains: "0-100", - }, - { - name: "平倉百分比為負數_應該報錯", - closePercentage: -10.0, - totalQuantity: 1.0, - markPrice: 100.0, - expectError: true, - errorContains: "0-100", - }, - { - name: "正常範圍_不應報錯", - closePercentage: 50.0, - totalQuantity: 1.0, - markPrice: 100.0, - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 模擬百分比驗證邏輯 - var err error - if tt.closePercentage <= 0 || tt.closePercentage > 100 { - err = fmt.Errorf("平仓百分比必须在 0-100 之间,当前: %.1f", tt.closePercentage) - } - - if tt.expectError { - if err == nil { - t.Errorf("期望報錯但沒有報錯") - } - } else { - if err != nil { - t.Errorf("不應報錯但報錯了: %v", err) - } - } - }) - } -} - -// TestPartialCloseIntegration 整合測試(使用 mock trader) -func TestPartialCloseIntegration(t *testing.T) { - tests := []struct { - name string - symbol string - side string - totalQuantity float64 - markPrice float64 - closePercentage float64 - newStopLoss float64 - newTakeProfit float64 - expectFullClose bool - expectStopLossCall bool - expectTakeProfitCall bool - }{ - { - name: "LONG倉_正常部分平倉_有止盈止損", - symbol: "BTCUSDT", - side: "LONG", - totalQuantity: 1.0, - markPrice: 50000.0, - closePercentage: 50.0, - newStopLoss: 48000.0, - newTakeProfit: 52000.0, - expectFullClose: false, - expectStopLossCall: true, - expectTakeProfitCall: true, - }, - { - name: "SHORT倉_剩餘價值過小_應自動全平", - symbol: "ETHUSDT", - side: "SHORT", - totalQuantity: 0.02, - markPrice: 3000.0, // 總價值 60 USDT - closePercentage: 95.0, // 剩餘 3 USDT < 10 USDT - newStopLoss: 3100.0, - newTakeProfit: 2900.0, - expectFullClose: true, - expectStopLossCall: false, // 全平不需要恢復止盈止損 - expectTakeProfitCall: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 創建 mock trader - mockTrader := &MockPartialCloseTrader{ - positions: []map[string]interface{}{ - { - "symbol": tt.symbol, - "side": tt.side, - "quantity": tt.totalQuantity, - "markPrice": tt.markPrice, - }, - }, - } - - // 創建決策 - dec := &decision.Decision{ - Symbol: tt.symbol, - Action: "partial_close", - ClosePercentage: tt.closePercentage, - NewStopLoss: tt.newStopLoss, - NewTakeProfit: tt.newTakeProfit, - } - - // 創建 actionRecord - actionRecord := &logger.DecisionAction{} - - // 計算剩餘價值 - closeQuantity := tt.totalQuantity * (tt.closePercentage / 100.0) - remainingQuantity := tt.totalQuantity - closeQuantity - remainingValue := remainingQuantity * tt.markPrice - - // 驗證最小倉位檢查 - const MIN_POSITION_VALUE = 10.0 - shouldFullClose := remainingValue > 0 && remainingValue <= MIN_POSITION_VALUE - - if shouldFullClose != tt.expectFullClose { - t.Errorf("最小倉位檢查不符: shouldFullClose = %v, 期望 = %v (剩餘 %.2f USDT)", - shouldFullClose, tt.expectFullClose, remainingValue) - } - - // 模擬執行邏輯 - if shouldFullClose { - // 應該轉為全平 - if tt.side == "LONG" { - mockTrader.CloseLong(tt.symbol, tt.totalQuantity) - } else { - mockTrader.CloseShort(tt.symbol, tt.totalQuantity) - } - } else { - // 正常部分平倉 - if tt.side == "LONG" { - mockTrader.ClosePartialLong(tt.symbol, closeQuantity) - } else { - mockTrader.ClosePartialShort(tt.symbol, closeQuantity) - } - - // 恢復止盈止損 - if dec.NewStopLoss > 0 { - mockTrader.SetStopLoss(tt.symbol, tt.side, remainingQuantity, dec.NewStopLoss) - } - if dec.NewTakeProfit > 0 { - mockTrader.SetTakeProfit(tt.symbol, tt.side, remainingQuantity, dec.NewTakeProfit) - } - } - - // 驗證調用 - if tt.expectFullClose { - if !mockTrader.closeLongCalled && !mockTrader.closeShortCalled { - t.Error("期望調用全平但沒有調用") - } - if mockTrader.closePartialCalled { - t.Error("不應該調用部分平倉") - } - } else { - if !mockTrader.closePartialCalled { - t.Error("期望調用部分平倉但沒有調用") - } - } - - if mockTrader.stopLossCalled != tt.expectStopLossCall { - t.Errorf("止損調用不符: called = %v, 期望 = %v", - mockTrader.stopLossCalled, tt.expectStopLossCall) - } - - if mockTrader.takeProfitCalled != tt.expectTakeProfitCall { - t.Errorf("止盈調用不符: called = %v, 期望 = %v", - mockTrader.takeProfitCalled, tt.expectTakeProfitCall) - } - - _ = actionRecord // 避免未使用警告 - }) - } -} diff --git a/trader/position_sync.go b/trader/position_sync.go new file mode 100644 index 00000000..30c2ca56 --- /dev/null +++ b/trader/position_sync.go @@ -0,0 +1,318 @@ +package trader + +import ( + "fmt" + "nofx/logger" + "nofx/store" + "sync" + "time" +) + +// PositionSyncManager 仓位状态同步管理器 +// 负责定期同步交易所仓位,检测手动平仓等变化 +type PositionSyncManager struct { + store *store.Store + interval time.Duration + stopCh chan struct{} + wg sync.WaitGroup + traderCache map[string]Trader // trader_id -> Trader 实例缓存 + configCache map[string]*store.TraderFullConfig // trader_id -> 配置缓存 + cacheMutex sync.RWMutex +} + +// NewPositionSyncManager 创建仓位同步管理器 +func NewPositionSyncManager(st *store.Store, interval time.Duration) *PositionSyncManager { + if interval == 0 { + interval = 10 * time.Second + } + return &PositionSyncManager{ + store: st, + interval: interval, + stopCh: make(chan struct{}), + traderCache: make(map[string]Trader), + configCache: make(map[string]*store.TraderFullConfig), + } +} + +// Start 启动仓位同步服务 +func (m *PositionSyncManager) Start() { + m.wg.Add(1) + go m.run() + logger.Info("📊 仓位同步管理器已启动") +} + +// Stop 停止仓位同步服务 +func (m *PositionSyncManager) Stop() { + close(m.stopCh) + m.wg.Wait() + + // 清理缓存 + m.cacheMutex.Lock() + m.traderCache = make(map[string]Trader) + m.configCache = make(map[string]*store.TraderFullConfig) + m.cacheMutex.Unlock() + + logger.Info("📊 仓位同步管理器已停止") +} + +// run 主循环 +func (m *PositionSyncManager) run() { + defer m.wg.Done() + + // 启动时立即执行一次 + m.syncPositions() + + ticker := time.NewTicker(m.interval) + defer ticker.Stop() + + for { + select { + case <-m.stopCh: + return + case <-ticker.C: + m.syncPositions() + } + } +} + +// syncPositions 同步所有仓位状态 +func (m *PositionSyncManager) syncPositions() { + // 获取所有 OPEN 状态的仓位 + localPositions, err := m.store.Position().GetAllOpenPositions() + if err != nil { + logger.Infof("⚠️ 获取本地仓位失败: %v", err) + return + } + + if len(localPositions) == 0 { + return + } + + // 按 trader_id 分组 + positionsByTrader := make(map[string][]*store.TraderPosition) + for _, pos := range localPositions { + positionsByTrader[pos.TraderID] = append(positionsByTrader[pos.TraderID], pos) + } + + // 逐个 trader 处理 + for traderID, traderPositions := range positionsByTrader { + m.syncTraderPositions(traderID, traderPositions) + } +} + +// syncTraderPositions 同步单个 trader 的仓位 +func (m *PositionSyncManager) syncTraderPositions(traderID string, localPositions []*store.TraderPosition) { + // 获取或创建 trader 实例 + trader, err := m.getOrCreateTrader(traderID) + if err != nil { + logger.Infof("⚠️ 获取 trader 实例失败 (ID: %s): %v", traderID, err) + return + } + + // 获取交易所当前仓位 + exchangePositions, err := trader.GetPositions() + if err != nil { + logger.Infof("⚠️ 获取交易所仓位失败 (ID: %s): %v", traderID, err) + return + } + + // 构建交易所仓位 map: symbol_side -> position + exchangeMap := make(map[string]map[string]interface{}) + for _, pos := range exchangePositions { + symbol, _ := pos["symbol"].(string) + side, _ := pos["positionSide"].(string) + if symbol == "" || side == "" { + continue + } + key := fmt.Sprintf("%s_%s", symbol, side) + exchangeMap[key] = pos + } + + // 对比本地和交易所仓位 + for _, localPos := range localPositions { + key := fmt.Sprintf("%s_%s", localPos.Symbol, localPos.Side) + exchangePos, exists := exchangeMap[key] + + if !exists { + // 交易所没有这个仓位了 → 已被平仓 + m.closeLocalPosition(localPos, trader, "manual") + continue + } + + // 检查数量是否为0或很小 + qty := getFloatFromMap(exchangePos, "positionAmt") + if qty < 0 { + qty = -qty // 空仓数量是负的 + } + + if qty < 0.0000001 { + // 数量为0,仓位已平 + m.closeLocalPosition(localPos, trader, "manual") + } + } +} + +// closeLocalPosition 标记本地仓位为已平仓 +func (m *PositionSyncManager) closeLocalPosition(pos *store.TraderPosition, trader Trader, reason string) { + // 尝试获取最后成交价作为平仓价 + exitPrice := pos.EntryPrice // 默认用开仓价 + + // 尝试从交易所获取最新价格 + if price, err := trader.GetMarketPrice(pos.Symbol); err == nil && price > 0 { + exitPrice = price + } + + // 计算盈亏 + var realizedPnL float64 + if pos.Side == "LONG" { + realizedPnL = (exitPrice - pos.EntryPrice) * pos.Quantity + } else { + realizedPnL = (pos.EntryPrice - exitPrice) * pos.Quantity + } + + // 更新数据库 + err := m.store.Position().ClosePosition( + pos.ID, + exitPrice, + "", // 手动平仓没有订单ID + realizedPnL, + 0, // 手动平仓无法获取手续费 + reason, + ) + + if err != nil { + logger.Infof("⚠️ 更新仓位状态失败: %v", err) + } else { + logger.Infof("📊 仓位已平仓 [%s] %s %s @ %.4f → %.4f, PnL: %.2f (%s)", + pos.TraderID[:8], pos.Symbol, pos.Side, pos.EntryPrice, exitPrice, realizedPnL, reason) + } +} + +// getOrCreateTrader 获取或创建 trader 实例 +func (m *PositionSyncManager) getOrCreateTrader(traderID string) (Trader, error) { + m.cacheMutex.RLock() + trader, exists := m.traderCache[traderID] + m.cacheMutex.RUnlock() + + if exists && trader != nil { + return trader, nil + } + + // 需要创建新的 trader 实例 + config, err := m.getTraderConfig(traderID) + if err != nil { + return nil, fmt.Errorf("获取 trader 配置失败: %w", err) + } + + trader, err = m.createTrader(config) + if err != nil { + return nil, fmt.Errorf("创建 trader 实例失败: %w", err) + } + + m.cacheMutex.Lock() + m.traderCache[traderID] = trader + m.cacheMutex.Unlock() + + return trader, nil +} + +// getTraderConfig 获取 trader 配置 +func (m *PositionSyncManager) getTraderConfig(traderID string) (*store.TraderFullConfig, error) { + m.cacheMutex.RLock() + config, exists := m.configCache[traderID] + m.cacheMutex.RUnlock() + + if exists { + return config, nil + } + + // 从数据库获取 + traders, err := m.store.Trader().ListAll() + if err != nil { + return nil, fmt.Errorf("获取 trader 列表失败: %w", err) + } + + var userID string + for _, t := range traders { + if t.ID == traderID { + userID = t.UserID + break + } + } + + if userID == "" { + return nil, fmt.Errorf("找不到 trader: %s", traderID) + } + + config, err = m.store.Trader().GetFullConfig(userID, traderID) + if err != nil { + return nil, err + } + + m.cacheMutex.Lock() + m.configCache[traderID] = config + m.cacheMutex.Unlock() + + return config, nil +} + +// createTrader 根据配置创建 trader 实例 +func (m *PositionSyncManager) createTrader(config *store.TraderFullConfig) (Trader, error) { + exchange := config.Exchange + + switch exchange.Type { + case "binance": + return NewFuturesTrader(exchange.APIKey, exchange.SecretKey, config.Trader.UserID), nil + + case "bybit": + return NewBybitTrader(exchange.APIKey, exchange.SecretKey), nil + + case "hyperliquid": + return NewHyperliquidTrader(exchange.SecretKey, exchange.HyperliquidWalletAddr, exchange.Testnet) + + case "aster": + return NewAsterTrader(exchange.AsterUser, exchange.AsterSigner, exchange.AsterPrivateKey) + + case "lighter": + if exchange.LighterAPIKeyPrivateKey != "" { + return NewLighterTraderV2( + exchange.LighterPrivateKey, + exchange.LighterWalletAddr, + exchange.LighterAPIKeyPrivateKey, + exchange.Testnet, + ) + } + return NewLighterTrader(exchange.LighterPrivateKey, exchange.LighterWalletAddr, exchange.Testnet) + + default: + return nil, fmt.Errorf("不支持的交易所类型: %s", exchange.Type) + } +} + +// InvalidateCache 使缓存失效 +func (m *PositionSyncManager) InvalidateCache(traderID string) { + m.cacheMutex.Lock() + defer m.cacheMutex.Unlock() + + delete(m.traderCache, traderID) + delete(m.configCache, traderID) +} + +// getFloatFromMap 从 map 中获取 float64 值 +func getFloatFromMap(m map[string]interface{}, key string) float64 { + if v, ok := m[key]; ok { + switch val := v.(type) { + case float64: + return val + case int64: + return float64(val) + case int: + return float64(val) + case string: + var f float64 + fmt.Sscanf(val, "%f", &f) + return f + } + } + return 0 +} diff --git a/web/package-lock.json b/web/package-lock.json index 857e30ce..72290fd2 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -121,7 +121,6 @@ "resolved": "https://registry.npmjs.org/@babel/core/-/core-7.28.5.tgz", "integrity": "sha512-e7jT4DxYvIDLk1ZHmU/m/mB19rex9sv0c2ftBtjSBv+kVM/902eh0fINUzD7UwLLNR+jU585GxUJ8/EBfAM5fw==", "dev": true, - "peer": true, "dependencies": { "@babel/code-frame": "^7.27.1", "@babel/generator": "^7.28.5", @@ -453,7 +452,6 @@ } ], "license": "MIT", - "peer": true, "engines": { "node": ">=18" }, @@ -477,7 +475,6 @@ } ], "license": "MIT", - "peer": true, "engines": { "node": ">=18" } @@ -2037,7 +2034,8 @@ "resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.4.tgz", "integrity": "sha512-rfT93uj5s0PRL7EzccGMs3brplhcrghnDoV26NqKhCAS1hVo+WdNsPvE/yb6ilfr5hi2MEk6d5EWJTKdxg8jVw==", "dev": true, - "license": "MIT" + "license": "MIT", + "peer": true }, "node_modules/@types/babel__core": { "version": "7.20.5", @@ -2158,7 +2156,6 @@ "resolved": "https://registry.npmjs.org/@types/react/-/react-18.3.26.tgz", "integrity": "sha512-RFA/bURkcKzx/X9oumPG9Vp3D3JUgus/d0b67KB0t5S/raciymilkOa66olh78MUI92QLbEJevO7rvqU/kjwKA==", "devOptional": true, - "peer": true, "dependencies": { "@types/prop-types": "*", "csstype": "^3.0.2" @@ -2169,7 +2166,6 @@ "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.3.7.tgz", "integrity": "sha512-MEe3UeoENYVFXzoXEWsvcpg6ZvlrFNlOQ7EOsvhI3CfAXwzPfO8Qwuxd40nepsYKqyyVQnTdEfv68q91yLcKrQ==", "devOptional": true, - "peer": true, "peerDependencies": { "@types/react": "^18.0.0" } @@ -2210,7 +2206,6 @@ "integrity": "sha512-6m1I5RmHBGTnUGS113G04DMu3CpSdxCAU/UvtjNWL4Nuf3MW9tQhiJqRlHzChIkhy6kZSAQmc+I1bcGjE3yNKg==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "8.46.3", "@typescript-eslint/types": "8.46.3", @@ -2535,7 +2530,6 @@ "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true, "license": "MIT", - "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -2969,7 +2963,6 @@ "url": "https://github.com/sponsors/ai" } ], - "peer": true, "dependencies": { "baseline-browser-mapping": "^2.8.19", "caniuse-lite": "^1.0.30001751", @@ -3697,7 +3690,8 @@ "resolved": "https://registry.npmjs.org/dom-accessibility-api/-/dom-accessibility-api-0.5.16.tgz", "integrity": "sha512-X7BJ2yElsnOJ30pZF4uIIDfBEVgF4XEBxL9Bxhy6dnrm5hkzqmsWHGTiHqRiITNhMyFLyAiWndIJP7Z1NTteDg==", "dev": true, - "license": "MIT" + "license": "MIT", + "peer": true }, "node_modules/dom-helpers": { "version": "5.2.1", @@ -4015,7 +4009,6 @@ "integrity": "sha512-BhHmn2yNOFA9H9JmmIVKJmd288g9hrVRDkdoIgRCRuSySRUHH7r/DI6aAXW9T1WwUuY3DFgrcaqB+deURBLR5g==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.8.0", "@eslint-community/regexpp": "^4.12.1", @@ -4076,7 +4069,6 @@ "integrity": "sha512-82GZUjRS0p/jganf6q1rEO25VSoHH0hKPCTrgillPjdI/3bgBhAE1QzHrHTizjpRvy6pGAvKjDJtk2pF9NDq8w==", "dev": true, "license": "MIT", - "peer": true, "bin": { "eslint-config-prettier": "bin/cli.js" }, @@ -5590,7 +5582,6 @@ "resolved": "https://registry.npmjs.org/jiti/-/jiti-1.21.7.tgz", "integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==", "dev": true, - "peer": true, "bin": { "jiti": "bin/jiti.js" } @@ -5619,7 +5610,6 @@ "integrity": "sha512-8i7LzZj7BF8uplX+ZyOlIz86V6TAsSs+np6m1kpW9u0JWi4z/1t+FzcK1aek+ybTnAC4KhBL4uXCNT0wcUIeCw==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "cssstyle": "^4.1.0", "data-urls": "^5.0.0", @@ -5994,6 +5984,7 @@ "integrity": "sha512-h5bgJWpxJNswbU7qCrV0tIKQCaS3blPDrqKWx+QxzuzL1zGUzij9XCWLrSLsJPu5t+eWA/ycetzYAO5IOMcWAQ==", "dev": true, "license": "MIT", + "peer": true, "bin": { "lz-string": "bin/bin.js" } @@ -6581,7 +6572,6 @@ "url": "https://github.com/sponsors/ai" } ], - "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -6735,7 +6725,6 @@ "integrity": "sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==", "dev": true, "license": "MIT", - "peer": true, "bin": { "prettier": "bin/prettier.cjs" }, @@ -6765,6 +6754,7 @@ "integrity": "sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "ansi-regex": "^5.0.1", "ansi-styles": "^5.0.0", @@ -6780,6 +6770,7 @@ "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", "dev": true, "license": "MIT", + "peer": true, "engines": { "node": ">=8" } @@ -6790,6 +6781,7 @@ "integrity": "sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==", "dev": true, "license": "MIT", + "peer": true, "engines": { "node": ">=10" }, @@ -6802,7 +6794,8 @@ "resolved": "https://registry.npmjs.org/react-is/-/react-is-17.0.2.tgz", "integrity": "sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==", "dev": true, - "license": "MIT" + "license": "MIT", + "peer": true }, "node_modules/prop-types": { "version": "15.8.1", @@ -6859,7 +6852,6 @@ "version": "18.3.1", "resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz", "integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==", - "peer": true, "dependencies": { "loose-envify": "^1.1.0" }, @@ -6871,7 +6863,6 @@ "version": "18.3.1", "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz", "integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==", - "peer": true, "dependencies": { "loose-envify": "^1.1.0", "scheduler": "^0.23.2" @@ -8063,7 +8054,6 @@ "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, - "peer": true, "engines": { "node": ">=12" }, @@ -8280,7 +8270,6 @@ "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.9.3.tgz", "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", "dev": true, - "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -8431,7 +8420,6 @@ "resolved": "https://registry.npmjs.org/vite/-/vite-6.4.1.tgz", "integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==", "dev": true, - "peer": true, "dependencies": { "esbuild": "^0.25.0", "fdir": "^6.4.4", @@ -9036,7 +9024,6 @@ "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, - "peer": true, "engines": { "node": ">=12" }, @@ -9573,7 +9560,6 @@ "integrity": "sha512-o5a9xKjbtuhY6Bi5S3+HvbRERmouabWbyUcpXXUA1u+GNUKoROi9byOJ8M0nHbHYHkYICiMlqxkg1KkYmm25Sw==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "esbuild": "^0.21.3", "postcss": "^8.4.43", @@ -9987,7 +9973,6 @@ "integrity": "sha512-JInaHOamG8pt5+Ey8kGmdcAcg3OL9reK8ltczgHTAwNhMys/6ThXHityHxVV2p3fkw/c+MAvBHFVYHFZDmjMCQ==", "dev": true, "license": "MIT", - "peer": true, "funding": { "url": "https://github.com/sponsors/colinhacks" } diff --git a/web/src/App.tsx b/web/src/App.tsx index 81453d5c..2347082a 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -1,7 +1,7 @@ import { useEffect, useState } from 'react' import useSWR from 'swr' import { api } from './lib/api' -import { EquityChart } from './components/EquityChart' +import { ChartTabs } from './components/ChartTabs' import { AITradersPage } from './components/AITradersPage' import { LoginPage } from './components/LoginPage' import { RegisterPage } from './components/RegisterPage' @@ -10,7 +10,6 @@ import { CompetitionPage } from './components/CompetitionPage' import { LandingPage } from './pages/LandingPage' import { FAQPage } from './pages/FAQPage' import HeaderBar from './components/HeaderBar' -import AILearning from './components/AILearning' import { LanguageProvider, useLanguage } from './contexts/LanguageContext' import { AuthProvider, useAuth } from './contexts/AuthContext' import { ConfirmDialogProvider } from './components/ConfirmDialog' @@ -780,9 +779,9 @@ function TraderDetailsPage({
{/* 左侧:图表 + 持仓 */}
- {/* Equity Chart */} + {/* Chart Tabs (Equity / K-line) */}
- +
{/* Current Positions */} @@ -1002,10 +1001,6 @@ function TraderDetailsPage({ {/* 右侧结束 */}
- {/* AI Learning & Performance Analysis */} -
- -
) } diff --git a/web/src/components/AILearning.tsx b/web/src/components/AILearning.tsx deleted file mode 100644 index a10f8f14..00000000 --- a/web/src/components/AILearning.tsx +++ /dev/null @@ -1,1142 +0,0 @@ -import useSWR from 'swr' -import { useLanguage } from '../contexts/LanguageContext' -import { t } from '../i18n/translations' -import { stripLeadingIcons } from '../lib/text' -import { api } from '../lib/api' -import { - Brain, - BarChart3, - TrendingUp, - TrendingDown, - Sparkles, - Coins, - Trophy, - ScrollText, - Lightbulb, -} from 'lucide-react' - -interface TradeOutcome { - symbol: string - side: string - quantity: number - leverage: number - open_price: number - close_price: number - position_value: number - margin_used: number - pn_l: number - pn_l_pct: number - duration: string - open_time: string - close_time: string - was_stop_loss: boolean -} - -interface SymbolPerformance { - symbol: string - total_trades: number - winning_trades: number - losing_trades: number - win_rate: number - total_pn_l: number - avg_pn_l: number -} - -interface PerformanceAnalysis { - total_trades: number - winning_trades: number - losing_trades: number - win_rate: number - avg_win: number - avg_loss: number - profit_factor: number - sharpe_ratio: number - recent_trades: TradeOutcome[] - symbol_stats: { [key: string]: SymbolPerformance } - best_symbol: string - worst_symbol: string -} - -interface AILearningProps { - traderId: string -} - -export default function AILearning({ traderId }: AILearningProps) { - const { language } = useLanguage() - const { data: performance, error } = useSWR( - traderId ? `performance-${traderId}` : 'performance', - () => api.getPerformance(traderId), - { - refreshInterval: 30000, // 30秒刷新(AI学习分析数据更新频率较低) - revalidateOnFocus: false, - dedupingInterval: 20000, - } - ) - - if (error) { - return ( -
-
- {stripLeadingIcons(t('loadingError', language))} -
-
- ) - } - - if (!performance) { - return ( -
-
- {t('loading', language)} -
-
- ) - } - - if (!performance || performance.total_trades === 0) { - return ( -
-
- -

- {t('aiLearning', language)} -

-
-
{t('noCompleteData', language)}
-
- ) - } - - const symbolStats = performance.symbol_stats || {} - const symbolStatsList = Object.values(symbolStats) - .filter((stat) => stat != null) - .sort((a, b) => (b.total_pn_l || 0) - (a.total_pn_l || 0)) - - return ( -
- {/* 标题区 - 优化设计 */} -
-
-
-
- -
-
-

- {t('aiLearning', language)} -

-

- {t('tradesAnalyzed', language, { - count: performance.total_trades, - })} -

-
-
-
- - {/* 核心指标卡片 - 4列网格 */} -
- {/* 总交易数 */} -
-
-
-
- {t('totalTrades', language)} -
-
- {performance.total_trades} -
-
- Trades -
-
-
- - {/* 胜率 */} -
= 50 - ? 'linear-gradient(135deg, rgba(16, 185, 129, 0.2) 0%, rgba(30, 35, 41, 0.8) 100%)' - : 'linear-gradient(135deg, rgba(248, 113, 113, 0.2) 0%, rgba(30, 35, 41, 0.8) 100%)', - border: `1px solid ${(performance.win_rate || 0) >= 50 ? 'rgba(16, 185, 129, 0.4)' : 'rgba(248, 113, 113, 0.4)'}`, - boxShadow: `0 4px 16px ${(performance.win_rate || 0) >= 50 ? 'rgba(16, 185, 129, 0.2)' : 'rgba(248, 113, 113, 0.2)'}`, - }} - > -
= 50 ? '#10B981' : '#F87171'} 0%, transparent 70%)`, - filter: 'blur(20px)', - }} - /> -
-
= 50 ? '#6EE7B7' : '#FCA5A5', - }} - > - {t('winRate', language)} -
-
= 50 ? '#10B981' : '#F87171', - }} - > - {(performance.win_rate || 0).toFixed(1)}% -
-
- {performance.winning_trades || 0}W /{' '} - {performance.losing_trades || 0}L -
-
-
- - {/* 平均盈利 */} -
-
-
-
- {t('avgWin', language)} -
-
- +{(performance.avg_win || 0).toFixed(2)} -
-
- USDT Average -
-
-
- - {/* 平均亏损 */} -
-
-
-
- {t('avgLoss', language)} -
-
- {(performance.avg_loss || 0).toFixed(2)} -
-
- USDT Average -
-
-
-
- - {/* 关键指标:夏普比率 & 盈亏比 - 2列网格 */} -
- {/* 夏普比率 */} -
-
-
-
-
- -
-
-
- 夏普比率 -
-
- 风险调整后收益 · AI自我进化指标 -
-
-
- -
-
= 2 - ? '#10B981' - : (performance.sharpe_ratio || 0) >= 1 - ? '#22D3EE' - : (performance.sharpe_ratio || 0) >= 0 - ? '#F0B90B' - : '#F87171', - textShadow: '0 4px 12px rgba(0, 0, 0, 0.3)', - }} - > - {performance.sharpe_ratio - ? performance.sharpe_ratio.toFixed(2) - : 'N/A'} -
- - {performance.sharpe_ratio !== undefined && ( -
-
= 2 - ? '#10B981' - : (performance.sharpe_ratio || 0) >= 1 - ? '#22D3EE' - : (performance.sharpe_ratio || 0) >= 0 - ? '#F0B90B' - : '#F87171', - background: - (performance.sharpe_ratio || 0) >= 2 - ? 'rgba(16, 185, 129, 0.2)' - : (performance.sharpe_ratio || 0) >= 1 - ? 'rgba(34, 211, 238, 0.2)' - : (performance.sharpe_ratio || 0) >= 0 - ? 'rgba(240, 185, 11, 0.2)' - : 'rgba(248, 113, 113, 0.2)', - }} - > - {performance.sharpe_ratio >= 2 - ? '🟢 卓越表现' - : performance.sharpe_ratio >= 1 - ? '🟢 良好表现' - : performance.sharpe_ratio >= 0 - ? '🟡 波动较大' - : '🔴 需要调整'} -
-
- )} -
- - {performance.sharpe_ratio !== undefined && ( -
-
- {performance.sharpe_ratio >= 2 && - '✨ AI策略非常有效!风险调整后收益优异,可适度扩大仓位但保持纪律。'} - {performance.sharpe_ratio >= 1 && - performance.sharpe_ratio < 2 && - '✅ 策略表现稳健,风险收益平衡良好,继续保持当前策略。'} - {performance.sharpe_ratio >= 0 && - performance.sharpe_ratio < 1 && - '⚠️ 收益为正但波动较大,AI正在优化策略,降低风险。'} - {performance.sharpe_ratio < 0 && - '🚨 当前策略需要调整!AI已自动进入保守模式,减少仓位和交易频率。'} -
-
- )} -
-
- - {/* 盈亏比 */} -
-
-
-
-
- -
-
-
- {t('profitFactor', language)} -
-
- {t('avgWinDivLoss', language)} -
-
-
- -
-
= 2.0 - ? '#10B981' - : (performance.profit_factor || 0) >= 1.5 - ? '#F0B90B' - : (performance.profit_factor || 0) >= 1.0 - ? '#FB923C' - : '#F87171', - textShadow: '0 4px 12px rgba(0, 0, 0, 0.3)', - }} - > - {(performance.profit_factor || 0) > 0 - ? (performance.profit_factor || 0).toFixed(2) - : 'N/A'} -
- -
-
= 2.0 - ? '#10B981' - : (performance.profit_factor || 0) >= 1.5 - ? '#F0B90B' - : '#94A3B8', - background: - (performance.profit_factor || 0) >= 2.0 - ? 'rgba(16, 185, 129, 0.2)' - : (performance.profit_factor || 0) >= 1.5 - ? 'rgba(240, 185, 11, 0.2)' - : 'rgba(148, 163, 184, 0.2)', - }} - > - {(performance.profit_factor || 0) >= 2.0 && - t('excellent', language)} - {(performance.profit_factor || 0) >= 1.5 && - (performance.profit_factor || 0) < 2.0 && - t('good', language)} - {(performance.profit_factor || 0) >= 1.0 && - (performance.profit_factor || 0) < 1.5 && - t('fair', language)} - {(performance.profit_factor || 0) > 0 && - (performance.profit_factor || 0) < 1.0 && - t('poor', language)} -
-
-
- -
-
- {(performance.profit_factor || 0) >= 2.0 && - '🔥 盈利能力出色!每亏1元能赚' + - (performance.profit_factor || 0).toFixed(1) + - '元,AI策略表现优异。'} - {(performance.profit_factor || 0) >= 1.5 && - (performance.profit_factor || 0) < 2.0 && - '✓ 策略稳定盈利,盈亏比健康,继续保持纪律性交易。'} - {(performance.profit_factor || 0) >= 1.0 && - (performance.profit_factor || 0) < 1.5 && - '⚠️ 策略略有盈利但需优化,AI正在调整仓位和止损策略。'} - {(performance.profit_factor || 0) > 0 && - (performance.profit_factor || 0) < 1.0 && - '❌ 平均亏损大于盈利,需要调整策略或降低交易频率。'} -
-
-
-
-
- - {/* 最佳/最差币种 - 独立行 */} - {(performance.best_symbol || performance.worst_symbol) && ( -
- {performance.best_symbol && ( -
-
- - - {t('bestPerformer', language)} - -
-
- {performance.best_symbol} -
- {symbolStats[performance.best_symbol] && ( -
- {symbolStats[performance.best_symbol].total_pn_l > 0 - ? '+' - : ''} - {symbolStats[performance.best_symbol].total_pn_l.toFixed(2)}{' '} - USDT {t('pnl', language)} -
- )} -
- )} - - {performance.worst_symbol && ( -
-
- - - {t('worstPerformer', language)} - -
-
- {performance.worst_symbol} -
- {symbolStats[performance.worst_symbol] && ( -
- {symbolStats[performance.worst_symbol].total_pn_l > 0 - ? '+' - : ''} - {symbolStats[performance.worst_symbol].total_pn_l.toFixed(2)}{' '} - USDT {t('pnl', language)} -
- )} -
- )} -
- )} - - {/* 币种表现 & 历史成交 - 左右分屏 2列布局 */} -
- {/* 左侧:币种表现统计表格 */} - {symbolStatsList.length > 0 && ( -
-
-

- {' '} - {stripLeadingIcons(t('symbolPerformance', language))} -

-
-
- - - - - - - - - - - - {symbolStatsList.map((stat, idx) => ( - 0 - ? '1px solid rgba(99, 102, 241, 0.1)' - : 'none', - }} - > - - - - - - - ))} - -
- Symbol - - Trades - - Win Rate - - Total P&L (USDT) - - Avg P&L (USDT) -
- - {stat.symbol} - - - {stat.total_trades} - = 50 ? '#10B981' : '#F87171', - }} - > - {(stat.win_rate || 0).toFixed(1)}% - 0 ? '#10B981' : '#F87171', - }} - > - {(stat.total_pn_l || 0) > 0 ? '+' : ''} - {(stat.total_pn_l || 0).toFixed(2)} - 0 ? '#10B981' : '#F87171', - }} - > - {(stat.avg_pn_l || 0) > 0 ? '+' : ''} - {(stat.avg_pn_l || 0).toFixed(2)} -
-
-
- )} - - {/* 右侧:历史成交记录 */} -
-
-
- -
-

- {t('tradeHistory', language)} -

-

- {performance?.recent_trades && - performance.recent_trades.length > 0 - ? t('completedTrades', language, { - count: performance.recent_trades.length, - }) - : t('completedTradesWillAppear', language)} -

-
-
-
- -
- {performance?.recent_trades && - performance.recent_trades.length > 0 ? ( - performance.recent_trades.map( - (trade: TradeOutcome, idx: number) => { - const isProfitable = trade.pn_l >= 0 - const isRecent = idx === 0 - - return ( -
-
-
- - {trade.symbol} - - - {trade.side.toUpperCase()} - - {isRecent && ( - - {t('latest', language)} - - )} -
-
- {isProfitable ? '+' : ''} - {trade.pn_l_pct.toFixed(2)}% -
-
- -
-
-
- {t('entry', language)} -
-
- {trade.open_price.toFixed(4)} -
-
-
-
- {t('exit', language)} -
-
- {trade.close_price.toFixed(4)} -
-
-
- - {/* Position Details */} -
-
-
Quantity
-
- {trade.quantity ? trade.quantity.toFixed(4) : '-'} -
-
-
-
Leverage
-
- {trade.leverage ? `${trade.leverage}x` : '-'} -
-
-
-
Position Value
-
- {trade.position_value - ? `$${trade.position_value.toFixed(2)}` - : '-'} -
-
-
-
Margin Used
-
- {trade.margin_used - ? `$${trade.margin_used.toFixed(2)}` - : '-'} -
-
-
- -
-
- P&L - - {isProfitable ? '+' : ''} - {trade.pn_l.toFixed(2)} USDT - -
-
- -
- ⏱️ {formatDuration(trade.duration)} - {trade.was_stop_loss && ( - - {t('stopLoss', language)} - - )} -
- -
- {new Date(trade.close_time).toLocaleString('en-US', { - month: 'short', - day: '2-digit', - hour: '2-digit', - minute: '2-digit', - })} -
-
- ) - } - ) - ) : ( -
-
- -
-
- {t('noCompletedTrades', language)} -
-
- )} -
-
-
- - {/* AI学习说明 - 现代化设计 */} -
-
-
- -
-
-

- {stripLeadingIcons(t('howAILearns', language))} -

-
-
- - - {t('aiLearningPoint1', language)} - -
-
- - - {t('aiLearningPoint2', language)} - -
-
- - - {t('aiLearningPoint3', language)} - -
-
- - - {t('aiLearningPoint4', language)} - -
-
-
-
-
-
- ) -} - -// 格式化持仓时长 -function formatDuration(duration: string | undefined): string { - if (!duration) return '-' - - const match = duration.match(/(\d+h)?(\d+m)?(\d+\.?\d*s)?/) - if (!match) return duration - - const hours = match[1] || '' - const minutes = match[2] || '' - const seconds = match[3] || '' - - let result = '' - if (hours) result += hours.replace('h', '小时') - if (minutes) result += minutes.replace('m', '分') - if (!hours && seconds) result += seconds.replace(/(\d+)\.?\d*s/, '$1秒') - - return result || duration -} diff --git a/web/src/components/ChartTabs.tsx b/web/src/components/ChartTabs.tsx new file mode 100644 index 00000000..29f8d98b --- /dev/null +++ b/web/src/components/ChartTabs.tsx @@ -0,0 +1,89 @@ +import { useState } from 'react' +import { EquityChart } from './EquityChart' +import { TradingViewChart } from './TradingViewChart' +import { useLanguage } from '../contexts/LanguageContext' +import { t } from '../i18n/translations' +import { BarChart3, CandlestickChart } from 'lucide-react' + +interface ChartTabsProps { + traderId: string +} + +type ChartTab = 'equity' | 'kline' + +export function ChartTabs({ traderId }: ChartTabsProps) { + const { language } = useLanguage() + const [activeTab, setActiveTab] = useState('equity') + + console.log('[ChartTabs] rendering, activeTab:', activeTab) + + return ( +
+ {/* Tab Headers - 这是Tab切换按钮区域 */} +
+ + + +
+ + {/* Tab Content */} +
+ {activeTab === 'equity' ? ( + + ) : ( + + )} +
+
+ ) +} diff --git a/web/src/components/DecisionCard.tsx b/web/src/components/DecisionCard.tsx index 96d713d1..9b4b74cf 100644 --- a/web/src/components/DecisionCard.tsx +++ b/web/src/components/DecisionCard.tsx @@ -126,6 +126,11 @@ export function DecisionCard({ decision, language }: DecisionCardProps) { background: 'rgba(14, 203, 129, 0.1)', color: '#0ECB81', } + : action.action === 'wait' || action.action === 'hold' + ? { + background: 'rgba(132, 142, 156, 0.1)', + color: '#848E9C', + } : { background: 'rgba(248, 113, 113, 0.1)', color: '#F87171', diff --git a/web/src/components/EquityChart.tsx b/web/src/components/EquityChart.tsx index f8beb1d5..13de9921 100644 --- a/web/src/components/EquityChart.tsx +++ b/web/src/components/EquityChart.tsx @@ -33,9 +33,10 @@ interface EquityPoint { interface EquityChartProps { traderId?: string + embedded?: boolean // 嵌入模式(不显示外层卡片) } -export function EquityChart({ traderId }: EquityChartProps) { +export function EquityChart({ traderId, embedded = false }: EquityChartProps) { const { language } = useLanguage() const { user, token } = useAuth() const [displayMode, setDisplayMode] = useState<'dollar' | 'percent'>('dollar') @@ -62,7 +63,7 @@ export function EquityChart({ traderId }: EquityChartProps) { if (error) { return ( -
+
-

- {t('accountEquityCurve', language)} -

+
+ {!embedded && ( +

+ {t('accountEquityCurve', language)} +

+ )}
@@ -193,16 +196,18 @@ export function EquityChart({ traderId }: EquityChartProps) { } return ( -
+
{/* Header */}
-

- {t('accountEquityCurve', language)} -

+ {!embedded && ( +

+ {t('accountEquityCurve', language)} +

+ )}
('login') const [email, setEmail] = useState('') const [password, setPassword] = useState('') @@ -236,7 +234,9 @@ export function LoginPage() {
+ + {showExchangeDropdown && ( +
+ {EXCHANGES.map((ex) => ( + + ))} +
+ )} +
+ + {/* Symbol Selector */} +
+ + + {showSymbolDropdown && ( +
+ {/* Custom Input */} +
+
+ setCustomSymbol(e.target.value.toUpperCase())} + onKeyDown={(e) => e.key === 'Enter' && handleCustomSymbolSubmit()} + placeholder={t('enterSymbol', language)} + className="flex-1 px-3 py-1.5 rounded text-sm" + style={{ + background: '#0B0E11', + border: '1px solid #2B3139', + color: '#EAECEF', + }} + /> + +
+
+ + {/* Popular Symbols */} +
+
+ {t('popularSymbols', language)} +
+
+ {POPULAR_SYMBOLS.map((sym) => ( + + ))} +
+
+
+ )} +
+ + {/* Interval Selector */} +
+ {INTERVALS.map((int) => ( + + ))} +
+ + {/* Fullscreen Toggle */} + +
+
+ + {/* Chart Container */} +
+ + {/* Click outside to close dropdowns */} + {(showExchangeDropdown || showSymbolDropdown) && ( +
{ + setShowExchangeDropdown(false) + setShowSymbolDropdown(false) + }} + /> + )} +
+ ) +} + +// 使用 memo 避免不必要的重渲染 +export const TradingViewChart = memo(TradingViewChartComponent) diff --git a/web/src/i18n/translations.ts b/web/src/i18n/translations.ts index f83cf052..e9110a9a 100644 --- a/web/src/i18n/translations.ts +++ b/web/src/i18n/translations.ts @@ -83,6 +83,13 @@ export const translations = { currentGap: 'Current Gap', count: '{count} pts', + // TradingView Chart + marketChart: 'Market Chart', + enterSymbol: 'Enter symbol...', + popularSymbols: 'Popular Symbols', + fullscreen: 'Fullscreen', + exitFullscreen: 'Exit Fullscreen', + // Backtest Page backtestPage: { title: 'Backtest Lab', @@ -264,40 +271,6 @@ export const translations = { pnl: 'P&L', pos: 'Pos', - // AI Learning - aiLearning: 'AI Learning & Reflection', - tradesAnalyzed: '{count} trades analyzed · Real-time evolution', - latestReflection: 'Latest Reflection', - fullCoT: 'Full Chain of Thought', - totalTrades: 'Total Trades', - winRate: 'Win Rate', - avgWin: 'Avg Win', - avgLoss: 'Avg Loss', - profitFactor: 'Profit Factor', - avgWinDivLoss: 'Avg Win ÷ Avg Loss', - excellent: '🔥 Excellent - Strong profitability', - good: '✓ Good - Stable profits', - fair: '⚠️ Fair - Needs optimization', - poor: '❌ Poor - Losses exceed gains', - bestPerformer: 'Best Performer', - worstPerformer: 'Worst Performer', - symbolPerformance: 'Symbol Performance', - tradeHistory: 'Trade History', - completedTrades: 'Recent {count} completed trades', - noCompletedTrades: 'No completed trades yet', - completedTradesWillAppear: 'Completed trades will appear here', - entry: 'Entry', - exit: 'Exit', - stopLoss: 'Stop Loss', - latest: 'Latest', - - // AI Learning Description - howAILearns: 'How AI Learns & Evolves', - aiLearningPoint1: 'Analyzes last 20 trading cycles before each decision', - aiLearningPoint2: 'Identifies best & worst performing symbols', - aiLearningPoint3: 'Optimizes position sizing based on win rate', - aiLearningPoint4: 'Avoids repeating past mistakes', - // AI Traders Management manageAITraders: 'Manage your AI trading bots', aiModels: 'AI Models', @@ -499,9 +472,6 @@ export const translations = { // Loading & Error loading: 'Loading...', - loadingError: '⚠️ Failed to load AI learning data', - noCompleteData: - 'No complete trading data (needs to complete open → close cycle)', // AI Traders Page - Additional inUse: 'In Use', @@ -954,7 +924,7 @@ export const translations = { // Data & Privacy faqDataStorage: 'Where is my data stored?', faqDataStorageAnswer: - 'All data is stored locally on your machine in SQLite databases: config.db (trader configurations), trading.db (trade history), and decision_logs/ (AI decision records).', + 'All data is stored locally on your machine in SQLite databases: data.db (all configurations and trade history), and decision_logs/ (AI decision records).', faqApiKeySecurity: 'Is my API key secure?', faqApiKeySecurityAnswer: @@ -1109,6 +1079,13 @@ export const translations = { currentGap: '当前差距', count: '{count} 个', + // TradingView Chart + marketChart: '行情图表', + enterSymbol: '输入币种...', + popularSymbols: '热门币种', + fullscreen: '全屏', + exitFullscreen: '退出全屏', + // Backtest Page backtestPage: { title: '回测实验室', @@ -1288,40 +1265,6 @@ export const translations = { pnl: '收益', pos: '持仓', - // AI Learning - aiLearning: 'AI学习与反思', - tradesAnalyzed: '已分析 {count} 笔交易 · 实时演化', - latestReflection: '最新反思', - fullCoT: '📋 完整思维链', - totalTrades: '总交易数', - winRate: '胜率', - avgWin: '平均盈利', - avgLoss: '平均亏损', - profitFactor: '盈亏比', - avgWinDivLoss: '平均盈利 ÷ 平均亏损', - excellent: '🔥 优秀 - 盈利能力强', - good: '✓ 良好 - 稳定盈利', - fair: '⚠️ 一般 - 需要优化', - poor: '❌ 较差 - 亏损超过盈利', - bestPerformer: '最佳表现', - worstPerformer: '最差表现', - symbolPerformance: '📊 币种表现', - tradeHistory: '历史成交', - completedTrades: '最近 {count} 笔已完成交易', - noCompletedTrades: '暂无完成的交易', - completedTradesWillAppear: '已完成的交易将显示在这里', - entry: '入场', - exit: '出场', - stopLoss: '止损', - latest: '最新', - - // AI Learning Description - howAILearns: '💡 AI如何学习和进化', - aiLearningPoint1: '每次决策前分析最近20个交易周期', - aiLearningPoint2: '识别表现最好和最差的币种', - aiLearningPoint3: '根据胜率优化仓位大小', - aiLearningPoint4: '避免重复过去的错误', - // AI Traders Management manageAITraders: '管理您的AI交易机器人', aiModels: 'AI模型', @@ -1512,8 +1455,6 @@ export const translations = { // Loading & Error loading: '加载中...', - loadingError: '⚠️ 加载AI学习数据失败', - noCompleteData: '暂无完整交易数据(需要完成开仓→平仓的完整周期)', // AI Traders Page - Additional inUse: '正在使用', @@ -1927,7 +1868,7 @@ export const translations = { // Data & Privacy faqDataStorage: '我的数据存储在哪里?', faqDataStorageAnswer: - '所有数据都本地存储在您的机器上,使用 SQLite 数据库:config.db(交易员配置)、trading.db(交易历史)、decision_logs/(AI 决策记录)。', + '所有数据都本地存储在您的机器上,使用 SQLite 数据库:data.db(所有配置和交易历史)、decision_logs/(AI 决策记录)。', faqApiKeySecurity: 'API 密钥安全吗?', faqApiKeySecurityAnswer: diff --git a/web/src/lib/api.ts b/web/src/lib/api.ts index d2d9d8bf..29bad0ee 100644 --- a/web/src/lib/api.ts +++ b/web/src/lib/api.ts @@ -337,16 +337,6 @@ export const api = { return result.data! }, - // 获取AI学习表现分析(支持trader_id) - async getPerformance(traderId?: string): Promise { - const url = traderId - ? `${API_BASE}/performance?trader_id=${traderId}` - : `${API_BASE}/performance` - const result = await httpClient.get(url) - if (!result.success) throw new Error('获取AI学习数据失败') - return result.data! - }, - // 获取竞赛数据(无需认证) async getCompetition(): Promise { const result = await httpClient.get( diff --git a/web/src/pages/TraderDashboard.tsx b/web/src/pages/TraderDashboard.tsx index d66a141b..83d5a187 100644 --- a/web/src/pages/TraderDashboard.tsx +++ b/web/src/pages/TraderDashboard.tsx @@ -2,8 +2,7 @@ import { useEffect, useState } from 'react' import { useNavigate, useSearchParams } from 'react-router-dom' import useSWR from 'swr' import { api } from '../lib/api' -import { EquityChart } from '../components/EquityChart' -import AILearning from '../components/AILearning' +import { ChartTabs } from '../components/ChartTabs' import { useLanguage } from '../contexts/LanguageContext' import { useAuth } from '../contexts/AuthContext' import { t, type Language } from '../i18n/translations' @@ -419,9 +418,9 @@ export default function TraderDashboard() {
{/* 左侧:图表 + 持仓 */}
- {/* Equity Chart */} + {/* Chart Tabs (Equity / K-line) */}
- +
{/* Current Positions */} @@ -669,10 +668,6 @@ export default function TraderDashboard() {
- {/* AI Learning & Performance Analysis */} -
- -
) }