diff --git a/api/handler_ai_model.go b/api/handler_ai_model.go new file mode 100644 index 00000000..d30d5dec --- /dev/null +++ b/api/handler_ai_model.go @@ -0,0 +1,211 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "nofx/config" + "nofx/crypto" + "nofx/logger" + "nofx/security" + + "github.com/gin-gonic/gin" +) + +type ModelConfig struct { + ID string `json:"id"` + Name string `json:"name"` + Provider string `json:"provider"` + Enabled bool `json:"enabled"` + APIKey string `json:"apiKey,omitempty"` + CustomAPIURL string `json:"customApiUrl,omitempty"` +} + +// SafeModelConfig Safe model configuration structure (does not contain sensitive information) +type SafeModelConfig struct { + ID string `json:"id"` + Name string `json:"name"` + Provider string `json:"provider"` + Enabled bool `json:"enabled"` + CustomAPIURL string `json:"customApiUrl"` // Custom API URL (usually not sensitive) + CustomModelName string `json:"customModelName"` // Custom model name (not sensitive) +} + +type UpdateModelConfigRequest struct { + Models map[string]struct { + Enabled bool `json:"enabled"` + APIKey string `json:"api_key"` + CustomAPIURL string `json:"custom_api_url"` + CustomModelName string `json:"custom_model_name"` + } `json:"models"` +} + +// handleGetModelConfigs Get AI model configurations +func (s *Server) handleGetModelConfigs(c *gin.Context) { + userID := c.GetString("user_id") + logger.Infof("๐Ÿ” Querying AI model configs for user %s", userID) + models, err := s.store.AIModel().List(userID) + if err != nil { + logger.Infof("โŒ Failed to get AI model configs: %v", err) + SafeInternalError(c, "Failed to get AI model configs", err) + return + } + + // If no models in database, return default models + if len(models) == 0 { + logger.Infof("โš ๏ธ No AI models in database, returning defaults") + defaultModels := []SafeModelConfig{ + {ID: "deepseek", Name: "DeepSeek AI", Provider: "deepseek", Enabled: false}, + {ID: "qwen", Name: "Qwen AI", Provider: "qwen", Enabled: false}, + {ID: "openai", Name: "OpenAI", Provider: "openai", Enabled: false}, + {ID: "claude", Name: "Claude AI", Provider: "claude", Enabled: false}, + {ID: "gemini", Name: "Gemini AI", Provider: "gemini", Enabled: false}, + {ID: "grok", Name: "Grok AI", Provider: "grok", Enabled: false}, + {ID: "kimi", Name: "Kimi AI", Provider: "kimi", Enabled: false}, + {ID: "minimax", Name: "MiniMax AI", Provider: "minimax", Enabled: false}, + } + c.JSON(http.StatusOK, defaultModels) + return + } + + logger.Infof("โœ… Found %d AI model configs", len(models)) + + // Convert to safe response structure, remove sensitive information + safeModels := make([]SafeModelConfig, len(models)) + for i, model := range models { + safeModels[i] = SafeModelConfig{ + ID: model.ID, + Name: model.Name, + Provider: model.Provider, + Enabled: model.Enabled, + CustomAPIURL: model.CustomAPIURL, + CustomModelName: model.CustomModelName, + } + } + + c.JSON(http.StatusOK, safeModels) +} + +// handleUpdateModelConfigs Update AI model configurations (supports both encrypted and plain text based on config) +func (s *Server) handleUpdateModelConfigs(c *gin.Context) { + userID := c.GetString("user_id") + cfg := config.Get() + + // Read raw request body + bodyBytes, err := c.GetRawData() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"}) + return + } + + var req UpdateModelConfigRequest + + // Check if transport encryption is enabled + if !cfg.TransportEncryption { + // Transport encryption disabled, accept plain JSON + if err := json.Unmarshal(bodyBytes, &req); err != nil { + logger.Infof("โŒ Failed to parse plain JSON request: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"}) + return + } + logger.Infof("๐Ÿ“ Received plain text model config (UserID: %s)", userID) + } else { + // Transport encryption enabled, require encrypted payload + var encryptedPayload crypto.EncryptedPayload + if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil { + logger.Infof("โŒ Failed to parse encrypted payload: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format, encrypted transmission required"}) + return + } + + // Verify encrypted data + if encryptedPayload.WrappedKey == "" { + logger.Infof("โŒ Detected unencrypted request (UserID: %s)", userID) + c.JSON(http.StatusBadRequest, gin.H{ + "error": "This endpoint only supports encrypted transmission, please use encrypted client", + "code": "ENCRYPTION_REQUIRED", + "message": "Encrypted transmission is required for security reasons", + }) + return + } + + // Decrypt data + decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload) + if err != nil { + logger.Infof("โŒ Failed to decrypt model config (UserID: %s): %v", userID, err) + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to decrypt data"}) + return + } + + // Parse decrypted data + if err := json.Unmarshal([]byte(decrypted), &req); err != nil { + logger.Infof("โŒ Failed to parse decrypted data: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse decrypted data"}) + return + } + logger.Infof("๐Ÿ”“ Decrypted model config data (UserID: %s)", userID) + } + + // Update each model's configuration and track traders that need reload + tradersToReload := make(map[string]bool) + for modelID, modelData := range req.Models { + // SSRF protection: validate custom_api_url before storing + if modelData.CustomAPIURL != "" { + cleanURL := strings.TrimSuffix(modelData.CustomAPIURL, "#") + if err := security.ValidateURL(cleanURL); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid custom_api_url for model %s: %s", modelID, err.Error())}) + return + } + } + + // Find traders using this AI model BEFORE updating + traders, _ := s.store.Trader().ListByAIModelID(userID, modelID) + for _, t := range traders { + tradersToReload[t.ID] = true + } + + err := s.store.AIModel().Update(userID, modelID, modelData.Enabled, modelData.APIKey, modelData.CustomAPIURL, modelData.CustomModelName) + if err != nil { + SafeInternalError(c, fmt.Sprintf("Update model %s", modelID), err) + return + } + } + + // Remove affected traders from memory BEFORE reloading to pick up new config + for traderID := range tradersToReload { + logger.Infof("๐Ÿ”„ Removing trader %s from memory to reload with new AI model config", traderID) + s.traderManager.RemoveTrader(traderID) + } + + // Reload all traders for this user to make new config take effect immediately + err = s.traderManager.LoadUserTradersFromStore(s.store, userID) + if err != nil { + logger.Infof("โš ๏ธ Failed to reload user traders into memory: %v", err) + // Don't return error here since model config was successfully updated to database + } + + logger.Infof("โœ“ AI model config updated: %+v", req.Models) + c.JSON(http.StatusOK, gin.H{"message": "Model configuration updated"}) +} + +// handleGetSupportedModels Get list of AI models supported by the system +func (s *Server) handleGetSupportedModels(c *gin.Context) { + // Return static list of supported AI models with default versions + supportedModels := []map[string]interface{}{ + {"id": "deepseek", "name": "DeepSeek", "provider": "deepseek", "defaultModel": "deepseek-chat"}, + {"id": "qwen", "name": "Qwen", "provider": "qwen", "defaultModel": "qwen3-max"}, + {"id": "openai", "name": "OpenAI", "provider": "openai", "defaultModel": "gpt-5.1"}, + {"id": "claude", "name": "Claude", "provider": "claude", "defaultModel": "claude-opus-4-6"}, + {"id": "gemini", "name": "Google Gemini", "provider": "gemini", "defaultModel": "gemini-3-pro-preview"}, + {"id": "grok", "name": "Grok (xAI)", "provider": "grok", "defaultModel": "grok-3-latest"}, + {"id": "kimi", "name": "Kimi (Moonshot)", "provider": "kimi", "defaultModel": "moonshot-v1-auto"}, + {"id": "minimax", "name": "MiniMax", "provider": "minimax", "defaultModel": "MiniMax-M2.5"}, + {"id": "blockrun-base", "name": "BlockRun (Base Wallet)", "provider": "blockrun-base", "defaultModel": "auto"}, + {"id": "blockrun-sol", "name": "BlockRun (Solana Wallet)", "provider": "blockrun-sol", "defaultModel": "auto"}, + {"id": "claw402", "name": "Claw402 (Base USDC)", "provider": "claw402", "defaultModel": "deepseek"}, + } + + c.JSON(http.StatusOK, supportedModels) +} diff --git a/api/handler_competition.go b/api/handler_competition.go new file mode 100644 index 00000000..be793876 --- /dev/null +++ b/api/handler_competition.go @@ -0,0 +1,469 @@ +package api + +import ( + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "nofx/logger" + "nofx/store" + + "github.com/gin-gonic/gin" +) + +// handleDecisions Decision log list +func (s *Server) handleDecisions(c *gin.Context) { + _, traderID, err := s.getTraderFromQuery(c) + if err != nil { + SafeBadRequest(c, "Invalid trader ID") + return + } + + trader, err := s.traderManager.GetTrader(traderID) + if err != nil { + SafeNotFound(c, "Trader") + return + } + + // Get all historical decision records (unlimited) + records, err := trader.GetStore().Decision().GetLatestRecords(trader.GetID(), 10000) + if err != nil { + SafeInternalError(c, "Get decision log", err) + return + } + + c.JSON(http.StatusOK, records) +} + +// handleLatestDecisions Latest decision logs (newest first, supports limit parameter) +func (s *Server) handleLatestDecisions(c *gin.Context) { + _, traderID, err := s.getTraderFromQuery(c) + if err != nil { + SafeBadRequest(c, "Invalid trader ID") + return + } + + trader, err := s.traderManager.GetTrader(traderID) + if err != nil { + SafeNotFound(c, "Trader") + return + } + + // Get limit from query parameter, default to 5 + limit := 5 + if limitStr := c.Query("limit"); limitStr != "" { + if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 { + limit = parsedLimit + if limit > 100 { + limit = 100 // Max 100 to prevent abuse + } + } + } + + records, err := trader.GetStore().Decision().GetLatestRecords(trader.GetID(), limit) + if err != nil { + SafeInternalError(c, "Get decision log", err) + return + } + + // Reverse array to put newest first (for list display) + // GetLatestRecords returns oldest to newest (for charts), here we need newest to oldest + for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 { + records[i], records[j] = records[j], records[i] + } + + c.JSON(http.StatusOK, records) +} + +// handleStatistics Statistics information +func (s *Server) handleStatistics(c *gin.Context) { + _, traderID, err := s.getTraderFromQuery(c) + if err != nil { + SafeBadRequest(c, "Invalid trader ID") + return + } + + trader, err := s.traderManager.GetTrader(traderID) + if err != nil { + SafeNotFound(c, "Trader") + return + } + + stats, err := trader.GetStore().Decision().GetStatistics(trader.GetID()) + if err != nil { + SafeInternalError(c, "Get statistics", err) + return + } + + c.JSON(http.StatusOK, stats) +} + +// handleCompetition Competition overview (compare all traders) +func (s *Server) handleCompetition(c *gin.Context) { + userID := c.GetString("user_id") + + // Ensure user's traders are loaded into memory + err := s.traderManager.LoadUserTradersFromStore(s.store, userID) + if err != nil { + logger.Infof("โš ๏ธ Failed to load traders for user %s: %v", userID, err) + } + + competition, err := s.traderManager.GetCompetitionData() + if err != nil { + SafeInternalError(c, "Get competition data", err) + return + } + + c.JSON(http.StatusOK, competition) +} + +// handleEquityHistory Return rate historical data +// Query directly from database, not dependent on trader in memory (so historical data can be retrieved after restart) +func (s *Server) handleEquityHistory(c *gin.Context) { + _, traderID, err := s.getTraderFromQuery(c) + if err != nil { + SafeBadRequest(c, "Invalid trader ID") + return + } + + // Get equity historical data from new equity table + // Every 3 minutes per cycle: 10000 records = about 20 days of data + snapshots, err := s.store.Equity().GetLatest(traderID, 10000) + if err != nil { + SafeInternalError(c, "Get historical data", err) + return + } + + if len(snapshots) == 0 { + c.JSON(http.StatusOK, []interface{}{}) + return + } + + // Build return rate historical data points + type EquityPoint struct { + Timestamp string `json:"timestamp"` + TotalEquity float64 `json:"total_equity"` // Account equity (wallet + unrealized) + AvailableBalance float64 `json:"available_balance"` // Available balance + TotalPnL float64 `json:"total_pnl"` // Total PnL (unrealized PnL) + TotalPnLPct float64 `json:"total_pnl_pct"` // Total PnL percentage + PositionCount int `json:"position_count"` // Position count + MarginUsedPct float64 `json:"margin_used_pct"` // Margin used percentage + } + + // Use the balance of the first record as initial balance to calculate return rate + initialBalance := snapshots[0].Balance + if initialBalance == 0 { + initialBalance = 1 // Avoid division by zero + } + + var history []EquityPoint + for _, snap := range snapshots { + // Calculate PnL percentage + totalPnLPct := 0.0 + if initialBalance > 0 { + totalPnLPct = (snap.UnrealizedPnL / initialBalance) * 100 + } + + history = append(history, EquityPoint{ + Timestamp: snap.Timestamp.Format("2006-01-02 15:04:05"), + TotalEquity: snap.TotalEquity, + AvailableBalance: snap.Balance, + TotalPnL: snap.UnrealizedPnL, + TotalPnLPct: totalPnLPct, + PositionCount: snap.PositionCount, + MarginUsedPct: snap.MarginUsedPct, + }) + } + + c.JSON(http.StatusOK, history) +} + +// handlePublicTraderList Get public trader list (no authentication required) +func (s *Server) handlePublicTraderList(c *gin.Context) { + // Get trader information from all users + competition, err := s.traderManager.GetCompetitionData() + if err != nil { + SafeInternalError(c, "Get trader list", err) + return + } + + // Get traders array + tradersData, exists := competition["traders"] + if !exists { + c.JSON(http.StatusOK, []map[string]interface{}{}) + return + } + + traders, ok := tradersData.([]map[string]interface{}) + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Trader data format error", + }) + return + } + + // Return trader basic information, filter sensitive information + result := make([]map[string]interface{}, 0, len(traders)) + for _, trader := range traders { + result = append(result, map[string]interface{}{ + "trader_id": trader["trader_id"], + "trader_name": trader["trader_name"], + "ai_model": trader["ai_model"], + "exchange": trader["exchange"], + "is_running": trader["is_running"], + "total_equity": trader["total_equity"], + "total_pnl": trader["total_pnl"], + "total_pnl_pct": trader["total_pnl_pct"], + "position_count": trader["position_count"], + "margin_used_pct": trader["margin_used_pct"], + }) + } + + c.JSON(http.StatusOK, result) +} + +// handlePublicCompetition Get public competition data (no authentication required) +func (s *Server) handlePublicCompetition(c *gin.Context) { + competition, err := s.traderManager.GetCompetitionData() + if err != nil { + SafeInternalError(c, "Get competition data", err) + return + } + + c.JSON(http.StatusOK, competition) +} + +// handleTopTraders Get top 5 trader data (no authentication required, for performance comparison) +func (s *Server) handleTopTraders(c *gin.Context) { + topTraders, err := s.traderManager.GetTopTradersData() + if err != nil { + SafeInternalError(c, "Get top traders data", err) + return + } + + c.JSON(http.StatusOK, topTraders) +} + +// handleEquityHistoryBatch Batch get return rate historical data for multiple traders (no authentication required, for performance comparison) +// Supports optional 'hours' parameter to filter data by time range (e.g., hours=24 for last 24 hours) +func (s *Server) handleEquityHistoryBatch(c *gin.Context) { + var requestBody struct { + TraderIDs []string `json:"trader_ids"` + Hours int `json:"hours"` // Optional: filter by last N hours (0 = all data) + } + + // Try to parse POST request JSON body + if err := c.ShouldBindJSON(&requestBody); err != nil { + // If JSON parse fails, try to get from query parameters (compatible with GET request) + traderIDsParam := c.Query("trader_ids") + if traderIDsParam == "" { + // If no trader_ids specified, return historical data for top 5 + topTraders, err := s.traderManager.GetTopTradersData() + if err != nil { + SafeInternalError(c, "Get top traders", err) + return + } + + traders, ok := topTraders["traders"].([]map[string]interface{}) + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Trader data format error"}) + return + } + + // Extract trader IDs + traderIDs := make([]string, 0, len(traders)) + for _, trader := range traders { + if traderID, ok := trader["trader_id"].(string); ok { + traderIDs = append(traderIDs, traderID) + } + } + + // Parse hours parameter from query + hoursParam := c.Query("hours") + hours := 0 + if hoursParam != "" { + fmt.Sscanf(hoursParam, "%d", &hours) + } + + result := s.getEquityHistoryForTraders(traderIDs, hours) + c.JSON(http.StatusOK, result) + return + } + + // Parse comma-separated trader IDs + requestBody.TraderIDs = strings.Split(traderIDsParam, ",") + for i := range requestBody.TraderIDs { + requestBody.TraderIDs[i] = strings.TrimSpace(requestBody.TraderIDs[i]) + } + + // Parse hours parameter from query + hoursParam := c.Query("hours") + if hoursParam != "" { + fmt.Sscanf(hoursParam, "%d", &requestBody.Hours) + } + } + + // Limit to maximum 20 traders to prevent oversized requests + if len(requestBody.TraderIDs) > 20 { + requestBody.TraderIDs = requestBody.TraderIDs[:20] + } + + result := s.getEquityHistoryForTraders(requestBody.TraderIDs, requestBody.Hours) + c.JSON(http.StatusOK, result) +} + +// getEquityHistoryForTraders Get historical data for multiple traders +// Query directly from database, not dependent on trader in memory (so historical data can be retrieved after restart) +// Also appends current real-time data point to ensure chart matches leaderboard +// hours: filter by last N hours (0 = use default limit of 500 records) +func (s *Server) getEquityHistoryForTraders(traderIDs []string, hours int) map[string]interface{} { + result := make(map[string]interface{}) + histories := make(map[string]interface{}) + errors := make(map[string]string) + + // Use a single consistent timestamp for all real-time data points + now := time.Now() + + // Pre-fetch initial balances for all traders + initialBalances := make(map[string]float64) + for _, traderID := range traderIDs { + if traderID == "" { + continue + } + // Get trader's initial balance from database (use GetByID which doesn't require userID) + trader, err := s.store.Trader().GetByID(traderID) + if err == nil && trader != nil && trader.InitialBalance > 0 { + initialBalances[traderID] = trader.InitialBalance + } + } + + for _, traderID := range traderIDs { + if traderID == "" { + continue + } + + // Get equity historical data from new equity table + var snapshots []*store.EquitySnapshot + var err error + + if hours > 0 { + // Filter by time range + startTime := now.Add(-time.Duration(hours) * time.Hour) + snapshots, err = s.store.Equity().GetByTimeRange(traderID, startTime, now) + } else { + // Default: get latest 500 records + snapshots, err = s.store.Equity().GetLatest(traderID, 500) + } + if err != nil { + logger.Errorf("[API] Failed to get equity history for %s: %v", traderID, err) + errors[traderID] = "Failed to get historical data" + continue + } + + // Get initial balance for calculating PnL percentage + initialBalance := initialBalances[traderID] + if initialBalance <= 0 && len(snapshots) > 0 { + // If no initial balance configured, use the first snapshot's equity as baseline + initialBalance = snapshots[0].TotalEquity + } + + // Build return rate historical data with PnL percentage + history := make([]map[string]interface{}, 0, len(snapshots)+1) + var lastSnapshotTime time.Time + for _, snap := range snapshots { + // Calculate PnL percentage: (current_equity - initial_balance) / initial_balance * 100 + pnlPct := 0.0 + if initialBalance > 0 { + pnlPct = (snap.TotalEquity - initialBalance) / initialBalance * 100 + } + + history = append(history, map[string]interface{}{ + "timestamp": snap.Timestamp, + "total_equity": snap.TotalEquity, + "total_pnl": snap.UnrealizedPnL, + "total_pnl_pct": pnlPct, + "balance": snap.Balance, + }) + if snap.Timestamp.After(lastSnapshotTime) { + lastSnapshotTime = snap.Timestamp + } + } + + // Append current real-time data point to ensure chart matches leaderboard + // This ensures the latest point is always current, not from a potentially stale snapshot + if trader, err := s.traderManager.GetTrader(traderID); err == nil { + if accountInfo, err := trader.GetAccountInfo(); err == nil { + // Only append if it's been more than 30 seconds since last snapshot + if now.Sub(lastSnapshotTime) > 30*time.Second { + totalEquity := 0.0 + if v, ok := accountInfo["total_equity"].(float64); ok { + totalEquity = v + } + totalPnL := 0.0 + if v, ok := accountInfo["total_pnl"].(float64); ok { + totalPnL = v + } + walletBalance := 0.0 + if v, ok := accountInfo["wallet_balance"].(float64); ok { + walletBalance = v + } + pnlPct := 0.0 + if initialBalance > 0 { + pnlPct = (totalEquity - initialBalance) / initialBalance * 100 + } + + history = append(history, map[string]interface{}{ + "timestamp": now, + "total_equity": totalEquity, + "total_pnl": totalPnL, + "total_pnl_pct": pnlPct, + "balance": walletBalance, + }) + } + } + } + + histories[traderID] = history + } + + result["histories"] = histories + result["count"] = len(histories) + if len(errors) > 0 { + result["errors"] = errors + } + + return result +} + +// handleGetPublicTraderConfig Get public trader configuration information (no authentication required, does not include sensitive information) +func (s *Server) handleGetPublicTraderConfig(c *gin.Context) { + traderID := c.Param("id") + if traderID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Trader ID cannot be empty"}) + return + } + + trader, err := s.traderManager.GetTrader(traderID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"}) + return + } + + // Get trader status information + status := trader.GetStatus() + + // Only return public configuration information, not including sensitive data like API keys + result := map[string]interface{}{ + "trader_id": trader.GetID(), + "trader_name": trader.GetName(), + "ai_model": trader.GetAIModel(), + "exchange": trader.GetExchange(), + "is_running": status["is_running"], + "ai_provider": status["ai_provider"], + "start_time": status["start_time"], + } + + c.JSON(http.StatusOK, result) +} diff --git a/api/handler_exchange.go b/api/handler_exchange.go new file mode 100644 index 00000000..24884c7e --- /dev/null +++ b/api/handler_exchange.go @@ -0,0 +1,353 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + + "nofx/config" + "nofx/crypto" + "nofx/logger" + + "github.com/gin-gonic/gin" +) + +type ExchangeConfig struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` // "cex" or "dex" + Enabled bool `json:"enabled"` + APIKey string `json:"apiKey,omitempty"` + SecretKey string `json:"secretKey,omitempty"` + Testnet bool `json:"testnet,omitempty"` +} + +// SafeExchangeConfig Safe exchange configuration structure (does not contain sensitive information) +type SafeExchangeConfig struct { + ID string `json:"id"` // UUID + ExchangeType string `json:"exchange_type"` // "binance", "bybit", "okx", "hyperliquid", "aster", "lighter" + AccountName string `json:"account_name"` // User-defined account name + Name string `json:"name"` // Display name + Type string `json:"type"` // "cex" or "dex" + Enabled bool `json:"enabled"` + Testnet bool `json:"testnet,omitempty"` + HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` // Hyperliquid wallet address (not sensitive) + AsterUser string `json:"asterUser"` // Aster username (not sensitive) + AsterSigner string `json:"asterSigner"` // Aster signer (not sensitive) + LighterWalletAddr string `json:"lighterWalletAddr"` // LIGHTER wallet address (not sensitive) +} + +type UpdateExchangeConfigRequest struct { + Exchanges map[string]struct { + Enabled bool `json:"enabled"` + APIKey string `json:"api_key"` + SecretKey string `json:"secret_key"` + Passphrase string `json:"passphrase"` // OKX specific + Testnet bool `json:"testnet"` + HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"` + HyperliquidUnifiedAcct bool `json:"hyperliquid_unified_account"` // Unified Account mode + AsterUser string `json:"aster_user"` + AsterSigner string `json:"aster_signer"` + AsterPrivateKey string `json:"aster_private_key"` + LighterWalletAddr string `json:"lighter_wallet_addr"` + LighterPrivateKey string `json:"lighter_private_key"` + LighterAPIKeyPrivateKey string `json:"lighter_api_key_private_key"` + LighterAPIKeyIndex int `json:"lighter_api_key_index"` + } `json:"exchanges"` +} + +// CreateExchangeRequest request structure for creating a new exchange account +type CreateExchangeRequest struct { + ExchangeType string `json:"exchange_type" binding:"required"` // "binance", "bybit", "okx", "hyperliquid", "aster", "lighter" + AccountName string `json:"account_name"` // User-defined account name + Enabled bool `json:"enabled"` + APIKey string `json:"api_key"` + SecretKey string `json:"secret_key"` + Passphrase string `json:"passphrase"` + Testnet bool `json:"testnet"` + HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"` + HyperliquidUnifiedAcct bool `json:"hyperliquid_unified_account"` // Unified Account mode: Spot as Perp collateral + AsterUser string `json:"aster_user"` + AsterSigner string `json:"aster_signer"` + AsterPrivateKey string `json:"aster_private_key"` + LighterWalletAddr string `json:"lighter_wallet_addr"` + LighterPrivateKey string `json:"lighter_private_key"` + LighterAPIKeyPrivateKey string `json:"lighter_api_key_private_key"` + LighterAPIKeyIndex int `json:"lighter_api_key_index"` +} + +// handleGetExchangeConfigs Get exchange configurations +func (s *Server) handleGetExchangeConfigs(c *gin.Context) { + userID := c.GetString("user_id") + logger.Infof("๐Ÿ” Querying exchange configs for user %s", userID) + exchanges, err := s.store.Exchange().List(userID) + if err != nil { + SafeInternalError(c, "Failed to get exchange configs", err) + return + } + + // If no exchanges in database, return empty array (user needs to create accounts) + if len(exchanges) == 0 { + logger.Infof("โš ๏ธ No exchanges in database for user %s", userID) + c.JSON(http.StatusOK, []SafeExchangeConfig{}) + return + } + + logger.Infof("โœ… Found %d exchange configs", len(exchanges)) + + // Convert to safe response structure, remove sensitive information + safeExchanges := make([]SafeExchangeConfig, len(exchanges)) + for i, exchange := range exchanges { + safeExchanges[i] = SafeExchangeConfig{ + ID: exchange.ID, + ExchangeType: exchange.ExchangeType, + AccountName: exchange.AccountName, + Name: exchange.Name, + Type: exchange.Type, + Enabled: exchange.Enabled, + Testnet: exchange.Testnet, + HyperliquidWalletAddr: exchange.HyperliquidWalletAddr, + AsterUser: exchange.AsterUser, + AsterSigner: exchange.AsterSigner, + LighterWalletAddr: exchange.LighterWalletAddr, + } + } + + c.JSON(http.StatusOK, safeExchanges) +} + +// handleUpdateExchangeConfigs Update exchange configurations (supports both encrypted and plain text based on config) +func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) { + userID := c.GetString("user_id") + cfg := config.Get() + + // Read raw request body + bodyBytes, err := c.GetRawData() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"}) + return + } + + var req UpdateExchangeConfigRequest + + // Check if transport encryption is enabled + if !cfg.TransportEncryption { + // Transport encryption disabled, accept plain JSON + if err := json.Unmarshal(bodyBytes, &req); err != nil { + logger.Infof("โŒ Failed to parse plain JSON request: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"}) + return + } + logger.Infof("๐Ÿ“ Received plain text exchange config (UserID: %s)", userID) + } else { + // Transport encryption enabled, require encrypted payload + var encryptedPayload crypto.EncryptedPayload + if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil { + logger.Infof("โŒ Failed to parse encrypted payload: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format, encrypted transmission required"}) + return + } + + // Verify encrypted data + if encryptedPayload.WrappedKey == "" { + logger.Infof("โŒ Detected unencrypted request (UserID: %s)", userID) + c.JSON(http.StatusBadRequest, gin.H{ + "error": "This endpoint only supports encrypted transmission, please use encrypted client", + "code": "ENCRYPTION_REQUIRED", + "message": "Encrypted transmission is required for security reasons", + }) + return + } + + // Decrypt data + decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload) + if err != nil { + logger.Infof("โŒ Failed to decrypt exchange config (UserID: %s): %v", userID, err) + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to decrypt data"}) + return + } + + // Parse decrypted data + if err := json.Unmarshal([]byte(decrypted), &req); err != nil { + logger.Infof("โŒ Failed to parse decrypted data: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse decrypted data"}) + return + } + logger.Infof("๐Ÿ”“ Decrypted exchange config data (UserID: %s)", userID) + } + + // Update each exchange's configuration and track traders that need reload + tradersToReload := make(map[string]bool) + for exchangeID, exchangeData := range req.Exchanges { + // Find traders using this exchange BEFORE updating + traders, _ := s.store.Trader().ListByExchangeID(userID, exchangeID) + for _, t := range traders { + tradersToReload[t.ID] = true + } + + err := s.store.Exchange().Update(userID, exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Passphrase, exchangeData.Testnet, exchangeData.HyperliquidWalletAddr, exchangeData.HyperliquidUnifiedAcct, exchangeData.AsterUser, exchangeData.AsterSigner, exchangeData.AsterPrivateKey, exchangeData.LighterWalletAddr, exchangeData.LighterPrivateKey, exchangeData.LighterAPIKeyPrivateKey, exchangeData.LighterAPIKeyIndex) + if err != nil { + SafeInternalError(c, fmt.Sprintf("Update exchange %s", exchangeID), err) + return + } + } + + // Remove affected traders from memory BEFORE reloading to pick up new config + for traderID := range tradersToReload { + logger.Infof("๐Ÿ”„ Removing trader %s from memory to reload with new exchange config", traderID) + s.traderManager.RemoveTrader(traderID) + } + + // Reload all traders for this user to make new config take effect immediately + err = s.traderManager.LoadUserTradersFromStore(s.store, userID) + if err != nil { + logger.Infof("โš ๏ธ Failed to reload user traders into memory: %v", err) + // Don't return error here since exchange config was successfully updated to database + } + + logger.Infof("โœ“ Exchange config updated: %+v", req.Exchanges) + c.JSON(http.StatusOK, gin.H{"message": "Exchange configuration updated"}) +} + +// handleCreateExchange Create a new exchange account +func (s *Server) handleCreateExchange(c *gin.Context) { + userID := c.GetString("user_id") + cfg := config.Get() + + // Read raw request body + bodyBytes, err := c.GetRawData() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"}) + return + } + + var req CreateExchangeRequest + + // Check if transport encryption is enabled + if !cfg.TransportEncryption { + // Transport encryption disabled, accept plain JSON + if err := json.Unmarshal(bodyBytes, &req); err != nil { + logger.Infof("โŒ Failed to parse plain JSON request: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"}) + return + } + } else { + // Transport encryption enabled, require encrypted payload + var encryptedPayload crypto.EncryptedPayload + if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format, encrypted transmission required"}) + return + } + + if encryptedPayload.WrappedKey == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "This endpoint only supports encrypted transmission", + "code": "ENCRYPTION_REQUIRED", + "message": "Encrypted transmission is required for security reasons", + }) + return + } + + decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to decrypt data"}) + return + } + + if err := json.Unmarshal([]byte(decrypted), &req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse decrypted data"}) + return + } + } + + // Validate exchange type + validTypes := map[string]bool{ + "binance": true, "bybit": true, "okx": true, "bitget": true, + "hyperliquid": true, "aster": true, "lighter": true, "gate": true, "kucoin": true, "indodax": true, + } + if !validTypes[req.ExchangeType] { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid exchange type: %s", req.ExchangeType)}) + return + } + + // Create new exchange account + id, err := s.store.Exchange().Create( + userID, req.ExchangeType, req.AccountName, req.Enabled, + req.APIKey, req.SecretKey, req.Passphrase, req.Testnet, + req.HyperliquidWalletAddr, req.HyperliquidUnifiedAcct, + req.AsterUser, req.AsterSigner, req.AsterPrivateKey, + req.LighterWalletAddr, req.LighterPrivateKey, req.LighterAPIKeyPrivateKey, req.LighterAPIKeyIndex, + ) + if err != nil { + logger.Infof("โŒ Failed to create exchange account: %v", err) + SafeInternalError(c, "Failed to create exchange account", err) + return + } + + logger.Infof("โœ“ Created exchange account: type=%s, name=%s, id=%s", req.ExchangeType, req.AccountName, id) + c.JSON(http.StatusOK, gin.H{ + "message": "Exchange account created", + "id": id, + }) +} + +// handleDeleteExchange Delete an exchange account +func (s *Server) handleDeleteExchange(c *gin.Context) { + userID := c.GetString("user_id") + exchangeID := c.Param("id") + + if exchangeID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Exchange ID is required"}) + return + } + + // Check if any traders are using this exchange + traders, err := s.store.Trader().List(userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check traders"}) + return + } + + for _, trader := range traders { + if trader.ExchangeID == exchangeID { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Cannot delete exchange account that is in use by traders", + "trader_id": trader.ID, + "trader_name": trader.Name, + }) + return + } + } + + // Delete exchange account + err = s.store.Exchange().Delete(userID, exchangeID) + if err != nil { + logger.Infof("โŒ Failed to delete exchange account: %v", err) + SafeInternalError(c, "Failed to delete exchange account", err) + return + } + + logger.Infof("โœ“ Deleted exchange account: id=%s", exchangeID) + c.JSON(http.StatusOK, gin.H{"message": "Exchange account deleted"}) +} + +// handleGetSupportedExchanges Get list of exchanges supported by the system +func (s *Server) handleGetSupportedExchanges(c *gin.Context) { + // Return static list of supported exchange types + // Note: ID is empty for supported exchanges (they are templates, not actual accounts) + supportedExchanges := []SafeExchangeConfig{ + {ExchangeType: "binance", Name: "Binance Futures", Type: "cex"}, + {ExchangeType: "bybit", Name: "Bybit Futures", Type: "cex"}, + {ExchangeType: "okx", Name: "OKX Futures", Type: "cex"}, + {ExchangeType: "gate", Name: "Gate.io Futures", Type: "cex"}, + {ExchangeType: "kucoin", Name: "KuCoin Futures", Type: "cex"}, + {ExchangeType: "hyperliquid", Name: "Hyperliquid", Type: "dex"}, + {ExchangeType: "aster", Name: "Aster DEX", Type: "dex"}, + {ExchangeType: "lighter", Name: "LIGHTER DEX", Type: "dex"}, + {ExchangeType: "alpaca", Name: "Alpaca (US Stocks)", Type: "stock"}, + {ExchangeType: "forex", Name: "Forex (TwelveData)", Type: "forex"}, + {ExchangeType: "metals", Name: "Metals (TwelveData)", Type: "metals"}, + } + + c.JSON(http.StatusOK, supportedExchanges) +} diff --git a/api/handler_klines.go b/api/handler_klines.go new file mode 100644 index 00000000..cebee1c9 --- /dev/null +++ b/api/handler_klines.go @@ -0,0 +1,392 @@ +package api + +import ( + "context" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "nofx/logger" + "nofx/market" + "nofx/provider/alpaca" + "nofx/provider/coinank/coinank_api" + "nofx/provider/coinank/coinank_enum" + "nofx/provider/hyperliquid" + "nofx/provider/twelvedata" + + "github.com/gin-gonic/gin" +) + +// handleKlines K-line data (supports multiple exchanges via coinank) +func (s *Server) handleKlines(c *gin.Context) { + // Get query parameters + symbol := c.Query("symbol") + if symbol == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "symbol parameter is required"}) + return + } + + interval := c.DefaultQuery("interval", "5m") + exchange := c.DefaultQuery("exchange", "binance") // Default to binance for backward compatibility + limitStr := c.DefaultQuery("limit", "1000") + limit, err := strconv.Atoi(limitStr) + if err != nil || limit <= 0 { + limit = 1000 + } + + // Coinank API has a maximum limit of 1500 klines per request + if limit > 1500 { + limit = 1500 + } + + var klines []market.Kline + exchangeLower := strings.ToLower(exchange) + + // Route to appropriate data source based on exchange type + switch exchangeLower { + case "alpaca": + // US Stocks via Alpaca + klines, err = s.getKlinesFromAlpaca(symbol, interval, limit) + if err != nil { + SafeInternalError(c, "Get klines from Alpaca", err) + return + } + case "forex", "metals": + // Forex and Metals via Twelve Data + klines, err = s.getKlinesFromTwelveData(symbol, interval, limit) + if err != nil { + SafeInternalError(c, "Get klines from TwelveData", err) + return + } + case "hyperliquid", "hyperliquid-xyz", "xyz": + // Hyperliquid native API - supports both crypto perps and stock perps (xyz dex) + klines, err = s.getKlinesFromHyperliquid(symbol, interval, limit) + if err != nil { + SafeInternalError(c, "Get klines from Hyperliquid", err) + return + } + default: + // Crypto exchanges via CoinAnk + symbol = market.Normalize(symbol) + klines, err = s.getKlinesFromCoinank(symbol, interval, exchange, limit) + if err != nil { + SafeInternalError(c, "Get klines from CoinAnk", err) + return + } + } + + c.JSON(http.StatusOK, klines) +} + +// getKlinesFromCoinank fetches kline data from coinank free/open API for multiple exchanges +func (s *Server) getKlinesFromCoinank(symbol, interval, exchange string, limit int) ([]market.Kline, error) { + // Map exchange string to coinank enum + var coinankExchange coinank_enum.Exchange + switch strings.ToLower(exchange) { + case "binance": + coinankExchange = coinank_enum.Binance + case "bybit": + coinankExchange = coinank_enum.Bybit + case "okx": + coinankExchange = coinank_enum.Okex + case "bitget": + coinankExchange = coinank_enum.Bitget + case "gate": + coinankExchange = coinank_enum.Gate + case "aster": + coinankExchange = coinank_enum.Aster + case "lighter": + // Lighter doesn't have direct CoinAnk support, use Binance data as fallback + coinankExchange = coinank_enum.Binance + case "kucoin": + // KuCoin doesn't have direct CoinAnk support, use Binance data as fallback + coinankExchange = coinank_enum.Binance + default: + // For any unknown exchange, default to Binance + logger.Warnf("โš ๏ธ Unknown exchange '%s', defaulting to Binance for CoinAnk", exchange) + coinankExchange = coinank_enum.Binance + } + + // Map interval string to coinank enum + var coinankInterval coinank_enum.Interval + switch interval { + case "1s": + coinankInterval = coinank_enum.Second1 + case "5s": + coinankInterval = coinank_enum.Second5 + case "10s": + coinankInterval = coinank_enum.Second10 + case "30s": + coinankInterval = coinank_enum.Second30 + case "1m": + coinankInterval = coinank_enum.Minute1 + case "3m": + coinankInterval = coinank_enum.Minute3 + case "5m": + coinankInterval = coinank_enum.Minute5 + case "10m": + coinankInterval = coinank_enum.Minute10 + case "15m": + coinankInterval = coinank_enum.Minute15 + case "30m": + coinankInterval = coinank_enum.Minute30 + case "1h": + coinankInterval = coinank_enum.Hour1 + case "2h": + coinankInterval = coinank_enum.Hour2 + case "4h": + coinankInterval = coinank_enum.Hour4 + case "6h": + coinankInterval = coinank_enum.Hour6 + case "8h": + coinankInterval = coinank_enum.Hour8 + case "12h": + coinankInterval = coinank_enum.Hour12 + case "1d": + coinankInterval = coinank_enum.Day1 + case "3d": + coinankInterval = coinank_enum.Day3 + case "1w": + coinankInterval = coinank_enum.Week1 + case "1M": + coinankInterval = coinank_enum.Month1 + default: + return nil, fmt.Errorf("unsupported interval for coinank: %s", interval) + } + + // Convert symbol format for different exchanges + // OKX uses "BTC-USDT-SWAP" format instead of "BTCUSDT" + apiSymbol := symbol + if coinankExchange == coinank_enum.Okex { + // Convert BTCUSDT -> BTC-USDT-SWAP + if strings.HasSuffix(symbol, "USDT") { + base := strings.TrimSuffix(symbol, "USDT") + apiSymbol = fmt.Sprintf("%s-USDT-SWAP", base) + } + } + + // Call coinank free/open API (no authentication required) + ctx := context.Background() + ts := time.Now().UnixMilli() + // Use "To" side to search backward from current time (get historical klines) + coinankKlines, err := coinank_api.Kline(ctx, apiSymbol, coinankExchange, ts, coinank_enum.To, limit, coinankInterval) + if err != nil { + // Free API doesn't support all exchanges (e.g., OKX, Bitget) + // Fallback to Binance data as reference + if coinankExchange != coinank_enum.Binance { + logger.Warnf("โš ๏ธ CoinAnk free API doesn't support %s, falling back to Binance data", coinankExchange) + coinankKlines, err = coinank_api.Kline(ctx, symbol, coinank_enum.Binance, ts, coinank_enum.To, limit, coinankInterval) + if err != nil { + return nil, fmt.Errorf("coinank API error (fallback): %w", err) + } + } else { + return nil, fmt.Errorf("coinank API error: %w", err) + } + } + + // Convert coinank kline format to market.Kline format + // Coinank: Volume = BTC ๆ•ฐ้‡, Quantity = USDT ๆˆไบค้ข + klines := make([]market.Kline, len(coinankKlines)) + for i, ck := range coinankKlines { + klines[i] = market.Kline{ + OpenTime: ck.StartTime, + Open: ck.Open, + High: ck.High, + Low: ck.Low, + Close: ck.Close, + Volume: ck.Volume, // BTC ๆ•ฐ้‡ + QuoteVolume: ck.Quantity, // USDT ๆˆไบค้ข + CloseTime: ck.EndTime, + } + } + + return klines, nil +} + +// getKlinesFromAlpaca fetches kline data from Alpaca API for US stocks +func (s *Server) getKlinesFromAlpaca(symbol, interval string, limit int) ([]market.Kline, error) { + // Create Alpaca client + client := alpaca.NewClient() + + // Map interval to Alpaca timeframe format + timeframe := alpaca.MapTimeframe(interval) + + // Fetch bars from Alpaca + ctx := context.Background() + bars, err := client.GetBars(ctx, symbol, timeframe, limit) + if err != nil { + return nil, fmt.Errorf("alpaca API error: %w", err) + } + + // Convert Alpaca bars to market.Kline format + klines := make([]market.Kline, len(bars)) + for i, bar := range bars { + klines[i] = market.Kline{ + OpenTime: bar.Timestamp.UnixMilli(), + Open: bar.Open, + High: bar.High, + Low: bar.Low, + Close: bar.Close, + Volume: float64(bar.Volume), // ่‚กๆ•ฐ + QuoteVolume: float64(bar.Volume) * bar.Close, // ๆˆไบค้ข = ่‚กๆ•ฐ * ๆ”ถ็›˜ไปท (USD) + CloseTime: bar.Timestamp.UnixMilli(), + } + } + + return klines, nil +} + +// getKlinesFromTwelveData fetches kline data from Twelve Data API for forex and metals +func (s *Server) getKlinesFromTwelveData(symbol, interval string, limit int) ([]market.Kline, error) { + // Create Twelve Data client + client := twelvedata.NewClient() + + // Map interval to Twelve Data timeframe format + timeframe := twelvedata.MapTimeframe(interval) + + // Fetch time series from Twelve Data + ctx := context.Background() + result, err := client.GetTimeSeries(ctx, symbol, timeframe, limit) + if err != nil { + return nil, fmt.Errorf("twelvedata API error: %w", err) + } + + // Convert Twelve Data bars to market.Kline format + // Note: Twelve Data returns bars in reverse order (newest first) + klines := make([]market.Kline, len(result.Values)) + for i, bar := range result.Values { + open, high, low, close, volume, timestamp, err := twelvedata.ParseBar(bar) + if err != nil { + logger.Warnf("โš ๏ธ Failed to parse TwelveData bar: %v", err) + continue + } + + // Reverse order: put oldest first + idx := len(result.Values) - 1 - i + klines[idx] = market.Kline{ + OpenTime: timestamp, + Open: open, + High: high, + Low: low, + Close: close, + Volume: volume, + CloseTime: timestamp, + } + } + + return klines, nil +} + +// getKlinesFromHyperliquid fetches kline data from Hyperliquid API +// Supports both crypto perps (default dex) and stock perps/forex/commodities (xyz dex) +func (s *Server) getKlinesFromHyperliquid(symbol, interval string, limit int) ([]market.Kline, error) { + // Create Hyperliquid client + client := hyperliquid.NewClient() + + // Map interval to Hyperliquid format + timeframe := hyperliquid.MapTimeframe(interval) + + // Fetch candles from Hyperliquid + // FormatCoinForAPI will automatically add xyz: prefix for stock perps + ctx := context.Background() + candles, err := client.GetCandles(ctx, symbol, timeframe, limit) + if err != nil { + return nil, fmt.Errorf("hyperliquid API error: %w", err) + } + + // Convert Hyperliquid candles to market.Kline format + klines := make([]market.Kline, len(candles)) + for i, candle := range candles { + open, _ := strconv.ParseFloat(candle.Open, 64) + high, _ := strconv.ParseFloat(candle.High, 64) + low, _ := strconv.ParseFloat(candle.Low, 64) + close, _ := strconv.ParseFloat(candle.Close, 64) + volume, _ := strconv.ParseFloat(candle.Volume, 64) + + klines[i] = market.Kline{ + OpenTime: candle.OpenTime, + Open: open, + High: high, + Low: low, + Close: close, + Volume: volume, // ๅˆ็บฆๆ•ฐ้‡ + QuoteVolume: volume * close, // ๆˆไบค้ข (USD) + CloseTime: candle.CloseTime, + } + } + + return klines, nil +} + +// handleSymbols returns available symbols for a given exchange +func (s *Server) handleSymbols(c *gin.Context) { + exchange := c.DefaultQuery("exchange", "hyperliquid") + + type SymbolInfo struct { + Symbol string `json:"symbol"` + Name string `json:"name"` + Category string `json:"category"` // crypto, stock, forex, commodity, index + MaxLeverage int `json:"maxLeverage,omitempty"` + } + + var symbols []SymbolInfo + + switch strings.ToLower(exchange) { + case "hyperliquid", "hyperliquid-xyz", "xyz": + // Fetch symbols from Hyperliquid + client := hyperliquid.NewClient() + ctx := context.Background() + + // Get crypto perps from default dex + if exchange == "hyperliquid" || exchange == "hyperliquid-xyz" { + mids, err := client.GetAllMids(ctx) + if err == nil { + for symbol := range mids { + // Skip spot tokens (start with @) + if strings.HasPrefix(symbol, "@") { + continue + } + symbols = append(symbols, SymbolInfo{ + Symbol: symbol, + Name: symbol, + Category: "crypto", + }) + } + } + } + + // Get xyz dex symbols (stocks, forex, commodities) + xyzMids, err := client.GetAllMidsXYZ(ctx) + if err == nil { + for symbol := range xyzMids { + // Remove xyz: prefix for display + displaySymbol := strings.TrimPrefix(symbol, "xyz:") + category := "stock" + if displaySymbol == "GOLD" || displaySymbol == "SILVER" { + category = "commodity" + } else if displaySymbol == "EUR" || displaySymbol == "JPY" { + category = "forex" + } else if displaySymbol == "XYZ100" { + category = "index" + } + symbols = append(symbols, SymbolInfo{ + Symbol: displaySymbol, + Name: displaySymbol, + Category: category, + }) + } + } + + default: + c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported exchange for symbol listing"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "exchange": exchange, + "symbols": symbols, + "count": len(symbols), + }) +} diff --git a/api/handler_order.go b/api/handler_order.go new file mode 100644 index 00000000..56939337 --- /dev/null +++ b/api/handler_order.go @@ -0,0 +1,402 @@ +package api + +import ( + "net/http" + "strconv" + + "nofx/logger" + "nofx/market" + + "github.com/gin-gonic/gin" +) + +// handleTraderList Trader list +func (s *Server) handleTraderList(c *gin.Context) { + userID := c.GetString("user_id") + traders, err := s.store.Trader().List(userID) + if err != nil { + SafeInternalError(c, "Failed to get trader list", err) + return + } + + result := make([]map[string]interface{}, 0, len(traders)) + for _, trader := range traders { + // Get real-time running status + isRunning := trader.IsRunning + if at, err := s.traderManager.GetTrader(trader.ID); err == nil { + status := at.GetStatus() + if running, ok := status["is_running"].(bool); ok { + isRunning = running + } + } + + // Get strategy name if strategy_id is set + var strategyName string + if trader.StrategyID != "" { + if strategy, err := s.store.Strategy().Get(userID, trader.StrategyID); err == nil { + strategyName = strategy.Name + } + } + + // Return complete AIModelID (e.g. "admin_deepseek"), don't truncate + // Frontend needs complete ID to verify model exists (consistent with handleGetTraderConfig) + result = append(result, map[string]interface{}{ + "trader_id": trader.ID, + "trader_name": trader.Name, + "ai_model": trader.AIModelID, // Use complete ID + "exchange_id": trader.ExchangeID, + "is_running": isRunning, + "show_in_competition": trader.ShowInCompetition, + "initial_balance": trader.InitialBalance, + "strategy_id": trader.StrategyID, + "strategy_name": strategyName, + }) + } + + c.JSON(http.StatusOK, result) +} + +// handleGetTraderConfig Get trader detailed configuration +func (s *Server) handleGetTraderConfig(c *gin.Context) { + userID := c.GetString("user_id") + traderID := c.Param("id") + + if traderID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Trader ID cannot be empty"}) + return + } + + fullCfg, err := s.store.Trader().GetFullConfig(userID, traderID) + if err != nil { + SafeNotFound(c, "Trader config") + return + } + traderConfig := fullCfg.Trader + + // Get real-time running status + isRunning := traderConfig.IsRunning + if at, err := s.traderManager.GetTrader(traderID); err == nil { + status := at.GetStatus() + if running, ok := status["is_running"].(bool); ok { + isRunning = running + } + } + + // Return complete model ID without conversion, consistent with frontend model list + aiModelID := traderConfig.AIModelID + + result := map[string]interface{}{ + "trader_id": traderConfig.ID, + "trader_name": traderConfig.Name, + "ai_model": aiModelID, + "exchange_id": traderConfig.ExchangeID, + "strategy_id": traderConfig.StrategyID, + "initial_balance": traderConfig.InitialBalance, + "scan_interval_minutes": traderConfig.ScanIntervalMinutes, + "btc_eth_leverage": traderConfig.BTCETHLeverage, + "altcoin_leverage": traderConfig.AltcoinLeverage, + "trading_symbols": traderConfig.TradingSymbols, + "custom_prompt": traderConfig.CustomPrompt, + "override_base_prompt": traderConfig.OverrideBasePrompt, + "is_cross_margin": traderConfig.IsCrossMargin, + "use_ai500": traderConfig.UseAI500, + "use_oi_top": traderConfig.UseOITop, + "is_running": isRunning, + } + + c.JSON(http.StatusOK, result) +} + +// handleStatus System status +func (s *Server) handleStatus(c *gin.Context) { + _, traderID, err := s.getTraderFromQuery(c) + if err != nil { + SafeBadRequest(c, "Invalid trader ID") + return + } + + trader, err := s.traderManager.GetTrader(traderID) + if err != nil { + SafeNotFound(c, "Trader") + return + } + + status := trader.GetStatus() + c.JSON(http.StatusOK, status) +} + +// handleAccount Account information +func (s *Server) handleAccount(c *gin.Context) { + _, traderID, err := s.getTraderFromQuery(c) + if err != nil { + SafeBadRequest(c, "Invalid trader ID") + return + } + + trader, err := s.traderManager.GetTrader(traderID) + if err != nil { + SafeNotFound(c, "Trader") + return + } + + logger.Infof("๐Ÿ“Š Received account info request [%s]", trader.GetName()) + account, err := trader.GetAccountInfo() + if err != nil { + SafeInternalError(c, "Get account info", err) + return + } + + logger.Infof("โœ“ Returning account info [%s]: equity=%.2f, available=%.2f, pnl=%.2f (%.2f%%)", + trader.GetName(), + account["total_equity"], + account["available_balance"], + account["total_pnl"], + account["total_pnl_pct"]) + c.JSON(http.StatusOK, account) +} + +// handlePositions Position list +func (s *Server) handlePositions(c *gin.Context) { + _, traderID, err := s.getTraderFromQuery(c) + if err != nil { + SafeBadRequest(c, "Invalid trader ID") + return + } + + trader, err := s.traderManager.GetTrader(traderID) + if err != nil { + SafeNotFound(c, "Trader") + return + } + + positions, err := trader.GetPositions() + if err != nil { + SafeInternalError(c, "Get positions", err) + return + } + + c.JSON(http.StatusOK, positions) +} + +// handlePositionHistory Historical closed positions with statistics +func (s *Server) handlePositionHistory(c *gin.Context) { + _, traderID, err := s.getTraderFromQuery(c) + if err != nil { + SafeBadRequest(c, "Invalid trader ID") + return + } + + trader, err := s.traderManager.GetTrader(traderID) + if err != nil { + SafeNotFound(c, "Trader") + return + } + + // Get optional query parameters + limitStr := c.DefaultQuery("limit", "100") + limit := 100 + if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 500 { + limit = l + } + + // Get store + store := trader.GetStore() + if store == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"}) + return + } + + // Get closed positions + positions, err := store.Position().GetClosedPositions(trader.GetID(), limit) + if err != nil { + SafeInternalError(c, "Get position history", err) + return + } + + // Get statistics + stats, _ := store.Position().GetFullStats(trader.GetID()) + + // Get symbol stats + symbolStats, _ := store.Position().GetSymbolStats(trader.GetID(), 10) + + // Get direction stats + directionStats, _ := store.Position().GetDirectionStats(trader.GetID()) + + c.JSON(http.StatusOK, gin.H{ + "positions": positions, + "stats": stats, + "symbol_stats": symbolStats, + "direction_stats": directionStats, + }) +} + +// handleTrades Historical trades list +func (s *Server) handleTrades(c *gin.Context) { + _, traderID, err := s.getTraderFromQuery(c) + if err != nil { + SafeBadRequest(c, "Invalid trader ID") + return + } + + trader, err := s.traderManager.GetTrader(traderID) + if err != nil { + SafeNotFound(c, "Trader") + return + } + + // Get optional query parameters + symbol := c.Query("symbol") + limitStr := c.DefaultQuery("limit", "100") + limit := 100 + if l, err := strconv.Atoi(limitStr); err == nil && l > 0 { + limit = l + } + + // Normalize symbol (add USDT suffix if not present) + if symbol != "" { + symbol = market.Normalize(symbol) + } + + // Get trades from store + store := trader.GetStore() + if store == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"}) + return + } + + allTrades, err := store.Position().GetRecentTrades(trader.GetID(), limit) + if err != nil { + SafeInternalError(c, "Get trades", err) + return + } + + // Filter by symbol if specified + if symbol != "" { + var result []interface{} + for _, trade := range allTrades { + if trade.Symbol == symbol { + result = append(result, trade) + } + } + c.JSON(http.StatusOK, result) + return + } + + c.JSON(http.StatusOK, allTrades) +} + +// handleOrders Order list (all orders including open, close, stop loss, take profit, etc.) +func (s *Server) handleOrders(c *gin.Context) { + _, traderID, err := s.getTraderFromQuery(c) + if err != nil { + SafeBadRequest(c, "Invalid trader ID") + return + } + + trader, err := s.traderManager.GetTrader(traderID) + if err != nil { + SafeNotFound(c, "Trader") + return + } + + // Get optional query parameters + symbol := c.Query("symbol") + statusFilter := c.Query("status") // NEW, FILLED, CANCELED, etc. + limitStr := c.DefaultQuery("limit", "100") + limit := 100 + if l, err := strconv.Atoi(limitStr); err == nil && l > 0 { + limit = l + } + + // Normalize symbol (add USDT suffix if not present) + if symbol != "" { + symbol = market.Normalize(symbol) + } + + // Get orders from store + store := trader.GetStore() + if store == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"}) + return + } + + // Get orders with filters applied at database level + orders, err := store.Order().GetTraderOrdersFiltered(trader.GetID(), symbol, statusFilter, limit) + if err != nil { + SafeInternalError(c, "Get orders", err) + return + } + + c.JSON(http.StatusOK, orders) +} + +// handleOrderFills Order fill details (all fills for a specific order) +func (s *Server) handleOrderFills(c *gin.Context) { + orderIDStr := c.Param("id") + orderID, err := strconv.ParseInt(orderIDStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid order ID"}) + return + } + + _, traderID, err := s.getTraderFromQuery(c) + if err != nil { + SafeBadRequest(c, "Invalid trader ID") + return + } + + trader, err := s.traderManager.GetTrader(traderID) + if err != nil { + SafeNotFound(c, "Trader") + return + } + + store := trader.GetStore() + if store == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"}) + return + } + + // Get fills for this order + fills, err := store.Order().GetOrderFills(orderID) + if err != nil { + SafeInternalError(c, "Get order fills", err) + return + } + + c.JSON(http.StatusOK, fills) +} + +// handleOpenOrders Get open orders (pending SL/TP) from exchange +func (s *Server) handleOpenOrders(c *gin.Context) { + _, traderID, err := s.getTraderFromQuery(c) + if err != nil { + SafeBadRequest(c, "Invalid trader ID") + return + } + + trader, err := s.traderManager.GetTrader(traderID) + if err != nil { + SafeNotFound(c, "Trader") + return + } + + // Get symbol parameter (required for exchange query) + symbol := c.Query("symbol") + if symbol == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "symbol parameter is required"}) + return + } + + // Normalize symbol + symbol = market.Normalize(symbol) + + // Get open orders from exchange + openOrders, err := trader.GetOpenOrders(symbol) + if err != nil { + SafeInternalError(c, "Get open orders", err) + return + } + + c.JSON(http.StatusOK, openOrders) +} diff --git a/api/handler_telegram.go b/api/handler_telegram.go new file mode 100644 index 00000000..f3d4c84e --- /dev/null +++ b/api/handler_telegram.go @@ -0,0 +1,105 @@ +package api + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +// handleGetTelegramConfig returns current Telegram bot configuration and binding status +func (s *Server) handleGetTelegramConfig(c *gin.Context) { + cfg, err := s.store.TelegramConfig().Get() + if err != nil { + // Not configured yet - return empty state + c.JSON(http.StatusOK, gin.H{ + "configured": false, + "is_bound": false, + "token_masked": "", + "username": "", + }) + return + } + + // Mask bot token for security (show only last 6 chars) + tokenMasked := "" + if cfg.BotToken != "" { + if len(cfg.BotToken) > 6 { + tokenMasked = "***" + cfg.BotToken[len(cfg.BotToken)-6:] + } else { + tokenMasked = "***" + } + } + + c.JSON(http.StatusOK, gin.H{ + "configured": cfg.BotToken != "", + "is_bound": cfg.ChatID != 0, + "username": cfg.Username, + "bound_at": cfg.BoundAt, + "token_masked": tokenMasked, + "model_id": cfg.ModelID, + }) +} + +// handleUpdateTelegramConfig saves bot token (+ optional model ID) and triggers bot hot-reload +func (s *Server) handleUpdateTelegramConfig(c *gin.Context) { + var req struct { + BotToken string `json:"bot_token"` + ModelID string `json:"model_id"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) + return + } + if req.BotToken == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "bot_token is required"}) + return + } + + if err := s.store.TelegramConfig().Save(req.BotToken, req.ModelID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save config"}) + return + } + + // Signal bot hot-reload if channel is available + if s.telegramReloadCh != nil { + select { + case s.telegramReloadCh <- struct{}{}: + default: // non-blocking + } + } + + c.JSON(http.StatusOK, gin.H{"success": true, "message": "Bot token saved. Bot will reload automatically."}) +} + +// handleUnbindTelegram removes Telegram user binding +func (s *Server) handleUnbindTelegram(c *gin.Context) { + if err := s.store.TelegramConfig().Unbind(); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to unbind"}) + return + } + c.JSON(http.StatusOK, gin.H{"success": true, "message": "Telegram binding removed"}) +} + +// handleUpdateTelegramModel updates only the AI model used for Telegram replies (no token re-entry needed) +func (s *Server) handleUpdateTelegramModel(c *gin.Context) { + var req struct { + ModelID string `json:"model_id"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) + return + } + + cfg, err := s.store.TelegramConfig().Get() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "no Telegram config found, save a bot token first"}) + return + } + + if err := s.store.TelegramConfig().Save(cfg.BotToken, req.ModelID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save model config"}) + return + } + + c.JSON(http.StatusOK, gin.H{"success": true, "model_id": req.ModelID}) +} diff --git a/api/handler_trader.go b/api/handler_trader.go new file mode 100644 index 00000000..ead84200 --- /dev/null +++ b/api/handler_trader.go @@ -0,0 +1,1212 @@ +package api + +import ( + "fmt" + "net/http" + "strings" + "time" + + "nofx/logger" + "nofx/store" + "nofx/trader" + "nofx/trader/aster" + "nofx/trader/binance" + "nofx/trader/bitget" + "nofx/trader/bybit" + "nofx/trader/gate" + hyperliquidtrader "nofx/trader/hyperliquid" + "nofx/trader/kucoin" + "nofx/trader/lighter" + "nofx/trader/okx" + + "github.com/gin-gonic/gin" +) + +// AI trader management related structures +type CreateTraderRequest struct { + Name string `json:"name" binding:"required"` + AIModelID string `json:"ai_model_id" binding:"required"` + ExchangeID string `json:"exchange_id" binding:"required"` + StrategyID string `json:"strategy_id"` // Strategy ID (new version) + InitialBalance float64 `json:"initial_balance"` + ScanIntervalMinutes int `json:"scan_interval_minutes"` + IsCrossMargin *bool `json:"is_cross_margin"` // Pointer type, nil means use default value true + ShowInCompetition *bool `json:"show_in_competition"` // Pointer type, nil means use default value true + // The following fields are kept for backward compatibility, new version uses strategy config + BTCETHLeverage int `json:"btc_eth_leverage"` + AltcoinLeverage int `json:"altcoin_leverage"` + TradingSymbols string `json:"trading_symbols"` + CustomPrompt string `json:"custom_prompt"` + OverrideBasePrompt bool `json:"override_base_prompt"` + SystemPromptTemplate string `json:"system_prompt_template"` // System prompt template name + UseAI500 bool `json:"use_ai500"` + UseOITop bool `json:"use_oi_top"` +} + +// UpdateTraderRequest Update trader request +type UpdateTraderRequest struct { + Name string `json:"name" binding:"required"` + AIModelID string `json:"ai_model_id" binding:"required"` + ExchangeID string `json:"exchange_id" binding:"required"` + StrategyID string `json:"strategy_id"` // Strategy ID (new version) + InitialBalance float64 `json:"initial_balance"` + ScanIntervalMinutes int `json:"scan_interval_minutes"` + IsCrossMargin *bool `json:"is_cross_margin"` + ShowInCompetition *bool `json:"show_in_competition"` + // The following fields are kept for backward compatibility, new version uses strategy config + BTCETHLeverage int `json:"btc_eth_leverage"` + AltcoinLeverage int `json:"altcoin_leverage"` + TradingSymbols string `json:"trading_symbols"` + CustomPrompt string `json:"custom_prompt"` + OverrideBasePrompt bool `json:"override_base_prompt"` + SystemPromptTemplate string `json:"system_prompt_template"` +} + +// handleCreateTrader Create new AI trader +func (s *Server) handleCreateTrader(c *gin.Context) { + userID := c.GetString("user_id") + var req CreateTraderRequest + if err := c.ShouldBindJSON(&req); err != nil { + SafeBadRequest(c, "Invalid request parameters") + return + } + + // Validate leverage values + if req.BTCETHLeverage < 0 || req.BTCETHLeverage > 50 { + c.JSON(http.StatusBadRequest, gin.H{"error": "BTC/ETH leverage must be between 1-50x"}) + return + } + if req.AltcoinLeverage < 0 || req.AltcoinLeverage > 20 { + c.JSON(http.StatusBadRequest, gin.H{"error": "Altcoin leverage must be between 1-20x"}) + return + } + + // Validate trading symbol format + if req.TradingSymbols != "" { + symbols := strings.Split(req.TradingSymbols, ",") + for _, symbol := range symbols { + symbol = strings.TrimSpace(symbol) + if symbol != "" && !strings.HasSuffix(strings.ToUpper(symbol), "USDT") { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid symbol format: %s, must end with USDT", symbol)}) + return + } + } + } + + // Generate trader ID (use short UUID prefix for readability) + exchangeIDShort := req.ExchangeID + if len(exchangeIDShort) > 8 { + exchangeIDShort = exchangeIDShort[:8] + } + traderID := fmt.Sprintf("%s_%s_%d", exchangeIDShort, req.AIModelID, time.Now().Unix()) + + // Set default values + isCrossMargin := true // Default to cross margin mode + if req.IsCrossMargin != nil { + isCrossMargin = *req.IsCrossMargin + } + + showInCompetition := true // Default to show in competition + if req.ShowInCompetition != nil { + showInCompetition = *req.ShowInCompetition + } + + // Set leverage default values + btcEthLeverage := 10 // Default value + altcoinLeverage := 5 // Default value + if req.BTCETHLeverage > 0 { + btcEthLeverage = req.BTCETHLeverage + } + if req.AltcoinLeverage > 0 { + altcoinLeverage = req.AltcoinLeverage + } + + // Set system prompt template default value + systemPromptTemplate := "default" + if req.SystemPromptTemplate != "" { + systemPromptTemplate = req.SystemPromptTemplate + } + + // Set scan interval default value + scanIntervalMinutes := req.ScanIntervalMinutes + if scanIntervalMinutes < 3 { + scanIntervalMinutes = 3 // Default 3 minutes, not allowed to be less than 3 + } + + // Query exchange actual balance, override user input + actualBalance := req.InitialBalance // Default to use user input + exchanges, err := s.store.Exchange().List(userID) + if err != nil { + logger.Infof("โš ๏ธ Failed to get exchange config, using user input for initial balance: %v", err) + } + + // Find matching exchange configuration + var exchangeCfg *store.Exchange + for _, ex := range exchanges { + if ex.ID == req.ExchangeID { + exchangeCfg = ex + break + } + } + + if exchangeCfg == nil { + logger.Infof("โš ๏ธ Exchange %s configuration not found, using user input for initial balance", req.ExchangeID) + } else if !exchangeCfg.Enabled { + logger.Infof("โš ๏ธ Exchange %s not enabled, using user input for initial balance", req.ExchangeID) + } else { + // Create temporary trader based on exchange type to query balance + var tempTrader trader.Trader + var createErr error + + // Use ExchangeType (e.g., "binance") instead of ID (UUID) + // Convert EncryptedString fields to string + switch exchangeCfg.ExchangeType { + case "binance": + tempTrader = binance.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID) + case "hyperliquid": + tempTrader, createErr = hyperliquidtrader.NewHyperliquidTrader( + string(exchangeCfg.APIKey), // private key + exchangeCfg.HyperliquidWalletAddr, + exchangeCfg.Testnet, + exchangeCfg.HyperliquidUnifiedAcct, + ) + case "aster": + tempTrader, createErr = aster.NewAsterTrader( + exchangeCfg.AsterUser, + exchangeCfg.AsterSigner, + string(exchangeCfg.AsterPrivateKey), + ) + case "bybit": + tempTrader = bybit.NewBybitTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + ) + case "okx": + tempTrader = okx.NewOKXTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + ) + case "bitget": + tempTrader = bitget.NewBitgetTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + ) + case "gate": + tempTrader = gate.NewGateTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + ) + case "kucoin": + tempTrader = kucoin.NewKuCoinTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + ) + case "lighter": + if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" { + // Lighter only supports mainnet + tempTrader, createErr = lighter.NewLighterTraderV2( + exchangeCfg.LighterWalletAddr, + string(exchangeCfg.LighterAPIKeyPrivateKey), + exchangeCfg.LighterAPIKeyIndex, + false, // Always use mainnet for Lighter + ) + } else { + createErr = fmt.Errorf("Lighter requires wallet address and API Key private key") + } + default: + logger.Infof("โš ๏ธ Unsupported exchange type: %s, using user input for initial balance", exchangeCfg.ExchangeType) + } + + if createErr != nil { + logger.Infof("โš ๏ธ Failed to create temporary trader, using user input for initial balance: %v", createErr) + } else if tempTrader != nil { + // Query actual balance + balanceInfo, balanceErr := tempTrader.GetBalance() + if balanceErr != nil { + logger.Infof("โš ๏ธ Failed to query exchange balance, using user input for initial balance: %v", balanceErr) + } else { + // Extract total equity (account total value = wallet balance + unrealized PnL) + // Priority: total_equity > totalWalletBalance > wallet_balance > totalEq > balance + // Note: Must use total_equity (not availableBalance) for accurate P&L calculation + balanceKeys := []string{"total_equity", "totalWalletBalance", "wallet_balance", "totalEq", "balance"} + for _, key := range balanceKeys { + if balance, ok := balanceInfo[key].(float64); ok && balance > 0 { + actualBalance = balance + logger.Infof("โœ“ Queried exchange total equity (%s): %.2f USDT (user input: %.2f USDT)", key, actualBalance, req.InitialBalance) + break + } + } + if actualBalance <= 0 { + logger.Infof("โš ๏ธ Unable to extract total equity from balance info, balanceInfo=%v, using user input for initial balance", balanceInfo) + } + } + } + } + + // Create trader configuration (database entity) + logger.Infof("๐Ÿ”ง DEBUG: Starting to create trader config, ID=%s, Name=%s, AIModel=%s, Exchange=%s, StrategyID=%s", traderID, req.Name, req.AIModelID, req.ExchangeID, req.StrategyID) + traderRecord := &store.Trader{ + ID: traderID, + UserID: userID, + Name: req.Name, + AIModelID: req.AIModelID, + ExchangeID: req.ExchangeID, + StrategyID: req.StrategyID, // Associated strategy ID (new version) + InitialBalance: actualBalance, // Use actual queried balance + BTCETHLeverage: btcEthLeverage, + AltcoinLeverage: altcoinLeverage, + TradingSymbols: req.TradingSymbols, + UseAI500: req.UseAI500, + UseOITop: req.UseOITop, + CustomPrompt: req.CustomPrompt, + OverrideBasePrompt: req.OverrideBasePrompt, + SystemPromptTemplate: systemPromptTemplate, + IsCrossMargin: isCrossMargin, + ShowInCompetition: showInCompetition, + ScanIntervalMinutes: scanIntervalMinutes, + IsRunning: false, + } + + // Save to database + logger.Infof("๐Ÿ”ง DEBUG: Preparing to call CreateTrader") + err = s.store.Trader().Create(traderRecord) + if err != nil { + logger.Infof("โŒ Failed to create trader: %v", err) + SafeInternalError(c, "Failed to create trader", err) + return + } + logger.Infof("๐Ÿ”ง DEBUG: CreateTrader succeeded") + + // Immediately load new trader into TraderManager + logger.Infof("๐Ÿ”ง DEBUG: Preparing to call LoadUserTraders") + err = s.traderManager.LoadUserTradersFromStore(s.store, userID) + if err != nil { + logger.Infof("โš ๏ธ Failed to load user traders into memory: %v", err) + // Don't return error here since trader was successfully created in database + } + logger.Infof("๐Ÿ”ง DEBUG: LoadUserTraders completed") + + logger.Infof("โœ“ Trader created successfully: %s (model: %s, exchange: %s)", req.Name, req.AIModelID, req.ExchangeID) + + c.JSON(http.StatusCreated, gin.H{ + "trader_id": traderID, + "trader_name": req.Name, + "ai_model": req.AIModelID, + "is_running": false, + }) +} + +// handleUpdateTrader Update trader configuration +func (s *Server) handleUpdateTrader(c *gin.Context) { + userID := c.GetString("user_id") + traderID := c.Param("id") + + var req UpdateTraderRequest + if err := c.ShouldBindJSON(&req); err != nil { + SafeBadRequest(c, "Invalid request parameters") + return + } + + // Check if trader exists and belongs to current user + traders, err := s.store.Trader().List(userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get trader list"}) + return + } + + var existingTrader *store.Trader + for _, t := range traders { + if t.ID == traderID { + existingTrader = t + break + } + } + + if existingTrader == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"}) + return + } + + // Set default values + isCrossMargin := existingTrader.IsCrossMargin // Keep original value + if req.IsCrossMargin != nil { + isCrossMargin = *req.IsCrossMargin + } + + showInCompetition := existingTrader.ShowInCompetition // Keep original value + if req.ShowInCompetition != nil { + showInCompetition = *req.ShowInCompetition + } + + // Set leverage default values + btcEthLeverage := req.BTCETHLeverage + altcoinLeverage := req.AltcoinLeverage + if btcEthLeverage <= 0 { + btcEthLeverage = existingTrader.BTCETHLeverage // Keep original value + } + if altcoinLeverage <= 0 { + altcoinLeverage = existingTrader.AltcoinLeverage // Keep original value + } + + // Set scan interval, allow updates + scanIntervalMinutes := req.ScanIntervalMinutes + logger.Infof("๐Ÿ“Š Update trader scan_interval: req=%d, existing=%d", req.ScanIntervalMinutes, existingTrader.ScanIntervalMinutes) + if scanIntervalMinutes <= 0 { + scanIntervalMinutes = existingTrader.ScanIntervalMinutes // Keep original value + } else if scanIntervalMinutes < 3 { + scanIntervalMinutes = 3 + } + logger.Infof("๐Ÿ“Š Final scan_interval_minutes: %d", scanIntervalMinutes) + + // Set system prompt template + systemPromptTemplate := req.SystemPromptTemplate + if systemPromptTemplate == "" { + systemPromptTemplate = existingTrader.SystemPromptTemplate // Keep original value + } + + // Handle strategy ID (if not provided, keep original value) + strategyID := req.StrategyID + if strategyID == "" { + strategyID = existingTrader.StrategyID + } + + // Update trader configuration + traderRecord := &store.Trader{ + ID: traderID, + UserID: userID, + Name: req.Name, + AIModelID: req.AIModelID, + ExchangeID: req.ExchangeID, + StrategyID: strategyID, // Associated strategy ID + InitialBalance: req.InitialBalance, + BTCETHLeverage: btcEthLeverage, + AltcoinLeverage: altcoinLeverage, + TradingSymbols: req.TradingSymbols, + CustomPrompt: req.CustomPrompt, + OverrideBasePrompt: req.OverrideBasePrompt, + SystemPromptTemplate: systemPromptTemplate, + IsCrossMargin: isCrossMargin, + ShowInCompetition: showInCompetition, + ScanIntervalMinutes: scanIntervalMinutes, + IsRunning: existingTrader.IsRunning, // Keep original value + } + + // Check if trader was running before update (we'll restart it after) + wasRunning := false + if existingMemTrader, memErr := s.traderManager.GetTrader(traderID); memErr == nil { + status := existingMemTrader.GetStatus() + if running, ok := status["is_running"].(bool); ok && running { + wasRunning = true + logger.Infof("๐Ÿ”„ Trader %s was running, will restart with new config after update", traderID) + } + } + + // Update database + logger.Infof("๐Ÿ”„ Updating trader: ID=%s, Name=%s, AIModelID=%s, StrategyID=%s, ScanInterval=%d min", + traderRecord.ID, traderRecord.Name, traderRecord.AIModelID, traderRecord.StrategyID, scanIntervalMinutes) + err = s.store.Trader().Update(traderRecord) + if err != nil { + SafeInternalError(c, "Failed to update trader", err) + return + } + + // Remove old trader from memory first (this also stops if running) + s.traderManager.RemoveTrader(traderID) + + // Reload traders into memory with fresh config + err = s.traderManager.LoadUserTradersFromStore(s.store, userID) + if err != nil { + logger.Infof("โš ๏ธ Failed to reload user traders into memory: %v", err) + } + + // If trader was running before, restart it with new config + if wasRunning { + if reloadedTrader, getErr := s.traderManager.GetTrader(traderID); getErr == nil { + go func() { + logger.Infof("โ–ถ๏ธ Restarting trader %s with new config...", traderID) + if runErr := reloadedTrader.Run(); runErr != nil { + logger.Infof("โŒ Trader %s runtime error: %v", traderID, runErr) + } + }() + } + } + + logger.Infof("โœ“ Trader updated successfully: %s (model: %s, exchange: %s, strategy: %s)", req.Name, req.AIModelID, req.ExchangeID, strategyID) + + c.JSON(http.StatusOK, gin.H{ + "trader_id": traderID, + "trader_name": req.Name, + "ai_model": req.AIModelID, + "message": "Trader updated successfully", + }) +} + +// handleDeleteTrader Delete trader +func (s *Server) handleDeleteTrader(c *gin.Context) { + userID := c.GetString("user_id") + traderID := c.Param("id") + + // Delete from database + err := s.store.Trader().Delete(userID, traderID) + if err != nil { + SafeInternalError(c, "Failed to delete trader", err) + return + } + + // If trader is running, stop it first + if trader, err := s.traderManager.GetTrader(traderID); err == nil { + status := trader.GetStatus() + if isRunning, ok := status["is_running"].(bool); ok && isRunning { + trader.Stop() + logger.Infof("โน Stopped running trader: %s", traderID) + } + } + + // Remove trader from memory + s.traderManager.RemoveTrader(traderID) + + logger.Infof("โœ“ Trader deleted: %s", traderID) + c.JSON(http.StatusOK, gin.H{"message": "Trader deleted"}) +} + +// handleStartTrader Start trader +func (s *Server) handleStartTrader(c *gin.Context) { + userID := c.GetString("user_id") + traderID := c.Param("id") + + // Verify trader belongs to current user + _, err := s.store.Trader().GetFullConfig(userID, traderID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist or no access permission"}) + return + } + + // Check if trader exists in memory and if it's running + existingTrader, _ := s.traderManager.GetTrader(traderID) + if existingTrader != nil { + status := existingTrader.GetStatus() + if isRunning, ok := status["is_running"].(bool); ok && isRunning { + c.JSON(http.StatusBadRequest, gin.H{"error": "Trader is already running"}) + return + } + // Trader exists but is stopped - remove from memory to reload fresh config + logger.Infof("๐Ÿ”„ Removing stopped trader %s from memory to reload config...", traderID) + s.traderManager.RemoveTrader(traderID) + } + + // Load trader from database (always reload to get latest config) + logger.Infof("๐Ÿ”„ Loading trader %s from database...", traderID) + if loadErr := s.traderManager.LoadUserTradersFromStore(s.store, userID); loadErr != nil { + logger.Infof("โŒ Failed to load user traders: %v", loadErr) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to load trader: " + loadErr.Error()}) + return + } + + trader, err := s.traderManager.GetTrader(traderID) + if err != nil { + // Check detailed reason + fullCfg, _ := s.store.Trader().GetFullConfig(userID, traderID) + if fullCfg != nil && fullCfg.Trader != nil { + // Check strategy + if fullCfg.Strategy == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Trader has no strategy configured, please create a strategy in Strategy Studio and associate it with the trader"}) + return + } + // Check AI model + if fullCfg.AIModel == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Trader's AI model does not exist, please check AI model configuration"}) + return + } + if !fullCfg.AIModel.Enabled { + c.JSON(http.StatusBadRequest, gin.H{"error": "Trader's AI model is not enabled, please enable the AI model first"}) + return + } + // Check exchange + if fullCfg.Exchange == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Trader's exchange does not exist, please check exchange configuration"}) + return + } + if !fullCfg.Exchange.Enabled { + c.JSON(http.StatusBadRequest, gin.H{"error": "Trader's exchange is not enabled, please enable the exchange first"}) + return + } + } + // Check if there's a specific load error + if loadErr := s.traderManager.GetLoadError(traderID); loadErr != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to load trader: " + loadErr.Error()}) + return + } + c.JSON(http.StatusNotFound, gin.H{"error": "Failed to load trader, please check AI model, exchange and strategy configuration"}) + return + } + + // Start trader + go func() { + logger.Infof("โ–ถ๏ธ Starting trader %s (%s)", traderID, trader.GetName()) + if err := trader.Run(); err != nil { + logger.Infof("โŒ Trader %s runtime error: %v", trader.GetName(), err) + } + }() + + // Update running status in database + err = s.store.Trader().UpdateStatus(userID, traderID, true) + if err != nil { + logger.Infof("โš ๏ธ Failed to update trader status: %v", err) + } + + logger.Infof("โœ“ Trader %s started", trader.GetName()) + c.JSON(http.StatusOK, gin.H{"message": "Trader started"}) +} + +// handleStopTrader Stop trader +func (s *Server) handleStopTrader(c *gin.Context) { + userID := c.GetString("user_id") + traderID := c.Param("id") + + // Verify trader belongs to current user + _, err := s.store.Trader().GetFullConfig(userID, traderID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist or no access permission"}) + return + } + + trader, err := s.traderManager.GetTrader(traderID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"}) + return + } + + // Check if trader is running + status := trader.GetStatus() + if isRunning, ok := status["is_running"].(bool); ok && !isRunning { + c.JSON(http.StatusBadRequest, gin.H{"error": "Trader is already stopped"}) + return + } + + // Stop trader + trader.Stop() + + // Update running status in database + err = s.store.Trader().UpdateStatus(userID, traderID, false) + if err != nil { + logger.Infof("โš ๏ธ Failed to update trader status: %v", err) + } + + logger.Infof("โน Trader %s stopped", trader.GetName()) + c.JSON(http.StatusOK, gin.H{"message": "Trader stopped"}) +} + +// handleUpdateTraderPrompt Update trader custom prompt +func (s *Server) handleUpdateTraderPrompt(c *gin.Context) { + traderID := c.Param("id") + userID := c.GetString("user_id") + + var req struct { + CustomPrompt string `json:"custom_prompt"` + OverrideBasePrompt bool `json:"override_base_prompt"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + SafeBadRequest(c, "Invalid request parameters") + return + } + + // Update database + err := s.store.Trader().UpdateCustomPrompt(userID, traderID, req.CustomPrompt, req.OverrideBasePrompt) + if err != nil { + SafeInternalError(c, "Failed to update custom prompt", err) + return + } + + // If trader is in memory, update its custom prompt and override settings + trader, err := s.traderManager.GetTrader(traderID) + if err == nil { + trader.SetCustomPrompt(req.CustomPrompt) + trader.SetOverrideBasePrompt(req.OverrideBasePrompt) + logger.Infof("โœ“ Updated trader %s custom prompt (override base=%v)", trader.GetName(), req.OverrideBasePrompt) + } + + c.JSON(http.StatusOK, gin.H{"message": "Custom prompt updated"}) +} + +// handleToggleCompetition Toggle trader competition visibility +func (s *Server) handleToggleCompetition(c *gin.Context) { + traderID := c.Param("id") + userID := c.GetString("user_id") + + var req struct { + ShowInCompetition bool `json:"show_in_competition"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + SafeBadRequest(c, "Invalid request parameters") + return + } + + // Update database + err := s.store.Trader().UpdateShowInCompetition(userID, traderID, req.ShowInCompetition) + if err != nil { + SafeInternalError(c, "Update competition visibility", err) + return + } + + // Update in-memory trader if it exists + if trader, err := s.traderManager.GetTrader(traderID); err == nil { + trader.SetShowInCompetition(req.ShowInCompetition) + } + + status := "shown" + if !req.ShowInCompetition { + status = "hidden" + } + logger.Infof("โœ“ Trader %s competition visibility updated: %s", traderID, status) + c.JSON(http.StatusOK, gin.H{ + "message": "Competition visibility updated", + "show_in_competition": req.ShowInCompetition, + }) +} + +// handleGetGridRiskInfo returns current risk information for a grid trader +func (s *Server) handleGetGridRiskInfo(c *gin.Context) { + traderID := c.Param("id") + + autoTrader, err := s.traderManager.GetTrader(traderID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "trader not found"}) + return + } + + riskInfo := autoTrader.GetGridRiskInfo() + c.JSON(http.StatusOK, riskInfo) +} + +// handleSyncBalance Sync exchange balance to initial_balance (Option B: Manual Sync + Option C: Smart Detection) +func (s *Server) handleSyncBalance(c *gin.Context) { + userID := c.GetString("user_id") + traderID := c.Param("id") + + logger.Infof("๐Ÿ”„ User %s requested balance sync for trader %s", userID, traderID) + + // Get trader configuration from database (including exchange info) + fullConfig, err := s.store.Trader().GetFullConfig(userID, traderID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"}) + return + } + + traderConfig := fullConfig.Trader + exchangeCfg := fullConfig.Exchange + + if exchangeCfg == nil || !exchangeCfg.Enabled { + c.JSON(http.StatusBadRequest, gin.H{"error": "Exchange not configured or not enabled"}) + return + } + + // Create temporary trader to query balance + var tempTrader trader.Trader + var createErr error + + // Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID) + // Convert EncryptedString fields to string + switch exchangeCfg.ExchangeType { + case "binance": + tempTrader = binance.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID) + case "hyperliquid": + tempTrader, createErr = hyperliquidtrader.NewHyperliquidTrader( + string(exchangeCfg.APIKey), + exchangeCfg.HyperliquidWalletAddr, + exchangeCfg.Testnet, + exchangeCfg.HyperliquidUnifiedAcct, + ) + case "aster": + tempTrader, createErr = aster.NewAsterTrader( + exchangeCfg.AsterUser, + exchangeCfg.AsterSigner, + string(exchangeCfg.AsterPrivateKey), + ) + case "bybit": + tempTrader = bybit.NewBybitTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + ) + case "okx": + tempTrader = okx.NewOKXTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + ) + case "bitget": + tempTrader = bitget.NewBitgetTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + ) + case "gate": + tempTrader = gate.NewGateTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + ) + case "kucoin": + tempTrader = kucoin.NewKuCoinTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + ) + case "lighter": + if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" { + // Lighter only supports mainnet + tempTrader, createErr = lighter.NewLighterTraderV2( + exchangeCfg.LighterWalletAddr, + string(exchangeCfg.LighterAPIKeyPrivateKey), + exchangeCfg.LighterAPIKeyIndex, + false, // Always use mainnet for Lighter + ) + } else { + createErr = fmt.Errorf("Lighter requires wallet address and API Key private key") + } + default: + c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported exchange type"}) + return + } + + if createErr != nil { + logger.Infof("โš ๏ธ Failed to create temporary trader: %v", createErr) + SafeInternalError(c, "Failed to connect to exchange", createErr) + return + } + + // Query actual balance + balanceInfo, balanceErr := tempTrader.GetBalance() + if balanceErr != nil { + logger.Infof("โš ๏ธ Failed to query exchange balance: %v", balanceErr) + SafeInternalError(c, "Failed to query balance", balanceErr) + return + } + + // Extract total equity (for P&L calculation, we need total account value, not available balance) + var actualBalance float64 + // Priority: total_equity > totalWalletBalance > wallet_balance > totalEq > balance + balanceKeys := []string{"total_equity", "totalWalletBalance", "wallet_balance", "totalEq", "balance"} + for _, key := range balanceKeys { + if balance, ok := balanceInfo[key].(float64); ok && balance > 0 { + actualBalance = balance + break + } + } + if actualBalance <= 0 { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Unable to get total equity"}) + return + } + + oldBalance := traderConfig.InitialBalance + + // โœ… Option C: Smart balance change detection + changePercent := ((actualBalance - oldBalance) / oldBalance) * 100 + changeType := "increase" + if changePercent < 0 { + changeType = "decrease" + } + + logger.Infof("โœ“ Queried actual exchange balance: %.2f USDT (current config: %.2f USDT, change: %.2f%%)", + actualBalance, oldBalance, changePercent) + + // Update initial_balance in database + err = s.store.Trader().UpdateInitialBalance(userID, traderID, actualBalance) + if err != nil { + logger.Infof("โŒ Failed to update initial_balance: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update balance"}) + return + } + + // Reload traders into memory + err = s.traderManager.LoadUserTradersFromStore(s.store, userID) + if err != nil { + logger.Infof("โš ๏ธ Failed to reload user traders into memory: %v", err) + } + + logger.Infof("โœ… Synced balance: %.2f โ†’ %.2f USDT (%s %.2f%%)", oldBalance, actualBalance, changeType, changePercent) + + c.JSON(http.StatusOK, gin.H{ + "message": "Balance synced successfully", + "old_balance": oldBalance, + "new_balance": actualBalance, + "change_percent": changePercent, + "change_type": changeType, + }) +} + +// handleClosePosition One-click close position +func (s *Server) handleClosePosition(c *gin.Context) { + userID := c.GetString("user_id") + traderID := c.Param("id") + + var req struct { + Symbol string `json:"symbol" binding:"required"` + Side string `json:"side" binding:"required"` // "LONG" or "SHORT" + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Parameter error: symbol and side are required"}) + return + } + + logger.Infof("๐Ÿ”ป User %s requested position close: trader=%s, symbol=%s, side=%s", userID, traderID, req.Symbol, req.Side) + + // Get trader configuration from database (including exchange info) + fullConfig, err := s.store.Trader().GetFullConfig(userID, traderID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"}) + return + } + + exchangeCfg := fullConfig.Exchange + + if exchangeCfg == nil || !exchangeCfg.Enabled { + c.JSON(http.StatusBadRequest, gin.H{"error": "Exchange not configured or not enabled"}) + return + } + + // Create temporary trader to execute close position + var tempTrader trader.Trader + var createErr error + + // Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID) + // Convert EncryptedString fields to string + switch exchangeCfg.ExchangeType { + case "binance": + tempTrader = binance.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID) + case "hyperliquid": + tempTrader, createErr = hyperliquidtrader.NewHyperliquidTrader( + string(exchangeCfg.APIKey), + exchangeCfg.HyperliquidWalletAddr, + exchangeCfg.Testnet, + exchangeCfg.HyperliquidUnifiedAcct, + ) + case "aster": + tempTrader, createErr = aster.NewAsterTrader( + exchangeCfg.AsterUser, + exchangeCfg.AsterSigner, + string(exchangeCfg.AsterPrivateKey), + ) + case "bybit": + tempTrader = bybit.NewBybitTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + ) + case "okx": + tempTrader = okx.NewOKXTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + ) + case "bitget": + tempTrader = bitget.NewBitgetTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + ) + case "gate": + tempTrader = gate.NewGateTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + ) + case "kucoin": + tempTrader = kucoin.NewKuCoinTrader( + string(exchangeCfg.APIKey), + string(exchangeCfg.SecretKey), + string(exchangeCfg.Passphrase), + ) + case "lighter": + if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" { + // Lighter only supports mainnet + tempTrader, createErr = lighter.NewLighterTraderV2( + exchangeCfg.LighterWalletAddr, + string(exchangeCfg.LighterAPIKeyPrivateKey), + exchangeCfg.LighterAPIKeyIndex, + false, // Always use mainnet for Lighter + ) + } else { + createErr = fmt.Errorf("Lighter requires wallet address and API Key private key") + } + default: + c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported exchange type"}) + return + } + + if createErr != nil { + logger.Infof("โš ๏ธ Failed to create temporary trader: %v", createErr) + SafeInternalError(c, "Failed to connect to exchange", createErr) + return + } + + // Get current position info BEFORE closing (to get quantity and price) + positions, err := tempTrader.GetPositions() + if err != nil { + logger.Infof("โš ๏ธ Failed to get positions: %v", err) + } + + var posQty float64 + var entryPrice float64 + for _, pos := range positions { + if pos["symbol"] == req.Symbol && pos["side"] == strings.ToLower(req.Side) { + if amt, ok := pos["positionAmt"].(float64); ok { + posQty = amt + if posQty < 0 { + posQty = -posQty // Make positive + } + } + if price, ok := pos["entryPrice"].(float64); ok { + entryPrice = price + } + break + } + } + + // Execute close position operation + var result map[string]interface{} + var closeErr error + + if req.Side == "LONG" { + result, closeErr = tempTrader.CloseLong(req.Symbol, 0) // 0 means close all + } else if req.Side == "SHORT" { + result, closeErr = tempTrader.CloseShort(req.Symbol, 0) // 0 means close all + } else { + c.JSON(http.StatusBadRequest, gin.H{"error": "side must be LONG or SHORT"}) + return + } + + if closeErr != nil { + logger.Infof("โŒ Close position failed: symbol=%s, side=%s, error=%v", req.Symbol, req.Side, closeErr) + SafeInternalError(c, "Close position", closeErr) + return + } + + logger.Infof("โœ… Position closed successfully: symbol=%s, side=%s, qty=%.6f, result=%v", req.Symbol, req.Side, posQty, result) + + // Record order to database (for chart markers and history) + s.recordClosePositionOrder(traderID, exchangeCfg.ID, exchangeCfg.ExchangeType, req.Symbol, req.Side, posQty, entryPrice, result) + + c.JSON(http.StatusOK, gin.H{ + "message": "Position closed successfully", + "symbol": req.Symbol, + "side": req.Side, + "result": result, + }) +} + +// recordClosePositionOrder Record close position order to database (Lighter version - direct FILLED status) +func (s *Server) recordClosePositionOrder(traderID, exchangeID, exchangeType, symbol, side string, quantity, exitPrice float64, result map[string]interface{}) { + // Skip for exchanges with OrderSync - let the background sync handle it to avoid duplicates + switch exchangeType { + case "binance", "lighter", "hyperliquid", "bybit", "okx", "bitget", "aster", "gate": + logger.Infof(" ๐Ÿ“ Close order will be synced by OrderSync, skipping immediate record") + return + } + + // Check if order was placed (skip if NO_POSITION) + status, _ := result["status"].(string) + if status == "NO_POSITION" { + logger.Infof(" โš ๏ธ No position to close, skipping order record") + return + } + + // Get order ID from result + var orderID string + switch v := result["orderId"].(type) { + case int64: + orderID = fmt.Sprintf("%d", v) + case float64: + orderID = fmt.Sprintf("%.0f", v) + case string: + orderID = v + default: + orderID = fmt.Sprintf("%v", v) + } + + if orderID == "" || orderID == "0" { + logger.Infof(" โš ๏ธ Order ID is empty, skipping record") + return + } + + // Determine order action based on side + var orderAction string + if side == "LONG" { + orderAction = "close_long" + } else { + orderAction = "close_short" + } + + // Use entry price if exit price not available + if exitPrice == 0 { + exitPrice = quantity * 100 // Rough estimate if we don't have price + } + + // Estimate fee (0.04% for Lighter taker) + fee := exitPrice * quantity * 0.0004 + + // Create order record - DIRECTLY as FILLED (Lighter market orders fill immediately) + orderRecord := &store.TraderOrder{ + TraderID: traderID, + ExchangeID: exchangeID, + ExchangeType: exchangeType, + ExchangeOrderID: orderID, + Symbol: symbol, + PositionSide: side, + OrderAction: orderAction, + Type: "MARKET", + Side: getSideFromAction(orderAction), + Quantity: quantity, + Price: 0, // Market order + Status: "FILLED", + FilledQuantity: quantity, + AvgFillPrice: exitPrice, + Commission: fee, + FilledAt: time.Now().UTC().UnixMilli(), + CreatedAt: time.Now().UTC().UnixMilli(), + UpdatedAt: time.Now().UTC().UnixMilli(), + } + + if err := s.store.Order().CreateOrder(orderRecord); err != nil { + logger.Infof(" โš ๏ธ Failed to record order: %v", err) + return + } + + logger.Infof(" โœ… Order recorded as FILLED: %s [%s] %s qty=%.6f price=%.6f", orderID, orderAction, symbol, quantity, exitPrice) + + // Create fill record immediately + tradeID := fmt.Sprintf("%s-%d", orderID, time.Now().UnixNano()) + fillRecord := &store.TraderFill{ + TraderID: traderID, + ExchangeID: exchangeID, + ExchangeType: exchangeType, + OrderID: orderRecord.ID, + ExchangeOrderID: orderID, + ExchangeTradeID: tradeID, + Symbol: symbol, + Side: getSideFromAction(orderAction), + Price: exitPrice, + Quantity: quantity, + QuoteQuantity: exitPrice * quantity, + Commission: fee, + CommissionAsset: "USDT", + RealizedPnL: 0, + IsMaker: false, + CreatedAt: time.Now().UTC().UnixMilli(), + } + + if err := s.store.Order().CreateFill(fillRecord); err != nil { + logger.Infof(" โš ๏ธ Failed to record fill: %v", err) + } else { + logger.Infof(" โœ… Fill record created: price=%.6f qty=%.6f", exitPrice, quantity) + } +} + +// pollAndUpdateOrderStatus Poll order status and update with fill data +func (s *Server) pollAndUpdateOrderStatus(orderRecordID int64, traderID, exchangeID, exchangeType, orderID, symbol, orderAction string, tempTrader trader.Trader) { + var actualPrice float64 + var actualQty float64 + var fee float64 + + // Wait a bit for order to be filled + time.Sleep(500 * time.Millisecond) + + // For Lighter, use GetTrades instead of GetOrderStatus (market orders are filled immediately) + if exchangeType == "lighter" { + s.pollLighterTradeHistory(orderRecordID, traderID, exchangeID, exchangeType, orderID, symbol, orderAction, tempTrader) + return + } + + // For other exchanges, poll GetOrderStatus + for i := 0; i < 5; i++ { + status, err := tempTrader.GetOrderStatus(symbol, orderID) + if err != nil { + logger.Infof(" โš ๏ธ GetOrderStatus failed (attempt %d/5): %v", i+1, err) + time.Sleep(500 * time.Millisecond) + continue + } + if err == nil { + statusStr, _ := status["status"].(string) + if statusStr == "FILLED" { + // Get actual fill price + if avgPrice, ok := status["avgPrice"].(float64); ok && avgPrice > 0 { + actualPrice = avgPrice + } + // Get actual executed quantity + if execQty, ok := status["executedQty"].(float64); ok && execQty > 0 { + actualQty = execQty + } + // Get commission/fee + if commission, ok := status["commission"].(float64); ok { + fee = commission + } + + logger.Infof(" โœ… Order filled: avgPrice=%.6f, qty=%.6f, fee=%.6f", actualPrice, actualQty, fee) + + // Update order status to FILLED + if err := s.store.Order().UpdateOrderStatus(orderRecordID, "FILLED", actualQty, actualPrice, fee); err != nil { + logger.Infof(" โš ๏ธ Failed to update order status: %v", err) + return + } + + // Record fill details + tradeID := fmt.Sprintf("%s-%d", orderID, time.Now().UnixNano()) + fillRecord := &store.TraderFill{ + TraderID: traderID, + ExchangeID: exchangeID, + ExchangeType: exchangeType, + OrderID: orderRecordID, + ExchangeOrderID: orderID, + ExchangeTradeID: tradeID, + Symbol: symbol, + Side: getSideFromAction(orderAction), + Price: actualPrice, + Quantity: actualQty, + QuoteQuantity: actualPrice * actualQty, + Commission: fee, + CommissionAsset: "USDT", + RealizedPnL: 0, + IsMaker: false, + CreatedAt: time.Now().UTC().UnixMilli(), + } + + if err := s.store.Order().CreateFill(fillRecord); err != nil { + logger.Infof(" โš ๏ธ Failed to record fill: %v", err) + } else { + logger.Infof(" ๐Ÿ“ Fill recorded: price=%.6f, qty=%.6f", actualPrice, actualQty) + } + + return + } else if statusStr == "CANCELED" || statusStr == "EXPIRED" || statusStr == "REJECTED" { + logger.Infof(" โš ๏ธ Order %s, updating status", statusStr) + s.store.Order().UpdateOrderStatus(orderRecordID, statusStr, 0, 0, 0) + return + } + } + time.Sleep(500 * time.Millisecond) + } + + logger.Infof(" โš ๏ธ Failed to confirm order fill after polling, order may still be pending") +} + +// pollLighterTradeHistory No longer used - Lighter orders are marked as FILLED immediately +// Keeping this function stub for compatibility with other exchanges +func (s *Server) pollLighterTradeHistory(orderRecordID int64, traderID, exchangeID, exchangeType, orderID, symbol, orderAction string, tempTrader trader.Trader) { + // For Lighter, orders are now recorded as FILLED immediately in recordClosePositionOrder + // This function is no longer called for Lighter exchange + logger.Infof(" โ„น๏ธ pollLighterTradeHistory called but not needed (order already marked FILLED)") +} + +// getSideFromAction Get order side (BUY/SELL) from order action +func getSideFromAction(action string) string { + switch action { + case "open_long", "close_short": + return "BUY" + case "open_short", "close_long": + return "SELL" + default: + return "BUY" + } +} diff --git a/api/handler_user.go b/api/handler_user.go new file mode 100644 index 00000000..1b22010c --- /dev/null +++ b/api/handler_user.go @@ -0,0 +1,223 @@ +package api + +import ( + "net/http" + "strings" + "time" + + "nofx/auth" + "nofx/logger" + "nofx/store" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +// handleLogout Add current token to blacklist +func (s *Server) handleLogout(c *gin.Context) { + authHeader := c.GetHeader("Authorization") + if authHeader == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing Authorization header"}) + return + } + parts := strings.Split(authHeader, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"}) + return + } + tokenString := parts[1] + claims, err := auth.ValidateJWT(tokenString) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) + return + } + var exp time.Time + if claims.ExpiresAt != nil { + exp = claims.ExpiresAt.Time + } else { + exp = time.Now().Add(24 * time.Hour) + } + auth.BlacklistToken(tokenString, exp) + c.JSON(http.StatusOK, gin.H{"message": "Logged out"}) +} + +// handleRegister Handle user registration request. +// handleRegister allows registration only when no users exist yet (first-time setup). +// This is a single-user system; subsequent registrations are permanently closed. +func (s *Server) handleRegister(c *gin.Context) { + userCount, err := s.store.User().Count() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check user count"}) + return + } + + if userCount > 0 { + c.JSON(http.StatusForbidden, gin.H{"error": "System already initialized"}) + return + } + + var req struct { + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required,min=6"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + SafeBadRequest(c, "Invalid request parameters") + return + } + + // Check if email already exists + _, err = s.store.User().GetByEmail(req.Email) + if err == nil { + c.JSON(http.StatusConflict, gin.H{"error": "Email already registered"}) + return + } + + // Generate password hash + passwordHash, err := auth.HashPassword(req.Password) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Password processing failed"}) + return + } + + // Create user + userID := uuid.New().String() + user := &store.User{ + ID: userID, + Email: req.Email, + PasswordHash: passwordHash, + } + + err = s.store.User().Create(user) + if err != nil { + SafeInternalError(c, "Failed to create user", err) + return + } + + // Generate JWT token + token, err := auth.GenerateJWT(user.ID, user.Email) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"}) + return + } + + // Initialize default model and exchange configs for user + err = s.initUserDefaultConfigs(user.ID) + if err != nil { + logger.Infof("Failed to initialize user default configs: %v", err) + } + + c.JSON(http.StatusOK, gin.H{ + "token": token, + "user_id": user.ID, + "email": user.Email, + "message": "Registration successful", + }) +} + +// handleLogin Handle user login request +func (s *Server) handleLogin(c *gin.Context) { + var req struct { + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + SafeBadRequest(c, "Invalid request parameters") + return + } + + // Get user information + user, err := s.store.User().GetByEmail(req.Email) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Email or password incorrect"}) + return + } + + // Verify password + if !auth.CheckPassword(req.Password, user.PasswordHash) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Email or password incorrect"}) + return + } + + // Issue token directly after password verification. + token, err := auth.GenerateJWT(user.ID, user.Email) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "token": token, + "user_id": user.ID, + "email": user.Email, + "message": "Login successful", + }) +} + +// handleChangePassword changes the password for the currently authenticated user. +func (s *Server) handleChangePassword(c *gin.Context) { + userID := c.GetString("user_id") + var req struct { + NewPassword string `json:"new_password" binding:"required,min=8"` + } + if err := c.ShouldBindJSON(&req); err != nil { + SafeBadRequest(c, "new_password is required (min 8 chars)") + return + } + hash, err := auth.HashPassword(req.NewPassword) + if err != nil { + SafeInternalError(c, "Password processing failed", err) + return + } + if err := s.store.User().UpdatePassword(userID, hash); err != nil { + SafeInternalError(c, "Failed to update password", err) + return + } + c.JSON(http.StatusOK, gin.H{"message": "Password updated"}) +} + +// handleResetPassword Reset password via email and new password +func (s *Server) handleResetPassword(c *gin.Context) { + var req struct { + Email string `json:"email" binding:"required,email"` + NewPassword string `json:"new_password" binding:"required,min=6"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + SafeBadRequest(c, "Invalid request parameters") + return + } + + // Query user + user, err := s.store.User().GetByEmail(req.Email) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Email does not exist"}) + return + } + + // Generate new password hash + newPasswordHash, err := auth.HashPassword(req.NewPassword) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Password processing failed"}) + return + } + + // Update password + err = s.store.User().UpdatePassword(user.ID, newPasswordHash) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Password update failed"}) + return + } + + logger.Infof("โœ“ User %s password has been reset", user.Email) + c.JSON(http.StatusOK, gin.H{"message": "Password reset successful, please login with new password"}) +} + +// initUserDefaultConfigs Initialize default model and exchange configs for new user +func (s *Server) initUserDefaultConfigs(userID string) error { + // Commented out auto-creation of default configs, let users add manually + // This way new users won't have config items automatically after registration + logger.Infof("User %s registration completed, waiting for manual AI model and exchange configuration", userID) + return nil +} diff --git a/api/server.go b/api/server.go index c87cf936..55944022 100644 --- a/api/server.go +++ b/api/server.go @@ -2,40 +2,19 @@ package api import ( "context" - "encoding/json" "fmt" "net" "net/http" "nofx/auth" "nofx/backtest" - "nofx/config" "nofx/crypto" "nofx/logger" "nofx/manager" - "nofx/security" - "nofx/market" - "nofx/provider/alpaca" - "nofx/provider/coinank/coinank_api" - "nofx/provider/coinank/coinank_enum" - "nofx/provider/hyperliquid" - "nofx/provider/twelvedata" "nofx/store" - "nofx/trader" - "nofx/trader/aster" - "nofx/trader/binance" - "nofx/trader/bitget" - "nofx/trader/bybit" - "nofx/trader/gate" - hyperliquidtrader "nofx/trader/hyperliquid" - "nofx/trader/kucoin" - "nofx/trader/lighter" - "nofx/trader/okx" - "strconv" "strings" "time" "github.com/gin-gonic/gin" - "github.com/google/uuid" ) // Server HTTP API server @@ -547,2620 +526,6 @@ func (s *Server) getTraderFromQuery(c *gin.Context) (*manager.TraderManager, str return s.traderManager, traderID, nil } -// AI trader management related structures -type CreateTraderRequest struct { - Name string `json:"name" binding:"required"` - AIModelID string `json:"ai_model_id" binding:"required"` - ExchangeID string `json:"exchange_id" binding:"required"` - StrategyID string `json:"strategy_id"` // Strategy ID (new version) - InitialBalance float64 `json:"initial_balance"` - ScanIntervalMinutes int `json:"scan_interval_minutes"` - IsCrossMargin *bool `json:"is_cross_margin"` // Pointer type, nil means use default value true - ShowInCompetition *bool `json:"show_in_competition"` // Pointer type, nil means use default value true - // The following fields are kept for backward compatibility, new version uses strategy config - BTCETHLeverage int `json:"btc_eth_leverage"` - AltcoinLeverage int `json:"altcoin_leverage"` - TradingSymbols string `json:"trading_symbols"` - CustomPrompt string `json:"custom_prompt"` - OverrideBasePrompt bool `json:"override_base_prompt"` - SystemPromptTemplate string `json:"system_prompt_template"` // System prompt template name - UseAI500 bool `json:"use_ai500"` - UseOITop bool `json:"use_oi_top"` -} - -type ModelConfig struct { - ID string `json:"id"` - Name string `json:"name"` - Provider string `json:"provider"` - Enabled bool `json:"enabled"` - APIKey string `json:"apiKey,omitempty"` - CustomAPIURL string `json:"customApiUrl,omitempty"` -} - -// SafeModelConfig Safe model configuration structure (does not contain sensitive information) -type SafeModelConfig struct { - ID string `json:"id"` - Name string `json:"name"` - Provider string `json:"provider"` - Enabled bool `json:"enabled"` - CustomAPIURL string `json:"customApiUrl"` // Custom API URL (usually not sensitive) - CustomModelName string `json:"customModelName"` // Custom model name (not sensitive) -} - -type ExchangeConfig struct { - ID string `json:"id"` - Name string `json:"name"` - Type string `json:"type"` // "cex" or "dex" - Enabled bool `json:"enabled"` - APIKey string `json:"apiKey,omitempty"` - SecretKey string `json:"secretKey,omitempty"` - Testnet bool `json:"testnet,omitempty"` -} - -// SafeExchangeConfig Safe exchange configuration structure (does not contain sensitive information) -type SafeExchangeConfig struct { - ID string `json:"id"` // UUID - ExchangeType string `json:"exchange_type"` // "binance", "bybit", "okx", "hyperliquid", "aster", "lighter" - AccountName string `json:"account_name"` // User-defined account name - Name string `json:"name"` // Display name - Type string `json:"type"` // "cex" or "dex" - Enabled bool `json:"enabled"` - Testnet bool `json:"testnet,omitempty"` - HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` // Hyperliquid wallet address (not sensitive) - AsterUser string `json:"asterUser"` // Aster username (not sensitive) - AsterSigner string `json:"asterSigner"` // Aster signer (not sensitive) - LighterWalletAddr string `json:"lighterWalletAddr"` // LIGHTER wallet address (not sensitive) -} - -type UpdateModelConfigRequest struct { - Models map[string]struct { - Enabled bool `json:"enabled"` - APIKey string `json:"api_key"` - CustomAPIURL string `json:"custom_api_url"` - CustomModelName string `json:"custom_model_name"` - } `json:"models"` -} - -type UpdateExchangeConfigRequest struct { - Exchanges map[string]struct { - Enabled bool `json:"enabled"` - APIKey string `json:"api_key"` - SecretKey string `json:"secret_key"` - Passphrase string `json:"passphrase"` // OKX specific - Testnet bool `json:"testnet"` - HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"` - HyperliquidUnifiedAcct bool `json:"hyperliquid_unified_account"` // Unified Account mode - AsterUser string `json:"aster_user"` - AsterSigner string `json:"aster_signer"` - AsterPrivateKey string `json:"aster_private_key"` - LighterWalletAddr string `json:"lighter_wallet_addr"` - LighterPrivateKey string `json:"lighter_private_key"` - LighterAPIKeyPrivateKey string `json:"lighter_api_key_private_key"` - LighterAPIKeyIndex int `json:"lighter_api_key_index"` - } `json:"exchanges"` -} - -// handleCreateTrader Create new AI trader -func (s *Server) handleCreateTrader(c *gin.Context) { - userID := c.GetString("user_id") - var req CreateTraderRequest - if err := c.ShouldBindJSON(&req); err != nil { - SafeBadRequest(c, "Invalid request parameters") - return - } - - // Validate leverage values - if req.BTCETHLeverage < 0 || req.BTCETHLeverage > 50 { - c.JSON(http.StatusBadRequest, gin.H{"error": "BTC/ETH leverage must be between 1-50x"}) - return - } - if req.AltcoinLeverage < 0 || req.AltcoinLeverage > 20 { - c.JSON(http.StatusBadRequest, gin.H{"error": "Altcoin leverage must be between 1-20x"}) - return - } - - // Validate trading symbol format - if req.TradingSymbols != "" { - symbols := strings.Split(req.TradingSymbols, ",") - for _, symbol := range symbols { - symbol = strings.TrimSpace(symbol) - if symbol != "" && !strings.HasSuffix(strings.ToUpper(symbol), "USDT") { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid symbol format: %s, must end with USDT", symbol)}) - return - } - } - } - - // Generate trader ID (use short UUID prefix for readability) - exchangeIDShort := req.ExchangeID - if len(exchangeIDShort) > 8 { - exchangeIDShort = exchangeIDShort[:8] - } - traderID := fmt.Sprintf("%s_%s_%d", exchangeIDShort, req.AIModelID, time.Now().Unix()) - - // Set default values - isCrossMargin := true // Default to cross margin mode - if req.IsCrossMargin != nil { - isCrossMargin = *req.IsCrossMargin - } - - showInCompetition := true // Default to show in competition - if req.ShowInCompetition != nil { - showInCompetition = *req.ShowInCompetition - } - - // Set leverage default values - btcEthLeverage := 10 // Default value - altcoinLeverage := 5 // Default value - if req.BTCETHLeverage > 0 { - btcEthLeverage = req.BTCETHLeverage - } - if req.AltcoinLeverage > 0 { - altcoinLeverage = req.AltcoinLeverage - } - - // Set system prompt template default value - systemPromptTemplate := "default" - if req.SystemPromptTemplate != "" { - systemPromptTemplate = req.SystemPromptTemplate - } - - // Set scan interval default value - scanIntervalMinutes := req.ScanIntervalMinutes - if scanIntervalMinutes < 3 { - scanIntervalMinutes = 3 // Default 3 minutes, not allowed to be less than 3 - } - - // Query exchange actual balance, override user input - actualBalance := req.InitialBalance // Default to use user input - exchanges, err := s.store.Exchange().List(userID) - if err != nil { - logger.Infof("โš ๏ธ Failed to get exchange config, using user input for initial balance: %v", err) - } - - // Find matching exchange configuration - var exchangeCfg *store.Exchange - for _, ex := range exchanges { - if ex.ID == req.ExchangeID { - exchangeCfg = ex - break - } - } - - if exchangeCfg == nil { - logger.Infof("โš ๏ธ Exchange %s configuration not found, using user input for initial balance", req.ExchangeID) - } else if !exchangeCfg.Enabled { - logger.Infof("โš ๏ธ Exchange %s not enabled, using user input for initial balance", req.ExchangeID) - } else { - // Create temporary trader based on exchange type to query balance - var tempTrader trader.Trader - var createErr error - - // Use ExchangeType (e.g., "binance") instead of ID (UUID) - // Convert EncryptedString fields to string - switch exchangeCfg.ExchangeType { - case "binance": - tempTrader = binance.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID) - case "hyperliquid": - tempTrader, createErr = hyperliquidtrader.NewHyperliquidTrader( - string(exchangeCfg.APIKey), // private key - exchangeCfg.HyperliquidWalletAddr, - exchangeCfg.Testnet, - exchangeCfg.HyperliquidUnifiedAcct, - ) - case "aster": - tempTrader, createErr = aster.NewAsterTrader( - exchangeCfg.AsterUser, - exchangeCfg.AsterSigner, - string(exchangeCfg.AsterPrivateKey), - ) - case "bybit": - tempTrader = bybit.NewBybitTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - ) - case "okx": - tempTrader = okx.NewOKXTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - string(exchangeCfg.Passphrase), - ) - case "bitget": - tempTrader = bitget.NewBitgetTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - string(exchangeCfg.Passphrase), - ) - case "gate": - tempTrader = gate.NewGateTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - ) - case "kucoin": - tempTrader = kucoin.NewKuCoinTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - string(exchangeCfg.Passphrase), - ) - case "lighter": - if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" { - // Lighter only supports mainnet - tempTrader, createErr = lighter.NewLighterTraderV2( - exchangeCfg.LighterWalletAddr, - string(exchangeCfg.LighterAPIKeyPrivateKey), - exchangeCfg.LighterAPIKeyIndex, - false, // Always use mainnet for Lighter - ) - } else { - createErr = fmt.Errorf("Lighter requires wallet address and API Key private key") - } - default: - logger.Infof("โš ๏ธ Unsupported exchange type: %s, using user input for initial balance", exchangeCfg.ExchangeType) - } - - if createErr != nil { - logger.Infof("โš ๏ธ Failed to create temporary trader, using user input for initial balance: %v", createErr) - } else if tempTrader != nil { - // Query actual balance - balanceInfo, balanceErr := tempTrader.GetBalance() - if balanceErr != nil { - logger.Infof("โš ๏ธ Failed to query exchange balance, using user input for initial balance: %v", balanceErr) - } else { - // Extract total equity (account total value = wallet balance + unrealized PnL) - // Priority: total_equity > totalWalletBalance > wallet_balance > totalEq > balance - // Note: Must use total_equity (not availableBalance) for accurate P&L calculation - balanceKeys := []string{"total_equity", "totalWalletBalance", "wallet_balance", "totalEq", "balance"} - for _, key := range balanceKeys { - if balance, ok := balanceInfo[key].(float64); ok && balance > 0 { - actualBalance = balance - logger.Infof("โœ“ Queried exchange total equity (%s): %.2f USDT (user input: %.2f USDT)", key, actualBalance, req.InitialBalance) - break - } - } - if actualBalance <= 0 { - logger.Infof("โš ๏ธ Unable to extract total equity from balance info, balanceInfo=%v, using user input for initial balance", balanceInfo) - } - } - } - } - - // Create trader configuration (database entity) - logger.Infof("๐Ÿ”ง DEBUG: Starting to create trader config, ID=%s, Name=%s, AIModel=%s, Exchange=%s, StrategyID=%s", traderID, req.Name, req.AIModelID, req.ExchangeID, req.StrategyID) - traderRecord := &store.Trader{ - ID: traderID, - UserID: userID, - Name: req.Name, - AIModelID: req.AIModelID, - ExchangeID: req.ExchangeID, - StrategyID: req.StrategyID, // Associated strategy ID (new version) - InitialBalance: actualBalance, // Use actual queried balance - BTCETHLeverage: btcEthLeverage, - AltcoinLeverage: altcoinLeverage, - TradingSymbols: req.TradingSymbols, - UseAI500: req.UseAI500, - UseOITop: req.UseOITop, - CustomPrompt: req.CustomPrompt, - OverrideBasePrompt: req.OverrideBasePrompt, - SystemPromptTemplate: systemPromptTemplate, - IsCrossMargin: isCrossMargin, - ShowInCompetition: showInCompetition, - ScanIntervalMinutes: scanIntervalMinutes, - IsRunning: false, - } - - // Save to database - logger.Infof("๐Ÿ”ง DEBUG: Preparing to call CreateTrader") - err = s.store.Trader().Create(traderRecord) - if err != nil { - logger.Infof("โŒ Failed to create trader: %v", err) - SafeInternalError(c, "Failed to create trader", err) - return - } - logger.Infof("๐Ÿ”ง DEBUG: CreateTrader succeeded") - - // Immediately load new trader into TraderManager - logger.Infof("๐Ÿ”ง DEBUG: Preparing to call LoadUserTraders") - err = s.traderManager.LoadUserTradersFromStore(s.store, userID) - if err != nil { - logger.Infof("โš ๏ธ Failed to load user traders into memory: %v", err) - // Don't return error here since trader was successfully created in database - } - logger.Infof("๐Ÿ”ง DEBUG: LoadUserTraders completed") - - logger.Infof("โœ“ Trader created successfully: %s (model: %s, exchange: %s)", req.Name, req.AIModelID, req.ExchangeID) - - c.JSON(http.StatusCreated, gin.H{ - "trader_id": traderID, - "trader_name": req.Name, - "ai_model": req.AIModelID, - "is_running": false, - }) -} - -// UpdateTraderRequest Update trader request -type UpdateTraderRequest struct { - Name string `json:"name" binding:"required"` - AIModelID string `json:"ai_model_id" binding:"required"` - ExchangeID string `json:"exchange_id" binding:"required"` - StrategyID string `json:"strategy_id"` // Strategy ID (new version) - InitialBalance float64 `json:"initial_balance"` - ScanIntervalMinutes int `json:"scan_interval_minutes"` - IsCrossMargin *bool `json:"is_cross_margin"` - ShowInCompetition *bool `json:"show_in_competition"` - // The following fields are kept for backward compatibility, new version uses strategy config - BTCETHLeverage int `json:"btc_eth_leverage"` - AltcoinLeverage int `json:"altcoin_leverage"` - TradingSymbols string `json:"trading_symbols"` - CustomPrompt string `json:"custom_prompt"` - OverrideBasePrompt bool `json:"override_base_prompt"` - SystemPromptTemplate string `json:"system_prompt_template"` -} - -// handleUpdateTrader Update trader configuration -func (s *Server) handleUpdateTrader(c *gin.Context) { - userID := c.GetString("user_id") - traderID := c.Param("id") - - var req UpdateTraderRequest - if err := c.ShouldBindJSON(&req); err != nil { - SafeBadRequest(c, "Invalid request parameters") - return - } - - // Check if trader exists and belongs to current user - traders, err := s.store.Trader().List(userID) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get trader list"}) - return - } - - var existingTrader *store.Trader - for _, t := range traders { - if t.ID == traderID { - existingTrader = t - break - } - } - - if existingTrader == nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"}) - return - } - - // Set default values - isCrossMargin := existingTrader.IsCrossMargin // Keep original value - if req.IsCrossMargin != nil { - isCrossMargin = *req.IsCrossMargin - } - - showInCompetition := existingTrader.ShowInCompetition // Keep original value - if req.ShowInCompetition != nil { - showInCompetition = *req.ShowInCompetition - } - - // Set leverage default values - btcEthLeverage := req.BTCETHLeverage - altcoinLeverage := req.AltcoinLeverage - if btcEthLeverage <= 0 { - btcEthLeverage = existingTrader.BTCETHLeverage // Keep original value - } - if altcoinLeverage <= 0 { - altcoinLeverage = existingTrader.AltcoinLeverage // Keep original value - } - - // Set scan interval, allow updates - scanIntervalMinutes := req.ScanIntervalMinutes - logger.Infof("๐Ÿ“Š Update trader scan_interval: req=%d, existing=%d", req.ScanIntervalMinutes, existingTrader.ScanIntervalMinutes) - if scanIntervalMinutes <= 0 { - scanIntervalMinutes = existingTrader.ScanIntervalMinutes // Keep original value - } else if scanIntervalMinutes < 3 { - scanIntervalMinutes = 3 - } - logger.Infof("๐Ÿ“Š Final scan_interval_minutes: %d", scanIntervalMinutes) - - // Set system prompt template - systemPromptTemplate := req.SystemPromptTemplate - if systemPromptTemplate == "" { - systemPromptTemplate = existingTrader.SystemPromptTemplate // Keep original value - } - - // Handle strategy ID (if not provided, keep original value) - strategyID := req.StrategyID - if strategyID == "" { - strategyID = existingTrader.StrategyID - } - - // Update trader configuration - traderRecord := &store.Trader{ - ID: traderID, - UserID: userID, - Name: req.Name, - AIModelID: req.AIModelID, - ExchangeID: req.ExchangeID, - StrategyID: strategyID, // Associated strategy ID - InitialBalance: req.InitialBalance, - BTCETHLeverage: btcEthLeverage, - AltcoinLeverage: altcoinLeverage, - TradingSymbols: req.TradingSymbols, - CustomPrompt: req.CustomPrompt, - OverrideBasePrompt: req.OverrideBasePrompt, - SystemPromptTemplate: systemPromptTemplate, - IsCrossMargin: isCrossMargin, - ShowInCompetition: showInCompetition, - ScanIntervalMinutes: scanIntervalMinutes, - IsRunning: existingTrader.IsRunning, // Keep original value - } - - // Check if trader was running before update (we'll restart it after) - wasRunning := false - if existingMemTrader, memErr := s.traderManager.GetTrader(traderID); memErr == nil { - status := existingMemTrader.GetStatus() - if running, ok := status["is_running"].(bool); ok && running { - wasRunning = true - logger.Infof("๐Ÿ”„ Trader %s was running, will restart with new config after update", traderID) - } - } - - // Update database - logger.Infof("๐Ÿ”„ Updating trader: ID=%s, Name=%s, AIModelID=%s, StrategyID=%s, ScanInterval=%d min", - traderRecord.ID, traderRecord.Name, traderRecord.AIModelID, traderRecord.StrategyID, scanIntervalMinutes) - err = s.store.Trader().Update(traderRecord) - if err != nil { - SafeInternalError(c, "Failed to update trader", err) - return - } - - // Remove old trader from memory first (this also stops if running) - s.traderManager.RemoveTrader(traderID) - - // Reload traders into memory with fresh config - err = s.traderManager.LoadUserTradersFromStore(s.store, userID) - if err != nil { - logger.Infof("โš ๏ธ Failed to reload user traders into memory: %v", err) - } - - // If trader was running before, restart it with new config - if wasRunning { - if reloadedTrader, getErr := s.traderManager.GetTrader(traderID); getErr == nil { - go func() { - logger.Infof("โ–ถ๏ธ Restarting trader %s with new config...", traderID) - if runErr := reloadedTrader.Run(); runErr != nil { - logger.Infof("โŒ Trader %s runtime error: %v", traderID, runErr) - } - }() - } - } - - logger.Infof("โœ“ Trader updated successfully: %s (model: %s, exchange: %s, strategy: %s)", req.Name, req.AIModelID, req.ExchangeID, strategyID) - - c.JSON(http.StatusOK, gin.H{ - "trader_id": traderID, - "trader_name": req.Name, - "ai_model": req.AIModelID, - "message": "Trader updated successfully", - }) -} - -// handleDeleteTrader Delete trader -func (s *Server) handleDeleteTrader(c *gin.Context) { - userID := c.GetString("user_id") - traderID := c.Param("id") - - // Delete from database - err := s.store.Trader().Delete(userID, traderID) - if err != nil { - SafeInternalError(c, "Failed to delete trader", err) - return - } - - // If trader is running, stop it first - if trader, err := s.traderManager.GetTrader(traderID); err == nil { - status := trader.GetStatus() - if isRunning, ok := status["is_running"].(bool); ok && isRunning { - trader.Stop() - logger.Infof("โน Stopped running trader: %s", traderID) - } - } - - // Remove trader from memory - s.traderManager.RemoveTrader(traderID) - - logger.Infof("โœ“ Trader deleted: %s", traderID) - c.JSON(http.StatusOK, gin.H{"message": "Trader deleted"}) -} - -// handleStartTrader Start trader -func (s *Server) handleStartTrader(c *gin.Context) { - userID := c.GetString("user_id") - traderID := c.Param("id") - - // Verify trader belongs to current user - _, err := s.store.Trader().GetFullConfig(userID, traderID) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist or no access permission"}) - return - } - - // Check if trader exists in memory and if it's running - existingTrader, _ := s.traderManager.GetTrader(traderID) - if existingTrader != nil { - status := existingTrader.GetStatus() - if isRunning, ok := status["is_running"].(bool); ok && isRunning { - c.JSON(http.StatusBadRequest, gin.H{"error": "Trader is already running"}) - return - } - // Trader exists but is stopped - remove from memory to reload fresh config - logger.Infof("๐Ÿ”„ Removing stopped trader %s from memory to reload config...", traderID) - s.traderManager.RemoveTrader(traderID) - } - - // Load trader from database (always reload to get latest config) - logger.Infof("๐Ÿ”„ Loading trader %s from database...", traderID) - if loadErr := s.traderManager.LoadUserTradersFromStore(s.store, userID); loadErr != nil { - logger.Infof("โŒ Failed to load user traders: %v", loadErr) - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to load trader: " + loadErr.Error()}) - return - } - - trader, err := s.traderManager.GetTrader(traderID) - if err != nil { - // Check detailed reason - fullCfg, _ := s.store.Trader().GetFullConfig(userID, traderID) - if fullCfg != nil && fullCfg.Trader != nil { - // Check strategy - if fullCfg.Strategy == nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Trader has no strategy configured, please create a strategy in Strategy Studio and associate it with the trader"}) - return - } - // Check AI model - if fullCfg.AIModel == nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Trader's AI model does not exist, please check AI model configuration"}) - return - } - if !fullCfg.AIModel.Enabled { - c.JSON(http.StatusBadRequest, gin.H{"error": "Trader's AI model is not enabled, please enable the AI model first"}) - return - } - // Check exchange - if fullCfg.Exchange == nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Trader's exchange does not exist, please check exchange configuration"}) - return - } - if !fullCfg.Exchange.Enabled { - c.JSON(http.StatusBadRequest, gin.H{"error": "Trader's exchange is not enabled, please enable the exchange first"}) - return - } - } - // Check if there's a specific load error - if loadErr := s.traderManager.GetLoadError(traderID); loadErr != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to load trader: " + loadErr.Error()}) - return - } - c.JSON(http.StatusNotFound, gin.H{"error": "Failed to load trader, please check AI model, exchange and strategy configuration"}) - return - } - - // Start trader - go func() { - logger.Infof("โ–ถ๏ธ Starting trader %s (%s)", traderID, trader.GetName()) - if err := trader.Run(); err != nil { - logger.Infof("โŒ Trader %s runtime error: %v", trader.GetName(), err) - } - }() - - // Update running status in database - err = s.store.Trader().UpdateStatus(userID, traderID, true) - if err != nil { - logger.Infof("โš ๏ธ Failed to update trader status: %v", err) - } - - logger.Infof("โœ“ Trader %s started", trader.GetName()) - c.JSON(http.StatusOK, gin.H{"message": "Trader started"}) -} - -// handleStopTrader Stop trader -func (s *Server) handleStopTrader(c *gin.Context) { - userID := c.GetString("user_id") - traderID := c.Param("id") - - // Verify trader belongs to current user - _, err := s.store.Trader().GetFullConfig(userID, traderID) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist or no access permission"}) - return - } - - trader, err := s.traderManager.GetTrader(traderID) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"}) - return - } - - // Check if trader is running - status := trader.GetStatus() - if isRunning, ok := status["is_running"].(bool); ok && !isRunning { - c.JSON(http.StatusBadRequest, gin.H{"error": "Trader is already stopped"}) - return - } - - // Stop trader - trader.Stop() - - // Update running status in database - err = s.store.Trader().UpdateStatus(userID, traderID, false) - if err != nil { - logger.Infof("โš ๏ธ Failed to update trader status: %v", err) - } - - logger.Infof("โน Trader %s stopped", trader.GetName()) - c.JSON(http.StatusOK, gin.H{"message": "Trader stopped"}) -} - -// handleUpdateTraderPrompt Update trader custom prompt -func (s *Server) handleUpdateTraderPrompt(c *gin.Context) { - traderID := c.Param("id") - userID := c.GetString("user_id") - - var req struct { - CustomPrompt string `json:"custom_prompt"` - OverrideBasePrompt bool `json:"override_base_prompt"` - } - - if err := c.ShouldBindJSON(&req); err != nil { - SafeBadRequest(c, "Invalid request parameters") - return - } - - // Update database - err := s.store.Trader().UpdateCustomPrompt(userID, traderID, req.CustomPrompt, req.OverrideBasePrompt) - if err != nil { - SafeInternalError(c, "Failed to update custom prompt", err) - return - } - - // If trader is in memory, update its custom prompt and override settings - trader, err := s.traderManager.GetTrader(traderID) - if err == nil { - trader.SetCustomPrompt(req.CustomPrompt) - trader.SetOverrideBasePrompt(req.OverrideBasePrompt) - logger.Infof("โœ“ Updated trader %s custom prompt (override base=%v)", trader.GetName(), req.OverrideBasePrompt) - } - - c.JSON(http.StatusOK, gin.H{"message": "Custom prompt updated"}) -} - -// handleToggleCompetition Toggle trader competition visibility -func (s *Server) handleToggleCompetition(c *gin.Context) { - traderID := c.Param("id") - userID := c.GetString("user_id") - - var req struct { - ShowInCompetition bool `json:"show_in_competition"` - } - - if err := c.ShouldBindJSON(&req); err != nil { - SafeBadRequest(c, "Invalid request parameters") - return - } - - // Update database - err := s.store.Trader().UpdateShowInCompetition(userID, traderID, req.ShowInCompetition) - if err != nil { - SafeInternalError(c, "Update competition visibility", err) - return - } - - // Update in-memory trader if it exists - if trader, err := s.traderManager.GetTrader(traderID); err == nil { - trader.SetShowInCompetition(req.ShowInCompetition) - } - - status := "shown" - if !req.ShowInCompetition { - status = "hidden" - } - logger.Infof("โœ“ Trader %s competition visibility updated: %s", traderID, status) - c.JSON(http.StatusOK, gin.H{ - "message": "Competition visibility updated", - "show_in_competition": req.ShowInCompetition, - }) -} - -// handleGetGridRiskInfo returns current risk information for a grid trader -func (s *Server) handleGetGridRiskInfo(c *gin.Context) { - traderID := c.Param("id") - - autoTrader, err := s.traderManager.GetTrader(traderID) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "trader not found"}) - return - } - - riskInfo := autoTrader.GetGridRiskInfo() - c.JSON(http.StatusOK, riskInfo) -} - -// handleSyncBalance Sync exchange balance to initial_balance (Option B: Manual Sync + Option C: Smart Detection) -func (s *Server) handleSyncBalance(c *gin.Context) { - userID := c.GetString("user_id") - traderID := c.Param("id") - - logger.Infof("๐Ÿ”„ User %s requested balance sync for trader %s", userID, traderID) - - // Get trader configuration from database (including exchange info) - fullConfig, err := s.store.Trader().GetFullConfig(userID, traderID) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"}) - return - } - - traderConfig := fullConfig.Trader - exchangeCfg := fullConfig.Exchange - - if exchangeCfg == nil || !exchangeCfg.Enabled { - c.JSON(http.StatusBadRequest, gin.H{"error": "Exchange not configured or not enabled"}) - return - } - - // Create temporary trader to query balance - var tempTrader trader.Trader - var createErr error - - // Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID) - // Convert EncryptedString fields to string - switch exchangeCfg.ExchangeType { - case "binance": - tempTrader = binance.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID) - case "hyperliquid": - tempTrader, createErr = hyperliquidtrader.NewHyperliquidTrader( - string(exchangeCfg.APIKey), - exchangeCfg.HyperliquidWalletAddr, - exchangeCfg.Testnet, - exchangeCfg.HyperliquidUnifiedAcct, - ) - case "aster": - tempTrader, createErr = aster.NewAsterTrader( - exchangeCfg.AsterUser, - exchangeCfg.AsterSigner, - string(exchangeCfg.AsterPrivateKey), - ) - case "bybit": - tempTrader = bybit.NewBybitTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - ) - case "okx": - tempTrader = okx.NewOKXTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - string(exchangeCfg.Passphrase), - ) - case "bitget": - tempTrader = bitget.NewBitgetTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - string(exchangeCfg.Passphrase), - ) - case "gate": - tempTrader = gate.NewGateTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - ) - case "kucoin": - tempTrader = kucoin.NewKuCoinTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - string(exchangeCfg.Passphrase), - ) - case "lighter": - if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" { - // Lighter only supports mainnet - tempTrader, createErr = lighter.NewLighterTraderV2( - exchangeCfg.LighterWalletAddr, - string(exchangeCfg.LighterAPIKeyPrivateKey), - exchangeCfg.LighterAPIKeyIndex, - false, // Always use mainnet for Lighter - ) - } else { - createErr = fmt.Errorf("Lighter requires wallet address and API Key private key") - } - default: - c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported exchange type"}) - return - } - - if createErr != nil { - logger.Infof("โš ๏ธ Failed to create temporary trader: %v", createErr) - SafeInternalError(c, "Failed to connect to exchange", createErr) - return - } - - // Query actual balance - balanceInfo, balanceErr := tempTrader.GetBalance() - if balanceErr != nil { - logger.Infof("โš ๏ธ Failed to query exchange balance: %v", balanceErr) - SafeInternalError(c, "Failed to query balance", balanceErr) - return - } - - // Extract total equity (for P&L calculation, we need total account value, not available balance) - var actualBalance float64 - // Priority: total_equity > totalWalletBalance > wallet_balance > totalEq > balance - balanceKeys := []string{"total_equity", "totalWalletBalance", "wallet_balance", "totalEq", "balance"} - for _, key := range balanceKeys { - if balance, ok := balanceInfo[key].(float64); ok && balance > 0 { - actualBalance = balance - break - } - } - if actualBalance <= 0 { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Unable to get total equity"}) - return - } - - oldBalance := traderConfig.InitialBalance - - // โœ… Option C: Smart balance change detection - changePercent := ((actualBalance - oldBalance) / oldBalance) * 100 - changeType := "increase" - if changePercent < 0 { - changeType = "decrease" - } - - logger.Infof("โœ“ Queried actual exchange balance: %.2f USDT (current config: %.2f USDT, change: %.2f%%)", - actualBalance, oldBalance, changePercent) - - // Update initial_balance in database - err = s.store.Trader().UpdateInitialBalance(userID, traderID, actualBalance) - if err != nil { - logger.Infof("โŒ Failed to update initial_balance: %v", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update balance"}) - return - } - - // Reload traders into memory - err = s.traderManager.LoadUserTradersFromStore(s.store, userID) - if err != nil { - logger.Infof("โš ๏ธ Failed to reload user traders into memory: %v", err) - } - - logger.Infof("โœ… Synced balance: %.2f โ†’ %.2f USDT (%s %.2f%%)", oldBalance, actualBalance, changeType, changePercent) - - c.JSON(http.StatusOK, gin.H{ - "message": "Balance synced successfully", - "old_balance": oldBalance, - "new_balance": actualBalance, - "change_percent": changePercent, - "change_type": changeType, - }) -} - -// handleClosePosition One-click close position -func (s *Server) handleClosePosition(c *gin.Context) { - userID := c.GetString("user_id") - traderID := c.Param("id") - - var req struct { - Symbol string `json:"symbol" binding:"required"` - Side string `json:"side" binding:"required"` // "LONG" or "SHORT" - } - - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Parameter error: symbol and side are required"}) - return - } - - logger.Infof("๐Ÿ”ป User %s requested position close: trader=%s, symbol=%s, side=%s", userID, traderID, req.Symbol, req.Side) - - // Get trader configuration from database (including exchange info) - fullConfig, err := s.store.Trader().GetFullConfig(userID, traderID) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"}) - return - } - - exchangeCfg := fullConfig.Exchange - - if exchangeCfg == nil || !exchangeCfg.Enabled { - c.JSON(http.StatusBadRequest, gin.H{"error": "Exchange not configured or not enabled"}) - return - } - - // Create temporary trader to execute close position - var tempTrader trader.Trader - var createErr error - - // Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID) - // Convert EncryptedString fields to string - switch exchangeCfg.ExchangeType { - case "binance": - tempTrader = binance.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID) - case "hyperliquid": - tempTrader, createErr = hyperliquidtrader.NewHyperliquidTrader( - string(exchangeCfg.APIKey), - exchangeCfg.HyperliquidWalletAddr, - exchangeCfg.Testnet, - exchangeCfg.HyperliquidUnifiedAcct, - ) - case "aster": - tempTrader, createErr = aster.NewAsterTrader( - exchangeCfg.AsterUser, - exchangeCfg.AsterSigner, - string(exchangeCfg.AsterPrivateKey), - ) - case "bybit": - tempTrader = bybit.NewBybitTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - ) - case "okx": - tempTrader = okx.NewOKXTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - string(exchangeCfg.Passphrase), - ) - case "bitget": - tempTrader = bitget.NewBitgetTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - string(exchangeCfg.Passphrase), - ) - case "gate": - tempTrader = gate.NewGateTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - ) - case "kucoin": - tempTrader = kucoin.NewKuCoinTrader( - string(exchangeCfg.APIKey), - string(exchangeCfg.SecretKey), - string(exchangeCfg.Passphrase), - ) - case "lighter": - if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" { - // Lighter only supports mainnet - tempTrader, createErr = lighter.NewLighterTraderV2( - exchangeCfg.LighterWalletAddr, - string(exchangeCfg.LighterAPIKeyPrivateKey), - exchangeCfg.LighterAPIKeyIndex, - false, // Always use mainnet for Lighter - ) - } else { - createErr = fmt.Errorf("Lighter requires wallet address and API Key private key") - } - default: - c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported exchange type"}) - return - } - - if createErr != nil { - logger.Infof("โš ๏ธ Failed to create temporary trader: %v", createErr) - SafeInternalError(c, "Failed to connect to exchange", createErr) - return - } - - // Get current position info BEFORE closing (to get quantity and price) - positions, err := tempTrader.GetPositions() - if err != nil { - logger.Infof("โš ๏ธ Failed to get positions: %v", err) - } - - var posQty float64 - var entryPrice float64 - for _, pos := range positions { - if pos["symbol"] == req.Symbol && pos["side"] == strings.ToLower(req.Side) { - if amt, ok := pos["positionAmt"].(float64); ok { - posQty = amt - if posQty < 0 { - posQty = -posQty // Make positive - } - } - if price, ok := pos["entryPrice"].(float64); ok { - entryPrice = price - } - break - } - } - - // Execute close position operation - var result map[string]interface{} - var closeErr error - - if req.Side == "LONG" { - result, closeErr = tempTrader.CloseLong(req.Symbol, 0) // 0 means close all - } else if req.Side == "SHORT" { - result, closeErr = tempTrader.CloseShort(req.Symbol, 0) // 0 means close all - } else { - c.JSON(http.StatusBadRequest, gin.H{"error": "side must be LONG or SHORT"}) - return - } - - if closeErr != nil { - logger.Infof("โŒ Close position failed: symbol=%s, side=%s, error=%v", req.Symbol, req.Side, closeErr) - SafeInternalError(c, "Close position", closeErr) - return - } - - logger.Infof("โœ… Position closed successfully: symbol=%s, side=%s, qty=%.6f, result=%v", req.Symbol, req.Side, posQty, result) - - // Record order to database (for chart markers and history) - s.recordClosePositionOrder(traderID, exchangeCfg.ID, exchangeCfg.ExchangeType, req.Symbol, req.Side, posQty, entryPrice, result) - - c.JSON(http.StatusOK, gin.H{ - "message": "Position closed successfully", - "symbol": req.Symbol, - "side": req.Side, - "result": result, - }) -} - -// recordClosePositionOrder Record close position order to database (Lighter version - direct FILLED status) -func (s *Server) recordClosePositionOrder(traderID, exchangeID, exchangeType, symbol, side string, quantity, exitPrice float64, result map[string]interface{}) { - // Skip for exchanges with OrderSync - let the background sync handle it to avoid duplicates - switch exchangeType { - case "binance", "lighter", "hyperliquid", "bybit", "okx", "bitget", "aster", "gate": - logger.Infof(" ๐Ÿ“ Close order will be synced by OrderSync, skipping immediate record") - return - } - - // Check if order was placed (skip if NO_POSITION) - status, _ := result["status"].(string) - if status == "NO_POSITION" { - logger.Infof(" โš ๏ธ No position to close, skipping order record") - return - } - - // Get order ID from result - var orderID string - switch v := result["orderId"].(type) { - case int64: - orderID = fmt.Sprintf("%d", v) - case float64: - orderID = fmt.Sprintf("%.0f", v) - case string: - orderID = v - default: - orderID = fmt.Sprintf("%v", v) - } - - if orderID == "" || orderID == "0" { - logger.Infof(" โš ๏ธ Order ID is empty, skipping record") - return - } - - // Determine order action based on side - var orderAction string - if side == "LONG" { - orderAction = "close_long" - } else { - orderAction = "close_short" - } - - // Use entry price if exit price not available - if exitPrice == 0 { - exitPrice = quantity * 100 // Rough estimate if we don't have price - } - - // Estimate fee (0.04% for Lighter taker) - fee := exitPrice * quantity * 0.0004 - - // Create order record - DIRECTLY as FILLED (Lighter market orders fill immediately) - orderRecord := &store.TraderOrder{ - TraderID: traderID, - ExchangeID: exchangeID, - ExchangeType: exchangeType, - ExchangeOrderID: orderID, - Symbol: symbol, - PositionSide: side, - OrderAction: orderAction, - Type: "MARKET", - Side: getSideFromAction(orderAction), - Quantity: quantity, - Price: 0, // Market order - Status: "FILLED", - FilledQuantity: quantity, - AvgFillPrice: exitPrice, - Commission: fee, - FilledAt: time.Now().UTC().UnixMilli(), - CreatedAt: time.Now().UTC().UnixMilli(), - UpdatedAt: time.Now().UTC().UnixMilli(), - } - - if err := s.store.Order().CreateOrder(orderRecord); err != nil { - logger.Infof(" โš ๏ธ Failed to record order: %v", err) - return - } - - logger.Infof(" โœ… Order recorded as FILLED: %s [%s] %s qty=%.6f price=%.6f", orderID, orderAction, symbol, quantity, exitPrice) - - // Create fill record immediately - tradeID := fmt.Sprintf("%s-%d", orderID, time.Now().UnixNano()) - fillRecord := &store.TraderFill{ - TraderID: traderID, - ExchangeID: exchangeID, - ExchangeType: exchangeType, - OrderID: orderRecord.ID, - ExchangeOrderID: orderID, - ExchangeTradeID: tradeID, - Symbol: symbol, - Side: getSideFromAction(orderAction), - Price: exitPrice, - Quantity: quantity, - QuoteQuantity: exitPrice * quantity, - Commission: fee, - CommissionAsset: "USDT", - RealizedPnL: 0, - IsMaker: false, - CreatedAt: time.Now().UTC().UnixMilli(), - } - - if err := s.store.Order().CreateFill(fillRecord); err != nil { - logger.Infof(" โš ๏ธ Failed to record fill: %v", err) - } else { - logger.Infof(" โœ… Fill record created: price=%.6f qty=%.6f", exitPrice, quantity) - } -} - -// pollAndUpdateOrderStatus Poll order status and update with fill data -func (s *Server) pollAndUpdateOrderStatus(orderRecordID int64, traderID, exchangeID, exchangeType, orderID, symbol, orderAction string, tempTrader trader.Trader) { - var actualPrice float64 - var actualQty float64 - var fee float64 - - // Wait a bit for order to be filled - time.Sleep(500 * time.Millisecond) - - // For Lighter, use GetTrades instead of GetOrderStatus (market orders are filled immediately) - if exchangeType == "lighter" { - s.pollLighterTradeHistory(orderRecordID, traderID, exchangeID, exchangeType, orderID, symbol, orderAction, tempTrader) - return - } - - // For other exchanges, poll GetOrderStatus - for i := 0; i < 5; i++ { - status, err := tempTrader.GetOrderStatus(symbol, orderID) - if err != nil { - logger.Infof(" โš ๏ธ GetOrderStatus failed (attempt %d/5): %v", i+1, err) - time.Sleep(500 * time.Millisecond) - continue - } - if err == nil { - statusStr, _ := status["status"].(string) - if statusStr == "FILLED" { - // Get actual fill price - if avgPrice, ok := status["avgPrice"].(float64); ok && avgPrice > 0 { - actualPrice = avgPrice - } - // Get actual executed quantity - if execQty, ok := status["executedQty"].(float64); ok && execQty > 0 { - actualQty = execQty - } - // Get commission/fee - if commission, ok := status["commission"].(float64); ok { - fee = commission - } - - logger.Infof(" โœ… Order filled: avgPrice=%.6f, qty=%.6f, fee=%.6f", actualPrice, actualQty, fee) - - // Update order status to FILLED - if err := s.store.Order().UpdateOrderStatus(orderRecordID, "FILLED", actualQty, actualPrice, fee); err != nil { - logger.Infof(" โš ๏ธ Failed to update order status: %v", err) - return - } - - // Record fill details - tradeID := fmt.Sprintf("%s-%d", orderID, time.Now().UnixNano()) - fillRecord := &store.TraderFill{ - TraderID: traderID, - ExchangeID: exchangeID, - ExchangeType: exchangeType, - OrderID: orderRecordID, - ExchangeOrderID: orderID, - ExchangeTradeID: tradeID, - Symbol: symbol, - Side: getSideFromAction(orderAction), - Price: actualPrice, - Quantity: actualQty, - QuoteQuantity: actualPrice * actualQty, - Commission: fee, - CommissionAsset: "USDT", - RealizedPnL: 0, - IsMaker: false, - CreatedAt: time.Now().UTC().UnixMilli(), - } - - if err := s.store.Order().CreateFill(fillRecord); err != nil { - logger.Infof(" โš ๏ธ Failed to record fill: %v", err) - } else { - logger.Infof(" ๐Ÿ“ Fill recorded: price=%.6f, qty=%.6f", actualPrice, actualQty) - } - - return - } else if statusStr == "CANCELED" || statusStr == "EXPIRED" || statusStr == "REJECTED" { - logger.Infof(" โš ๏ธ Order %s, updating status", statusStr) - s.store.Order().UpdateOrderStatus(orderRecordID, statusStr, 0, 0, 0) - return - } - } - time.Sleep(500 * time.Millisecond) - } - - logger.Infof(" โš ๏ธ Failed to confirm order fill after polling, order may still be pending") -} - -// pollLighterTradeHistory No longer used - Lighter orders are marked as FILLED immediately -// Keeping this function stub for compatibility with other exchanges -func (s *Server) pollLighterTradeHistory(orderRecordID int64, traderID, exchangeID, exchangeType, orderID, symbol, orderAction string, tempTrader trader.Trader) { - // For Lighter, orders are now recorded as FILLED immediately in recordClosePositionOrder - // This function is no longer called for Lighter exchange - logger.Infof(" โ„น๏ธ pollLighterTradeHistory called but not needed (order already marked FILLED)") -} - -// getSideFromAction Get order side (BUY/SELL) from order action -func getSideFromAction(action string) string { - switch action { - case "open_long", "close_short": - return "BUY" - case "open_short", "close_long": - return "SELL" - default: - return "BUY" - } -} - -// handleGetModelConfigs Get AI model configurations -func (s *Server) handleGetModelConfigs(c *gin.Context) { - userID := c.GetString("user_id") - logger.Infof("๐Ÿ” Querying AI model configs for user %s", userID) - models, err := s.store.AIModel().List(userID) - if err != nil { - logger.Infof("โŒ Failed to get AI model configs: %v", err) - SafeInternalError(c, "Failed to get AI model configs", err) - return - } - - // If no models in database, return default models - if len(models) == 0 { - logger.Infof("โš ๏ธ No AI models in database, returning defaults") - defaultModels := []SafeModelConfig{ - {ID: "deepseek", Name: "DeepSeek AI", Provider: "deepseek", Enabled: false}, - {ID: "qwen", Name: "Qwen AI", Provider: "qwen", Enabled: false}, - {ID: "openai", Name: "OpenAI", Provider: "openai", Enabled: false}, - {ID: "claude", Name: "Claude AI", Provider: "claude", Enabled: false}, - {ID: "gemini", Name: "Gemini AI", Provider: "gemini", Enabled: false}, - {ID: "grok", Name: "Grok AI", Provider: "grok", Enabled: false}, - {ID: "kimi", Name: "Kimi AI", Provider: "kimi", Enabled: false}, - {ID: "minimax", Name: "MiniMax AI", Provider: "minimax", Enabled: false}, - } - c.JSON(http.StatusOK, defaultModels) - return - } - - logger.Infof("โœ… Found %d AI model configs", len(models)) - - // Convert to safe response structure, remove sensitive information - safeModels := make([]SafeModelConfig, len(models)) - for i, model := range models { - safeModels[i] = SafeModelConfig{ - ID: model.ID, - Name: model.Name, - Provider: model.Provider, - Enabled: model.Enabled, - CustomAPIURL: model.CustomAPIURL, - CustomModelName: model.CustomModelName, - } - } - - c.JSON(http.StatusOK, safeModels) -} - -// handleUpdateModelConfigs Update AI model configurations (supports both encrypted and plain text based on config) -func (s *Server) handleUpdateModelConfigs(c *gin.Context) { - userID := c.GetString("user_id") - cfg := config.Get() - - // Read raw request body - bodyBytes, err := c.GetRawData() - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"}) - return - } - - var req UpdateModelConfigRequest - - // Check if transport encryption is enabled - if !cfg.TransportEncryption { - // Transport encryption disabled, accept plain JSON - if err := json.Unmarshal(bodyBytes, &req); err != nil { - logger.Infof("โŒ Failed to parse plain JSON request: %v", err) - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"}) - return - } - logger.Infof("๐Ÿ“ Received plain text model config (UserID: %s)", userID) - } else { - // Transport encryption enabled, require encrypted payload - var encryptedPayload crypto.EncryptedPayload - if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil { - logger.Infof("โŒ Failed to parse encrypted payload: %v", err) - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format, encrypted transmission required"}) - return - } - - // Verify encrypted data - if encryptedPayload.WrappedKey == "" { - logger.Infof("โŒ Detected unencrypted request (UserID: %s)", userID) - c.JSON(http.StatusBadRequest, gin.H{ - "error": "This endpoint only supports encrypted transmission, please use encrypted client", - "code": "ENCRYPTION_REQUIRED", - "message": "Encrypted transmission is required for security reasons", - }) - return - } - - // Decrypt data - decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload) - if err != nil { - logger.Infof("โŒ Failed to decrypt model config (UserID: %s): %v", userID, err) - c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to decrypt data"}) - return - } - - // Parse decrypted data - if err := json.Unmarshal([]byte(decrypted), &req); err != nil { - logger.Infof("โŒ Failed to parse decrypted data: %v", err) - c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse decrypted data"}) - return - } - logger.Infof("๐Ÿ”“ Decrypted model config data (UserID: %s)", userID) - } - - // Update each model's configuration and track traders that need reload - tradersToReload := make(map[string]bool) - for modelID, modelData := range req.Models { - // SSRF protection: validate custom_api_url before storing - if modelData.CustomAPIURL != "" { - cleanURL := strings.TrimSuffix(modelData.CustomAPIURL, "#") - if err := security.ValidateURL(cleanURL); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid custom_api_url for model %s: %s", modelID, err.Error())}) - return - } - } - - // Find traders using this AI model BEFORE updating - traders, _ := s.store.Trader().ListByAIModelID(userID, modelID) - for _, t := range traders { - tradersToReload[t.ID] = true - } - - err := s.store.AIModel().Update(userID, modelID, modelData.Enabled, modelData.APIKey, modelData.CustomAPIURL, modelData.CustomModelName) - if err != nil { - SafeInternalError(c, fmt.Sprintf("Update model %s", modelID), err) - return - } - } - - // Remove affected traders from memory BEFORE reloading to pick up new config - for traderID := range tradersToReload { - logger.Infof("๐Ÿ”„ Removing trader %s from memory to reload with new AI model config", traderID) - s.traderManager.RemoveTrader(traderID) - } - - // Reload all traders for this user to make new config take effect immediately - err = s.traderManager.LoadUserTradersFromStore(s.store, userID) - if err != nil { - logger.Infof("โš ๏ธ Failed to reload user traders into memory: %v", err) - // Don't return error here since model config was successfully updated to database - } - - logger.Infof("โœ“ AI model config updated: %+v", req.Models) - c.JSON(http.StatusOK, gin.H{"message": "Model configuration updated"}) -} - -// handleGetExchangeConfigs Get exchange configurations -func (s *Server) handleGetExchangeConfigs(c *gin.Context) { - userID := c.GetString("user_id") - logger.Infof("๐Ÿ” Querying exchange configs for user %s", userID) - exchanges, err := s.store.Exchange().List(userID) - if err != nil { - SafeInternalError(c, "Failed to get exchange configs", err) - return - } - - // If no exchanges in database, return empty array (user needs to create accounts) - if len(exchanges) == 0 { - logger.Infof("โš ๏ธ No exchanges in database for user %s", userID) - c.JSON(http.StatusOK, []SafeExchangeConfig{}) - return - } - - logger.Infof("โœ… Found %d exchange configs", len(exchanges)) - - // Convert to safe response structure, remove sensitive information - safeExchanges := make([]SafeExchangeConfig, len(exchanges)) - for i, exchange := range exchanges { - safeExchanges[i] = SafeExchangeConfig{ - ID: exchange.ID, - ExchangeType: exchange.ExchangeType, - AccountName: exchange.AccountName, - Name: exchange.Name, - Type: exchange.Type, - Enabled: exchange.Enabled, - Testnet: exchange.Testnet, - HyperliquidWalletAddr: exchange.HyperliquidWalletAddr, - AsterUser: exchange.AsterUser, - AsterSigner: exchange.AsterSigner, - LighterWalletAddr: exchange.LighterWalletAddr, - } - } - - c.JSON(http.StatusOK, safeExchanges) -} - -// handleUpdateExchangeConfigs Update exchange configurations (supports both encrypted and plain text based on config) -func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) { - userID := c.GetString("user_id") - cfg := config.Get() - - // Read raw request body - bodyBytes, err := c.GetRawData() - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"}) - return - } - - var req UpdateExchangeConfigRequest - - // Check if transport encryption is enabled - if !cfg.TransportEncryption { - // Transport encryption disabled, accept plain JSON - if err := json.Unmarshal(bodyBytes, &req); err != nil { - logger.Infof("โŒ Failed to parse plain JSON request: %v", err) - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"}) - return - } - logger.Infof("๐Ÿ“ Received plain text exchange config (UserID: %s)", userID) - } else { - // Transport encryption enabled, require encrypted payload - var encryptedPayload crypto.EncryptedPayload - if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil { - logger.Infof("โŒ Failed to parse encrypted payload: %v", err) - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format, encrypted transmission required"}) - return - } - - // Verify encrypted data - if encryptedPayload.WrappedKey == "" { - logger.Infof("โŒ Detected unencrypted request (UserID: %s)", userID) - c.JSON(http.StatusBadRequest, gin.H{ - "error": "This endpoint only supports encrypted transmission, please use encrypted client", - "code": "ENCRYPTION_REQUIRED", - "message": "Encrypted transmission is required for security reasons", - }) - return - } - - // Decrypt data - decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload) - if err != nil { - logger.Infof("โŒ Failed to decrypt exchange config (UserID: %s): %v", userID, err) - c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to decrypt data"}) - return - } - - // Parse decrypted data - if err := json.Unmarshal([]byte(decrypted), &req); err != nil { - logger.Infof("โŒ Failed to parse decrypted data: %v", err) - c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse decrypted data"}) - return - } - logger.Infof("๐Ÿ”“ Decrypted exchange config data (UserID: %s)", userID) - } - - // Update each exchange's configuration and track traders that need reload - tradersToReload := make(map[string]bool) - for exchangeID, exchangeData := range req.Exchanges { - // Find traders using this exchange BEFORE updating - traders, _ := s.store.Trader().ListByExchangeID(userID, exchangeID) - for _, t := range traders { - tradersToReload[t.ID] = true - } - - err := s.store.Exchange().Update(userID, exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Passphrase, exchangeData.Testnet, exchangeData.HyperliquidWalletAddr, exchangeData.HyperliquidUnifiedAcct, exchangeData.AsterUser, exchangeData.AsterSigner, exchangeData.AsterPrivateKey, exchangeData.LighterWalletAddr, exchangeData.LighterPrivateKey, exchangeData.LighterAPIKeyPrivateKey, exchangeData.LighterAPIKeyIndex) - if err != nil { - SafeInternalError(c, fmt.Sprintf("Update exchange %s", exchangeID), err) - return - } - } - - // Remove affected traders from memory BEFORE reloading to pick up new config - for traderID := range tradersToReload { - logger.Infof("๐Ÿ”„ Removing trader %s from memory to reload with new exchange config", traderID) - s.traderManager.RemoveTrader(traderID) - } - - // Reload all traders for this user to make new config take effect immediately - err = s.traderManager.LoadUserTradersFromStore(s.store, userID) - if err != nil { - logger.Infof("โš ๏ธ Failed to reload user traders into memory: %v", err) - // Don't return error here since exchange config was successfully updated to database - } - - logger.Infof("โœ“ Exchange config updated: %+v", req.Exchanges) - c.JSON(http.StatusOK, gin.H{"message": "Exchange configuration updated"}) -} - -// CreateExchangeRequest request structure for creating a new exchange account -type CreateExchangeRequest struct { - ExchangeType string `json:"exchange_type" binding:"required"` // "binance", "bybit", "okx", "hyperliquid", "aster", "lighter" - AccountName string `json:"account_name"` // User-defined account name - Enabled bool `json:"enabled"` - APIKey string `json:"api_key"` - SecretKey string `json:"secret_key"` - Passphrase string `json:"passphrase"` - Testnet bool `json:"testnet"` - HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"` - HyperliquidUnifiedAcct bool `json:"hyperliquid_unified_account"` // Unified Account mode: Spot as Perp collateral - AsterUser string `json:"aster_user"` - AsterSigner string `json:"aster_signer"` - AsterPrivateKey string `json:"aster_private_key"` - LighterWalletAddr string `json:"lighter_wallet_addr"` - LighterPrivateKey string `json:"lighter_private_key"` - LighterAPIKeyPrivateKey string `json:"lighter_api_key_private_key"` - LighterAPIKeyIndex int `json:"lighter_api_key_index"` -} - -// handleCreateExchange Create a new exchange account -func (s *Server) handleCreateExchange(c *gin.Context) { - userID := c.GetString("user_id") - cfg := config.Get() - - // Read raw request body - bodyBytes, err := c.GetRawData() - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"}) - return - } - - var req CreateExchangeRequest - - // Check if transport encryption is enabled - if !cfg.TransportEncryption { - // Transport encryption disabled, accept plain JSON - if err := json.Unmarshal(bodyBytes, &req); err != nil { - logger.Infof("โŒ Failed to parse plain JSON request: %v", err) - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"}) - return - } - } else { - // Transport encryption enabled, require encrypted payload - var encryptedPayload crypto.EncryptedPayload - if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format, encrypted transmission required"}) - return - } - - if encryptedPayload.WrappedKey == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "This endpoint only supports encrypted transmission", - "code": "ENCRYPTION_REQUIRED", - "message": "Encrypted transmission is required for security reasons", - }) - return - } - - decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to decrypt data"}) - return - } - - if err := json.Unmarshal([]byte(decrypted), &req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse decrypted data"}) - return - } - } - - // Validate exchange type - validTypes := map[string]bool{ - "binance": true, "bybit": true, "okx": true, "bitget": true, - "hyperliquid": true, "aster": true, "lighter": true, "gate": true, "kucoin": true, "indodax": true, - } - if !validTypes[req.ExchangeType] { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid exchange type: %s", req.ExchangeType)}) - return - } - - // Create new exchange account - id, err := s.store.Exchange().Create( - userID, req.ExchangeType, req.AccountName, req.Enabled, - req.APIKey, req.SecretKey, req.Passphrase, req.Testnet, - req.HyperliquidWalletAddr, req.HyperliquidUnifiedAcct, - req.AsterUser, req.AsterSigner, req.AsterPrivateKey, - req.LighterWalletAddr, req.LighterPrivateKey, req.LighterAPIKeyPrivateKey, req.LighterAPIKeyIndex, - ) - if err != nil { - logger.Infof("โŒ Failed to create exchange account: %v", err) - SafeInternalError(c, "Failed to create exchange account", err) - return - } - - logger.Infof("โœ“ Created exchange account: type=%s, name=%s, id=%s", req.ExchangeType, req.AccountName, id) - c.JSON(http.StatusOK, gin.H{ - "message": "Exchange account created", - "id": id, - }) -} - -// handleDeleteExchange Delete an exchange account -func (s *Server) handleDeleteExchange(c *gin.Context) { - userID := c.GetString("user_id") - exchangeID := c.Param("id") - - if exchangeID == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "Exchange ID is required"}) - return - } - - // Check if any traders are using this exchange - traders, err := s.store.Trader().List(userID) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check traders"}) - return - } - - for _, trader := range traders { - if trader.ExchangeID == exchangeID { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "Cannot delete exchange account that is in use by traders", - "trader_id": trader.ID, - "trader_name": trader.Name, - }) - return - } - } - - // Delete exchange account - err = s.store.Exchange().Delete(userID, exchangeID) - if err != nil { - logger.Infof("โŒ Failed to delete exchange account: %v", err) - SafeInternalError(c, "Failed to delete exchange account", err) - return - } - - logger.Infof("โœ“ Deleted exchange account: id=%s", exchangeID) - c.JSON(http.StatusOK, gin.H{"message": "Exchange account deleted"}) -} - -// handleTraderList Trader list -func (s *Server) handleTraderList(c *gin.Context) { - userID := c.GetString("user_id") - traders, err := s.store.Trader().List(userID) - if err != nil { - SafeInternalError(c, "Failed to get trader list", err) - return - } - - result := make([]map[string]interface{}, 0, len(traders)) - for _, trader := range traders { - // Get real-time running status - isRunning := trader.IsRunning - if at, err := s.traderManager.GetTrader(trader.ID); err == nil { - status := at.GetStatus() - if running, ok := status["is_running"].(bool); ok { - isRunning = running - } - } - - // Get strategy name if strategy_id is set - var strategyName string - if trader.StrategyID != "" { - if strategy, err := s.store.Strategy().Get(userID, trader.StrategyID); err == nil { - strategyName = strategy.Name - } - } - - // Return complete AIModelID (e.g. "admin_deepseek"), don't truncate - // Frontend needs complete ID to verify model exists (consistent with handleGetTraderConfig) - result = append(result, map[string]interface{}{ - "trader_id": trader.ID, - "trader_name": trader.Name, - "ai_model": trader.AIModelID, // Use complete ID - "exchange_id": trader.ExchangeID, - "is_running": isRunning, - "show_in_competition": trader.ShowInCompetition, - "initial_balance": trader.InitialBalance, - "strategy_id": trader.StrategyID, - "strategy_name": strategyName, - }) - } - - c.JSON(http.StatusOK, result) -} - -// handleGetTraderConfig Get trader detailed configuration -func (s *Server) handleGetTraderConfig(c *gin.Context) { - userID := c.GetString("user_id") - traderID := c.Param("id") - - if traderID == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "Trader ID cannot be empty"}) - return - } - - fullCfg, err := s.store.Trader().GetFullConfig(userID, traderID) - if err != nil { - SafeNotFound(c, "Trader config") - return - } - traderConfig := fullCfg.Trader - - // Get real-time running status - isRunning := traderConfig.IsRunning - if at, err := s.traderManager.GetTrader(traderID); err == nil { - status := at.GetStatus() - if running, ok := status["is_running"].(bool); ok { - isRunning = running - } - } - - // Return complete model ID without conversion, consistent with frontend model list - aiModelID := traderConfig.AIModelID - - result := map[string]interface{}{ - "trader_id": traderConfig.ID, - "trader_name": traderConfig.Name, - "ai_model": aiModelID, - "exchange_id": traderConfig.ExchangeID, - "strategy_id": traderConfig.StrategyID, - "initial_balance": traderConfig.InitialBalance, - "scan_interval_minutes": traderConfig.ScanIntervalMinutes, - "btc_eth_leverage": traderConfig.BTCETHLeverage, - "altcoin_leverage": traderConfig.AltcoinLeverage, - "trading_symbols": traderConfig.TradingSymbols, - "custom_prompt": traderConfig.CustomPrompt, - "override_base_prompt": traderConfig.OverrideBasePrompt, - "is_cross_margin": traderConfig.IsCrossMargin, - "use_ai500": traderConfig.UseAI500, - "use_oi_top": traderConfig.UseOITop, - "is_running": isRunning, - } - - c.JSON(http.StatusOK, result) -} - -// handleStatus System status -func (s *Server) handleStatus(c *gin.Context) { - _, traderID, err := s.getTraderFromQuery(c) - if err != nil { - SafeBadRequest(c, "Invalid trader ID") - return - } - - trader, err := s.traderManager.GetTrader(traderID) - if err != nil { - SafeNotFound(c, "Trader") - return - } - - status := trader.GetStatus() - c.JSON(http.StatusOK, status) -} - -// handleAccount Account information -func (s *Server) handleAccount(c *gin.Context) { - _, traderID, err := s.getTraderFromQuery(c) - if err != nil { - SafeBadRequest(c, "Invalid trader ID") - return - } - - trader, err := s.traderManager.GetTrader(traderID) - if err != nil { - SafeNotFound(c, "Trader") - return - } - - logger.Infof("๐Ÿ“Š Received account info request [%s]", trader.GetName()) - account, err := trader.GetAccountInfo() - if err != nil { - SafeInternalError(c, "Get account info", err) - return - } - - logger.Infof("โœ“ Returning account info [%s]: equity=%.2f, available=%.2f, pnl=%.2f (%.2f%%)", - trader.GetName(), - account["total_equity"], - account["available_balance"], - account["total_pnl"], - account["total_pnl_pct"]) - c.JSON(http.StatusOK, account) -} - -// handlePositions Position list -func (s *Server) handlePositions(c *gin.Context) { - _, traderID, err := s.getTraderFromQuery(c) - if err != nil { - SafeBadRequest(c, "Invalid trader ID") - return - } - - trader, err := s.traderManager.GetTrader(traderID) - if err != nil { - SafeNotFound(c, "Trader") - return - } - - positions, err := trader.GetPositions() - if err != nil { - SafeInternalError(c, "Get positions", err) - return - } - - c.JSON(http.StatusOK, positions) -} - -// handlePositionHistory Historical closed positions with statistics -func (s *Server) handlePositionHistory(c *gin.Context) { - _, traderID, err := s.getTraderFromQuery(c) - if err != nil { - SafeBadRequest(c, "Invalid trader ID") - return - } - - trader, err := s.traderManager.GetTrader(traderID) - if err != nil { - SafeNotFound(c, "Trader") - return - } - - // Get optional query parameters - limitStr := c.DefaultQuery("limit", "100") - limit := 100 - if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 500 { - limit = l - } - - // Get store - store := trader.GetStore() - if store == nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"}) - return - } - - // Get closed positions - positions, err := store.Position().GetClosedPositions(trader.GetID(), limit) - if err != nil { - SafeInternalError(c, "Get position history", err) - return - } - - // Get statistics - stats, _ := store.Position().GetFullStats(trader.GetID()) - - // Get symbol stats - symbolStats, _ := store.Position().GetSymbolStats(trader.GetID(), 10) - - // Get direction stats - directionStats, _ := store.Position().GetDirectionStats(trader.GetID()) - - c.JSON(http.StatusOK, gin.H{ - "positions": positions, - "stats": stats, - "symbol_stats": symbolStats, - "direction_stats": directionStats, - }) -} - -// handleTrades Historical trades list -func (s *Server) handleTrades(c *gin.Context) { - _, traderID, err := s.getTraderFromQuery(c) - if err != nil { - SafeBadRequest(c, "Invalid trader ID") - return - } - - trader, err := s.traderManager.GetTrader(traderID) - if err != nil { - SafeNotFound(c, "Trader") - return - } - - // Get optional query parameters - symbol := c.Query("symbol") - limitStr := c.DefaultQuery("limit", "100") - limit := 100 - if l, err := strconv.Atoi(limitStr); err == nil && l > 0 { - limit = l - } - - // Normalize symbol (add USDT suffix if not present) - if symbol != "" { - symbol = market.Normalize(symbol) - } - - // Get trades from store - store := trader.GetStore() - if store == nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"}) - return - } - - allTrades, err := store.Position().GetRecentTrades(trader.GetID(), limit) - if err != nil { - SafeInternalError(c, "Get trades", err) - return - } - - // Filter by symbol if specified - if symbol != "" { - var result []interface{} - for _, trade := range allTrades { - if trade.Symbol == symbol { - result = append(result, trade) - } - } - c.JSON(http.StatusOK, result) - return - } - - c.JSON(http.StatusOK, allTrades) -} - -// handleOrders Order list (all orders including open, close, stop loss, take profit, etc.) -func (s *Server) handleOrders(c *gin.Context) { - _, traderID, err := s.getTraderFromQuery(c) - if err != nil { - SafeBadRequest(c, "Invalid trader ID") - return - } - - trader, err := s.traderManager.GetTrader(traderID) - if err != nil { - SafeNotFound(c, "Trader") - return - } - - // Get optional query parameters - symbol := c.Query("symbol") - statusFilter := c.Query("status") // NEW, FILLED, CANCELED, etc. - limitStr := c.DefaultQuery("limit", "100") - limit := 100 - if l, err := strconv.Atoi(limitStr); err == nil && l > 0 { - limit = l - } - - // Normalize symbol (add USDT suffix if not present) - if symbol != "" { - symbol = market.Normalize(symbol) - } - - // Get orders from store - store := trader.GetStore() - if store == nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"}) - return - } - - // Get orders with filters applied at database level - orders, err := store.Order().GetTraderOrdersFiltered(trader.GetID(), symbol, statusFilter, limit) - if err != nil { - SafeInternalError(c, "Get orders", err) - return - } - - c.JSON(http.StatusOK, orders) -} - -// handleOrderFills Order fill details (all fills for a specific order) -func (s *Server) handleOrderFills(c *gin.Context) { - orderIDStr := c.Param("id") - orderID, err := strconv.ParseInt(orderIDStr, 10, 64) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid order ID"}) - return - } - - _, traderID, err := s.getTraderFromQuery(c) - if err != nil { - SafeBadRequest(c, "Invalid trader ID") - return - } - - trader, err := s.traderManager.GetTrader(traderID) - if err != nil { - SafeNotFound(c, "Trader") - return - } - - store := trader.GetStore() - if store == nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"}) - return - } - - // Get fills for this order - fills, err := store.Order().GetOrderFills(orderID) - if err != nil { - SafeInternalError(c, "Get order fills", err) - return - } - - c.JSON(http.StatusOK, fills) -} - -// handleOpenOrders Get open orders (pending SL/TP) from exchange -func (s *Server) handleOpenOrders(c *gin.Context) { - _, traderID, err := s.getTraderFromQuery(c) - if err != nil { - SafeBadRequest(c, "Invalid trader ID") - return - } - - trader, err := s.traderManager.GetTrader(traderID) - if err != nil { - SafeNotFound(c, "Trader") - return - } - - // Get symbol parameter (required for exchange query) - symbol := c.Query("symbol") - if symbol == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "symbol parameter is required"}) - return - } - - // Normalize symbol - symbol = market.Normalize(symbol) - - // Get open orders from exchange - openOrders, err := trader.GetOpenOrders(symbol) - if err != nil { - SafeInternalError(c, "Get open orders", err) - return - } - - c.JSON(http.StatusOK, openOrders) -} - -// handleKlines K-line data (supports multiple exchanges via coinank) -func (s *Server) handleKlines(c *gin.Context) { - // Get query parameters - symbol := c.Query("symbol") - if symbol == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "symbol parameter is required"}) - return - } - - interval := c.DefaultQuery("interval", "5m") - exchange := c.DefaultQuery("exchange", "binance") // Default to binance for backward compatibility - limitStr := c.DefaultQuery("limit", "1000") - limit, err := strconv.Atoi(limitStr) - if err != nil || limit <= 0 { - limit = 1000 - } - - // Coinank API has a maximum limit of 1500 klines per request - if limit > 1500 { - limit = 1500 - } - - var klines []market.Kline - exchangeLower := strings.ToLower(exchange) - - // Route to appropriate data source based on exchange type - switch exchangeLower { - case "alpaca": - // US Stocks via Alpaca - klines, err = s.getKlinesFromAlpaca(symbol, interval, limit) - if err != nil { - SafeInternalError(c, "Get klines from Alpaca", err) - return - } - case "forex", "metals": - // Forex and Metals via Twelve Data - klines, err = s.getKlinesFromTwelveData(symbol, interval, limit) - if err != nil { - SafeInternalError(c, "Get klines from TwelveData", err) - return - } - case "hyperliquid", "hyperliquid-xyz", "xyz": - // Hyperliquid native API - supports both crypto perps and stock perps (xyz dex) - klines, err = s.getKlinesFromHyperliquid(symbol, interval, limit) - if err != nil { - SafeInternalError(c, "Get klines from Hyperliquid", err) - return - } - default: - // Crypto exchanges via CoinAnk - symbol = market.Normalize(symbol) - klines, err = s.getKlinesFromCoinank(symbol, interval, exchange, limit) - if err != nil { - SafeInternalError(c, "Get klines from CoinAnk", err) - return - } - } - - c.JSON(http.StatusOK, klines) -} - -// getKlinesFromCoinank fetches kline data from coinank free/open API for multiple exchanges -func (s *Server) getKlinesFromCoinank(symbol, interval, exchange string, limit int) ([]market.Kline, error) { - // Map exchange string to coinank enum - var coinankExchange coinank_enum.Exchange - switch strings.ToLower(exchange) { - case "binance": - coinankExchange = coinank_enum.Binance - case "bybit": - coinankExchange = coinank_enum.Bybit - case "okx": - coinankExchange = coinank_enum.Okex - case "bitget": - coinankExchange = coinank_enum.Bitget - case "gate": - coinankExchange = coinank_enum.Gate - case "aster": - coinankExchange = coinank_enum.Aster - case "lighter": - // Lighter doesn't have direct CoinAnk support, use Binance data as fallback - coinankExchange = coinank_enum.Binance - case "kucoin": - // KuCoin doesn't have direct CoinAnk support, use Binance data as fallback - coinankExchange = coinank_enum.Binance - default: - // For any unknown exchange, default to Binance - logger.Warnf("โš ๏ธ Unknown exchange '%s', defaulting to Binance for CoinAnk", exchange) - coinankExchange = coinank_enum.Binance - } - - // Map interval string to coinank enum - var coinankInterval coinank_enum.Interval - switch interval { - case "1s": - coinankInterval = coinank_enum.Second1 - case "5s": - coinankInterval = coinank_enum.Second5 - case "10s": - coinankInterval = coinank_enum.Second10 - case "30s": - coinankInterval = coinank_enum.Second30 - case "1m": - coinankInterval = coinank_enum.Minute1 - case "3m": - coinankInterval = coinank_enum.Minute3 - case "5m": - coinankInterval = coinank_enum.Minute5 - case "10m": - coinankInterval = coinank_enum.Minute10 - case "15m": - coinankInterval = coinank_enum.Minute15 - case "30m": - coinankInterval = coinank_enum.Minute30 - case "1h": - coinankInterval = coinank_enum.Hour1 - case "2h": - coinankInterval = coinank_enum.Hour2 - case "4h": - coinankInterval = coinank_enum.Hour4 - case "6h": - coinankInterval = coinank_enum.Hour6 - case "8h": - coinankInterval = coinank_enum.Hour8 - case "12h": - coinankInterval = coinank_enum.Hour12 - case "1d": - coinankInterval = coinank_enum.Day1 - case "3d": - coinankInterval = coinank_enum.Day3 - case "1w": - coinankInterval = coinank_enum.Week1 - case "1M": - coinankInterval = coinank_enum.Month1 - default: - return nil, fmt.Errorf("unsupported interval for coinank: %s", interval) - } - - // Convert symbol format for different exchanges - // OKX uses "BTC-USDT-SWAP" format instead of "BTCUSDT" - apiSymbol := symbol - if coinankExchange == coinank_enum.Okex { - // Convert BTCUSDT -> BTC-USDT-SWAP - if strings.HasSuffix(symbol, "USDT") { - base := strings.TrimSuffix(symbol, "USDT") - apiSymbol = fmt.Sprintf("%s-USDT-SWAP", base) - } - } - - // Call coinank free/open API (no authentication required) - ctx := context.Background() - ts := time.Now().UnixMilli() - // Use "To" side to search backward from current time (get historical klines) - coinankKlines, err := coinank_api.Kline(ctx, apiSymbol, coinankExchange, ts, coinank_enum.To, limit, coinankInterval) - if err != nil { - // Free API doesn't support all exchanges (e.g., OKX, Bitget) - // Fallback to Binance data as reference - if coinankExchange != coinank_enum.Binance { - logger.Warnf("โš ๏ธ CoinAnk free API doesn't support %s, falling back to Binance data", coinankExchange) - coinankKlines, err = coinank_api.Kline(ctx, symbol, coinank_enum.Binance, ts, coinank_enum.To, limit, coinankInterval) - if err != nil { - return nil, fmt.Errorf("coinank API error (fallback): %w", err) - } - } else { - return nil, fmt.Errorf("coinank API error: %w", err) - } - } - - // Convert coinank kline format to market.Kline format - // Coinank: Volume = BTC ๆ•ฐ้‡, Quantity = USDT ๆˆไบค้ข - klines := make([]market.Kline, len(coinankKlines)) - for i, ck := range coinankKlines { - klines[i] = market.Kline{ - OpenTime: ck.StartTime, - Open: ck.Open, - High: ck.High, - Low: ck.Low, - Close: ck.Close, - Volume: ck.Volume, // BTC ๆ•ฐ้‡ - QuoteVolume: ck.Quantity, // USDT ๆˆไบค้ข - CloseTime: ck.EndTime, - } - } - - return klines, nil -} - -// getKlinesFromAlpaca fetches kline data from Alpaca API for US stocks -func (s *Server) getKlinesFromAlpaca(symbol, interval string, limit int) ([]market.Kline, error) { - // Create Alpaca client - client := alpaca.NewClient() - - // Map interval to Alpaca timeframe format - timeframe := alpaca.MapTimeframe(interval) - - // Fetch bars from Alpaca - ctx := context.Background() - bars, err := client.GetBars(ctx, symbol, timeframe, limit) - if err != nil { - return nil, fmt.Errorf("alpaca API error: %w", err) - } - - // Convert Alpaca bars to market.Kline format - klines := make([]market.Kline, len(bars)) - for i, bar := range bars { - klines[i] = market.Kline{ - OpenTime: bar.Timestamp.UnixMilli(), - Open: bar.Open, - High: bar.High, - Low: bar.Low, - Close: bar.Close, - Volume: float64(bar.Volume), // ่‚กๆ•ฐ - QuoteVolume: float64(bar.Volume) * bar.Close, // ๆˆไบค้ข = ่‚กๆ•ฐ * ๆ”ถ็›˜ไปท (USD) - CloseTime: bar.Timestamp.UnixMilli(), - } - } - - return klines, nil -} - -// getKlinesFromTwelveData fetches kline data from Twelve Data API for forex and metals -func (s *Server) getKlinesFromTwelveData(symbol, interval string, limit int) ([]market.Kline, error) { - // Create Twelve Data client - client := twelvedata.NewClient() - - // Map interval to Twelve Data timeframe format - timeframe := twelvedata.MapTimeframe(interval) - - // Fetch time series from Twelve Data - ctx := context.Background() - result, err := client.GetTimeSeries(ctx, symbol, timeframe, limit) - if err != nil { - return nil, fmt.Errorf("twelvedata API error: %w", err) - } - - // Convert Twelve Data bars to market.Kline format - // Note: Twelve Data returns bars in reverse order (newest first) - klines := make([]market.Kline, len(result.Values)) - for i, bar := range result.Values { - open, high, low, close, volume, timestamp, err := twelvedata.ParseBar(bar) - if err != nil { - logger.Warnf("โš ๏ธ Failed to parse TwelveData bar: %v", err) - continue - } - - // Reverse order: put oldest first - idx := len(result.Values) - 1 - i - klines[idx] = market.Kline{ - OpenTime: timestamp, - Open: open, - High: high, - Low: low, - Close: close, - Volume: volume, - CloseTime: timestamp, - } - } - - return klines, nil -} - -// getKlinesFromHyperliquid fetches kline data from Hyperliquid API -// Supports both crypto perps (default dex) and stock perps/forex/commodities (xyz dex) -func (s *Server) getKlinesFromHyperliquid(symbol, interval string, limit int) ([]market.Kline, error) { - // Create Hyperliquid client - client := hyperliquid.NewClient() - - // Map interval to Hyperliquid format - timeframe := hyperliquid.MapTimeframe(interval) - - // Fetch candles from Hyperliquid - // FormatCoinForAPI will automatically add xyz: prefix for stock perps - ctx := context.Background() - candles, err := client.GetCandles(ctx, symbol, timeframe, limit) - if err != nil { - return nil, fmt.Errorf("hyperliquid API error: %w", err) - } - - // Convert Hyperliquid candles to market.Kline format - klines := make([]market.Kline, len(candles)) - for i, candle := range candles { - open, _ := strconv.ParseFloat(candle.Open, 64) - high, _ := strconv.ParseFloat(candle.High, 64) - low, _ := strconv.ParseFloat(candle.Low, 64) - close, _ := strconv.ParseFloat(candle.Close, 64) - volume, _ := strconv.ParseFloat(candle.Volume, 64) - - klines[i] = market.Kline{ - OpenTime: candle.OpenTime, - Open: open, - High: high, - Low: low, - Close: close, - Volume: volume, // ๅˆ็บฆๆ•ฐ้‡ - QuoteVolume: volume * close, // ๆˆไบค้ข (USD) - CloseTime: candle.CloseTime, - } - } - - return klines, nil -} - -// handleSymbols returns available symbols for a given exchange -func (s *Server) handleSymbols(c *gin.Context) { - exchange := c.DefaultQuery("exchange", "hyperliquid") - - type SymbolInfo struct { - Symbol string `json:"symbol"` - Name string `json:"name"` - Category string `json:"category"` // crypto, stock, forex, commodity, index - MaxLeverage int `json:"maxLeverage,omitempty"` - } - - var symbols []SymbolInfo - - switch strings.ToLower(exchange) { - case "hyperliquid", "hyperliquid-xyz", "xyz": - // Fetch symbols from Hyperliquid - client := hyperliquid.NewClient() - ctx := context.Background() - - // Get crypto perps from default dex - if exchange == "hyperliquid" || exchange == "hyperliquid-xyz" { - mids, err := client.GetAllMids(ctx) - if err == nil { - for symbol := range mids { - // Skip spot tokens (start with @) - if strings.HasPrefix(symbol, "@") { - continue - } - symbols = append(symbols, SymbolInfo{ - Symbol: symbol, - Name: symbol, - Category: "crypto", - }) - } - } - } - - // Get xyz dex symbols (stocks, forex, commodities) - xyzMids, err := client.GetAllMidsXYZ(ctx) - if err == nil { - for symbol := range xyzMids { - // Remove xyz: prefix for display - displaySymbol := strings.TrimPrefix(symbol, "xyz:") - category := "stock" - if displaySymbol == "GOLD" || displaySymbol == "SILVER" { - category = "commodity" - } else if displaySymbol == "EUR" || displaySymbol == "JPY" { - category = "forex" - } else if displaySymbol == "XYZ100" { - category = "index" - } - symbols = append(symbols, SymbolInfo{ - Symbol: displaySymbol, - Name: displaySymbol, - Category: category, - }) - } - } - - default: - c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported exchange for symbol listing"}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "exchange": exchange, - "symbols": symbols, - "count": len(symbols), - }) -} - -// handleDecisions Decision log list -func (s *Server) handleDecisions(c *gin.Context) { - _, traderID, err := s.getTraderFromQuery(c) - if err != nil { - SafeBadRequest(c, "Invalid trader ID") - return - } - - trader, err := s.traderManager.GetTrader(traderID) - if err != nil { - SafeNotFound(c, "Trader") - return - } - - // Get all historical decision records (unlimited) - records, err := trader.GetStore().Decision().GetLatestRecords(trader.GetID(), 10000) - if err != nil { - SafeInternalError(c, "Get decision log", err) - return - } - - c.JSON(http.StatusOK, records) -} - -// handleLatestDecisions Latest decision logs (newest first, supports limit parameter) -func (s *Server) handleLatestDecisions(c *gin.Context) { - _, traderID, err := s.getTraderFromQuery(c) - if err != nil { - SafeBadRequest(c, "Invalid trader ID") - return - } - - trader, err := s.traderManager.GetTrader(traderID) - if err != nil { - SafeNotFound(c, "Trader") - return - } - - // Get limit from query parameter, default to 5 - limit := 5 - if limitStr := c.Query("limit"); limitStr != "" { - if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 { - limit = parsedLimit - if limit > 100 { - limit = 100 // Max 100 to prevent abuse - } - } - } - - records, err := trader.GetStore().Decision().GetLatestRecords(trader.GetID(), limit) - if err != nil { - SafeInternalError(c, "Get decision log", err) - return - } - - // Reverse array to put newest first (for list display) - // GetLatestRecords returns oldest to newest (for charts), here we need newest to oldest - for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 { - records[i], records[j] = records[j], records[i] - } - - c.JSON(http.StatusOK, records) -} - -// handleStatistics Statistics information -func (s *Server) handleStatistics(c *gin.Context) { - _, traderID, err := s.getTraderFromQuery(c) - if err != nil { - SafeBadRequest(c, "Invalid trader ID") - return - } - - trader, err := s.traderManager.GetTrader(traderID) - if err != nil { - SafeNotFound(c, "Trader") - return - } - - stats, err := trader.GetStore().Decision().GetStatistics(trader.GetID()) - if err != nil { - SafeInternalError(c, "Get statistics", err) - return - } - - c.JSON(http.StatusOK, stats) -} - -// handleCompetition Competition overview (compare all traders) -func (s *Server) handleCompetition(c *gin.Context) { - userID := c.GetString("user_id") - - // Ensure user's traders are loaded into memory - err := s.traderManager.LoadUserTradersFromStore(s.store, userID) - if err != nil { - logger.Infof("โš ๏ธ Failed to load traders for user %s: %v", userID, err) - } - - competition, err := s.traderManager.GetCompetitionData() - if err != nil { - SafeInternalError(c, "Get competition data", err) - return - } - - c.JSON(http.StatusOK, competition) -} - -// handleEquityHistory Return rate historical data -// Query directly from database, not dependent on trader in memory (so historical data can be retrieved after restart) -func (s *Server) handleEquityHistory(c *gin.Context) { - _, traderID, err := s.getTraderFromQuery(c) - if err != nil { - SafeBadRequest(c, "Invalid trader ID") - return - } - - // Get equity historical data from new equity table - // Every 3 minutes per cycle: 10000 records = about 20 days of data - snapshots, err := s.store.Equity().GetLatest(traderID, 10000) - if err != nil { - SafeInternalError(c, "Get historical data", err) - return - } - - if len(snapshots) == 0 { - c.JSON(http.StatusOK, []interface{}{}) - return - } - - // Build return rate historical data points - type EquityPoint struct { - Timestamp string `json:"timestamp"` - TotalEquity float64 `json:"total_equity"` // Account equity (wallet + unrealized) - AvailableBalance float64 `json:"available_balance"` // Available balance - TotalPnL float64 `json:"total_pnl"` // Total PnL (unrealized PnL) - TotalPnLPct float64 `json:"total_pnl_pct"` // Total PnL percentage - PositionCount int `json:"position_count"` // Position count - MarginUsedPct float64 `json:"margin_used_pct"` // Margin used percentage - } - - // Use the balance of the first record as initial balance to calculate return rate - initialBalance := snapshots[0].Balance - if initialBalance == 0 { - initialBalance = 1 // Avoid division by zero - } - - var history []EquityPoint - for _, snap := range snapshots { - // Calculate PnL percentage - totalPnLPct := 0.0 - if initialBalance > 0 { - totalPnLPct = (snap.UnrealizedPnL / initialBalance) * 100 - } - - history = append(history, EquityPoint{ - Timestamp: snap.Timestamp.Format("2006-01-02 15:04:05"), - TotalEquity: snap.TotalEquity, - AvailableBalance: snap.Balance, - TotalPnL: snap.UnrealizedPnL, - TotalPnLPct: totalPnLPct, - PositionCount: snap.PositionCount, - MarginUsedPct: snap.MarginUsedPct, - }) - } - - c.JSON(http.StatusOK, history) -} - // authMiddleware JWT authentication middleware func (s *Server) authMiddleware() gin.HandlerFunc { return func(c *gin.Context) { @@ -3204,256 +569,6 @@ func (s *Server) authMiddleware() gin.HandlerFunc { } } -// handleLogout Add current token to blacklist -func (s *Server) handleLogout(c *gin.Context) { - authHeader := c.GetHeader("Authorization") - if authHeader == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing Authorization header"}) - return - } - parts := strings.Split(authHeader, " ") - if len(parts) != 2 || parts[0] != "Bearer" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"}) - return - } - tokenString := parts[1] - claims, err := auth.ValidateJWT(tokenString) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) - return - } - var exp time.Time - if claims.ExpiresAt != nil { - exp = claims.ExpiresAt.Time - } else { - exp = time.Now().Add(24 * time.Hour) - } - auth.BlacklistToken(tokenString, exp) - c.JSON(http.StatusOK, gin.H{"message": "Logged out"}) -} - -// handleRegister Handle user registration request. -// handleRegister allows registration only when no users exist yet (first-time setup). -// This is a single-user system; subsequent registrations are permanently closed. -func (s *Server) handleRegister(c *gin.Context) { - userCount, err := s.store.User().Count() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check user count"}) - return - } - - if userCount > 0 { - c.JSON(http.StatusForbidden, gin.H{"error": "System already initialized"}) - return - } - - var req struct { - Email string `json:"email" binding:"required,email"` - Password string `json:"password" binding:"required,min=6"` - } - - if err := c.ShouldBindJSON(&req); err != nil { - SafeBadRequest(c, "Invalid request parameters") - return - } - - // Check if email already exists - _, err = s.store.User().GetByEmail(req.Email) - if err == nil { - c.JSON(http.StatusConflict, gin.H{"error": "Email already registered"}) - return - } - - // Generate password hash - passwordHash, err := auth.HashPassword(req.Password) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Password processing failed"}) - return - } - - // Create user - userID := uuid.New().String() - user := &store.User{ - ID: userID, - Email: req.Email, - PasswordHash: passwordHash, - } - - err = s.store.User().Create(user) - if err != nil { - SafeInternalError(c, "Failed to create user", err) - return - } - - // Generate JWT token - token, err := auth.GenerateJWT(user.ID, user.Email) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"}) - return - } - - // Initialize default model and exchange configs for user - err = s.initUserDefaultConfigs(user.ID) - if err != nil { - logger.Infof("Failed to initialize user default configs: %v", err) - } - - c.JSON(http.StatusOK, gin.H{ - "token": token, - "user_id": user.ID, - "email": user.Email, - "message": "Registration successful", - }) -} - -// handleLogin Handle user login request -func (s *Server) handleLogin(c *gin.Context) { - var req struct { - Email string `json:"email" binding:"required,email"` - Password string `json:"password" binding:"required"` - } - - if err := c.ShouldBindJSON(&req); err != nil { - SafeBadRequest(c, "Invalid request parameters") - return - } - - // Get user information - user, err := s.store.User().GetByEmail(req.Email) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Email or password incorrect"}) - return - } - - // Verify password - if !auth.CheckPassword(req.Password, user.PasswordHash) { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Email or password incorrect"}) - return - } - - // Issue token directly after password verification. - token, err := auth.GenerateJWT(user.ID, user.Email) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "token": token, - "user_id": user.ID, - "email": user.Email, - "message": "Login successful", - }) -} - -// handleChangePassword changes the password for the currently authenticated user. -func (s *Server) handleChangePassword(c *gin.Context) { - userID := c.GetString("user_id") - var req struct { - NewPassword string `json:"new_password" binding:"required,min=8"` - } - if err := c.ShouldBindJSON(&req); err != nil { - SafeBadRequest(c, "new_password is required (min 8 chars)") - return - } - hash, err := auth.HashPassword(req.NewPassword) - if err != nil { - SafeInternalError(c, "Password processing failed", err) - return - } - if err := s.store.User().UpdatePassword(userID, hash); err != nil { - SafeInternalError(c, "Failed to update password", err) - return - } - c.JSON(http.StatusOK, gin.H{"message": "Password updated"}) -} - -// handleResetPassword Reset password via email and new password -func (s *Server) handleResetPassword(c *gin.Context) { - var req struct { - Email string `json:"email" binding:"required,email"` - NewPassword string `json:"new_password" binding:"required,min=6"` - } - - if err := c.ShouldBindJSON(&req); err != nil { - SafeBadRequest(c, "Invalid request parameters") - return - } - - // Query user - user, err := s.store.User().GetByEmail(req.Email) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Email does not exist"}) - return - } - - // Generate new password hash - newPasswordHash, err := auth.HashPassword(req.NewPassword) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Password processing failed"}) - return - } - - // Update password - err = s.store.User().UpdatePassword(user.ID, newPasswordHash) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Password update failed"}) - return - } - - logger.Infof("โœ“ User %s password has been reset", user.Email) - c.JSON(http.StatusOK, gin.H{"message": "Password reset successful, please login with new password"}) -} - -// initUserDefaultConfigs Initialize default model and exchange configs for new user -func (s *Server) initUserDefaultConfigs(userID string) error { - // Commented out auto-creation of default configs, let users add manually - // This way new users won't have config items automatically after registration - logger.Infof("User %s registration completed, waiting for manual AI model and exchange configuration", userID) - return nil -} - -// handleGetSupportedModels Get list of AI models supported by the system -func (s *Server) handleGetSupportedModels(c *gin.Context) { - // Return static list of supported AI models with default versions - supportedModels := []map[string]interface{}{ - {"id": "deepseek", "name": "DeepSeek", "provider": "deepseek", "defaultModel": "deepseek-chat"}, - {"id": "qwen", "name": "Qwen", "provider": "qwen", "defaultModel": "qwen3-max"}, - {"id": "openai", "name": "OpenAI", "provider": "openai", "defaultModel": "gpt-5.1"}, - {"id": "claude", "name": "Claude", "provider": "claude", "defaultModel": "claude-opus-4-6"}, - {"id": "gemini", "name": "Google Gemini", "provider": "gemini", "defaultModel": "gemini-3-pro-preview"}, - {"id": "grok", "name": "Grok (xAI)", "provider": "grok", "defaultModel": "grok-3-latest"}, - {"id": "kimi", "name": "Kimi (Moonshot)", "provider": "kimi", "defaultModel": "moonshot-v1-auto"}, - {"id": "minimax", "name": "MiniMax", "provider": "minimax", "defaultModel": "MiniMax-M2.5"}, - {"id": "blockrun-base", "name": "BlockRun (Base Wallet)", "provider": "blockrun-base", "defaultModel": "auto"}, - {"id": "blockrun-sol", "name": "BlockRun (Solana Wallet)", "provider": "blockrun-sol", "defaultModel": "auto"}, - {"id": "claw402", "name": "Claw402 (Base USDC)", "provider": "claw402", "defaultModel": "deepseek"}, - } - - c.JSON(http.StatusOK, supportedModels) -} - -// handleGetSupportedExchanges Get list of exchanges supported by the system -func (s *Server) handleGetSupportedExchanges(c *gin.Context) { - // Return static list of supported exchange types - // Note: ID is empty for supported exchanges (they are templates, not actual accounts) - supportedExchanges := []SafeExchangeConfig{ - {ExchangeType: "binance", Name: "Binance Futures", Type: "cex"}, - {ExchangeType: "bybit", Name: "Bybit Futures", Type: "cex"}, - {ExchangeType: "okx", Name: "OKX Futures", Type: "cex"}, - {ExchangeType: "gate", Name: "Gate.io Futures", Type: "cex"}, - {ExchangeType: "kucoin", Name: "KuCoin Futures", Type: "cex"}, - {ExchangeType: "hyperliquid", Name: "Hyperliquid", Type: "dex"}, - {ExchangeType: "aster", Name: "Aster DEX", Type: "dex"}, - {ExchangeType: "lighter", Name: "LIGHTER DEX", Type: "dex"}, - {ExchangeType: "alpaca", Name: "Alpaca (US Stocks)", Type: "stock"}, - {ExchangeType: "forex", Name: "Forex (TwelveData)", Type: "forex"}, - {ExchangeType: "metals", Name: "Metals (TwelveData)", Type: "metals"}, - } - - c.JSON(http.StatusOK, supportedExchanges) -} - // Start Start server func (s *Server) Start() error { addr := fmt.Sprintf(":%d", s.port) @@ -3500,393 +615,7 @@ func (s *Server) Shutdown() error { return s.httpServer.Shutdown(ctx) } -// handlePublicTraderList Get public trader list (no authentication required) -func (s *Server) handlePublicTraderList(c *gin.Context) { - // Get trader information from all users - competition, err := s.traderManager.GetCompetitionData() - if err != nil { - SafeInternalError(c, "Get trader list", err) - return - } - - // Get traders array - tradersData, exists := competition["traders"] - if !exists { - c.JSON(http.StatusOK, []map[string]interface{}{}) - return - } - - traders, ok := tradersData.([]map[string]interface{}) - if !ok { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Trader data format error", - }) - return - } - - // Return trader basic information, filter sensitive information - result := make([]map[string]interface{}, 0, len(traders)) - for _, trader := range traders { - result = append(result, map[string]interface{}{ - "trader_id": trader["trader_id"], - "trader_name": trader["trader_name"], - "ai_model": trader["ai_model"], - "exchange": trader["exchange"], - "is_running": trader["is_running"], - "total_equity": trader["total_equity"], - "total_pnl": trader["total_pnl"], - "total_pnl_pct": trader["total_pnl_pct"], - "position_count": trader["position_count"], - "margin_used_pct": trader["margin_used_pct"], - }) - } - - c.JSON(http.StatusOK, result) -} - -// handlePublicCompetition Get public competition data (no authentication required) -func (s *Server) handlePublicCompetition(c *gin.Context) { - competition, err := s.traderManager.GetCompetitionData() - if err != nil { - SafeInternalError(c, "Get competition data", err) - return - } - - c.JSON(http.StatusOK, competition) -} - -// handleTopTraders Get top 5 trader data (no authentication required, for performance comparison) -func (s *Server) handleTopTraders(c *gin.Context) { - topTraders, err := s.traderManager.GetTopTradersData() - if err != nil { - SafeInternalError(c, "Get top traders data", err) - return - } - - c.JSON(http.StatusOK, topTraders) -} - -// handleEquityHistoryBatch Batch get return rate historical data for multiple traders (no authentication required, for performance comparison) -// Supports optional 'hours' parameter to filter data by time range (e.g., hours=24 for last 24 hours) -func (s *Server) handleEquityHistoryBatch(c *gin.Context) { - var requestBody struct { - TraderIDs []string `json:"trader_ids"` - Hours int `json:"hours"` // Optional: filter by last N hours (0 = all data) - } - - // Try to parse POST request JSON body - if err := c.ShouldBindJSON(&requestBody); err != nil { - // If JSON parse fails, try to get from query parameters (compatible with GET request) - traderIDsParam := c.Query("trader_ids") - if traderIDsParam == "" { - // If no trader_ids specified, return historical data for top 5 - topTraders, err := s.traderManager.GetTopTradersData() - if err != nil { - SafeInternalError(c, "Get top traders", err) - return - } - - traders, ok := topTraders["traders"].([]map[string]interface{}) - if !ok { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Trader data format error"}) - return - } - - // Extract trader IDs - traderIDs := make([]string, 0, len(traders)) - for _, trader := range traders { - if traderID, ok := trader["trader_id"].(string); ok { - traderIDs = append(traderIDs, traderID) - } - } - - // Parse hours parameter from query - hoursParam := c.Query("hours") - hours := 0 - if hoursParam != "" { - fmt.Sscanf(hoursParam, "%d", &hours) - } - - result := s.getEquityHistoryForTraders(traderIDs, hours) - c.JSON(http.StatusOK, result) - return - } - - // Parse comma-separated trader IDs - requestBody.TraderIDs = strings.Split(traderIDsParam, ",") - for i := range requestBody.TraderIDs { - requestBody.TraderIDs[i] = strings.TrimSpace(requestBody.TraderIDs[i]) - } - - // Parse hours parameter from query - hoursParam := c.Query("hours") - if hoursParam != "" { - fmt.Sscanf(hoursParam, "%d", &requestBody.Hours) - } - } - - // Limit to maximum 20 traders to prevent oversized requests - if len(requestBody.TraderIDs) > 20 { - requestBody.TraderIDs = requestBody.TraderIDs[:20] - } - - result := s.getEquityHistoryForTraders(requestBody.TraderIDs, requestBody.Hours) - c.JSON(http.StatusOK, result) -} - -// getEquityHistoryForTraders Get historical data for multiple traders -// Query directly from database, not dependent on trader in memory (so historical data can be retrieved after restart) -// Also appends current real-time data point to ensure chart matches leaderboard -// hours: filter by last N hours (0 = use default limit of 500 records) -func (s *Server) getEquityHistoryForTraders(traderIDs []string, hours int) map[string]interface{} { - result := make(map[string]interface{}) - histories := make(map[string]interface{}) - errors := make(map[string]string) - - // Use a single consistent timestamp for all real-time data points - now := time.Now() - - // Pre-fetch initial balances for all traders - initialBalances := make(map[string]float64) - for _, traderID := range traderIDs { - if traderID == "" { - continue - } - // Get trader's initial balance from database (use GetByID which doesn't require userID) - trader, err := s.store.Trader().GetByID(traderID) - if err == nil && trader != nil && trader.InitialBalance > 0 { - initialBalances[traderID] = trader.InitialBalance - } - } - - for _, traderID := range traderIDs { - if traderID == "" { - continue - } - - // Get equity historical data from new equity table - var snapshots []*store.EquitySnapshot - var err error - - if hours > 0 { - // Filter by time range - startTime := now.Add(-time.Duration(hours) * time.Hour) - snapshots, err = s.store.Equity().GetByTimeRange(traderID, startTime, now) - } else { - // Default: get latest 500 records - snapshots, err = s.store.Equity().GetLatest(traderID, 500) - } - if err != nil { - logger.Errorf("[API] Failed to get equity history for %s: %v", traderID, err) - errors[traderID] = "Failed to get historical data" - continue - } - - // Get initial balance for calculating PnL percentage - initialBalance := initialBalances[traderID] - if initialBalance <= 0 && len(snapshots) > 0 { - // If no initial balance configured, use the first snapshot's equity as baseline - initialBalance = snapshots[0].TotalEquity - } - - // Build return rate historical data with PnL percentage - history := make([]map[string]interface{}, 0, len(snapshots)+1) - var lastSnapshotTime time.Time - for _, snap := range snapshots { - // Calculate PnL percentage: (current_equity - initial_balance) / initial_balance * 100 - pnlPct := 0.0 - if initialBalance > 0 { - pnlPct = (snap.TotalEquity - initialBalance) / initialBalance * 100 - } - - history = append(history, map[string]interface{}{ - "timestamp": snap.Timestamp, - "total_equity": snap.TotalEquity, - "total_pnl": snap.UnrealizedPnL, - "total_pnl_pct": pnlPct, - "balance": snap.Balance, - }) - if snap.Timestamp.After(lastSnapshotTime) { - lastSnapshotTime = snap.Timestamp - } - } - - // Append current real-time data point to ensure chart matches leaderboard - // This ensures the latest point is always current, not from a potentially stale snapshot - if trader, err := s.traderManager.GetTrader(traderID); err == nil { - if accountInfo, err := trader.GetAccountInfo(); err == nil { - // Only append if it's been more than 30 seconds since last snapshot - if now.Sub(lastSnapshotTime) > 30*time.Second { - totalEquity := 0.0 - if v, ok := accountInfo["total_equity"].(float64); ok { - totalEquity = v - } - totalPnL := 0.0 - if v, ok := accountInfo["total_pnl"].(float64); ok { - totalPnL = v - } - walletBalance := 0.0 - if v, ok := accountInfo["wallet_balance"].(float64); ok { - walletBalance = v - } - pnlPct := 0.0 - if initialBalance > 0 { - pnlPct = (totalEquity - initialBalance) / initialBalance * 100 - } - - history = append(history, map[string]interface{}{ - "timestamp": now, - "total_equity": totalEquity, - "total_pnl": totalPnL, - "total_pnl_pct": pnlPct, - "balance": walletBalance, - }) - } - } - } - - histories[traderID] = history - } - - result["histories"] = histories - result["count"] = len(histories) - if len(errors) > 0 { - result["errors"] = errors - } - - return result -} - -// handleGetPublicTraderConfig Get public trader configuration information (no authentication required, does not include sensitive information) -func (s *Server) handleGetPublicTraderConfig(c *gin.Context) { - traderID := c.Param("id") - if traderID == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "Trader ID cannot be empty"}) - return - } - - trader, err := s.traderManager.GetTrader(traderID) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"}) - return - } - - // Get trader status information - status := trader.GetStatus() - - // Only return public configuration information, not including sensitive data like API keys - result := map[string]interface{}{ - "trader_id": trader.GetID(), - "trader_name": trader.GetName(), - "ai_model": trader.GetAIModel(), - "exchange": trader.GetExchange(), - "is_running": status["is_running"], - "ai_provider": status["ai_provider"], - "start_time": status["start_time"], - } - - c.JSON(http.StatusOK, result) -} - // SetTelegramReloadCh sets the channel used to signal the Telegram bot to reload func (s *Server) SetTelegramReloadCh(ch chan<- struct{}) { s.telegramReloadCh = ch } - -// handleGetTelegramConfig returns current Telegram bot configuration and binding status -func (s *Server) handleGetTelegramConfig(c *gin.Context) { - cfg, err := s.store.TelegramConfig().Get() - if err != nil { - // Not configured yet - return empty state - c.JSON(http.StatusOK, gin.H{ - "configured": false, - "is_bound": false, - "token_masked": "", - "username": "", - }) - return - } - - // Mask bot token for security (show only last 6 chars) - tokenMasked := "" - if cfg.BotToken != "" { - if len(cfg.BotToken) > 6 { - tokenMasked = "***" + cfg.BotToken[len(cfg.BotToken)-6:] - } else { - tokenMasked = "***" - } - } - - c.JSON(http.StatusOK, gin.H{ - "configured": cfg.BotToken != "", - "is_bound": cfg.ChatID != 0, - "username": cfg.Username, - "bound_at": cfg.BoundAt, - "token_masked": tokenMasked, - "model_id": cfg.ModelID, - }) -} - -// handleUpdateTelegramConfig saves bot token (+ optional model ID) and triggers bot hot-reload -func (s *Server) handleUpdateTelegramConfig(c *gin.Context) { - var req struct { - BotToken string `json:"bot_token"` - ModelID string `json:"model_id"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) - return - } - if req.BotToken == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "bot_token is required"}) - return - } - - if err := s.store.TelegramConfig().Save(req.BotToken, req.ModelID); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save config"}) - return - } - - // Signal bot hot-reload if channel is available - if s.telegramReloadCh != nil { - select { - case s.telegramReloadCh <- struct{}{}: - default: // non-blocking - } - } - - c.JSON(http.StatusOK, gin.H{"success": true, "message": "Bot token saved. Bot will reload automatically."}) -} - -// handleUnbindTelegram removes Telegram user binding -func (s *Server) handleUnbindTelegram(c *gin.Context) { - if err := s.store.TelegramConfig().Unbind(); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to unbind"}) - return - } - c.JSON(http.StatusOK, gin.H{"success": true, "message": "Telegram binding removed"}) -} - -// handleUpdateTelegramModel updates only the AI model used for Telegram replies (no token re-entry needed) -func (s *Server) handleUpdateTelegramModel(c *gin.Context) { - var req struct { - ModelID string `json:"model_id"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) - return - } - - cfg, err := s.store.TelegramConfig().Get() - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "no Telegram config found, save a bot token first"}) - return - } - - if err := s.store.TelegramConfig().Save(cfg.BotToken, req.ModelID); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save model config"}) - return - } - - c.JSON(http.StatusOK, gin.H{"success": true, "model_id": req.ModelID}) -} diff --git a/api/strategy.go b/api/strategy.go index c6599d1d..2fc9c9fb 100644 --- a/api/strategy.go +++ b/api/strategy.go @@ -8,6 +8,8 @@ import ( "nofx/logger" "nofx/market" "nofx/mcp" + _ "nofx/mcp/payment" + _ "nofx/mcp/provider" "nofx/store" "time" @@ -637,49 +639,20 @@ func (s *Server) runRealAITest(userID, modelID, systemPrompt, userPrompt string) return "", fmt.Errorf("AI model %s is missing API Key", model.Name) } - // Create AI client - var aiClient mcp.AIClient + // Create AI client via registry provider := model.Provider - - // Convert EncryptedString to string for API key apiKey := string(model.APIKey) + + aiClient := mcp.NewAIClientByProvider(provider) + if aiClient == nil { + aiClient = mcp.NewClient() + } + + // Payment providers ignore custom URL switch provider { - case "qwen": - aiClient = mcp.NewQwenClient() - aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) - case "deepseek": - aiClient = mcp.NewDeepSeekClient() - aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) - case "claude": - aiClient = mcp.NewClaudeClient() - aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) - case "kimi": - aiClient = mcp.NewKimiClient() - aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) - case "gemini": - aiClient = mcp.NewGeminiClient() - aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) - case "grok": - aiClient = mcp.NewGrokClient() - aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) - case "openai": - aiClient = mcp.NewOpenAIClient() - aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) - case "minimax": - aiClient = mcp.NewMiniMaxClient() - aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) - case "blockrun-base": - aiClient = mcp.NewBlockRunBaseClient() - aiClient.SetAPIKey(apiKey, "", model.CustomModelName) - case "blockrun-sol": - aiClient = mcp.NewBlockRunSolClient() - aiClient.SetAPIKey(apiKey, "", model.CustomModelName) - case "claw402": - aiClient = mcp.NewClaw402Client() + case "blockrun-base", "blockrun-sol", "claw402": aiClient.SetAPIKey(apiKey, "", model.CustomModelName) default: - // Use generic client - aiClient = mcp.NewClient() aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName) } diff --git a/backtest/ai_client.go b/backtest/ai_client.go index 203e33d1..5395b430 100644 --- a/backtest/ai_client.go +++ b/backtest/ai_client.go @@ -5,15 +5,17 @@ import ( "strings" "nofx/mcp" + _ "nofx/mcp/payment" + _ "nofx/mcp/provider" ) // configureMCPClient creates/clones an MCP client based on configuration (returns mcp.AIClient interface). // Note: mcp.New() returns an interface type; here we convert to concrete implementation before copying to avoid concurrent shared state. func configureMCPClient(cfg BacktestConfig, base mcp.AIClient) (mcp.AIClient, error) { - provider := strings.ToLower(strings.TrimSpace(cfg.AICfg.Provider)) + providerName := strings.ToLower(strings.TrimSpace(cfg.AICfg.Provider)) - // DeepSeek - if provider == "" || provider == "inherit" || provider == "default" { + // Inherit base client + if providerName == "" || providerName == "inherit" || providerName == "default" { client := cloneBaseClient(base) if cfg.AICfg.APIKey != "" || cfg.AICfg.BaseURL != "" || cfg.AICfg.Model != "" { client.SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model) @@ -21,143 +23,49 @@ func configureMCPClient(cfg BacktestConfig, base mcp.AIClient) (mcp.AIClient, er return client, nil } - switch provider { - case "deepseek": - if cfg.AICfg.APIKey == "" { - return nil, fmt.Errorf("deepseek provider requires api key") - } - ds := mcp.NewDeepSeekClientWithOptions() - ds.(*mcp.DeepSeekClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model) - return ds, nil - case "qwen": - if cfg.AICfg.APIKey == "" { - return nil, fmt.Errorf("qwen provider requires api key") - } - qc := mcp.NewQwenClientWithOptions() - qc.(*mcp.QwenClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model) - return qc, nil - case "claude": - if cfg.AICfg.APIKey == "" { - return nil, fmt.Errorf("claude provider requires api key") - } - cc := mcp.NewClaudeClientWithOptions() - cc.(*mcp.ClaudeClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model) - return cc, nil - case "kimi": - if cfg.AICfg.APIKey == "" { - return nil, fmt.Errorf("kimi provider requires api key") - } - kc := mcp.NewKimiClientWithOptions() - kc.(*mcp.KimiClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model) - return kc, nil - case "gemini": - if cfg.AICfg.APIKey == "" { - return nil, fmt.Errorf("gemini provider requires api key") - } - gc := mcp.NewGeminiClientWithOptions() - gc.(*mcp.GeminiClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model) - return gc, nil - case "grok": - if cfg.AICfg.APIKey == "" { - return nil, fmt.Errorf("grok provider requires api key") - } - grokC := mcp.NewGrokClientWithOptions() - grokC.(*mcp.GrokClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model) - return grokC, nil - case "openai": - if cfg.AICfg.APIKey == "" { - return nil, fmt.Errorf("openai provider requires api key") - } - oaiC := mcp.NewOpenAIClientWithOptions() - oaiC.(*mcp.OpenAIClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model) - return oaiC, nil - case "minimax": - if cfg.AICfg.APIKey == "" { - return nil, fmt.Errorf("minimax provider requires api key") - } - mmC := mcp.NewMiniMaxClientWithOptions() - mmC.(*mcp.MiniMaxClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model) - return mmC, nil - case "blockrun-base": - if cfg.AICfg.APIKey == "" { - return nil, fmt.Errorf("blockrun-base provider requires wallet private key") - } - brBase := mcp.NewBlockRunBaseClient() - brBase.SetAPIKey(cfg.AICfg.APIKey, "", cfg.AICfg.Model) - return brBase, nil - case "blockrun-sol": - if cfg.AICfg.APIKey == "" { - return nil, fmt.Errorf("blockrun-sol provider requires wallet keypair") - } - brSol := mcp.NewBlockRunSolClient() - brSol.SetAPIKey(cfg.AICfg.APIKey, "", cfg.AICfg.Model) - return brSol, nil - case "claw402": - if cfg.AICfg.APIKey == "" { - return nil, fmt.Errorf("claw402 provider requires wallet private key") - } - claw := mcp.NewClaw402Client() - claw.SetAPIKey(cfg.AICfg.APIKey, "", cfg.AICfg.Model) - return claw, nil - case "custom": + // Custom provider uses cloned base + if providerName == "custom" { if cfg.AICfg.BaseURL == "" || cfg.AICfg.APIKey == "" || cfg.AICfg.Model == "" { return nil, fmt.Errorf("custom provider requires base_url, api key and model") } client := cloneBaseClient(base) client.SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model) return client, nil - default: + } + + // Create client via registry + client := mcp.NewAIClientByProvider(providerName) + if client == nil { return nil, fmt.Errorf("unsupported ai provider %s", cfg.AICfg.Provider) } + + if cfg.AICfg.APIKey == "" { + return nil, fmt.Errorf("%s provider requires api key", providerName) + } + + // Payment providers ignore custom URL + switch providerName { + case "blockrun-base", "blockrun-sol", "claw402": + client.SetAPIKey(cfg.AICfg.APIKey, "", cfg.AICfg.Model) + default: + client.SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model) + } + return client, nil } // cloneBaseClient copies the base client to avoid shared mutable state. +// Uses the ClientEmbedder interface to extract the underlying *mcp.Client +// from any provider type that embeds it. func cloneBaseClient(base mcp.AIClient) *mcp.Client { - // Prefer to reuse the passed-in base client (deep copy) - switch c := base.(type) { - case *mcp.Client: + if embedder, ok := base.(mcp.ClientEmbedder); ok { + if inner := embedder.BaseClient(); inner != nil { + cp := *inner + return &cp + } + } + if c, ok := base.(*mcp.Client); ok { cp := *c return &cp - case *mcp.DeepSeekClient: - if c != nil && c.Client != nil { - cp := *c.Client - return &cp - } - case *mcp.QwenClient: - if c != nil && c.Client != nil { - cp := *c.Client - return &cp - } - case *mcp.ClaudeClient: - if c != nil && c.Client != nil { - cp := *c.Client - return &cp - } - case *mcp.KimiClient: - if c != nil && c.Client != nil { - cp := *c.Client - return &cp - } - case *mcp.GeminiClient: - if c != nil && c.Client != nil { - cp := *c.Client - return &cp - } - case *mcp.GrokClient: - if c != nil && c.Client != nil { - cp := *c.Client - return &cp - } - case *mcp.OpenAIClient: - if c != nil && c.Client != nil { - cp := *c.Client - return &cp - } - case *mcp.MiniMaxClient: - if c != nil && c.Client != nil { - cp := *c.Client - return &cp - } } // Fall back to a new default client return mcp.NewClient().(*mcp.Client) diff --git a/llm/qwen_agent.go b/llm/qwen_agent.go deleted file mode 100644 index 8f5db189..00000000 --- a/llm/qwen_agent.go +++ /dev/null @@ -1,351 +0,0 @@ -package llm - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "time" -) - -// ้˜ฟ้‡Œไบ‘ API ้…็ฝฎ -const ( - DefaultQwenBaseURL = "https://dashscope.aliyuncs.com/api/v1/apps" - // ๆ ‡ๅ‡† OpenAI ๅ…ผๅฎนๆจกๅผ API - QwenCompatibleURL = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions" -) - -// QwenAgent ้˜ฟ้‡Œไบ‘็™พ็‚ผๆ™บ่ƒฝไฝ“ๅฎขๆˆท็ซฏ -type QwenAgent struct { - AppID string - APIKey string - BaseURL string - SessionID string - Client *http.Client -} - -// QwenRequest ่ฏทๆฑ‚็ป“ๆž„ -type QwenRequest struct { - Input QwenInput `json:"input"` - Parameters QwenParameters `json:"parameters,omitempty"` -} - -// QwenInput ่พ“ๅ…ฅ็ป“ๆž„ -type QwenInput struct { - Prompt string `json:"prompt"` - BizParams map[string]interface{} `json:"biz_params,omitempty"` -} - -// QwenParameters ๅ‚ๆ•ฐ็ป“ๆž„ -type QwenParameters struct { - SessionID string `json:"session_id,omitempty"` - IncrementalOutput bool `json:"incremental_output,omitempty"` -} - -// QwenResponse ๅ“ๅบ”็ป“ๆž„ -type QwenResponse struct { - Output QwenOutput `json:"output"` - Usage QwenUsage `json:"usage,omitempty"` - RequestID string `json:"request_id"` - Code string `json:"code,omitempty"` - Message string `json:"message,omitempty"` -} - -// QwenOutput ่พ“ๅ‡บ็ป“ๆž„ -type QwenOutput struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason,omitempty"` - SessionID string `json:"session_id,omitempty"` -} - -// QwenUsage ็”จ้‡็ปŸ่ฎก -type QwenUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` -} - -// NewQwenAgent ๅˆ›ๅปบๆ–ฐ็š„ๆ™บ่ƒฝไฝ“ๅฎขๆˆท็ซฏ -func NewQwenAgent(appID, apiKey string) *QwenAgent { - return &QwenAgent{ - AppID: appID, - APIKey: apiKey, - BaseURL: DefaultQwenBaseURL, - Client: &http.Client{ - Timeout: 180 * time.Second, - }, - } -} - -// Chat ๅŒๆญฅๅฏน่ฏ -func (a *QwenAgent) Chat(ctx context.Context, prompt string) (*QwenResponse, error) { - reqBody := QwenRequest{ - Input: QwenInput{ - Prompt: prompt, - }, - Parameters: QwenParameters{ - SessionID: a.SessionID, - }, - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("marshal request failed: %w", err) - } - - url := fmt.Sprintf("%s/%s/completion", a.BaseURL, a.AppID) - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) - if err != nil { - return nil, fmt.Errorf("create request failed: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+a.APIKey) - - resp, err := a.Client.Do(req) - if err != nil { - return nil, fmt.Errorf("request failed: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read response failed: %w", err) - } - - var result QwenResponse - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("unmarshal response failed: %w, body: %s", err, string(body)) - } - - // ๆ›ดๆ–ฐ session_id ็”จไบŽๅคš่ฝฎๅฏน่ฏ - if result.Output.SessionID != "" { - a.SessionID = result.Output.SessionID - } - - // ๆฃ€ๆŸฅ API ้”™่ฏฏ - if result.Code != "" { - return &result, fmt.Errorf("API error: code=%s, message=%s", result.Code, result.Message) - } - - return &result, nil -} - -// ChatStream ๆตๅผๅฏน่ฏ -func (a *QwenAgent) ChatStream(ctx context.Context, prompt string, callback func(chunk string)) error { - reqBody := QwenRequest{ - Input: QwenInput{ - Prompt: prompt, - }, - Parameters: QwenParameters{ - SessionID: a.SessionID, - IncrementalOutput: true, - }, - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return fmt.Errorf("marshal request failed: %w", err) - } - - url := fmt.Sprintf("%s/%s/completion", a.BaseURL, a.AppID) - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) - if err != nil { - return fmt.Errorf("create request failed: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+a.APIKey) - req.Header.Set("X-DashScope-SSE", "enable") - - resp, err := a.Client.Do(req) - if err != nil { - return fmt.Errorf("request failed: %w", err) - } - defer resp.Body.Close() - - reader := bufio.NewReader(resp.Body) - for { - line, err := reader.ReadString('\n') - if err != nil { - if err == io.EOF { - break - } - return fmt.Errorf("read stream failed: %w", err) - } - - line = strings.TrimSpace(line) - if !strings.HasPrefix(line, "data:") { - continue - } - - data := strings.TrimPrefix(line, "data:") - var chunk QwenResponse - if err := json.Unmarshal([]byte(data), &chunk); err != nil { - continue - } - - // ๆ›ดๆ–ฐ session_id - if chunk.Output.SessionID != "" { - a.SessionID = chunk.Output.SessionID - } - - // ๅ›ž่ฐƒ่พ“ๅ‡บๆ–‡ๆœฌ - if chunk.Output.Text != "" { - callback(chunk.Output.Text) - } - } - - return nil -} - -// ChatWithBizParams ๅธฆไธšๅŠกๅ‚ๆ•ฐ็š„ๅฏน่ฏ -func (a *QwenAgent) ChatWithBizParams(ctx context.Context, prompt string, bizParams map[string]interface{}) (*QwenResponse, error) { - reqBody := QwenRequest{ - Input: QwenInput{ - Prompt: prompt, - BizParams: bizParams, - }, - Parameters: QwenParameters{ - SessionID: a.SessionID, - }, - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("marshal request failed: %w", err) - } - - url := fmt.Sprintf("%s/%s/completion", a.BaseURL, a.AppID) - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) - if err != nil { - return nil, fmt.Errorf("create request failed: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+a.APIKey) - - resp, err := a.Client.Do(req) - if err != nil { - return nil, fmt.Errorf("request failed: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read response failed: %w", err) - } - - var result QwenResponse - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("unmarshal response failed: %w, body: %s", err, string(body)) - } - - if result.Output.SessionID != "" { - a.SessionID = result.Output.SessionID - } - - if result.Code != "" { - return &result, fmt.Errorf("API error: code=%s, message=%s", result.Code, result.Message) - } - - return &result, nil -} - -// ResetSession ้‡็ฝฎไผš่ฏ -func (a *QwenAgent) ResetSession() { - a.SessionID = "" -} - -// ========== ๆ ‡ๅ‡† OpenAI ๅ…ผๅฎน API ========== - -// ChatCompletionRequest OpenAI ๅ…ผๅฎนๆ ผๅผ่ฏทๆฑ‚ -type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []ChatCompletionMessage `json:"messages"` -} - -// ChatCompletionMessage ๆถˆๆฏ็ป“ๆž„ -type ChatCompletionMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -// ChatCompletionResponse OpenAI ๅ…ผๅฎนๆ ผๅผๅ“ๅบ” -type ChatCompletionResponse struct { - ID string `json:"id"` - Model string `json:"model"` - Choices []struct { - Message struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` - Error *struct { - Code string `json:"code"` - Message string `json:"message"` - } `json:"error,omitempty"` -} - -// ChatWithModel ไฝฟ็”จๆ ‡ๅ‡† OpenAI ๅ…ผๅฎน API ่ฐƒ็”จๆŒ‡ๅฎšๆจกๅž‹ -func (a *QwenAgent) ChatWithModel(ctx context.Context, model, prompt string) (*ChatCompletionResponse, error) { - reqBody := ChatCompletionRequest{ - Model: model, - Messages: []ChatCompletionMessage{ - {Role: "user", Content: prompt}, - }, - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("marshal request failed: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", QwenCompatibleURL, bytes.NewBuffer(jsonData)) - if err != nil { - return nil, fmt.Errorf("create request failed: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+a.APIKey) - - resp, err := a.Client.Do(req) - if err != nil { - return nil, fmt.Errorf("request failed: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read response failed: %w", err) - } - - var result ChatCompletionResponse - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("unmarshal response failed: %w, body: %s", err, string(body)) - } - - if result.Error != nil { - return &result, fmt.Errorf("API error: code=%s, message=%s", result.Error.Code, result.Error.Message) - } - - return &result, nil -} - -// GetContent ไปŽๅ“ๅบ”ไธญ่Žทๅ–ๅ†…ๅฎน -func (r *ChatCompletionResponse) GetContent() string { - if len(r.Choices) > 0 { - return r.Choices[0].Message.Content - } - return "" -} diff --git a/llm/qwen_agent_test.go b/llm/qwen_agent_test.go deleted file mode 100644 index 27d1b275..00000000 --- a/llm/qwen_agent_test.go +++ /dev/null @@ -1,425 +0,0 @@ -package llm - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "strings" - "testing" - "time" -) - -// ้˜ฟ้‡Œไบ‘็™พ็‚ผๅนณๅฐ้…็ฝฎ (ไปŽ็Žฏๅขƒๅ˜้‡่Žทๅ–) -var ( - QwenAppID = os.Getenv("QWEN_APP_ID") - QwenAPIKey = os.Getenv("QWEN_API_KEY") -) - -// ============== ๆต‹่ฏ•็”จไพ‹ ============== - -// TestQwenBasicChat ๆต‹่ฏ•ๅŸบๆœฌๅŒๆญฅๅฏน่ฏ -func TestQwenBasicChat(t *testing.T) { - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - ctx := context.Background() - - prompt := "ไฝ ๅฅฝ๏ผŒ่ฏท็”จไธ€ๅฅ่ฏไป‹็ปไฝ ่‡ชๅทฑ" - t.Logf("็”จๆˆท: %s", prompt) - - start := time.Now() - resp, err := agent.Chat(ctx, prompt) - elapsed := time.Since(start) - - if err != nil { - t.Fatalf("Chat failed: %v", err) - } - - if resp.Output.Text == "" { - t.Fatal("Empty response text") - } - - t.Logf("ๅŠฉๆ‰‹: %s", resp.Output.Text) - t.Logf("่€—ๆ—ถ: %v, Token: %d", elapsed, resp.Usage.TotalTokens) -} - -// TestQwenStreamChat ๆต‹่ฏ•ๆตๅผ่พ“ๅ‡บ -func TestQwenStreamChat(t *testing.T) { - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - ctx := context.Background() - - prompt := "่ฏท็”จ3ๅฅ่ฏ่งฃ้‡Šไป€ไนˆๆ˜ฏ้‡ๅŒ–ไบคๆ˜“" - t.Logf("็”จๆˆท: %s", prompt) - - var fullText strings.Builder - start := time.Now() - - err := agent.ChatStream(ctx, prompt, func(chunk string) { - fullText.WriteString(chunk) - }) - - elapsed := time.Since(start) - - if err != nil { - t.Fatalf("ChatStream failed: %v", err) - } - - if fullText.Len() == 0 { - t.Fatal("Empty stream response") - } - - t.Logf("ๅŠฉๆ‰‹: %s", fullText.String()) - t.Logf("่€—ๆ—ถ: %v, ๅญ—็ฌฆๆ•ฐ: %d", elapsed, fullText.Len()) -} - -// TestQwenMultiTurn ๆต‹่ฏ•ๅคš่ฝฎๅฏน่ฏ๏ผˆไธŠไธ‹ๆ–‡่ฎฐๅฟ†๏ผ‰ -func TestQwenMultiTurn(t *testing.T) { - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - ctx := context.Background() - - // ็ฌฌไธ€่ฝฎ๏ผš่ฎพ็ฝฎไธŠไธ‹ๆ–‡ - resp1, err := agent.Chat(ctx, "ๆˆ‘ๅซๅฐๆ˜Ž๏ผŒๆˆ‘ๆ˜ฏไธ€ๅ Go ็จ‹ๅบๅ‘˜๏ผŒ่ฏท่ฎฐไฝ่ฟ™ไบ›ไฟกๆฏ") - if err != nil { - t.Fatalf("Round 1 failed: %v", err) - } - t.Logf("[Round 1] ็”จๆˆท: ๆˆ‘ๅซๅฐๆ˜Ž๏ผŒๆˆ‘ๆ˜ฏไธ€ๅ Go ็จ‹ๅบๅ‘˜") - t.Logf("[Round 1] ๅŠฉๆ‰‹: %s", resp1.Output.Text) - t.Logf("[Round 1] SessionID: %s", agent.SessionID) - - // ็ฌฌไบŒ่ฝฎ๏ผš้ชŒ่ฏ่ฎฐๅฟ† - resp2, err := agent.Chat(ctx, "่ฏท้—ฎๆˆ‘ๅซไป€ไนˆๅๅญ—๏ผŸๆˆ‘ๆ˜ฏๅšไป€ไนˆ็š„๏ผŸ") - if err != nil { - t.Fatalf("Round 2 failed: %v", err) - } - t.Logf("[Round 2] ็”จๆˆท: ่ฏท้—ฎๆˆ‘ๅซไป€ไนˆๅๅญ—๏ผŸๆˆ‘ๆ˜ฏๅšไป€ไนˆ็š„๏ผŸ") - t.Logf("[Round 2] ๅŠฉๆ‰‹: %s", resp2.Output.Text) - - // ๆฃ€ๆŸฅๆ˜ฏๅฆ่ฎฐไฝไบ†ไฟกๆฏ - text := strings.ToLower(resp2.Output.Text) - if !strings.Contains(text, "ๅฐๆ˜Ž") && !strings.Contains(text, "go") { - t.Logf("่ญฆๅ‘Š: ๆจกๅž‹ๅฏ่ƒฝๆฒกๆœ‰ๆญฃ็กฎ่ฎฐไฝไธŠไธ‹ๆ–‡") - } -} - -// TestQwenResetSession ๆต‹่ฏ•้‡็ฝฎไผš่ฏ -func TestQwenResetSession(t *testing.T) { - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - ctx := context.Background() - - // ๅปบ็ซ‹ไธŠไธ‹ๆ–‡ - resp1, err := agent.Chat(ctx, "่ฎฐไฝ่ฟ™ไธชๅฏ†็ : ABC123XYZ") - if err != nil { - t.Fatalf("Setup context failed: %v", err) - } - t.Logf("่ฎพ็ฝฎไธŠไธ‹ๆ–‡: %s", resp1.Output.Text) - - oldSession := agent.SessionID - t.Logf("ๅŽŸ SessionID: %s", oldSession) - - // ้‡็ฝฎไผš่ฏ - agent.ResetSession() - t.Log("ไผš่ฏๅทฒ้‡็ฝฎ") - - // ๆ–ฐๅฏน่ฏ - ๅบ”่ฏฅไธ่ฎฐๅพ—ไน‹ๅ‰็š„ๅ†…ๅฎน - resp2, err := agent.Chat(ctx, "ๆˆ‘ไน‹ๅ‰ๅ‘Š่ฏ‰ไฝ ็š„ๅฏ†็ ๆ˜ฏไป€ไนˆ๏ผŸ") - if err != nil { - t.Fatalf("New session chat failed: %v", err) - } - t.Logf("ๆ–ฐๅฏน่ฏๅ›žๅค: %s", resp2.Output.Text) - t.Logf("ๆ–ฐ SessionID: %s", agent.SessionID) - - if oldSession == agent.SessionID { - t.Error("Session was not reset properly") - } -} - -// TestQwenCodeGeneration ๆต‹่ฏ•ไปฃ็ ็”Ÿๆˆ่ƒฝๅŠ› -func TestQwenCodeGeneration(t *testing.T) { - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - ctx := context.Background() - - prompt := "่ฏท็”จ Go ่ฏญ่จ€ๅ†™ไธ€ไธช่ฎก็ฎ—็งปๅŠจๅนณๅ‡็บฟ(MA)็š„ๅ‡ฝๆ•ฐ๏ผŒ่พ“ๅ…ฅๆ˜ฏ []float64 ไปทๆ ผๅˆ‡็‰‡ๅ’Œ int ๅ‘จๆœŸ" - t.Logf("็”จๆˆท: %s", prompt) - - resp, err := agent.Chat(ctx, prompt) - if err != nil { - t.Fatalf("Code generation failed: %v", err) - } - - t.Logf("ๅŠฉๆ‰‹:\n%s", resp.Output.Text) - - // ๆฃ€ๆŸฅๆ˜ฏๅฆๅŒ…ๅซไปฃ็ ็‰นๅพ - text := resp.Output.Text - if !strings.Contains(text, "func") || !strings.Contains(text, "float64") { - t.Log("่ญฆๅ‘Š: ๅ“ๅบ”ๅฏ่ƒฝไธๅŒ…ๅซๆœ‰ๆ•ˆ็š„ Go ไปฃ็ ") - } -} - -// TestQwenJSONOutput ๆต‹่ฏ• JSON ๆ ผๅผ่พ“ๅ‡บ -func TestQwenJSONOutput(t *testing.T) { - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - ctx := context.Background() - - prompt := `่ฏทๅˆ†ๆž BTC ็š„ๅŸบๆœฌไฟกๆฏ๏ผŒไปฅ็บฏ JSON ๆ ผๅผ่ฟ”ๅ›ž๏ผˆไธ่ฆ markdown ไปฃ็ ๅ—๏ผ‰๏ผŒๅŒ…ๅซไปฅไธ‹ๅญ—ๆฎต: -{"name": "่ต„ไบงๅ็งฐ", "type": "่ต„ไบง็ฑปๅž‹", "risk": 1-10็š„้ฃŽ้™ฉ็ญ‰็บงๆ•ฐๅญ—} -ๅช่ฟ”ๅ›ž JSON ๅฏน่ฑก๏ผŒไธ่ฆไปปไฝ•ๅ…ถไป–ๆ–‡ๅญ—` - - t.Logf("็”จๆˆท: %s", prompt) - - resp, err := agent.Chat(ctx, prompt) - if err != nil { - t.Fatalf("JSON output test failed: %v", err) - } - - t.Logf("ๅŠฉๆ‰‹: %s", resp.Output.Text) - - // ๅฐ่ฏ•่งฃๆž JSON - text := resp.Output.Text - // ๆๅ– JSON ้ƒจๅˆ† - start := strings.Index(text, "{") - end := strings.LastIndex(text, "}") - if start != -1 && end != -1 && end > start { - jsonStr := text[start : end+1] - var result map[string]interface{} - if err := json.Unmarshal([]byte(jsonStr), &result); err != nil { - t.Logf("JSON ่งฃๆžๅคฑ่ดฅ: %v", err) - } else { - t.Logf("JSON ่งฃๆžๆˆๅŠŸ: %+v", result) - } - } -} - -// TestQwenLongResponse ๆต‹่ฏ•้•ฟๆ–‡ๆœฌ็”Ÿๆˆ -func TestQwenLongResponse(t *testing.T) { - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - ctx := context.Background() - - prompt := "่ฏท่ฏฆ็ป†ไป‹็ปๅŠ ๅฏ†่ดงๅธๆฐธ็ปญๅˆ็บฆไบคๆ˜“ไธญ็š„้ฃŽ้™ฉ็ฎก็†็ญ–็•ฅ๏ผŒๅŒ…ๆ‹ฌๆญขๆŸ่ฎพ็ฝฎใ€ไป“ไฝ็ฎก็†ใ€ๆ ๆ†้€‰ๆ‹ฉใ€่ต„้‡‘่ดน็އ่€ƒ่™‘็ญ‰ๆ–น้ข๏ผŒ่‡ณๅฐ‘500ๅญ—" - t.Logf("็”จๆˆท: %s", prompt) - - start := time.Now() - resp, err := agent.Chat(ctx, prompt) - elapsed := time.Since(start) - - if err != nil { - t.Fatalf("Long response test failed: %v", err) - } - - text := resp.Output.Text - t.Logf("ๅ“ๅบ”้•ฟๅบฆ: %d ๅญ—็ฌฆ", len(text)) - t.Logf("่€—ๆ—ถ: %v", elapsed) - t.Logf("Token ไฝฟ็”จ: input=%d, output=%d, total=%d", - resp.Usage.InputTokens, resp.Usage.OutputTokens, resp.Usage.TotalTokens) - - // ๅชๆ˜พ็คบๅ‰500ๅญ—็ฌฆ - if len(text) > 500 { - t.Logf("ๅŠฉๆ‰‹(ๅ‰500ๅญ—): %s...", text[:500]) - } else { - t.Logf("ๅŠฉๆ‰‹: %s", text) - } -} - -// TestQwenTradingScenario ๆต‹่ฏ•ไบคๆ˜“ๅœบๆ™ฏ้—ฎ็ญ” -func TestQwenTradingScenario(t *testing.T) { - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - ctx := context.Background() - - questions := []string{ - "BTC ๅฝ“ๅ‰ไปทๆ ผ 95000 ็พŽๅ…ƒ๏ผŒRSI ๅœจ 75 ้™„่ฟ‘๏ผŒMACD ้‡‘ๅ‰๏ผŒไฝ ๅปบ่ฎฎ็Žฐๅœจๅผ€ๅคš่ฟ˜ๆ˜ฏๅผ€็ฉบ๏ผŸ็ฎ€็Ÿญๅ›ž็ญ”", - "ๅฆ‚ๆžœๆˆ‘ๆœ‰ 10000 USDT๏ผŒๆƒณ็”จ 10 ๅ€ๆ ๆ†ๅšๅคš ETH๏ผŒๅปบ่ฎฎๅผ€ๅคšๅคงไป“ไฝ๏ผŸ", - "ไป€ไนˆๆ˜ฏ่ต„้‡‘่ดน็އ๏ผŸๆญฃ็š„่ต„้‡‘่ดน็އๅฏนๅคšๅคดๆœ‰ไป€ไนˆๅฝฑๅ“๏ผŸ", - } - - for i, q := range questions { - agent.ResetSession() // ๆฏไธช้—ฎ้ข˜็‹ฌ็ซ‹ - - t.Logf("\n[้—ฎ้ข˜%d] %s", i+1, q) - resp, err := agent.Chat(ctx, q) - if err != nil { - t.Errorf("Question %d failed: %v", i+1, err) - continue - } - - // ๆˆชๅ–ๆ˜พ็คบ - text := resp.Output.Text - if len(text) > 300 { - text = text[:300] + "..." - } - t.Logf("[ๅ›ž็ญ”%d] %s", i+1, text) - } -} - -// TestQwenErrorHandling ๆต‹่ฏ•้”™่ฏฏๅค„็† -func TestQwenErrorHandling(t *testing.T) { - ctx := context.Background() - - // ๆต‹่ฏ•ๆ— ๆ•ˆ API Key - t.Run("InvalidAPIKey", func(t *testing.T) { - agent := NewQwenAgent(QwenAppID, "invalid-api-key") - _, err := agent.Chat(ctx, "ๆต‹่ฏ•") - if err == nil { - t.Log("่ญฆๅ‘Š: ๆ— ๆ•ˆ API Key ๆฒกๆœ‰่ฟ”ๅ›ž้”™่ฏฏ") - } else { - t.Logf("้ข„ๆœŸ้”™่ฏฏ: %v", err) - } - }) - - // ๆต‹่ฏ•ๆ— ๆ•ˆ App ID - t.Run("InvalidAppID", func(t *testing.T) { - agent := NewQwenAgent("invalid-app-id", QwenAPIKey) - _, err := agent.Chat(ctx, "ๆต‹่ฏ•") - if err == nil { - t.Log("่ญฆๅ‘Š: ๆ— ๆ•ˆ App ID ๆฒกๆœ‰่ฟ”ๅ›ž้”™่ฏฏ") - } else { - t.Logf("้ข„ๆœŸ้”™่ฏฏ: %v", err) - } - }) -} - -// TestQwenSpecialCharacters ๆต‹่ฏ•็‰นๆฎŠๅญ—็ฌฆๅค„็† -func TestQwenSpecialCharacters(t *testing.T) { - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - ctx := context.Background() - - testCases := []string{ - "่ฏท่งฃ้‡Š่ฟ™ไธช่กจๆƒ…: ๐Ÿ˜€๐ŸŽ‰๐Ÿš€", - "ไธญ่‹ฑๆ–‡ๆททๅˆ: Helloไธ–็•Œ๏ผ", - "็‰นๆฎŠ็ฌฆๅท: <>&\"'", - } - - for _, prompt := range testCases { - agent.ResetSession() - t.Logf("็”จๆˆท: %s", prompt) - - resp, err := agent.Chat(ctx, prompt) - if err != nil { - t.Errorf("็‰นๆฎŠๅญ—็ฌฆๆต‹่ฏ•ๅคฑ่ดฅ: %v", err) - continue - } - - if len(resp.Output.Text) > 100 { - t.Logf("ๅŠฉๆ‰‹: %s...", resp.Output.Text[:100]) - } else { - t.Logf("ๅŠฉๆ‰‹: %s", resp.Output.Text) - } - } -} - -// TestQwenConcurrentSessions ๆต‹่ฏ•ๅนถๅ‘ไผš่ฏ -func TestQwenConcurrentSessions(t *testing.T) { - agent1 := NewQwenAgent(QwenAppID, QwenAPIKey) - agent2 := NewQwenAgent(QwenAppID, QwenAPIKey) - ctx := context.Background() - - // Agent1 ๅฏน่ฏ - resp1, err := agent1.Chat(ctx, "ๆˆ‘ๆ˜ฏ Alice๏ผŒ่ฏท่ฎฐไฝ") - if err != nil { - t.Fatalf("Agent1 chat failed: %v", err) - } - t.Logf("[Agent1] ่ฎพ็ฝฎ: ๆˆ‘ๆ˜ฏ Alice -> %s", resp1.Output.Text[:min(100, len(resp1.Output.Text))]) - - // Agent2 ๅฏน่ฏ - resp2, err := agent2.Chat(ctx, "ๆˆ‘ๆ˜ฏ Bob๏ผŒ่ฏท่ฎฐไฝ") - if err != nil { - t.Fatalf("Agent2 chat failed: %v", err) - } - t.Logf("[Agent2] ่ฎพ็ฝฎ: ๆˆ‘ๆ˜ฏ Bob -> %s", resp2.Output.Text[:min(100, len(resp2.Output.Text))]) - - // ้ชŒ่ฏไผš่ฏ้š”็ฆป - resp1Check, _ := agent1.Chat(ctx, "ๆˆ‘ๅซไป€ไนˆ๏ผŸ") - resp2Check, _ := agent2.Chat(ctx, "ๆˆ‘ๅซไป€ไนˆ๏ผŸ") - - t.Logf("[Agent1] ้ชŒ่ฏ: %s", resp1Check.Output.Text[:min(100, len(resp1Check.Output.Text))]) - t.Logf("[Agent2] ้ชŒ่ฏ: %s", resp2Check.Output.Text[:min(100, len(resp2Check.Output.Text))]) - - if agent1.SessionID == agent2.SessionID { - t.Error("ไธคไธช Agent ็š„ SessionID ไธๅบ”่ฏฅ็›ธๅŒ") - } else { - t.Logf("Session ้š”็ฆปๆญฃๅธธ: Agent1=%s..., Agent2=%s...", - agent1.SessionID[:min(20, len(agent1.SessionID))], - agent2.SessionID[:min(20, len(agent2.SessionID))]) - } -} - -// TestQwenTimeout ๆต‹่ฏ•่ถ…ๆ—ถๅค„็† -func TestQwenTimeout(t *testing.T) { - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - agent.Client.Timeout = 1 * time.Millisecond // ๆž็Ÿญ่ถ…ๆ—ถ - - ctx := context.Background() - _, err := agent.Chat(ctx, "ๆต‹่ฏ•่ถ…ๆ—ถ") - - if err == nil { - t.Log("่ญฆๅ‘Š: ๆž็Ÿญ่ถ…ๆ—ถๆฒกๆœ‰่งฆๅ‘้”™่ฏฏ") - } else { - t.Logf("้ข„ๆœŸ่ถ…ๆ—ถ้”™่ฏฏ: %v", err) - } - - // ๆขๅคๆญฃๅธธ่ถ…ๆ—ถ - agent.Client.Timeout = 120 * time.Second -} - -// TestQwenContextCancel ๆต‹่ฏ•ไธŠไธ‹ๆ–‡ๅ–ๆถˆ -func TestQwenContextCancel(t *testing.T) { - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - - ctx, cancel := context.WithCancel(context.Background()) - cancel() // ็ซ‹ๅณๅ–ๆถˆ - - _, err := agent.Chat(ctx, "ๆต‹่ฏ•ๅ–ๆถˆ") - if err == nil { - t.Error("ๅ–ๆถˆ็š„ไธŠไธ‹ๆ–‡ๅบ”่ฏฅ่ฟ”ๅ›ž้”™่ฏฏ") - } else { - t.Logf("้ข„ๆœŸๅ–ๆถˆ้”™่ฏฏ: %v", err) - } -} - -// TestQwenWithBizParams ๆต‹่ฏ•ๅธฆไธšๅŠกๅ‚ๆ•ฐ็š„่ฐƒ็”จ -func TestQwenWithBizParams(t *testing.T) { - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - ctx := context.Background() - - // ๆž„้€ ๅธฆไธšๅŠกๅ‚ๆ•ฐ็š„่ฏทๆฑ‚ - reqBody := QwenRequest{ - Input: QwenInput{ - Prompt: "ๆ นๆฎๆไพ›็š„็”จๆˆทไฟกๆฏ๏ผŒ็ป™ๅ‡บไธชๆ€งๅŒ–็š„ๆŠ•่ต„ๅปบ่ฎฎ", - BizParams: map[string]interface{}{ - "user_risk_level": "moderate", - "capital": 10000, - "experience": "intermediate", - }, - }, - } - - jsonData, _ := json.Marshal(reqBody) - url := fmt.Sprintf("%s/%s/completion", agent.BaseURL, agent.AppID) - - req, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+agent.APIKey) - - resp, err := agent.Client.Do(req) - if err != nil { - t.Fatalf("Request with biz params failed: %v", err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - var result QwenResponse - json.Unmarshal(body, &result) - - if result.Output.Text != "" { - t.Logf("ๅธฆไธšๅŠกๅ‚ๆ•ฐๅ“ๅบ”: %s", result.Output.Text[:min(200, len(result.Output.Text))]) - } else { - t.Logf("ๅ“ๅบ”: %s", string(body)) - } -} - -func min(a, b int) int { - if a < b { - return a - } - return b -} diff --git a/llm/qwen_indicator_test.go b/llm/qwen_indicator_test.go deleted file mode 100644 index 873c6967..00000000 --- a/llm/qwen_indicator_test.go +++ /dev/null @@ -1,737 +0,0 @@ -package llm - -import ( - "context" - "encoding/json" - "fmt" - "math" - "nofx/market" - "nofx/provider/coinank" - "nofx/provider/coinank/coinank_api" - "nofx/provider/coinank/coinank_enum" - "regexp" - "strconv" - "strings" - "testing" - "time" -) - -// IndicatorResult AI ่ฎก็ฎ—็š„ๆŒ‡ๆ ‡็ป“ๆžœ -type IndicatorResult struct { - EMA12 float64 `json:"ema12"` - EMA26 float64 `json:"ema26"` - MACD float64 `json:"macd"` - RSI14 float64 `json:"rsi14"` - BOLLUp float64 `json:"boll_upper"` - BOLLMid float64 `json:"boll_middle"` - BOLLLow float64 `json:"boll_lower"` - ATR14 float64 `json:"atr14"` - SMA20 float64 `json:"sma20"` -} - -// ๆœฌๅœฐ่ฎก็ฎ—ๆŒ‡ๆ ‡๏ผˆไฝฟ็”จ market ๅŒ…็š„ๅ‡ฝๆ•ฐ๏ผ‰ -func calculateLocalIndicators(klines []market.Kline) IndicatorResult { - result := IndicatorResult{} - - if len(klines) >= 12 { - result.EMA12 = market.ExportCalculateEMA(klines, 12) - } - if len(klines) >= 26 { - result.EMA26 = market.ExportCalculateEMA(klines, 26) - result.MACD = market.ExportCalculateMACD(klines) - } - if len(klines) > 14 { - result.RSI14 = market.ExportCalculateRSI(klines, 14) - } - if len(klines) >= 20 { - result.BOLLUp, result.BOLLMid, result.BOLLLow = market.ExportCalculateBOLL(klines, 20, 2.0) - // SMA20 ๅฐฑๆ˜ฏ BOLL ไธญ่ฝจ - result.SMA20 = result.BOLLMid - } - if len(klines) > 14 { - result.ATR14 = market.ExportCalculateATR(klines, 14) - } - - return result -} - -// ๆ ผๅผๅŒ– K ็บฟๆ•ฐๆฎไธบๆ–‡ๆœฌ๏ผŒๅ‘็ป™ AI -func formatKlinesForAI(klines []market.Kline) string { - var sb strings.Builder - sb.WriteString("ไปฅไธ‹ๆ˜ฏK็บฟๆ•ฐๆฎ๏ผˆไปŽๆ—งๅˆฐๆ–ฐๆŽ’ๅˆ—๏ผ‰๏ผš\n") - sb.WriteString("ๅบๅท | ๆ—ถ้—ด | ๅผ€็›˜ไปท | ๆœ€้ซ˜ไปท | ๆœ€ไฝŽไปท | ๆ”ถ็›˜ไปท | ๆˆไบค้‡\n") - sb.WriteString("-----|------|--------|--------|--------|--------|--------\n") - - for i, k := range klines { - t := time.UnixMilli(k.OpenTime) - sb.WriteString(fmt.Sprintf("%d | %s | %.2f | %.2f | %.2f | %.2f | %.2f\n", - i+1, t.Format("01-02 15:04"), k.Open, k.High, k.Low, k.Close, k.Volume)) - } - - return sb.String() -} - -// ๆž„ๅปบ AI ่ฎก็ฎ—ๆŒ‡ๆ ‡็š„ prompt -func buildIndicatorPrompt(klines []market.Kline) string { - klinesText := formatKlinesForAI(klines) - - prompt := fmt.Sprintf(`%s - -่ฏทๆ นๆฎไปฅไธŠ %d ๆ นK็บฟๆ•ฐๆฎ๏ผŒ่ฎก็ฎ—ไปฅไธ‹ๆŠ€ๆœฏๆŒ‡ๆ ‡๏ผˆไฝฟ็”จๆ ‡ๅ‡†็ฎ—ๆณ•๏ผ‰๏ผš - -1. EMA12๏ผˆ12ๅ‘จๆœŸๆŒ‡ๆ•ฐ็งปๅŠจๅนณๅ‡็บฟ๏ผ‰ -2. EMA26๏ผˆ26ๅ‘จๆœŸๆŒ‡ๆ•ฐ็งปๅŠจๅนณๅ‡็บฟ๏ผ‰ -3. MACD๏ผˆEMA12 - EMA26๏ผ‰ -4. RSI14๏ผˆ14ๅ‘จๆœŸ็›ธๅฏนๅผบๅผฑๆŒ‡ๆ ‡๏ผŒไฝฟ็”จWilderๅนณๆป‘ๆณ•๏ผ‰ -5. BOLLๅธƒๆž—ๅธฆ๏ผˆ20ๅ‘จๆœŸ๏ผŒ2ๅ€ๆ ‡ๅ‡†ๅทฎ๏ผ‰๏ผšไธŠ่ฝจใ€ไธญ่ฝจใ€ไธ‹่ฝจ -6. ATR14๏ผˆ14ๅ‘จๆœŸๅนณๅ‡็œŸๅฎžๆณขๅน…๏ผŒไฝฟ็”จWilderๅนณๆป‘ๆณ•๏ผ‰ -7. SMA20๏ผˆ20ๅ‘จๆœŸ็ฎ€ๅ•็งปๅŠจๅนณๅ‡็บฟ๏ผ‰ - -่ฏทไธฅๆ ผๆŒ‰็…งไปฅไธ‹ JSON ๆ ผๅผ่ฟ”ๅ›ž็ป“ๆžœ๏ผŒไธ่ฆๆทปๅŠ ไปปไฝ•ๅ…ถไป–ๆ–‡ๅญ—๏ผš -{ - "ema12": ๆ•ฐๅ€ผ, - "ema26": ๆ•ฐๅ€ผ, - "macd": ๆ•ฐๅ€ผ, - "rsi14": ๆ•ฐๅ€ผ, - "boll_upper": ๆ•ฐๅ€ผ, - "boll_middle": ๆ•ฐๅ€ผ, - "boll_lower": ๆ•ฐๅ€ผ, - "atr14": ๆ•ฐๅ€ผ, - "sma20": ๆ•ฐๅ€ผ -} - -ๆณจๆ„๏ผš -- ๆ‰€ๆœ‰ๆ•ฐๅ€ผไฟ็•™2ไฝๅฐๆ•ฐ -- EMA่ฎก็ฎ—ไฝฟ็”จSMAไฝœไธบๅˆๅง‹ๅ€ผ๏ผŒไน˜ๆ•ฐไธบ 2/(period+1) -- RSIไฝฟ็”จWilderๅนณๆป‘ๆณ• -- ๅช่ฟ”ๅ›žJSON๏ผŒไธ่ฆ่งฃ้‡Š่ฟ‡็จ‹`, klinesText, len(klines)) - - return prompt -} - -// ไปŽ AI ๅ“ๅบ”ไธญๆๅ– JSON -func extractJSONFromResponse(text string) (IndicatorResult, error) { - var result IndicatorResult - - // ๅฐ่ฏ•็›ดๆŽฅ่งฃๆž - if err := json.Unmarshal([]byte(text), &result); err == nil { - return result, nil - } - - // ๆๅ– JSON ้ƒจๅˆ† - re := regexp.MustCompile(`\{[^{}]*"ema12"[^{}]*\}`) - match := re.FindString(text) - if match == "" { - // ๅฐ่ฏ•ๆ›ดๅฎฝๆพ็š„ๅŒน้… - start := strings.Index(text, "{") - end := strings.LastIndex(text, "}") - if start != -1 && end != -1 && end > start { - match = text[start : end+1] - } - } - - if match == "" { - return result, fmt.Errorf("no JSON found in response: %s", text[:min(200, len(text))]) - } - - if err := json.Unmarshal([]byte(match), &result); err != nil { - return result, fmt.Errorf("parse JSON failed: %w, json: %s", err, match) - } - - return result, nil -} - -// ๆฏ”่พƒไธคไธชๆŒ‡ๆ ‡็ป“ๆžœ๏ผŒ่ฟ”ๅ›ž่ฏฏๅทฎ็™พๅˆ†ๆฏ” -func compareIndicators(local, ai IndicatorResult) map[string]float64 { - errors := make(map[string]float64) - - calcError := func(name string, localVal, aiVal float64) { - if localVal == 0 { - if aiVal == 0 { - errors[name] = 0 - } else { - errors[name] = 100 // ๆœฌๅœฐไธบ0ไฝ†AIไธไธบ0 - } - return - } - errors[name] = math.Abs(localVal-aiVal) / math.Abs(localVal) * 100 - } - - calcError("EMA12", local.EMA12, ai.EMA12) - calcError("EMA26", local.EMA26, ai.EMA26) - calcError("MACD", local.MACD, ai.MACD) - calcError("RSI14", local.RSI14, ai.RSI14) - calcError("BOLL_UP", local.BOLLUp, ai.BOLLUp) - calcError("BOLL_MID", local.BOLLMid, ai.BOLLMid) - calcError("BOLL_LOW", local.BOLLLow, ai.BOLLLow) - calcError("ATR14", local.ATR14, ai.ATR14) - calcError("SMA20", local.SMA20, ai.SMA20) - - return errors -} - -// ็”Ÿๆˆๆต‹่ฏ•็”จ K ็บฟๆ•ฐๆฎ -func generateTestKlines(count int, basePrice float64) []market.Kline { - klines := make([]market.Kline, count) - price := basePrice - now := time.Now() - - for i := 0; i < count; i++ { - // ๆจกๆ‹Ÿไปทๆ ผๆณขๅŠจ - change := (float64(i%7) - 3) * 0.5 // -1.5 ๅˆฐ +1.5 ็š„ๆณขๅŠจ - price = price + change - - open := price - high := price + math.Abs(change)*0.5 + 0.5 - low := price - math.Abs(change)*0.5 - 0.3 - close := price + (change * 0.3) - - klines[i] = market.Kline{ - OpenTime: now.Add(time.Duration(-count+i) * time.Hour).UnixMilli(), - Open: open, - High: high, - Low: low, - Close: close, - Volume: 1000 + float64(i*100), - CloseTime: now.Add(time.Duration(-count+i+1) * time.Hour).UnixMilli(), - } - } - - return klines -} - -// TestQwenIndicatorCalculation ๆต‹่ฏ• AI ่ฎก็ฎ—ๆŠ€ๆœฏๆŒ‡ๆ ‡ -func TestQwenIndicatorCalculation(t *testing.T) { - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - ctx := context.Background() - - // ็”Ÿๆˆ 30 ๆ นๆต‹่ฏ• K ็บฟ - klines := generateTestKlines(30, 95000) - - t.Log("===== K็บฟๆ•ฐๆฎ (ๆœ€ๅŽ5ๆ น) =====") - for i := len(klines) - 5; i < len(klines); i++ { - k := klines[i] - t.Logf(" [%d] O:%.2f H:%.2f L:%.2f C:%.2f", i+1, k.Open, k.High, k.Low, k.Close) - } - - // ๆœฌๅœฐ่ฎก็ฎ— - t.Log("\n===== ๆœฌๅœฐ่ฎก็ฎ—็ป“ๆžœ =====") - localResult := calculateLocalIndicators(klines) - t.Logf(" EMA12: %.2f", localResult.EMA12) - t.Logf(" EMA26: %.2f", localResult.EMA26) - t.Logf(" MACD: %.2f", localResult.MACD) - t.Logf(" RSI14: %.2f", localResult.RSI14) - t.Logf(" BOLLไธŠ่ฝจ: %.2f", localResult.BOLLUp) - t.Logf(" BOLLไธญ่ฝจ: %.2f", localResult.BOLLMid) - t.Logf(" BOLLไธ‹่ฝจ: %.2f", localResult.BOLLLow) - t.Logf(" ATR14: %.2f", localResult.ATR14) - t.Logf(" SMA20: %.2f", localResult.SMA20) - - // AI ่ฎก็ฎ— - t.Log("\n===== ่ฐƒ็”จ AI ่ฎก็ฎ— =====") - prompt := buildIndicatorPrompt(klines) - t.Logf("Prompt ้•ฟๅบฆ: %d ๅญ—็ฌฆ", len(prompt)) - - start := time.Now() - resp, err := agent.Chat(ctx, prompt) - elapsed := time.Since(start) - - if err != nil { - t.Fatalf("AI ่ฐƒ็”จๅคฑ่ดฅ: %v", err) - } - - t.Logf("AI ๅ“ๅบ”่€—ๆ—ถ: %v", elapsed) - t.Logf("AI ๅŽŸๅง‹ๅ“ๅบ”:\n%s", resp.Output.Text) - - // ่งฃๆž AI ็ป“ๆžœ - aiResult, err := extractJSONFromResponse(resp.Output.Text) - if err != nil { - t.Fatalf("่งฃๆž AI ็ป“ๆžœๅคฑ่ดฅ: %v", err) - } - - t.Log("\n===== AI ่ฎก็ฎ—็ป“ๆžœ =====") - t.Logf(" EMA12: %.2f", aiResult.EMA12) - t.Logf(" EMA26: %.2f", aiResult.EMA26) - t.Logf(" MACD: %.2f", aiResult.MACD) - t.Logf(" RSI14: %.2f", aiResult.RSI14) - t.Logf(" BOLLไธŠ่ฝจ: %.2f", aiResult.BOLLUp) - t.Logf(" BOLLไธญ่ฝจ: %.2f", aiResult.BOLLMid) - t.Logf(" BOLLไธ‹่ฝจ: %.2f", aiResult.BOLLLow) - t.Logf(" ATR14: %.2f", aiResult.ATR14) - t.Logf(" SMA20: %.2f", aiResult.SMA20) - - // ๅฏนๆฏ”็ป“ๆžœ - t.Log("\n===== ่ฏฏๅทฎๅฏนๆฏ” (%) =====") - errors := compareIndicators(localResult, aiResult) - - totalError := 0.0 - for name, errPct := range errors { - status := "โœ“" - if errPct > 5 { - status = "โš " - } - if errPct > 10 { - status = "โœ—" - } - t.Logf(" %s %s: %.2f%%", status, name, errPct) - totalError += errPct - } - - avgError := totalError / float64(len(errors)) - t.Logf("\n ๅนณๅ‡่ฏฏๅทฎ: %.2f%%", avgError) - - if avgError > 10 { - t.Logf("่ญฆๅ‘Š: AI ่ฎก็ฎ—่ฏฏๅทฎ่พƒๅคง๏ผŒๅฏ่ƒฝ็ฎ—ๆณ•็†่งฃๆœ‰ๅทฎๅผ‚") - } else if avgError < 5 { - t.Log("AI ่ฎก็ฎ—็ฒพๅบฆ่‰ฏๅฅฝ๏ผ") - } -} - -// TestQwenIndicatorWithRealKlines ไฝฟ็”จ็œŸๅฎž K ็บฟๆต‹่ฏ• -func TestQwenIndicatorWithRealKlines(t *testing.T) { - // ๅฐ่ฏ•่Žทๅ–็œŸๅฎž K ็บฟๆ•ฐๆฎ - client := market.NewAPIClient() - klines, err := client.GetKlines("BTC", "1h", 30) - if err != nil { - t.Skipf("่Žทๅ–็œŸๅฎž K ็บฟๅคฑ่ดฅ๏ผŒ่ทณ่ฟ‡ๆต‹่ฏ•: %v", err) - return - } - - if len(klines) < 26 { - t.Skipf("K ็บฟๆ•ฐ้‡ไธ่ถณ: %d", len(klines)) - return - } - - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - ctx := context.Background() - - t.Logf("่Žทๅ–ๅˆฐ %d ๆ น BTC 1h K็บฟ", len(klines)) - t.Log("ๆœ€ๆ–ฐไปทๆ ผ:", klines[len(klines)-1].Close) - - // ๆœฌๅœฐ่ฎก็ฎ— - localResult := calculateLocalIndicators(klines) - t.Log("\n===== ๆœฌๅœฐ่ฎก็ฎ— =====") - t.Logf(" EMA12: %.2f, EMA26: %.2f, MACD: %.2f", localResult.EMA12, localResult.EMA26, localResult.MACD) - t.Logf(" RSI14: %.2f", localResult.RSI14) - t.Logf(" BOLL: %.2f / %.2f / %.2f", localResult.BOLLUp, localResult.BOLLMid, localResult.BOLLLow) - - // AI ่ฎก็ฎ— - prompt := buildIndicatorPrompt(klines) - resp, err := agent.Chat(ctx, prompt) - if err != nil { - t.Fatalf("AI ่ฐƒ็”จๅคฑ่ดฅ: %v", err) - } - - t.Log("\n===== AI ๅ“ๅบ” =====") - t.Log(resp.Output.Text) - - aiResult, err := extractJSONFromResponse(resp.Output.Text) - if err != nil { - t.Logf("่งฃๆžๅคฑ่ดฅ: %v", err) - return - } - - // ๅฏนๆฏ” - errors := compareIndicators(localResult, aiResult) - t.Log("\n===== ่ฏฏๅทฎ =====") - for name, errPct := range errors { - t.Logf(" %s: %.2f%%", name, errPct) - } -} - -// TestQwenIndicatorMultiTimeframe ๆต‹่ฏ•ๅคšไธชๆ—ถ้—ดๅ‘จๆœŸ -func TestQwenIndicatorMultiTimeframe(t *testing.T) { - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - ctx := context.Background() - - timeframes := []struct { - name string - count int - price float64 - }{ - {"5mๅ‘จๆœŸ", 30, 95000}, - {"1hๅ‘จๆœŸ", 50, 95000}, - {"4hๅ‘จๆœŸ", 40, 95000}, - } - - for _, tf := range timeframes { - t.Run(tf.name, func(t *testing.T) { - klines := generateTestKlines(tf.count, tf.price) - - localResult := calculateLocalIndicators(klines) - - // ็ฎ€ๅŒ–็š„ prompt - prompt := buildSimpleIndicatorPrompt(klines) - - resp, err := agent.Chat(ctx, prompt) - if err != nil { - t.Fatalf("AI ่ฐƒ็”จๅคฑ่ดฅ: %v", err) - } - - aiResult, err := extractJSONFromResponse(resp.Output.Text) - if err != nil { - t.Logf("่งฃๆžๅคฑ่ดฅ: %v", err) - t.Logf("AI ๅ“ๅบ”: %s", resp.Output.Text[:min(500, len(resp.Output.Text))]) - return - } - - errors := compareIndicators(localResult, aiResult) - - // ่ฎก็ฎ—ๅนณๅ‡่ฏฏๅทฎ - total := 0.0 - for _, e := range errors { - total += e - } - avgErr := total / float64(len(errors)) - - t.Logf("ๆœฌๅœฐ MACD: %.2f, AI MACD: %.2f, ่ฏฏๅทฎ: %.2f%%", localResult.MACD, aiResult.MACD, errors["MACD"]) - t.Logf("ๆœฌๅœฐ RSI: %.2f, AI RSI: %.2f, ่ฏฏๅทฎ: %.2f%%", localResult.RSI14, aiResult.RSI14, errors["RSI14"]) - t.Logf("ๅนณๅ‡่ฏฏๅทฎ: %.2f%%", avgErr) - }) - - time.Sleep(2 * time.Second) // ้ฟๅ…่ฏทๆฑ‚่ฟ‡ๅฟซ - } -} - -// ็ฎ€ๅŒ–็š„ prompt -func buildSimpleIndicatorPrompt(klines []market.Kline) string { - // ๅชๆไพ›ๆ”ถ็›˜ไปทๅบๅˆ—๏ผŒๅ‡ๅฐ‘ token - var prices []string - for _, k := range klines { - prices = append(prices, fmt.Sprintf("%.2f", k.Close)) - } - - return fmt.Sprintf(`ๆ”ถ็›˜ไปทๅบๅˆ—๏ผˆไปŽๆ—งๅˆฐๆ–ฐ๏ผ‰: [%s] - -่ฏท่ฎก็ฎ—ๆŠ€ๆœฏๆŒ‡ๆ ‡ๅนถ่ฟ”ๅ›ž JSON๏ผš -- ema12: 12ๅ‘จๆœŸEMA -- ema26: 26ๅ‘จๆœŸEMA -- macd: EMA12-EMA26 -- rsi14: 14ๅ‘จๆœŸRSI(Wilderๅนณๆป‘) -- boll_upper, boll_middle, boll_lower: 20ๅ‘จๆœŸBOLL(2ๅ€ๆ ‡ๅ‡†ๅทฎ) -- atr14: 0 (ๆ— ้ซ˜ไฝŽไปทๆ•ฐๆฎ) -- sma20: 20ๅ‘จๆœŸSMA - -ๅช่ฟ”ๅ›žJSONๆ ผๅผ๏ผš{"ema12":ๆ•ฐๅ€ผ,"ema26":ๆ•ฐๅ€ผ,...}`, strings.Join(prices, ",")) -} - -// TestQwenIndicatorAccuracy ็ฒพๅบฆๆต‹่ฏ•๏ผšไฝฟ็”จ็ฎ€ๅ•ๆ•ฐๆฎ้ชŒ่ฏ็ฎ—ๆณ• -func TestQwenIndicatorAccuracy(t *testing.T) { - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - ctx := context.Background() - - // ไฝฟ็”จ็ฎ€ๅ•้€’ๅขžๆ•ฐๆฎ๏ผŒไพฟไบŽ้ชŒ่ฏ - prices := []float64{ - 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, // 1-10 - 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, // 11-20 - 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, // 21-30 - } - - // ๆž„ๅปบ K ็บฟ - klines := make([]market.Kline, len(prices)) - for i, p := range prices { - klines[i] = market.Kline{ - Open: p - 0.5, - High: p + 1, - Low: p - 1, - Close: p, - } - } - - // ๆœฌๅœฐ่ฎก็ฎ— - localResult := calculateLocalIndicators(klines) - - t.Log("===== ็ฎ€ๅ•้€’ๅขžๆ•ฐๆฎๆต‹่ฏ• =====") - t.Logf("ไปทๆ ผๅบๅˆ—: %v", prices) - t.Logf("ๆœฌๅœฐ่ฎก็ฎ—:") - t.Logf(" SMA20 = %.4f (็†่ฎบๅ€ผ: 119.5)", localResult.SMA20) - t.Logf(" EMA12 = %.4f", localResult.EMA12) - t.Logf(" RSI14 = %.4f (ๆŒ็ปญไธŠๆถจๅบ”ๆŽฅ่ฟ‘100)", localResult.RSI14) - - // AI ่ฎก็ฎ— - var priceStrs []string - for _, p := range prices { - priceStrs = append(priceStrs, strconv.FormatFloat(p, 'f', 0, 64)) - } - - prompt := fmt.Sprintf(`ๆ”ถ็›˜ไปทๅบๅˆ—: [%s] - -่ฏท่ฎก็ฎ—: -1. SMA20 (20ๅ‘จๆœŸ็ฎ€ๅ•็งปๅŠจๅนณๅ‡) -2. EMA12 (12ๅ‘จๆœŸๆŒ‡ๆ•ฐ็งปๅŠจๅนณๅ‡๏ผŒๅˆๅง‹ๅ€ผ็”จSMA๏ผŒไน˜ๆ•ฐ=2/13) -3. RSI14 (14ๅ‘จๆœŸRSI๏ผŒWilderๅนณๆป‘ๆณ•) - -่ฟ”ๅ›žJSON: {"sma20":ๆ•ฐๅ€ผ,"ema12":ๆ•ฐๅ€ผ,"rsi14":ๆ•ฐๅ€ผ} -ๅช่ฟ”ๅ›žJSON`, strings.Join(priceStrs, ",")) - - resp, err := agent.Chat(ctx, prompt) - if err != nil { - t.Fatalf("AI ่ฐƒ็”จๅคฑ่ดฅ: %v", err) - } - - t.Logf("\nAI ๅ“ๅบ”: %s", resp.Output.Text) - - // ็ฎ€ๅ•่งฃๆž - var aiSimple struct { - SMA20 float64 `json:"sma20"` - EMA12 float64 `json:"ema12"` - RSI14 float64 `json:"rsi14"` - } - - text := resp.Output.Text - start := strings.Index(text, "{") - end := strings.LastIndex(text, "}") - if start != -1 && end > start { - json.Unmarshal([]byte(text[start:end+1]), &aiSimple) - } - - t.Logf("\nAI ่ฎก็ฎ—:") - t.Logf(" SMA20 = %.4f", aiSimple.SMA20) - t.Logf(" EMA12 = %.4f", aiSimple.EMA12) - t.Logf(" RSI14 = %.4f", aiSimple.RSI14) - - // ้ชŒ่ฏ SMA20 (็†่ฎบๅ€ผๅบ”่ฏฅๆ˜ฏ 110+...+129 ็š„ๅนณๅ‡ = 119.5) - expectedSMA := 119.5 - if math.Abs(aiSimple.SMA20-expectedSMA) < 0.1 { - t.Log("\nโœ“ AI ็š„ SMA20 ่ฎก็ฎ—ๆญฃ็กฎ!") - } else { - t.Logf("\nโœ— AI ็š„ SMA20 ๆœ‰่ฏฏๅทฎ๏ผŒๆœŸๆœ› %.2f", expectedSMA) - } -} - -// coinankKlinesToMarket ๅฐ† coinank K็บฟ่ฝฌๆขไธบ market.Kline -func coinankKlinesToMarket(klines []coinank.KlineResult) []market.Kline { - result := make([]market.Kline, len(klines)) - for i, k := range klines { - result[i] = market.Kline{ - OpenTime: k.StartTime, - Open: k.Open, - High: k.High, - Low: k.Low, - Close: k.Close, - Volume: k.Volume, - CloseTime: k.EndTime, - } - } - return result -} - -// TestQwenETHMultiTimeframe ไฝฟ็”จ Coinank ๅ…่ดน API ่Žทๅ–็œŸๅฎž ETH ๆ•ฐๆฎๆต‹่ฏ•ๅคšๅ‘จๆœŸๆŒ‡ๆ ‡ -func TestQwenETHMultiTimeframe(t *testing.T) { - ctx := context.Background() - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - - // ๆต‹่ฏ•ๅคšไธชๆ—ถ้—ดๅ‘จๆœŸ - timeframes := []struct { - name string - interval coinank_enum.Interval - size int - }{ - {"5ๅˆ†้’Ÿ", coinank_enum.Minute5, 50}, - {"1ๅฐๆ—ถ", coinank_enum.Hour1, 50}, - {"4ๅฐๆ—ถ", coinank_enum.Hour4, 50}, - {"ๆ—ฅ็บฟ", coinank_enum.Day1, 30}, - } - - now := time.Now() - - for _, tf := range timeframes { - t.Run(tf.name, func(t *testing.T) { - // ไฝฟ็”จ coinank ๅ…่ดน API ่Žทๅ– ETH K็บฟๆ•ฐๆฎ - coinankKlines, err := coinank_api.Kline(ctx, "ETHUSDT", coinank_enum.Binance, - now.UnixMilli(), coinank_enum.To, tf.size, tf.interval) - if err != nil { - t.Fatalf("่Žทๅ– %s K็บฟๅคฑ่ดฅ: %v", tf.name, err) - } - - if len(coinankKlines) < 26 { - t.Skipf("K็บฟๆ•ฐ้‡ไธ่ถณ: %d", len(coinankKlines)) - return - } - - // ่ฝฌๆขไธบ market.Kline - klines := coinankKlinesToMarket(coinankKlines) - - t.Logf("่Žทๅ–ๅˆฐ %d ๆ น ETH %s K็บฟ", len(klines), tf.name) - t.Logf("ๆœ€ๆ–ฐๆ”ถ็›˜ไปท: %.2f, ๆ—ถ้—ด: %s", - klines[len(klines)-1].Close, - time.UnixMilli(klines[len(klines)-1].CloseTime).Format("2006-01-02 15:04")) - - // ๆœฌๅœฐ่ฎก็ฎ— - localResult := calculateLocalIndicators(klines) - t.Log("\n===== ๆœฌๅœฐ่ฎก็ฎ— =====") - t.Logf(" EMA12: %.2f, EMA26: %.2f, MACD: %.4f", - localResult.EMA12, localResult.EMA26, localResult.MACD) - t.Logf(" RSI14: %.2f", localResult.RSI14) - t.Logf(" BOLL: %.2f / %.2f / %.2f", - localResult.BOLLUp, localResult.BOLLMid, localResult.BOLLLow) - t.Logf(" ATR14: %.4f", localResult.ATR14) - - // AI ่ฎก็ฎ— - ไฝฟ็”จ็ฎ€ๅŒ– prompt๏ผˆๅชๅ‘ๆ”ถ็›˜ไปท๏ผ‰ - prompt := buildSimpleIndicatorPrompt(klines) - t.Logf("\nPrompt ้•ฟๅบฆ: %d ๅญ—็ฌฆ", len(prompt)) - - start := time.Now() - resp, err := agent.Chat(ctx, prompt) - elapsed := time.Since(start) - - if err != nil { - t.Fatalf("AI ่ฐƒ็”จๅคฑ่ดฅ: %v", err) - } - - t.Logf("AI ๅ“ๅบ”่€—ๆ—ถ: %v", elapsed) - - // ่งฃๆž AI ็ป“ๆžœ - aiResult, err := extractJSONFromResponse(resp.Output.Text) - if err != nil { - t.Logf("AI ๅŽŸๅง‹ๅ“ๅบ”:\n%s", resp.Output.Text[:min(500, len(resp.Output.Text))]) - t.Fatalf("่งฃๆžๅคฑ่ดฅ: %v", err) - } - - t.Log("\n===== AI ่ฎก็ฎ— =====") - t.Logf(" EMA12: %.2f, EMA26: %.2f, MACD: %.4f", - aiResult.EMA12, aiResult.EMA26, aiResult.MACD) - t.Logf(" RSI14: %.2f", aiResult.RSI14) - t.Logf(" BOLL: %.2f / %.2f / %.2f", - aiResult.BOLLUp, aiResult.BOLLMid, aiResult.BOLLLow) - - // ๅฏนๆฏ”่ฏฏๅทฎ - t.Log("\n===== ่ฏฏๅทฎๅฏนๆฏ” =====") - errors := compareIndicators(localResult, aiResult) - totalErr := 0.0 - for name, errPct := range errors { - status := "โœ“" - if errPct > 1 { - status = "โš " - } - if errPct > 5 { - status = "โœ—" - } - t.Logf(" %s %-10s: %.2f%%", status, name, errPct) - totalErr += errPct - } - - avgErr := totalErr / float64(len(errors)) - t.Logf("\n ๅนณๅ‡่ฏฏๅทฎ: %.2f%%", avgErr) - - if avgErr < 1 { - t.Log(" โœ“ AI ่ฎก็ฎ—็ฒพๅบฆไผ˜็ง€!") - } else if avgErr < 5 { - t.Log(" โš  AI ่ฎก็ฎ—็ฒพๅบฆ่‰ฏๅฅฝ") - } else { - t.Log(" โœ— AI ่ฎก็ฎ—่ฏฏๅทฎ่พƒๅคง") - } - - // ็ญ‰ๅพ…้ฟๅ…่ฏทๆฑ‚่ฟ‡ๅฟซ - time.Sleep(2 * time.Second) - }) - } -} - -// TestQwenETHIndicatorComparison ETH ๆŒ‡ๆ ‡ๅฏนๆฏ”๏ผšไฝฟ็”จ Coinank ๅ…่ดน API + Qwen ๆ ‡ๅ‡† API -func TestQwenETHIndicatorComparison(t *testing.T) { - ctx := context.Background() - agent := NewQwenAgent(QwenAppID, QwenAPIKey) - - // ไฝฟ็”จ coinank ๅ…่ดน API ่Žทๅ– ETH 1ๅฐๆ—ถ K็บฟ - now := time.Now() - coinankKlines, err := coinank_api.Kline(ctx, "ETHUSDT", coinank_enum.Binance, - now.UnixMilli(), coinank_enum.To, 30, coinank_enum.Hour1) - if err != nil { - t.Fatalf("่Žทๅ– K็บฟๅคฑ่ดฅ: %v", err) - } - - // ่ฝฌๆขไธบ market.Kline - klines := coinankKlinesToMarket(coinankKlines) - - t.Logf("่Žทๅ–ๅˆฐ %d ๆ น ETH 1h K็บฟ", len(klines)) - - // ๅช็”จๆ”ถ็›˜ไปท๏ผŒ็ฎ€ๅŒ– prompt - var prices []string - for _, k := range klines { - prices = append(prices, fmt.Sprintf("%.2f", k.Close)) - } - - // ๆœฌๅœฐ่ฎก็ฎ— - localResult := calculateLocalIndicators(klines) - - t.Log("\n===== ๆœฌๅœฐ่ฎก็ฎ—็ป“ๆžœ =====") - t.Logf("SMA20: %.2f", localResult.SMA20) - t.Logf("EMA12: %.2f", localResult.EMA12) - t.Logf("EMA26: %.2f", localResult.EMA26) - t.Logf("MACD: %.4f", localResult.MACD) - t.Logf("RSI14: %.2f", localResult.RSI14) - - // ็ฎ€ๅŒ–็š„ AI prompt - prompt := fmt.Sprintf(`ETH ๆœ€่ฟ‘30ๆ น1ๅฐๆ—ถK็บฟๆ”ถ็›˜ไปท๏ผˆไปŽๆ—งๅˆฐๆ–ฐ๏ผ‰: -[%s] - -่ฏท่ฎก็ฎ—ไปฅไธ‹ๆŒ‡ๆ ‡ๅนถ่ฟ”ๅ›ž็บฏ JSON: -1. sma20: ๆœ€ๅŽ20ไธชไปทๆ ผ็š„็ฎ€ๅ•็งปๅŠจๅนณๅ‡ -2. ema12: 12ๅ‘จๆœŸEMA๏ผˆๅˆๅง‹ๅ€ผ็”จๅ‰12ไธชไปทๆ ผ็š„SMA๏ผŒไน˜ๆ•ฐ=2/13๏ผ‰ -3. ema26: 26ๅ‘จๆœŸEMA๏ผˆๅˆๅง‹ๅ€ผ็”จๅ‰26ไธชไปทๆ ผ็š„SMA๏ผŒไน˜ๆ•ฐ=2/27๏ผ‰ -4. macd: EMA12 - EMA26 -5. rsi14: 14ๅ‘จๆœŸRSI๏ผˆWilderๅนณๆป‘ๆณ•๏ผ‰ - -ๅช่ฟ”ๅ›žJSONๆ ผๅผ: {"sma20":ๆ•ฐๅ€ผ,"ema12":ๆ•ฐๅ€ผ,"ema26":ๆ•ฐๅ€ผ,"macd":ๆ•ฐๅ€ผ,"rsi14":ๆ•ฐๅ€ผ} -ไธ่ฆไปปไฝ•่งฃ้‡Šๆ–‡ๅญ—`, strings.Join(prices, ", ")) - - t.Logf("\nๅ‘้€ Prompt (%d ๅญ—็ฌฆ)", len(prompt)) - - // ไฝฟ็”จๆ ‡ๅ‡† API - resp, err := agent.ChatWithModel(ctx, "qwen-max", prompt) - if err != nil { - t.Fatalf("AI ่ฐƒ็”จๅคฑ่ดฅ: %v", err) - } - - aiText := resp.GetContent() - t.Logf("\nAI ๅ“ๅบ”:\n%s", aiText) - - // ่งฃๆž - var aiResult struct { - SMA20 float64 `json:"sma20"` - EMA12 float64 `json:"ema12"` - EMA26 float64 `json:"ema26"` - MACD float64 `json:"macd"` - RSI14 float64 `json:"rsi14"` - } - - start := strings.Index(aiText, "{") - end := strings.LastIndex(aiText, "}") - if start != -1 && end > start { - if err := json.Unmarshal([]byte(aiText[start:end+1]), &aiResult); err != nil { - t.Logf("JSON ่งฃๆžๅคฑ่ดฅ: %v", err) - } - } - - t.Log("\n===== AI ่ฎก็ฎ—็ป“ๆžœ =====") - t.Logf("SMA20: %.2f", aiResult.SMA20) - t.Logf("EMA12: %.2f", aiResult.EMA12) - t.Logf("EMA26: %.2f", aiResult.EMA26) - t.Logf("MACD: %.4f", aiResult.MACD) - t.Logf("RSI14: %.2f", aiResult.RSI14) - - // ่ฎก็ฎ—่ฏฏๅทฎ - t.Log("\n===== ่ฏฏๅทฎ =====") - calcErr := func(name string, local, ai float64) { - if local == 0 { - t.Logf(" %s: ๆœฌๅœฐ=0, AI=%.2f", name, ai) - return - } - errPct := math.Abs(local-ai) / math.Abs(local) * 100 - status := "โœ“" - if errPct > 1 { - status = "โš " - } - if errPct > 5 { - status = "โœ—" - } - t.Logf(" %s %s: ๆœฌๅœฐ=%.2f, AI=%.2f, ่ฏฏๅทฎ=%.2f%%", status, name, local, ai, errPct) - } - - calcErr("SMA20", localResult.SMA20, aiResult.SMA20) - calcErr("EMA12", localResult.EMA12, aiResult.EMA12) - calcErr("EMA26", localResult.EMA26, aiResult.EMA26) - calcErr("MACD", localResult.MACD, aiResult.MACD) - calcErr("RSI14", localResult.RSI14, aiResult.RSI14) -} diff --git a/main.go b/main.go index 5758d4fb..06f76a17 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,8 @@ import ( "nofx/logger" "nofx/manager" "nofx/mcp" + _ "nofx/mcp/payment" + _ "nofx/mcp/provider" "nofx/store" "nofx/telegram" "os" @@ -168,7 +170,7 @@ func newSharedMCPClient() mcp.AIClient { logger.Warn("โš ๏ธ DEEPSEEK_API_KEY not set, AI features will be unavailable") return nil } - return mcp.NewDeepSeekClient() + return mcp.NewAIClientByProvider("deepseek") } // initInstallationID initializes the anonymous installation ID for experience improvement diff --git a/mcp/claude_client_test.go b/mcp/claude_client_test.go deleted file mode 100644 index 268c2849..00000000 --- a/mcp/claude_client_test.go +++ /dev/null @@ -1,248 +0,0 @@ -package mcp - -import ( - "encoding/json" - "net/http" - "testing" -) - -// โ”€โ”€ buildRequestBodyFromRequest โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - -func TestClaudeClient_BuildRequestBody_SystemPromptLifted(t *testing.T) { - c := newTestClaudeClient() - req := &Request{ - Model: "claude-opus-4-6", - Messages: []Message{ - {Role: "system", Content: "You are helpful."}, - {Role: "user", Content: "Hello"}, - }, - } - body := c.buildRequestBodyFromRequest(req) - - if body["system"] != "You are helpful." { - t.Errorf("system not lifted to top level: %v", body["system"]) - } - msgs := body["messages"].([]map[string]any) - if len(msgs) != 1 || msgs[0]["role"] != "user" { - t.Errorf("system message should be removed from messages array: %v", msgs) - } -} - -func TestClaudeClient_BuildRequestBody_ToolsUseInputSchema(t *testing.T) { - c := newTestClaudeClient() - req := &Request{ - Model: "claude-opus-4-6", - Messages: []Message{{Role: "user", Content: "hi"}}, - Tools: []Tool{{ - Type: "function", - Function: FunctionDef{ - Name: "my_tool", - Description: "does stuff", - Parameters: map[string]any{"type": "object"}, - }, - }}, - } - body := c.buildRequestBodyFromRequest(req) - - tools, ok := body["tools"].([]map[string]any) - if !ok || len(tools) != 1 { - t.Fatalf("tools not set correctly: %v", body["tools"]) - } - tool := tools[0] - if tool["name"] != "my_tool" { - t.Errorf("tool name wrong: %v", tool["name"]) - } - if tool["input_schema"] == nil { - t.Error("tool must use input_schema, not parameters") - } - if _, hasParams := tool["parameters"]; hasParams { - t.Error("tool must NOT have parameters key (Anthropic uses input_schema)") - } -} - -func TestClaudeClient_BuildRequestBody_ToolChoiceObject(t *testing.T) { - c := newTestClaudeClient() - req := &Request{ - Model: "claude-opus-4-6", - Messages: []Message{{Role: "user", Content: "hi"}}, - ToolChoice: "auto", - } - body := c.buildRequestBodyFromRequest(req) - - tc, ok := body["tool_choice"].(map[string]any) - if !ok { - t.Fatalf("tool_choice must be an object, got: %T %v", body["tool_choice"], body["tool_choice"]) - } - if tc["type"] != "auto" { - t.Errorf("tool_choice.type must be 'auto', got: %v", tc["type"]) - } -} - -// โ”€โ”€ convertMessagesToAnthropic โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - -func TestConvertMessages_AssistantToolCall(t *testing.T) { - msgs := []Message{ - { - Role: "assistant", - ToolCalls: []ToolCall{{ - ID: "tc1", - Type: "function", - Function: ToolCallFunction{Name: "api_request", Arguments: `{"method":"GET","path":"/api/x","body":{}}`}, - }}, - }, - } - out := convertMessagesToAnthropic(msgs) - - if len(out) != 1 { - t.Fatalf("expected 1 message, got %d", len(out)) - } - msg := out[0] - if msg["role"] != "assistant" { - t.Errorf("role should be assistant: %v", msg["role"]) - } - blocks := msg["content"].([]map[string]any) - if len(blocks) != 1 || blocks[0]["type"] != "tool_use" { - t.Errorf("content should be tool_use block: %v", blocks) - } - if blocks[0]["id"] != "tc1" { - t.Errorf("tool_use id wrong: %v", blocks[0]["id"]) - } - // Input must be parsed JSON object, not a string. - input, ok := blocks[0]["input"].(map[string]any) - if !ok { - t.Errorf("tool_use input must be map, got %T", blocks[0]["input"]) - } - if input["method"] != "GET" { - t.Errorf("input.method wrong: %v", input) - } -} - -func TestConvertMessages_ToolResultMergedIntoUserTurn(t *testing.T) { - // Anthropic requires strictly alternating turns; consecutive tool results - // must be merged into a single user message. - msgs := []Message{ - {Role: "tool", ToolCallID: "tc1", Content: `{"result":"a"}`}, - {Role: "tool", ToolCallID: "tc2", Content: `{"result":"b"}`}, - } - out := convertMessagesToAnthropic(msgs) - - if len(out) != 1 { - t.Fatalf("consecutive tool results must be merged into one user turn, got %d messages", len(out)) - } - if out[0]["role"] != "user" { - t.Errorf("tool results must become role=user: %v", out[0]["role"]) - } - blocks := out[0]["content"].([]map[string]any) - if len(blocks) != 2 { - t.Errorf("expected 2 tool_result blocks, got %d", len(blocks)) - } - if blocks[0]["type"] != "tool_result" || blocks[1]["type"] != "tool_result" { - t.Errorf("blocks should be tool_result: %v", blocks) - } - if blocks[0]["tool_use_id"] != "tc1" || blocks[1]["tool_use_id"] != "tc2" { - t.Errorf("tool_use_id mismatch: %v", blocks) - } -} - -// โ”€โ”€ parseMCPResponseFull โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - -func TestClaudeClient_ParseResponse_TextOnly(t *testing.T) { - c := newTestClaudeClient() - body := []byte(`{ - "content": [{"type":"text","text":"Hello from Claude"}], - "usage": {"input_tokens": 10, "output_tokens": 5} - }`) - resp, err := c.parseMCPResponseFull(body) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.Content != "Hello from Claude" { - t.Errorf("content mismatch: %q", resp.Content) - } - if len(resp.ToolCalls) != 0 { - t.Errorf("expected no tool calls: %v", resp.ToolCalls) - } -} - -func TestClaudeClient_ParseResponse_ToolUse(t *testing.T) { - c := newTestClaudeClient() - body := []byte(`{ - "content": [{ - "type": "tool_use", - "id": "toolu_01abc", - "name": "api_request", - "input": {"method":"POST","path":"/api/strategies","body":{"name":"BTC็ญ–็•ฅ"}} - }], - "usage": {"input_tokens": 100, "output_tokens": 30} - }`) - resp, err := c.parseMCPResponseFull(body) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(resp.ToolCalls) != 1 { - t.Fatalf("expected 1 tool call, got %d", len(resp.ToolCalls)) - } - tc := resp.ToolCalls[0] - if tc.ID != "toolu_01abc" { - t.Errorf("tool call ID wrong: %v", tc.ID) - } - if tc.Function.Name != "api_request" { - t.Errorf("function name wrong: %v", tc.Function.Name) - } - // Arguments must be a valid JSON string. - var args map[string]any - if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil { - t.Errorf("arguments not valid JSON: %q โ€” %v", tc.Function.Arguments, err) - } - if args["method"] != "POST" { - t.Errorf("args.method wrong: %v", args) - } -} - -func TestClaudeClient_ParseResponse_APIError(t *testing.T) { - c := newTestClaudeClient() - body := []byte(`{"error":{"type":"authentication_error","message":"invalid x-api-key"}}`) - _, err := c.parseMCPResponseFull(body) - if err == nil { - t.Fatal("expected error for API error response") - } - if err.Error() == "" { - t.Error("error message should not be empty") - } -} - -// โ”€โ”€ Auth header โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - -func TestClaudeClient_SetAuthHeader(t *testing.T) { - c := newTestClaudeClient() - c.APIKey = "sk-ant-test123" - - // net/http.Header canonicalizes keys (x-api-key โ†’ X-Api-Key). - h := make(http.Header) - c.setAuthHeader(h) - - if got := h.Get("x-api-key"); got != "sk-ant-test123" { - t.Errorf("x-api-key header not set correctly: %q", got) - } - if h.Get("anthropic-version") == "" { - t.Error("anthropic-version header must be set") - } - // Must NOT use Authorization: Bearer (that's OpenAI format). - if h.Get("Authorization") != "" { - t.Error("Claude must use x-api-key, not Authorization header") - } -} - -func TestClaudeClient_BuildUrl(t *testing.T) { - c := newTestClaudeClient() - url := c.buildUrl() - if url != DefaultClaudeBaseURL+"/messages" { - t.Errorf("URL should be /messages endpoint, got: %s", url) - } -} - -// โ”€โ”€ helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - -func newTestClaudeClient() *ClaudeClient { - return NewClaudeClientWithOptions().(*ClaudeClient) -} diff --git a/mcp/client.go b/mcp/client.go index 72f01782..d4923b38 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -60,14 +60,14 @@ type Client struct { UseFullURL bool // Whether to use full URL (without appending /chat/completions) MaxTokens int // Maximum tokens for AI response - httpClient *http.Client - logger Logger // Logger (replaceable) - config *Config // Config object (stores all configurations) + HTTPClient *http.Client // Exported for sub-packages + Log Logger // Exported for sub-packages + Cfg *Config // Exported for sub-packages - // hooks are used to implement dynamic dispatch (polymorphism) - // When DeepSeekClient embeds Client, hooks point to DeepSeekClient - // This way methods called in call() are automatically dispatched to the overridden version in subclass - hooks clientHooks + // Hooks are used to implement dynamic dispatch (polymorphism) + // When provider.DeepSeekClient embeds Client, Hooks point to DeepSeekClient + // This way methods called in Call() are automatically dispatched to the overridden version + Hooks ClientHooks } // New creates default client (backward compatible) @@ -80,21 +80,22 @@ func New() AIClient { // NewClient creates client (supports options pattern) // // Usage examples: -// // Basic usage (backward compatible) -// client := mcp.NewClient() // -// // Custom logger -// client := mcp.NewClient(mcp.WithLogger(customLogger)) +// // Basic usage (backward compatible) +// client := mcp.NewClient() // -// // Custom timeout -// client := mcp.NewClient(mcp.WithTimeout(60*time.Second)) +// // Custom logger +// client := mcp.NewClient(mcp.WithLogger(customLogger)) // -// // Combine multiple options -// client := mcp.NewClient( -// mcp.WithDeepSeekConfig("sk-xxx"), -// mcp.WithLogger(customLogger), -// mcp.WithTimeout(60*time.Second), -// ) +// // Custom timeout +// client := mcp.NewClient(mcp.WithTimeout(60*time.Second)) +// +// // Combine multiple options +// client := mcp.NewClient( +// mcp.WithDeepSeekConfig("sk-xxx"), +// mcp.WithLogger(customLogger), +// mcp.WithTimeout(60*time.Second), +// ) func NewClient(opts ...ClientOption) AIClient { // 1. Create default config cfg := DefaultConfig() @@ -112,9 +113,9 @@ func NewClient(opts ...ClientOption) AIClient { Model: cfg.Model, MaxTokens: cfg.MaxTokens, UseFullURL: cfg.UseFullURL, - httpClient: cfg.HTTPClient, - logger: cfg.Logger, - config: cfg, + HTTPClient: cfg.HTTPClient, + Log: cfg.Logger, + Cfg: cfg, } // 4. Set default Provider (if not set) @@ -125,7 +126,7 @@ func NewClient(opts ...ClientOption) AIClient { } // 5. Set hooks to point to self - client.hooks = client + client.Hooks = client return client } @@ -148,7 +149,7 @@ func (client *Client) SetAPIKey(apiKey, apiURL, customModel string) { } func (client *Client) SetTimeout(timeout time.Duration) { - client.httpClient.Timeout = timeout + client.HTTPClient.Timeout = timeout } // CallWithMessages template method - fixed retry flow (cannot be overridden) @@ -159,32 +160,32 @@ func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, // Fixed retry flow var lastErr error - maxRetries := client.config.MaxRetries + maxRetries := client.Cfg.MaxRetries for attempt := 1; attempt <= maxRetries; attempt++ { if attempt > 1 { - client.logger.Warnf("โš ๏ธ AI API call failed, retrying (%d/%d)...", attempt, maxRetries) + client.Log.Warnf("โš ๏ธ AI API call failed, retrying (%d/%d)...", attempt, maxRetries) } // Call the fixed single-call flow - result, err := client.hooks.call(systemPrompt, userPrompt) + result, err := client.Hooks.Call(systemPrompt, userPrompt) if err == nil { if attempt > 1 { - client.logger.Infof("โœ“ AI API retry succeeded") + client.Log.Infof("โœ“ AI API retry succeeded") } return result, nil } lastErr = err - // Check if error is retryable via hooks (supports custom retry strategy in subclass) - if !client.hooks.isRetryableError(err) { + // Check if error is retryable via hooks (supports custom retry strategy) + if !client.Hooks.IsRetryableError(err) { return "", err } // Wait before retry if attempt < maxRetries { - waitTime := client.config.RetryWaitBase * time.Duration(attempt) - client.logger.Infof("โณ Waiting %v before retry...", waitTime) + waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt) + client.Log.Infof("โณ Waiting %v before retry...", waitTime) time.Sleep(waitTime) } } @@ -192,11 +193,11 @@ func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, return "", fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr) } -func (client *Client) setAuthHeader(reqHeader http.Header) { +func (client *Client) SetAuthHeader(reqHeader http.Header) { reqHeader.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey)) } -func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any { +func (client *Client) BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any { // Build messages array messages := []map[string]string{} @@ -217,7 +218,7 @@ func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[s requestBody := map[string]interface{}{ "model": client.Model, "messages": messages, - "temperature": client.config.Temperature, // Use configured temperature + "temperature": client.Cfg.Temperature, // Use configured temperature } // OpenAI newer models use max_completion_tokens instead of max_tokens if client.Provider == ProviderOpenAI { @@ -228,8 +229,8 @@ func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[s return requestBody } -// can be used to marshal the request body and can be overridden -func (client *Client) marshalRequestBody(requestBody map[string]any) ([]byte, error) { +// MarshalRequestBody can be used to marshal the request body and can be overridden +func (client *Client) MarshalRequestBody(requestBody map[string]any) ([]byte, error) { jsonData, err := json.Marshal(requestBody) if err != nil { return nil, fmt.Errorf("failed to serialize request: %w", err) @@ -237,17 +238,17 @@ func (client *Client) marshalRequestBody(requestBody map[string]any) ([]byte, er return jsonData, nil } -func (client *Client) parseMCPResponse(body []byte) (string, error) { - r, err := client.parseMCPResponseFull(body) +func (client *Client) ParseMCPResponse(body []byte) (string, error) { + r, err := client.ParseMCPResponseFull(body) if err != nil { return "", err } return r.Content, nil } -// parseMCPResponseFull parses the OpenAI-format response body and returns both +// ParseMCPResponseFull parses the OpenAI-format response body and returns both // the text content and any tool calls. -func (client *Client) parseMCPResponseFull(body []byte) (*LLMResponse, error) { +func (client *Client) ParseMCPResponseFull(body []byte) (*LLMResponse, error) { var result struct { Choices []struct { Message struct { @@ -288,14 +289,14 @@ func (client *Client) parseMCPResponseFull(body []byte) (*LLMResponse, error) { }, nil } -func (client *Client) buildUrl() string { +func (client *Client) BuildUrl() string { if client.UseFullURL { return client.BaseURL } return fmt.Sprintf("%s/chat/completions", client.BaseURL) } -func (client *Client) buildRequest(url string, jsonData []byte) (*http.Request, error) { +func (client *Client) BuildRequest(url string, jsonData []byte) (*http.Request, error) { // Create HTTP request req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { @@ -304,42 +305,42 @@ func (client *Client) buildRequest(url string, jsonData []byte) (*http.Request, req.Header.Set("Content-Type", "application/json") - // Set auth header via hooks (supports overriding in subclass) - client.hooks.setAuthHeader(req.Header) + // Set auth header via hooks (supports overriding) + client.Hooks.SetAuthHeader(req.Header) return req, nil } -// call single AI API call (fixed flow, cannot be overridden) -func (client *Client) call(systemPrompt, userPrompt string) (string, error) { +// Call single AI API call (fixed flow, cannot be overridden) +func (client *Client) Call(systemPrompt, userPrompt string) (string, error) { // Print current AI configuration - client.logger.Infof("๐Ÿ“ก [%s] Request AI Server: BaseURL: %s", client.String(), client.BaseURL) - client.logger.Debugf("[%s] UseFullURL: %v", client.String(), client.UseFullURL) + client.Log.Infof("๐Ÿ“ก [%s] Request AI Server: BaseURL: %s", client.String(), client.BaseURL) + client.Log.Debugf("[%s] UseFullURL: %v", client.String(), client.UseFullURL) if len(client.APIKey) > 8 { - client.logger.Debugf("[%s] API Key: %s...%s", client.String(), client.APIKey[:4], client.APIKey[len(client.APIKey)-4:]) + client.Log.Debugf("[%s] API Key: %s...%s", client.String(), client.APIKey[:4], client.APIKey[len(client.APIKey)-4:]) } // Step 1: Build request body (via hooks for dynamic dispatch) - requestBody := client.hooks.buildMCPRequestBody(systemPrompt, userPrompt) + requestBody := client.Hooks.BuildMCPRequestBody(systemPrompt, userPrompt) // Step 2: Serialize request body (via hooks for dynamic dispatch) - jsonData, err := client.hooks.marshalRequestBody(requestBody) + jsonData, err := client.Hooks.MarshalRequestBody(requestBody) if err != nil { return "", err } // Step 3: Build URL (via hooks for dynamic dispatch) - url := client.hooks.buildUrl() - client.logger.Infof("๐Ÿ“ก [MCP %s] Request URL: %s", client.String(), url) + url := client.Hooks.BuildUrl() + client.Log.Infof("๐Ÿ“ก [MCP %s] Request URL: %s", client.String(), url) // Step 4: Create HTTP request (fixed logic) - req, err := client.hooks.buildRequest(url, jsonData) + req, err := client.Hooks.BuildRequest(url, jsonData) if err != nil { return "", fmt.Errorf("failed to create request: %w", err) } // Step 5: Send HTTP request (fixed logic) - resp, err := client.httpClient.Do(req) + resp, err := client.HTTPClient.Do(req) if err != nil { return "", fmt.Errorf("failed to send request: %w", err) } @@ -357,7 +358,7 @@ func (client *Client) call(systemPrompt, userPrompt string) (string, error) { } // Step 8: Parse response (via hooks for dynamic dispatch) - result, err := client.hooks.parseMCPResponse(body) + result, err := client.Hooks.ParseMCPResponse(body) if err != nil { return "", fmt.Errorf("fail to parse AI server response: %w", err) } @@ -370,11 +371,11 @@ func (client *Client) String() string { client.Provider, client.Model) } -// isRetryableError determines if error is retryable (network errors, timeouts, etc.) -func (client *Client) isRetryableError(err error) bool { +// IsRetryableError determines if error is retryable (network errors, timeouts, etc.) +func (client *Client) IsRetryableError(err error) bool { errStr := err.Error() // Network errors, timeouts, EOF, etc. can be retried - for _, retryable := range client.config.RetryableErrors { + for _, retryable := range client.Cfg.RetryableErrors { if strings.Contains(errStr, retryable) { return true } @@ -387,20 +388,6 @@ func (client *Client) isRetryableError(err error) bool { // ============================================================ // CallWithRequest calls AI API using Request object (supports advanced features) -// -// This method supports: -// - Multi-turn conversation history -// - Fine-grained parameter control (temperature, top_p, penalties, etc.) -// - Function Calling / Tools -// - Streaming response (future support) -// -// Usage example: -// request := NewRequestBuilder(). -// WithSystemPrompt("You are helpful"). -// WithUserPrompt("Hello"). -// WithTemperature(0.8). -// Build() -// result, err := client.CallWithRequest(request) func (client *Client) CallWithRequest(req *Request) (string, error) { if client.APIKey == "" { return "", fmt.Errorf("AI API key not set, please call SetAPIKey first") @@ -413,32 +400,32 @@ func (client *Client) CallWithRequest(req *Request) (string, error) { // Fixed retry flow var lastErr error - maxRetries := client.config.MaxRetries + maxRetries := client.Cfg.MaxRetries for attempt := 1; attempt <= maxRetries; attempt++ { if attempt > 1 { - client.logger.Warnf("โš ๏ธ AI API call failed, retrying (%d/%d)...", attempt, maxRetries) + client.Log.Warnf("โš ๏ธ AI API call failed, retrying (%d/%d)...", attempt, maxRetries) } // Call single request result, err := client.callWithRequest(req) if err == nil { if attempt > 1 { - client.logger.Infof("โœ“ AI API retry succeeded") + client.Log.Infof("โœ“ AI API retry succeeded") } return result, nil } lastErr = err // Check if error is retryable - if !client.hooks.isRetryableError(err) { + if !client.Hooks.IsRetryableError(err) { return "", err } // Wait before retry if attempt < maxRetries { - waitTime := client.config.RetryWaitBase * time.Duration(attempt) - client.logger.Infof("โณ Waiting %v before retry...", waitTime) + waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt) + client.Log.Infof("โณ Waiting %v before retry...", waitTime) time.Sleep(waitTime) } } @@ -456,21 +443,21 @@ func (client *Client) CallWithRequestFull(req *Request) (*LLMResponse, error) { } var lastErr error - maxRetries := client.config.MaxRetries + maxRetries := client.Cfg.MaxRetries for attempt := 1; attempt <= maxRetries; attempt++ { if attempt > 1 { - client.logger.Warnf("โš ๏ธ AI API call failed, retrying (%d/%d)...", attempt, maxRetries) + client.Log.Warnf("โš ๏ธ AI API call failed, retrying (%d/%d)...", attempt, maxRetries) } result, err := client.callWithRequestFull(req) if err == nil { return result, nil } lastErr = err - if !client.hooks.isRetryableError(err) { + if !client.Hooks.IsRetryableError(err) { return nil, err } if attempt < maxRetries { - waitTime := client.config.RetryWaitBase * time.Duration(attempt) + waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt) time.Sleep(waitTime) } } @@ -479,21 +466,21 @@ func (client *Client) CallWithRequestFull(req *Request) (*LLMResponse, error) { // callWithRequestFull single call that returns LLMResponse (content + tool calls). func (client *Client) callWithRequestFull(req *Request) (*LLMResponse, error) { - client.logger.Infof("๐Ÿ“ก [%s] Request AI Server (full): BaseURL: %s", client.String(), client.BaseURL) + client.Log.Infof("๐Ÿ“ก [%s] Request AI Server (full): BaseURL: %s", client.String(), client.BaseURL) - requestBody := client.hooks.buildRequestBodyFromRequest(req) - jsonData, err := client.hooks.marshalRequestBody(requestBody) + requestBody := client.Hooks.BuildRequestBodyFromRequest(req) + jsonData, err := client.Hooks.MarshalRequestBody(requestBody) if err != nil { return nil, err } - url := client.hooks.buildUrl() - httpReq, err := client.hooks.buildRequest(url, jsonData) + url := client.Hooks.BuildUrl() + httpReq, err := client.Hooks.BuildRequest(url, jsonData) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - resp, err := client.httpClient.Do(httpReq) + resp, err := client.HTTPClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } @@ -507,31 +494,31 @@ func (client *Client) callWithRequestFull(req *Request) (*LLMResponse, error) { return nil, fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body)) } - return client.hooks.parseMCPResponseFull(body) + return client.Hooks.ParseMCPResponseFull(body) } // callWithRequest single AI API call (using Request object) func (client *Client) callWithRequest(req *Request) (string, error) { // Print current AI configuration - client.logger.Infof("๐Ÿ“ก [%s] Request AI Server with Builder: BaseURL: %s", client.String(), client.BaseURL) - client.logger.Debugf("[%s] Messages count: %d", client.String(), len(req.Messages)) + client.Log.Infof("๐Ÿ“ก [%s] Request AI Server with Builder: BaseURL: %s", client.String(), client.BaseURL) + client.Log.Debugf("[%s] Messages count: %d", client.String(), len(req.Messages)) - requestBody := client.hooks.buildRequestBodyFromRequest(req) + requestBody := client.Hooks.BuildRequestBodyFromRequest(req) - jsonData, err := client.hooks.marshalRequestBody(requestBody) + jsonData, err := client.Hooks.MarshalRequestBody(requestBody) if err != nil { return "", err } - url := client.hooks.buildUrl() - client.logger.Infof("๐Ÿ“ก [MCP %s] Request URL: %s", client.String(), url) + url := client.Hooks.BuildUrl() + client.Log.Infof("๐Ÿ“ก [MCP %s] Request URL: %s", client.String(), url) - httpReq, err := client.hooks.buildRequest(url, jsonData) + httpReq, err := client.Hooks.BuildRequest(url, jsonData) if err != nil { return "", fmt.Errorf("failed to create request: %w", err) } - resp, err := client.httpClient.Do(httpReq) + resp, err := client.HTTPClient.Do(httpReq) if err != nil { return "", fmt.Errorf("failed to send request: %w", err) } @@ -546,7 +533,7 @@ func (client *Client) callWithRequest(req *Request) (string, error) { return "", fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body)) } - result, err := client.hooks.parseMCPResponse(body) + result, err := client.Hooks.ParseMCPResponse(body) if err != nil { return "", fmt.Errorf("fail to parse AI server response: %w", err) } @@ -554,8 +541,8 @@ func (client *Client) callWithRequest(req *Request) (string, error) { return result, nil } -// buildRequestBodyFromRequest builds request body from Request object -func (client *Client) buildRequestBodyFromRequest(req *Request) map[string]any { +// BuildRequestBodyFromRequest builds request body from Request object +func (client *Client) BuildRequestBodyFromRequest(req *Request) map[string]any { // Convert Message to API format โ€” must use map[string]any to support // tool-call messages (tool_calls, tool_call_id fields). messages := make([]map[string]any, 0, len(req.Messages)) @@ -586,7 +573,7 @@ func (client *Client) buildRequestBodyFromRequest(req *Request) map[string]any { requestBody["temperature"] = *req.Temperature } else { // If not set in Request, use Client's configuration - requestBody["temperature"] = client.config.Temperature + requestBody["temperature"] = client.Cfg.Temperature } // OpenAI newer models use max_completion_tokens instead of max_tokens @@ -647,19 +634,19 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string)) } req.Stream = true - requestBody := client.hooks.buildRequestBodyFromRequest(req) - jsonData, err := client.hooks.marshalRequestBody(requestBody) + requestBody := client.Hooks.BuildRequestBodyFromRequest(req) + jsonData, err := client.Hooks.MarshalRequestBody(requestBody) if err != nil { return "", err } - url := client.hooks.buildUrl() - httpReq, err := client.hooks.buildRequest(url, jsonData) + url := client.Hooks.BuildUrl() + httpReq, err := client.Hooks.BuildRequest(url, jsonData) if err != nil { return "", err } - // Idle-timeout watchdog: cancel the request if no SSE line arrives for 30 seconds. + // Idle-timeout watchdog: cancel the request if no SSE line arrives for 60 seconds. // This breaks the scanner out of an indefinitely blocking Read on a hung connection. const idleTimeout = 60 * time.Second ctx, cancel := context.WithCancel(context.Background()) @@ -689,7 +676,7 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string)) }() httpReq = httpReq.WithContext(ctx) - resp, err := client.httpClient.Do(httpReq) + resp, err := client.HTTPClient.Do(httpReq) if err != nil { return "", fmt.Errorf("streaming request failed: %w", err) } diff --git a/mcp/client_test.go b/mcp/client_test.go index 4af38703..b76890dd 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -27,16 +27,16 @@ func TestNewClient_Default(t *testing.T) { t.Error("MaxTokens should be positive") } - if c.logger == nil { - t.Error("logger should not be nil") + if c.Log == nil { + t.Error("Log should not be nil") } - if c.httpClient == nil { - t.Error("httpClient should not be nil") + if c.HTTPClient == nil { + t.Error("HTTPClient should not be nil") } - if c.hooks == nil { - t.Error("hooks should not be nil") + if c.Hooks == nil { + t.Error("Hooks should not be nil") } } @@ -54,12 +54,12 @@ func TestNewClient_WithOptions(t *testing.T) { c := client.(*Client) - if c.logger != mockLogger { - t.Error("logger should be set from option") + if c.Log != mockLogger { + t.Error("Log should be set from option") } - if c.httpClient != mockHTTP { - t.Error("httpClient should be set from option") + if c.HTTPClient != mockHTTP { + t.Error("HTTPClient should be set from option") } if c.MaxTokens != 4000 { @@ -174,7 +174,7 @@ func TestClient_Retry_Success(t *testing.T) { WithMaxRetries(3), ) - // Since our client uses hooks.call, need special handling + // Since our client uses Hooks.Call, need special handling // Here we test that CallWithMessages will invoke retry logic c := client.(*Client) @@ -242,7 +242,7 @@ func TestClient_BuildMCPRequestBody(t *testing.T) { client := NewClient() c := client.(*Client) - body := c.buildMCPRequestBody("system prompt", "user prompt") + body := c.BuildMCPRequestBody("system prompt", "user prompt") if body == nil { t.Fatal("body should not be nil") @@ -300,7 +300,7 @@ func TestClient_BuildUrl(t *testing.T) { ) c := client.(*Client) - url := c.buildUrl() + url := c.BuildUrl() if url != tt.expected { t.Errorf("expected '%s', got '%s'", tt.expected, url) } @@ -313,7 +313,7 @@ func TestClient_SetAuthHeader(t *testing.T) { c := client.(*Client) headers := make(http.Header) - c.setAuthHeader(headers) + c.SetAuthHeader(headers) authHeader := headers.Get("Authorization") if authHeader != "Bearer test-api-key" { @@ -359,7 +359,7 @@ func TestClient_IsRetryableError(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := c.isRetryableError(tt.err) + result := c.IsRetryableError(tt.err) if result != tt.expected { t.Errorf("expected %v, got %v", tt.expected, result) } @@ -378,8 +378,8 @@ func TestClient_SetTimeout(t *testing.T) { client.SetTimeout(newTimeout) c := client.(*Client) - if c.httpClient.Timeout != newTimeout { - t.Errorf("expected timeout %v, got %v", newTimeout, c.httpClient.Timeout) + if c.HTTPClient.Timeout != newTimeout { + t.Errorf("expected timeout %v, got %v", newTimeout, c.HTTPClient.Timeout) } } diff --git a/mcp/config_usage_test.go b/mcp/config_usage_test.go index c0e984f8..8186506d 100644 --- a/mcp/config_usage_test.go +++ b/mcp/config_usage_test.go @@ -84,7 +84,7 @@ func TestConfig_Temperature_IsUsed(t *testing.T) { c := client.(*Client) // Build request body - requestBody := c.buildMCPRequestBody("system", "user") + requestBody := c.BuildMCPRequestBody("system", "user") // Verify temperature field temp, ok := requestBody["temperature"].(float64) @@ -201,7 +201,7 @@ func TestConfig_RetryableErrors_IsUsed(t *testing.T) { c := client.(*Client) // Modify config's RetryableErrors (no WithRetryableErrors option yet) - c.config.RetryableErrors = customRetryableErrors + c.Cfg.RetryableErrors = customRetryableErrors tests := []struct { name string @@ -227,7 +227,7 @@ func TestConfig_RetryableErrors_IsUsed(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := c.isRetryableError(tt.err) + result := c.IsRetryableError(tt.err) if result != tt.retryable { t.Errorf("expected isRetryableError(%v) = %v, got %v", tt.err, tt.retryable, result) } @@ -244,19 +244,19 @@ func TestConfig_DefaultValues(t *testing.T) { c := client.(*Client) // Verify default values - if c.config.MaxRetries != 3 { - t.Errorf("default MaxRetries should be 3, got %d", c.config.MaxRetries) + if c.Cfg.MaxRetries != 3 { + t.Errorf("default MaxRetries should be 3, got %d", c.Cfg.MaxRetries) } - if c.config.Temperature != 0.5 { - t.Errorf("default Temperature should be 0.5, got %f", c.config.Temperature) + if c.Cfg.Temperature != 0.5 { + t.Errorf("default Temperature should be 0.5, got %f", c.Cfg.Temperature) } - if c.config.RetryWaitBase != 2*time.Second { - t.Errorf("default RetryWaitBase should be 2s, got %v", c.config.RetryWaitBase) + if c.Cfg.RetryWaitBase != 2*time.Second { + t.Errorf("default RetryWaitBase should be 2s, got %v", c.Cfg.RetryWaitBase) } - if len(c.config.RetryableErrors) == 0 { + if len(c.Cfg.RetryableErrors) == 0 { t.Error("default RetryableErrors should not be empty") } } diff --git a/mcp/deepseek_client.go b/mcp/deepseek_client.go deleted file mode 100644 index 5972c806..00000000 --- a/mcp/deepseek_client.go +++ /dev/null @@ -1,83 +0,0 @@ -package mcp - -import ( - "net/http" -) - -const ( - ProviderDeepSeek = "deepseek" - DefaultDeepSeekBaseURL = "https://api.deepseek.com" - DefaultDeepSeekModel = "deepseek-chat" -) - -type DeepSeekClient struct { - *Client -} - -// NewDeepSeekClient creates DeepSeek client (backward compatible) -// -// Deprecated: Recommend using NewDeepSeekClientWithOptions for better flexibility -func NewDeepSeekClient() AIClient { - return NewDeepSeekClientWithOptions() -} - -// NewDeepSeekClientWithOptions creates DeepSeek client (supports options pattern) -// -// Usage examples: -// // Basic usage -// client := mcp.NewDeepSeekClientWithOptions() -// -// // Custom configuration -// client := mcp.NewDeepSeekClientWithOptions( -// mcp.WithAPIKey("sk-xxx"), -// mcp.WithLogger(customLogger), -// mcp.WithTimeout(60*time.Second), -// ) -func NewDeepSeekClientWithOptions(opts ...ClientOption) AIClient { - // 1. Create DeepSeek preset options - deepseekOpts := []ClientOption{ - WithProvider(ProviderDeepSeek), - WithModel(DefaultDeepSeekModel), - WithBaseURL(DefaultDeepSeekBaseURL), - } - - // 2. Merge user options (user options have higher priority) - allOpts := append(deepseekOpts, opts...) - - // 3. Create base client - baseClient := NewClient(allOpts...).(*Client) - - // 4. Create DeepSeek client - dsClient := &DeepSeekClient{ - Client: baseClient, - } - - // 5. Set hooks to point to DeepSeekClient (implement dynamic dispatch) - baseClient.hooks = dsClient - - return dsClient -} - -func (dsClient *DeepSeekClient) SetAPIKey(apiKey string, customURL string, customModel string) { - dsClient.APIKey = apiKey - - if len(apiKey) > 8 { - dsClient.logger.Infof("๐Ÿ”ง [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) - } - if customURL != "" { - dsClient.BaseURL = customURL - dsClient.logger.Infof("๐Ÿ”ง [MCP] DeepSeek using custom BaseURL: %s", customURL) - } else { - dsClient.logger.Infof("๐Ÿ”ง [MCP] DeepSeek using default BaseURL: %s", dsClient.BaseURL) - } - if customModel != "" { - dsClient.Model = customModel - dsClient.logger.Infof("๐Ÿ”ง [MCP] DeepSeek using custom Model: %s", customModel) - } else { - dsClient.logger.Infof("๐Ÿ”ง [MCP] DeepSeek using default Model: %s", dsClient.Model) - } -} - -func (dsClient *DeepSeekClient) setAuthHeader(reqHeaders http.Header) { - dsClient.Client.setAuthHeader(reqHeaders) -} diff --git a/mcp/deepseek_client_test.go b/mcp/deepseek_client_test.go deleted file mode 100644 index afcd3b81..00000000 --- a/mcp/deepseek_client_test.go +++ /dev/null @@ -1,272 +0,0 @@ -package mcp - -import ( - "testing" - "time" -) - -// ============================================================ -// Test DeepSeekClient Creation and Configuration -// ============================================================ - -func TestNewDeepSeekClient_Default(t *testing.T) { - client := NewDeepSeekClient() - - if client == nil { - t.Fatal("client should not be nil") - } - - // Type assertion check - dsClient, ok := client.(*DeepSeekClient) - if !ok { - t.Fatal("client should be *DeepSeekClient") - } - - // Verify default values - if dsClient.Provider != ProviderDeepSeek { - t.Errorf("Provider should be '%s', got '%s'", ProviderDeepSeek, dsClient.Provider) - } - - if dsClient.BaseURL != DefaultDeepSeekBaseURL { - t.Errorf("BaseURL should be '%s', got '%s'", DefaultDeepSeekBaseURL, dsClient.BaseURL) - } - - if dsClient.Model != DefaultDeepSeekModel { - t.Errorf("Model should be '%s', got '%s'", DefaultDeepSeekModel, dsClient.Model) - } - - if dsClient.logger == nil { - t.Error("logger should not be nil") - } - - if dsClient.httpClient == nil { - t.Error("httpClient should not be nil") - } -} - -func TestNewDeepSeekClientWithOptions(t *testing.T) { - mockLogger := NewMockLogger() - customModel := "deepseek-v2" - customAPIKey := "sk-custom-key" - - client := NewDeepSeekClientWithOptions( - WithLogger(mockLogger), - WithModel(customModel), - WithAPIKey(customAPIKey), - WithMaxTokens(4000), - ) - - dsClient := client.(*DeepSeekClient) - - // Verify custom options are applied - if dsClient.logger != mockLogger { - t.Error("logger should be set from option") - } - - if dsClient.Model != customModel { - t.Error("Model should be set from option") - } - - if dsClient.APIKey != customAPIKey { - t.Error("APIKey should be set from option") - } - - if dsClient.MaxTokens != 4000 { - t.Error("MaxTokens should be 4000") - } - - // Verify DeepSeek default values are retained - if dsClient.Provider != ProviderDeepSeek { - t.Errorf("Provider should still be '%s'", ProviderDeepSeek) - } - - if dsClient.BaseURL != DefaultDeepSeekBaseURL { - t.Errorf("BaseURL should still be '%s'", DefaultDeepSeekBaseURL) - } -} - -// ============================================================ -// Test SetAPIKey -// ============================================================ - -func TestDeepSeekClient_SetAPIKey(t *testing.T) { - mockLogger := NewMockLogger() - client := NewDeepSeekClientWithOptions( - WithLogger(mockLogger), - ) - - dsClient := client.(*DeepSeekClient) - - // Test setting API Key (default URL and Model) - dsClient.SetAPIKey("sk-test-key-12345678", "", "") - - if dsClient.APIKey != "sk-test-key-12345678" { - t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", dsClient.APIKey) - } - - // Verify logging - logs := mockLogger.GetLogsByLevel("INFO") - if len(logs) == 0 { - t.Error("should have logged API key setting") - } - - // Verify BaseURL and Model remain default - if dsClient.BaseURL != DefaultDeepSeekBaseURL { - t.Error("BaseURL should remain default") - } - - if dsClient.Model != DefaultDeepSeekModel { - t.Error("Model should remain default") - } -} - -func TestDeepSeekClient_SetAPIKey_WithCustomURL(t *testing.T) { - mockLogger := NewMockLogger() - client := NewDeepSeekClientWithOptions( - WithLogger(mockLogger), - ) - - dsClient := client.(*DeepSeekClient) - - customURL := "https://custom.api.com/v1" - dsClient.SetAPIKey("sk-test-key-12345678", customURL, "") - - if dsClient.BaseURL != customURL { - t.Errorf("BaseURL should be '%s', got '%s'", customURL, dsClient.BaseURL) - } - - // Verify logging - logs := mockLogger.GetLogsByLevel("INFO") - hasCustomURLLog := false - for _, log := range logs { - if log.Format == "๐Ÿ”ง [MCP] DeepSeek using custom BaseURL: %s" { - hasCustomURLLog = true - break - } - } - - if !hasCustomURLLog { - t.Error("should have logged custom BaseURL") - } -} - -func TestDeepSeekClient_SetAPIKey_WithCustomModel(t *testing.T) { - mockLogger := NewMockLogger() - client := NewDeepSeekClientWithOptions( - WithLogger(mockLogger), - ) - - dsClient := client.(*DeepSeekClient) - - customModel := "deepseek-v3" - dsClient.SetAPIKey("sk-test-key-12345678", "", customModel) - - if dsClient.Model != customModel { - t.Errorf("Model should be '%s', got '%s'", customModel, dsClient.Model) - } - - // Verify logging - logs := mockLogger.GetLogsByLevel("INFO") - hasCustomModelLog := false - for _, log := range logs { - if log.Format == "๐Ÿ”ง [MCP] DeepSeek using custom Model: %s" { - hasCustomModelLog = true - break - } - } - - if !hasCustomModelLog { - t.Error("should have logged custom Model") - } -} - -// ============================================================ -// Test Integration Features -// ============================================================ - -func TestDeepSeekClient_CallWithMessages_Success(t *testing.T) { - mockHTTP := NewMockHTTPClient() - mockHTTP.SetSuccessResponse("DeepSeek AI response") - mockLogger := NewMockLogger() - - client := NewDeepSeekClientWithOptions( - WithHTTPClient(mockHTTP.ToHTTPClient()), - WithLogger(mockLogger), - WithAPIKey("sk-test-key"), - ) - - result, err := client.CallWithMessages("system prompt", "user prompt") - - if err != nil { - t.Fatalf("should not error: %v", err) - } - - if result != "DeepSeek AI response" { - t.Errorf("expected 'DeepSeek AI response', got '%s'", result) - } - - // Verify request - requests := mockHTTP.GetRequests() - if len(requests) != 1 { - t.Fatalf("expected 1 request, got %d", len(requests)) - } - - req := requests[0] - - // Verify URL - expectedURL := DefaultDeepSeekBaseURL + "/chat/completions" - if req.URL.String() != expectedURL { - t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String()) - } - - // Verify Authorization header - authHeader := req.Header.Get("Authorization") - if authHeader != "Bearer sk-test-key" { - t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader) - } - - // Verify Content-Type - if req.Header.Get("Content-Type") != "application/json" { - t.Error("Content-Type should be application/json") - } -} - -func TestDeepSeekClient_Timeout(t *testing.T) { - client := NewDeepSeekClientWithOptions( - WithTimeout(30 * time.Second), - ) - - dsClient := client.(*DeepSeekClient) - - if dsClient.httpClient.Timeout != 30*time.Second { - t.Errorf("expected timeout 30s, got %v", dsClient.httpClient.Timeout) - } - - // Test SetTimeout - client.SetTimeout(60 * time.Second) - - if dsClient.httpClient.Timeout != 60*time.Second { - t.Errorf("expected timeout 60s after SetTimeout, got %v", dsClient.httpClient.Timeout) - } -} - -// ============================================================ -// Test hooks Mechanism -// ============================================================ - -func TestDeepSeekClient_HooksIntegration(t *testing.T) { - client := NewDeepSeekClientWithOptions() - dsClient := client.(*DeepSeekClient) - - // Verify hooks point to dsClient itself (implements polymorphism) - if dsClient.hooks != dsClient { - t.Error("hooks should point to dsClient for polymorphism") - } - - // Verify buildUrl uses DeepSeek configuration - url := dsClient.buildUrl() - expectedURL := DefaultDeepSeekBaseURL + "/chat/completions" - if url != expectedURL { - t.Errorf("expected URL '%s', got '%s'", expectedURL, url) - } -} diff --git a/mcp/examples_test.go b/mcp/examples_test.go index 78125421..36888007 100644 --- a/mcp/examples_test.go +++ b/mcp/examples_test.go @@ -6,6 +6,7 @@ import ( "time" "nofx/mcp" + "nofx/mcp/provider" ) // ============================================================ @@ -24,7 +25,7 @@ func Example_backward_compatible() { func Example_deepseek_backward_compatible() { // DeepSeek old code continues to work - client := mcp.NewDeepSeekClient() + client := provider.NewDeepSeekClient() client.SetAPIKey("sk-xxx", "", "") result, _ := client.CallWithMessages("system", "user") @@ -141,12 +142,12 @@ func Example_custom_http_client() { func Example_deepseek_new_api() { // Basic usage - client := mcp.NewDeepSeekClientWithOptions( + client := provider.NewDeepSeekClientWithOptions( mcp.WithAPIKey("sk-xxx"), ) // Advanced usage - client = mcp.NewDeepSeekClientWithOptions( + client = provider.NewDeepSeekClientWithOptions( mcp.WithAPIKey("sk-xxx"), mcp.WithLogger(&CustomLogger{}), mcp.WithTimeout(90*time.Second), @@ -163,12 +164,12 @@ func Example_deepseek_new_api() { func Example_qwen_new_api() { // Basic usage - client := mcp.NewQwenClientWithOptions( + client := provider.NewQwenClientWithOptions( mcp.WithAPIKey("sk-xxx"), ) // Advanced usage - client = mcp.NewQwenClientWithOptions( + client = provider.NewQwenClientWithOptions( mcp.WithAPIKey("sk-xxx"), mcp.WithLogger(&CustomLogger{}), mcp.WithTimeout(90*time.Second), @@ -185,7 +186,7 @@ func Example_qwen_new_api() { func Example_trader_migration() { // Old code (continues to work) oldStyleClient := func(apiKey, customURL, customModel string) mcp.AIClient { - client := mcp.NewDeepSeekClient() + client := provider.NewDeepSeekClient() client.SetAPIKey(apiKey, customURL, customModel) return client } @@ -204,7 +205,7 @@ func Example_trader_migration() { opts = append(opts, mcp.WithModel(customModel)) } - return mcp.NewDeepSeekClientWithOptions(opts...) + return provider.NewDeepSeekClientWithOptions(opts...) } // Both approaches work @@ -230,13 +231,7 @@ func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) { } func Example_testing_with_mock() { - // Use Mock during testing - // mockHTTP := &MockHTTPClient{ - // Response: `{"choices":[{"message":{"content":"test response"}}]}`, - // } - client := mcp.NewClient( - // mcp.WithHTTPClient(mockHTTP), // Use mockHTTP in actual tests mcp.WithLogger(mcp.NewNoopLogger()), // Disable logging ) @@ -258,7 +253,6 @@ func Example_environment_specific() { // Production environment: structured logging + timeout protection prodClient := mcp.NewClient( mcp.WithDeepSeekConfig("sk-xxx"), - // mcp.WithLogger(&ZapLogger{}), // Production-grade logging mcp.WithTimeout(30*time.Second), mcp.WithMaxRetries(3), ) @@ -273,7 +267,7 @@ func Example_environment_specific() { func Example_real_world_usage() { // Create client with complete configuration - client := mcp.NewDeepSeekClientWithOptions( + client := provider.NewDeepSeekClientWithOptions( mcp.WithAPIKey("sk-xxxxxxxxxx"), mcp.WithTimeout(60*time.Second), mcp.WithMaxRetries(5), diff --git a/mcp/gemini_client.go b/mcp/gemini_client.go deleted file mode 100644 index 43b469a4..00000000 --- a/mcp/gemini_client.go +++ /dev/null @@ -1,71 +0,0 @@ -package mcp - -import ( - "net/http" -) - -const ( - ProviderGemini = "gemini" - DefaultGeminiBaseURL = "https://generativelanguage.googleapis.com/v1beta/openai" - DefaultGeminiModel = "gemini-3-pro-preview" -) - -type GeminiClient struct { - *Client -} - -// NewGeminiClient creates Gemini client (backward compatible) -func NewGeminiClient() AIClient { - return NewGeminiClientWithOptions() -} - -// NewGeminiClientWithOptions creates Gemini client (supports options pattern) -func NewGeminiClientWithOptions(opts ...ClientOption) AIClient { - // 1. Create Gemini preset options - geminiOpts := []ClientOption{ - WithProvider(ProviderGemini), - WithModel(DefaultGeminiModel), - WithBaseURL(DefaultGeminiBaseURL), - } - - // 2. Merge user options (user options have higher priority) - allOpts := append(geminiOpts, opts...) - - // 3. Create base client - baseClient := NewClient(allOpts...).(*Client) - - // 4. Create Gemini client - geminiClient := &GeminiClient{ - Client: baseClient, - } - - // 5. Set hooks to point to GeminiClient (implement dynamic dispatch) - baseClient.hooks = geminiClient - - return geminiClient -} - -func (c *GeminiClient) SetAPIKey(apiKey string, customURL string, customModel string) { - c.APIKey = apiKey - - if len(apiKey) > 8 { - c.logger.Infof("๐Ÿ”ง [MCP] Gemini API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) - } - if customURL != "" { - c.BaseURL = customURL - c.logger.Infof("๐Ÿ”ง [MCP] Gemini using custom BaseURL: %s", customURL) - } else { - c.logger.Infof("๐Ÿ”ง [MCP] Gemini using default BaseURL: %s", c.BaseURL) - } - if customModel != "" { - c.Model = customModel - c.logger.Infof("๐Ÿ”ง [MCP] Gemini using custom Model: %s", customModel) - } else { - c.logger.Infof("๐Ÿ”ง [MCP] Gemini using default Model: %s", c.Model) - } -} - -// Gemini OpenAI-compatible API uses standard Bearer auth -func (c *GeminiClient) setAuthHeader(reqHeaders http.Header) { - c.Client.setAuthHeader(reqHeaders) -} diff --git a/mcp/grok_client.go b/mcp/grok_client.go deleted file mode 100644 index c08be624..00000000 --- a/mcp/grok_client.go +++ /dev/null @@ -1,71 +0,0 @@ -package mcp - -import ( - "net/http" -) - -const ( - ProviderGrok = "grok" - DefaultGrokBaseURL = "https://api.x.ai/v1" - DefaultGrokModel = "grok-3-latest" -) - -type GrokClient struct { - *Client -} - -// NewGrokClient creates Grok client (backward compatible) -func NewGrokClient() AIClient { - return NewGrokClientWithOptions() -} - -// NewGrokClientWithOptions creates Grok client (supports options pattern) -func NewGrokClientWithOptions(opts ...ClientOption) AIClient { - // 1. Create Grok preset options - grokOpts := []ClientOption{ - WithProvider(ProviderGrok), - WithModel(DefaultGrokModel), - WithBaseURL(DefaultGrokBaseURL), - } - - // 2. Merge user options (user options have higher priority) - allOpts := append(grokOpts, opts...) - - // 3. Create base client - baseClient := NewClient(allOpts...).(*Client) - - // 4. Create Grok client - grokClient := &GrokClient{ - Client: baseClient, - } - - // 5. Set hooks to point to GrokClient (implement dynamic dispatch) - baseClient.hooks = grokClient - - return grokClient -} - -func (c *GrokClient) SetAPIKey(apiKey string, customURL string, customModel string) { - c.APIKey = apiKey - - if len(apiKey) > 8 { - c.logger.Infof("๐Ÿ”ง [MCP] Grok API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) - } - if customURL != "" { - c.BaseURL = customURL - c.logger.Infof("๐Ÿ”ง [MCP] Grok using custom BaseURL: %s", customURL) - } else { - c.logger.Infof("๐Ÿ”ง [MCP] Grok using default BaseURL: %s", c.BaseURL) - } - if customModel != "" { - c.Model = customModel - c.logger.Infof("๐Ÿ”ง [MCP] Grok using custom Model: %s", customModel) - } else { - c.logger.Infof("๐Ÿ”ง [MCP] Grok using default Model: %s", c.Model) - } -} - -// Grok uses standard OpenAI-compatible API with Bearer auth -func (c *GrokClient) setAuthHeader(reqHeaders http.Header) { - c.Client.setAuthHeader(reqHeaders) -} diff --git a/mcp/hooks.go b/mcp/hooks.go new file mode 100644 index 00000000..9b19f02a --- /dev/null +++ b/mcp/hooks.go @@ -0,0 +1,37 @@ +package mcp + +import "net/http" + +// ClientHooks is the dispatch interface used to implement per-provider +// polymorphism without Go's lack of virtual methods. +// +// Each method can be overridden by an embedding struct (e.g. provider.ClaudeClient). +// The base *Client provides OpenAI-compatible defaults; providers with a +// different wire format (Anthropic, Gemini native, etc.) override only what +// differs. All call-path methods in client.go invoke these via c.Hooks so +// that the override is always picked up at runtime. +type ClientHooks interface { + // โ”€โ”€ Simple CallWithMessages path โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + Call(systemPrompt, userPrompt string) (string, error) + BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any + + // โ”€โ”€ Shared request plumbing โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + BuildUrl() string + BuildRequest(url string, jsonData []byte) (*http.Request, error) + SetAuthHeader(reqHeaders http.Header) + MarshalRequestBody(requestBody map[string]any) ([]byte, error) + + // โ”€โ”€ Advanced (Request-object) path โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + // BuildRequestBodyFromRequest converts a *Request into the provider's + // native wire-format map. + BuildRequestBodyFromRequest(req *Request) map[string]any + + // ParseMCPResponse extracts the plain-text reply from a non-streaming + // response body. + ParseMCPResponse(body []byte) (string, error) + + // ParseMCPResponseFull extracts both text and tool calls. + ParseMCPResponseFull(body []byte) (*LLMResponse, error) + + IsRetryableError(err error) bool +} diff --git a/mcp/interface.go b/mcp/interface.go index c7e62edc..deecbe2a 100644 --- a/mcp/interface.go +++ b/mcp/interface.go @@ -1,10 +1,15 @@ package mcp import ( - "net/http" "time" ) +// ClientEmbedder is implemented by provider types that embed *Client, +// allowing generic extraction of the underlying base client (e.g. for cloning). +type ClientEmbedder interface { + BaseClient() *Client +} + // AIClient public AI client interface (for external use) type AIClient interface { SetAPIKey(apiKey string, customURL string, customModel string) @@ -21,41 +26,3 @@ type AIClient interface { // (LLMResponse.ToolCalls), but not both. CallWithRequestFull(req *Request) (*LLMResponse, error) } - -// clientHooks is the internal dispatch interface used to implement per-provider -// polymorphism without Go's lack of virtual methods. -// -// Each method can be overridden by an embedding struct (e.g. ClaudeClient). -// The base *Client provides OpenAI-compatible defaults; providers with a -// different wire format (Anthropic, Gemini native, etc.) override only what -// differs. All call-path methods in client.go invoke these via c.hooks so -// that the override is always picked up at runtime. -type clientHooks interface { - // โ”€โ”€ Simple CallWithMessages path โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - call(systemPrompt, userPrompt string) (string, error) - buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any - - // โ”€โ”€ Shared request plumbing โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - buildUrl() string - buildRequest(url string, jsonData []byte) (*http.Request, error) - setAuthHeader(reqHeaders http.Header) - marshalRequestBody(requestBody map[string]any) ([]byte, error) - - // โ”€โ”€ Advanced (Request-object) path โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - // buildRequestBodyFromRequest converts a *Request into the provider's - // native wire-format map. Providers that use a different protocol (e.g. - // Anthropic uses "input_schema" for tools, "tool_use" content blocks, and - // a top-level "system" field) override this method. - buildRequestBodyFromRequest(req *Request) map[string]any - - // parseMCPResponse extracts the plain-text reply from a non-streaming - // response body. - parseMCPResponse(body []byte) (string, error) - - // parseMCPResponseFull extracts both text and tool calls. Providers whose - // response envelope differs from the OpenAI choices[] structure (e.g. - // Anthropic content[] with tool_use blocks) override this method. - parseMCPResponseFull(body []byte) (*LLMResponse, error) - - isRetryableError(err error) bool -} diff --git a/mcp/kimi_client.go b/mcp/kimi_client.go deleted file mode 100644 index 0337b10c..00000000 --- a/mcp/kimi_client.go +++ /dev/null @@ -1,71 +0,0 @@ -package mcp - -import ( - "net/http" -) - -const ( - ProviderKimi = "kimi" - DefaultKimiBaseURL = "https://api.moonshot.ai/v1" // Global endpoint (use api.moonshot.cn for China) - DefaultKimiModel = "moonshot-v1-auto" -) - -type KimiClient struct { - *Client -} - -// NewKimiClient creates Kimi (Moonshot) client (backward compatible) -func NewKimiClient() AIClient { - return NewKimiClientWithOptions() -} - -// NewKimiClientWithOptions creates Kimi client (supports options pattern) -func NewKimiClientWithOptions(opts ...ClientOption) AIClient { - // 1. Create Kimi preset options - kimiOpts := []ClientOption{ - WithProvider(ProviderKimi), - WithModel(DefaultKimiModel), - WithBaseURL(DefaultKimiBaseURL), - } - - // 2. Merge user options (user options have higher priority) - allOpts := append(kimiOpts, opts...) - - // 3. Create base client - baseClient := NewClient(allOpts...).(*Client) - - // 4. Create Kimi client - kimiClient := &KimiClient{ - Client: baseClient, - } - - // 5. Set hooks to point to KimiClient (implement dynamic dispatch) - baseClient.hooks = kimiClient - - return kimiClient -} - -func (c *KimiClient) SetAPIKey(apiKey string, customURL string, customModel string) { - c.APIKey = apiKey - - if len(apiKey) > 8 { - c.logger.Infof("๐Ÿ”ง [MCP] Kimi API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) - } - if customURL != "" { - c.BaseURL = customURL - c.logger.Infof("๐Ÿ”ง [MCP] Kimi using custom BaseURL: %s", customURL) - } else { - c.logger.Infof("๐Ÿ”ง [MCP] Kimi using default BaseURL: %s", c.BaseURL) - } - if customModel != "" { - c.Model = customModel - c.logger.Infof("๐Ÿ”ง [MCP] Kimi using custom Model: %s", customModel) - } else { - c.logger.Infof("๐Ÿ”ง [MCP] Kimi using default Model: %s", c.Model) - } -} - -// Kimi uses standard OpenAI-compatible API, so we just use the base client methods -func (c *KimiClient) setAuthHeader(reqHeaders http.Header) { - c.Client.setAuthHeader(reqHeaders) -} diff --git a/mcp/minimax_client.go b/mcp/minimax_client.go deleted file mode 100644 index 7bedb15a..00000000 --- a/mcp/minimax_client.go +++ /dev/null @@ -1,83 +0,0 @@ -package mcp - -import ( - "net/http" -) - -const ( - ProviderMiniMax = "minimax" - DefaultMiniMaxBaseURL = "https://api.minimax.io/v1" - DefaultMiniMaxModel = "MiniMax-M2.5" -) - -type MiniMaxClient struct { - *Client -} - -// NewMiniMaxClient creates MiniMax client (backward compatible) -func NewMiniMaxClient() AIClient { - return NewMiniMaxClientWithOptions() -} - -// NewMiniMaxClientWithOptions creates MiniMax client (supports options pattern) -// -// Usage examples: -// -// // Basic usage -// client := mcp.NewMiniMaxClientWithOptions() -// -// // Custom configuration -// client := mcp.NewMiniMaxClientWithOptions( -// mcp.WithAPIKey("sk-xxx"), -// mcp.WithLogger(customLogger), -// mcp.WithTimeout(60*time.Second), -// ) -func NewMiniMaxClientWithOptions(opts ...ClientOption) AIClient { - // 1. Create MiniMax preset options - minimaxOpts := []ClientOption{ - WithProvider(ProviderMiniMax), - WithModel(DefaultMiniMaxModel), - WithBaseURL(DefaultMiniMaxBaseURL), - } - - // 2. Merge user options (user options have higher priority) - allOpts := append(minimaxOpts, opts...) - - // 3. Create base client - baseClient := NewClient(allOpts...).(*Client) - - // 4. Create MiniMax client - minimaxClient := &MiniMaxClient{ - Client: baseClient, - } - - // 5. Set hooks to point to MiniMaxClient (implement dynamic dispatch) - baseClient.hooks = minimaxClient - - return minimaxClient -} - -func (c *MiniMaxClient) SetAPIKey(apiKey string, customURL string, customModel string) { - c.APIKey = apiKey - - if len(apiKey) > 8 { - c.logger.Infof("๐Ÿ”ง [MCP] MiniMax API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) - } - if customURL != "" { - c.BaseURL = customURL - c.logger.Infof("๐Ÿ”ง [MCP] MiniMax using custom BaseURL: %s", customURL) - } else { - c.logger.Infof("๐Ÿ”ง [MCP] MiniMax using default BaseURL: %s", c.BaseURL) - } - if customModel != "" { - c.Model = customModel - c.logger.Infof("๐Ÿ”ง [MCP] MiniMax using custom Model: %s", customModel) - } else { - c.logger.Infof("๐Ÿ”ง [MCP] MiniMax using default Model: %s", c.Model) - } -} - -// MiniMax uses standard OpenAI-compatible API with Bearer auth -func (c *MiniMaxClient) setAuthHeader(reqHeaders http.Header) { - c.Client.setAuthHeader(reqHeaders) -} diff --git a/mcp/minimax_client_test.go b/mcp/minimax_client_test.go deleted file mode 100644 index 21e45e22..00000000 --- a/mcp/minimax_client_test.go +++ /dev/null @@ -1,272 +0,0 @@ -package mcp - -import ( - "testing" - "time" -) - -// ============================================================ -// Test MiniMaxClient Creation and Configuration -// ============================================================ - -func TestNewMiniMaxClient_Default(t *testing.T) { - client := NewMiniMaxClient() - - if client == nil { - t.Fatal("client should not be nil") - } - - // Type assertion check - mmClient, ok := client.(*MiniMaxClient) - if !ok { - t.Fatal("client should be *MiniMaxClient") - } - - // Verify default values - if mmClient.Provider != ProviderMiniMax { - t.Errorf("Provider should be '%s', got '%s'", ProviderMiniMax, mmClient.Provider) - } - - if mmClient.BaseURL != DefaultMiniMaxBaseURL { - t.Errorf("BaseURL should be '%s', got '%s'", DefaultMiniMaxBaseURL, mmClient.BaseURL) - } - - if mmClient.Model != DefaultMiniMaxModel { - t.Errorf("Model should be '%s', got '%s'", DefaultMiniMaxModel, mmClient.Model) - } - - if mmClient.logger == nil { - t.Error("logger should not be nil") - } - - if mmClient.httpClient == nil { - t.Error("httpClient should not be nil") - } -} - -func TestNewMiniMaxClientWithOptions(t *testing.T) { - mockLogger := NewMockLogger() - customModel := "MiniMax-M2.5-highspeed" - customAPIKey := "sk-custom-key" - - client := NewMiniMaxClientWithOptions( - WithLogger(mockLogger), - WithModel(customModel), - WithAPIKey(customAPIKey), - WithMaxTokens(4000), - ) - - mmClient := client.(*MiniMaxClient) - - // Verify custom options are applied - if mmClient.logger != mockLogger { - t.Error("logger should be set from option") - } - - if mmClient.Model != customModel { - t.Error("Model should be set from option") - } - - if mmClient.APIKey != customAPIKey { - t.Error("APIKey should be set from option") - } - - if mmClient.MaxTokens != 4000 { - t.Error("MaxTokens should be 4000") - } - - // Verify MiniMax default values are retained - if mmClient.Provider != ProviderMiniMax { - t.Errorf("Provider should still be '%s'", ProviderMiniMax) - } - - if mmClient.BaseURL != DefaultMiniMaxBaseURL { - t.Errorf("BaseURL should still be '%s'", DefaultMiniMaxBaseURL) - } -} - -// ============================================================ -// Test SetAPIKey -// ============================================================ - -func TestMiniMaxClient_SetAPIKey(t *testing.T) { - mockLogger := NewMockLogger() - client := NewMiniMaxClientWithOptions( - WithLogger(mockLogger), - ) - - mmClient := client.(*MiniMaxClient) - - // Test setting API Key (default URL and Model) - mmClient.SetAPIKey("sk-test-key-12345678", "", "") - - if mmClient.APIKey != "sk-test-key-12345678" { - t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", mmClient.APIKey) - } - - // Verify logging - logs := mockLogger.GetLogsByLevel("INFO") - if len(logs) == 0 { - t.Error("should have logged API key setting") - } - - // Verify BaseURL and Model remain default - if mmClient.BaseURL != DefaultMiniMaxBaseURL { - t.Error("BaseURL should remain default") - } - - if mmClient.Model != DefaultMiniMaxModel { - t.Error("Model should remain default") - } -} - -func TestMiniMaxClient_SetAPIKey_WithCustomURL(t *testing.T) { - mockLogger := NewMockLogger() - client := NewMiniMaxClientWithOptions( - WithLogger(mockLogger), - ) - - mmClient := client.(*MiniMaxClient) - - customURL := "https://api.minimaxi.com/v1" - mmClient.SetAPIKey("sk-test-key-12345678", customURL, "") - - if mmClient.BaseURL != customURL { - t.Errorf("BaseURL should be '%s', got '%s'", customURL, mmClient.BaseURL) - } - - // Verify logging - logs := mockLogger.GetLogsByLevel("INFO") - hasCustomURLLog := false - for _, log := range logs { - if log.Format == "๐Ÿ”ง [MCP] MiniMax using custom BaseURL: %s" { - hasCustomURLLog = true - break - } - } - - if !hasCustomURLLog { - t.Error("should have logged custom BaseURL") - } -} - -func TestMiniMaxClient_SetAPIKey_WithCustomModel(t *testing.T) { - mockLogger := NewMockLogger() - client := NewMiniMaxClientWithOptions( - WithLogger(mockLogger), - ) - - mmClient := client.(*MiniMaxClient) - - customModel := "MiniMax-M2.5-highspeed" - mmClient.SetAPIKey("sk-test-key-12345678", "", customModel) - - if mmClient.Model != customModel { - t.Errorf("Model should be '%s', got '%s'", customModel, mmClient.Model) - } - - // Verify logging - logs := mockLogger.GetLogsByLevel("INFO") - hasCustomModelLog := false - for _, log := range logs { - if log.Format == "๐Ÿ”ง [MCP] MiniMax using custom Model: %s" { - hasCustomModelLog = true - break - } - } - - if !hasCustomModelLog { - t.Error("should have logged custom Model") - } -} - -// ============================================================ -// Test Integration Features -// ============================================================ - -func TestMiniMaxClient_CallWithMessages_Success(t *testing.T) { - mockHTTP := NewMockHTTPClient() - mockHTTP.SetSuccessResponse("MiniMax AI response") - mockLogger := NewMockLogger() - - client := NewMiniMaxClientWithOptions( - WithHTTPClient(mockHTTP.ToHTTPClient()), - WithLogger(mockLogger), - WithAPIKey("sk-test-key"), - ) - - result, err := client.CallWithMessages("system prompt", "user prompt") - - if err != nil { - t.Fatalf("should not error: %v", err) - } - - if result != "MiniMax AI response" { - t.Errorf("expected 'MiniMax AI response', got '%s'", result) - } - - // Verify request - requests := mockHTTP.GetRequests() - if len(requests) != 1 { - t.Fatalf("expected 1 request, got %d", len(requests)) - } - - req := requests[0] - - // Verify URL - expectedURL := DefaultMiniMaxBaseURL + "/chat/completions" - if req.URL.String() != expectedURL { - t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String()) - } - - // Verify Authorization header - authHeader := req.Header.Get("Authorization") - if authHeader != "Bearer sk-test-key" { - t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader) - } - - // Verify Content-Type - if req.Header.Get("Content-Type") != "application/json" { - t.Error("Content-Type should be application/json") - } -} - -func TestMiniMaxClient_Timeout(t *testing.T) { - client := NewMiniMaxClientWithOptions( - WithTimeout(30 * time.Second), - ) - - mmClient := client.(*MiniMaxClient) - - if mmClient.httpClient.Timeout != 30*time.Second { - t.Errorf("expected timeout 30s, got %v", mmClient.httpClient.Timeout) - } - - // Test SetTimeout - client.SetTimeout(60 * time.Second) - - if mmClient.httpClient.Timeout != 60*time.Second { - t.Errorf("expected timeout 60s after SetTimeout, got %v", mmClient.httpClient.Timeout) - } -} - -// ============================================================ -// Test hooks Mechanism -// ============================================================ - -func TestMiniMaxClient_HooksIntegration(t *testing.T) { - client := NewMiniMaxClientWithOptions() - mmClient := client.(*MiniMaxClient) - - // Verify hooks point to mmClient itself (implements polymorphism) - if mmClient.hooks != mmClient { - t.Error("hooks should point to mmClient for polymorphism") - } - - // Verify buildUrl uses MiniMax configuration - url := mmClient.buildUrl() - expectedURL := DefaultMiniMaxBaseURL + "/chat/completions" - if url != expectedURL { - t.Errorf("expected URL '%s', got '%s'", expectedURL, url) - } -} diff --git a/mcp/mock_test.go b/mcp/mock_test.go index 1f93042d..38d373c6 100644 --- a/mcp/mock_test.go +++ b/mcp/mock_test.go @@ -247,7 +247,7 @@ func NewMockClientHooks() *MockClientHooks { return &MockClientHooks{} } -func (m *MockClientHooks) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any { +func (m *MockClientHooks) BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any { m.BuildRequestBodyCalled++ if m.BuildRequestBodyFunc != nil { return m.BuildRequestBodyFunc(systemPrompt, userPrompt) @@ -261,7 +261,7 @@ func (m *MockClientHooks) buildMCPRequestBody(systemPrompt, userPrompt string) m } } -func (m *MockClientHooks) buildUrl() string { +func (m *MockClientHooks) BuildUrl() string { m.BuildUrlCalled++ if m.BuildUrlFunc != nil { return m.BuildUrlFunc() @@ -269,12 +269,12 @@ func (m *MockClientHooks) buildUrl() string { return "https://api.test.com/chat/completions" } -func (m *MockClientHooks) setAuthHeader(headers http.Header) { +func (m *MockClientHooks) SetAuthHeader(headers http.Header) { m.SetAuthHeaderCalled++ headers.Set("Authorization", "Bearer test-key") } -func (m *MockClientHooks) marshalRequestBody(body map[string]any) ([]byte, error) { +func (m *MockClientHooks) MarshalRequestBody(body map[string]any) ([]byte, error) { m.MarshalRequestCalled++ if m.MarshalRequestBodyFunc != nil { return m.MarshalRequestBodyFunc(body) @@ -282,7 +282,7 @@ func (m *MockClientHooks) marshalRequestBody(body map[string]any) ([]byte, error return json.Marshal(body) } -func (m *MockClientHooks) parseMCPResponse(body []byte) (string, error) { +func (m *MockClientHooks) ParseMCPResponse(body []byte) (string, error) { m.ParseResponseCalled++ if m.ParseResponseFunc != nil { return m.ParseResponseFunc(body) @@ -290,7 +290,15 @@ func (m *MockClientHooks) parseMCPResponse(body []byte) (string, error) { return "mocked response", nil } -func (m *MockClientHooks) isRetryableError(err error) bool { +func (m *MockClientHooks) ParseMCPResponseFull(body []byte) (*LLMResponse, error) { + r, err := m.ParseMCPResponse(body) + if err != nil { + return nil, err + } + return &LLMResponse{Content: r}, nil +} + +func (m *MockClientHooks) IsRetryableError(err error) bool { m.IsRetryableErrorCalled++ if m.IsRetryableErrorFunc != nil { return m.IsRetryableErrorFunc(err) @@ -298,13 +306,17 @@ func (m *MockClientHooks) isRetryableError(err error) bool { return false } -func (m *MockClientHooks) buildRequest(url string, jsonData []byte) (*http.Request, error) { +func (m *MockClientHooks) BuildRequest(url string, jsonData []byte) (*http.Request, error) { req, _ := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) req.Header.Set("Content-Type", "application/json") - m.setAuthHeader(req.Header) + m.SetAuthHeader(req.Header) return req, nil } -func (m *MockClientHooks) call(systemPrompt, userPrompt string) (string, error) { +func (m *MockClientHooks) Call(systemPrompt, userPrompt string) (string, error) { return "mocked call result", nil } + +func (m *MockClientHooks) BuildRequestBodyFromRequest(req *Request) map[string]any { + return map[string]any{"model": "test-model"} +} diff --git a/mcp/openai_client.go b/mcp/openai_client.go deleted file mode 100644 index 03e535b5..00000000 --- a/mcp/openai_client.go +++ /dev/null @@ -1,71 +0,0 @@ -package mcp - -import ( - "net/http" -) - -const ( - ProviderOpenAI = "openai" - DefaultOpenAIBaseURL = "https://api.openai.com/v1" - DefaultOpenAIModel = "gpt-5.4" -) - -type OpenAIClient struct { - *Client -} - -// NewOpenAIClient creates OpenAI client (backward compatible) -func NewOpenAIClient() AIClient { - return NewOpenAIClientWithOptions() -} - -// NewOpenAIClientWithOptions creates OpenAI client (supports options pattern) -func NewOpenAIClientWithOptions(opts ...ClientOption) AIClient { - // 1. Create OpenAI preset options - openaiOpts := []ClientOption{ - WithProvider(ProviderOpenAI), - WithModel(DefaultOpenAIModel), - WithBaseURL(DefaultOpenAIBaseURL), - } - - // 2. Merge user options (user options have higher priority) - allOpts := append(openaiOpts, opts...) - - // 3. Create base client - baseClient := NewClient(allOpts...).(*Client) - - // 4. Create OpenAI client - openaiClient := &OpenAIClient{ - Client: baseClient, - } - - // 5. Set hooks to point to OpenAIClient (implement dynamic dispatch) - baseClient.hooks = openaiClient - - return openaiClient -} - -func (c *OpenAIClient) SetAPIKey(apiKey string, customURL string, customModel string) { - c.APIKey = apiKey - - if len(apiKey) > 8 { - c.logger.Infof("๐Ÿ”ง [MCP] OpenAI API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) - } - if customURL != "" { - c.BaseURL = customURL - c.logger.Infof("๐Ÿ”ง [MCP] OpenAI using custom BaseURL: %s", customURL) - } else { - c.logger.Infof("๐Ÿ”ง [MCP] OpenAI using default BaseURL: %s", c.BaseURL) - } - if customModel != "" { - c.Model = customModel - c.logger.Infof("๐Ÿ”ง [MCP] OpenAI using custom Model: %s", customModel) - } else { - c.logger.Infof("๐Ÿ”ง [MCP] OpenAI using default Model: %s", c.Model) - } -} - -// OpenAI uses standard Bearer auth -func (c *OpenAIClient) setAuthHeader(reqHeaders http.Header) { - c.Client.setAuthHeader(reqHeaders) -} diff --git a/mcp/options_test.go b/mcp/options_test.go index f39ee33b..4bbe4719 100644 --- a/mcp/options_test.go +++ b/mcp/options_test.go @@ -279,8 +279,8 @@ func TestOptionsWithNewClient(t *testing.T) { t.Error("Model should be set from options") } - if c.logger != mockLogger { - t.Error("logger should be set from options") + if c.Log != mockLogger { + t.Error("Log should be set from options") } if c.MaxTokens != 4000 { @@ -288,78 +288,4 @@ func TestOptionsWithNewClient(t *testing.T) { } } -func TestOptionsWithDeepSeekClient(t *testing.T) { - mockLogger := NewMockLogger() - - client := NewDeepSeekClientWithOptions( - WithAPIKey("sk-deepseek-key"), - WithLogger(mockLogger), - WithMaxTokens(5000), - ) - - dsClient := client.(*DeepSeekClient) - - // Verify DeepSeek default values - if dsClient.Provider != ProviderDeepSeek { - t.Error("Provider should be DeepSeek") - } - - if dsClient.BaseURL != DefaultDeepSeekBaseURL { - t.Error("BaseURL should be DeepSeek default") - } - - if dsClient.Model != DefaultDeepSeekModel { - t.Error("Model should be DeepSeek default") - } - - // Verify custom options - if dsClient.APIKey != "sk-deepseek-key" { - t.Error("APIKey should be set from options") - } - - if dsClient.logger != mockLogger { - t.Error("logger should be set from options") - } - - if dsClient.MaxTokens != 5000 { - t.Error("MaxTokens should be 5000") - } -} - -func TestOptionsWithQwenClient(t *testing.T) { - mockLogger := NewMockLogger() - - client := NewQwenClientWithOptions( - WithAPIKey("sk-qwen-key"), - WithLogger(mockLogger), - WithMaxTokens(6000), - ) - - qwenClient := client.(*QwenClient) - - // Verify Qwen default values - if qwenClient.Provider != ProviderQwen { - t.Error("Provider should be Qwen") - } - - if qwenClient.BaseURL != DefaultQwenBaseURL { - t.Error("BaseURL should be Qwen default") - } - - if qwenClient.Model != DefaultQwenModel { - t.Error("Model should be Qwen default") - } - - // Verify custom options - if qwenClient.APIKey != "sk-qwen-key" { - t.Error("APIKey should be set from options") - } - - if qwenClient.logger != mockLogger { - t.Error("logger should be set from options") - } - - if qwenClient.MaxTokens != 6000 { - t.Error("MaxTokens should be 6000") - } -} +// Provider-specific option tests are in mcp/provider/options_test.go diff --git a/mcp/blockrun_base.go b/mcp/payment/blockrun_base.go similarity index 81% rename from mcp/blockrun_base.go rename to mcp/payment/blockrun_base.go index c97efa4a..5953417e 100644 --- a/mcp/blockrun_base.go +++ b/mcp/payment/blockrun_base.go @@ -1,4 +1,4 @@ -package mcp +package payment import ( "crypto/ecdsa" @@ -14,12 +14,13 @@ import ( "github.com/ethereum/go-ethereum/crypto" "golang.org/x/crypto/sha3" + + "nofx/mcp" ) const ( - ProviderBlockRunBase = "blockrun-base" DefaultBlockRunBaseURL = "https://blockrun.ai" - DefaultBlockRunModel = "gpt-5.4" + DefaultBlockRunModel = "gpt-5.4" BlockRunChatEndpoint = "/api/v1/chat/completions" BaseUSDCContract = "0x833589fCD6eDb6E08f4c7C32D4f71b54bdA02913" BaseChainID int64 = 8453 @@ -28,10 +29,16 @@ const ( // EIP-712 type hashes for USDC TransferWithAuthorization (ERC-3009) var ( - eip712DomainTypeHash = keccak256String("EIP712Domain(string name,string version,uint256 chainId,address verifyingContract)") + eip712DomainTypeHash = keccak256String("EIP712Domain(string name,string version,uint256 chainId,address verifyingContract)") transferWithAuthTypeHash = keccak256String("TransferWithAuthorization(address from,address to,uint256 value,uint256 validAfter,uint256 validBefore,bytes32 nonce)") ) +func init() { + mcp.RegisterProvider(mcp.ProviderBlockRunBase, func(opts ...mcp.ClientOption) mcp.AIClient { + return NewBlockRunBaseClientWithOptions(opts...) + }) +} + func keccak256String(s string) []byte { h := sha3.NewLegacyKeccak256() h.Write([]byte(s)) @@ -48,71 +55,72 @@ func keccak256Bytes(data ...[]byte) []byte { // BlockRunBaseClient implements AIClient using BlockRun's API with x402 v2 EIP-712 payment signing. type BlockRunBaseClient struct { - *Client + *mcp.Client privateKey *ecdsa.PrivateKey } +func (c *BlockRunBaseClient) BaseClient() *mcp.Client { return c.Client } + // NewBlockRunBaseClient creates a BlockRun Base wallet client (backward compatible). -func NewBlockRunBaseClient() AIClient { +func NewBlockRunBaseClient() mcp.AIClient { return NewBlockRunBaseClientWithOptions() } // NewBlockRunBaseClientWithOptions creates a BlockRun Base wallet client. -func NewBlockRunBaseClientWithOptions(opts ...ClientOption) AIClient { - baseOpts := []ClientOption{ - WithProvider(ProviderBlockRunBase), - WithModel(DefaultBlockRunModel), - WithBaseURL(DefaultBlockRunBaseURL), +func NewBlockRunBaseClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient { + baseOpts := []mcp.ClientOption{ + mcp.WithProvider(mcp.ProviderBlockRunBase), + mcp.WithModel(DefaultBlockRunModel), + mcp.WithBaseURL(DefaultBlockRunBaseURL), } allOpts := append(baseOpts, opts...) - baseClient := NewClient(allOpts...).(*Client) + baseClient := mcp.NewClient(allOpts...).(*mcp.Client) baseClient.UseFullURL = true baseClient.BaseURL = DefaultBlockRunBaseURL + BlockRunChatEndpoint c := &BlockRunBaseClient{Client: baseClient} - baseClient.hooks = c + baseClient.Hooks = c return c } // SetAPIKey stores the EVM private key (hex, with or without 0x prefix). -// customModel selects the AI model to use (e.g. "claude-sonnet-4.6"); empty means default. func (c *BlockRunBaseClient) SetAPIKey(apiKey string, customURL string, customModel string) { hexKey := strings.TrimPrefix(apiKey, "0x") privKey, err := crypto.HexToECDSA(hexKey) if err != nil { - c.logger.Warnf("โš ๏ธ [MCP] BlockRun Base: invalid private key: %v", err) + c.Log.Warnf("โš ๏ธ [MCP] BlockRun Base: invalid private key: %v", err) } else { c.privateKey = privKey c.APIKey = apiKey addr := crypto.PubkeyToAddress(privKey.PublicKey).Hex() - c.logger.Infof("๐Ÿ”ง [MCP] BlockRun Base wallet: %s", addr) + c.Log.Infof("๐Ÿ”ง [MCP] BlockRun Base wallet: %s", addr) } if customModel != "" { c.Model = customModel - c.logger.Infof("๐Ÿ”ง [MCP] BlockRun Base model: %s", customModel) + c.Log.Infof("๐Ÿ”ง [MCP] BlockRun Base model: %s", customModel) } else { - c.logger.Infof("๐Ÿ”ง [MCP] BlockRun Base model: %s", DefaultBlockRunModel) + c.Log.Infof("๐Ÿ”ง [MCP] BlockRun Base model: %s", DefaultBlockRunModel) } } -func (c *BlockRunBaseClient) setAuthHeader(h http.Header) { x402SetAuthHeader(h) } +func (c *BlockRunBaseClient) SetAuthHeader(h http.Header) { X402SetAuthHeader(h) } -func (c *BlockRunBaseClient) call(systemPrompt, userPrompt string) (string, error) { - return x402Call(c.Client, c.signPayment, "BlockRun Base", systemPrompt, userPrompt) +func (c *BlockRunBaseClient) Call(systemPrompt, userPrompt string) (string, error) { + return X402Call(c.Client, c.signPayment, "BlockRun Base", systemPrompt, userPrompt) } -func (c *BlockRunBaseClient) CallWithRequestFull(req *Request) (*LLMResponse, error) { - return x402CallFull(c.Client, c.signPayment, "BlockRun Base", req) +func (c *BlockRunBaseClient) CallWithRequestFull(req *mcp.Request) (*mcp.LLMResponse, error) { + return X402CallFull(c.Client, c.signPayment, "BlockRun Base", req) } // signPayment parses the Payment-Required header (x402 v2) and returns a signed payment value. func (c *BlockRunBaseClient) signPayment(paymentHeaderB64 string) (string, error) { - return signBasePaymentHeader(c.privateKey, paymentHeaderB64, "BlockRun Base") + return SignBasePaymentHeader(c.privateKey, paymentHeaderB64, "BlockRun Base") } -// signX402Payment is the shared EIP-712 signing logic for x402 v2 on Base USDC. +// SignX402Payment is the shared EIP-712 signing logic for x402 v2 on Base USDC. // Used by both BlockRunBaseClient and Claw402Client. -func signX402Payment(privateKey *ecdsa.PrivateKey, senderAddr string, opt x402AcceptOption, resource *x402Resource) (string, error) { +func SignX402Payment(privateKey *ecdsa.PrivateKey, senderAddr string, opt X402AcceptOption, resource *X402Resource) (string, error) { recipient := opt.PayTo amount := opt.Amount network := opt.Network @@ -224,7 +232,6 @@ func signX402Payment(privateKey *ecdsa.PrivateKey, senderAddr string, opt x402Ac // buildDomainSeparatorDynamic builds the EIP-712 domain separator using runtime values. func buildDomainSeparatorDynamic(name, version, network, asset string) ([]byte, error) { - // Extract chain ID from network string like "eip155:8453" chainID := new(big.Int).SetInt64(BaseChainID) if strings.HasPrefix(network, "eip155:") { parts := strings.SplitN(network, ":", 2) @@ -311,8 +318,6 @@ func hexToBytes32(s string) ([]byte, error) { func parseBigInt(s string) (*big.Int, error) { n := new(big.Int) - // Only treat as hex when explicitly prefixed with 0x/0X. - // x402 amounts are always decimal strings (e.g. "3000" = 0.003 USDC). if strings.HasPrefix(s, "0x") || strings.HasPrefix(s, "0X") { if _, ok := n.SetString(s[2:], 16); ok { return n, nil @@ -335,11 +340,11 @@ func leftPad32(b []byte) []byte { return padded } -// buildUrl returns the full BlockRun endpoint URL. -func (c *BlockRunBaseClient) buildUrl() string { +// BuildUrl returns the full BlockRun endpoint URL. +func (c *BlockRunBaseClient) BuildUrl() string { return DefaultBlockRunBaseURL + BlockRunChatEndpoint } -func (c *BlockRunBaseClient) buildRequest(url string, jsonData []byte) (*http.Request, error) { - return x402BuildRequest(url, jsonData) +func (c *BlockRunBaseClient) BuildRequest(url string, jsonData []byte) (*http.Request, error) { + return X402BuildRequest(url, jsonData) } diff --git a/mcp/blockrun_sol.go b/mcp/payment/blockrun_sol.go similarity index 77% rename from mcp/blockrun_sol.go rename to mcp/payment/blockrun_sol.go index cb88636e..dc78c29e 100644 --- a/mcp/blockrun_sol.go +++ b/mcp/payment/blockrun_sol.go @@ -1,4 +1,4 @@ -package mcp +package payment import ( "context" @@ -12,10 +12,11 @@ import ( "github.com/gagliardetto/solana-go/programs/compute-budget" "github.com/gagliardetto/solana-go/programs/token" "github.com/gagliardetto/solana-go/rpc" + + "nofx/mcp" ) const ( - ProviderBlockRunSol = "blockrun-sol" DefaultBlockRunSolURL = "https://sol.blockrun.ai" SolanaUSDCMint = "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v" SolanaNetwork = "solana:5eykt4UsFv8P8NJdTREpY1vzqKqZKvdp" @@ -26,62 +27,69 @@ const ( computeUnitPrice = uint64(1) ) +func init() { + mcp.RegisterProvider(mcp.ProviderBlockRunSol, func(opts ...mcp.ClientOption) mcp.AIClient { + return NewBlockRunSolClientWithOptions(opts...) + }) +} + // BlockRunSolClient implements AIClient using BlockRun's Solana x402 v2 payment protocol. type BlockRunSolClient struct { - *Client + *mcp.Client keypair solana.PrivateKey } +func (c *BlockRunSolClient) BaseClient() *mcp.Client { return c.Client } + // NewBlockRunSolClient creates a BlockRun Solana wallet client (backward compatible). -func NewBlockRunSolClient() AIClient { +func NewBlockRunSolClient() mcp.AIClient { return NewBlockRunSolClientWithOptions() } // NewBlockRunSolClientWithOptions creates a BlockRun Solana wallet client. -func NewBlockRunSolClientWithOptions(opts ...ClientOption) AIClient { - baseOpts := []ClientOption{ - WithProvider(ProviderBlockRunSol), - WithModel(DefaultBlockRunModel), - WithBaseURL(DefaultBlockRunSolURL), +func NewBlockRunSolClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient { + baseOpts := []mcp.ClientOption{ + mcp.WithProvider(mcp.ProviderBlockRunSol), + mcp.WithModel(DefaultBlockRunModel), + mcp.WithBaseURL(DefaultBlockRunSolURL), } allOpts := append(baseOpts, opts...) - baseClient := NewClient(allOpts...).(*Client) + baseClient := mcp.NewClient(allOpts...).(*mcp.Client) baseClient.UseFullURL = true baseClient.BaseURL = DefaultBlockRunSolURL + BlockRunChatEndpoint c := &BlockRunSolClient{Client: baseClient} - baseClient.hooks = c + baseClient.Hooks = c return c } // SetAPIKey stores the Solana wallet private key (base58-encoded 64-byte keypair). -// customModel selects the AI model; empty means default. func (c *BlockRunSolClient) SetAPIKey(apiKey string, customURL string, customModel string) { kp, err := solana.PrivateKeyFromBase58(strings.TrimSpace(apiKey)) if err != nil { - c.logger.Warnf("โš ๏ธ [MCP] BlockRun Sol: failed to parse private key: %v", err) + c.Log.Warnf("โš ๏ธ [MCP] BlockRun Sol: failed to parse private key: %v", err) return } c.keypair = kp c.APIKey = apiKey - c.logger.Infof("๐Ÿ”ง [MCP] BlockRun Sol wallet: %s", kp.PublicKey().String()) + c.Log.Infof("๐Ÿ”ง [MCP] BlockRun Sol wallet: %s", kp.PublicKey().String()) if customModel != "" { c.Model = customModel - c.logger.Infof("๐Ÿ”ง [MCP] BlockRun Sol model: %s", customModel) + c.Log.Infof("๐Ÿ”ง [MCP] BlockRun Sol model: %s", customModel) } else { - c.logger.Infof("๐Ÿ”ง [MCP] BlockRun Sol model: %s", DefaultBlockRunModel) + c.Log.Infof("๐Ÿ”ง [MCP] BlockRun Sol model: %s", DefaultBlockRunModel) } } -func (c *BlockRunSolClient) setAuthHeader(h http.Header) { x402SetAuthHeader(h) } +func (c *BlockRunSolClient) SetAuthHeader(h http.Header) { X402SetAuthHeader(h) } -func (c *BlockRunSolClient) call(systemPrompt, userPrompt string) (string, error) { - return x402Call(c.Client, c.signSolanaPayment, "BlockRun Sol", systemPrompt, userPrompt) +func (c *BlockRunSolClient) Call(systemPrompt, userPrompt string) (string, error) { + return X402Call(c.Client, c.signSolanaPayment, "BlockRun Sol", systemPrompt, userPrompt) } -func (c *BlockRunSolClient) CallWithRequestFull(req *Request) (*LLMResponse, error) { - return x402CallFull(c.Client, c.signSolanaPayment, "BlockRun Sol", req) +func (c *BlockRunSolClient) CallWithRequestFull(req *mcp.Request) (*mcp.LLMResponse, error) { + return X402CallFull(c.Client, c.signSolanaPayment, "BlockRun Sol", req) } // signSolanaPayment parses the Payment-Required header and builds a signed x402 v2 Solana payload. @@ -90,18 +98,18 @@ func (c *BlockRunSolClient) signSolanaPayment(paymentHeaderB64 string) (string, return "", fmt.Errorf("no private key set for BlockRun Sol wallet") } - decoded, err := x402DecodeHeader(paymentHeaderB64) + decoded, err := X402DecodeHeader(paymentHeaderB64) if err != nil { return "", err } - var req x402v2PaymentRequired + var req X402v2PaymentRequired if err := json.Unmarshal(decoded, &req); err != nil { return "", fmt.Errorf("failed to parse x402 v2 Solana header: %w", err) } // Find the Solana option - var opt *x402AcceptOption + var opt *X402AcceptOption for i := range req.Accepts { if strings.HasPrefix(req.Accepts[i].Network, "solana:") { opt = &req.Accepts[i] @@ -174,11 +182,9 @@ func (c *BlockRunSolClient) signSolanaPayment(paymentHeaderB64 string) (string, } // buildSolanaTransferTx builds a partial-signed VersionedTransaction for SPL USDC TransferChecked. -// The fee payer (CDP facilitator) slot is left with a zero signature; only the user signs. func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr string) (string, error) { ownerPubkey := c.keypair.PublicKey() - // Parse recipient and feePayer recipientPK, err := solana.PublicKeyFromBase58(recipient) if err != nil { return "", fmt.Errorf("invalid recipient address: %w", err) @@ -189,13 +195,11 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr } mintPK := solana.MustPublicKeyFromBase58(SolanaUSDCMint) - // Parse amount var amountU64 uint64 if _, err := fmt.Sscanf(amountStr, "%d", &amountU64); err != nil { return "", fmt.Errorf("invalid amount %q: %w", amountStr, err) } - // Derive ATAs sourceATA, _, err := solana.FindAssociatedTokenAddress(ownerPubkey, mintPK) if err != nil { return "", fmt.Errorf("failed to derive source ATA: %w", err) @@ -205,7 +209,6 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr return "", fmt.Errorf("failed to derive dest ATA: %w", err) } - // Fetch latest blockhash from Solana mainnet rpcClient := rpc.New(SolanaMainnetRPC) bhResp, err := rpcClient.GetLatestBlockhash(context.Background(), rpc.CommitmentFinalized) if err != nil { @@ -213,7 +216,6 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr } recentBlockhash := bhResp.Value.Blockhash - // Build instructions: ComputeBudgetSetLimit, ComputeBudgetSetPrice, TransferChecked setLimitIx, err := computebudget.NewSetComputeUnitLimitInstruction(computeUnitLimit).ValidateAndBuild() if err != nil { return "", fmt.Errorf("failed to build SetComputeUnitLimit: %w", err) @@ -235,7 +237,6 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr return "", fmt.Errorf("failed to build TransferChecked: %w", err) } - // Build transaction with feePayer as payer (matches Python SDK) tx, err := solana.NewTransaction( []solana.Instruction{setLimitIx, setPriceIx, transferIx}, recentBlockhash, @@ -245,9 +246,6 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr return "", fmt.Errorf("failed to build transaction: %w", err) } - // Partial sign: user signs; fee_payer (CDP) co-signs on server side - // The transaction has 2 signers: [feePayer (index 0), owner (index 1)] - // We sign only our index (owner). _, err = tx.Sign(func(key solana.PublicKey) *solana.PrivateKey { if key.Equals(ownerPubkey) { return &c.keypair @@ -258,7 +256,6 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr return "", fmt.Errorf("failed to sign transaction: %w", err) } - // Serialize transaction txBytes, err := tx.MarshalBinary() if err != nil { return "", fmt.Errorf("failed to serialize transaction: %w", err) @@ -267,11 +264,11 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr return base64.StdEncoding.EncodeToString(txBytes), nil } -// buildUrl returns the full BlockRun Solana endpoint URL. -func (c *BlockRunSolClient) buildUrl() string { +// BuildUrl returns the full BlockRun Solana endpoint URL. +func (c *BlockRunSolClient) BuildUrl() string { return DefaultBlockRunSolURL + BlockRunChatEndpoint } -func (c *BlockRunSolClient) buildRequest(url string, jsonData []byte) (*http.Request, error) { - return x402BuildRequest(url, jsonData) +func (c *BlockRunSolClient) BuildRequest(url string, jsonData []byte) (*http.Request, error) { + return X402BuildRequest(url, jsonData) } diff --git a/mcp/claw402.go b/mcp/payment/claw402.go similarity index 58% rename from mcp/claw402.go rename to mcp/payment/claw402.go index 03bb9ca6..177f0626 100644 --- a/mcp/claw402.go +++ b/mcp/payment/claw402.go @@ -1,4 +1,4 @@ -package mcp +package payment import ( "crypto/ecdsa" @@ -6,11 +6,13 @@ import ( "strings" "github.com/ethereum/go-ethereum/crypto" + + "nofx/mcp" + "nofx/mcp/provider" ) const ( - ProviderClaw402 = "claw402" - DefaultClaw402URL = "https://claw402.ai" + DefaultClaw402URL = "https://claw402.ai" DefaultClaw402Model = "deepseek" ) @@ -39,35 +41,42 @@ var claw402ModelEndpoints = map[string]string{ "kimi-k2.5": "/api/v1/ai/kimi/chat/k2.5", } +func init() { + mcp.RegisterProvider(mcp.ProviderClaw402, func(opts ...mcp.ClientOption) mcp.AIClient { + return NewClaw402ClientWithOptions(opts...) + }) +} + // Claw402Client implements AIClient using claw402.ai's x402 v2 USDC payment gateway. -// Reuses the same EIP-712 signing as BlockRunBaseClient (same Base chain + USDC contract). // When the selected model routes to an Anthropic endpoint, it automatically uses // the Anthropic wire format for requests and responses (via an internal ClaudeClient). type Claw402Client struct { - *Client + *mcp.Client privateKey *ecdsa.PrivateKey - claudeProxy *ClaudeClient // non-nil when endpoint is /anthropic/ + claudeProxy *provider.ClaudeClient // non-nil when endpoint is /anthropic/ } +func (c *Claw402Client) BaseClient() *mcp.Client { return c.Client } + // NewClaw402Client creates a claw402 client (backward compatible). -func NewClaw402Client() AIClient { +func NewClaw402Client() mcp.AIClient { return NewClaw402ClientWithOptions() } // NewClaw402ClientWithOptions creates a claw402 client with options. -func NewClaw402ClientWithOptions(opts ...ClientOption) AIClient { - baseOpts := []ClientOption{ - WithProvider(ProviderClaw402), - WithModel(DefaultClaw402Model), - WithBaseURL(DefaultClaw402URL), +func NewClaw402ClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient { + baseOpts := []mcp.ClientOption{ + mcp.WithProvider(mcp.ProviderClaw402), + mcp.WithModel(DefaultClaw402Model), + mcp.WithBaseURL(DefaultClaw402URL), } allOpts := append(baseOpts, opts...) - baseClient := NewClient(allOpts...).(*Client) + baseClient := mcp.NewClient(allOpts...).(*mcp.Client) baseClient.UseFullURL = true baseClient.BaseURL = DefaultClaw402URL + claw402ModelEndpoints[DefaultClaw402Model] c := &Claw402Client{Client: baseClient} - baseClient.hooks = c + baseClient.Hooks = c return c } @@ -76,12 +85,12 @@ func (c *Claw402Client) SetAPIKey(apiKey string, _ string, customModel string) { hexKey := strings.TrimPrefix(apiKey, "0x") privKey, err := crypto.HexToECDSA(hexKey) if err != nil { - c.logger.Warnf("โš ๏ธ [MCP] Claw402: invalid private key: %v", err) + c.Log.Warnf("โš ๏ธ [MCP] Claw402: invalid private key: %v", err) } else { c.privateKey = privKey c.APIKey = apiKey addr := crypto.PubkeyToAddress(privKey.PublicKey).Hex() - c.logger.Infof("๐Ÿ”ง [MCP] Claw402 wallet: %s", addr) + c.Log.Infof("๐Ÿ”ง [MCP] Claw402 wallet: %s", addr) } if customModel != "" { c.Model = customModel @@ -91,11 +100,11 @@ func (c *Claw402Client) SetAPIKey(apiKey string, _ string, customModel string) { // Anthropic endpoints need different wire format (Messages API) if strings.Contains(endpoint, "/anthropic/") { - c.claudeProxy = &ClaudeClient{Client: c.Client} - c.logger.Infof("๐Ÿ”ง [MCP] Claw402 model: %s โ†’ %s (Anthropic format)", c.Model, endpoint) + c.claudeProxy = &provider.ClaudeClient{Client: c.Client} + c.Log.Infof("๐Ÿ”ง [MCP] Claw402 model: %s โ†’ %s (Anthropic format)", c.Model, endpoint) } else { c.claudeProxy = nil - c.logger.Infof("๐Ÿ”ง [MCP] Claw402 model: %s โ†’ %s", c.Model, endpoint) + c.Log.Infof("๐Ÿ”ง [MCP] Claw402 model: %s โ†’ %s", c.Model, endpoint) } } @@ -111,56 +120,56 @@ func (c *Claw402Client) resolveEndpoint() string { return claw402ModelEndpoints[DefaultClaw402Model] } -func (c *Claw402Client) setAuthHeader(h http.Header) { x402SetAuthHeader(h) } +func (c *Claw402Client) SetAuthHeader(h http.Header) { X402SetAuthHeader(h) } -func (c *Claw402Client) call(systemPrompt, userPrompt string) (string, error) { - return x402Call(c.Client, c.signPayment, "Claw402", systemPrompt, userPrompt) +func (c *Claw402Client) Call(systemPrompt, userPrompt string) (string, error) { + return X402Call(c.Client, c.signPayment, "Claw402", systemPrompt, userPrompt) } -func (c *Claw402Client) CallWithRequestFull(req *Request) (*LLMResponse, error) { - return x402CallFull(c.Client, c.signPayment, "Claw402", req) +func (c *Claw402Client) CallWithRequestFull(req *mcp.Request) (*mcp.LLMResponse, error) { + return X402CallFull(c.Client, c.signPayment, "Claw402", req) } // signPayment signs x402 v2 EIP-712 payment (same Base chain + USDC as BlockRunBase). func (c *Claw402Client) signPayment(paymentHeaderB64 string) (string, error) { - return signBasePaymentHeader(c.privateKey, paymentHeaderB64, "Claw402") + return SignBasePaymentHeader(c.privateKey, paymentHeaderB64, "Claw402") } // โ”€โ”€ Format overrides for Anthropic endpoints โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -func (c *Claw402Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any { +func (c *Claw402Client) BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any { if c.claudeProxy != nil { - return c.claudeProxy.buildMCPRequestBody(systemPrompt, userPrompt) + return c.claudeProxy.BuildMCPRequestBody(systemPrompt, userPrompt) } - return c.Client.buildMCPRequestBody(systemPrompt, userPrompt) + return c.Client.BuildMCPRequestBody(systemPrompt, userPrompt) } -func (c *Claw402Client) buildRequestBodyFromRequest(req *Request) map[string]any { +func (c *Claw402Client) BuildRequestBodyFromRequest(req *mcp.Request) map[string]any { if c.claudeProxy != nil { - return c.claudeProxy.buildRequestBodyFromRequest(req) + return c.claudeProxy.BuildRequestBodyFromRequest(req) } - return c.Client.buildRequestBodyFromRequest(req) + return c.Client.BuildRequestBodyFromRequest(req) } -func (c *Claw402Client) parseMCPResponse(body []byte) (string, error) { +func (c *Claw402Client) ParseMCPResponse(body []byte) (string, error) { if c.claudeProxy != nil { - return c.claudeProxy.parseMCPResponse(body) + return c.claudeProxy.ParseMCPResponse(body) } - return c.Client.parseMCPResponse(body) + return c.Client.ParseMCPResponse(body) } -func (c *Claw402Client) parseMCPResponseFull(body []byte) (*LLMResponse, error) { +func (c *Claw402Client) ParseMCPResponseFull(body []byte) (*mcp.LLMResponse, error) { if c.claudeProxy != nil { - return c.claudeProxy.parseMCPResponseFull(body) + return c.claudeProxy.ParseMCPResponseFull(body) } - return c.Client.parseMCPResponseFull(body) + return c.Client.ParseMCPResponseFull(body) } -// buildUrl returns the full claw402 endpoint URL. -func (c *Claw402Client) buildUrl() string { +// BuildUrl returns the full claw402 endpoint URL. +func (c *Claw402Client) BuildUrl() string { return c.BaseURL } -func (c *Claw402Client) buildRequest(url string, jsonData []byte) (*http.Request, error) { - return x402BuildRequest(url, jsonData) +func (c *Claw402Client) BuildRequest(url string, jsonData []byte) (*http.Request, error) { + return X402BuildRequest(url, jsonData) } diff --git a/mcp/x402.go b/mcp/payment/x402.go similarity index 62% rename from mcp/x402.go rename to mcp/payment/x402.go index 8efdb6c0..852d2fc1 100644 --- a/mcp/x402.go +++ b/mcp/payment/x402.go @@ -1,4 +1,4 @@ -package mcp +package payment import ( "bytes" @@ -11,28 +11,30 @@ import ( "time" "github.com/ethereum/go-ethereum/crypto" + + "nofx/mcp" ) const ( - // x402MaxPaymentRetries is the number of retries for 5xx errors on the + // X402MaxPaymentRetries is the number of retries for 5xx errors on the // payment-signed request. The same payment signature is reused (no double-charge). - x402MaxPaymentRetries = 3 + X402MaxPaymentRetries = 3 - // x402RetryBaseWait is the base wait between payment retry attempts. - x402RetryBaseWait = 3 * time.Second + // X402RetryBaseWait is the base wait between payment retry attempts. + X402RetryBaseWait = 3 * time.Second ) // โ”€โ”€ Shared x402 types โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -// x402v2PaymentRequired is the structure of the Payment-Required header (x402 v2). -type x402v2PaymentRequired struct { +// X402v2PaymentRequired is the structure of the Payment-Required header (x402 v2). +type X402v2PaymentRequired struct { X402Version int `json:"x402Version"` - Accepts []x402AcceptOption `json:"accepts"` - Resource *x402Resource `json:"resource"` + Accepts []X402AcceptOption `json:"accepts"` + Resource *X402Resource `json:"resource"` } -// x402AcceptOption is a payment option from the x402 v2 header. -type x402AcceptOption struct { +// X402AcceptOption is a payment option from the x402 v2 header. +type X402AcceptOption struct { Scheme string `json:"scheme"` Network string `json:"network"` Amount string `json:"amount"` @@ -42,22 +44,22 @@ type x402AcceptOption struct { Extra map[string]string `json:"extra"` } -// x402Resource describes the resource being paid for. -type x402Resource struct { +// X402Resource describes the resource being paid for. +type X402Resource struct { URL string `json:"url"` Description string `json:"description"` MimeType string `json:"mimeType"` } -// x402SignFunc is a callback that signs an x402 payment header and returns the +// X402SignFunc is a callback that signs an x402 payment header and returns the // base64-encoded payment signature. -type x402SignFunc func(paymentHeaderB64 string) (string, error) +type X402SignFunc func(paymentHeaderB64 string) (string, error) // โ”€โ”€ Shared x402 helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -// x402DecodeHeader decodes a base64-encoded x402 Payment-Required header, +// X402DecodeHeader decodes a base64-encoded x402 Payment-Required header, // trying RawStdEncoding first then StdEncoding as fallback. -func x402DecodeHeader(b64 string) ([]byte, error) { +func X402DecodeHeader(b64 string) ([]byte, error) { decoded, err := base64.RawStdEncoding.DecodeString(b64) if err != nil { decoded, err = base64.StdEncoding.DecodeString(b64) @@ -68,19 +70,19 @@ func x402DecodeHeader(b64 string) ([]byte, error) { return decoded, nil } -// signBasePaymentHeader decodes a base64 x402 header, parses it, and signs with +// SignBasePaymentHeader decodes a base64 x402 header, parses it, and signs with // EIP-712 (USDC TransferWithAuthorization). Shared by BlockRunBase and Claw402. -func signBasePaymentHeader(privateKey *ecdsa.PrivateKey, paymentHeaderB64 string, providerName string) (string, error) { +func SignBasePaymentHeader(privateKey *ecdsa.PrivateKey, paymentHeaderB64 string, providerName string) (string, error) { if privateKey == nil { return "", fmt.Errorf("no private key set for %s wallet", providerName) } - decoded, err := x402DecodeHeader(paymentHeaderB64) + decoded, err := X402DecodeHeader(paymentHeaderB64) if err != nil { return "", err } - var req x402v2PaymentRequired + var req X402v2PaymentRequired if err := json.Unmarshal(decoded, &req); err != nil { return "", fmt.Errorf("failed to parse x402 v2 payment header: %w", err) } @@ -89,19 +91,16 @@ func signBasePaymentHeader(privateKey *ecdsa.PrivateKey, paymentHeaderB64 string } senderAddr := crypto.PubkeyToAddress(privateKey.PublicKey).Hex() - return signX402Payment(privateKey, senderAddr, req.Accepts[0], req.Resource) + return SignX402Payment(privateKey, senderAddr, req.Accepts[0], req.Resource) } -// doX402Request executes an HTTP request and handles the x402 v2 payment flow. -// On a 402 response it reads the Payment-Required (or X-Payment-Required) header, -// signs via signFn, retries with Payment-Signature, and logs the Payment-Response -// header (tx hash) on success. -func doX402Request( +// DoX402Request executes an HTTP request and handles the x402 v2 payment flow. +func DoX402Request( httpClient *http.Client, buildReqFn func() (*http.Request, error), - signFn x402SignFunc, + signFn X402SignFunc, providerTag string, - logger Logger, + logger mcp.Logger, ) ([]byte, error) { req, err := buildReqFn() if err != nil { @@ -133,10 +132,9 @@ func doX402Request( } // Retry loop for 5xx errors on the payment-signed request. - // Reuses the same payment signature โ€” no double-charge. var lastBody []byte var lastStatus int - for attempt := 1; attempt <= x402MaxPaymentRetries; attempt++ { + for attempt := 1; attempt <= X402MaxPaymentRetries; attempt++ { req2, err := buildReqFn() if err != nil { return nil, fmt.Errorf("failed to build retry request: %w", err) @@ -146,10 +144,10 @@ func doX402Request( resp2, err := httpClient.Do(req2) if err != nil { - if attempt < x402MaxPaymentRetries { - wait := x402RetryBaseWait * time.Duration(attempt) + if attempt < X402MaxPaymentRetries { + wait := X402RetryBaseWait * time.Duration(attempt) logger.Warnf("โš ๏ธ [%s] Payment request failed: %v, retrying in %v (%d/%d)...", - providerTag, err, wait, attempt+1, x402MaxPaymentRetries) + providerTag, err, wait, attempt+1, X402MaxPaymentRetries) time.Sleep(wait) continue } @@ -175,11 +173,11 @@ func doX402Request( lastBody = body2 lastStatus = resp2.StatusCode - // Retry on 5xx server errors (502, 503, 520, etc.) - if resp2.StatusCode >= 500 && attempt < x402MaxPaymentRetries { - wait := x402RetryBaseWait * time.Duration(attempt) + // Retry on 5xx server errors + if resp2.StatusCode >= 500 && attempt < X402MaxPaymentRetries { + wait := X402RetryBaseWait * time.Duration(attempt) logger.Warnf("โš ๏ธ [%s] Server error (status %d), retrying in %v (%d/%d)...", - providerTag, resp2.StatusCode, wait, attempt+1, x402MaxPaymentRetries) + providerTag, resp2.StatusCode, wait, attempt+1, X402MaxPaymentRetries) time.Sleep(wait) continue } @@ -201,8 +199,8 @@ func doX402Request( return body, nil } -// x402BuildRequest creates a POST request with Content-Type but no auth header. -func x402BuildRequest(url string, jsonData []byte) (*http.Request, error) { +// X402BuildRequest creates a POST request with Content-Type but no auth header. +func X402BuildRequest(url string, jsonData []byte) (*http.Request, error) { req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("fail to build request: %w", err) @@ -211,30 +209,30 @@ func x402BuildRequest(url string, jsonData []byte) (*http.Request, error) { return req, nil } -// x402SetAuthHeader is a no-op โ€” x402 providers authenticate via payment signing. -func x402SetAuthHeader(_ http.Header) {} +// X402SetAuthHeader is a no-op โ€” x402 providers authenticate via payment signing. +func X402SetAuthHeader(_ http.Header) {} -// x402Call handles the x402 payment flow for the simple CallWithMessages path. -func x402Call(c *Client, signFn x402SignFunc, tag string, systemPrompt, userPrompt string) (string, error) { - c.logger.Infof("๐Ÿ“ก [%s] Request AI Server: %s", tag, c.BaseURL) +// X402Call handles the x402 payment flow for the simple CallWithMessages path. +func X402Call(c *mcp.Client, signFn X402SignFunc, tag string, systemPrompt, userPrompt string) (string, error) { + c.Log.Infof("๐Ÿ“ก [%s] Request AI Server: %s", tag, c.BaseURL) - requestBody := c.hooks.buildMCPRequestBody(systemPrompt, userPrompt) - jsonData, err := c.hooks.marshalRequestBody(requestBody) + requestBody := c.Hooks.BuildMCPRequestBody(systemPrompt, userPrompt) + jsonData, err := c.Hooks.MarshalRequestBody(requestBody) if err != nil { return "", err } - body, err := doX402Request(c.httpClient, func() (*http.Request, error) { - return c.hooks.buildRequest(c.hooks.buildUrl(), jsonData) - }, signFn, tag, c.logger) + body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) { + return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData) + }, signFn, tag, c.Log) if err != nil { return "", err } - return c.hooks.parseMCPResponse(body) + return c.Hooks.ParseMCPResponse(body) } -// x402CallFull handles the x402 payment flow for the advanced Request path. -func x402CallFull(c *Client, signFn x402SignFunc, tag string, req *Request) (*LLMResponse, error) { +// X402CallFull handles the x402 payment flow for the advanced Request path. +func X402CallFull(c *mcp.Client, signFn X402SignFunc, tag string, req *mcp.Request) (*mcp.LLMResponse, error) { if c.APIKey == "" { return nil, fmt.Errorf("AI API key not set, please call SetAPIKey first") } @@ -242,19 +240,19 @@ func x402CallFull(c *Client, signFn x402SignFunc, tag string, req *Request) (*LL req.Model = c.Model } - c.logger.Infof("๐Ÿ“ก [%s] Request AI (full): %s", tag, c.BaseURL) + c.Log.Infof("๐Ÿ“ก [%s] Request AI (full): %s", tag, c.BaseURL) - requestBody := c.hooks.buildRequestBodyFromRequest(req) - jsonData, err := c.hooks.marshalRequestBody(requestBody) + requestBody := c.Hooks.BuildRequestBodyFromRequest(req) + jsonData, err := c.Hooks.MarshalRequestBody(requestBody) if err != nil { return nil, err } - body, err := doX402Request(c.httpClient, func() (*http.Request, error) { - return c.hooks.buildRequest(c.hooks.buildUrl(), jsonData) - }, signFn, tag, c.logger) + body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) { + return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData) + }, signFn, tag, c.Log) if err != nil { return nil, err } - return c.hooks.parseMCPResponseFull(body) + return c.Hooks.ParseMCPResponseFull(body) } diff --git a/mcp/claude_client.go b/mcp/provider/claude.go similarity index 72% rename from mcp/claude_client.go rename to mcp/provider/claude.go index d91bdb81..a9330687 100644 --- a/mcp/claude_client.go +++ b/mcp/provider/claude.go @@ -1,4 +1,4 @@ -// Package mcp โ€” ClaudeClient implements the Anthropic Messages API. +// Package provider โ€” ClaudeClient implements the Anthropic Messages API. // // Wire-format differences from the OpenAI-compatible base Client: // @@ -14,42 +14,51 @@ // โ”‚ Tool result โ”‚ role=tool + tool_call_id โ”‚ role=user content[tool_result] โ”‚ // โ”‚ Max tokens โ”‚ max_tokens โ”‚ max_tokens (same) โ”‚ // โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ -package mcp +package provider import ( "encoding/json" "fmt" "net/http" + + "nofx/mcp" ) const ( - ProviderClaude = "claude" DefaultClaudeBaseURL = "https://api.anthropic.com/v1" DefaultClaudeModel = "claude-opus-4-6" ) +func init() { + mcp.RegisterProvider(mcp.ProviderClaude, func(opts ...mcp.ClientOption) mcp.AIClient { + return NewClaudeClientWithOptions(opts...) + }) +} + // ClaudeClient wraps the base Client and overrides the methods that differ // for the Anthropic Messages API. All other behaviour (retry, timeout, // logging) is inherited unchanged. type ClaudeClient struct { - *Client + *mcp.Client } +func (c *ClaudeClient) BaseClient() *mcp.Client { return c.Client } + // NewClaudeClient creates a ClaudeClient with default settings. -func NewClaudeClient() AIClient { +func NewClaudeClient() mcp.AIClient { return NewClaudeClientWithOptions() } // NewClaudeClientWithOptions creates a ClaudeClient with optional overrides. -func NewClaudeClientWithOptions(opts ...ClientOption) AIClient { - baseClient := NewClient(append([]ClientOption{ - WithProvider(ProviderClaude), - WithModel(DefaultClaudeModel), - WithBaseURL(DefaultClaudeBaseURL), - }, opts...)...).(*Client) +func NewClaudeClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient { + baseClient := mcp.NewClient(append([]mcp.ClientOption{ + mcp.WithProvider(mcp.ProviderClaude), + mcp.WithModel(DefaultClaudeModel), + mcp.WithBaseURL(DefaultClaudeBaseURL), + }, opts...)...).(*mcp.Client) c := &ClaudeClient{Client: baseClient} - baseClient.hooks = c // wire dynamic dispatch to ClaudeClient + baseClient.Hooks = c // wire dynamic dispatch to ClaudeClient return c } @@ -59,32 +68,32 @@ func NewClaudeClientWithOptions(opts ...ClientOption) AIClient { func (c *ClaudeClient) SetAPIKey(apiKey, customURL, customModel string) { c.APIKey = apiKey if len(apiKey) > 8 { - c.logger.Infof("๐Ÿ”ง [MCP] Claude API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) + c.Log.Infof("๐Ÿ”ง [MCP] Claude API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) } if customURL != "" { c.BaseURL = customURL - c.logger.Infof("๐Ÿ”ง [MCP] Claude BaseURL: %s", customURL) + c.Log.Infof("๐Ÿ”ง [MCP] Claude BaseURL: %s", customURL) } if customModel != "" { c.Model = customModel - c.logger.Infof("๐Ÿ”ง [MCP] Claude Model: %s", customModel) + c.Log.Infof("๐Ÿ”ง [MCP] Claude Model: %s", customModel) } } -// setAuthHeader uses x-api-key instead of Authorization: Bearer. -func (c *ClaudeClient) setAuthHeader(h http.Header) { +// SetAuthHeader uses x-api-key instead of Authorization: Bearer. +func (c *ClaudeClient) SetAuthHeader(h http.Header) { h.Set("x-api-key", c.APIKey) h.Set("anthropic-version", "2023-06-01") } -// buildUrl targets /messages instead of /chat/completions. -func (c *ClaudeClient) buildUrl() string { +// BuildUrl targets /messages instead of /chat/completions. +func (c *ClaudeClient) BuildUrl() string { return fmt.Sprintf("%s/messages", c.BaseURL) } -// buildMCPRequestBody builds the Anthropic wire format for the simple +// BuildMCPRequestBody builds the Anthropic wire format for the simple // CallWithMessages path (no tool support). -func (c *ClaudeClient) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any { +func (c *ClaudeClient) BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any { return map[string]any{ "model": c.Model, "max_tokens": c.MaxTokens, @@ -95,23 +104,12 @@ func (c *ClaudeClient) buildMCPRequestBody(systemPrompt, userPrompt string) map[ } } -// buildRequestBodyFromRequest converts a *Request into the Anthropic Messages -// API wire format. This is the key override that makes tool calling work -// correctly with Claude. -// -// Conversions applied: -// -// - System messages are lifted to the top-level "system" field. -// - Tool definitions: parameters โ†’ input_schema, wrapper removed. -// - Assistant messages with ToolCalls โ†’ content[{type:tool_use,...}]. -// - Tool result messages (role=tool) โ†’ role=user with tool_result blocks. -// Consecutive tool results are merged into a single user turn (Anthropic -// requires strictly alternating user/assistant turns). -// - tool_choice "auto"/"any" โ†’ {"type":"auto"/"any"} object. -func (c *ClaudeClient) buildRequestBodyFromRequest(req *Request) map[string]any { +// BuildRequestBodyFromRequest converts a *Request into the Anthropic Messages +// API wire format. +func (c *ClaudeClient) BuildRequestBodyFromRequest(req *mcp.Request) map[string]any { // โ”€โ”€ 1. Separate system prompt from conversation messages โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ var systemPrompt string - var convMsgs []Message + var convMsgs []mcp.Message for _, m := range req.Messages { if m.Role == "system" { systemPrompt = m.Content @@ -121,7 +119,7 @@ func (c *ClaudeClient) buildRequestBodyFromRequest(req *Request) map[string]any } // โ”€โ”€ 2. Convert messages to Anthropic format โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - anthropicMsgs := convertMessagesToAnthropic(convMsgs) + anthropicMsgs := ConvertMessagesToAnthropic(convMsgs) // โ”€โ”€ 3. Convert tool definitions (parameters โ†’ input_schema) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ var anthropicTools []map[string]any @@ -162,16 +160,9 @@ func (c *ClaudeClient) buildRequestBodyFromRequest(req *Request) map[string]any return body } -// convertMessagesToAnthropic translates from the OpenAI-shaped mcp.Message +// ConvertMessagesToAnthropic translates from the OpenAI-shaped mcp.Message // slice to Anthropic's messages array. -// -// Rules: -// 1. role=assistant + ToolCalls โ†’ role=assistant, content=[tool_use, ...] -// 2. role=tool (result) โ†’ role=user, content=[tool_result, ...] -// Consecutive tool-result messages are merged into one user turn so the -// conversation always alternates user/assistant. -// 3. All other messages โ†’ {role, content} as-is. -func convertMessagesToAnthropic(msgs []Message) []map[string]any { +func ConvertMessagesToAnthropic(msgs []mcp.Message) []map[string]any { var out []map[string]any for i := 0; i < len(msgs); { @@ -232,29 +223,18 @@ func convertMessagesToAnthropic(msgs []Message) []map[string]any { // โ”€โ”€ Response parsers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -// parseMCPResponse extracts the plain-text reply from an Anthropic response. -// Used by CallWithMessages / CallWithRequest (no tool support). -func (c *ClaudeClient) parseMCPResponse(body []byte) (string, error) { - r, err := c.parseMCPResponseFull(body) +// ParseMCPResponse extracts the plain-text reply from an Anthropic response. +func (c *ClaudeClient) ParseMCPResponse(body []byte) (string, error) { + r, err := c.ParseMCPResponseFull(body) if err != nil { return "", err } return r.Content, nil } -// parseMCPResponseFull extracts both text and tool calls from an Anthropic +// ParseMCPResponseFull extracts both text and tool calls from an Anthropic // response envelope. -// -// Anthropic response shape: -// -// { -// "content": [ -// {"type": "text", "text": "..."}, -// {"type": "tool_use", "id": "...", "name": "...", "input": {...}} -// ], -// "stop_reason": "tool_use" | "end_turn" -// } -func (c *ClaudeClient) parseMCPResponseFull(body []byte) (*LLMResponse, error) { +func (c *ClaudeClient) ParseMCPResponseFull(body []byte) (*mcp.LLMResponse, error) { var raw struct { Content []struct { Type string `json:"type"` @@ -281,8 +261,8 @@ func (c *ClaudeClient) parseMCPResponseFull(body []byte) (*LLMResponse, error) { } total := raw.Usage.InputTokens + raw.Usage.OutputTokens - if TokenUsageCallback != nil && total > 0 { - TokenUsageCallback(TokenUsage{ + if mcp.TokenUsageCallback != nil && total > 0 { + mcp.TokenUsageCallback(mcp.TokenUsage{ Provider: c.Provider, Model: c.Model, PromptTokens: raw.Usage.InputTokens, @@ -291,7 +271,7 @@ func (c *ClaudeClient) parseMCPResponseFull(body []byte) (*LLMResponse, error) { }) } - result := &LLMResponse{} + result := &mcp.LLMResponse{} for _, block := range raw.Content { switch block.Type { case "text": @@ -304,10 +284,10 @@ func (c *ClaudeClient) parseMCPResponseFull(body []byte) (*LLMResponse, error) { if err != nil { argsJSON = []byte("{}") } - result.ToolCalls = append(result.ToolCalls, ToolCall{ + result.ToolCalls = append(result.ToolCalls, mcp.ToolCall{ ID: block.ID, Type: "function", - Function: ToolCallFunction{ + Function: mcp.ToolCallFunction{ Name: block.Name, Arguments: string(argsJSON), }, diff --git a/mcp/provider/deepseek.go b/mcp/provider/deepseek.go new file mode 100644 index 00000000..638db4f7 --- /dev/null +++ b/mcp/provider/deepseek.go @@ -0,0 +1,69 @@ +package provider + +import ( + "net/http" + + "nofx/mcp" +) + +func init() { + mcp.RegisterProvider(mcp.ProviderDeepSeek, func(opts ...mcp.ClientOption) mcp.AIClient { + return NewDeepSeekClientWithOptions(opts...) + }) +} + +type DeepSeekClient struct { + *mcp.Client +} + +func (c *DeepSeekClient) BaseClient() *mcp.Client { return c.Client } + +// NewDeepSeekClient creates DeepSeek client (backward compatible) +// +// Deprecated: Recommend using NewDeepSeekClientWithOptions for better flexibility +func NewDeepSeekClient() mcp.AIClient { + return NewDeepSeekClientWithOptions() +} + +// NewDeepSeekClientWithOptions creates DeepSeek client (supports options pattern) +func NewDeepSeekClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient { + deepseekOpts := []mcp.ClientOption{ + mcp.WithProvider(mcp.ProviderDeepSeek), + mcp.WithModel(mcp.DefaultDeepSeekModel), + mcp.WithBaseURL(mcp.DefaultDeepSeekBaseURL), + } + + allOpts := append(deepseekOpts, opts...) + baseClient := mcp.NewClient(allOpts...).(*mcp.Client) + + dsClient := &DeepSeekClient{ + Client: baseClient, + } + + baseClient.Hooks = dsClient + return dsClient +} + +func (dsClient *DeepSeekClient) SetAPIKey(apiKey string, customURL string, customModel string) { + dsClient.APIKey = apiKey + + if len(apiKey) > 8 { + dsClient.Log.Infof("๐Ÿ”ง [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) + } + if customURL != "" { + dsClient.BaseURL = customURL + dsClient.Log.Infof("๐Ÿ”ง [MCP] DeepSeek using custom BaseURL: %s", customURL) + } else { + dsClient.Log.Infof("๐Ÿ”ง [MCP] DeepSeek using default BaseURL: %s", dsClient.BaseURL) + } + if customModel != "" { + dsClient.Model = customModel + dsClient.Log.Infof("๐Ÿ”ง [MCP] DeepSeek using custom Model: %s", customModel) + } else { + dsClient.Log.Infof("๐Ÿ”ง [MCP] DeepSeek using default Model: %s", dsClient.Model) + } +} + +func (dsClient *DeepSeekClient) SetAuthHeader(reqHeaders http.Header) { + dsClient.Client.SetAuthHeader(reqHeaders) +} diff --git a/mcp/provider/gemini.go b/mcp/provider/gemini.go new file mode 100644 index 00000000..c90ec3ed --- /dev/null +++ b/mcp/provider/gemini.go @@ -0,0 +1,73 @@ +package provider + +import ( + "net/http" + + "nofx/mcp" +) + +const ( + DefaultGeminiBaseURL = "https://generativelanguage.googleapis.com/v1beta/openai" + DefaultGeminiModel = "gemini-3-pro-preview" +) + +func init() { + mcp.RegisterProvider(mcp.ProviderGemini, func(opts ...mcp.ClientOption) mcp.AIClient { + return NewGeminiClientWithOptions(opts...) + }) +} + +type GeminiClient struct { + *mcp.Client +} + +func (c *GeminiClient) BaseClient() *mcp.Client { return c.Client } + +// NewGeminiClient creates Gemini client (backward compatible) +func NewGeminiClient() mcp.AIClient { + return NewGeminiClientWithOptions() +} + +// NewGeminiClientWithOptions creates Gemini client (supports options pattern) +func NewGeminiClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient { + geminiOpts := []mcp.ClientOption{ + mcp.WithProvider(mcp.ProviderGemini), + mcp.WithModel(DefaultGeminiModel), + mcp.WithBaseURL(DefaultGeminiBaseURL), + } + + allOpts := append(geminiOpts, opts...) + baseClient := mcp.NewClient(allOpts...).(*mcp.Client) + + geminiClient := &GeminiClient{ + Client: baseClient, + } + + baseClient.Hooks = geminiClient + return geminiClient +} + +func (c *GeminiClient) SetAPIKey(apiKey string, customURL string, customModel string) { + c.APIKey = apiKey + + if len(apiKey) > 8 { + c.Log.Infof("๐Ÿ”ง [MCP] Gemini API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) + } + if customURL != "" { + c.BaseURL = customURL + c.Log.Infof("๐Ÿ”ง [MCP] Gemini using custom BaseURL: %s", customURL) + } else { + c.Log.Infof("๐Ÿ”ง [MCP] Gemini using default BaseURL: %s", c.BaseURL) + } + if customModel != "" { + c.Model = customModel + c.Log.Infof("๐Ÿ”ง [MCP] Gemini using custom Model: %s", customModel) + } else { + c.Log.Infof("๐Ÿ”ง [MCP] Gemini using default Model: %s", c.Model) + } +} + +// Gemini OpenAI-compatible API uses standard Bearer auth +func (c *GeminiClient) SetAuthHeader(reqHeaders http.Header) { + c.Client.SetAuthHeader(reqHeaders) +} diff --git a/mcp/provider/grok.go b/mcp/provider/grok.go new file mode 100644 index 00000000..6eb885fc --- /dev/null +++ b/mcp/provider/grok.go @@ -0,0 +1,73 @@ +package provider + +import ( + "net/http" + + "nofx/mcp" +) + +const ( + DefaultGrokBaseURL = "https://api.x.ai/v1" + DefaultGrokModel = "grok-3-latest" +) + +func init() { + mcp.RegisterProvider(mcp.ProviderGrok, func(opts ...mcp.ClientOption) mcp.AIClient { + return NewGrokClientWithOptions(opts...) + }) +} + +type GrokClient struct { + *mcp.Client +} + +func (c *GrokClient) BaseClient() *mcp.Client { return c.Client } + +// NewGrokClient creates Grok client (backward compatible) +func NewGrokClient() mcp.AIClient { + return NewGrokClientWithOptions() +} + +// NewGrokClientWithOptions creates Grok client (supports options pattern) +func NewGrokClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient { + grokOpts := []mcp.ClientOption{ + mcp.WithProvider(mcp.ProviderGrok), + mcp.WithModel(DefaultGrokModel), + mcp.WithBaseURL(DefaultGrokBaseURL), + } + + allOpts := append(grokOpts, opts...) + baseClient := mcp.NewClient(allOpts...).(*mcp.Client) + + grokClient := &GrokClient{ + Client: baseClient, + } + + baseClient.Hooks = grokClient + return grokClient +} + +func (c *GrokClient) SetAPIKey(apiKey string, customURL string, customModel string) { + c.APIKey = apiKey + + if len(apiKey) > 8 { + c.Log.Infof("๐Ÿ”ง [MCP] Grok API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) + } + if customURL != "" { + c.BaseURL = customURL + c.Log.Infof("๐Ÿ”ง [MCP] Grok using custom BaseURL: %s", customURL) + } else { + c.Log.Infof("๐Ÿ”ง [MCP] Grok using default BaseURL: %s", c.BaseURL) + } + if customModel != "" { + c.Model = customModel + c.Log.Infof("๐Ÿ”ง [MCP] Grok using custom Model: %s", customModel) + } else { + c.Log.Infof("๐Ÿ”ง [MCP] Grok using default Model: %s", c.Model) + } +} + +// Grok uses standard OpenAI-compatible API with Bearer auth +func (c *GrokClient) SetAuthHeader(reqHeaders http.Header) { + c.Client.SetAuthHeader(reqHeaders) +} diff --git a/mcp/provider/kimi.go b/mcp/provider/kimi.go new file mode 100644 index 00000000..e4301c31 --- /dev/null +++ b/mcp/provider/kimi.go @@ -0,0 +1,73 @@ +package provider + +import ( + "net/http" + + "nofx/mcp" +) + +const ( + DefaultKimiBaseURL = "https://api.moonshot.ai/v1" // Global endpoint (use api.moonshot.cn for China) + DefaultKimiModel = "moonshot-v1-auto" +) + +func init() { + mcp.RegisterProvider(mcp.ProviderKimi, func(opts ...mcp.ClientOption) mcp.AIClient { + return NewKimiClientWithOptions(opts...) + }) +} + +type KimiClient struct { + *mcp.Client +} + +func (c *KimiClient) BaseClient() *mcp.Client { return c.Client } + +// NewKimiClient creates Kimi (Moonshot) client (backward compatible) +func NewKimiClient() mcp.AIClient { + return NewKimiClientWithOptions() +} + +// NewKimiClientWithOptions creates Kimi client (supports options pattern) +func NewKimiClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient { + kimiOpts := []mcp.ClientOption{ + mcp.WithProvider(mcp.ProviderKimi), + mcp.WithModel(DefaultKimiModel), + mcp.WithBaseURL(DefaultKimiBaseURL), + } + + allOpts := append(kimiOpts, opts...) + baseClient := mcp.NewClient(allOpts...).(*mcp.Client) + + kimiClient := &KimiClient{ + Client: baseClient, + } + + baseClient.Hooks = kimiClient + return kimiClient +} + +func (c *KimiClient) SetAPIKey(apiKey string, customURL string, customModel string) { + c.APIKey = apiKey + + if len(apiKey) > 8 { + c.Log.Infof("๐Ÿ”ง [MCP] Kimi API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) + } + if customURL != "" { + c.BaseURL = customURL + c.Log.Infof("๐Ÿ”ง [MCP] Kimi using custom BaseURL: %s", customURL) + } else { + c.Log.Infof("๐Ÿ”ง [MCP] Kimi using default BaseURL: %s", c.BaseURL) + } + if customModel != "" { + c.Model = customModel + c.Log.Infof("๐Ÿ”ง [MCP] Kimi using custom Model: %s", customModel) + } else { + c.Log.Infof("๐Ÿ”ง [MCP] Kimi using default Model: %s", c.Model) + } +} + +// Kimi uses standard OpenAI-compatible API +func (c *KimiClient) SetAuthHeader(reqHeaders http.Header) { + c.Client.SetAuthHeader(reqHeaders) +} diff --git a/mcp/provider/minimax.go b/mcp/provider/minimax.go new file mode 100644 index 00000000..23e66686 --- /dev/null +++ b/mcp/provider/minimax.go @@ -0,0 +1,73 @@ +package provider + +import ( + "net/http" + + "nofx/mcp" +) + +const ( + DefaultMiniMaxBaseURL = "https://api.minimax.io/v1" + DefaultMiniMaxModel = "MiniMax-M2.5" +) + +func init() { + mcp.RegisterProvider(mcp.ProviderMiniMax, func(opts ...mcp.ClientOption) mcp.AIClient { + return NewMiniMaxClientWithOptions(opts...) + }) +} + +type MiniMaxClient struct { + *mcp.Client +} + +func (c *MiniMaxClient) BaseClient() *mcp.Client { return c.Client } + +// NewMiniMaxClient creates MiniMax client (backward compatible) +func NewMiniMaxClient() mcp.AIClient { + return NewMiniMaxClientWithOptions() +} + +// NewMiniMaxClientWithOptions creates MiniMax client (supports options pattern) +func NewMiniMaxClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient { + minimaxOpts := []mcp.ClientOption{ + mcp.WithProvider(mcp.ProviderMiniMax), + mcp.WithModel(DefaultMiniMaxModel), + mcp.WithBaseURL(DefaultMiniMaxBaseURL), + } + + allOpts := append(minimaxOpts, opts...) + baseClient := mcp.NewClient(allOpts...).(*mcp.Client) + + minimaxClient := &MiniMaxClient{ + Client: baseClient, + } + + baseClient.Hooks = minimaxClient + return minimaxClient +} + +func (c *MiniMaxClient) SetAPIKey(apiKey string, customURL string, customModel string) { + c.APIKey = apiKey + + if len(apiKey) > 8 { + c.Log.Infof("๐Ÿ”ง [MCP] MiniMax API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) + } + if customURL != "" { + c.BaseURL = customURL + c.Log.Infof("๐Ÿ”ง [MCP] MiniMax using custom BaseURL: %s", customURL) + } else { + c.Log.Infof("๐Ÿ”ง [MCP] MiniMax using default BaseURL: %s", c.BaseURL) + } + if customModel != "" { + c.Model = customModel + c.Log.Infof("๐Ÿ”ง [MCP] MiniMax using custom Model: %s", customModel) + } else { + c.Log.Infof("๐Ÿ”ง [MCP] MiniMax using default Model: %s", c.Model) + } +} + +// MiniMax uses standard OpenAI-compatible API with Bearer auth +func (c *MiniMaxClient) SetAuthHeader(reqHeaders http.Header) { + c.Client.SetAuthHeader(reqHeaders) +} diff --git a/mcp/provider/openai.go b/mcp/provider/openai.go new file mode 100644 index 00000000..4cd13c37 --- /dev/null +++ b/mcp/provider/openai.go @@ -0,0 +1,73 @@ +package provider + +import ( + "net/http" + + "nofx/mcp" +) + +const ( + DefaultOpenAIBaseURL = "https://api.openai.com/v1" + DefaultOpenAIModel = "gpt-5.4" +) + +func init() { + mcp.RegisterProvider(mcp.ProviderOpenAI, func(opts ...mcp.ClientOption) mcp.AIClient { + return NewOpenAIClientWithOptions(opts...) + }) +} + +type OpenAIClient struct { + *mcp.Client +} + +func (c *OpenAIClient) BaseClient() *mcp.Client { return c.Client } + +// NewOpenAIClient creates OpenAI client (backward compatible) +func NewOpenAIClient() mcp.AIClient { + return NewOpenAIClientWithOptions() +} + +// NewOpenAIClientWithOptions creates OpenAI client (supports options pattern) +func NewOpenAIClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient { + openaiOpts := []mcp.ClientOption{ + mcp.WithProvider(mcp.ProviderOpenAI), + mcp.WithModel(DefaultOpenAIModel), + mcp.WithBaseURL(DefaultOpenAIBaseURL), + } + + allOpts := append(openaiOpts, opts...) + baseClient := mcp.NewClient(allOpts...).(*mcp.Client) + + openaiClient := &OpenAIClient{ + Client: baseClient, + } + + baseClient.Hooks = openaiClient + return openaiClient +} + +func (c *OpenAIClient) SetAPIKey(apiKey string, customURL string, customModel string) { + c.APIKey = apiKey + + if len(apiKey) > 8 { + c.Log.Infof("๐Ÿ”ง [MCP] OpenAI API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) + } + if customURL != "" { + c.BaseURL = customURL + c.Log.Infof("๐Ÿ”ง [MCP] OpenAI using custom BaseURL: %s", customURL) + } else { + c.Log.Infof("๐Ÿ”ง [MCP] OpenAI using default BaseURL: %s", c.BaseURL) + } + if customModel != "" { + c.Model = customModel + c.Log.Infof("๐Ÿ”ง [MCP] OpenAI using custom Model: %s", customModel) + } else { + c.Log.Infof("๐Ÿ”ง [MCP] OpenAI using default Model: %s", c.Model) + } +} + +// OpenAI uses standard Bearer auth +func (c *OpenAIClient) SetAuthHeader(reqHeaders http.Header) { + c.Client.SetAuthHeader(reqHeaders) +} diff --git a/mcp/provider/options_test.go b/mcp/provider/options_test.go new file mode 100644 index 00000000..cb03f9cf --- /dev/null +++ b/mcp/provider/options_test.go @@ -0,0 +1,83 @@ +package provider + +import ( + "testing" + + "nofx/mcp" +) + +func TestOptionsWithDeepSeekClient(t *testing.T) { + logger := mcp.NewNoopLogger() + + client := NewDeepSeekClientWithOptions( + mcp.WithAPIKey("sk-deepseek-key"), + mcp.WithLogger(logger), + mcp.WithMaxTokens(5000), + ) + + dsClient := client.(*DeepSeekClient) + + // Verify DeepSeek default values + if dsClient.Provider != mcp.ProviderDeepSeek { + t.Error("Provider should be DeepSeek") + } + + if dsClient.BaseURL != mcp.DefaultDeepSeekBaseURL { + t.Error("BaseURL should be DeepSeek default") + } + + if dsClient.Model != mcp.DefaultDeepSeekModel { + t.Error("Model should be DeepSeek default") + } + + // Verify custom options + if dsClient.APIKey != "sk-deepseek-key" { + t.Error("APIKey should be set from options") + } + + if dsClient.Log != logger { + t.Error("Log should be set from options") + } + + if dsClient.MaxTokens != 5000 { + t.Error("MaxTokens should be 5000") + } +} + +func TestOptionsWithQwenClient(t *testing.T) { + logger := mcp.NewNoopLogger() + + client := NewQwenClientWithOptions( + mcp.WithAPIKey("sk-qwen-key"), + mcp.WithLogger(logger), + mcp.WithMaxTokens(6000), + ) + + qwenClient := client.(*QwenClient) + + // Verify Qwen default values + if qwenClient.Provider != mcp.ProviderQwen { + t.Error("Provider should be Qwen") + } + + if qwenClient.BaseURL != mcp.DefaultQwenBaseURL { + t.Error("BaseURL should be Qwen default") + } + + if qwenClient.Model != mcp.DefaultQwenModel { + t.Error("Model should be Qwen default") + } + + // Verify custom options + if qwenClient.APIKey != "sk-qwen-key" { + t.Error("APIKey should be set from options") + } + + if qwenClient.Log != logger { + t.Error("Log should be set from options") + } + + if qwenClient.MaxTokens != 6000 { + t.Error("MaxTokens should be 6000") + } +} diff --git a/mcp/provider/qwen.go b/mcp/provider/qwen.go new file mode 100644 index 00000000..40968ee4 --- /dev/null +++ b/mcp/provider/qwen.go @@ -0,0 +1,74 @@ +package provider + +import ( + "net/http" + + "nofx/mcp" +) + +const ( + DefaultQwenBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1" + DefaultQwenModel = "qwen3-max" +) + +func init() { + mcp.RegisterProvider(mcp.ProviderQwen, func(opts ...mcp.ClientOption) mcp.AIClient { + return NewQwenClientWithOptions(opts...) + }) +} + +type QwenClient struct { + *mcp.Client +} + +func (c *QwenClient) BaseClient() *mcp.Client { return c.Client } + +// NewQwenClient creates Qwen client (backward compatible) +// +// Deprecated: Recommend using NewQwenClientWithOptions for better flexibility +func NewQwenClient() mcp.AIClient { + return NewQwenClientWithOptions() +} + +// NewQwenClientWithOptions creates Qwen client (supports options pattern) +func NewQwenClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient { + qwenOpts := []mcp.ClientOption{ + mcp.WithProvider(mcp.ProviderQwen), + mcp.WithModel(DefaultQwenModel), + mcp.WithBaseURL(DefaultQwenBaseURL), + } + + allOpts := append(qwenOpts, opts...) + baseClient := mcp.NewClient(allOpts...).(*mcp.Client) + + qwenClient := &QwenClient{ + Client: baseClient, + } + + baseClient.Hooks = qwenClient + return qwenClient +} + +func (qwenClient *QwenClient) SetAPIKey(apiKey string, customURL string, customModel string) { + qwenClient.APIKey = apiKey + + if len(apiKey) > 8 { + qwenClient.Log.Infof("๐Ÿ”ง [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) + } + if customURL != "" { + qwenClient.BaseURL = customURL + qwenClient.Log.Infof("๐Ÿ”ง [MCP] Qwen using custom BaseURL: %s", customURL) + } else { + qwenClient.Log.Infof("๐Ÿ”ง [MCP] Qwen using default BaseURL: %s", qwenClient.BaseURL) + } + if customModel != "" { + qwenClient.Model = customModel + qwenClient.Log.Infof("๐Ÿ”ง [MCP] Qwen using custom Model: %s", customModel) + } else { + qwenClient.Log.Infof("๐Ÿ”ง [MCP] Qwen using default Model: %s", qwenClient.Model) + } +} + +func (qwenClient *QwenClient) SetAuthHeader(reqHeaders http.Header) { + qwenClient.Client.SetAuthHeader(reqHeaders) +} diff --git a/mcp/providers.go b/mcp/providers.go new file mode 100644 index 00000000..ad7f2eed --- /dev/null +++ b/mcp/providers.go @@ -0,0 +1,31 @@ +package mcp + +// Provider name constants โ€” kept in the mcp package so that client.go can +// reference them for default configuration without importing sub-packages. +// Provider sub-packages re-use these same values. +const ( + ProviderDeepSeek = "deepseek" + ProviderOpenAI = "openai" + ProviderClaude = "claude" + ProviderQwen = "qwen" + ProviderGemini = "gemini" + ProviderGrok = "grok" + ProviderKimi = "kimi" + ProviderMiniMax = "minimax" + + ProviderBlockRunBase = "blockrun-base" + ProviderBlockRunSol = "blockrun-sol" + ProviderClaw402 = "claw402" + + // Default DeepSeek configuration (used as fallback in NewClient) + DefaultDeepSeekBaseURL = "https://api.deepseek.com" + DefaultDeepSeekModel = "deepseek-chat" + + // Default Qwen configuration (used by WithQwenConfig convenience option) + DefaultQwenBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1" + DefaultQwenModel = "qwen3-max" + + // Default MiniMax configuration (used by WithMiniMaxConfig convenience option) + DefaultMiniMaxBaseURL = "https://api.minimax.io/v1" + DefaultMiniMaxModel = "MiniMax-M2.5" +) diff --git a/mcp/qwen_client.go b/mcp/qwen_client.go deleted file mode 100644 index 4c1b6ae3..00000000 --- a/mcp/qwen_client.go +++ /dev/null @@ -1,83 +0,0 @@ -package mcp - -import ( - "net/http" -) - -const ( - ProviderQwen = "qwen" - DefaultQwenBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1" - DefaultQwenModel = "qwen3-max" -) - -type QwenClient struct { - *Client -} - -// NewQwenClient creates Qwen client (backward compatible) -// -// Deprecated: Recommend using NewQwenClientWithOptions for better flexibility -func NewQwenClient() AIClient { - return NewQwenClientWithOptions() -} - -// NewQwenClientWithOptions creates Qwen client (supports options pattern) -// -// Usage examples: -// // Basic usage -// client := mcp.NewQwenClientWithOptions() -// -// // Custom configuration -// client := mcp.NewQwenClientWithOptions( -// mcp.WithAPIKey("sk-xxx"), -// mcp.WithLogger(customLogger), -// mcp.WithTimeout(60*time.Second), -// ) -func NewQwenClientWithOptions(opts ...ClientOption) AIClient { - // 1. Create Qwen preset options - qwenOpts := []ClientOption{ - WithProvider(ProviderQwen), - WithModel(DefaultQwenModel), - WithBaseURL(DefaultQwenBaseURL), - } - - // 2. Merge user options (user options have higher priority) - allOpts := append(qwenOpts, opts...) - - // 3. Create base client - baseClient := NewClient(allOpts...).(*Client) - - // 4. Create Qwen client - qwenClient := &QwenClient{ - Client: baseClient, - } - - // 5. Set hooks to point to QwenClient (implement dynamic dispatch) - baseClient.hooks = qwenClient - - return qwenClient -} - -func (qwenClient *QwenClient) SetAPIKey(apiKey string, customURL string, customModel string) { - qwenClient.APIKey = apiKey - - if len(apiKey) > 8 { - qwenClient.logger.Infof("๐Ÿ”ง [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:]) - } - if customURL != "" { - qwenClient.BaseURL = customURL - qwenClient.logger.Infof("๐Ÿ”ง [MCP] Qwen using custom BaseURL: %s", customURL) - } else { - qwenClient.logger.Infof("๐Ÿ”ง [MCP] Qwen using default BaseURL: %s", qwenClient.BaseURL) - } - if customModel != "" { - qwenClient.Model = customModel - qwenClient.logger.Infof("๐Ÿ”ง [MCP] Qwen using custom Model: %s", customModel) - } else { - qwenClient.logger.Infof("๐Ÿ”ง [MCP] Qwen using default Model: %s", qwenClient.Model) - } -} - -func (qwenClient *QwenClient) setAuthHeader(reqHeaders http.Header) { - qwenClient.Client.setAuthHeader(reqHeaders) -} diff --git a/mcp/qwen_client_test.go b/mcp/qwen_client_test.go deleted file mode 100644 index 90149fc7..00000000 --- a/mcp/qwen_client_test.go +++ /dev/null @@ -1,272 +0,0 @@ -package mcp - -import ( - "testing" - "time" -) - -// ============================================================ -// Test QwenClient Creation and Configuration -// ============================================================ - -func TestNewQwenClient_Default(t *testing.T) { - client := NewQwenClient() - - if client == nil { - t.Fatal("client should not be nil") - } - - // Type assertion check - qwenClient, ok := client.(*QwenClient) - if !ok { - t.Fatal("client should be *QwenClient") - } - - // Verify default values - if qwenClient.Provider != ProviderQwen { - t.Errorf("Provider should be '%s', got '%s'", ProviderQwen, qwenClient.Provider) - } - - if qwenClient.BaseURL != DefaultQwenBaseURL { - t.Errorf("BaseURL should be '%s', got '%s'", DefaultQwenBaseURL, qwenClient.BaseURL) - } - - if qwenClient.Model != DefaultQwenModel { - t.Errorf("Model should be '%s', got '%s'", DefaultQwenModel, qwenClient.Model) - } - - if qwenClient.logger == nil { - t.Error("logger should not be nil") - } - - if qwenClient.httpClient == nil { - t.Error("httpClient should not be nil") - } -} - -func TestNewQwenClientWithOptions(t *testing.T) { - mockLogger := NewMockLogger() - customModel := "qwen-plus" - customAPIKey := "sk-custom-qwen-key" - - client := NewQwenClientWithOptions( - WithLogger(mockLogger), - WithModel(customModel), - WithAPIKey(customAPIKey), - WithMaxTokens(4000), - ) - - qwenClient := client.(*QwenClient) - - // Verify custom options are applied - if qwenClient.logger != mockLogger { - t.Error("logger should be set from option") - } - - if qwenClient.Model != customModel { - t.Error("Model should be set from option") - } - - if qwenClient.APIKey != customAPIKey { - t.Error("APIKey should be set from option") - } - - if qwenClient.MaxTokens != 4000 { - t.Error("MaxTokens should be 4000") - } - - // Verify Qwen default values are retained - if qwenClient.Provider != ProviderQwen { - t.Errorf("Provider should still be '%s'", ProviderQwen) - } - - if qwenClient.BaseURL != DefaultQwenBaseURL { - t.Errorf("BaseURL should still be '%s'", DefaultQwenBaseURL) - } -} - -// ============================================================ -// Test SetAPIKey -// ============================================================ - -func TestQwenClient_SetAPIKey(t *testing.T) { - mockLogger := NewMockLogger() - client := NewQwenClientWithOptions( - WithLogger(mockLogger), - ) - - qwenClient := client.(*QwenClient) - - // Test setting API Key (default URL and Model) - qwenClient.SetAPIKey("sk-test-key-12345678", "", "") - - if qwenClient.APIKey != "sk-test-key-12345678" { - t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", qwenClient.APIKey) - } - - // Verify logging - logs := mockLogger.GetLogsByLevel("INFO") - if len(logs) == 0 { - t.Error("should have logged API key setting") - } - - // Verify BaseURL and Model remain default - if qwenClient.BaseURL != DefaultQwenBaseURL { - t.Error("BaseURL should remain default") - } - - if qwenClient.Model != DefaultQwenModel { - t.Error("Model should remain default") - } -} - -func TestQwenClient_SetAPIKey_WithCustomURL(t *testing.T) { - mockLogger := NewMockLogger() - client := NewQwenClientWithOptions( - WithLogger(mockLogger), - ) - - qwenClient := client.(*QwenClient) - - customURL := "https://custom.qwen.api.com/v1" - qwenClient.SetAPIKey("sk-test-key-12345678", customURL, "") - - if qwenClient.BaseURL != customURL { - t.Errorf("BaseURL should be '%s', got '%s'", customURL, qwenClient.BaseURL) - } - - // Verify logging - logs := mockLogger.GetLogsByLevel("INFO") - hasCustomURLLog := false - for _, log := range logs { - if log.Format == "๐Ÿ”ง [MCP] Qwen using custom BaseURL: %s" { - hasCustomURLLog = true - break - } - } - - if !hasCustomURLLog { - t.Error("should have logged custom BaseURL") - } -} - -func TestQwenClient_SetAPIKey_WithCustomModel(t *testing.T) { - mockLogger := NewMockLogger() - client := NewQwenClientWithOptions( - WithLogger(mockLogger), - ) - - qwenClient := client.(*QwenClient) - - customModel := "qwen-turbo" - qwenClient.SetAPIKey("sk-test-key-12345678", "", customModel) - - if qwenClient.Model != customModel { - t.Errorf("Model should be '%s', got '%s'", customModel, qwenClient.Model) - } - - // Verify logging - logs := mockLogger.GetLogsByLevel("INFO") - hasCustomModelLog := false - for _, log := range logs { - if log.Format == "๐Ÿ”ง [MCP] Qwen using custom Model: %s" { - hasCustomModelLog = true - break - } - } - - if !hasCustomModelLog { - t.Error("should have logged custom Model") - } -} - -// ============================================================ -// Test Integration Features -// ============================================================ - -func TestQwenClient_CallWithMessages_Success(t *testing.T) { - mockHTTP := NewMockHTTPClient() - mockHTTP.SetSuccessResponse("Qwen AI response") - mockLogger := NewMockLogger() - - client := NewQwenClientWithOptions( - WithHTTPClient(mockHTTP.ToHTTPClient()), - WithLogger(mockLogger), - WithAPIKey("sk-test-key"), - ) - - result, err := client.CallWithMessages("system prompt", "user prompt") - - if err != nil { - t.Fatalf("should not error: %v", err) - } - - if result != "Qwen AI response" { - t.Errorf("expected 'Qwen AI response', got '%s'", result) - } - - // Verify request - requests := mockHTTP.GetRequests() - if len(requests) != 1 { - t.Fatalf("expected 1 request, got %d", len(requests)) - } - - req := requests[0] - - // Verify URL - expectedURL := DefaultQwenBaseURL + "/chat/completions" - if req.URL.String() != expectedURL { - t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String()) - } - - // Verify Authorization header - authHeader := req.Header.Get("Authorization") - if authHeader != "Bearer sk-test-key" { - t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader) - } - - // Verify Content-Type - if req.Header.Get("Content-Type") != "application/json" { - t.Error("Content-Type should be application/json") - } -} - -func TestQwenClient_Timeout(t *testing.T) { - client := NewQwenClientWithOptions( - WithTimeout(30 * time.Second), - ) - - qwenClient := client.(*QwenClient) - - if qwenClient.httpClient.Timeout != 30*time.Second { - t.Errorf("expected timeout 30s, got %v", qwenClient.httpClient.Timeout) - } - - // Test SetTimeout - client.SetTimeout(60 * time.Second) - - if qwenClient.httpClient.Timeout != 60*time.Second { - t.Errorf("expected timeout 60s after SetTimeout, got %v", qwenClient.httpClient.Timeout) - } -} - -// ============================================================ -// Test hooks Mechanism -// ============================================================ - -func TestQwenClient_HooksIntegration(t *testing.T) { - client := NewQwenClientWithOptions() - qwenClient := client.(*QwenClient) - - // Verify hooks point to qwenClient itself (implements polymorphism) - if qwenClient.hooks != qwenClient { - t.Error("hooks should point to qwenClient for polymorphism") - } - - // Verify buildUrl uses Qwen configuration - url := qwenClient.buildUrl() - expectedURL := DefaultQwenBaseURL + "/chat/completions" - if url != expectedURL { - t.Errorf("expected URL '%s', got '%s'", expectedURL, url) - } -} diff --git a/mcp/registry.go b/mcp/registry.go new file mode 100644 index 00000000..1152e0d8 --- /dev/null +++ b/mcp/registry.go @@ -0,0 +1,20 @@ +package mcp + +// providerRegistry maps provider names to factory functions. +var providerRegistry = map[string]func(...ClientOption) AIClient{} + +// RegisterProvider registers a provider factory function. +// Called by provider/payment sub-packages in their init() functions. +func RegisterProvider(name string, factory func(...ClientOption) AIClient) { + providerRegistry[name] = factory +} + +// NewAIClientByProvider creates an AIClient by provider name using the registry. +// Returns nil if the provider is not registered. +func NewAIClientByProvider(name string, opts ...ClientOption) AIClient { + factory, ok := providerRegistry[name] + if !ok { + return nil + } + return factory(opts...) +} diff --git a/mcp/request_builder_test.go b/mcp/request_builder_test.go index 6c5de6db..4ec10a9f 100644 --- a/mcp/request_builder_test.go +++ b/mcp/request_builder_test.go @@ -450,10 +450,10 @@ func TestClient_CallWithRequest_UsesClientModel(t *testing.T) { mockHTTP.SetSuccessResponse("Response") mockLogger := NewMockLogger() - client := NewDeepSeekClientWithOptions( + client := NewClient( + WithDeepSeekConfig("sk-test-key"), WithHTTPClient(mockHTTP.ToHTTPClient()), WithLogger(mockLogger), - WithAPIKey("sk-test-key"), ) // Request does not set model, should use Client's model diff --git a/telegram/bot.go b/telegram/bot.go index e80e4c5a..95085257 100644 --- a/telegram/bot.go +++ b/telegram/bot.go @@ -5,6 +5,8 @@ import ( "nofx/config" "nofx/logger" "nofx/mcp" + _ "nofx/mcp/payment" + _ "nofx/mcp/provider" "nofx/store" "nofx/telegram/agent" "os" @@ -319,32 +321,11 @@ func isUSDCProvider(provider string) bool { } func clientForProvider(provider string) mcp.AIClient { - switch provider { - case "openai": - return mcp.NewOpenAIClient() - case "deepseek": - return mcp.NewDeepSeekClient() - case "claude": - return mcp.NewClaudeClient() - case "qwen": - return mcp.NewQwenClient() - case "kimi": - return mcp.NewKimiClient() - case "grok": - return mcp.NewGrokClient() - case "gemini": - return mcp.NewGeminiClient() - case "minimax": - return mcp.NewMiniMaxClient() - case "blockrun-base": - return mcp.NewBlockRunBaseClient() - case "blockrun-sol": - return mcp.NewBlockRunSolClient() - case "claw402": - return mcp.NewClaw402Client() - default: - return mcp.NewDeepSeekClient() + client := mcp.NewAIClientByProvider(provider) + if client == nil { + client = mcp.NewAIClientByProvider("deepseek") } + return client } // โ”€โ”€ Status message โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ diff --git a/trader/auto_trader.go b/trader/auto_trader.go index ef3a3478..6c958509 100644 --- a/trader/auto_trader.go +++ b/trader/auto_trader.go @@ -1,14 +1,12 @@ package trader import ( - "encoding/json" "fmt" - "math" - "nofx/experience" "nofx/kernel" "nofx/logger" - "nofx/market" "nofx/mcp" + _ "nofx/mcp/payment" + _ "nofx/mcp/provider" "nofx/store" "nofx/trader/aster" "nofx/trader/binance" @@ -20,7 +18,6 @@ import ( "nofx/trader/kucoin" "nofx/trader/lighter" "nofx/trader/okx" - "strings" "sync" "time" ) @@ -175,76 +172,42 @@ func NewAutoTrader(config AutoTraderConfig, st *store.Store, userID string) (*Au aiModel = "qwen" } + // Resolve API key (provider-specific overrides) + apiKey := config.CustomAPIKey + customURL := config.CustomAPIURL switch aiModel { - case "claude": - mcpClient = mcp.NewClaudeClient() - mcpClient.SetAPIKey(config.CustomAPIKey, config.CustomAPIURL, config.CustomModelName) - logger.Infof("๐Ÿค– [%s] Using Claude AI", config.Name) - - case "kimi": - mcpClient = mcp.NewKimiClient() - mcpClient.SetAPIKey(config.CustomAPIKey, config.CustomAPIURL, config.CustomModelName) - logger.Infof("๐Ÿค– [%s] Using Kimi (Moonshot) AI", config.Name) - - case "gemini": - mcpClient = mcp.NewGeminiClient() - mcpClient.SetAPIKey(config.CustomAPIKey, config.CustomAPIURL, config.CustomModelName) - logger.Infof("๐Ÿค– [%s] Using Google Gemini AI", config.Name) - - case "grok": - mcpClient = mcp.NewGrokClient() - mcpClient.SetAPIKey(config.CustomAPIKey, config.CustomAPIURL, config.CustomModelName) - logger.Infof("๐Ÿค– [%s] Using xAI Grok AI", config.Name) - - case "openai": - mcpClient = mcp.NewOpenAIClient() - mcpClient.SetAPIKey(config.CustomAPIKey, config.CustomAPIURL, config.CustomModelName) - logger.Infof("๐Ÿค– [%s] Using OpenAI", config.Name) - - case "minimax": - mcpClient = mcp.NewMiniMaxClient() - mcpClient.SetAPIKey(config.CustomAPIKey, config.CustomAPIURL, config.CustomModelName) - logger.Infof("๐Ÿค– [%s] Using MiniMax AI", config.Name) - - case "blockrun-base": - mcpClient = mcp.NewBlockRunBaseClient() - mcpClient.SetAPIKey(config.CustomAPIKey, "", config.CustomModelName) - logger.Infof("๐Ÿค– [%s] Using BlockRun (Base Wallet) AI", config.Name) - - case "blockrun-sol": - mcpClient = mcp.NewBlockRunSolClient() - mcpClient.SetAPIKey(config.CustomAPIKey, "", config.CustomModelName) - logger.Infof("๐Ÿค– [%s] Using BlockRun (Solana Wallet) AI", config.Name) - - case "claw402": - mcpClient = mcp.NewClaw402Client() - mcpClient.SetAPIKey(config.CustomAPIKey, "", config.CustomModelName) - logger.Infof("๐Ÿค– [%s] Using Claw402 (Base USDC) AI", config.Name) - case "qwen": - mcpClient = mcp.NewQwenClient() - apiKey := config.QwenKey - if apiKey == "" { - apiKey = config.CustomAPIKey + if config.QwenKey != "" { + apiKey = config.QwenKey } - mcpClient.SetAPIKey(apiKey, config.CustomAPIURL, config.CustomModelName) - logger.Infof("๐Ÿค– [%s] Using Alibaba Cloud Qwen AI", config.Name) - - case "custom": - mcpClient = mcp.New() - mcpClient.SetAPIKey(config.CustomAPIKey, config.CustomAPIURL, config.CustomModelName) - logger.Infof("๐Ÿค– [%s] Using custom AI API: %s (model: %s)", config.Name, config.CustomAPIURL, config.CustomModelName) - - default: // deepseek or empty - mcpClient = mcp.NewDeepSeekClient() - apiKey := config.DeepSeekKey - if apiKey == "" { - apiKey = config.CustomAPIKey + case "deepseek", "": + if config.DeepSeekKey != "" { + apiKey = config.DeepSeekKey } - mcpClient.SetAPIKey(apiKey, config.CustomAPIURL, config.CustomModelName) - logger.Infof("๐Ÿค– [%s] Using DeepSeek AI", config.Name) } + // Create client via registry (covers all registered providers) + if aiModel == "custom" { + mcpClient = mcp.New() + } else if aiModel == "" { + aiModel = "deepseek" + mcpClient = mcp.NewAIClientByProvider(aiModel) + } else { + mcpClient = mcp.NewAIClientByProvider(aiModel) + } + if mcpClient == nil { + mcpClient = mcp.New() + } + + // Payment providers (blockrun-*, claw402) ignore customURL + switch aiModel { + case "blockrun-base", "blockrun-sol", "claw402": + mcpClient.SetAPIKey(apiKey, "", config.CustomModelName) + default: + mcpClient.SetAPIKey(apiKey, customURL, config.CustomModelName) + } + logger.Infof("๐Ÿค– [%s] Using %s AI", config.Name, aiModel) + if config.CustomAPIURL != "" || config.CustomModelName != "" { logger.Infof("๐Ÿ”ง [%s] Custom config - URL: %s, Model: %s", config.Name, config.CustomAPIURL, config.CustomModelName) } @@ -554,899 +517,6 @@ func (at *AutoTrader) Stop() { logger.Info("โน Automatic trading system stopped") } -// runCycle runs one trading cycle (using AI full decision-making) -func (at *AutoTrader) runCycle() error { - at.callCount++ - - logger.Info("\n" + strings.Repeat("=", 70) + "\n") - logger.Infof("โฐ %s - AI decision cycle #%d", time.Now().Format("2006-01-02 15:04:05"), at.callCount) - logger.Info(strings.Repeat("=", 70)) - - // 0. Check if trader is stopped (early exit to prevent trades after Stop() is called) - at.isRunningMutex.RLock() - running := at.isRunning - at.isRunningMutex.RUnlock() - if !running { - logger.Infof("โน Trader is stopped, aborting cycle #%d", at.callCount) - return nil - } - - // Create decision record - record := &store.DecisionRecord{ - ExecutionLog: []string{}, - Success: true, - } - - // 1. Check if trading needs to be stopped - if time.Now().Before(at.stopUntil) { - remaining := at.stopUntil.Sub(time.Now()) - logger.Infof("โธ Risk control: Trading paused, remaining %.0f minutes", remaining.Minutes()) - record.Success = false - record.ErrorMessage = fmt.Sprintf("Risk control paused, remaining %.0f minutes", remaining.Minutes()) - at.saveDecision(record) - return nil - } - - // 2. Reset daily P&L (reset every day) - if time.Since(at.lastResetTime) > 24*time.Hour { - at.dailyPnL = 0 - at.lastResetTime = time.Now() - logger.Info("๐Ÿ“… Daily P&L reset") - } - - // 4. Collect trading context - ctx, err := at.buildTradingContext() - if err != nil { - record.Success = false - record.ErrorMessage = fmt.Sprintf("Failed to build trading context: %v", err) - at.saveDecision(record) - return fmt.Errorf("failed to build trading context: %w", err) - } - - // Save equity snapshot independently (decoupled from AI decision, used for drawing profit curve) - // NOTE: Must be called BEFORE candidate coins check to ensure equity is always recorded - at.saveEquitySnapshot(ctx) - - // ๅฆ‚ๆžœๆฒกๆœ‰ๅ€™้€‰ๅธ็ง๏ผŒ่ฎฐๅฝ•ไฝ†ไธๆŠฅ้”™ - if len(ctx.CandidateCoins) == 0 { - logger.Infof("โ„น๏ธ No candidate coins available, skipping this cycle") - record.Success = true // ไธๆ˜ฏ้”™่ฏฏ๏ผŒๅชๆ˜ฏๆฒกๆœ‰ๅ€™้€‰ๅธ - record.ExecutionLog = append(record.ExecutionLog, "No candidate coins available, cycle skipped") - record.AccountState = store.AccountSnapshot{ - TotalBalance: ctx.Account.TotalEquity, - AvailableBalance: ctx.Account.AvailableBalance, - TotalUnrealizedProfit: ctx.Account.UnrealizedPnL, - PositionCount: ctx.Account.PositionCount, - InitialBalance: at.initialBalance, - } - at.saveDecision(record) - return nil - } - - logger.Info(strings.Repeat("=", 70)) - for _, coin := range ctx.CandidateCoins { - record.CandidateCoins = append(record.CandidateCoins, coin.Symbol) - } - - logger.Infof("๐Ÿ“Š Account equity: %.2f USDT | Available: %.2f USDT | Positions: %d", - ctx.Account.TotalEquity, ctx.Account.AvailableBalance, ctx.Account.PositionCount) - - // 5. Use strategy engine to call AI for decision - logger.Infof("๐Ÿค– Requesting AI analysis and decision... [Strategy Engine]") - aiDecision, err := kernel.GetFullDecisionWithStrategy(ctx, at.mcpClient, at.strategyEngine, "balanced") - - if aiDecision != nil && aiDecision.AIRequestDurationMs > 0 { - record.AIRequestDurationMs = aiDecision.AIRequestDurationMs - logger.Infof("โฑ๏ธ AI call duration: %.2f seconds", float64(record.AIRequestDurationMs)/1000) - record.ExecutionLog = append(record.ExecutionLog, - fmt.Sprintf("AI call duration: %d ms", record.AIRequestDurationMs)) - } - - // Save chain of thought, decisions, and input prompt even if there's an error (for debugging) - if aiDecision != nil { - record.SystemPrompt = aiDecision.SystemPrompt // Save system prompt - record.InputPrompt = aiDecision.UserPrompt - record.CoTTrace = aiDecision.CoTTrace - record.RawResponse = aiDecision.RawResponse // Save raw AI response for debugging - if len(aiDecision.Decisions) > 0 { - decisionJSON, _ := json.MarshalIndent(aiDecision.Decisions, "", " ") - record.DecisionJSON = string(decisionJSON) - } - } - - if err != nil { - record.Success = false - record.ErrorMessage = fmt.Sprintf("Failed to get AI decision: %v", err) - - // Print system prompt and AI chain of thought (output even with errors for debugging) - if aiDecision != nil { - logger.Info("\n" + strings.Repeat("=", 70) + "\n") - logger.Infof("๐Ÿ“‹ System prompt (error case)") - logger.Info(strings.Repeat("=", 70)) - logger.Info(aiDecision.SystemPrompt) - logger.Info(strings.Repeat("=", 70)) - - if aiDecision.CoTTrace != "" { - logger.Info("\n" + strings.Repeat("-", 70) + "\n") - logger.Info("๐Ÿ’ญ AI chain of thought analysis (error case):") - logger.Info(strings.Repeat("-", 70)) - logger.Info(aiDecision.CoTTrace) - logger.Info(strings.Repeat("-", 70)) - } - } - - at.saveDecision(record) - return fmt.Errorf("failed to get AI decision: %w", err) - } - - // // 5. Print system prompt - // logger.Infof("\n" + strings.Repeat("=", 70)) - // logger.Infof("๐Ÿ“‹ System prompt [template: %s]", at.systemPromptTemplate) - // logger.Info(strings.Repeat("=", 70)) - // logger.Info(decision.SystemPrompt) - // logger.Infof(strings.Repeat("=", 70) + "\n") - - // 6. Print AI chain of thought - // logger.Infof("\n" + strings.Repeat("-", 70)) - // logger.Info("๐Ÿ’ญ AI chain of thought analysis:") - // logger.Info(strings.Repeat("-", 70)) - // logger.Info(decision.CoTTrace) - // logger.Infof(strings.Repeat("-", 70) + "\n") - - // 7. Print AI decisions - // logger.Infof("๐Ÿ“‹ AI decision list (%d items):\n", len(kernel.Decisions)) - // for i, d := range kernel.Decisions { - // logger.Infof(" [%d] %s: %s - %s", i+1, d.Symbol, d.Action, d.Reasoning) - // if d.Action == "open_long" || d.Action == "open_short" { - // logger.Infof(" Leverage: %dx | Position: %.2f USDT | Stop loss: %.4f | Take profit: %.4f", - // d.Leverage, d.PositionSizeUSD, d.StopLoss, d.TakeProfit) - // } - // } - logger.Info() - logger.Info(strings.Repeat("-", 70)) - // 8. Sort decisions: ensure close positions first, then open positions (prevent position stacking overflow) - logger.Info(strings.Repeat("-", 70)) - - // 8. Sort decisions: ensure close positions first, then open positions (prevent position stacking overflow) - sortedDecisions := sortDecisionsByPriority(aiDecision.Decisions) - - logger.Info("๐Ÿ”„ Execution order (optimized): Close positions first โ†’ Open positions later") - for i, d := range sortedDecisions { - logger.Infof(" [%d] %s %s", i+1, d.Symbol, d.Action) - } - logger.Info() - - // Check if trader is stopped before executing any decisions (prevent trades after Stop()) - at.isRunningMutex.RLock() - running = at.isRunning - at.isRunningMutex.RUnlock() - if !running { - logger.Infof("โน Trader stopped before decision execution, aborting cycle #%d", at.callCount) - return nil - } - - // Execute decisions and record results - for _, d := range sortedDecisions { - // Check if trader is stopped before each decision (allow immediate stop during execution) - at.isRunningMutex.RLock() - running = at.isRunning - at.isRunningMutex.RUnlock() - if !running { - logger.Infof("โน Trader stopped during decision execution, aborting remaining decisions") - break - } - - actionRecord := store.DecisionAction{ - Action: d.Action, - Symbol: d.Symbol, - Quantity: 0, - Leverage: d.Leverage, - Price: 0, - StopLoss: d.StopLoss, - TakeProfit: d.TakeProfit, - Confidence: d.Confidence, - Reasoning: d.Reasoning, - Timestamp: time.Now().UTC(), - Success: false, - } - - if err := at.executeDecisionWithRecord(&d, &actionRecord); err != nil { - logger.Infof("โŒ Failed to execute decision (%s %s): %v", d.Symbol, d.Action, err) - actionRecord.Error = err.Error() - record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("โŒ %s %s failed: %v", d.Symbol, d.Action, err)) - } else { - actionRecord.Success = true - record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("โœ“ %s %s succeeded", d.Symbol, d.Action)) - // Brief delay after successful execution - time.Sleep(1 * time.Second) - } - - record.Decisions = append(record.Decisions, actionRecord) - } - - // 9. Save decision record - if err := at.saveDecision(record); err != nil { - logger.Infof("โš  Failed to save decision record: %v", err) - } - - return nil -} - -// buildTradingContext builds trading context -func (at *AutoTrader) buildTradingContext() (*kernel.Context, error) { - // 1. Get account information - balance, err := at.trader.GetBalance() - if err != nil { - return nil, fmt.Errorf("failed to get account balance: %w", err) - } - - // Get account fields - totalWalletBalance := 0.0 - totalUnrealizedProfit := 0.0 - availableBalance := 0.0 - totalEquity := 0.0 - - if wallet, ok := balance["totalWalletBalance"].(float64); ok { - totalWalletBalance = wallet - } - if unrealized, ok := balance["totalUnrealizedProfit"].(float64); ok { - totalUnrealizedProfit = unrealized - } - if avail, ok := balance["availableBalance"].(float64); ok { - availableBalance = avail - } - - // Use totalEquity directly if provided by trader (more accurate) - if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 { - totalEquity = eq - } else { - // Fallback: Total Equity = Wallet balance + Unrealized profit - totalEquity = totalWalletBalance + totalUnrealizedProfit - } - - // 2. Get position information - positions, err := at.trader.GetPositions() - if err != nil { - return nil, fmt.Errorf("failed to get positions: %w", err) - } - - var positionInfos []kernel.PositionInfo - totalMarginUsed := 0.0 - - // Current position key set (for cleaning up closed position records) - currentPositionKeys := make(map[string]bool) - - for _, pos := range positions { - symbol := pos["symbol"].(string) - side := pos["side"].(string) - entryPrice := pos["entryPrice"].(float64) - markPrice := pos["markPrice"].(float64) - quantity := pos["positionAmt"].(float64) - if quantity < 0 { - quantity = -quantity // Short position quantity is negative, convert to positive - } - - // Skip closed positions (quantity = 0), prevent "ghost positions" from being passed to AI - if quantity == 0 { - continue - } - - unrealizedPnl := pos["unRealizedProfit"].(float64) - liquidationPrice := pos["liquidationPrice"].(float64) - - // Calculate margin used (estimated) - leverage := 10 // Default value, should actually be fetched from position info - if lev, ok := pos["leverage"].(float64); ok { - leverage = int(lev) - } - marginUsed := (quantity * markPrice) / float64(leverage) - totalMarginUsed += marginUsed - - // Calculate P&L percentage (based on margin, considering leverage) - pnlPct := calculatePnLPercentage(unrealizedPnl, marginUsed) - - // Get position open time from exchange (preferred) or fallback to local tracking - posKey := symbol + "_" + side - currentPositionKeys[posKey] = true - - var updateTime int64 - // Priority 1: Get from database (trader_positions table) - most accurate - if at.store != nil { - if dbPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, symbol, side); err == nil && dbPos != nil { - if dbPos.EntryTime > 0 { - updateTime = dbPos.EntryTime - } - } - } - // Priority 2: Get from exchange API (Bybit: createdTime, OKX: createdTime) - if updateTime == 0 { - if createdTime, ok := pos["createdTime"].(int64); ok && createdTime > 0 { - updateTime = createdTime - } - } - // Priority 3: Fallback to local tracking - if updateTime == 0 { - if _, exists := at.positionFirstSeenTime[posKey]; !exists { - at.positionFirstSeenTime[posKey] = time.Now().UnixMilli() - } - updateTime = at.positionFirstSeenTime[posKey] - } - - // Get peak profit rate for this position - at.peakPnLCacheMutex.RLock() - peakPnlPct := at.peakPnLCache[posKey] - at.peakPnLCacheMutex.RUnlock() - - positionInfos = append(positionInfos, kernel.PositionInfo{ - Symbol: symbol, - Side: side, - EntryPrice: entryPrice, - MarkPrice: markPrice, - Quantity: quantity, - Leverage: leverage, - UnrealizedPnL: unrealizedPnl, - UnrealizedPnLPct: pnlPct, - PeakPnLPct: peakPnlPct, - LiquidationPrice: liquidationPrice, - MarginUsed: marginUsed, - UpdateTime: updateTime, - }) - } - - // Clean up closed position records - for key := range at.positionFirstSeenTime { - if !currentPositionKeys[key] { - delete(at.positionFirstSeenTime, key) - } - } - - // 3. Use strategy engine to get candidate coins (must have strategy engine) - var candidateCoins []kernel.CandidateCoin - if at.strategyEngine == nil { - logger.Infof("โš ๏ธ [%s] No strategy engine configured, skipping candidate coins", at.name) - } else { - coins, err := at.strategyEngine.GetCandidateCoins() - if err != nil { - // Log warning but don't fail - equity snapshot should still be saved - logger.Infof("โš ๏ธ [%s] Failed to get candidate coins: %v (will use empty list)", at.name, err) - } else { - candidateCoins = coins - logger.Infof("๐Ÿ“‹ [%s] Strategy engine fetched candidate coins: %d", at.name, len(candidateCoins)) - } - } - - // 4. Calculate total P&L - totalPnL := totalEquity - at.initialBalance - totalPnLPct := 0.0 - if at.initialBalance > 0 { - totalPnLPct = (totalPnL / at.initialBalance) * 100 - } - - marginUsedPct := 0.0 - if totalEquity > 0 { - marginUsedPct = (totalMarginUsed / totalEquity) * 100 - } - - // 5. Get leverage from strategy config - strategyConfig := at.strategyEngine.GetConfig() - btcEthLeverage := strategyConfig.RiskControl.BTCETHMaxLeverage - altcoinLeverage := strategyConfig.RiskControl.AltcoinMaxLeverage - logger.Infof("๐Ÿ“‹ [%s] Strategy leverage config: BTC/ETH=%dx, Altcoin=%dx", at.name, btcEthLeverage, altcoinLeverage) - - // 6. Build context - ctx := &kernel.Context{ - CurrentTime: time.Now().UTC().Format("2006-01-02 15:04:05 UTC"), - RuntimeMinutes: int(time.Since(at.startTime).Minutes()), - CallCount: at.callCount, - BTCETHLeverage: btcEthLeverage, - AltcoinLeverage: altcoinLeverage, - Account: kernel.AccountInfo{ - TotalEquity: totalEquity, - AvailableBalance: availableBalance, - UnrealizedPnL: totalUnrealizedProfit, - TotalPnL: totalPnL, - TotalPnLPct: totalPnLPct, - MarginUsed: totalMarginUsed, - MarginUsedPct: marginUsedPct, - PositionCount: len(positionInfos), - }, - Positions: positionInfos, - CandidateCoins: candidateCoins, - } - - // 7. Add recent closed trades (if store is available) - if at.store != nil { - // Get recent 10 closed trades for AI context - recentTrades, err := at.store.Position().GetRecentTrades(at.id, 10) - if err != nil { - logger.Infof("โš ๏ธ [%s] Failed to get recent trades: %v", at.name, err) - } else { - logger.Infof("๐Ÿ“Š [%s] Found %d recent closed trades for AI context", at.name, len(recentTrades)) - for _, trade := range recentTrades { - // Convert Unix timestamps to formatted strings for AI readability - entryTimeStr := "" - if trade.EntryTime > 0 { - entryTimeStr = time.Unix(trade.EntryTime, 0).UTC().Format("01-02 15:04 UTC") - } - exitTimeStr := "" - if trade.ExitTime > 0 { - exitTimeStr = time.Unix(trade.ExitTime, 0).UTC().Format("01-02 15:04 UTC") - } - - ctx.RecentOrders = append(ctx.RecentOrders, kernel.RecentOrder{ - Symbol: trade.Symbol, - Side: trade.Side, - EntryPrice: trade.EntryPrice, - ExitPrice: trade.ExitPrice, - RealizedPnL: trade.RealizedPnL, - PnLPct: trade.PnLPct, - EntryTime: entryTimeStr, - ExitTime: exitTimeStr, - HoldDuration: trade.HoldDuration, - }) - } - } - // Get trading statistics for AI context - stats, err := at.store.Position().GetFullStats(at.id) - if err != nil { - logger.Infof("โš ๏ธ [%s] Failed to get trading stats: %v", at.name, err) - } else if stats == nil { - logger.Infof("โš ๏ธ [%s] GetFullStats returned nil", at.name) - } else if stats.TotalTrades == 0 { - logger.Infof("โš ๏ธ [%s] GetFullStats returned 0 trades (traderID=%s)", at.name, at.id) - } else { - ctx.TradingStats = &kernel.TradingStats{ - TotalTrades: stats.TotalTrades, - WinRate: stats.WinRate, - ProfitFactor: stats.ProfitFactor, - SharpeRatio: stats.SharpeRatio, - TotalPnL: stats.TotalPnL, - AvgWin: stats.AvgWin, - AvgLoss: stats.AvgLoss, - MaxDrawdownPct: stats.MaxDrawdownPct, - } - logger.Infof("๐Ÿ“ˆ [%s] Trading stats: %d trades, %.1f%% win rate, PF=%.2f, Sharpe=%.2f, DD=%.1f%%", - at.name, stats.TotalTrades, stats.WinRate, stats.ProfitFactor, stats.SharpeRatio, stats.MaxDrawdownPct) - } - } else { - logger.Infof("โš ๏ธ [%s] Store is nil, cannot get recent trades", at.name) - } - - // 8. Get quantitative data (if enabled in strategy config) - if strategyConfig.Indicators.EnableQuantData { - // Collect symbols to query (candidate coins + position coins) - symbolsToQuery := make(map[string]bool) - for _, coin := range candidateCoins { - symbolsToQuery[coin.Symbol] = true - } - for _, pos := range positionInfos { - symbolsToQuery[pos.Symbol] = true - } - - symbols := make([]string, 0, len(symbolsToQuery)) - for sym := range symbolsToQuery { - symbols = append(symbols, sym) - } - - logger.Infof("๐Ÿ“Š [%s] Fetching quantitative data for %d symbols...", at.name, len(symbols)) - ctx.QuantDataMap = at.strategyEngine.FetchQuantDataBatch(symbols) - logger.Infof("๐Ÿ“Š [%s] Successfully fetched quantitative data for %d symbols", at.name, len(ctx.QuantDataMap)) - } - - // 9. Get OI ranking data (market-wide position changes) - if strategyConfig.Indicators.EnableOIRanking { - logger.Infof("๐Ÿ“Š [%s] Fetching OI ranking data...", at.name) - ctx.OIRankingData = at.strategyEngine.FetchOIRankingData() - if ctx.OIRankingData != nil { - logger.Infof("๐Ÿ“Š [%s] OI ranking data ready: %d top, %d low positions", - at.name, len(ctx.OIRankingData.TopPositions), len(ctx.OIRankingData.LowPositions)) - } - } - - // 10. Get NetFlow ranking data (market-wide fund flow) - if strategyConfig.Indicators.EnableNetFlowRanking { - logger.Infof("๐Ÿ’ฐ [%s] Fetching NetFlow ranking data...", at.name) - ctx.NetFlowRankingData = at.strategyEngine.FetchNetFlowRankingData() - if ctx.NetFlowRankingData != nil { - logger.Infof("๐Ÿ’ฐ [%s] NetFlow ranking data ready: inst_in=%d, inst_out=%d", - at.name, len(ctx.NetFlowRankingData.InstitutionFutureTop), len(ctx.NetFlowRankingData.InstitutionFutureLow)) - } - } - - // 11. Get Price ranking data (market-wide gainers/losers) - if strategyConfig.Indicators.EnablePriceRanking { - logger.Infof("๐Ÿ“ˆ [%s] Fetching Price ranking data...", at.name) - ctx.PriceRankingData = at.strategyEngine.FetchPriceRankingData() - if ctx.PriceRankingData != nil { - logger.Infof("๐Ÿ“ˆ [%s] Price ranking data ready for %d durations", - at.name, len(ctx.PriceRankingData.Durations)) - } - } - - return ctx, nil -} - -// executeDecisionWithRecord executes AI decision and records detailed information -func (at *AutoTrader) executeDecisionWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error { - switch decision.Action { - case "open_long": - return at.executeOpenLongWithRecord(decision, actionRecord) - case "open_short": - return at.executeOpenShortWithRecord(decision, actionRecord) - case "close_long": - return at.executeCloseLongWithRecord(decision, actionRecord) - case "close_short": - return at.executeCloseShortWithRecord(decision, actionRecord) - case "hold", "wait": - // No execution needed, just record - return nil - default: - return fmt.Errorf("unknown action: %s", decision.Action) - } -} - -// executeOpenLongWithRecord executes open long position and records detailed information -func (at *AutoTrader) executeOpenLongWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error { - logger.Infof(" ๐Ÿ“ˆ Open long: %s", decision.Symbol) - - // โš ๏ธ Get current positions for multiple checks - positions, err := at.trader.GetPositions() - if err != nil { - return fmt.Errorf("failed to get positions: %w", err) - } - - // [CODE ENFORCED] Check max positions limit - if err := at.enforceMaxPositions(len(positions)); err != nil { - return err - } - - // Check if there's already a position in the same symbol and direction - for _, pos := range positions { - if pos["symbol"] == decision.Symbol && pos["side"] == "long" { - return fmt.Errorf("โŒ %s already has long position, close it first", decision.Symbol) - } - } - - // Get current price - marketData, err := market.GetWithExchange(decision.Symbol, at.exchange) - if err != nil { - return err - } - - // Get balance (needed for multiple checks) - balance, err := at.trader.GetBalance() - if err != nil { - return fmt.Errorf("failed to get account balance: %w", err) - } - availableBalance := 0.0 - if avail, ok := balance["availableBalance"].(float64); ok { - availableBalance = avail - } - - // Get equity for position value ratio check - equity := 0.0 - if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 { - equity = eq - } else if eq, ok := balance["totalWalletBalance"].(float64); ok && eq > 0 { - equity = eq - } else { - equity = availableBalance // Fallback to available balance - } - - // [CODE ENFORCED] Position Value Ratio Check: position_value <= equity ร— ratio - adjustedPositionSize, wasCapped := at.enforcePositionValueRatio(decision.PositionSizeUSD, equity, decision.Symbol) - if wasCapped { - decision.PositionSizeUSD = adjustedPositionSize - } - - // โš ๏ธ Auto-adjust position size if insufficient margin - // Formula: totalRequired = positionSize/leverage + positionSize*0.001 + positionSize/leverage*0.01 - // = positionSize * (1.01/leverage + 0.001) - marginFactor := 1.01/float64(decision.Leverage) + 0.001 - maxAffordablePositionSize := availableBalance / marginFactor - - actualPositionSize := decision.PositionSizeUSD - if actualPositionSize > maxAffordablePositionSize { - // Use 98% of max to leave buffer for price fluctuation - adjustedSize := maxAffordablePositionSize * 0.98 - logger.Infof(" โš ๏ธ Position size %.2f exceeds max affordable %.2f, auto-reducing to %.2f", - actualPositionSize, maxAffordablePositionSize, adjustedSize) - actualPositionSize = adjustedSize - decision.PositionSizeUSD = actualPositionSize - } - - // [CODE ENFORCED] Minimum position size check - if err := at.enforceMinPositionSize(decision.PositionSizeUSD); err != nil { - return err - } - - // Calculate quantity with adjusted position size - quantity := actualPositionSize / marketData.CurrentPrice - actionRecord.Quantity = quantity - actionRecord.Price = marketData.CurrentPrice - - // Set margin mode - if err := at.trader.SetMarginMode(decision.Symbol, at.config.IsCrossMargin); err != nil { - logger.Infof(" โš ๏ธ Failed to set margin mode: %v", err) - // Continue execution, doesn't affect trading - } - - // Open position - order, err := at.trader.OpenLong(decision.Symbol, quantity, decision.Leverage) - if err != nil { - return err - } - - // Record order ID - if orderID, ok := order["orderId"].(int64); ok { - actionRecord.OrderID = orderID - } - - logger.Infof(" โœ“ Position opened successfully, order ID: %v, quantity: %.4f", order["orderId"], quantity) - - // Record order to database and poll for confirmation - at.recordAndConfirmOrder(order, decision.Symbol, "open_long", quantity, marketData.CurrentPrice, decision.Leverage, 0) - - // Record position opening time - posKey := decision.Symbol + "_long" - at.positionFirstSeenTime[posKey] = time.Now().UnixMilli() - - // Set stop loss and take profit - if err := at.trader.SetStopLoss(decision.Symbol, "LONG", quantity, decision.StopLoss); err != nil { - logger.Infof(" โš  Failed to set stop loss: %v", err) - } - if err := at.trader.SetTakeProfit(decision.Symbol, "LONG", quantity, decision.TakeProfit); err != nil { - logger.Infof(" โš  Failed to set take profit: %v", err) - } - - return nil -} - -// executeOpenShortWithRecord executes open short position and records detailed information -func (at *AutoTrader) executeOpenShortWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error { - logger.Infof(" ๐Ÿ“‰ Open short: %s", decision.Symbol) - - // โš ๏ธ Get current positions for multiple checks - positions, err := at.trader.GetPositions() - if err != nil { - return fmt.Errorf("failed to get positions: %w", err) - } - - // [CODE ENFORCED] Check max positions limit - if err := at.enforceMaxPositions(len(positions)); err != nil { - return err - } - - // Check if there's already a position in the same symbol and direction - for _, pos := range positions { - if pos["symbol"] == decision.Symbol && pos["side"] == "short" { - return fmt.Errorf("โŒ %s already has short position, close it first", decision.Symbol) - } - } - - // Get current price - marketData, err := market.GetWithExchange(decision.Symbol, at.exchange) - if err != nil { - return err - } - - // Get balance (needed for multiple checks) - balance, err := at.trader.GetBalance() - if err != nil { - return fmt.Errorf("failed to get account balance: %w", err) - } - availableBalance := 0.0 - if avail, ok := balance["availableBalance"].(float64); ok { - availableBalance = avail - } - - // Get equity for position value ratio check - equity := 0.0 - if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 { - equity = eq - } else if eq, ok := balance["totalWalletBalance"].(float64); ok && eq > 0 { - equity = eq - } else { - equity = availableBalance // Fallback to available balance - } - - // [CODE ENFORCED] Position Value Ratio Check: position_value <= equity ร— ratio - adjustedPositionSize, wasCapped := at.enforcePositionValueRatio(decision.PositionSizeUSD, equity, decision.Symbol) - if wasCapped { - decision.PositionSizeUSD = adjustedPositionSize - } - - // โš ๏ธ Auto-adjust position size if insufficient margin - // Formula: totalRequired = positionSize/leverage + positionSize*0.001 + positionSize/leverage*0.01 - // = positionSize * (1.01/leverage + 0.001) - marginFactor := 1.01/float64(decision.Leverage) + 0.001 - maxAffordablePositionSize := availableBalance / marginFactor - - actualPositionSize := decision.PositionSizeUSD - if actualPositionSize > maxAffordablePositionSize { - // Use 98% of max to leave buffer for price fluctuation - adjustedSize := maxAffordablePositionSize * 0.98 - logger.Infof(" โš ๏ธ Position size %.2f exceeds max affordable %.2f, auto-reducing to %.2f", - actualPositionSize, maxAffordablePositionSize, adjustedSize) - actualPositionSize = adjustedSize - decision.PositionSizeUSD = actualPositionSize - } - - // [CODE ENFORCED] Minimum position size check - if err := at.enforceMinPositionSize(decision.PositionSizeUSD); err != nil { - return err - } - - // Calculate quantity with adjusted position size - quantity := actualPositionSize / marketData.CurrentPrice - actionRecord.Quantity = quantity - actionRecord.Price = marketData.CurrentPrice - - // Set margin mode - if err := at.trader.SetMarginMode(decision.Symbol, at.config.IsCrossMargin); err != nil { - logger.Infof(" โš ๏ธ Failed to set margin mode: %v", err) - // Continue execution, doesn't affect trading - } - - // Open position - order, err := at.trader.OpenShort(decision.Symbol, quantity, decision.Leverage) - if err != nil { - return err - } - - // Record order ID - if orderID, ok := order["orderId"].(int64); ok { - actionRecord.OrderID = orderID - } - - logger.Infof(" โœ“ Position opened successfully, order ID: %v, quantity: %.4f", order["orderId"], quantity) - - // Record order to database and poll for confirmation - at.recordAndConfirmOrder(order, decision.Symbol, "open_short", quantity, marketData.CurrentPrice, decision.Leverage, 0) - - // Record position opening time - posKey := decision.Symbol + "_short" - at.positionFirstSeenTime[posKey] = time.Now().UnixMilli() - - // Set stop loss and take profit - if err := at.trader.SetStopLoss(decision.Symbol, "SHORT", quantity, decision.StopLoss); err != nil { - logger.Infof(" โš  Failed to set stop loss: %v", err) - } - if err := at.trader.SetTakeProfit(decision.Symbol, "SHORT", quantity, decision.TakeProfit); err != nil { - logger.Infof(" โš  Failed to set take profit: %v", err) - } - - return nil -} - -// executeCloseLongWithRecord executes close long position and records detailed information -func (at *AutoTrader) executeCloseLongWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error { - logger.Infof(" ๐Ÿ”„ Close long: %s", decision.Symbol) - - // Get current price - marketData, err := market.GetWithExchange(decision.Symbol, at.exchange) - if err != nil { - return err - } - actionRecord.Price = marketData.CurrentPrice - - // Normalize symbol for database lookup - normalizedSymbol := market.Normalize(decision.Symbol) - - // Get entry price and quantity - prioritize local database for accurate quantity - var entryPrice float64 - var quantity float64 - - // First try to get from local database (more accurate for quantity) - if at.store != nil { - if openPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, normalizedSymbol, "LONG"); err == nil && openPos != nil { - quantity = openPos.Quantity - entryPrice = openPos.EntryPrice - logger.Infof(" ๐Ÿ“Š Using local position data: qty=%.8f, entry=%.2f", quantity, entryPrice) - } - } - - // Fallback to exchange API if local data not found - if quantity == 0 { - positions, err := at.trader.GetPositions() - if err == nil { - for _, pos := range positions { - if pos["symbol"] == decision.Symbol && pos["side"] == "long" { - if ep, ok := pos["entryPrice"].(float64); ok { - entryPrice = ep - } - if amt, ok := pos["positionAmt"].(float64); ok && amt > 0 { - quantity = amt - } - break - } - } - } - logger.Infof(" ๐Ÿ“Š Using exchange position data: qty=%.8f, entry=%.2f", quantity, entryPrice) - } - - // Close position - order, err := at.trader.CloseLong(decision.Symbol, 0) // 0 = close all - if err != nil { - return err - } - - // Record order ID - if orderID, ok := order["orderId"].(int64); ok { - actionRecord.OrderID = orderID - } - - // Record order to database and poll for confirmation - at.recordAndConfirmOrder(order, decision.Symbol, "close_long", quantity, marketData.CurrentPrice, 0, entryPrice) - - logger.Infof(" โœ“ Position closed successfully") - return nil -} - -// executeCloseShortWithRecord executes close short position and records detailed information -func (at *AutoTrader) executeCloseShortWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error { - logger.Infof(" ๐Ÿ”„ Close short: %s", decision.Symbol) - - // Get current price - marketData, err := market.GetWithExchange(decision.Symbol, at.exchange) - if err != nil { - return err - } - actionRecord.Price = marketData.CurrentPrice - - // Normalize symbol for database lookup - normalizedSymbol := market.Normalize(decision.Symbol) - - // Get entry price and quantity - prioritize local database for accurate quantity - var entryPrice float64 - var quantity float64 - - // First try to get from local database (more accurate for quantity) - if at.store != nil { - if openPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, normalizedSymbol, "SHORT"); err == nil && openPos != nil { - quantity = openPos.Quantity - entryPrice = openPos.EntryPrice - logger.Infof(" ๐Ÿ“Š Using local position data: qty=%.8f, entry=%.2f", quantity, entryPrice) - } - } - - // Fallback to exchange API if local data not found - if quantity == 0 { - positions, err := at.trader.GetPositions() - if err == nil { - for _, pos := range positions { - if pos["symbol"] == decision.Symbol && pos["side"] == "short" { - if ep, ok := pos["entryPrice"].(float64); ok { - entryPrice = ep - } - if amt, ok := pos["positionAmt"].(float64); ok { - quantity = -amt // positionAmt is negative for short - } - break - } - } - } - logger.Infof(" ๐Ÿ“Š Using exchange position data: qty=%.8f, entry=%.2f", quantity, entryPrice) - } - - // Close position - order, err := at.trader.CloseShort(decision.Symbol, 0) // 0 = close all - if err != nil { - return err - } - - // Record order ID - if orderID, ok := order["orderId"].(int64); ok { - actionRecord.OrderID = orderID - } - - // Record order to database and poll for confirmation - at.recordAndConfirmOrder(order, decision.Symbol, "close_short", quantity, marketData.CurrentPrice, 0, entryPrice) - - logger.Infof(" โœ“ Position closed successfully") - return nil -} - // GetID gets trader ID func (at *AutoTrader) GetID() string { return at.id @@ -1504,824 +574,16 @@ func (at *AutoTrader) GetSystemPromptTemplate() string { return "strategy" } -// saveEquitySnapshot saves equity snapshot independently (for drawing profit curve, decoupled from AI decision) -func (at *AutoTrader) saveEquitySnapshot(ctx *kernel.Context) { - if at.store == nil || ctx == nil { - return - } - - snapshot := &store.EquitySnapshot{ - TraderID: at.id, - Timestamp: time.Now().UTC(), - TotalEquity: ctx.Account.TotalEquity, - Balance: ctx.Account.TotalEquity - ctx.Account.UnrealizedPnL, - UnrealizedPnL: ctx.Account.UnrealizedPnL, - PositionCount: ctx.Account.PositionCount, - MarginUsedPct: ctx.Account.MarginUsedPct, - } - - if err := at.store.Equity().Save(snapshot); err != nil { - logger.Infof("โš ๏ธ Failed to save equity snapshot: %v", err) - } -} - -// saveDecision saves AI decision log to database (only records AI input/output, for debugging) -func (at *AutoTrader) saveDecision(record *store.DecisionRecord) error { - if at.store == nil { - return nil - } - - at.cycleNumber++ - record.CycleNumber = at.cycleNumber - record.TraderID = at.id - - if record.Timestamp.IsZero() { - record.Timestamp = time.Now().UTC() - } - - if err := at.store.Decision().LogDecision(record); err != nil { - logger.Infof("โš ๏ธ Failed to save decision record: %v", err) - return err - } - - logger.Infof("๐Ÿ“ Decision record saved: trader=%s, cycle=%d", at.id, at.cycleNumber) - return nil -} - // GetStore gets data store (for external access to decision records, etc.) func (at *AutoTrader) GetStore() *store.Store { return at.store } -// GetStatus gets system status (for API) -func (at *AutoTrader) GetStatus() map[string]interface{} { - aiProvider := "DeepSeek" - if at.config.UseQwen { - aiProvider = "Qwen" - } - - at.isRunningMutex.RLock() - isRunning := at.isRunning - at.isRunningMutex.RUnlock() - - result := map[string]interface{}{ - "trader_id": at.id, - "trader_name": at.name, - "ai_model": at.aiModel, - "exchange": at.exchange, - "is_running": isRunning, - "start_time": at.startTime.Format(time.RFC3339), - "runtime_minutes": int(time.Since(at.startTime).Minutes()), - "call_count": at.callCount, - "initial_balance": at.initialBalance, - "scan_interval": at.config.ScanInterval.String(), - "stop_until": at.stopUntil.Format(time.RFC3339), - "last_reset_time": at.lastResetTime.Format(time.RFC3339), - "ai_provider": aiProvider, - } - - // Add strategy info - if at.config.StrategyConfig != nil { - result["strategy_type"] = at.config.StrategyConfig.StrategyType - if at.config.StrategyConfig.GridConfig != nil { - result["grid_symbol"] = at.config.StrategyConfig.GridConfig.Symbol - } - } - - return result -} - -// GetAccountInfo gets account information (for API) -func (at *AutoTrader) GetAccountInfo() (map[string]interface{}, error) { - balance, err := at.trader.GetBalance() - if err != nil { - return nil, fmt.Errorf("failed to get balance: %w", err) - } - - // Get account fields - totalWalletBalance := 0.0 - totalUnrealizedProfit := 0.0 - availableBalance := 0.0 - totalEquity := 0.0 - - if wallet, ok := balance["totalWalletBalance"].(float64); ok { - totalWalletBalance = wallet - } - if unrealized, ok := balance["totalUnrealizedProfit"].(float64); ok { - totalUnrealizedProfit = unrealized - } - if avail, ok := balance["availableBalance"].(float64); ok { - availableBalance = avail - } - - // Use totalEquity directly if provided by trader (more accurate) - if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 { - totalEquity = eq - } else { - // Fallback: Total Equity = Wallet balance + Unrealized profit - totalEquity = totalWalletBalance + totalUnrealizedProfit - } - - // Get positions to calculate total margin - positions, err := at.trader.GetPositions() - if err != nil { - return nil, fmt.Errorf("failed to get positions: %w", err) - } - - totalMarginUsed := 0.0 - totalUnrealizedPnLCalculated := 0.0 - for _, pos := range positions { - markPrice := pos["markPrice"].(float64) - quantity := pos["positionAmt"].(float64) - if quantity < 0 { - quantity = -quantity - } - unrealizedPnl := pos["unRealizedProfit"].(float64) - totalUnrealizedPnLCalculated += unrealizedPnl - - leverage := 10 - if lev, ok := pos["leverage"].(float64); ok { - leverage = int(lev) - } - marginUsed := (quantity * markPrice) / float64(leverage) - totalMarginUsed += marginUsed - } - - // Verify unrealized P&L consistency (API value vs calculated from positions) - // Note: Lighter API may return 0 for unrealized PnL, this is a known limitation - diff := math.Abs(totalUnrealizedProfit - totalUnrealizedPnLCalculated) - if diff > 5.0 { // Only warn if difference is significant (> 5 USDT) - logger.Infof("โš ๏ธ Unrealized P&L inconsistency (Lighter API limitation): API=%.4f, Calculated=%.4f, Diff=%.4f", - totalUnrealizedProfit, totalUnrealizedPnLCalculated, diff) - } - - totalPnL := totalEquity - at.initialBalance - totalPnLPct := 0.0 - if at.initialBalance > 0 { - totalPnLPct = (totalPnL / at.initialBalance) * 100 - } else { - logger.Infof("โš ๏ธ Initial Balance abnormal: %.2f, cannot calculate P&L percentage", at.initialBalance) - } - - marginUsedPct := 0.0 - if totalEquity > 0 { - marginUsedPct = (totalMarginUsed / totalEquity) * 100 - } - - return map[string]interface{}{ - // Core fields - "total_equity": totalEquity, // Account equity = wallet + unrealized - "wallet_balance": totalWalletBalance, // Wallet balance (excluding unrealized P&L) - "unrealized_profit": totalUnrealizedProfit, // Unrealized P&L (official value from exchange API) - "available_balance": availableBalance, // Available balance - - // P&L statistics - "total_pnl": totalPnL, // Total P&L = equity - initial - "total_pnl_pct": totalPnLPct, // Total P&L percentage - "initial_balance": at.initialBalance, // Initial balance - "daily_pnl": at.dailyPnL, // Daily P&L - - // Position information - "position_count": len(positions), // Position count - "margin_used": totalMarginUsed, // Margin used - "margin_used_pct": marginUsedPct, // Margin usage rate - }, nil -} - -// GetPositions gets position list (for API) -func (at *AutoTrader) GetPositions() ([]map[string]interface{}, error) { - positions, err := at.trader.GetPositions() - if err != nil { - return nil, fmt.Errorf("failed to get positions: %w", err) - } - - var result []map[string]interface{} - for _, pos := range positions { - symbol := pos["symbol"].(string) - side := pos["side"].(string) - entryPrice := pos["entryPrice"].(float64) - markPrice := pos["markPrice"].(float64) - quantity := pos["positionAmt"].(float64) - if quantity < 0 { - quantity = -quantity - } - unrealizedPnl := pos["unRealizedProfit"].(float64) - liquidationPrice := pos["liquidationPrice"].(float64) - - leverage := 10 - if lev, ok := pos["leverage"].(float64); ok { - leverage = int(lev) - } - - // Calculate margin used - marginUsed := (quantity * markPrice) / float64(leverage) - - // Calculate P&L percentage (based on margin) - pnlPct := calculatePnLPercentage(unrealizedPnl, marginUsed) - - result = append(result, map[string]interface{}{ - "symbol": symbol, - "side": side, - "entry_price": entryPrice, - "mark_price": markPrice, - "quantity": quantity, - "leverage": leverage, - "unrealized_pnl": unrealizedPnl, - "unrealized_pnl_pct": pnlPct, - "liquidation_price": liquidationPrice, - "margin_used": marginUsed, - }) - } - - return result, nil -} - // calculatePnLPercentage calculates P&L percentage (based on margin, automatically considers leverage) -// Return rate = Unrealized P&L / Margin ร— 100% +// Return rate = Unrealized P&L / Margin x 100% func calculatePnLPercentage(unrealizedPnl, marginUsed float64) float64 { if marginUsed > 0 { return (unrealizedPnl / marginUsed) * 100 } return 0.0 } - -// sortDecisionsByPriority sorts decisions: close positions first, then open positions, finally hold/wait -// This avoids position stacking overflow when changing positions -func sortDecisionsByPriority(decisions []kernel.Decision) []kernel.Decision { - if len(decisions) <= 1 { - return decisions - } - - // Define priority - getActionPriority := func(action string) int { - switch action { - case "close_long", "close_short": - return 1 // Highest priority: close positions first - case "open_long", "open_short": - return 2 // Second priority: open positions later - case "hold", "wait": - return 3 // Lowest priority: wait - default: - return 999 // Unknown actions at the end - } - } - - // Copy decision list - sorted := make([]kernel.Decision, len(decisions)) - copy(sorted, decisions) - - // Sort by priority - for i := 0; i < len(sorted)-1; i++ { - for j := i + 1; j < len(sorted); j++ { - if getActionPriority(sorted[i].Action) > getActionPriority(sorted[j].Action) { - sorted[i], sorted[j] = sorted[j], sorted[i] - } - } - } - - return sorted -} - -// startDrawdownMonitor starts drawdown monitoring -func (at *AutoTrader) startDrawdownMonitor() { - at.monitorWg.Add(1) - go func() { - defer at.monitorWg.Done() - - ticker := time.NewTicker(1 * time.Minute) // Check every minute - defer ticker.Stop() - - logger.Info("๐Ÿ“Š Started position drawdown monitoring (check every minute)") - - for { - select { - case <-ticker.C: - at.checkPositionDrawdown() - case <-at.stopMonitorCh: - logger.Info("โน Stopped position drawdown monitoring") - return - } - } - }() -} - -// checkPositionDrawdown checks position drawdown situation -func (at *AutoTrader) checkPositionDrawdown() { - // Get current positions - positions, err := at.trader.GetPositions() - if err != nil { - logger.Infof("โŒ Drawdown monitoring: failed to get positions: %v", err) - return - } - - for _, pos := range positions { - symbol := pos["symbol"].(string) - side := pos["side"].(string) - entryPrice := pos["entryPrice"].(float64) - markPrice := pos["markPrice"].(float64) - quantity := pos["positionAmt"].(float64) - if quantity < 0 { - quantity = -quantity // Short position quantity is negative, convert to positive - } - - // Calculate current P&L percentage - leverage := 10 // Default value - if lev, ok := pos["leverage"].(float64); ok { - leverage = int(lev) - } - - var currentPnLPct float64 - if side == "long" { - currentPnLPct = ((markPrice - entryPrice) / entryPrice) * float64(leverage) * 100 - } else { - currentPnLPct = ((entryPrice - markPrice) / entryPrice) * float64(leverage) * 100 - } - - // Construct unique position identifier (distinguish long/short) - posKey := symbol + "_" + side - - // Get historical peak profit for this position - at.peakPnLCacheMutex.RLock() - peakPnLPct, exists := at.peakPnLCache[posKey] - at.peakPnLCacheMutex.RUnlock() - - if !exists { - // If no historical peak record, use current P&L as initial value - peakPnLPct = currentPnLPct - at.UpdatePeakPnL(symbol, side, currentPnLPct) - } else { - // Update peak cache - at.UpdatePeakPnL(symbol, side, currentPnLPct) - } - - // Calculate drawdown (magnitude of decline from peak) - var drawdownPct float64 - if peakPnLPct > 0 && currentPnLPct < peakPnLPct { - drawdownPct = ((peakPnLPct - currentPnLPct) / peakPnLPct) * 100 - } - - // Check close position condition: profit > 5% and drawdown >= 40% - if currentPnLPct > 5.0 && drawdownPct >= 40.0 { - logger.Infof("๐Ÿšจ Drawdown close position condition triggered: %s %s | Current profit: %.2f%% | Peak profit: %.2f%% | Drawdown: %.2f%%", - symbol, side, currentPnLPct, peakPnLPct, drawdownPct) - - // Execute close position - if err := at.emergencyClosePosition(symbol, side); err != nil { - logger.Infof("โŒ Drawdown close position failed (%s %s): %v", symbol, side, err) - } else { - logger.Infof("โœ… Drawdown close position succeeded: %s %s", symbol, side) - // Clear cache for this position after closing - at.ClearPeakPnLCache(symbol, side) - } - } else if currentPnLPct > 5.0 { - // Record situations close to close position condition (for debugging) - logger.Infof("๐Ÿ“Š Drawdown monitoring: %s %s | Profit: %.2f%% | Peak: %.2f%% | Drawdown: %.2f%%", - symbol, side, currentPnLPct, peakPnLPct, drawdownPct) - } - } -} - -// emergencyClosePosition emergency close position function -func (at *AutoTrader) emergencyClosePosition(symbol, side string) error { - switch side { - case "long": - order, err := at.trader.CloseLong(symbol, 0) // 0 = close all - if err != nil { - return err - } - logger.Infof("โœ… Emergency close long position succeeded, order ID: %v", order["orderId"]) - case "short": - order, err := at.trader.CloseShort(symbol, 0) // 0 = close all - if err != nil { - return err - } - logger.Infof("โœ… Emergency close short position succeeded, order ID: %v", order["orderId"]) - default: - return fmt.Errorf("unknown position direction: %s", side) - } - - return nil -} - -// GetPeakPnLCache gets peak profit cache -func (at *AutoTrader) GetPeakPnLCache() map[string]float64 { - at.peakPnLCacheMutex.RLock() - defer at.peakPnLCacheMutex.RUnlock() - - // Return a copy of the cache - cache := make(map[string]float64) - for k, v := range at.peakPnLCache { - cache[k] = v - } - return cache -} - -// UpdatePeakPnL updates peak profit cache -func (at *AutoTrader) UpdatePeakPnL(symbol, side string, currentPnLPct float64) { - at.peakPnLCacheMutex.Lock() - defer at.peakPnLCacheMutex.Unlock() - - posKey := symbol + "_" + side - if peak, exists := at.peakPnLCache[posKey]; exists { - // Update peak (if long, take larger value; if short, currentPnLPct is negative, also compare) - if currentPnLPct > peak { - at.peakPnLCache[posKey] = currentPnLPct - } - } else { - // First time recording - at.peakPnLCache[posKey] = currentPnLPct - } -} - -// ClearPeakPnLCache clears peak cache for specified position -func (at *AutoTrader) ClearPeakPnLCache(symbol, side string) { - at.peakPnLCacheMutex.Lock() - defer at.peakPnLCacheMutex.Unlock() - - posKey := symbol + "_" + side - delete(at.peakPnLCache, posKey) -} - -// recordAndConfirmOrder polls order status for actual fill data and records position -// action: open_long, open_short, close_long, close_short -// entryPrice: entry price when closing (0 when opening) -func (at *AutoTrader) recordAndConfirmOrder(orderResult map[string]interface{}, symbol, action string, quantity float64, price float64, leverage int, entryPrice float64) { - if at.store == nil { - return - } - - // Get order ID (supports multiple types) - var orderID string - switch v := orderResult["orderId"].(type) { - case int64: - orderID = fmt.Sprintf("%d", v) - case float64: - orderID = fmt.Sprintf("%.0f", v) - case string: - orderID = v - default: - orderID = fmt.Sprintf("%v", v) - } - - if orderID == "" || orderID == "0" { - logger.Infof(" โš ๏ธ Order ID is empty, skipping record") - return - } - - // Determine positionSide - var positionSide string - switch action { - case "open_long", "close_long": - positionSide = "LONG" - case "open_short", "close_short": - positionSide = "SHORT" - } - - var actualPrice = price - var actualQty = quantity - var fee float64 - - // Exchanges with OrderSync: Skip immediate order recording, let OrderSync handle it - // This ensures accurate data from GetTrades API and avoids duplicate records - switch at.exchange { - case "binance", "lighter", "hyperliquid", "bybit", "okx", "bitget", "aster", "kucoin", "gate": - logger.Infof(" ๐Ÿ“ Order submitted (id: %s), will be synced by OrderSync", orderID) - return - } - - // For exchanges without OrderSync (e.g., Binance): record immediately and poll for fill data - orderRecord := at.createOrderRecord(orderID, symbol, action, positionSide, quantity, price, leverage) - if err := at.store.Order().CreateOrder(orderRecord); err != nil { - logger.Infof(" โš ๏ธ Failed to record order: %v", err) - } else { - logger.Infof(" ๐Ÿ“ Order recorded: %s [%s] %s", orderID, action, symbol) - } - - // Wait for order to be filled and get actual fill data - time.Sleep(500 * time.Millisecond) - for i := 0; i < 5; i++ { - status, err := at.trader.GetOrderStatus(symbol, orderID) - if err == nil { - statusStr, _ := status["status"].(string) - if statusStr == "FILLED" { - // Get actual fill price - if avgPrice, ok := status["avgPrice"].(float64); ok && avgPrice > 0 { - actualPrice = avgPrice - } - // Get actual executed quantity - if execQty, ok := status["executedQty"].(float64); ok && execQty > 0 { - actualQty = execQty - } - // Get commission/fee - if commission, ok := status["commission"].(float64); ok { - fee = commission - } - logger.Infof(" โœ… Order filled: avgPrice=%.6f, qty=%.6f, fee=%.6f", actualPrice, actualQty, fee) - - // Update order status to FILLED - if err := at.store.Order().UpdateOrderStatus(orderRecord.ID, "FILLED", actualQty, actualPrice, fee); err != nil { - logger.Infof(" โš ๏ธ Failed to update order status: %v", err) - } - - // Record fill details - at.recordOrderFill(orderRecord.ID, orderID, symbol, action, actualPrice, actualQty, fee) - break - } else if statusStr == "CANCELED" || statusStr == "EXPIRED" || statusStr == "REJECTED" { - logger.Infof(" โš ๏ธ Order %s, skipping position record", statusStr) - - // Update order status - if err := at.store.Order().UpdateOrderStatus(orderRecord.ID, statusStr, 0, 0, 0); err != nil { - logger.Infof(" โš ๏ธ Failed to update order status: %v", err) - } - return - } - } - time.Sleep(500 * time.Millisecond) - } - - // Normalize symbol for position record consistency - normalizedSymbolForPosition := market.Normalize(symbol) - - logger.Infof(" ๐Ÿ“ Recording position (ID: %s, action: %s, price: %.6f, qty: %.6f, fee: %.4f)", - orderID, action, actualPrice, actualQty, fee) - - // Record position change with actual fill data (use normalized symbol) - at.recordPositionChange(orderID, normalizedSymbolForPosition, positionSide, action, actualQty, actualPrice, leverage, entryPrice, fee) - - // Send anonymous trade statistics for experience improvement (async, non-blocking) - // This helps us understand overall product usage across all deployments - experience.TrackTrade(experience.TradeEvent{ - Exchange: at.exchange, - TradeType: action, - Symbol: symbol, - AmountUSD: actualPrice * actualQty, - Leverage: leverage, - UserID: at.userID, - TraderID: at.id, - }) -} - -// recordPositionChange records position change (create record on open, update record on close) -func (at *AutoTrader) recordPositionChange(orderID, symbol, side, action string, quantity, price float64, leverage int, entryPrice float64, fee float64) { - if at.store == nil { - return - } - - switch action { - case "open_long", "open_short": - // Open position: create new position record - nowMs := time.Now().UTC().UnixMilli() - pos := &store.TraderPosition{ - TraderID: at.id, - ExchangeID: at.exchangeID, // Exchange account UUID - ExchangeType: at.exchange, // Exchange type: binance/bybit/okx/etc - Symbol: symbol, - Side: side, // LONG or SHORT - Quantity: quantity, - EntryPrice: price, - EntryOrderID: orderID, - EntryTime: nowMs, - Leverage: leverage, - Status: "OPEN", - CreatedAt: nowMs, - UpdatedAt: nowMs, - } - if err := at.store.Position().Create(pos); err != nil { - logger.Infof(" โš ๏ธ Failed to record position: %v", err) - } else { - logger.Infof(" ๐Ÿ“Š Position recorded [%s] %s %s @ %.4f", at.id[:8], symbol, side, price) - } - - case "close_long", "close_short": - // Close position using PositionBuilder for consistent handling - // PositionBuilder will handle both cases: - // 1. If open position exists: close it properly - // 2. If no open position (e.g., table cleared): create a closed position record - posBuilder := store.NewPositionBuilder(at.store.Position()) - if err := posBuilder.ProcessTrade( - at.id, at.exchangeID, at.exchange, - symbol, side, action, - quantity, price, fee, 0, // realizedPnL will be calculated - time.Now().UTC().UnixMilli(), orderID, - ); err != nil { - logger.Infof(" โš ๏ธ Failed to process close position: %v", err) - } else { - logger.Infof(" โœ… Position closed [%s] %s %s @ %.4f", at.id[:8], symbol, side, price) - } - } -} - -// createOrderRecord creates an order record struct from order details -func (at *AutoTrader) createOrderRecord(orderID, symbol, action, positionSide string, quantity, price float64, leverage int) *store.TraderOrder { - // Determine order type (market for auto trader) - orderType := "MARKET" - - // Determine side (BUY/SELL) - var side string - switch action { - case "open_long", "close_short": - side = "BUY" - case "open_short", "close_long": - side = "SELL" - } - - // Use action as orderAction directly (keep lowercase format) - orderAction := action - - // Determine if it's a reduce only order - reduceOnly := (action == "close_long" || action == "close_short") - - // Normalize symbol for consistency - normalizedSymbol := market.Normalize(symbol) - - return &store.TraderOrder{ - TraderID: at.id, - ExchangeID: at.exchangeID, - ExchangeType: at.exchange, - ExchangeOrderID: orderID, - Symbol: normalizedSymbol, - Side: side, - PositionSide: positionSide, - Type: orderType, - TimeInForce: "GTC", - Quantity: quantity, - Price: price, - Status: "NEW", - FilledQuantity: 0, - AvgFillPrice: 0, - Commission: 0, - CommissionAsset: "USDT", - Leverage: leverage, - ReduceOnly: reduceOnly, - ClosePosition: reduceOnly, - OrderAction: orderAction, - CreatedAt: time.Now().UTC().UnixMilli(), - UpdatedAt: time.Now().UTC().UnixMilli(), - } -} - -// recordOrderFill records order fill/trade details -func (at *AutoTrader) recordOrderFill(orderRecordID int64, exchangeOrderID, symbol, action string, price, quantity, fee float64) { - if at.store == nil { - return - } - - // Determine side (BUY/SELL) - var side string - switch action { - case "open_long", "close_short": - side = "BUY" - case "open_short", "close_long": - side = "SELL" - } - - // Generate a simple trade ID (exchange doesn't always provide one) - tradeID := fmt.Sprintf("%s-%d", exchangeOrderID, time.Now().UnixNano()) - - // Normalize symbol for consistency - normalizedSymbol := market.Normalize(symbol) - - fill := &store.TraderFill{ - TraderID: at.id, - ExchangeID: at.exchangeID, - ExchangeType: at.exchange, - OrderID: orderRecordID, - ExchangeOrderID: exchangeOrderID, - ExchangeTradeID: tradeID, - Symbol: normalizedSymbol, - Side: side, - Price: price, - Quantity: quantity, - QuoteQuantity: price * quantity, - Commission: fee, - CommissionAsset: "USDT", - RealizedPnL: 0, // Will be calculated for close orders - IsMaker: false, // Market orders are usually taker - CreatedAt: time.Now().UTC().UnixMilli(), - } - - // Calculate realized PnL for close orders - if action == "close_long" || action == "close_short" { - // Try to get the entry price from the open position - var positionSide string - if action == "close_long" { - positionSide = "LONG" - } else { - positionSide = "SHORT" - } - - if openPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, symbol, positionSide); err == nil && openPos != nil { - if positionSide == "LONG" { - fill.RealizedPnL = (price - openPos.EntryPrice) * quantity - } else { - fill.RealizedPnL = (openPos.EntryPrice - price) * quantity - } - } - } - - if err := at.store.Order().CreateFill(fill); err != nil { - logger.Infof(" โš ๏ธ Failed to record fill: %v", err) - } else { - logger.Infof(" ๐Ÿ“‹ Fill recorded: %.4f @ %.6f, fee: %.4f", quantity, price, fee) - } -} - -// ============================================================================ -// Risk Control Helpers -// ============================================================================ - -// isBTCETH checks if a symbol is BTC or ETH -func isBTCETH(symbol string) bool { - symbol = strings.ToUpper(symbol) - return strings.HasPrefix(symbol, "BTC") || strings.HasPrefix(symbol, "ETH") -} - -// enforcePositionValueRatio checks and enforces position value ratio limits (CODE ENFORCED) -// Returns the adjusted position size (capped if necessary) and whether the position was capped -// positionSizeUSD: the original position size in USD -// equity: the account equity -// symbol: the trading symbol -func (at *AutoTrader) enforcePositionValueRatio(positionSizeUSD float64, equity float64, symbol string) (float64, bool) { - if at.config.StrategyConfig == nil { - return positionSizeUSD, false - } - - riskControl := at.config.StrategyConfig.RiskControl - - // Get the appropriate position value ratio limit - var maxPositionValueRatio float64 - if isBTCETH(symbol) { - maxPositionValueRatio = riskControl.BTCETHMaxPositionValueRatio - if maxPositionValueRatio <= 0 { - maxPositionValueRatio = 5.0 // Default: 5x for BTC/ETH - } - } else { - maxPositionValueRatio = riskControl.AltcoinMaxPositionValueRatio - if maxPositionValueRatio <= 0 { - maxPositionValueRatio = 1.0 // Default: 1x for altcoins - } - } - - // Calculate max allowed position value = equity ร— ratio - maxPositionValue := equity * maxPositionValueRatio - - // Check if position size exceeds limit - if positionSizeUSD > maxPositionValue { - logger.Infof(" โš ๏ธ [RISK CONTROL] Position %.2f USDT exceeds limit (equity %.2f ร— %.1fx = %.2f USDT max for %s), capping", - positionSizeUSD, equity, maxPositionValueRatio, maxPositionValue, symbol) - return maxPositionValue, true - } - - return positionSizeUSD, false -} - -// enforceMinPositionSize checks minimum position size (CODE ENFORCED) -func (at *AutoTrader) enforceMinPositionSize(positionSizeUSD float64) error { - if at.config.StrategyConfig == nil { - return nil - } - - minSize := at.config.StrategyConfig.RiskControl.MinPositionSize - if minSize <= 0 { - minSize = 12 // Default: 12 USDT - } - - if positionSizeUSD < minSize { - return fmt.Errorf("โŒ [RISK CONTROL] Position %.2f USDT below minimum (%.2f USDT)", positionSizeUSD, minSize) - } - return nil -} - -// enforceMaxPositions checks maximum positions count (CODE ENFORCED) -func (at *AutoTrader) enforceMaxPositions(currentPositionCount int) error { - if at.config.StrategyConfig == nil { - return nil - } - - maxPositions := at.config.StrategyConfig.RiskControl.MaxPositions - if maxPositions <= 0 { - maxPositions = 3 // Default: 3 positions - } - - if currentPositionCount >= maxPositions { - return fmt.Errorf("โŒ [RISK CONTROL] Already at max positions (%d/%d)", currentPositionCount, maxPositions) - } - return nil -} - -// getSideFromAction converts order action to side (BUY/SELL) -func getSideFromAction(action string) string { - switch action { - case "open_long", "close_short": - return "BUY" - case "open_short", "close_long": - return "SELL" - default: - return "BUY" - } -} - -// GetOpenOrders returns open orders (pending SL/TP) from exchange -func (at *AutoTrader) GetOpenOrders(symbol string) ([]OpenOrder, error) { - return at.trader.GetOpenOrders(symbol) -} diff --git a/trader/auto_trader_decision.go b/trader/auto_trader_decision.go new file mode 100644 index 00000000..2cec395c --- /dev/null +++ b/trader/auto_trader_decision.go @@ -0,0 +1,527 @@ +package trader + +import ( + "fmt" + "math" + "nofx/experience" + "nofx/kernel" + "nofx/logger" + "nofx/market" + "nofx/store" + "time" +) + +// saveEquitySnapshot saves equity snapshot independently (for drawing profit curve, decoupled from AI decision) +func (at *AutoTrader) saveEquitySnapshot(ctx *kernel.Context) { + if at.store == nil || ctx == nil { + return + } + + snapshot := &store.EquitySnapshot{ + TraderID: at.id, + Timestamp: time.Now().UTC(), + TotalEquity: ctx.Account.TotalEquity, + Balance: ctx.Account.TotalEquity - ctx.Account.UnrealizedPnL, + UnrealizedPnL: ctx.Account.UnrealizedPnL, + PositionCount: ctx.Account.PositionCount, + MarginUsedPct: ctx.Account.MarginUsedPct, + } + + if err := at.store.Equity().Save(snapshot); err != nil { + logger.Infof("โš ๏ธ Failed to save equity snapshot: %v", err) + } +} + +// saveDecision saves AI decision log to database (only records AI input/output, for debugging) +func (at *AutoTrader) saveDecision(record *store.DecisionRecord) error { + if at.store == nil { + return nil + } + + at.cycleNumber++ + record.CycleNumber = at.cycleNumber + record.TraderID = at.id + + if record.Timestamp.IsZero() { + record.Timestamp = time.Now().UTC() + } + + if err := at.store.Decision().LogDecision(record); err != nil { + logger.Infof("โš ๏ธ Failed to save decision record: %v", err) + return err + } + + logger.Infof("๐Ÿ“ Decision record saved: trader=%s, cycle=%d", at.id, at.cycleNumber) + return nil +} + +// GetStatus gets system status (for API) +func (at *AutoTrader) GetStatus() map[string]interface{} { + aiProvider := "DeepSeek" + if at.config.UseQwen { + aiProvider = "Qwen" + } + + at.isRunningMutex.RLock() + isRunning := at.isRunning + at.isRunningMutex.RUnlock() + + result := map[string]interface{}{ + "trader_id": at.id, + "trader_name": at.name, + "ai_model": at.aiModel, + "exchange": at.exchange, + "is_running": isRunning, + "start_time": at.startTime.Format(time.RFC3339), + "runtime_minutes": int(time.Since(at.startTime).Minutes()), + "call_count": at.callCount, + "initial_balance": at.initialBalance, + "scan_interval": at.config.ScanInterval.String(), + "stop_until": at.stopUntil.Format(time.RFC3339), + "last_reset_time": at.lastResetTime.Format(time.RFC3339), + "ai_provider": aiProvider, + } + + // Add strategy info + if at.config.StrategyConfig != nil { + result["strategy_type"] = at.config.StrategyConfig.StrategyType + if at.config.StrategyConfig.GridConfig != nil { + result["grid_symbol"] = at.config.StrategyConfig.GridConfig.Symbol + } + } + + return result +} + +// GetAccountInfo gets account information (for API) +func (at *AutoTrader) GetAccountInfo() (map[string]interface{}, error) { + balance, err := at.trader.GetBalance() + if err != nil { + return nil, fmt.Errorf("failed to get balance: %w", err) + } + + // Get account fields + totalWalletBalance := 0.0 + totalUnrealizedProfit := 0.0 + availableBalance := 0.0 + totalEquity := 0.0 + + if wallet, ok := balance["totalWalletBalance"].(float64); ok { + totalWalletBalance = wallet + } + if unrealized, ok := balance["totalUnrealizedProfit"].(float64); ok { + totalUnrealizedProfit = unrealized + } + if avail, ok := balance["availableBalance"].(float64); ok { + availableBalance = avail + } + + // Use totalEquity directly if provided by trader (more accurate) + if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 { + totalEquity = eq + } else { + // Fallback: Total Equity = Wallet balance + Unrealized profit + totalEquity = totalWalletBalance + totalUnrealizedProfit + } + + // Get positions to calculate total margin + positions, err := at.trader.GetPositions() + if err != nil { + return nil, fmt.Errorf("failed to get positions: %w", err) + } + + totalMarginUsed := 0.0 + totalUnrealizedPnLCalculated := 0.0 + for _, pos := range positions { + markPrice := pos["markPrice"].(float64) + quantity := pos["positionAmt"].(float64) + if quantity < 0 { + quantity = -quantity + } + unrealizedPnl := pos["unRealizedProfit"].(float64) + totalUnrealizedPnLCalculated += unrealizedPnl + + leverage := 10 + if lev, ok := pos["leverage"].(float64); ok { + leverage = int(lev) + } + marginUsed := (quantity * markPrice) / float64(leverage) + totalMarginUsed += marginUsed + } + + // Verify unrealized P&L consistency (API value vs calculated from positions) + // Note: Lighter API may return 0 for unrealized PnL, this is a known limitation + diff := math.Abs(totalUnrealizedProfit - totalUnrealizedPnLCalculated) + if diff > 5.0 { // Only warn if difference is significant (> 5 USDT) + logger.Infof("โš ๏ธ Unrealized P&L inconsistency (Lighter API limitation): API=%.4f, Calculated=%.4f, Diff=%.4f", + totalUnrealizedProfit, totalUnrealizedPnLCalculated, diff) + } + + totalPnL := totalEquity - at.initialBalance + totalPnLPct := 0.0 + if at.initialBalance > 0 { + totalPnLPct = (totalPnL / at.initialBalance) * 100 + } else { + logger.Infof("โš ๏ธ Initial Balance abnormal: %.2f, cannot calculate P&L percentage", at.initialBalance) + } + + marginUsedPct := 0.0 + if totalEquity > 0 { + marginUsedPct = (totalMarginUsed / totalEquity) * 100 + } + + return map[string]interface{}{ + // Core fields + "total_equity": totalEquity, // Account equity = wallet + unrealized + "wallet_balance": totalWalletBalance, // Wallet balance (excluding unrealized P&L) + "unrealized_profit": totalUnrealizedProfit, // Unrealized P&L (official value from exchange API) + "available_balance": availableBalance, // Available balance + + // P&L statistics + "total_pnl": totalPnL, // Total P&L = equity - initial + "total_pnl_pct": totalPnLPct, // Total P&L percentage + "initial_balance": at.initialBalance, // Initial balance + "daily_pnl": at.dailyPnL, // Daily P&L + + // Position information + "position_count": len(positions), // Position count + "margin_used": totalMarginUsed, // Margin used + "margin_used_pct": marginUsedPct, // Margin usage rate + }, nil +} + +// GetPositions gets position list (for API) +func (at *AutoTrader) GetPositions() ([]map[string]interface{}, error) { + positions, err := at.trader.GetPositions() + if err != nil { + return nil, fmt.Errorf("failed to get positions: %w", err) + } + + var result []map[string]interface{} + for _, pos := range positions { + symbol := pos["symbol"].(string) + side := pos["side"].(string) + entryPrice := pos["entryPrice"].(float64) + markPrice := pos["markPrice"].(float64) + quantity := pos["positionAmt"].(float64) + if quantity < 0 { + quantity = -quantity + } + unrealizedPnl := pos["unRealizedProfit"].(float64) + liquidationPrice := pos["liquidationPrice"].(float64) + + leverage := 10 + if lev, ok := pos["leverage"].(float64); ok { + leverage = int(lev) + } + + // Calculate margin used + marginUsed := (quantity * markPrice) / float64(leverage) + + // Calculate P&L percentage (based on margin) + pnlPct := calculatePnLPercentage(unrealizedPnl, marginUsed) + + result = append(result, map[string]interface{}{ + "symbol": symbol, + "side": side, + "entry_price": entryPrice, + "mark_price": markPrice, + "quantity": quantity, + "leverage": leverage, + "unrealized_pnl": unrealizedPnl, + "unrealized_pnl_pct": pnlPct, + "liquidation_price": liquidationPrice, + "margin_used": marginUsed, + }) + } + + return result, nil +} + +// recordAndConfirmOrder polls order status for actual fill data and records position +// action: open_long, open_short, close_long, close_short +// entryPrice: entry price when closing (0 when opening) +func (at *AutoTrader) recordAndConfirmOrder(orderResult map[string]interface{}, symbol, action string, quantity float64, price float64, leverage int, entryPrice float64) { + if at.store == nil { + return + } + + // Get order ID (supports multiple types) + var orderID string + switch v := orderResult["orderId"].(type) { + case int64: + orderID = fmt.Sprintf("%d", v) + case float64: + orderID = fmt.Sprintf("%.0f", v) + case string: + orderID = v + default: + orderID = fmt.Sprintf("%v", v) + } + + if orderID == "" || orderID == "0" { + logger.Infof(" โš ๏ธ Order ID is empty, skipping record") + return + } + + // Determine positionSide + var positionSide string + switch action { + case "open_long", "close_long": + positionSide = "LONG" + case "open_short", "close_short": + positionSide = "SHORT" + } + + var actualPrice = price + var actualQty = quantity + var fee float64 + + // Exchanges with OrderSync: Skip immediate order recording, let OrderSync handle it + // This ensures accurate data from GetTrades API and avoids duplicate records + switch at.exchange { + case "binance", "lighter", "hyperliquid", "bybit", "okx", "bitget", "aster", "kucoin", "gate": + logger.Infof(" ๐Ÿ“ Order submitted (id: %s), will be synced by OrderSync", orderID) + return + } + + // For exchanges without OrderSync (e.g., Binance): record immediately and poll for fill data + orderRecord := at.createOrderRecord(orderID, symbol, action, positionSide, quantity, price, leverage) + if err := at.store.Order().CreateOrder(orderRecord); err != nil { + logger.Infof(" โš ๏ธ Failed to record order: %v", err) + } else { + logger.Infof(" ๐Ÿ“ Order recorded: %s [%s] %s", orderID, action, symbol) + } + + // Wait for order to be filled and get actual fill data + time.Sleep(500 * time.Millisecond) + for i := 0; i < 5; i++ { + status, err := at.trader.GetOrderStatus(symbol, orderID) + if err == nil { + statusStr, _ := status["status"].(string) + if statusStr == "FILLED" { + // Get actual fill price + if avgPrice, ok := status["avgPrice"].(float64); ok && avgPrice > 0 { + actualPrice = avgPrice + } + // Get actual executed quantity + if execQty, ok := status["executedQty"].(float64); ok && execQty > 0 { + actualQty = execQty + } + // Get commission/fee + if commission, ok := status["commission"].(float64); ok { + fee = commission + } + logger.Infof(" โœ… Order filled: avgPrice=%.6f, qty=%.6f, fee=%.6f", actualPrice, actualQty, fee) + + // Update order status to FILLED + if err := at.store.Order().UpdateOrderStatus(orderRecord.ID, "FILLED", actualQty, actualPrice, fee); err != nil { + logger.Infof(" โš ๏ธ Failed to update order status: %v", err) + } + + // Record fill details + at.recordOrderFill(orderRecord.ID, orderID, symbol, action, actualPrice, actualQty, fee) + break + } else if statusStr == "CANCELED" || statusStr == "EXPIRED" || statusStr == "REJECTED" { + logger.Infof(" โš ๏ธ Order %s, skipping position record", statusStr) + // Update order status + if err := at.store.Order().UpdateOrderStatus(orderRecord.ID, statusStr, 0, 0, 0); err != nil { + logger.Infof(" โš ๏ธ Failed to update order status: %v", err) + } + return + } + } + time.Sleep(500 * time.Millisecond) + } + + // Normalize symbol for position record consistency + normalizedSymbolForPosition := market.Normalize(symbol) + + logger.Infof(" ๐Ÿ“ Recording position (ID: %s, action: %s, price: %.6f, qty: %.6f, fee: %.4f)", + orderID, action, actualPrice, actualQty, fee) + + // Record position change with actual fill data (use normalized symbol) + at.recordPositionChange(orderID, normalizedSymbolForPosition, positionSide, action, actualQty, actualPrice, leverage, entryPrice, fee) + + // Send anonymous trade statistics for experience improvement (async, non-blocking) + // This helps us understand overall product usage across all deployments + experience.TrackTrade(experience.TradeEvent{ + Exchange: at.exchange, + TradeType: action, + Symbol: symbol, + AmountUSD: actualPrice * actualQty, + Leverage: leverage, + UserID: at.userID, + TraderID: at.id, + }) +} + +// recordPositionChange records position change (create record on open, update record on close) +func (at *AutoTrader) recordPositionChange(orderID, symbol, side, action string, quantity, price float64, leverage int, entryPrice float64, fee float64) { + if at.store == nil { + return + } + + switch action { + case "open_long", "open_short": + // Open position: create new position record + nowMs := time.Now().UTC().UnixMilli() + pos := &store.TraderPosition{ + TraderID: at.id, + ExchangeID: at.exchangeID, // Exchange account UUID + ExchangeType: at.exchange, // Exchange type: binance/bybit/okx/etc + Symbol: symbol, + Side: side, // LONG or SHORT + Quantity: quantity, + EntryPrice: price, + EntryOrderID: orderID, + EntryTime: nowMs, + Leverage: leverage, + Status: "OPEN", + CreatedAt: nowMs, + UpdatedAt: nowMs, + } + if err := at.store.Position().Create(pos); err != nil { + logger.Infof(" โš ๏ธ Failed to record position: %v", err) + } else { + logger.Infof(" ๐Ÿ“Š Position recorded [%s] %s %s @ %.4f", at.id[:8], symbol, side, price) + } + + case "close_long", "close_short": + // Close position using PositionBuilder for consistent handling + // PositionBuilder will handle both cases: + // 1. If open position exists: close it properly + // 2. If no open position (e.g., table cleared): create a closed position record + posBuilder := store.NewPositionBuilder(at.store.Position()) + if err := posBuilder.ProcessTrade( + at.id, at.exchangeID, at.exchange, + symbol, side, action, + quantity, price, fee, 0, // realizedPnL will be calculated + time.Now().UTC().UnixMilli(), orderID, + ); err != nil { + logger.Infof(" โš ๏ธ Failed to process close position: %v", err) + } else { + logger.Infof(" โœ… Position closed [%s] %s %s @ %.4f", at.id[:8], symbol, side, price) + } + } +} + +// createOrderRecord creates an order record struct from order details +func (at *AutoTrader) createOrderRecord(orderID, symbol, action, positionSide string, quantity, price float64, leverage int) *store.TraderOrder { + // Determine order type (market for auto trader) + orderType := "MARKET" + + // Determine side (BUY/SELL) + var side string + switch action { + case "open_long", "close_short": + side = "BUY" + case "open_short", "close_long": + side = "SELL" + } + + // Use action as orderAction directly (keep lowercase format) + orderAction := action + + // Determine if it's a reduce only order + reduceOnly := (action == "close_long" || action == "close_short") + + // Normalize symbol for consistency + normalizedSymbol := market.Normalize(symbol) + + return &store.TraderOrder{ + TraderID: at.id, + ExchangeID: at.exchangeID, + ExchangeType: at.exchange, + ExchangeOrderID: orderID, + Symbol: normalizedSymbol, + Side: side, + PositionSide: positionSide, + Type: orderType, + TimeInForce: "GTC", + Quantity: quantity, + Price: price, + Status: "NEW", + FilledQuantity: 0, + AvgFillPrice: 0, + Commission: 0, + CommissionAsset: "USDT", + Leverage: leverage, + ReduceOnly: reduceOnly, + ClosePosition: reduceOnly, + OrderAction: orderAction, + CreatedAt: time.Now().UTC().UnixMilli(), + UpdatedAt: time.Now().UTC().UnixMilli(), + } +} + +// recordOrderFill records order fill/trade details +func (at *AutoTrader) recordOrderFill(orderRecordID int64, exchangeOrderID, symbol, action string, price, quantity, fee float64) { + if at.store == nil { + return + } + + // Determine side (BUY/SELL) + var side string + switch action { + case "open_long", "close_short": + side = "BUY" + case "open_short", "close_long": + side = "SELL" + } + + // Generate a simple trade ID (exchange doesn't always provide one) + tradeID := fmt.Sprintf("%s-%d", exchangeOrderID, time.Now().UnixNano()) + + // Normalize symbol for consistency + normalizedSymbol := market.Normalize(symbol) + + fill := &store.TraderFill{ + TraderID: at.id, + ExchangeID: at.exchangeID, + ExchangeType: at.exchange, + OrderID: orderRecordID, + ExchangeOrderID: exchangeOrderID, + ExchangeTradeID: tradeID, + Symbol: normalizedSymbol, + Side: side, + Price: price, + Quantity: quantity, + QuoteQuantity: price * quantity, + Commission: fee, + CommissionAsset: "USDT", + RealizedPnL: 0, // Will be calculated for close orders + IsMaker: false, // Market orders are usually taker + CreatedAt: time.Now().UTC().UnixMilli(), + } + + // Calculate realized PnL for close orders + if action == "close_long" || action == "close_short" { + // Try to get the entry price from the open position + var positionSide string + if action == "close_long" { + positionSide = "LONG" + } else { + positionSide = "SHORT" + } + + if openPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, symbol, positionSide); err == nil && openPos != nil { + if positionSide == "LONG" { + fill.RealizedPnL = (price - openPos.EntryPrice) * quantity + } else { + fill.RealizedPnL = (openPos.EntryPrice - price) * quantity + } + } + } + + if err := at.store.Order().CreateFill(fill); err != nil { + logger.Infof(" โš ๏ธ Failed to record fill: %v", err) + } else { + logger.Infof(" ๐Ÿ“‹ Fill recorded: %.4f @ %.6f, fee: %.4f", quantity, price, fee) + } +} + +// GetOpenOrders returns open orders (pending SL/TP) from exchange +func (at *AutoTrader) GetOpenOrders(symbol string) ([]OpenOrder, error) { + return at.trader.GetOpenOrders(symbol) +} diff --git a/trader/auto_trader_loop.go b/trader/auto_trader_loop.go new file mode 100644 index 00000000..e274c0da --- /dev/null +++ b/trader/auto_trader_loop.go @@ -0,0 +1,560 @@ +package trader + +import ( + "encoding/json" + "fmt" + "nofx/kernel" + "nofx/logger" + "nofx/store" + "strings" + "time" +) + +// runCycle runs one trading cycle (using AI full decision-making) +func (at *AutoTrader) runCycle() error { + at.callCount++ + + logger.Info("\n" + strings.Repeat("=", 70) + "\n") + logger.Infof("โฐ %s - AI decision cycle #%d", time.Now().Format("2006-01-02 15:04:05"), at.callCount) + logger.Info(strings.Repeat("=", 70)) + + // 0. Check if trader is stopped (early exit to prevent trades after Stop() is called) + at.isRunningMutex.RLock() + running := at.isRunning + at.isRunningMutex.RUnlock() + if !running { + logger.Infof("โน Trader is stopped, aborting cycle #%d", at.callCount) + return nil + } + + // Create decision record + record := &store.DecisionRecord{ + ExecutionLog: []string{}, + Success: true, + } + + // 1. Check if trading needs to be stopped + if time.Now().Before(at.stopUntil) { + remaining := at.stopUntil.Sub(time.Now()) + logger.Infof("โธ Risk control: Trading paused, remaining %.0f minutes", remaining.Minutes()) + record.Success = false + record.ErrorMessage = fmt.Sprintf("Risk control paused, remaining %.0f minutes", remaining.Minutes()) + at.saveDecision(record) + return nil + } + + // 2. Reset daily P&L (reset every day) + if time.Since(at.lastResetTime) > 24*time.Hour { + at.dailyPnL = 0 + at.lastResetTime = time.Now() + logger.Info("๐Ÿ“… Daily P&L reset") + } + + // 4. Collect trading context + ctx, err := at.buildTradingContext() + if err != nil { + record.Success = false + record.ErrorMessage = fmt.Sprintf("Failed to build trading context: %v", err) + at.saveDecision(record) + return fmt.Errorf("failed to build trading context: %w", err) + } + + // Save equity snapshot independently (decoupled from AI decision, used for drawing profit curve) + // NOTE: Must be called BEFORE candidate coins check to ensure equity is always recorded + at.saveEquitySnapshot(ctx) + + // ๅฆ‚ๆžœๆฒกๆœ‰ๅ€™้€‰ๅธ็ง๏ผŒ่ฎฐๅฝ•ไฝ†ไธๆŠฅ้”™ + if len(ctx.CandidateCoins) == 0 { + logger.Infof("โ„น๏ธ No candidate coins available, skipping this cycle") + record.Success = true // ไธๆ˜ฏ้”™่ฏฏ๏ผŒๅชๆ˜ฏๆฒกๆœ‰ๅ€™้€‰ๅธ + record.ExecutionLog = append(record.ExecutionLog, "No candidate coins available, cycle skipped") + record.AccountState = store.AccountSnapshot{ + TotalBalance: ctx.Account.TotalEquity, + AvailableBalance: ctx.Account.AvailableBalance, + TotalUnrealizedProfit: ctx.Account.UnrealizedPnL, + PositionCount: ctx.Account.PositionCount, + InitialBalance: at.initialBalance, + } + at.saveDecision(record) + return nil + } + + logger.Info(strings.Repeat("=", 70)) + for _, coin := range ctx.CandidateCoins { + record.CandidateCoins = append(record.CandidateCoins, coin.Symbol) + } + + logger.Infof("๐Ÿ“Š Account equity: %.2f USDT | Available: %.2f USDT | Positions: %d", + ctx.Account.TotalEquity, ctx.Account.AvailableBalance, ctx.Account.PositionCount) + + // 5. Use strategy engine to call AI for decision + logger.Infof("๐Ÿค– Requesting AI analysis and decision... [Strategy Engine]") + aiDecision, err := kernel.GetFullDecisionWithStrategy(ctx, at.mcpClient, at.strategyEngine, "balanced") + + if aiDecision != nil && aiDecision.AIRequestDurationMs > 0 { + record.AIRequestDurationMs = aiDecision.AIRequestDurationMs + logger.Infof("โฑ๏ธ AI call duration: %.2f seconds", float64(record.AIRequestDurationMs)/1000) + record.ExecutionLog = append(record.ExecutionLog, + fmt.Sprintf("AI call duration: %d ms", record.AIRequestDurationMs)) + } + + // Save chain of thought, decisions, and input prompt even if there's an error (for debugging) + if aiDecision != nil { + record.SystemPrompt = aiDecision.SystemPrompt // Save system prompt + record.InputPrompt = aiDecision.UserPrompt + record.CoTTrace = aiDecision.CoTTrace + record.RawResponse = aiDecision.RawResponse // Save raw AI response for debugging + if len(aiDecision.Decisions) > 0 { + decisionJSON, _ := json.MarshalIndent(aiDecision.Decisions, "", " ") + record.DecisionJSON = string(decisionJSON) + } + } + + if err != nil { + record.Success = false + record.ErrorMessage = fmt.Sprintf("Failed to get AI decision: %v", err) + + // Print system prompt and AI chain of thought (output even with errors for debugging) + if aiDecision != nil { + logger.Info("\n" + strings.Repeat("=", 70) + "\n") + logger.Infof("๐Ÿ“‹ System prompt (error case)") + logger.Info(strings.Repeat("=", 70)) + logger.Info(aiDecision.SystemPrompt) + logger.Info(strings.Repeat("=", 70)) + + if aiDecision.CoTTrace != "" { + logger.Info("\n" + strings.Repeat("-", 70) + "\n") + logger.Info("๐Ÿ’ญ AI chain of thought analysis (error case):") + logger.Info(strings.Repeat("-", 70)) + logger.Info(aiDecision.CoTTrace) + logger.Info(strings.Repeat("-", 70)) + } + } + + at.saveDecision(record) + return fmt.Errorf("failed to get AI decision: %w", err) + } + + // // 5. Print system prompt + // logger.Infof("\n" + strings.Repeat("=", 70)) + // logger.Infof("๐Ÿ“‹ System prompt [template: %s]", at.systemPromptTemplate) + // logger.Info(strings.Repeat("=", 70)) + // logger.Info(decision.SystemPrompt) + // logger.Infof(strings.Repeat("=", 70) + "\n") + + // 6. Print AI chain of thought + // logger.Infof("\n" + strings.Repeat("-", 70)) + // logger.Info("๐Ÿ’ญ AI chain of thought analysis:") + // logger.Info(strings.Repeat("-", 70)) + // logger.Info(decision.CoTTrace) + // logger.Infof(strings.Repeat("-", 70) + "\n") + + // 7. Print AI decisions + // logger.Infof("๐Ÿ“‹ AI decision list (%d items):\n", len(kernel.Decisions)) + // for i, d := range kernel.Decisions { + // logger.Infof(" [%d] %s: %s - %s", i+1, d.Symbol, d.Action, d.Reasoning) + // if d.Action == "open_long" || d.Action == "open_short" { + // logger.Infof(" Leverage: %dx | Position: %.2f USDT | Stop loss: %.4f | Take profit: %.4f", + // d.Leverage, d.PositionSizeUSD, d.StopLoss, d.TakeProfit) + // } + // } + logger.Info() + logger.Info(strings.Repeat("-", 70)) + // 8. Sort decisions: ensure close positions first, then open positions (prevent position stacking overflow) + logger.Info(strings.Repeat("-", 70)) + + // 8. Sort decisions: ensure close positions first, then open positions (prevent position stacking overflow) + sortedDecisions := sortDecisionsByPriority(aiDecision.Decisions) + + logger.Info("๐Ÿ”„ Execution order (optimized): Close positions first โ†’ Open positions later") + for i, d := range sortedDecisions { + logger.Infof(" [%d] %s %s", i+1, d.Symbol, d.Action) + } + logger.Info() + + // Check if trader is stopped before executing any decisions (prevent trades after Stop()) + at.isRunningMutex.RLock() + running = at.isRunning + at.isRunningMutex.RUnlock() + if !running { + logger.Infof("โน Trader stopped before decision execution, aborting cycle #%d", at.callCount) + return nil + } + + // Execute decisions and record results + for _, d := range sortedDecisions { + // Check if trader is stopped before each decision (allow immediate stop during execution) + at.isRunningMutex.RLock() + running = at.isRunning + at.isRunningMutex.RUnlock() + if !running { + logger.Infof("โน Trader stopped during decision execution, aborting remaining decisions") + break + } + + actionRecord := store.DecisionAction{ + Action: d.Action, + Symbol: d.Symbol, + Quantity: 0, + Leverage: d.Leverage, + Price: 0, + StopLoss: d.StopLoss, + TakeProfit: d.TakeProfit, + Confidence: d.Confidence, + Reasoning: d.Reasoning, + Timestamp: time.Now().UTC(), + Success: false, + } + + if err := at.executeDecisionWithRecord(&d, &actionRecord); err != nil { + logger.Infof("โŒ Failed to execute decision (%s %s): %v", d.Symbol, d.Action, err) + actionRecord.Error = err.Error() + record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("โŒ %s %s failed: %v", d.Symbol, d.Action, err)) + } else { + actionRecord.Success = true + record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("โœ“ %s %s succeeded", d.Symbol, d.Action)) + // Brief delay after successful execution + time.Sleep(1 * time.Second) + } + + record.Decisions = append(record.Decisions, actionRecord) + } + + // 9. Save decision record + if err := at.saveDecision(record); err != nil { + logger.Infof("โš  Failed to save decision record: %v", err) + } + + return nil +} + +// buildTradingContext builds trading context +func (at *AutoTrader) buildTradingContext() (*kernel.Context, error) { + // 1. Get account information + balance, err := at.trader.GetBalance() + if err != nil { + return nil, fmt.Errorf("failed to get account balance: %w", err) + } + + // Get account fields + totalWalletBalance := 0.0 + totalUnrealizedProfit := 0.0 + availableBalance := 0.0 + totalEquity := 0.0 + + if wallet, ok := balance["totalWalletBalance"].(float64); ok { + totalWalletBalance = wallet + } + if unrealized, ok := balance["totalUnrealizedProfit"].(float64); ok { + totalUnrealizedProfit = unrealized + } + if avail, ok := balance["availableBalance"].(float64); ok { + availableBalance = avail + } + + // Use totalEquity directly if provided by trader (more accurate) + if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 { + totalEquity = eq + } else { + // Fallback: Total Equity = Wallet balance + Unrealized profit + totalEquity = totalWalletBalance + totalUnrealizedProfit + } + + // 2. Get position information + positions, err := at.trader.GetPositions() + if err != nil { + return nil, fmt.Errorf("failed to get positions: %w", err) + } + + var positionInfos []kernel.PositionInfo + totalMarginUsed := 0.0 + + // Current position key set (for cleaning up closed position records) + currentPositionKeys := make(map[string]bool) + + for _, pos := range positions { + symbol := pos["symbol"].(string) + side := pos["side"].(string) + entryPrice := pos["entryPrice"].(float64) + markPrice := pos["markPrice"].(float64) + quantity := pos["positionAmt"].(float64) + if quantity < 0 { + quantity = -quantity // Short position quantity is negative, convert to positive + } + + // Skip closed positions (quantity = 0), prevent "ghost positions" from being passed to AI + if quantity == 0 { + continue + } + + unrealizedPnl := pos["unRealizedProfit"].(float64) + liquidationPrice := pos["liquidationPrice"].(float64) + + // Calculate margin used (estimated) + leverage := 10 // Default value, should actually be fetched from position info + if lev, ok := pos["leverage"].(float64); ok { + leverage = int(lev) + } + marginUsed := (quantity * markPrice) / float64(leverage) + totalMarginUsed += marginUsed + + // Calculate P&L percentage (based on margin, considering leverage) + pnlPct := calculatePnLPercentage(unrealizedPnl, marginUsed) + + // Get position open time from exchange (preferred) or fallback to local tracking + posKey := symbol + "_" + side + currentPositionKeys[posKey] = true + + var updateTime int64 + // Priority 1: Get from database (trader_positions table) - most accurate + if at.store != nil { + if dbPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, symbol, side); err == nil && dbPos != nil { + if dbPos.EntryTime > 0 { + updateTime = dbPos.EntryTime + } + } + } + // Priority 2: Get from exchange API (Bybit: createdTime, OKX: createdTime) + if updateTime == 0 { + if createdTime, ok := pos["createdTime"].(int64); ok && createdTime > 0 { + updateTime = createdTime + } + } + // Priority 3: Fallback to local tracking + if updateTime == 0 { + if _, exists := at.positionFirstSeenTime[posKey]; !exists { + at.positionFirstSeenTime[posKey] = time.Now().UnixMilli() + } + updateTime = at.positionFirstSeenTime[posKey] + } + + // Get peak profit rate for this position + at.peakPnLCacheMutex.RLock() + peakPnlPct := at.peakPnLCache[posKey] + at.peakPnLCacheMutex.RUnlock() + + positionInfos = append(positionInfos, kernel.PositionInfo{ + Symbol: symbol, + Side: side, + EntryPrice: entryPrice, + MarkPrice: markPrice, + Quantity: quantity, + Leverage: leverage, + UnrealizedPnL: unrealizedPnl, + UnrealizedPnLPct: pnlPct, + PeakPnLPct: peakPnlPct, + LiquidationPrice: liquidationPrice, + MarginUsed: marginUsed, + UpdateTime: updateTime, + }) + } + + // Clean up closed position records + for key := range at.positionFirstSeenTime { + if !currentPositionKeys[key] { + delete(at.positionFirstSeenTime, key) + } + } + + // 3. Use strategy engine to get candidate coins (must have strategy engine) + var candidateCoins []kernel.CandidateCoin + if at.strategyEngine == nil { + logger.Infof("โš ๏ธ [%s] No strategy engine configured, skipping candidate coins", at.name) + } else { + coins, err := at.strategyEngine.GetCandidateCoins() + if err != nil { + // Log warning but don't fail - equity snapshot should still be saved + logger.Infof("โš ๏ธ [%s] Failed to get candidate coins: %v (will use empty list)", at.name, err) + } else { + candidateCoins = coins + logger.Infof("๐Ÿ“‹ [%s] Strategy engine fetched candidate coins: %d", at.name, len(candidateCoins)) + } + } + + // 4. Calculate total P&L + totalPnL := totalEquity - at.initialBalance + totalPnLPct := 0.0 + if at.initialBalance > 0 { + totalPnLPct = (totalPnL / at.initialBalance) * 100 + } + + marginUsedPct := 0.0 + if totalEquity > 0 { + marginUsedPct = (totalMarginUsed / totalEquity) * 100 + } + + // 5. Get leverage from strategy config + strategyConfig := at.strategyEngine.GetConfig() + btcEthLeverage := strategyConfig.RiskControl.BTCETHMaxLeverage + altcoinLeverage := strategyConfig.RiskControl.AltcoinMaxLeverage + logger.Infof("๐Ÿ“‹ [%s] Strategy leverage config: BTC/ETH=%dx, Altcoin=%dx", at.name, btcEthLeverage, altcoinLeverage) + + // 6. Build context + ctx := &kernel.Context{ + CurrentTime: time.Now().UTC().Format("2006-01-02 15:04:05 UTC"), + RuntimeMinutes: int(time.Since(at.startTime).Minutes()), + CallCount: at.callCount, + BTCETHLeverage: btcEthLeverage, + AltcoinLeverage: altcoinLeverage, + Account: kernel.AccountInfo{ + TotalEquity: totalEquity, + AvailableBalance: availableBalance, + UnrealizedPnL: totalUnrealizedProfit, + TotalPnL: totalPnL, + TotalPnLPct: totalPnLPct, + MarginUsed: totalMarginUsed, + MarginUsedPct: marginUsedPct, + PositionCount: len(positionInfos), + }, + Positions: positionInfos, + CandidateCoins: candidateCoins, + } + + // 7. Add recent closed trades (if store is available) + if at.store != nil { + // Get recent 10 closed trades for AI context + recentTrades, err := at.store.Position().GetRecentTrades(at.id, 10) + if err != nil { + logger.Infof("โš ๏ธ [%s] Failed to get recent trades: %v", at.name, err) + } else { + logger.Infof("๐Ÿ“Š [%s] Found %d recent closed trades for AI context", at.name, len(recentTrades)) + for _, trade := range recentTrades { + // Convert Unix timestamps to formatted strings for AI readability + entryTimeStr := "" + if trade.EntryTime > 0 { + entryTimeStr = time.Unix(trade.EntryTime, 0).UTC().Format("01-02 15:04 UTC") + } + exitTimeStr := "" + if trade.ExitTime > 0 { + exitTimeStr = time.Unix(trade.ExitTime, 0).UTC().Format("01-02 15:04 UTC") + } + + ctx.RecentOrders = append(ctx.RecentOrders, kernel.RecentOrder{ + Symbol: trade.Symbol, + Side: trade.Side, + EntryPrice: trade.EntryPrice, + ExitPrice: trade.ExitPrice, + RealizedPnL: trade.RealizedPnL, + PnLPct: trade.PnLPct, + EntryTime: entryTimeStr, + ExitTime: exitTimeStr, + HoldDuration: trade.HoldDuration, + }) + } + } + // Get trading statistics for AI context + stats, err := at.store.Position().GetFullStats(at.id) + if err != nil { + logger.Infof("โš ๏ธ [%s] Failed to get trading stats: %v", at.name, err) + } else if stats == nil { + logger.Infof("โš ๏ธ [%s] GetFullStats returned nil", at.name) + } else if stats.TotalTrades == 0 { + logger.Infof("โš ๏ธ [%s] GetFullStats returned 0 trades (traderID=%s)", at.name, at.id) + } else { + ctx.TradingStats = &kernel.TradingStats{ + TotalTrades: stats.TotalTrades, + WinRate: stats.WinRate, + ProfitFactor: stats.ProfitFactor, + SharpeRatio: stats.SharpeRatio, + TotalPnL: stats.TotalPnL, + AvgWin: stats.AvgWin, + AvgLoss: stats.AvgLoss, + MaxDrawdownPct: stats.MaxDrawdownPct, + } + logger.Infof("๐Ÿ“ˆ [%s] Trading stats: %d trades, %.1f%% win rate, PF=%.2f, Sharpe=%.2f, DD=%.1f%%", + at.name, stats.TotalTrades, stats.WinRate, stats.ProfitFactor, stats.SharpeRatio, stats.MaxDrawdownPct) + } + } else { + logger.Infof("โš ๏ธ [%s] Store is nil, cannot get recent trades", at.name) + } + + // 8. Get quantitative data (if enabled in strategy config) + if strategyConfig.Indicators.EnableQuantData { + // Collect symbols to query (candidate coins + position coins) + symbolsToQuery := make(map[string]bool) + for _, coin := range candidateCoins { + symbolsToQuery[coin.Symbol] = true + } + for _, pos := range positionInfos { + symbolsToQuery[pos.Symbol] = true + } + + symbols := make([]string, 0, len(symbolsToQuery)) + for sym := range symbolsToQuery { + symbols = append(symbols, sym) + } + + logger.Infof("๐Ÿ“Š [%s] Fetching quantitative data for %d symbols...", at.name, len(symbols)) + ctx.QuantDataMap = at.strategyEngine.FetchQuantDataBatch(symbols) + logger.Infof("๐Ÿ“Š [%s] Successfully fetched quantitative data for %d symbols", at.name, len(ctx.QuantDataMap)) + } + + // 9. Get OI ranking data (market-wide position changes) + if strategyConfig.Indicators.EnableOIRanking { + logger.Infof("๐Ÿ“Š [%s] Fetching OI ranking data...", at.name) + ctx.OIRankingData = at.strategyEngine.FetchOIRankingData() + if ctx.OIRankingData != nil { + logger.Infof("๐Ÿ“Š [%s] OI ranking data ready: %d top, %d low positions", + at.name, len(ctx.OIRankingData.TopPositions), len(ctx.OIRankingData.LowPositions)) + } + } + + // 10. Get NetFlow ranking data (market-wide fund flow) + if strategyConfig.Indicators.EnableNetFlowRanking { + logger.Infof("๐Ÿ’ฐ [%s] Fetching NetFlow ranking data...", at.name) + ctx.NetFlowRankingData = at.strategyEngine.FetchNetFlowRankingData() + if ctx.NetFlowRankingData != nil { + logger.Infof("๐Ÿ’ฐ [%s] NetFlow ranking data ready: inst_in=%d, inst_out=%d", + at.name, len(ctx.NetFlowRankingData.InstitutionFutureTop), len(ctx.NetFlowRankingData.InstitutionFutureLow)) + } + } + + // 11. Get Price ranking data (market-wide gainers/losers) + if strategyConfig.Indicators.EnablePriceRanking { + logger.Infof("๐Ÿ“ˆ [%s] Fetching Price ranking data...", at.name) + ctx.PriceRankingData = at.strategyEngine.FetchPriceRankingData() + if ctx.PriceRankingData != nil { + logger.Infof("๐Ÿ“ˆ [%s] Price ranking data ready for %d durations", + at.name, len(ctx.PriceRankingData.Durations)) + } + } + + return ctx, nil +} + +// sortDecisionsByPriority sorts decisions: close positions first, then open positions, finally hold/wait +// This avoids position stacking overflow when changing positions +func sortDecisionsByPriority(decisions []kernel.Decision) []kernel.Decision { + if len(decisions) <= 1 { + return decisions + } + + // Define priority + getActionPriority := func(action string) int { + switch action { + case "close_long", "close_short": + return 1 // Highest priority: close positions first + case "open_long", "open_short": + return 2 // Second priority: open positions later + case "hold", "wait": + return 3 // Lowest priority: wait + default: + return 999 // Unknown actions at the end + } + } + + // Copy decision list + sorted := make([]kernel.Decision, len(decisions)) + copy(sorted, decisions) + + // Sort by priority + for i := 0; i < len(sorted)-1; i++ { + for j := i + 1; j < len(sorted); j++ { + if getActionPriority(sorted[i].Action) > getActionPriority(sorted[j].Action) { + sorted[i], sorted[j] = sorted[j], sorted[i] + } + } + } + + return sorted +} diff --git a/trader/auto_trader_orders.go b/trader/auto_trader_orders.go new file mode 100644 index 00000000..edfc1995 --- /dev/null +++ b/trader/auto_trader_orders.go @@ -0,0 +1,391 @@ +package trader + +import ( + "fmt" + "nofx/kernel" + "nofx/logger" + "nofx/market" + "nofx/store" + "time" +) + +// executeDecisionWithRecord executes AI decision and records detailed information +func (at *AutoTrader) executeDecisionWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error { + switch decision.Action { + case "open_long": + return at.executeOpenLongWithRecord(decision, actionRecord) + case "open_short": + return at.executeOpenShortWithRecord(decision, actionRecord) + case "close_long": + return at.executeCloseLongWithRecord(decision, actionRecord) + case "close_short": + return at.executeCloseShortWithRecord(decision, actionRecord) + case "hold", "wait": + // No execution needed, just record + return nil + default: + return fmt.Errorf("unknown action: %s", decision.Action) + } +} + +// executeOpenLongWithRecord executes open long position and records detailed information +func (at *AutoTrader) executeOpenLongWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error { + logger.Infof(" ๐Ÿ“ˆ Open long: %s", decision.Symbol) + + // โš ๏ธ Get current positions for multiple checks + positions, err := at.trader.GetPositions() + if err != nil { + return fmt.Errorf("failed to get positions: %w", err) + } + + // [CODE ENFORCED] Check max positions limit + if err := at.enforceMaxPositions(len(positions)); err != nil { + return err + } + + // Check if there's already a position in the same symbol and direction + for _, pos := range positions { + if pos["symbol"] == decision.Symbol && pos["side"] == "long" { + return fmt.Errorf("โŒ %s already has long position, close it first", decision.Symbol) + } + } + + // Get current price + marketData, err := market.GetWithExchange(decision.Symbol, at.exchange) + if err != nil { + return err + } + + // Get balance (needed for multiple checks) + balance, err := at.trader.GetBalance() + if err != nil { + return fmt.Errorf("failed to get account balance: %w", err) + } + availableBalance := 0.0 + if avail, ok := balance["availableBalance"].(float64); ok { + availableBalance = avail + } + + // Get equity for position value ratio check + equity := 0.0 + if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 { + equity = eq + } else if eq, ok := balance["totalWalletBalance"].(float64); ok && eq > 0 { + equity = eq + } else { + equity = availableBalance // Fallback to available balance + } + + // [CODE ENFORCED] Position Value Ratio Check: position_value <= equity ร— ratio + adjustedPositionSize, wasCapped := at.enforcePositionValueRatio(decision.PositionSizeUSD, equity, decision.Symbol) + if wasCapped { + decision.PositionSizeUSD = adjustedPositionSize + } + + // โš ๏ธ Auto-adjust position size if insufficient margin + // Formula: totalRequired = positionSize/leverage + positionSize*0.001 + positionSize/leverage*0.01 + // = positionSize * (1.01/leverage + 0.001) + marginFactor := 1.01/float64(decision.Leverage) + 0.001 + maxAffordablePositionSize := availableBalance / marginFactor + + actualPositionSize := decision.PositionSizeUSD + if actualPositionSize > maxAffordablePositionSize { + // Use 98% of max to leave buffer for price fluctuation + adjustedSize := maxAffordablePositionSize * 0.98 + logger.Infof(" โš ๏ธ Position size %.2f exceeds max affordable %.2f, auto-reducing to %.2f", + actualPositionSize, maxAffordablePositionSize, adjustedSize) + actualPositionSize = adjustedSize + decision.PositionSizeUSD = actualPositionSize + } + + // [CODE ENFORCED] Minimum position size check + if err := at.enforceMinPositionSize(decision.PositionSizeUSD); err != nil { + return err + } + + // Calculate quantity with adjusted position size + quantity := actualPositionSize / marketData.CurrentPrice + actionRecord.Quantity = quantity + actionRecord.Price = marketData.CurrentPrice + + // Set margin mode + if err := at.trader.SetMarginMode(decision.Symbol, at.config.IsCrossMargin); err != nil { + logger.Infof(" โš ๏ธ Failed to set margin mode: %v", err) + // Continue execution, doesn't affect trading + } + + // Open position + order, err := at.trader.OpenLong(decision.Symbol, quantity, decision.Leverage) + if err != nil { + return err + } + + // Record order ID + if orderID, ok := order["orderId"].(int64); ok { + actionRecord.OrderID = orderID + } + + logger.Infof(" โœ“ Position opened successfully, order ID: %v, quantity: %.4f", order["orderId"], quantity) + + // Record order to database and poll for confirmation + at.recordAndConfirmOrder(order, decision.Symbol, "open_long", quantity, marketData.CurrentPrice, decision.Leverage, 0) + + // Record position opening time + posKey := decision.Symbol + "_long" + at.positionFirstSeenTime[posKey] = time.Now().UnixMilli() + + // Set stop loss and take profit + if err := at.trader.SetStopLoss(decision.Symbol, "LONG", quantity, decision.StopLoss); err != nil { + logger.Infof(" โš  Failed to set stop loss: %v", err) + } + if err := at.trader.SetTakeProfit(decision.Symbol, "LONG", quantity, decision.TakeProfit); err != nil { + logger.Infof(" โš  Failed to set take profit: %v", err) + } + + return nil +} + +// executeOpenShortWithRecord executes open short position and records detailed information +func (at *AutoTrader) executeOpenShortWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error { + logger.Infof(" ๐Ÿ“‰ Open short: %s", decision.Symbol) + + // โš ๏ธ Get current positions for multiple checks + positions, err := at.trader.GetPositions() + if err != nil { + return fmt.Errorf("failed to get positions: %w", err) + } + + // [CODE ENFORCED] Check max positions limit + if err := at.enforceMaxPositions(len(positions)); err != nil { + return err + } + + // Check if there's already a position in the same symbol and direction + for _, pos := range positions { + if pos["symbol"] == decision.Symbol && pos["side"] == "short" { + return fmt.Errorf("โŒ %s already has short position, close it first", decision.Symbol) + } + } + + // Get current price + marketData, err := market.GetWithExchange(decision.Symbol, at.exchange) + if err != nil { + return err + } + + // Get balance (needed for multiple checks) + balance, err := at.trader.GetBalance() + if err != nil { + return fmt.Errorf("failed to get account balance: %w", err) + } + availableBalance := 0.0 + if avail, ok := balance["availableBalance"].(float64); ok { + availableBalance = avail + } + + // Get equity for position value ratio check + equity := 0.0 + if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 { + equity = eq + } else if eq, ok := balance["totalWalletBalance"].(float64); ok && eq > 0 { + equity = eq + } else { + equity = availableBalance // Fallback to available balance + } + + // [CODE ENFORCED] Position Value Ratio Check: position_value <= equity ร— ratio + adjustedPositionSize, wasCapped := at.enforcePositionValueRatio(decision.PositionSizeUSD, equity, decision.Symbol) + if wasCapped { + decision.PositionSizeUSD = adjustedPositionSize + } + + // โš ๏ธ Auto-adjust position size if insufficient margin + // Formula: totalRequired = positionSize/leverage + positionSize*0.001 + positionSize/leverage*0.01 + // = positionSize * (1.01/leverage + 0.001) + marginFactor := 1.01/float64(decision.Leverage) + 0.001 + maxAffordablePositionSize := availableBalance / marginFactor + + actualPositionSize := decision.PositionSizeUSD + if actualPositionSize > maxAffordablePositionSize { + // Use 98% of max to leave buffer for price fluctuation + adjustedSize := maxAffordablePositionSize * 0.98 + logger.Infof(" โš ๏ธ Position size %.2f exceeds max affordable %.2f, auto-reducing to %.2f", + actualPositionSize, maxAffordablePositionSize, adjustedSize) + actualPositionSize = adjustedSize + decision.PositionSizeUSD = actualPositionSize + } + + // [CODE ENFORCED] Minimum position size check + if err := at.enforceMinPositionSize(decision.PositionSizeUSD); err != nil { + return err + } + + // Calculate quantity with adjusted position size + quantity := actualPositionSize / marketData.CurrentPrice + actionRecord.Quantity = quantity + actionRecord.Price = marketData.CurrentPrice + + // Set margin mode + if err := at.trader.SetMarginMode(decision.Symbol, at.config.IsCrossMargin); err != nil { + logger.Infof(" โš ๏ธ Failed to set margin mode: %v", err) + // Continue execution, doesn't affect trading + } + + // Open position + order, err := at.trader.OpenShort(decision.Symbol, quantity, decision.Leverage) + if err != nil { + return err + } + + // Record order ID + if orderID, ok := order["orderId"].(int64); ok { + actionRecord.OrderID = orderID + } + + logger.Infof(" โœ“ Position opened successfully, order ID: %v, quantity: %.4f", order["orderId"], quantity) + + // Record order to database and poll for confirmation + at.recordAndConfirmOrder(order, decision.Symbol, "open_short", quantity, marketData.CurrentPrice, decision.Leverage, 0) + + // Record position opening time + posKey := decision.Symbol + "_short" + at.positionFirstSeenTime[posKey] = time.Now().UnixMilli() + + // Set stop loss and take profit + if err := at.trader.SetStopLoss(decision.Symbol, "SHORT", quantity, decision.StopLoss); err != nil { + logger.Infof(" โš  Failed to set stop loss: %v", err) + } + if err := at.trader.SetTakeProfit(decision.Symbol, "SHORT", quantity, decision.TakeProfit); err != nil { + logger.Infof(" โš  Failed to set take profit: %v", err) + } + + return nil +} + +// executeCloseLongWithRecord executes close long position and records detailed information +func (at *AutoTrader) executeCloseLongWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error { + logger.Infof(" ๐Ÿ”„ Close long: %s", decision.Symbol) + + // Get current price + marketData, err := market.GetWithExchange(decision.Symbol, at.exchange) + if err != nil { + return err + } + actionRecord.Price = marketData.CurrentPrice + + // Normalize symbol for database lookup + normalizedSymbol := market.Normalize(decision.Symbol) + + // Get entry price and quantity - prioritize local database for accurate quantity + var entryPrice float64 + var quantity float64 + + // First try to get from local database (more accurate for quantity) + if at.store != nil { + if openPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, normalizedSymbol, "LONG"); err == nil && openPos != nil { + quantity = openPos.Quantity + entryPrice = openPos.EntryPrice + logger.Infof(" ๐Ÿ“Š Using local position data: qty=%.8f, entry=%.2f", quantity, entryPrice) + } + } + + // Fallback to exchange API if local data not found + if quantity == 0 { + positions, err := at.trader.GetPositions() + if err == nil { + for _, pos := range positions { + if pos["symbol"] == decision.Symbol && pos["side"] == "long" { + if ep, ok := pos["entryPrice"].(float64); ok { + entryPrice = ep + } + if amt, ok := pos["positionAmt"].(float64); ok && amt > 0 { + quantity = amt + } + break + } + } + } + logger.Infof(" ๐Ÿ“Š Using exchange position data: qty=%.8f, entry=%.2f", quantity, entryPrice) + } + + // Close position + order, err := at.trader.CloseLong(decision.Symbol, 0) // 0 = close all + if err != nil { + return err + } + + // Record order ID + if orderID, ok := order["orderId"].(int64); ok { + actionRecord.OrderID = orderID + } + + // Record order to database and poll for confirmation + at.recordAndConfirmOrder(order, decision.Symbol, "close_long", quantity, marketData.CurrentPrice, 0, entryPrice) + + logger.Infof(" โœ“ Position closed successfully") + return nil +} + +// executeCloseShortWithRecord executes close short position and records detailed information +func (at *AutoTrader) executeCloseShortWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error { + logger.Infof(" ๐Ÿ”„ Close short: %s", decision.Symbol) + + // Get current price + marketData, err := market.GetWithExchange(decision.Symbol, at.exchange) + if err != nil { + return err + } + actionRecord.Price = marketData.CurrentPrice + + // Normalize symbol for database lookup + normalizedSymbol := market.Normalize(decision.Symbol) + + // Get entry price and quantity - prioritize local database for accurate quantity + var entryPrice float64 + var quantity float64 + + // First try to get from local database (more accurate for quantity) + if at.store != nil { + if openPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, normalizedSymbol, "SHORT"); err == nil && openPos != nil { + quantity = openPos.Quantity + entryPrice = openPos.EntryPrice + logger.Infof(" ๐Ÿ“Š Using local position data: qty=%.8f, entry=%.2f", quantity, entryPrice) + } + } + + // Fallback to exchange API if local data not found + if quantity == 0 { + positions, err := at.trader.GetPositions() + if err == nil { + for _, pos := range positions { + if pos["symbol"] == decision.Symbol && pos["side"] == "short" { + if ep, ok := pos["entryPrice"].(float64); ok { + entryPrice = ep + } + if amt, ok := pos["positionAmt"].(float64); ok { + quantity = -amt // positionAmt is negative for short + } + break + } + } + } + logger.Infof(" ๐Ÿ“Š Using exchange position data: qty=%.8f, entry=%.2f", quantity, entryPrice) + } + + // Close position + order, err := at.trader.CloseShort(decision.Symbol, 0) // 0 = close all + if err != nil { + return err + } + + // Record order ID + if orderID, ok := order["orderId"].(int64); ok { + actionRecord.OrderID = orderID + } + + // Record order to database and poll for confirmation + at.recordAndConfirmOrder(order, decision.Symbol, "close_short", quantity, marketData.CurrentPrice, 0, entryPrice) + + logger.Infof(" โœ“ Position closed successfully") + return nil +} diff --git a/trader/auto_trader_risk.go b/trader/auto_trader_risk.go new file mode 100644 index 00000000..b8c94fb3 --- /dev/null +++ b/trader/auto_trader_risk.go @@ -0,0 +1,263 @@ +package trader + +import ( + "fmt" + "nofx/logger" + "strings" + "time" +) + +// startDrawdownMonitor starts drawdown monitoring +func (at *AutoTrader) startDrawdownMonitor() { + at.monitorWg.Add(1) + go func() { + defer at.monitorWg.Done() + + ticker := time.NewTicker(1 * time.Minute) // Check every minute + defer ticker.Stop() + + logger.Info("๐Ÿ“Š Started position drawdown monitoring (check every minute)") + + for { + select { + case <-ticker.C: + at.checkPositionDrawdown() + case <-at.stopMonitorCh: + logger.Info("โน Stopped position drawdown monitoring") + return + } + } + }() +} + +// checkPositionDrawdown checks position drawdown situation +func (at *AutoTrader) checkPositionDrawdown() { + // Get current positions + positions, err := at.trader.GetPositions() + if err != nil { + logger.Infof("โŒ Drawdown monitoring: failed to get positions: %v", err) + return + } + + for _, pos := range positions { + symbol := pos["symbol"].(string) + side := pos["side"].(string) + entryPrice := pos["entryPrice"].(float64) + markPrice := pos["markPrice"].(float64) + quantity := pos["positionAmt"].(float64) + if quantity < 0 { + quantity = -quantity // Short position quantity is negative, convert to positive + } + + // Calculate current P&L percentage + leverage := 10 // Default value + if lev, ok := pos["leverage"].(float64); ok { + leverage = int(lev) + } + + var currentPnLPct float64 + if side == "long" { + currentPnLPct = ((markPrice - entryPrice) / entryPrice) * float64(leverage) * 100 + } else { + currentPnLPct = ((entryPrice - markPrice) / entryPrice) * float64(leverage) * 100 + } + + // Construct unique position identifier (distinguish long/short) + posKey := symbol + "_" + side + + // Get historical peak profit for this position + at.peakPnLCacheMutex.RLock() + peakPnLPct, exists := at.peakPnLCache[posKey] + at.peakPnLCacheMutex.RUnlock() + + if !exists { + // If no historical peak record, use current P&L as initial value + peakPnLPct = currentPnLPct + at.UpdatePeakPnL(symbol, side, currentPnLPct) + } else { + // Update peak cache + at.UpdatePeakPnL(symbol, side, currentPnLPct) + } + + // Calculate drawdown (magnitude of decline from peak) + var drawdownPct float64 + if peakPnLPct > 0 && currentPnLPct < peakPnLPct { + drawdownPct = ((peakPnLPct - currentPnLPct) / peakPnLPct) * 100 + } + + // Check close position condition: profit > 5% and drawdown >= 40% + if currentPnLPct > 5.0 && drawdownPct >= 40.0 { + logger.Infof("๐Ÿšจ Drawdown close position condition triggered: %s %s | Current profit: %.2f%% | Peak profit: %.2f%% | Drawdown: %.2f%%", + symbol, side, currentPnLPct, peakPnLPct, drawdownPct) + + // Execute close position + if err := at.emergencyClosePosition(symbol, side); err != nil { + logger.Infof("โŒ Drawdown close position failed (%s %s): %v", symbol, side, err) + } else { + logger.Infof("โœ… Drawdown close position succeeded: %s %s", symbol, side) + // Clear cache for this position after closing + at.ClearPeakPnLCache(symbol, side) + } + } else if currentPnLPct > 5.0 { + // Record situations close to close position condition (for debugging) + logger.Infof("๐Ÿ“Š Drawdown monitoring: %s %s | Profit: %.2f%% | Peak: %.2f%% | Drawdown: %.2f%%", + symbol, side, currentPnLPct, peakPnLPct, drawdownPct) + } + } +} + +// emergencyClosePosition emergency close position function +func (at *AutoTrader) emergencyClosePosition(symbol, side string) error { + switch side { + case "long": + order, err := at.trader.CloseLong(symbol, 0) // 0 = close all + if err != nil { + return err + } + logger.Infof("โœ… Emergency close long position succeeded, order ID: %v", order["orderId"]) + case "short": + order, err := at.trader.CloseShort(symbol, 0) // 0 = close all + if err != nil { + return err + } + logger.Infof("โœ… Emergency close short position succeeded, order ID: %v", order["orderId"]) + default: + return fmt.Errorf("unknown position direction: %s", side) + } + + return nil +} + +// GetPeakPnLCache gets peak profit cache +func (at *AutoTrader) GetPeakPnLCache() map[string]float64 { + at.peakPnLCacheMutex.RLock() + defer at.peakPnLCacheMutex.RUnlock() + + // Return a copy of the cache + cache := make(map[string]float64) + for k, v := range at.peakPnLCache { + cache[k] = v + } + return cache +} + +// UpdatePeakPnL updates peak profit cache +func (at *AutoTrader) UpdatePeakPnL(symbol, side string, currentPnLPct float64) { + at.peakPnLCacheMutex.Lock() + defer at.peakPnLCacheMutex.Unlock() + + posKey := symbol + "_" + side + if peak, exists := at.peakPnLCache[posKey]; exists { + // Update peak (if long, take larger value; if short, currentPnLPct is negative, also compare) + if currentPnLPct > peak { + at.peakPnLCache[posKey] = currentPnLPct + } + } else { + // First time recording + at.peakPnLCache[posKey] = currentPnLPct + } +} + +// ClearPeakPnLCache clears peak cache for specified position +func (at *AutoTrader) ClearPeakPnLCache(symbol, side string) { + at.peakPnLCacheMutex.Lock() + defer at.peakPnLCacheMutex.Unlock() + + posKey := symbol + "_" + side + delete(at.peakPnLCache, posKey) +} + +// ============================================================================ +// Risk Control Helpers +// ============================================================================ + +// isBTCETH checks if a symbol is BTC or ETH +func isBTCETH(symbol string) bool { + symbol = strings.ToUpper(symbol) + return strings.HasPrefix(symbol, "BTC") || strings.HasPrefix(symbol, "ETH") +} + +// enforcePositionValueRatio checks and enforces position value ratio limits (CODE ENFORCED) +// Returns the adjusted position size (capped if necessary) and whether the position was capped +// positionSizeUSD: the original position size in USD +// equity: the account equity +// symbol: the trading symbol +func (at *AutoTrader) enforcePositionValueRatio(positionSizeUSD float64, equity float64, symbol string) (float64, bool) { + if at.config.StrategyConfig == nil { + return positionSizeUSD, false + } + + riskControl := at.config.StrategyConfig.RiskControl + + // Get the appropriate position value ratio limit + var maxPositionValueRatio float64 + if isBTCETH(symbol) { + maxPositionValueRatio = riskControl.BTCETHMaxPositionValueRatio + if maxPositionValueRatio <= 0 { + maxPositionValueRatio = 5.0 // Default: 5x for BTC/ETH + } + } else { + maxPositionValueRatio = riskControl.AltcoinMaxPositionValueRatio + if maxPositionValueRatio <= 0 { + maxPositionValueRatio = 1.0 // Default: 1x for altcoins + } + } + + // Calculate max allowed position value = equity ร— ratio + maxPositionValue := equity * maxPositionValueRatio + + // Check if position size exceeds limit + if positionSizeUSD > maxPositionValue { + logger.Infof(" โš ๏ธ [RISK CONTROL] Position %.2f USDT exceeds limit (equity %.2f ร— %.1fx = %.2f USDT max for %s), capping", + positionSizeUSD, equity, maxPositionValueRatio, maxPositionValue, symbol) + return maxPositionValue, true + } + + return positionSizeUSD, false +} + +// enforceMinPositionSize checks minimum position size (CODE ENFORCED) +func (at *AutoTrader) enforceMinPositionSize(positionSizeUSD float64) error { + if at.config.StrategyConfig == nil { + return nil + } + + minSize := at.config.StrategyConfig.RiskControl.MinPositionSize + if minSize <= 0 { + minSize = 12 // Default: 12 USDT + } + + if positionSizeUSD < minSize { + return fmt.Errorf("โŒ [RISK CONTROL] Position %.2f USDT below minimum (%.2f USDT)", positionSizeUSD, minSize) + } + return nil +} + +// enforceMaxPositions checks maximum positions count (CODE ENFORCED) +func (at *AutoTrader) enforceMaxPositions(currentPositionCount int) error { + if at.config.StrategyConfig == nil { + return nil + } + + maxPositions := at.config.StrategyConfig.RiskControl.MaxPositions + if maxPositions <= 0 { + maxPositions = 3 // Default: 3 positions + } + + if currentPositionCount >= maxPositions { + return fmt.Errorf("โŒ [RISK CONTROL] Already at max positions (%d/%d)", currentPositionCount, maxPositions) + } + return nil +} + +// getSideFromAction converts order action to side (BUY/SELL) +func getSideFromAction(action string) string { + switch action { + case "open_long", "close_short": + return "BUY" + case "open_short", "close_long": + return "SELL" + default: + return "BUY" + } +} diff --git a/web/src/App.tsx b/web/src/App.tsx index e270264d..01d497db 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -4,27 +4,27 @@ import useSWR from 'swr' import { api } from './lib/api' import { TraderDashboardPage } from './pages/TraderDashboardPage' -import { AITradersPage } from './components/AITradersPage' -import { LoginPage } from './components/LoginPage' -import { SetupPage } from './components/SetupPage' +import { AITradersPage } from './components/trader/AITradersPage' +import { LoginPage } from './components/auth/LoginPage' +import { SetupPage } from './components/modals/SetupPage' import { SettingsPage } from './pages/SettingsPage' -import { ResetPasswordPage } from './components/ResetPasswordPage' -import { CompetitionPage } from './components/CompetitionPage' +import { ResetPasswordPage } from './components/auth/ResetPasswordPage' +import { CompetitionPage } from './components/trader/CompetitionPage' import { LandingPage } from './pages/LandingPage' import { FAQPage } from './pages/FAQPage' import { StrategyStudioPage } from './pages/StrategyStudioPage' import { StrategyMarketPage } from './pages/StrategyMarketPage' import { DataPage } from './pages/DataPage' -import { LoginRequiredOverlay } from './components/LoginRequiredOverlay' -import HeaderBar from './components/HeaderBar' +import { LoginRequiredOverlay } from './components/auth/LoginRequiredOverlay' +import HeaderBar from './components/common/HeaderBar' import { LanguageProvider, useLanguage } from './contexts/LanguageContext' import { AuthProvider, useAuth } from './contexts/AuthContext' -import { ConfirmDialogProvider } from './components/ConfirmDialog' +import { ConfirmDialogProvider } from './components/common/ConfirmDialog' import { t } from './i18n/translations' import { useSystemConfig } from './hooks/useSystemConfig' import { OFFICIAL_LINKS } from './constants/branding' -import { BacktestPage } from './components/BacktestPage' +import { BacktestPage } from './components/backtest/BacktestPage' import type { SystemStatus, AccountInfo, diff --git a/web/src/components/LoginPage.tsx b/web/src/components/auth/LoginPage.tsx similarity index 95% rename from web/src/components/LoginPage.tsx rename to web/src/components/auth/LoginPage.tsx index 07b4725f..d7fe3845 100644 --- a/web/src/components/LoginPage.tsx +++ b/web/src/components/auth/LoginPage.tsx @@ -1,10 +1,10 @@ import React, { useState, useEffect } from 'react' import { Eye, EyeOff } from 'lucide-react' import { toast } from 'sonner' -import { useAuth } from '../contexts/AuthContext' -import { useLanguage } from '../contexts/LanguageContext' -import { t } from '../i18n/translations' -import { DeepVoidBackground } from './DeepVoidBackground' +import { useAuth } from '../../contexts/AuthContext' +import { useLanguage } from '../../contexts/LanguageContext' +import { t } from '../../i18n/translations' +import { DeepVoidBackground } from '../common/DeepVoidBackground' export function LoginPage() { const { language } = useLanguage() diff --git a/web/src/components/LoginRequiredOverlay.tsx b/web/src/components/auth/LoginRequiredOverlay.tsx similarity index 98% rename from web/src/components/LoginRequiredOverlay.tsx rename to web/src/components/auth/LoginRequiredOverlay.tsx index 4b2cfcaa..b362cbf9 100644 --- a/web/src/components/LoginRequiredOverlay.tsx +++ b/web/src/components/auth/LoginRequiredOverlay.tsx @@ -1,7 +1,7 @@ import { motion, AnimatePresence } from 'framer-motion' import { LogIn, UserPlus, X, AlertTriangle, Terminal } from 'lucide-react' -import { DeepVoidBackground } from './DeepVoidBackground' -import { useLanguage } from '../contexts/LanguageContext' +import { DeepVoidBackground } from '../common/DeepVoidBackground' +import { useLanguage } from '../../contexts/LanguageContext' interface LoginRequiredOverlayProps { isOpen: boolean diff --git a/web/src/components/RegisterPage.test.tsx b/web/src/components/auth/RegisterPage.test.tsx similarity index 100% rename from web/src/components/RegisterPage.test.tsx rename to web/src/components/auth/RegisterPage.test.tsx diff --git a/web/src/components/RegisterPage.tsx b/web/src/components/auth/RegisterPage.tsx similarity index 97% rename from web/src/components/RegisterPage.tsx rename to web/src/components/auth/RegisterPage.tsx index 7485501b..e03af180 100644 --- a/web/src/components/RegisterPage.tsx +++ b/web/src/components/auth/RegisterPage.tsx @@ -2,13 +2,13 @@ import React, { useEffect, useState } from 'react' import { Eye, EyeOff } from 'lucide-react' import PasswordChecklist from 'react-password-checklist' import { toast } from 'sonner' -import { useAuth } from '../contexts/AuthContext' -import { useLanguage } from '../contexts/LanguageContext' -import { t } from '../i18n/translations' -import { getSystemConfig } from '../lib/config' -import { DeepVoidBackground } from './DeepVoidBackground' +import { useAuth } from '../../contexts/AuthContext' +import { useLanguage } from '../../contexts/LanguageContext' +import { t } from '../../i18n/translations' +import { getSystemConfig } from '../../lib/config' +import { DeepVoidBackground } from '../common/DeepVoidBackground' import { RegistrationDisabled } from './RegistrationDisabled' -import { WhitelistFullPage } from './WhitelistFullPage' +import { WhitelistFullPage } from '../common/WhitelistFullPage' export function RegisterPage() { const { language } = useLanguage() diff --git a/web/src/components/RegistrationDisabled.test.tsx b/web/src/components/auth/RegistrationDisabled.test.tsx similarity index 94% rename from web/src/components/RegistrationDisabled.test.tsx rename to web/src/components/auth/RegistrationDisabled.test.tsx index beec1b83..c7747f22 100644 --- a/web/src/components/RegistrationDisabled.test.tsx +++ b/web/src/components/auth/RegistrationDisabled.test.tsx @@ -1,11 +1,11 @@ import { describe, it, expect, vi } from 'vitest' import { render, screen, fireEvent } from '@testing-library/react' import { RegistrationDisabled } from './RegistrationDisabled' -import { LanguageProvider } from '../contexts/LanguageContext' +import { LanguageProvider } from '../../contexts/LanguageContext' // Mock useLanguage hook -vi.mock('../contexts/LanguageContext', async () => { - const actual = await vi.importActual('../contexts/LanguageContext') +vi.mock('../../contexts/LanguageContext', async () => { + const actual = await vi.importActual('../../contexts/LanguageContext') return { ...actual, useLanguage: () => ({ language: 'en' }), diff --git a/web/src/components/RegistrationDisabled.tsx b/web/src/components/auth/RegistrationDisabled.tsx similarity index 91% rename from web/src/components/RegistrationDisabled.tsx rename to web/src/components/auth/RegistrationDisabled.tsx index 048a9d62..3d2cc631 100644 --- a/web/src/components/RegistrationDisabled.tsx +++ b/web/src/components/auth/RegistrationDisabled.tsx @@ -1,5 +1,5 @@ -import { useLanguage } from '../contexts/LanguageContext' -import { t } from '../i18n/translations' +import { useLanguage } from '../../contexts/LanguageContext' +import { t } from '../../i18n/translations' export function RegistrationDisabled() { const { language } = useLanguage() diff --git a/web/src/components/ResetPasswordPage.tsx b/web/src/components/auth/ResetPasswordPage.tsx similarity index 97% rename from web/src/components/ResetPasswordPage.tsx rename to web/src/components/auth/ResetPasswordPage.tsx index 1eb569fe..5818a8d5 100644 --- a/web/src/components/ResetPasswordPage.tsx +++ b/web/src/components/auth/ResetPasswordPage.tsx @@ -1,11 +1,11 @@ import React, { useState } from 'react' -import { useAuth } from '../contexts/AuthContext' -import { useLanguage } from '../contexts/LanguageContext' -import { t } from '../i18n/translations' -import { Header } from './Header' +import { useAuth } from '../../contexts/AuthContext' +import { useLanguage } from '../../contexts/LanguageContext' +import { t } from '../../i18n/translations' +import { Header } from '../common/Header' import { ArrowLeft, KeyRound, Eye, EyeOff } from 'lucide-react' import PasswordChecklist from 'react-password-checklist' -import { Input } from './ui/input' +import { Input } from '../ui/input' import { toast } from 'sonner' export function ResetPasswordPage() { diff --git a/web/src/components/BacktestPage.tsx b/web/src/components/backtest/BacktestPage.tsx similarity index 99% rename from web/src/components/BacktestPage.tsx rename to web/src/components/backtest/BacktestPage.tsx index 70ffa6f3..4611832d 100644 --- a/web/src/components/BacktestPage.tsx +++ b/web/src/components/backtest/BacktestPage.tsx @@ -28,7 +28,7 @@ import { ArrowDownRight, CandlestickChart as CandlestickIcon, } from 'lucide-react' -import { DeepVoidBackground } from './DeepVoidBackground' +import { DeepVoidBackground } from '../common/DeepVoidBackground' import { ResponsiveContainer, AreaChart, @@ -39,12 +39,12 @@ import { Tooltip, ReferenceDot, } from 'recharts' -import { api } from '../lib/api' -import { useLanguage } from '../contexts/LanguageContext' -import { t } from '../i18n/translations' -import { confirmToast } from '../lib/notify' -import { DecisionCard } from './DecisionCard' -import { MetricTooltip } from './MetricTooltip' +import { api } from '../../lib/api' +import { useLanguage } from '../../contexts/LanguageContext' +import { t } from '../../i18n/translations' +import { confirmToast } from '../../lib/notify' +import { DecisionCard } from '../trader/DecisionCard' +import { MetricTooltip } from '../common/MetricTooltip' import type { BacktestStatusPayload, BacktestPositionStatus, @@ -55,7 +55,7 @@ import type { DecisionRecord, AIModel, Strategy, -} from '../types' +} from '../../types' // ============ Types ============ type WizardStep = 1 | 2 | 3 diff --git a/web/src/components/AdvancedChart.tsx b/web/src/components/charts/AdvancedChart.tsx similarity index 99% rename from web/src/components/AdvancedChart.tsx rename to web/src/components/charts/AdvancedChart.tsx index 749f42d6..9d3d1c59 100644 --- a/web/src/components/AdvancedChart.tsx +++ b/web/src/components/charts/AdvancedChart.tsx @@ -10,14 +10,14 @@ import { HistogramSeries, createSeriesMarkers, } from 'lightweight-charts' -import { useLanguage } from '../contexts/LanguageContext' -import { httpClient } from '../lib/httpClient' +import { useLanguage } from '../../contexts/LanguageContext' +import { httpClient } from '../../lib/httpClient' import { calculateSMA, calculateEMA, calculateBollingerBands, type Kline, -} from '../utils/indicators' +} from '../../utils/indicators' import { Settings, BarChart2 } from 'lucide-react' // ่ฎขๅ•ๆŽฅๅฃๅฎšไน‰ diff --git a/web/src/components/ChartTabs.tsx b/web/src/components/charts/ChartTabs.tsx similarity index 99% rename from web/src/components/ChartTabs.tsx rename to web/src/components/charts/ChartTabs.tsx index bdefb697..65db8a6c 100644 --- a/web/src/components/ChartTabs.tsx +++ b/web/src/components/charts/ChartTabs.tsx @@ -1,8 +1,8 @@ import { useState, useEffect, useRef } from 'react' import { EquityChart } from './EquityChart' import { AdvancedChart } from './AdvancedChart' -import { useLanguage } from '../contexts/LanguageContext' -import { t } from '../i18n/translations' +import { useLanguage } from '../../contexts/LanguageContext' +import { t } from '../../i18n/translations' import { BarChart3, CandlestickChart, ChevronDown, Search } from 'lucide-react' import { motion, AnimatePresence } from 'framer-motion' diff --git a/web/src/components/ChartWithOrders.tsx b/web/src/components/charts/ChartWithOrders.tsx similarity index 99% rename from web/src/components/ChartWithOrders.tsx rename to web/src/components/charts/ChartWithOrders.tsx index 5d25ae25..23495bf2 100644 --- a/web/src/components/ChartWithOrders.tsx +++ b/web/src/components/charts/ChartWithOrders.tsx @@ -8,8 +8,8 @@ import { CandlestickSeries, createSeriesMarkers, } from 'lightweight-charts' -import { useLanguage } from '../contexts/LanguageContext' -import { httpClient } from '../lib/httpClient' +import { useLanguage } from '../../contexts/LanguageContext' +import { httpClient } from '../../lib/httpClient' // ่ฎขๅ•ๆŽฅๅฃๅฎšไน‰ interface OrderMarker { diff --git a/web/src/components/ChartWithOrdersSimple.tsx b/web/src/components/charts/ChartWithOrdersSimple.tsx similarity index 98% rename from web/src/components/ChartWithOrdersSimple.tsx rename to web/src/components/charts/ChartWithOrdersSimple.tsx index 00576fd1..529c95b8 100644 --- a/web/src/components/ChartWithOrdersSimple.tsx +++ b/web/src/components/charts/ChartWithOrdersSimple.tsx @@ -1,5 +1,5 @@ import { useEffect, useState } from 'react' -import { httpClient } from '../lib/httpClient' +import { httpClient } from '../../lib/httpClient' interface ChartWithOrdersSimpleProps { symbol: string diff --git a/web/src/components/ComparisonChart.tsx b/web/src/components/charts/ComparisonChart.tsx similarity index 98% rename from web/src/components/ComparisonChart.tsx rename to web/src/components/charts/ComparisonChart.tsx index 4c7a08e4..4c91ba29 100644 --- a/web/src/components/ComparisonChart.tsx +++ b/web/src/components/charts/ComparisonChart.tsx @@ -12,11 +12,11 @@ import { ComposedChart, } from 'recharts' import useSWR from 'swr' -import { api } from '../lib/api' -import type { CompetitionTraderData } from '../types' -import { getTraderColor } from '../utils/traderColors' -import { useLanguage } from '../contexts/LanguageContext' -import { t } from '../i18n/translations' +import { api } from '../../lib/api' +import type { CompetitionTraderData } from '../../types' +import { getTraderColor } from '../../utils/traderColors' +import { useLanguage } from '../../contexts/LanguageContext' +import { t } from '../../i18n/translations' import { BarChart3, TrendingUp, TrendingDown, Zap } from 'lucide-react' // Time period options: 1D, 3D, 7D, 30D, All diff --git a/web/src/components/EquityChart.tsx b/web/src/components/charts/EquityChart.tsx similarity index 98% rename from web/src/components/EquityChart.tsx rename to web/src/components/charts/EquityChart.tsx index b0b0d203..407baeda 100644 --- a/web/src/components/EquityChart.tsx +++ b/web/src/components/charts/EquityChart.tsx @@ -10,10 +10,10 @@ import { ReferenceLine, } from 'recharts' import useSWR from 'swr' -import { api } from '../lib/api' -import { useLanguage } from '../contexts/LanguageContext' -import { useAuth } from '../contexts/AuthContext' -import { t } from '../i18n/translations' +import { api } from '../../lib/api' +import { useLanguage } from '../../contexts/LanguageContext' +import { useAuth } from '../../contexts/AuthContext' +import { t } from '../../i18n/translations' import { AlertTriangle, BarChart3, diff --git a/web/src/components/TradingViewChart.tsx b/web/src/components/charts/TradingViewChart.tsx similarity index 99% rename from web/src/components/TradingViewChart.tsx rename to web/src/components/charts/TradingViewChart.tsx index 7b52a287..a3850343 100644 --- a/web/src/components/TradingViewChart.tsx +++ b/web/src/components/charts/TradingViewChart.tsx @@ -1,6 +1,6 @@ import { useEffect, useRef, useState, memo } from 'react' -import { useLanguage } from '../contexts/LanguageContext' -import { t } from '../i18n/translations' +import { useLanguage } from '../../contexts/LanguageContext' +import { t } from '../../i18n/translations' import { ChevronDown, TrendingUp, X } from 'lucide-react' // ๆ”ฏๆŒ็š„ไบคๆ˜“ๆ‰€ๅˆ—่กจ (ๅˆ็บฆๆ ผๅผ) diff --git a/web/src/components/ConfirmDialog.tsx b/web/src/components/common/ConfirmDialog.tsx similarity index 97% rename from web/src/components/ConfirmDialog.tsx rename to web/src/components/common/ConfirmDialog.tsx index c2d9b711..2769d63d 100644 --- a/web/src/components/ConfirmDialog.tsx +++ b/web/src/components/common/ConfirmDialog.tsx @@ -13,8 +13,8 @@ import { AlertDialogDescription, AlertDialogFooter, AlertDialogTitle, -} from './ui/alert-dialog' -import { setGlobalConfirm } from '../lib/notify' +} from '../ui/alert-dialog' +import { setGlobalConfirm } from '../../lib/notify' interface ConfirmOptions { title?: string diff --git a/web/src/components/Container.tsx b/web/src/components/common/Container.tsx similarity index 100% rename from web/src/components/Container.tsx rename to web/src/components/common/Container.tsx diff --git a/web/src/components/DeepVoidBackground.tsx b/web/src/components/common/DeepVoidBackground.tsx similarity index 100% rename from web/src/components/DeepVoidBackground.tsx rename to web/src/components/common/DeepVoidBackground.tsx diff --git a/web/src/components/ExchangeIcons.tsx b/web/src/components/common/ExchangeIcons.tsx similarity index 100% rename from web/src/components/ExchangeIcons.tsx rename to web/src/components/common/ExchangeIcons.tsx diff --git a/web/src/components/Header.tsx b/web/src/components/common/Header.tsx similarity index 95% rename from web/src/components/Header.tsx rename to web/src/components/common/Header.tsx index c5ae5e1e..cf95ab55 100644 --- a/web/src/components/Header.tsx +++ b/web/src/components/common/Header.tsx @@ -1,5 +1,5 @@ -import { useLanguage } from '../contexts/LanguageContext' -import { t } from '../i18n/translations' +import { useLanguage } from '../../contexts/LanguageContext' +import { t } from '../../i18n/translations' import { Container } from './Container' interface HeaderProps { diff --git a/web/src/components/HeaderBar.tsx b/web/src/components/common/HeaderBar.tsx similarity index 99% rename from web/src/components/HeaderBar.tsx rename to web/src/components/common/HeaderBar.tsx index cfe46903..4427fec9 100644 --- a/web/src/components/HeaderBar.tsx +++ b/web/src/components/common/HeaderBar.tsx @@ -2,8 +2,8 @@ import { useState, useEffect, useRef } from 'react' import { useNavigate } from 'react-router-dom' import { motion, AnimatePresence } from 'framer-motion' import { Menu, X, ChevronDown, Settings } from 'lucide-react' -import { t, type Language } from '../i18n/translations' -import { OFFICIAL_LINKS } from '../constants/branding' +import { t, type Language } from '../../i18n/translations' +import { OFFICIAL_LINKS } from '../../constants/branding' type Page = | 'competition' diff --git a/web/src/components/MetricTooltip.tsx b/web/src/components/common/MetricTooltip.tsx similarity index 100% rename from web/src/components/MetricTooltip.tsx rename to web/src/components/common/MetricTooltip.tsx diff --git a/web/src/components/ModelIcons.tsx b/web/src/components/common/ModelIcons.tsx similarity index 100% rename from web/src/components/ModelIcons.tsx rename to web/src/components/common/ModelIcons.tsx diff --git a/web/src/components/PunkAvatar.tsx b/web/src/components/common/PunkAvatar.tsx similarity index 100% rename from web/src/components/PunkAvatar.tsx rename to web/src/components/common/PunkAvatar.tsx diff --git a/web/src/components/WebCryptoEnvironmentCheck.tsx b/web/src/components/common/WebCryptoEnvironmentCheck.tsx similarity index 98% rename from web/src/components/WebCryptoEnvironmentCheck.tsx rename to web/src/components/common/WebCryptoEnvironmentCheck.tsx index aaa9deb6..7532d067 100644 --- a/web/src/components/WebCryptoEnvironmentCheck.tsx +++ b/web/src/components/common/WebCryptoEnvironmentCheck.tsx @@ -1,7 +1,7 @@ import { useCallback, useEffect, useState, type ReactNode } from 'react' import { Loader2, ShieldAlert, ShieldCheck, ShieldMinus } from 'lucide-react' -import { CryptoService, diagnoseWebCryptoEnvironment } from '../lib/crypto' -import { t, type Language } from '../i18n/translations' +import { CryptoService, diagnoseWebCryptoEnvironment } from '../../lib/crypto' +import { t, type Language } from '../../i18n/translations' export type WebCryptoCheckStatus = | 'idle' diff --git a/web/src/components/WhitelistFullPage.tsx b/web/src/components/common/WhitelistFullPage.tsx similarity index 99% rename from web/src/components/WhitelistFullPage.tsx rename to web/src/components/common/WhitelistFullPage.tsx index 506b4afd..d18593fa 100644 --- a/web/src/components/WhitelistFullPage.tsx +++ b/web/src/components/common/WhitelistFullPage.tsx @@ -1,6 +1,6 @@ import { motion } from 'framer-motion' import { ShieldAlert, ArrowLeft, Twitter, Send, Lock } from 'lucide-react' -import { OFFICIAL_LINKS } from '../constants/branding' +import { OFFICIAL_LINKS } from '../../constants/branding' interface WhitelistFullPageProps { onBack?: () => void diff --git a/web/src/components/faq/FAQLayout.tsx b/web/src/components/faq/FAQLayout.tsx index ecff7659..cd70e31a 100644 --- a/web/src/components/faq/FAQLayout.tsx +++ b/web/src/components/faq/FAQLayout.tsx @@ -1,6 +1,6 @@ import { useState, useMemo } from 'react' import { HelpCircle } from 'lucide-react' -import { DeepVoidBackground } from '../DeepVoidBackground' +import { DeepVoidBackground } from '../common/DeepVoidBackground' import { t, type Language } from '../../i18n/translations' import { FAQSearchBar } from './FAQSearchBar' import { FAQSidebar } from './FAQSidebar' diff --git a/web/src/components/SetupPage.tsx b/web/src/components/modals/SetupPage.tsx similarity index 96% rename from web/src/components/SetupPage.tsx rename to web/src/components/modals/SetupPage.tsx index 5ad87df7..6243e4a4 100644 --- a/web/src/components/SetupPage.tsx +++ b/web/src/components/modals/SetupPage.tsx @@ -1,8 +1,8 @@ import React, { useState } from 'react' import { Eye, EyeOff } from 'lucide-react' -import { useAuth } from '../contexts/AuthContext' -import { DeepVoidBackground } from './DeepVoidBackground' -import { invalidateSystemConfig } from '../lib/config' +import { useAuth } from '../../contexts/AuthContext' +import { DeepVoidBackground } from '../common/DeepVoidBackground' +import { invalidateSystemConfig } from '../../lib/config' export function SetupPage() { const { register } = useAuth() diff --git a/web/src/components/TwoStageKeyModal.tsx b/web/src/components/modals/TwoStageKeyModal.tsx similarity index 98% rename from web/src/components/TwoStageKeyModal.tsx rename to web/src/components/modals/TwoStageKeyModal.tsx index d960b28a..2f2f692d 100644 --- a/web/src/components/TwoStageKeyModal.tsx +++ b/web/src/components/modals/TwoStageKeyModal.tsx @@ -1,8 +1,8 @@ import { useEffect, useMemo, useRef, useState } from 'react' import { createPortal } from 'react-dom' -import { t, type Language } from '../i18n/translations' +import { t, type Language } from '../../i18n/translations' import { toast } from 'sonner' -import { WebCryptoEnvironmentCheck } from './WebCryptoEnvironmentCheck' +import { WebCryptoEnvironmentCheck } from '../common/WebCryptoEnvironmentCheck' const DEFAULT_LENGTH = 64 diff --git a/web/src/components/AITradersPage.tsx b/web/src/components/trader/AITradersPage.tsx similarity index 99% rename from web/src/components/AITradersPage.tsx rename to web/src/components/trader/AITradersPage.tsx index 43e6fed4..41674f30 100644 --- a/web/src/components/AITradersPage.tsx +++ b/web/src/components/trader/AITradersPage.tsx @@ -1,23 +1,23 @@ import React, { useState, useEffect } from 'react' import { useNavigate } from 'react-router-dom' import useSWR from 'swr' -import { api } from '../lib/api' +import { api } from '../../lib/api' import type { TraderInfo, CreateTraderRequest, AIModel, Exchange, -} from '../types' -import { useLanguage } from '../contexts/LanguageContext' -import { t, type Language } from '../i18n/translations' -import { useAuth } from '../contexts/AuthContext' -import { getExchangeIcon } from './ExchangeIcons' -import { getModelIcon } from './ModelIcons' +} from '../../types' +import { useLanguage } from '../../contexts/LanguageContext' +import { t, type Language } from '../../i18n/translations' +import { useAuth } from '../../contexts/AuthContext' +import { getExchangeIcon } from '../common/ExchangeIcons' +import { getModelIcon } from '../common/ModelIcons' import { TraderConfigModal } from './TraderConfigModal' -import { DeepVoidBackground } from './DeepVoidBackground' -import { ExchangeConfigModal } from './traders/ExchangeConfigModal' -import { TelegramConfigModal } from './traders/TelegramConfigModal' -import { PunkAvatar, getTraderAvatar } from './PunkAvatar' +import { DeepVoidBackground } from '../common/DeepVoidBackground' +import { ExchangeConfigModal } from './ExchangeConfigModal' +import { TelegramConfigModal } from './TelegramConfigModal' +import { PunkAvatar, getTraderAvatar } from '../common/PunkAvatar' import { Bot, Brain, @@ -34,7 +34,7 @@ import { Check, MessageCircle, } from 'lucide-react' -import { confirmToast } from '../lib/notify' +import { confirmToast } from '../../lib/notify' import { toast } from 'sonner' // ่Žทๅ–ๅ‹ๅฅฝ็š„AIๆจกๅž‹ๅ็งฐ diff --git a/web/src/components/CompetitionPage.test.tsx b/web/src/components/trader/CompetitionPage.test.tsx similarity index 100% rename from web/src/components/CompetitionPage.test.tsx rename to web/src/components/trader/CompetitionPage.test.tsx diff --git a/web/src/components/CompetitionPage.tsx b/web/src/components/trader/CompetitionPage.tsx similarity index 97% rename from web/src/components/CompetitionPage.tsx rename to web/src/components/trader/CompetitionPage.tsx index afb88c05..71395eb0 100644 --- a/web/src/components/CompetitionPage.tsx +++ b/web/src/components/trader/CompetitionPage.tsx @@ -1,15 +1,15 @@ import { useState } from 'react' import { Trophy } from 'lucide-react' import useSWR from 'swr' -import { api } from '../lib/api' -import type { CompetitionData } from '../types' -import { ComparisonChart } from './ComparisonChart' +import { api } from '../../lib/api' +import type { CompetitionData } from '../../types' +import { ComparisonChart } from '../charts/ComparisonChart' import { TraderConfigViewModal } from './TraderConfigViewModal' -import { getTraderColor } from '../utils/traderColors' -import { useLanguage } from '../contexts/LanguageContext' -import { t } from '../i18n/translations' -import { PunkAvatar, getTraderAvatar } from './PunkAvatar' -import { DeepVoidBackground } from './DeepVoidBackground' +import { getTraderColor } from '../../utils/traderColors' +import { useLanguage } from '../../contexts/LanguageContext' +import { t } from '../../i18n/translations' +import { PunkAvatar, getTraderAvatar } from '../common/PunkAvatar' +import { DeepVoidBackground } from '../common/DeepVoidBackground' export function CompetitionPage() { const { language } = useLanguage() diff --git a/web/src/components/DecisionCard.tsx b/web/src/components/trader/DecisionCard.tsx similarity index 99% rename from web/src/components/DecisionCard.tsx rename to web/src/components/trader/DecisionCard.tsx index 8446c0a6..75b18c67 100644 --- a/web/src/components/DecisionCard.tsx +++ b/web/src/components/trader/DecisionCard.tsx @@ -1,6 +1,6 @@ import { useState } from 'react' -import type { DecisionRecord, DecisionAction } from '../types' -import { t, type Language } from '../i18n/translations' +import type { DecisionRecord, DecisionAction } from '../../types' +import { t, type Language } from '../../i18n/translations' interface DecisionCardProps { decision: DecisionRecord diff --git a/web/src/components/traders/ExchangeConfigModal.tsx b/web/src/components/trader/ExchangeConfigModal.tsx similarity index 99% rename from web/src/components/traders/ExchangeConfigModal.tsx rename to web/src/components/trader/ExchangeConfigModal.tsx index 769ba666..ecf01b1a 100644 --- a/web/src/components/traders/ExchangeConfigModal.tsx +++ b/web/src/components/trader/ExchangeConfigModal.tsx @@ -2,15 +2,15 @@ import React, { useState, useEffect } from 'react' import type { Exchange } from '../../types' import { t, type Language } from '../../i18n/translations' import { api } from '../../lib/api' -import { getExchangeIcon } from '../ExchangeIcons' +import { getExchangeIcon } from '../common/ExchangeIcons' import { TwoStageKeyModal, type TwoStageKeyModalResult, -} from '../TwoStageKeyModal' +} from '../modals/TwoStageKeyModal' import { WebCryptoEnvironmentCheck, type WebCryptoCheckStatus, -} from '../WebCryptoEnvironmentCheck' +} from '../common/WebCryptoEnvironmentCheck' import { BookOpen, Trash2, HelpCircle, ExternalLink, UserPlus, Key, Shield, ChevronLeft, Check, Copy, ArrowRight diff --git a/web/src/components/PositionHistory.tsx b/web/src/components/trader/PositionHistory.tsx similarity index 99% rename from web/src/components/PositionHistory.tsx rename to web/src/components/trader/PositionHistory.tsx index 4d669f64..03b17596 100644 --- a/web/src/components/PositionHistory.tsx +++ b/web/src/components/trader/PositionHistory.tsx @@ -1,15 +1,15 @@ import { useState, useEffect, useMemo } from 'react' -import { api } from '../lib/api' -import { useLanguage } from '../contexts/LanguageContext' -import { t, type Language } from '../i18n/translations' -import { MetricTooltip } from './MetricTooltip' -import { formatPrice, formatQuantity } from '../utils/format' +import { api } from '../../lib/api' +import { useLanguage } from '../../contexts/LanguageContext' +import { t, type Language } from '../../i18n/translations' +import { MetricTooltip } from '../common/MetricTooltip' +import { formatPrice, formatQuantity } from '../../utils/format' import type { HistoricalPosition, TraderStats, SymbolStats, DirectionStats, -} from '../types' +} from '../../types' interface PositionHistoryProps { traderId: string diff --git a/web/src/components/traders/TelegramConfigModal.tsx b/web/src/components/trader/TelegramConfigModal.tsx similarity index 100% rename from web/src/components/traders/TelegramConfigModal.tsx rename to web/src/components/trader/TelegramConfigModal.tsx diff --git a/web/src/components/traders/Tooltip.tsx b/web/src/components/trader/Tooltip.tsx similarity index 100% rename from web/src/components/traders/Tooltip.tsx rename to web/src/components/trader/Tooltip.tsx diff --git a/web/src/components/TraderConfigModal.tsx b/web/src/components/trader/TraderConfigModal.tsx similarity index 99% rename from web/src/components/TraderConfigModal.tsx rename to web/src/components/trader/TraderConfigModal.tsx index c1c5a27a..b0889a1b 100644 --- a/web/src/components/TraderConfigModal.tsx +++ b/web/src/components/trader/TraderConfigModal.tsx @@ -1,10 +1,10 @@ import { useState, useEffect } from 'react' -import type { AIModel, Exchange, CreateTraderRequest, Strategy } from '../types' -import { useLanguage } from '../contexts/LanguageContext' -import { t } from '../i18n/translations' +import type { AIModel, Exchange, CreateTraderRequest, Strategy } from '../../types' +import { useLanguage } from '../../contexts/LanguageContext' +import { t } from '../../i18n/translations' import { toast } from 'sonner' import { Pencil, Plus, X as IconX, Sparkles, ExternalLink, UserPlus } from 'lucide-react' -import { httpClient } from '../lib/httpClient' +import { httpClient } from '../../lib/httpClient' // ๆๅ–ไธ‹ๅˆ’็บฟๅŽ้ข็š„ๅ็งฐ้ƒจๅˆ† function getShortName(fullName: string): string { @@ -22,7 +22,7 @@ const EXCHANGE_REGISTRATION_LINKS: Record