diff --git a/.env.example b/.env.example index 50ad92dd..dc269f1b 100644 --- a/.env.example +++ b/.env.example @@ -13,6 +13,9 @@ REDIS_HOST=redis REDIS_PORT=6379 REDIS_PASSWORD=redis123456 +# 数据加密密钥 +DATA_ENCRYPTION_KEY=my_secret_encryption_key + # Ports Configuration # Backend API server port (internal: 8080, external: configurable) NOFX_BACKEND_PORT=8080 diff --git a/.gitignore b/.gitignore index 9f3bdd5d..a5c1c3c3 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,11 @@ config.db certs/ beta_codes.txt +# 密钥文件 +keys/ +*.key +*.pem + # 决策日志 decision_logs/ coin_pool_cache/ diff --git a/api/server.go b/api/server.go index a10a39f6..ca06904c 100644 --- a/api/server.go +++ b/api/server.go @@ -8,6 +8,7 @@ import ( "net/http" "nofx/auth" "nofx/config" + "nofx/crypto" "nofx/decision" "nofx/manager" "nofx/trader" @@ -24,11 +25,12 @@ type Server struct { router *gin.Engine traderManager *manager.TraderManager database config.DatabaseInterface + cryptoService *crypto.CryptoService port int } // NewServer 创建API服务器 -func NewServer(traderManager *manager.TraderManager, database config.DatabaseInterface, port int) *Server { +func NewServer(traderManager *manager.TraderManager, database config.DatabaseInterface, cryptoService *crypto.CryptoService, port int) *Server { // 设置为Release模式(减少日志输出) gin.SetMode(gin.ReleaseMode) @@ -37,10 +39,17 @@ func NewServer(traderManager *manager.TraderManager, database config.DatabaseInt // 启用CORS router.Use(corsMiddleware()) + if cryptoService == nil { + log.Printf("⚠️ 加密服务未初始化,敏感数据加解密功能不可用") + } else { + database.SetCryptoService(cryptoService) + } + s := &Server{ router: router, traderManager: traderManager, database: database, + cryptoService: cryptoService, port: port, } @@ -123,6 +132,7 @@ func (s *Server) setupRoutes() { // 交易所配置 protected.GET("/exchanges", s.handleGetExchangeConfigs) protected.PUT("/exchanges", s.handleUpdateExchangeConfigs) + protected.PUT("/exchanges/encrypted", s.handleUpdateExchangeConfigsEncrypted) // 用户信号源配置 protected.GET("/user/signal-sources", s.handleGetUserSignalSource) @@ -179,11 +189,19 @@ func (s *Server) handleGetSystemConfig(c *gin.Context) { betaModeStr, _ := s.database.GetSystemConfig("beta_mode") betaMode := betaModeStr == "true" + // 获取RSA公钥 + var rsaPublicKey string + if s.cryptoService != nil { + rsaPublicKey = s.cryptoService.GetPublicKeyPEM() + } + c.JSON(http.StatusOK, gin.H{ "beta_mode": betaMode, "default_coins": defaultCoins, "btc_eth_leverage": btcEthLeverage, "altcoin_leverage": altcoinLeverage, + "rsa_public_key": rsaPublicKey, + "rsa_key_id": "rsa-key-2025-11-05", }) } @@ -1638,8 +1656,10 @@ func (s *Server) handleCompleteRegistration(c *gin.Context) { // handleLogin 处理用户登录请求 func (s *Server) handleLogin(c *gin.Context) { var req struct { - Email string `json:"email" binding:"required,email"` - Password string `json:"password" binding:"required"` + Email string `json:"email"` + EmailEncrypted *crypto.EncryptedPayload `json:"email_encrypted"` + Password string `json:"password"` + PasswordEncrypted *crypto.EncryptedPayload `json:"password_encrypted"` } if err := c.ShouldBindJSON(&req); err != nil { @@ -1647,6 +1667,51 @@ func (s *Server) handleLogin(c *gin.Context) { return } + if req.EmailEncrypted != nil { + if s.cryptoService == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "加密服务不可用"}) + return + } + + decryptedEmail, err := s.cryptoService.DecryptSensitiveData(req.EmailEncrypted) + if err != nil { + log.Printf("❌ 登录邮箱解密失败: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "邮箱解密失败"}) + return + } + req.Email = decryptedEmail + } + + if req.PasswordEncrypted != nil { + if s.cryptoService == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "加密服务不可用"}) + return + } + + decryptedPassword, err := s.cryptoService.DecryptSensitiveData(req.PasswordEncrypted) + if err != nil { + log.Printf("❌ 登录密码解密失败: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "密码解密失败"}) + return + } + req.Password = decryptedPassword + } + + req.Email = strings.TrimSpace(req.Email) + if req.Email == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "邮箱不能为空"}) + return + } + if !strings.Contains(req.Email, "@") { + c.JSON(http.StatusBadRequest, gin.H{"error": "邮箱格式错误"}) + return + } + + if strings.TrimSpace(req.Password) == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "密码不能为空"}) + return + } + // 获取用户信息 user, err := s.database.GetUserByEmail(req.Email) if err != nil { @@ -2026,3 +2091,64 @@ func (s *Server) handleGetPublicTraderConfig(c *gin.Context) { c.JSON(http.StatusOK, result) } + +// handleUpdateExchangeConfigsEncrypted 更新交易所配置(加密传输) +func (s *Server) handleUpdateExchangeConfigsEncrypted(c *gin.Context) { + if s.cryptoService == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "加密服务不可用"}) + return + } + + userID := c.GetString("user_id") + + // 接收加密载荷 + var payload crypto.EncryptedPayload + if err := c.ShouldBindJSON(&payload); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 解密数据 + decryptedData, err := s.cryptoService.DecryptSensitiveData(&payload) + if err != nil { + log.Printf("❌ 解密失败: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "解密失败"}) + return + } + + // 解析解密后的数据 + var req UpdateExchangeConfigRequest + if err := json.Unmarshal([]byte(decryptedData), &req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "数据格式错误"}) + return + } + + // 更新每个交易所的配置 + for exchangeID, exchangeData := range req.Exchanges { + err := s.database.UpdateExchange( + userID, + exchangeID, + exchangeData.Enabled, + exchangeData.APIKey, + exchangeData.SecretKey, + exchangeData.Testnet, + exchangeData.HyperliquidWalletAddr, + exchangeData.AsterUser, + exchangeData.AsterSigner, + exchangeData.AsterPrivateKey, + ) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("更新交易所 %s 失败: %v", exchangeID, err)}) + return + } + } + + // 重新加载该用户的所有交易员,使新配置立即生效 + err = s.traderManager.LoadUserTraders(s.database, userID) + if err != nil { + log.Printf("⚠️ 重新加载用户交易员到内存失败: %v", err) + } + + log.Printf("✓ 交易所配置已通过加密方式更新") + c.JSON(http.StatusOK, gin.H{"message": "交易所配置已更新"}) +} diff --git a/config/database.go b/config/database.go index 51876587..1e6e1504 100644 --- a/config/database.go +++ b/config/database.go @@ -1,126 +1,130 @@ package config import ( - "fmt" - "time" + "fmt" + "time" + + "nofx/crypto" ) // DatabaseInterface 定义了数据库实现需要提供的方法集合 type DatabaseInterface interface { - CreateUser(user *User) error - GetUserByEmail(email string) (*User, error) - GetUserByID(userID string) (*User, error) - GetAllUsers() ([]string, error) - UpdateUserOTPVerified(userID string, verified bool) error - GetAIModels(userID string) ([]*AIModelConfig, error) - UpdateAIModel(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error - GetExchanges(userID string) ([]*ExchangeConfig, error) - UpdateExchange(userID, id string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error - CreateAIModel(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error - CreateExchange(userID, id, name, typ string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error - CreateTrader(trader *TraderRecord) error - GetTraders(userID string) ([]*TraderRecord, error) - UpdateTraderStatus(userID, id string, isRunning bool) error - UpdateTrader(trader *TraderRecord) error - UpdateTraderInitialBalance(userID, id string, newBalance float64) error - UpdateTraderCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error - DeleteTrader(userID, id string) error - GetTraderConfig(userID, traderID string) (*TraderRecord, *AIModelConfig, *ExchangeConfig, error) - GetSystemConfig(key string) (string, error) - SetSystemConfig(key, value string) error - CreateUserSignalSource(userID, coinPoolURL, oiTopURL string) error - GetUserSignalSource(userID string) (*UserSignalSource, error) - UpdateUserSignalSource(userID, coinPoolURL, oiTopURL string) error - GetCustomCoins() []string - LoadBetaCodesFromFile(filePath string) error - ValidateBetaCode(code string) (bool, error) - UseBetaCode(code, userEmail string) error - GetBetaCodeStats() (total, used int, err error) - Close() error + SetCryptoService(cs *crypto.CryptoService) + CreateUser(user *User) error + GetUserByEmail(email string) (*User, error) + GetUserByID(userID string) (*User, error) + GetAllUsers() ([]string, error) + UpdateUserOTPVerified(userID string, verified bool) error + GetAIModels(userID string) ([]*AIModelConfig, error) + UpdateAIModel(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error + GetExchanges(userID string) ([]*ExchangeConfig, error) + UpdateExchange(userID, id string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error + CreateAIModel(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error + CreateExchange(userID, id, name, typ string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error + CreateTrader(trader *TraderRecord) error + GetTraders(userID string) ([]*TraderRecord, error) + UpdateTraderStatus(userID, id string, isRunning bool) error + UpdateTrader(trader *TraderRecord) error + UpdateTraderInitialBalance(userID, id string, newBalance float64) error + UpdateTraderCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error + DeleteTrader(userID, id string) error + GetTraderConfig(userID, traderID string) (*TraderRecord, *AIModelConfig, *ExchangeConfig, error) + GetSystemConfig(key string) (string, error) + SetSystemConfig(key, value string) error + CreateUserSignalSource(userID, coinPoolURL, oiTopURL string) error + GetUserSignalSource(userID string) (*UserSignalSource, error) + UpdateUserSignalSource(userID, coinPoolURL, oiTopURL string) error + GetCustomCoins() []string + LoadBetaCodesFromFile(filePath string) error + ValidateBetaCode(code string) (bool, error) + UseBetaCode(code, userEmail string) error + GetBetaCodeStats() (total, used int, err error) + Close() error } // User 用户配置 type User struct { - ID string `json:"id"` - Email string `json:"email"` - PasswordHash string `json:"-"` // 不返回到前端 - OTPSecret string `json:"-"` // 不返回到前端 - OTPVerified bool `json:"otp_verified"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + Email string `json:"email"` + PasswordHash string `json:"-"` + OTPSecret string `json:"-"` + OTPVerified bool `json:"otp_verified"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // AIModelConfig AI模型配置 type AIModelConfig struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - Provider string `json:"provider"` - Enabled bool `json:"enabled"` - APIKey string `json:"apiKey"` - CustomAPIURL string `json:"customApiUrl"` - CustomModelName string `json:"customModelName"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + UserID string `json:"user_id"` + Name string `json:"name"` + Provider string `json:"provider"` + Enabled bool `json:"enabled"` + APIKey string `json:"apiKey"` + CustomAPIURL string `json:"customApiUrl"` + CustomModelName string `json:"customModelName"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // ExchangeConfig 交易所配置 type ExchangeConfig struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - Type string `json:"type"` - Enabled bool `json:"enabled"` - APIKey string `json:"apiKey"` - SecretKey string `json:"secretKey"` - Testnet bool `json:"testnet"` - HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` - AsterUser string `json:"asterUser"` - AsterSigner string `json:"asterSigner"` - AsterPrivateKey string `json:"asterPrivateKey"` - Deleted bool `json:"deleted"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + UserID string `json:"user_id"` + Name string `json:"name"` + Type string `json:"type"` + Enabled bool `json:"enabled"` + APIKey string `json:"apiKey"` + SecretKey string `json:"secretKey"` + Testnet bool `json:"testnet"` + HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` + AsterUser string `json:"asterUser"` + AsterSigner string `json:"asterSigner"` + AsterPrivateKey string `json:"asterPrivateKey"` + DEXWalletPrivateKey string `json:"dexWalletPrivateKey"` // 统一的DEX私钥字段 + Deleted bool `json:"deleted"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } -// TraderRecord 交易员配置(数据库实体) +// TraderRecord 交易员配置 type TraderRecord struct { - ID string `json:"id"` - UserID string `json:"user_id"` - Name string `json:"name"` - AIModelID string `json:"ai_model_id"` - ExchangeID string `json:"exchange_id"` - InitialBalance float64 `json:"initial_balance"` - ScanIntervalMinutes int `json:"scan_interval_minutes"` - IsRunning bool `json:"is_running"` - BTCETHLeverage int `json:"btc_eth_leverage"` - AltcoinLeverage int `json:"altcoin_leverage"` - TradingSymbols string `json:"trading_symbols"` - UseCoinPool bool `json:"use_coin_pool"` - UseOITop bool `json:"use_oi_top"` - CustomPrompt string `json:"custom_prompt"` - OverrideBasePrompt bool `json:"override_base_prompt"` - SystemPromptTemplate string `json:"system_prompt_template"` - IsCrossMargin bool `json:"is_cross_margin"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID string `json:"id"` + UserID string `json:"user_id"` + Name string `json:"name"` + AIModelID string `json:"ai_model_id"` + ExchangeID string `json:"exchange_id"` + InitialBalance float64 `json:"initial_balance"` + ScanIntervalMinutes int `json:"scan_interval_minutes"` + IsRunning bool `json:"is_running"` + BTCETHLeverage int `json:"btc_eth_leverage"` + AltcoinLeverage int `json:"altcoin_leverage"` + TradingSymbols string `json:"trading_symbols"` + UseCoinPool bool `json:"use_coin_pool"` + UseOITop bool `json:"use_oi_top"` + CustomPrompt string `json:"custom_prompt"` + OverrideBasePrompt bool `json:"override_base_prompt"` + SystemPromptTemplate string `json:"system_prompt_template"` + IsCrossMargin bool `json:"is_cross_margin"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // UserSignalSource 用户信号源配置 type UserSignalSource struct { - ID int `json:"id"` - UserID string `json:"user_id"` - CoinPoolURL string `json:"coin_pool_url"` - OITopURL string `json:"oi_top_url"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID int `json:"id"` + UserID string `json:"user_id"` + CoinPoolURL string `json:"coin_pool_url"` + OITopURL string `json:"oi_top_url"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // NewDatabase 创建数据库连接(仅支持 PostgreSQL) func NewDatabase() (DatabaseInterface, error) { - pgDB, err := NewPostgreSQLDatabase() - if err != nil { - return nil, fmt.Errorf("创建PostgreSQL数据库失败: %w", err) - } - return pgDB, nil + pgDB, err := NewPostgreSQLDatabase() + if err != nil { + return nil, fmt.Errorf("创建PostgreSQL数据库失败: %w", err) + } + return pgDB, nil } diff --git a/config/database_pg.go b/config/database_pg.go index a7da471e..1acee98f 100644 --- a/config/database_pg.go +++ b/config/database_pg.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "log" + "nofx/crypto" "nofx/market" "os" "slices" @@ -16,7 +17,8 @@ import ( // PostgreSQLDatabase PostgreSQL数据库配置 type PostgreSQLDatabase struct { - db *sql.DB + db *sql.DB + cryptoService *crypto.CryptoService } // NewPostgreSQLDatabase 创建PostgreSQL数据库连接 @@ -60,6 +62,42 @@ func NewPostgreSQLDatabase() (*PostgreSQLDatabase, error) { return database, nil } +func (d *PostgreSQLDatabase) SetCryptoService(cs *crypto.CryptoService) { + d.cryptoService = cs +} + +func (d *PostgreSQLDatabase) encryptValue(value string, aadParts ...string) (string, error) { + if value == "" { + return "", nil + } + if d.cryptoService == nil { + return "", fmt.Errorf("crypto service not initialized") + } + if !d.cryptoService.HasDataKey() { + return "", fmt.Errorf("data encryption key not configured") + } + if d.cryptoService.IsEncryptedStorageValue(value) { + return value, nil + } + return d.cryptoService.EncryptForStorage(value, aadParts...) +} + +func (d *PostgreSQLDatabase) decryptValue(value string, aadParts ...string) (string, error) { + if value == "" { + return "", nil + } + if d.cryptoService == nil { + return "", fmt.Errorf("crypto service not initialized") + } + if !d.cryptoService.HasDataKey() { + return "", fmt.Errorf("data encryption key not configured") + } + if !d.cryptoService.IsEncryptedStorageValue(value) { + return "", fmt.Errorf("value is not encrypted") + } + return d.cryptoService.DecryptFromStorage(value, aadParts...) +} + // getEnv 获取环境变量,如果不存在返回默认值 func getEnv(key, defaultValue string) string { if value := os.Getenv(key); value != "" { @@ -162,6 +200,15 @@ func (d *PostgreSQLDatabase) GetAIModels(userID string) ([]*AIModelConfig, error if err != nil { return nil, err } + + if model.APIKey != "" { + decrypted, err := d.decryptValue(model.APIKey, model.UserID, model.ID, "api_key") + if err != nil { + return nil, err + } + model.APIKey = decrypted + } + models = append(models, &model) } @@ -216,7 +263,7 @@ func (d *PostgreSQLDatabase) UpdateAIModel(userID, id string, enabled bool, apiK log.Printf("🗑️ UpdateAIModel: 已标记删除用户 %s 的模型配置 %s (通过provider匹配)", userID, existingID) return nil } - + // 没有找到配置,返回成功(幂等性) log.Printf("ℹ️ UpdateAIModel: 模型配置不存在,跳过删除: %s", id) return nil @@ -229,11 +276,18 @@ func (d *PostgreSQLDatabase) UpdateAIModel(userID, id string, enabled bool, apiK `, userID, id).Scan(&existingID) if err == nil { + apiKeyEnc, err := d.encryptValue(apiKey, userID, existingID, "api_key") + if err != nil { + return err + } // 找到了现有配置(精确匹配 ID),更新它 _, err = d.db.Exec(` UPDATE ai_models SET enabled = $1, api_key = $2, custom_api_url = $3, custom_model_name = $4, deleted = FALSE, updated_at = CURRENT_TIMESTAMP WHERE id = $5 AND user_id = $6 - `, enabled, apiKey, customAPIURL, customModelName, existingID, userID) + `, enabled, apiKeyEnc, customAPIURL, customModelName, existingID, userID) + return err + } + if err != sql.ErrNoRows { return err } @@ -244,12 +298,19 @@ func (d *PostgreSQLDatabase) UpdateAIModel(userID, id string, enabled bool, apiK `, userID, provider).Scan(&existingID) if err == nil { + apiKeyEnc, err := d.encryptValue(apiKey, userID, existingID, "api_key") + if err != nil { + return err + } // 找到了现有配置(通过 provider 匹配,兼容旧版),更新它 log.Printf("⚠️ 使用旧版 provider 匹配更新模型: %s -> %s", provider, existingID) _, err = d.db.Exec(` UPDATE ai_models SET enabled = $1, api_key = $2, custom_api_url = $3, custom_model_name = $4, deleted = FALSE, updated_at = CURRENT_TIMESTAMP WHERE id = $5 AND user_id = $6 - `, enabled, apiKey, customAPIURL, customModelName, existingID, userID) + `, enabled, apiKeyEnc, customAPIURL, customModelName, existingID, userID) + return err + } + if err != sql.ErrNoRows { return err } @@ -292,11 +353,16 @@ func (d *PostgreSQLDatabase) UpdateAIModel(userID, id string, enabled bool, apiK newModelID = fmt.Sprintf("%s_%s", userID, provider) } + apiKeyEnc, err := d.encryptValue(apiKey, userID, newModelID, "api_key") + if err != nil { + return err + } + log.Printf("✓ 创建新的 AI 模型配置: ID=%s, Provider=%s, Name=%s", newModelID, provider, name) _, err = d.db.Exec(` INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url, custom_model_name, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - `, newModelID, userID, name, provider, enabled, apiKey, customAPIURL, customModelName) + `, newModelID, userID, name, provider, enabled, apiKeyEnc, customAPIURL, customModelName) return err } @@ -309,6 +375,7 @@ func (d *PostgreSQLDatabase) GetExchanges(userID string) ([]*ExchangeConfig, err COALESCE(aster_user, '') AS aster_user, COALESCE(aster_signer, '') AS aster_signer, COALESCE(aster_private_key, '') AS aster_private_key, + COALESCE(dex_wallet_private_key, '') AS dex_wallet_private_key, COALESCE(deleted, FALSE) AS deleted, created_at, updated_at FROM exchanges @@ -329,12 +396,50 @@ func (d *PostgreSQLDatabase) GetExchanges(userID string) ([]*ExchangeConfig, err &exchange.Enabled, &exchange.APIKey, &exchange.SecretKey, &exchange.Testnet, &exchange.HyperliquidWalletAddr, &exchange.AsterUser, &exchange.AsterSigner, &exchange.AsterPrivateKey, + &exchange.DEXWalletPrivateKey, &exchange.Deleted, &exchange.CreatedAt, &exchange.UpdatedAt, ) if err != nil { return nil, err } + + if decrypted, err := d.decryptValue(exchange.APIKey, exchange.UserID, exchange.ID, "api_key"); err == nil { + exchange.APIKey = decrypted + } else { + return nil, err + } + if decrypted, err := d.decryptValue(exchange.SecretKey, exchange.UserID, exchange.ID, "secret_key"); err == nil { + exchange.SecretKey = decrypted + } else { + return nil, err + } + if decrypted, err := d.decryptValue(exchange.HyperliquidWalletAddr, exchange.UserID, exchange.ID, "hyperliquid_wallet_addr"); err == nil { + exchange.HyperliquidWalletAddr = decrypted + } else { + return nil, err + } + if decrypted, err := d.decryptValue(exchange.AsterUser, exchange.UserID, exchange.ID, "aster_user"); err == nil { + exchange.AsterUser = decrypted + } else { + return nil, err + } + if decrypted, err := d.decryptValue(exchange.AsterSigner, exchange.UserID, exchange.ID, "aster_signer"); err == nil { + exchange.AsterSigner = decrypted + } else { + return nil, err + } + if decrypted, err := d.decryptValue(exchange.AsterPrivateKey, exchange.UserID, exchange.ID, "aster_private_key"); err == nil { + exchange.AsterPrivateKey = decrypted + } else { + return nil, err + } + if decrypted, err := d.decryptValue(exchange.DEXWalletPrivateKey, exchange.UserID, exchange.ID, "dex_wallet_private_key"); err == nil { + exchange.DEXWalletPrivateKey = decrypted + } else { + return nil, err + } + exchanges = append(exchanges, &exchange) } @@ -345,7 +450,7 @@ func (d *PostgreSQLDatabase) GetExchanges(userID string) ([]*ExchangeConfig, err func (d *PostgreSQLDatabase) UpdateExchange(userID, id string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error { log.Printf("🔧 UpdateExchange: userID=%s, id=%s, enabled=%v", userID, id, enabled) - // 如果请求禁用该交易所,标记为已删除 + // 如果请求禁用该交易所,执行软删除 if !enabled { _, err := d.db.Exec(` UPDATE exchanges @@ -369,13 +474,38 @@ func (d *PostgreSQLDatabase) UpdateExchange(userID, id string, enabled bool, api return nil } + apiKeyEnc, err := d.encryptValue(apiKey, userID, id, "api_key") + if err != nil { + return fmt.Errorf("encrypt api_key failed: %w", err) + } + secretKeyEnc, err := d.encryptValue(secretKey, userID, id, "secret_key") + if err != nil { + return fmt.Errorf("encrypt secret_key failed: %w", err) + } + hyperAddrEnc, err := d.encryptValue(hyperliquidWalletAddr, userID, id, "hyperliquid_wallet_addr") + if err != nil { + return fmt.Errorf("encrypt hyperliquid_wallet_addr failed: %w", err) + } + asterUserEnc, err := d.encryptValue(asterUser, userID, id, "aster_user") + if err != nil { + return fmt.Errorf("encrypt aster_user failed: %w", err) + } + asterSignerEnc, err := d.encryptValue(asterSigner, userID, id, "aster_signer") + if err != nil { + return fmt.Errorf("encrypt aster_signer failed: %w", err) + } + asterPrivateKeyEnc, err := d.encryptValue(asterPrivateKey, userID, id, "aster_private_key") + if err != nil { + return fmt.Errorf("encrypt aster_private_key failed: %w", err) + } + // 首先尝试更新现有的用户配置 result, err := d.db.Exec(` UPDATE exchanges SET enabled = $1, api_key = $2, secret_key = $3, testnet = $4, hyperliquid_wallet_addr = $5, aster_user = $6, aster_signer = $7, aster_private_key = $8, deleted = FALSE, updated_at = CURRENT_TIMESTAMP WHERE id = $9 AND user_id = $10 - `, enabled, apiKey, secretKey, testnet, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, id, userID) + `, enabled, apiKeyEnc, secretKeyEnc, testnet, hyperAddrEnc, asterUserEnc, asterSignerEnc, asterPrivateKeyEnc, id, userID) if err != nil { log.Printf("❌ UpdateExchange: 更新失败: %v", err) return err @@ -418,7 +548,7 @@ func (d *PostgreSQLDatabase) UpdateExchange(userID, id string, enabled bool, api hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key, deleted, created_at, updated_at) VALUES ($1, $2, $3, $4, TRUE, $5, $6, $7, $8, $9, $10, $11, FALSE, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - `, id, userID, name, typ, apiKey, secretKey, testnet, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey) + `, id, userID, name, typ, apiKeyEnc, secretKeyEnc, testnet, hyperAddrEnc, asterUserEnc, asterSignerEnc, asterPrivateKeyEnc) if err != nil { log.Printf("❌ UpdateExchange: 创建记录失败: %v", err) @@ -434,21 +564,51 @@ func (d *PostgreSQLDatabase) UpdateExchange(userID, id string, enabled bool, api // CreateAIModel 创建AI模型配置 func (d *PostgreSQLDatabase) CreateAIModel(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error { - _, err := d.db.Exec(` + apiKeyEnc, err := d.encryptValue(apiKey, userID, id, "api_key") + if err != nil { + return err + } + + _, err = d.db.Exec(` INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (id) DO NOTHING - `, id, userID, name, provider, enabled, apiKey, customAPIURL) + `, id, userID, name, provider, enabled, apiKeyEnc, customAPIURL) return err } // CreateExchange 创建交易所配置 func (d *PostgreSQLDatabase) CreateExchange(userID, id, name, typ string, enabled bool, apiKey, secretKey string, testnet bool, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error { - _, err := d.db.Exec(` + apiKeyEnc, err := d.encryptValue(apiKey, userID, id, "api_key") + if err != nil { + return fmt.Errorf("encrypt api_key failed: %w", err) + } + secretKeyEnc, err := d.encryptValue(secretKey, userID, id, "secret_key") + if err != nil { + return fmt.Errorf("encrypt secret_key failed: %w", err) + } + hyperAddrEnc, err := d.encryptValue(hyperliquidWalletAddr, userID, id, "hyperliquid_wallet_addr") + if err != nil { + return fmt.Errorf("encrypt hyperliquid_wallet_addr failed: %w", err) + } + asterUserEnc, err := d.encryptValue(asterUser, userID, id, "aster_user") + if err != nil { + return fmt.Errorf("encrypt aster_user failed: %w", err) + } + asterSignerEnc, err := d.encryptValue(asterSigner, userID, id, "aster_signer") + if err != nil { + return fmt.Errorf("encrypt aster_signer failed: %w", err) + } + asterPrivateKeyEnc, err := d.encryptValue(asterPrivateKey, userID, id, "aster_private_key") + if err != nil { + return fmt.Errorf("encrypt aster_private_key failed: %w", err) + } + + _, err = d.db.Exec(` INSERT INTO exchanges (id, user_id, name, type, enabled, api_key, secret_key, testnet, hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) ON CONFLICT (id, user_id) DO NOTHING - `, id, userID, name, typ, enabled, apiKey, secretKey, testnet, hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey) + `, id, userID, name, typ, enabled, apiKeyEnc, secretKeyEnc, testnet, hyperAddrEnc, asterUserEnc, asterSignerEnc, asterPrivateKeyEnc) return err } @@ -575,6 +735,57 @@ func (d *PostgreSQLDatabase) GetTraderConfig(userID, traderID string) (*TraderRe return nil, nil, nil, err } + if aiModel.APIKey != "" { + decrypted, err := d.decryptValue(aiModel.APIKey, aiModel.UserID, aiModel.ID, "api_key") + if err != nil { + return nil, nil, nil, err + } + aiModel.APIKey = decrypted + } + + if exchange.APIKey != "" { + decrypted, err := d.decryptValue(exchange.APIKey, exchange.UserID, exchange.ID, "api_key") + if err != nil { + return nil, nil, nil, err + } + exchange.APIKey = decrypted + } + if exchange.SecretKey != "" { + decrypted, err := d.decryptValue(exchange.SecretKey, exchange.UserID, exchange.ID, "secret_key") + if err != nil { + return nil, nil, nil, err + } + exchange.SecretKey = decrypted + } + if exchange.HyperliquidWalletAddr != "" { + decrypted, err := d.decryptValue(exchange.HyperliquidWalletAddr, exchange.UserID, exchange.ID, "hyperliquid_wallet_addr") + if err != nil { + return nil, nil, nil, err + } + exchange.HyperliquidWalletAddr = decrypted + } + if exchange.AsterUser != "" { + decrypted, err := d.decryptValue(exchange.AsterUser, exchange.UserID, exchange.ID, "aster_user") + if err != nil { + return nil, nil, nil, err + } + exchange.AsterUser = decrypted + } + if exchange.AsterSigner != "" { + decrypted, err := d.decryptValue(exchange.AsterSigner, exchange.UserID, exchange.ID, "aster_signer") + if err != nil { + return nil, nil, nil, err + } + exchange.AsterSigner = decrypted + } + if exchange.AsterPrivateKey != "" { + decrypted, err := d.decryptValue(exchange.AsterPrivateKey, exchange.UserID, exchange.ID, "aster_private_key") + if err != nil { + return nil, nil, nil, err + } + exchange.AsterPrivateKey = decrypted + } + return &trader, &aiModel, &exchange, nil } diff --git a/crypto/crypto.go b/crypto/crypto.go new file mode 100644 index 00000000..9a29480f --- /dev/null +++ b/crypto/crypto.go @@ -0,0 +1,394 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + "time" +) + +const ( + storagePrefix = "ENC:v1:" + storageDelimiter = ":" + dataKeyEnvName = "DATA_ENCRYPTION_KEY" +) + +type EncryptedPayload struct { + WrappedKey string `json:"wrappedKey"` + IV string `json:"iv"` + Ciphertext string `json:"ciphertext"` + AAD string `json:"aad,omitempty"` + KID string `json:"kid,omitempty"` + TS int64 `json:"ts,omitempty"` +} + +type AADData struct { + UserID string `json:"userId"` + SessionID string `json:"sessionId"` + TS int64 `json:"ts"` + Purpose string `json:"purpose"` +} + +type CryptoService struct { + privateKey *rsa.PrivateKey + publicKey *rsa.PublicKey + dataKey []byte +} + +func NewCryptoService(privateKeyPath string) (*CryptoService, error) { + // 读取私钥文件 + privateKeyPEM, err := ioutil.ReadFile(privateKeyPath) + if err != nil { + // 如果私钥文件不存在,生成新的密钥对 + if err := GenerateRSAKeyPair(privateKeyPath); err != nil { + return nil, fmt.Errorf("failed to generate RSA key pair: %w", err) + } + privateKeyPEM, err = ioutil.ReadFile(privateKeyPath) + if err != nil { + return nil, fmt.Errorf("failed to read generated private key: %w", err) + } + } + + // 解析私钥 + privateKey, err := ParseRSAPrivateKeyFromPEM(privateKeyPEM) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + + dataKey, err := loadDataKeyFromEnv() + if err != nil { + return nil, fmt.Errorf("failed to load data encryption key: %w", err) + } + + return &CryptoService{ + privateKey: privateKey, + publicKey: &privateKey.PublicKey, + dataKey: dataKey, + }, nil +} + +func GenerateRSAKeyPair(privateKeyPath string) error { + // 确保目录存在 + dir := filepath.Dir(privateKeyPath) + if dir != "." { + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create directory %s: %w", dir, err) + } + } + + // 生成 RSA 密钥对 + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } + + // 编码私钥 + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + + // 保存私钥 + if err := ioutil.WriteFile(privateKeyPath, privateKeyPEM, 0600); err != nil { + return err + } + + // 编码公钥 + publicKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) + if err != nil { + return err + } + + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyDER, + }) + + // 保存公钥 + publicKeyPath := privateKeyPath + ".pub" + if err := ioutil.WriteFile(publicKeyPath, publicKeyPEM, 0644); err != nil { + return err + } + + return nil +} + +func ParseRSAPrivateKeyFromPEM(pemBytes []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, errors.New("no PEM block found") + } + + switch block.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(block.Bytes) + case "PRIVATE KEY": + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + rsaKey, ok := key.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("not an RSA key") + } + return rsaKey, nil + default: + return nil, errors.New("unsupported key type: " + block.Type) + } +} + +func loadDataKeyFromEnv() ([]byte, error) { + keyStr := strings.TrimSpace(os.Getenv(dataKeyEnvName)) + if keyStr == "" { + return nil, fmt.Errorf("%s not set", dataKeyEnvName) + } + + if key, ok := decodePossibleKey(keyStr); ok { + return key, nil + } + + sum := sha256.Sum256([]byte(keyStr)) + key := make([]byte, len(sum)) + copy(key, sum[:]) + return key, nil +} + +func decodePossibleKey(value string) ([]byte, bool) { + decoders := []func(string) ([]byte, error){ + base64.StdEncoding.DecodeString, + base64.RawStdEncoding.DecodeString, + func(s string) ([]byte, error) { return hex.DecodeString(s) }, + } + + for _, decoder := range decoders { + if decoded, err := decoder(value); err == nil { + if key, ok := normalizeAESKey(decoded); ok { + return key, true + } + } + } + + return nil, false +} + +func normalizeAESKey(raw []byte) ([]byte, bool) { + switch len(raw) { + case 16, 24, 32: + return raw, true + case 0: + return nil, false + default: + sum := sha256.Sum256(raw) + key := make([]byte, len(sum)) + copy(key, sum[:]) + return key, true + } +} + +func (cs *CryptoService) HasDataKey() bool { + return len(cs.dataKey) > 0 +} + +func (cs *CryptoService) GetPublicKeyPEM() string { + publicKeyDER, err := x509.MarshalPKIXPublicKey(cs.publicKey) + if err != nil { + return "" + } + + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyDER, + }) + + return string(publicKeyPEM) +} + +func (cs *CryptoService) EncryptForStorage(plaintext string, aadParts ...string) (string, error) { + if plaintext == "" { + return "", nil + } + if !cs.HasDataKey() { + return "", errors.New("data encryption key not configured") + } + if isEncryptedStorageValue(plaintext) { + return plaintext, nil + } + + block, err := aes.NewCipher(cs.dataKey) + if err != nil { + return "", err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return "", err + } + + aad := composeAAD(aadParts) + ciphertext := gcm.Seal(nil, nonce, []byte(plaintext), aad) + + return storagePrefix + + base64.StdEncoding.EncodeToString(nonce) + storageDelimiter + + base64.StdEncoding.EncodeToString(ciphertext), nil +} + +func (cs *CryptoService) DecryptFromStorage(value string, aadParts ...string) (string, error) { + if value == "" { + return "", nil + } + if !cs.HasDataKey() { + return "", errors.New("data encryption key not configured") + } + if !isEncryptedStorageValue(value) { + return "", errors.New("value is not encrypted") + } + + payload := strings.TrimPrefix(value, storagePrefix) + parts := strings.SplitN(payload, storageDelimiter, 2) + if len(parts) != 2 { + return "", errors.New("invalid encrypted payload format") + } + + nonce, err := base64.StdEncoding.DecodeString(parts[0]) + if err != nil { + return "", fmt.Errorf("decode nonce failed: %w", err) + } + + ciphertext, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("decode ciphertext failed: %w", err) + } + + block, err := aes.NewCipher(cs.dataKey) + if err != nil { + return "", err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + if len(nonce) != gcm.NonceSize() { + return "", fmt.Errorf("invalid nonce size: expected %d, got %d", gcm.NonceSize(), len(nonce)) + } + + aad := composeAAD(aadParts) + plaintext, err := gcm.Open(nil, nonce, ciphertext, aad) + if err != nil { + return "", fmt.Errorf("decryption failed: %w", err) + } + + return string(plaintext), nil +} + +func (cs *CryptoService) IsEncryptedStorageValue(value string) bool { + return isEncryptedStorageValue(value) +} + +func composeAAD(parts []string) []byte { + if len(parts) == 0 { + return nil + } + return []byte(strings.Join(parts, "|")) +} + +func isEncryptedStorageValue(value string) bool { + return strings.HasPrefix(value, storagePrefix) +} + +func (cs *CryptoService) DecryptPayload(payload *EncryptedPayload) ([]byte, error) { + // 1. 验证时间戳(防止重放攻击) + if payload.TS != 0 { + elapsed := time.Since(time.Unix(payload.TS, 0)) + if elapsed > 5*time.Minute || elapsed < -1*time.Minute { + return nil, errors.New("timestamp invalid or expired") + } + } + + // 2. 解码 base64url + wrappedKey, err := base64.RawURLEncoding.DecodeString(payload.WrappedKey) + if err != nil { + return nil, fmt.Errorf("failed to decode wrapped key: %w", err) + } + + iv, err := base64.RawURLEncoding.DecodeString(payload.IV) + if err != nil { + return nil, fmt.Errorf("failed to decode IV: %w", err) + } + + ciphertext, err := base64.RawURLEncoding.DecodeString(payload.Ciphertext) + if err != nil { + return nil, fmt.Errorf("failed to decode ciphertext: %w", err) + } + + var aad []byte + if payload.AAD != "" { + aad, err = base64.RawURLEncoding.DecodeString(payload.AAD) + if err != nil { + return nil, fmt.Errorf("failed to decode AAD: %w", err) + } + + // 验证 AAD + var aadData AADData + if err := json.Unmarshal(aad, &aadData); err == nil { + // 可以在这里添加额外的验证逻辑 + // 例如:验证 sessionID、userID 等 + } + } + + // 3. 使用 RSA-OAEP 解密 AES 密钥 + aesKey, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, cs.privateKey, wrappedKey, nil) + if err != nil { + return nil, fmt.Errorf("failed to unwrap AES key: %w", err) + } + + // 4. 使用 AES-GCM 解密数据 + block, err := aes.NewCipher(aesKey) + if err != nil { + return nil, fmt.Errorf("failed to create AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + + if len(iv) != gcm.NonceSize() { + return nil, fmt.Errorf("invalid IV size: expected %d, got %d", gcm.NonceSize(), len(iv)) + } + + // 解密并验证认证标签 + plaintext, err := gcm.Open(nil, iv, ciphertext, aad) + if err != nil { + return nil, fmt.Errorf("authentication/decryption failed: %w", err) + } + + return plaintext, nil +} + +func (cs *CryptoService) DecryptSensitiveData(payload *EncryptedPayload) (string, error) { + plaintext, err := cs.DecryptPayload(payload) + if err != nil { + return "", err + } + return string(plaintext), nil +} diff --git a/docker-compose.yml b/docker-compose.yml index acdf459a..a15a01de 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -57,6 +57,7 @@ services: environment: - TZ=${NOFX_TIMEZONE:-Asia/Shanghai} # Set timezone - AI_MAX_TOKENS=4000 # AI响应的最大token数(默认2000,建议4000-8000) + - DATA_ENCRYPTION_KEY=${DATA_ENCRYPTION_KEY} # 数据加密密钥 - POSTGRES_HOST=postgres - POSTGRES_PORT=5432 - POSTGRES_DB=${POSTGRES_DB:-nofx} diff --git a/main.go b/main.go index 73dbab1b..dee1082e 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "nofx/api" "nofx/auth" "nofx/config" + "nofx/crypto" "nofx/manager" "nofx/market" "nofx/pool" @@ -171,6 +172,13 @@ func main() { } defer database.Close() + // 初始化加密服务(用于敏感数据加密存储与传输) + cryptoService, err := crypto.NewCryptoService("keys/rsa_private.key") + if err != nil { + log.Fatalf("❌ 初始化加密服务失败: %v", err) + } + database.SetCryptoService(cryptoService) + // 同步config.json到数据库 if err := syncConfigToDatabase(database, configFile); err != nil { log.Printf("⚠️ 同步config.json到数据库失败: %v", err) @@ -289,7 +297,7 @@ func main() { } // 创建并启动API服务器 - apiServer := api.NewServer(traderManager, database, apiPort) + apiServer := api.NewServer(traderManager, database, cryptoService, apiPort) go func() { if err := apiServer.Start(); err != nil { log.Printf("❌ API服务器错误: %v", err) diff --git a/web/src/components/AITradersPage.tsx b/web/src/components/AITradersPage.tsx index 198821bd..87833f6c 100644 --- a/web/src/components/AITradersPage.tsx +++ b/web/src/components/AITradersPage.tsx @@ -13,6 +13,7 @@ import { useAuth } from '../contexts/AuthContext' import { getExchangeIcon } from './ExchangeIcons' import { getModelIcon } from './ModelIcons' import { TraderConfigModal } from './TraderConfigModal' +import { TwoStageKeyModal } from './TwoStageKeyModal' import { Bot, Brain, @@ -46,6 +47,12 @@ function getShortName(fullName: string): string { return parts.length > 1 ? parts[parts.length - 1] : fullName } +function maskSecret(value: string): string { + if (!value) return '' + const length = Math.min(value.length, 16) + return '•'.repeat(length) +} + interface AITradersPageProps { onTraderSelect?: (traderId: string) => void } @@ -445,7 +452,7 @@ export function AITradersPage({ onTraderSelect }: AITradersPageProps) { }, } - await api.updateExchangeConfigs(request) + await api.updateExchangeConfigsEncrypted(request) const refreshed = await api.getExchangeConfigs() setAllExchanges(refreshed) @@ -494,7 +501,7 @@ export function AITradersPage({ onTraderSelect }: AITradersPageProps) { }, } - await api.updateExchangeConfigs(request) + await api.updateExchangeConfigsEncrypted(request) const refreshedExchanges = await api.getExchangeConfigs() setAllExchanges(refreshedExchanges) @@ -1666,6 +1673,9 @@ function ExchangeConfigModal({ const [asterUser, setAsterUser] = useState('') const [asterSigner, setAsterSigner] = useState('') const [asterPrivateKey, setAsterPrivateKey] = useState('') + const [secureInputTarget, setSecureInputTarget] = useState< + null | 'hyperliquid' | 'aster' + >(null) // 获取当前选择的交易所信息 // 编辑模式:从 configuredExchanges 查找(包含用户配置的 apiKey、secretKey 等) @@ -1674,6 +1684,13 @@ function ExchangeConfigModal({ ? configuredExchanges?.find(e => e.id === selectedExchangeId) : supportedExchanges?.find(e => e.id === selectedExchangeId); + const secureInputContextLabel = + secureInputTarget === 'aster' + ? t('asterExchangeName', language) + : secureInputTarget === 'hyperliquid' + ? t('hyperliquidExchangeName', language) + : undefined + // 如果是编辑现有交易所,初始化表单数据 useEffect(() => { if (editingExchangeId && selectedExchange) { @@ -1692,6 +1709,28 @@ function ExchangeConfigModal({ } }, [editingExchangeId, selectedExchange]) + const handleSecureInputComplete = ({ + value, + obfuscationLog, + }: { + value: string + obfuscationLog: string[] + }) => { + const trimmed = value.trim() + if (secureInputTarget === 'hyperliquid') { + setApiKey(trimmed) + } + if (secureInputTarget === 'aster') { + setAsterPrivateKey(trimmed) + } + console.log('Secure input obfuscation log:', obfuscationLog) + setSecureInputTarget(null) + } + + const handleSecureInputCancel = () => { + setSecureInputTarget(null) + } + // 加载服务器IP(当选择binance时) useEffect(() => { if (selectedExchangeId === 'binance' && !serverIP) { @@ -1755,11 +1794,12 @@ function ExchangeConfigModal({ } return ( -
-
+ <> +
+

{editingExchangeId @@ -2094,19 +2134,55 @@ function ExchangeConfigModal({ > {t('privateKey', language)} - setApiKey(e.target.value)} - placeholder={t('enterPrivateKey', language)} - className="w-full px-3 py-2 rounded" - style={{ - background: '#0B0E11', - border: '1px solid #2B3139', - color: '#EAECEF', - }} - required - /> +
+
+ + + {apiKey && ( + + )} +
+ {apiKey && ( +
+ {t('secureInputHint', language)} +
+ )} +
{t('hyperliquidPrivateKeyDesc', language)}
@@ -2209,19 +2285,55 @@ function ExchangeConfigModal({ /> - setAsterPrivateKey(e.target.value)} - placeholder={t('enterPrivateKey', language)} - className="w-full px-3 py-2 rounded" - style={{ - background: '#0B0E11', - border: '1px solid #2B3139', - color: '#EAECEF', - }} - required - /> +
+
+ + + {asterPrivateKey && ( + + )} +
+ {asterPrivateKey && ( +
+ {t('secureInputHint', language)} +
+ )} +

)} @@ -2349,6 +2461,16 @@ function ExchangeConfigModal({
)} -
+
+ + + ) } diff --git a/web/src/components/TwoStageKeyModal.tsx b/web/src/components/TwoStageKeyModal.tsx new file mode 100644 index 00000000..fa0aa2ef --- /dev/null +++ b/web/src/components/TwoStageKeyModal.tsx @@ -0,0 +1,320 @@ +import { useEffect, useMemo, useRef, useState } from 'react' +import { createPortal } from 'react-dom' +import { t, type Language } from '../i18n/translations' + +const DEFAULT_LENGTH = 64 + +function generateObfuscation(): string { + const bytes = new Uint8Array(32) + crypto.getRandomValues(bytes) + return Array.from(bytes, (byte) => byte.toString(16).padStart(2, '0')).join('') +} + +function validatePrivateKeyFormat(value: string, expectedLength: number): boolean { + const normalized = value.startsWith('0x') ? value.slice(2) : value + if (normalized.length !== expectedLength) { + return false + } + return /^[0-9a-fA-F]+$/.test(normalized) +} + +export interface TwoStageKeyModalResult { + value: string + obfuscationLog: string[] +} + +interface TwoStageKeyModalProps { + isOpen: boolean + language: Language + onCancel: () => void + onComplete: (result: TwoStageKeyModalResult) => void + expectedLength?: number + contextLabel?: string +} + +export function TwoStageKeyModal({ + isOpen, + language, + onCancel, + onComplete, + expectedLength = DEFAULT_LENGTH, + contextLabel, +}: TwoStageKeyModalProps) { + const [stage, setStage] = useState<1 | 2>(1) + const [part1, setPart1] = useState('') + const [part2, setPart2] = useState('') + const [error, setError] = useState(null) + const [clipboardStatus, setClipboardStatus] = useState<'idle' | 'copied' | 'failed'>('idle') + const [obfuscationLog, setObfuscationLog] = useState([]) + const [processing, setProcessing] = useState(false) + const [manualObfuscationValue, setManualObfuscationValue] = useState(null) + const stage1InputRef = useRef(null) + const stage2InputRef = useRef(null) + + useEffect(() => { + if (!isOpen) return + const handler = (event: KeyboardEvent) => { + if (event.key === 'Escape') { + event.preventDefault() + onCancel() + } + } + document.addEventListener('keydown', handler) + return () => document.removeEventListener('keydown', handler) + }, [isOpen, onCancel]) + + useEffect(() => { + if (!isOpen) { + setStage(1) + setPart1('') + setPart2('') + setError(null) + setClipboardStatus('idle') + setObfuscationLog([]) + setProcessing(false) + setManualObfuscationValue(null) + return + } + + const focusTimer = setTimeout(() => { + if (stage === 1) { + stage1InputRef.current?.focus() + } else { + stage2InputRef.current?.focus() + } + }, 10) + + return () => clearTimeout(focusTimer) + }, [isOpen, stage]) + + const heading = useMemo(() => { + if (!contextLabel) { + return t('twoStageModalTitle', language) + } + return `${t('twoStageModalTitle', language)} · ${contextLabel}` + }, [contextLabel, language]) + + if (!isOpen) { + return null + } + + const handleOverlayClick = () => { + if (!processing) { + onCancel() + } + } + + const handleStage1Next = async () => { + if (!part1.trim()) { + setError(t('twoStageStage1Error', language)) + return + } + setProcessing(true) + const obfuscation = generateObfuscation() + let copied = false + try { + await navigator.clipboard.writeText(obfuscation) + copied = true + setClipboardStatus('copied') + setManualObfuscationValue(null) + } catch (err) { + console.warn('Clipboard write failed', err) + setClipboardStatus('failed') + setManualObfuscationValue(obfuscation) + } + setObfuscationLog((prev) => [...prev, `stage1:${new Date().toISOString()}`]) + setProcessing(false) + setError(null) + setStage(2) + if (copied) { + setManualObfuscationValue(null) + } + } + + const handleSubmit = () => { + const cleanedPart1 = part1.trim() + const cleanedPart2 = part2.trim() + const combined = (cleanedPart1 + cleanedPart2).replace(/\s+/g, '') + + if (!validatePrivateKeyFormat(combined, expectedLength)) { + setError(t('twoStageInvalidFormat', language, { length: expectedLength })) + return + } + + setObfuscationLog((prev) => [...prev, `stage2:${new Date().toISOString()}`]) + const result: TwoStageKeyModalResult = { + value: combined, + obfuscationLog: [...obfuscationLog, `stage2:${new Date().toISOString()}`], + } + onComplete(result) + } + + const modalContent = ( +
+
event.stopPropagation()} + > +
+

+ {heading} +

+

+ {t('twoStageModalDescription', language, { length: expectedLength })} +

+
+ + {stage === 1 ? ( +
+
+ + setPart1(event.target.value)} + placeholder={t('twoStageStage1Placeholder', language)} + className="w-full rounded border border-[#2B3139] bg-[#0F111C] px-3 py-2 text-sm text-[#EAECEF] outline-none focus:ring-2 focus:ring-[#F0B90B]/40" + disabled={processing} + /> +

+ {t('twoStageStage1Hint', language)} +

+
+ + {clipboardStatus === 'failed' && ( +
+
{t('twoStageClipboardManual', language)}
+ {manualObfuscationValue && ( + + {manualObfuscationValue} + + )} +
+ )} + + {error && ( +
+ {error} +
+ )} + +
+ + +
+
+ ) : ( +
+
+ + setPart2(event.target.value)} + placeholder={t('twoStageStage2Placeholder', language)} + className="w-full rounded border border-[#2B3139] bg-[#0F111C] px-3 py-2 text-sm text-[#EAECEF] outline-none focus:ring-2 focus:ring-[#F0B90B]/40" + /> +

+ {t('twoStageStage2Hint', language)} +

+
+ + {clipboardStatus === 'copied' && ( +
+ {t('twoStageClipboardSuccess', language)} +
+ )} + + {clipboardStatus === 'failed' && manualObfuscationValue && ( +
+ {t('twoStageClipboardReminder', language)} +
+ )} + + {error && ( +
+ {error} +
+ )} + +
+ + +
+
+ )} +
+
+ ) + + return createPortal(modalContent, document.body) +} diff --git a/web/src/contexts/AuthContext.tsx b/web/src/contexts/AuthContext.tsx index 1929d9ed..69a2f707 100644 --- a/web/src/contexts/AuthContext.tsx +++ b/web/src/contexts/AuthContext.tsx @@ -1,4 +1,6 @@ -import React, { createContext, useContext, useState, useEffect } from 'react' +import React, { createContext, useContext, useState, useEffect } from 'react'; +import { getSystemConfig } from '../lib/config'; +import { CryptoService } from '../lib/crypto'; interface User { id: string @@ -61,12 +63,33 @@ export function AuthProvider({ children }: { children: React.ReactNode }) { const login = async (email: string, password: string) => { try { + const systemConfig = await getSystemConfig() + if (!systemConfig.rsa_public_key) { + throw new Error('系统未配置登录所需的RSA公钥') + } + + await CryptoService.initialize(systemConfig.rsa_public_key) + const sessionId = sessionStorage.getItem('session_id') || '' + + const requestBody = { + email_encrypted: await CryptoService.encryptSensitiveData( + email, + email, + sessionId + ), + password_encrypted: await CryptoService.encryptSensitiveData( + password, + email, + sessionId + ), + } + const response = await fetch('/api/login', { method: 'POST', headers: { 'Content-Type': 'application/json', }, - body: JSON.stringify({ email, password }), + body: JSON.stringify(requestBody), }) const data = await response.json() @@ -84,6 +107,7 @@ export function AuthProvider({ children }: { children: React.ReactNode }) { return { success: false, message: data.error } } } catch (error) { + console.error('Login request failed:', error) return { success: false, message: '登录失败,请重试' } } diff --git a/web/src/i18n/translations.ts b/web/src/i18n/translations.ts index a65a7637..54de2a09 100644 --- a/web/src/i18n/translations.ts +++ b/web/src/i18n/translations.ts @@ -204,6 +204,42 @@ export const translations = { 'API wallet private key - Get from https://www.asterdex.com/en/api-wallet (only used locally for signing, never transmitted)', asterUsdtWarning: 'Important: Aster only tracks USDT balance. Please ensure you use USDT as margin currency to avoid P&L calculation errors caused by price fluctuations of other assets (BNB, ETH, etc.)', + hyperliquidExchangeName: 'Hyperliquid', + asterExchangeName: 'Aster DEX', + secureInputButton: 'Secure Input', + secureInputReenter: 'Re-enter Securely', + secureInputClear: 'Clear', + secureInputHint: + 'Captured via secure two-step input. Use “Re-enter Securely” to update this value.', + twoStageModalTitle: 'Secure Key Input', + twoStageModalDescription: + 'Use a two-step flow to enter your {length}-character private key safely.', + twoStageStage1Title: 'Step 1 · Enter the first half', + twoStageStage1Placeholder: 'First 32 characters (include 0x if present)', + twoStageStage1Hint: + 'Continuing copies an obfuscation string to your clipboard as a diversion.', + twoStageStage1Error: 'Please enter the first part before continuing.', + twoStageNext: 'Next', + twoStageProcessing: 'Processing…', + twoStageCancel: 'Cancel', + twoStageStage2Title: 'Step 2 · Enter the rest', + twoStageStage2Placeholder: 'Remaining characters of your private key', + twoStageStage2Hint: + 'Paste the obfuscation string somewhere neutral, then finish entering your key.', + twoStageClipboardSuccess: + 'Obfuscation string copied. Paste it into any text field once before completing.', + twoStageClipboardReminder: + 'Remember to paste the obfuscation string before submitting to avoid clipboard leaks.', + twoStageClipboardManual: + 'Automatic copy failed. Copy the obfuscation string below manually.', + twoStageClipboardFailed: + 'Automatic clipboard copy failed. Please copy the obfuscation string manually.', + twoStageClipboardInstruction: + 'Obfuscation string copied. Paste it once before finishing the input.', + twoStageBack: 'Back', + twoStageSubmit: 'Confirm', + twoStageInvalidFormat: + 'Invalid private key format. Expected {length} hexadecimal characters (optional 0x prefix).', testnetDescription: 'Enable to connect to exchange test environment for simulated trading', securityWarning: 'Security Warning', @@ -700,6 +736,34 @@ export const translations = { 'API 钱包私钥 - 从 https://www.asterdex.com/zh-CN/api-wallet 获取(仅在本地用于签名,不会被传输)', asterUsdtWarning: '重要提示:Aster 仅统计 USDT 余额。请确保您使用 USDT 作为保证金币种,避免其他资产(BNB、ETH等)的价格波动导致盈亏统计错误', + hyperliquidExchangeName: 'Hyperliquid', + asterExchangeName: 'Aster DEX', + secureInputButton: '安全输入', + secureInputReenter: '重新安全输入', + secureInputClear: '清除', + secureInputHint: '已通过安全双阶段输入设置。若需修改,请点击“重新安全输入”。', + twoStageModalTitle: '安全私钥输入', + twoStageModalDescription: '使用双阶段流程安全输入长度为 {length} 的私钥。', + twoStageStage1Title: '步骤一 · 输入前半段', + twoStageStage1Placeholder: '前 32 位字符(若有 0x 前缀请保留)', + twoStageStage1Hint: '继续后会将扰动字符串复制到剪贴板,用于迷惑剪贴板监控。', + twoStageStage1Error: '请先输入第一段私钥。', + twoStageNext: '下一步', + twoStageProcessing: '处理中…', + twoStageCancel: '取消', + twoStageStage2Title: '步骤二 · 输入剩余部分', + twoStageStage2Placeholder: '剩余的私钥字符', + twoStageStage2Hint: '将扰动字符串粘贴到任意位置后,再完成私钥输入。', + twoStageClipboardSuccess: + '扰动字符串已复制。请在完成前在任意文本处粘贴一次以迷惑剪贴板记录。', + twoStageClipboardReminder: + '记得在提交前粘贴一次扰动字符串,降低剪贴板泄漏风险。', + twoStageClipboardManual: '自动复制失败,请手动复制下面的扰动字符串。', + twoStageClipboardFailed: '自动写入剪贴板失败,请手动复制扰动字符串。', + twoStageClipboardInstruction: '扰动字符串已复制,请在完成输入前粘贴一次。', + twoStageBack: '返回', + twoStageSubmit: '确认', + twoStageInvalidFormat: '私钥格式不正确,应为 {length} 位十六进制字符(可选 0x 前缀)。', testnetDescription: '启用后将连接到交易所测试环境,用于模拟交易', securityWarning: '安全提示', saveConfiguration: '保存配置', diff --git a/web/src/lib/api.ts b/web/src/lib/api.ts index 04592e95..dc03da93 100644 --- a/web/src/lib/api.ts +++ b/web/src/lib/api.ts @@ -11,7 +11,8 @@ import type { UpdateModelConfigRequest, UpdateExchangeConfigRequest, CompetitionData, -} from '../types' +} from '../types'; +import { CryptoService } from './crypto'; const API_BASE = '/api' @@ -165,6 +166,40 @@ export const api = { if (!res.ok) throw new Error('更新交易所配置失败') }, + // 使用加密传输更新交易所配置 + async updateExchangeConfigsEncrypted(request: UpdateExchangeConfigRequest): Promise { + // 从系统配置获取公钥 + const configRes = await fetch(`${API_BASE}/config`); + if (!configRes.ok) throw new Error('获取系统配置失败'); + const config = await configRes.json(); + + if (!config.rsa_public_key) { + throw new Error('系统未配置RSA公钥,无法使用加密传输'); + } + + // 初始化加密服务 + await CryptoService.initialize(config.rsa_public_key); + + // 获取用户信息(从localStorage或其他地方) + const userId = localStorage.getItem('user_id') || ''; + const sessionId = sessionStorage.getItem('session_id') || ''; + + // 加密敏感数据 + const encryptedPayload = await CryptoService.encryptSensitiveData( + JSON.stringify(request), + userId, + sessionId + ); + + // 发送加密数据 + const res = await fetch(`${API_BASE}/exchanges/encrypted`, { + method: 'PUT', + headers: getAuthHeaders(), + body: JSON.stringify(encryptedPayload), + }); + if (!res.ok) throw new Error('更新交易所配置失败'); + }, + // 获取系统状态(支持trader_id) async getStatus(traderId?: string): Promise { const url = traderId diff --git a/web/src/lib/config.ts b/web/src/lib/config.ts index f5b56c94..0f137d2c 100644 --- a/web/src/lib/config.ts +++ b/web/src/lib/config.ts @@ -3,6 +3,8 @@ export interface SystemConfig { default_coins?: string[] btc_eth_leverage?: number altcoin_leverage?: number + rsa_public_key?: string + rsa_key_id?: string } let configPromise: Promise | null = null diff --git a/web/src/lib/crypto.ts b/web/src/lib/crypto.ts new file mode 100644 index 00000000..61548cf0 --- /dev/null +++ b/web/src/lib/crypto.ts @@ -0,0 +1,142 @@ +export interface EncryptedPayload { + wrappedKey: string; // RSA-OAEP(K) + iv: string; // 12 bytes + ciphertext: string; // AES-GCM 输出(含 tag) + aad?: string; // 可选:额外认证数据 + kid?: string; // 可选:服务端公钥标识 + ts?: number; // 可选:unix 秒,用于重放保护 +} + +export class CryptoService { + private static publicKey: CryptoKey | null = null; + private static publicKeyPEM: string | null = null; + + static async initialize(publicKeyPEM: string) { + if (this.publicKey && this.publicKeyPEM === publicKeyPEM) { + return; + } + this.publicKeyPEM = publicKeyPEM; + this.publicKey = await this.importPublicKey(publicKeyPEM); + } + + private static async importPublicKey(pem: string): Promise { + const pemHeader = '-----BEGIN PUBLIC KEY-----'; + const pemFooter = '-----END PUBLIC KEY-----'; + const headerIndex = pem.indexOf(pemHeader); + const footerIndex = pem.indexOf(pemFooter); + + if (headerIndex === -1 || footerIndex === -1 || headerIndex >= footerIndex) { + throw new Error('Invalid PEM formatted public key'); + } + + const pemContents = pem + .substring(headerIndex + pemHeader.length, footerIndex) + .replace(/\s+/g, ''); // 移除所有空白字符(包括换行符、空格等) + + const binaryDerString = atob(pemContents); + const binaryDer = new Uint8Array(binaryDerString.length); + for (let i = 0; i < binaryDerString.length; i++) { + binaryDer[i] = binaryDerString.charCodeAt(i); + } + + return crypto.subtle.importKey( + 'spki', + binaryDer, + { + name: 'RSA-OAEP', + hash: 'SHA-256', + }, + false, + ['encrypt'] + ); + } + + static async encryptSensitiveData( + plaintext: string, + userId?: string, + sessionId?: string + ): Promise { + if (!this.publicKey) { + throw new Error('Crypto service not initialized. Call initialize() first.'); + } + + // 1. 生成 256-bit AES 密钥 + const aesKey = await crypto.subtle.generateKey( + { + name: 'AES-GCM', + length: 256, + }, + true, + ['encrypt'] + ); + + // 2. 生成 12 字节随机 IV + const iv = crypto.getRandomValues(new Uint8Array(12)); + + // 3. 准备 AAD (额外认证数据) + const ts = Math.floor(Date.now() / 1000); + const aadObject = { + userId: userId || '', + sessionId: sessionId || '', + ts: ts, + purpose: 'sensitive_data_encryption' + }; + const aadString = JSON.stringify(aadObject); + const aadBytes = new TextEncoder().encode(aadString); + + // 4. 使用 AES-GCM 加密数据 + const plaintextBytes = new TextEncoder().encode(plaintext); + const ciphertext = await crypto.subtle.encrypt( + { + name: 'AES-GCM', + iv: iv, + additionalData: aadBytes, + tagLength: 128, // 16 bytes tag + }, + aesKey, + plaintextBytes + ); + + // 5. 导出 AES 密钥 + const aesKeyRaw = await crypto.subtle.exportKey('raw', aesKey); + + // 6. 使用 RSA-OAEP 加密 AES 密钥 + const wrappedKey = await crypto.subtle.encrypt( + { + name: 'RSA-OAEP', + }, + this.publicKey, + aesKeyRaw + ); + + // 7. 转换为 base64url 格式 + return { + wrappedKey: this.arrayBufferToBase64Url(wrappedKey), + iv: this.arrayBufferToBase64Url(iv), + ciphertext: this.arrayBufferToBase64Url(ciphertext), + aad: this.arrayBufferToBase64Url(aadBytes), + kid: 'rsa-key-2025-11-05', + ts: ts, + }; + } + + private static arrayBufferToBase64Url(buffer: ArrayBuffer | Uint8Array): string { + const bytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer); + let binary = ''; + for (let i = 0; i < bytes.byteLength; i++) { + binary += String.fromCharCode(bytes[i]); + } + return btoa(binary) + .replace(/\+/g, '-') + .replace(/\//g, '_') + .replace(/=/g, ''); + } + + static async encryptWalletPrivateKey(privateKey: string, userId?: string, sessionId?: string): Promise { + return this.encryptSensitiveData(privateKey, userId, sessionId); + } + + static async encryptExchangeSecret(secretKey: string, userId?: string, sessionId?: string): Promise { + return this.encryptSensitiveData(secretKey, userId, sessionId); + } +}