refactor: restructure project directories for better modularity

- Delete llm/ dead code (3 files, zero references)
- Split mcp/ into sub-packages: mcp/provider/ (8 providers) and
  mcp/payment/ (4 payment clients) with registry pattern
- Export Client internal fields and ClientHooks interface for
  sub-package access
- Split api/server.go (3892 lines) into 8 domain-specific handler files
- Split trader/auto_trader.go (2296 lines) into 5 focused files
- Reorganize web/src/components/ flat files into auth/, charts/,
  trader/, common/, modals/, backtest/ subdirectories
- Update all consumer imports to use registry-based provider creation
This commit is contained in:
tinkle-community
2026-03-11 23:58:13 +08:00
parent 6a30e11ee5
commit 8e294a5eed
103 changed files with 6391 additions and 8984 deletions
+211
View File
@@ -0,0 +1,211 @@
package api
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"nofx/config"
"nofx/crypto"
"nofx/logger"
"nofx/security"
"github.com/gin-gonic/gin"
)
type ModelConfig struct {
ID string `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
Enabled bool `json:"enabled"`
APIKey string `json:"apiKey,omitempty"`
CustomAPIURL string `json:"customApiUrl,omitempty"`
}
// SafeModelConfig Safe model configuration structure (does not contain sensitive information)
type SafeModelConfig struct {
ID string `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
Enabled bool `json:"enabled"`
CustomAPIURL string `json:"customApiUrl"` // Custom API URL (usually not sensitive)
CustomModelName string `json:"customModelName"` // Custom model name (not sensitive)
}
type UpdateModelConfigRequest struct {
Models map[string]struct {
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
CustomAPIURL string `json:"custom_api_url"`
CustomModelName string `json:"custom_model_name"`
} `json:"models"`
}
// handleGetModelConfigs Get AI model configurations
func (s *Server) handleGetModelConfigs(c *gin.Context) {
userID := c.GetString("user_id")
logger.Infof("🔍 Querying AI model configs for user %s", userID)
models, err := s.store.AIModel().List(userID)
if err != nil {
logger.Infof("❌ Failed to get AI model configs: %v", err)
SafeInternalError(c, "Failed to get AI model configs", err)
return
}
// If no models in database, return default models
if len(models) == 0 {
logger.Infof("⚠️ No AI models in database, returning defaults")
defaultModels := []SafeModelConfig{
{ID: "deepseek", Name: "DeepSeek AI", Provider: "deepseek", Enabled: false},
{ID: "qwen", Name: "Qwen AI", Provider: "qwen", Enabled: false},
{ID: "openai", Name: "OpenAI", Provider: "openai", Enabled: false},
{ID: "claude", Name: "Claude AI", Provider: "claude", Enabled: false},
{ID: "gemini", Name: "Gemini AI", Provider: "gemini", Enabled: false},
{ID: "grok", Name: "Grok AI", Provider: "grok", Enabled: false},
{ID: "kimi", Name: "Kimi AI", Provider: "kimi", Enabled: false},
{ID: "minimax", Name: "MiniMax AI", Provider: "minimax", Enabled: false},
}
c.JSON(http.StatusOK, defaultModels)
return
}
logger.Infof("✅ Found %d AI model configs", len(models))
// Convert to safe response structure, remove sensitive information
safeModels := make([]SafeModelConfig, len(models))
for i, model := range models {
safeModels[i] = SafeModelConfig{
ID: model.ID,
Name: model.Name,
Provider: model.Provider,
Enabled: model.Enabled,
CustomAPIURL: model.CustomAPIURL,
CustomModelName: model.CustomModelName,
}
}
c.JSON(http.StatusOK, safeModels)
}
// handleUpdateModelConfigs Update AI model configurations (supports both encrypted and plain text based on config)
func (s *Server) handleUpdateModelConfigs(c *gin.Context) {
userID := c.GetString("user_id")
cfg := config.Get()
// Read raw request body
bodyBytes, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"})
return
}
var req UpdateModelConfigRequest
// Check if transport encryption is enabled
if !cfg.TransportEncryption {
// Transport encryption disabled, accept plain JSON
if err := json.Unmarshal(bodyBytes, &req); err != nil {
logger.Infof("❌ Failed to parse plain JSON request: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
return
}
logger.Infof("📝 Received plain text model config (UserID: %s)", userID)
} else {
// Transport encryption enabled, require encrypted payload
var encryptedPayload crypto.EncryptedPayload
if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil {
logger.Infof("❌ Failed to parse encrypted payload: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format, encrypted transmission required"})
return
}
// Verify encrypted data
if encryptedPayload.WrappedKey == "" {
logger.Infof("❌ Detected unencrypted request (UserID: %s)", userID)
c.JSON(http.StatusBadRequest, gin.H{
"error": "This endpoint only supports encrypted transmission, please use encrypted client",
"code": "ENCRYPTION_REQUIRED",
"message": "Encrypted transmission is required for security reasons",
})
return
}
// Decrypt data
decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload)
if err != nil {
logger.Infof("❌ Failed to decrypt model config (UserID: %s): %v", userID, err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to decrypt data"})
return
}
// Parse decrypted data
if err := json.Unmarshal([]byte(decrypted), &req); err != nil {
logger.Infof("❌ Failed to parse decrypted data: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse decrypted data"})
return
}
logger.Infof("🔓 Decrypted model config data (UserID: %s)", userID)
}
// Update each model's configuration and track traders that need reload
tradersToReload := make(map[string]bool)
for modelID, modelData := range req.Models {
// SSRF protection: validate custom_api_url before storing
if modelData.CustomAPIURL != "" {
cleanURL := strings.TrimSuffix(modelData.CustomAPIURL, "#")
if err := security.ValidateURL(cleanURL); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid custom_api_url for model %s: %s", modelID, err.Error())})
return
}
}
// Find traders using this AI model BEFORE updating
traders, _ := s.store.Trader().ListByAIModelID(userID, modelID)
for _, t := range traders {
tradersToReload[t.ID] = true
}
err := s.store.AIModel().Update(userID, modelID, modelData.Enabled, modelData.APIKey, modelData.CustomAPIURL, modelData.CustomModelName)
if err != nil {
SafeInternalError(c, fmt.Sprintf("Update model %s", modelID), err)
return
}
}
// Remove affected traders from memory BEFORE reloading to pick up new config
for traderID := range tradersToReload {
logger.Infof("🔄 Removing trader %s from memory to reload with new AI model config", traderID)
s.traderManager.RemoveTrader(traderID)
}
// Reload all traders for this user to make new config take effect immediately
err = s.traderManager.LoadUserTradersFromStore(s.store, userID)
if err != nil {
logger.Infof("⚠️ Failed to reload user traders into memory: %v", err)
// Don't return error here since model config was successfully updated to database
}
logger.Infof("✓ AI model config updated: %+v", req.Models)
c.JSON(http.StatusOK, gin.H{"message": "Model configuration updated"})
}
// handleGetSupportedModels Get list of AI models supported by the system
func (s *Server) handleGetSupportedModels(c *gin.Context) {
// Return static list of supported AI models with default versions
supportedModels := []map[string]interface{}{
{"id": "deepseek", "name": "DeepSeek", "provider": "deepseek", "defaultModel": "deepseek-chat"},
{"id": "qwen", "name": "Qwen", "provider": "qwen", "defaultModel": "qwen3-max"},
{"id": "openai", "name": "OpenAI", "provider": "openai", "defaultModel": "gpt-5.1"},
{"id": "claude", "name": "Claude", "provider": "claude", "defaultModel": "claude-opus-4-6"},
{"id": "gemini", "name": "Google Gemini", "provider": "gemini", "defaultModel": "gemini-3-pro-preview"},
{"id": "grok", "name": "Grok (xAI)", "provider": "grok", "defaultModel": "grok-3-latest"},
{"id": "kimi", "name": "Kimi (Moonshot)", "provider": "kimi", "defaultModel": "moonshot-v1-auto"},
{"id": "minimax", "name": "MiniMax", "provider": "minimax", "defaultModel": "MiniMax-M2.5"},
{"id": "blockrun-base", "name": "BlockRun (Base Wallet)", "provider": "blockrun-base", "defaultModel": "auto"},
{"id": "blockrun-sol", "name": "BlockRun (Solana Wallet)", "provider": "blockrun-sol", "defaultModel": "auto"},
{"id": "claw402", "name": "Claw402 (Base USDC)", "provider": "claw402", "defaultModel": "deepseek"},
}
c.JSON(http.StatusOK, supportedModels)
}
+469
View File
@@ -0,0 +1,469 @@
package api
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
"nofx/logger"
"nofx/store"
"github.com/gin-gonic/gin"
)
// handleDecisions Decision log list
func (s *Server) handleDecisions(c *gin.Context) {
_, traderID, err := s.getTraderFromQuery(c)
if err != nil {
SafeBadRequest(c, "Invalid trader ID")
return
}
trader, err := s.traderManager.GetTrader(traderID)
if err != nil {
SafeNotFound(c, "Trader")
return
}
// Get all historical decision records (unlimited)
records, err := trader.GetStore().Decision().GetLatestRecords(trader.GetID(), 10000)
if err != nil {
SafeInternalError(c, "Get decision log", err)
return
}
c.JSON(http.StatusOK, records)
}
// handleLatestDecisions Latest decision logs (newest first, supports limit parameter)
func (s *Server) handleLatestDecisions(c *gin.Context) {
_, traderID, err := s.getTraderFromQuery(c)
if err != nil {
SafeBadRequest(c, "Invalid trader ID")
return
}
trader, err := s.traderManager.GetTrader(traderID)
if err != nil {
SafeNotFound(c, "Trader")
return
}
// Get limit from query parameter, default to 5
limit := 5
if limitStr := c.Query("limit"); limitStr != "" {
if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 {
limit = parsedLimit
if limit > 100 {
limit = 100 // Max 100 to prevent abuse
}
}
}
records, err := trader.GetStore().Decision().GetLatestRecords(trader.GetID(), limit)
if err != nil {
SafeInternalError(c, "Get decision log", err)
return
}
// Reverse array to put newest first (for list display)
// GetLatestRecords returns oldest to newest (for charts), here we need newest to oldest
for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 {
records[i], records[j] = records[j], records[i]
}
c.JSON(http.StatusOK, records)
}
// handleStatistics Statistics information
func (s *Server) handleStatistics(c *gin.Context) {
_, traderID, err := s.getTraderFromQuery(c)
if err != nil {
SafeBadRequest(c, "Invalid trader ID")
return
}
trader, err := s.traderManager.GetTrader(traderID)
if err != nil {
SafeNotFound(c, "Trader")
return
}
stats, err := trader.GetStore().Decision().GetStatistics(trader.GetID())
if err != nil {
SafeInternalError(c, "Get statistics", err)
return
}
c.JSON(http.StatusOK, stats)
}
// handleCompetition Competition overview (compare all traders)
func (s *Server) handleCompetition(c *gin.Context) {
userID := c.GetString("user_id")
// Ensure user's traders are loaded into memory
err := s.traderManager.LoadUserTradersFromStore(s.store, userID)
if err != nil {
logger.Infof("⚠️ Failed to load traders for user %s: %v", userID, err)
}
competition, err := s.traderManager.GetCompetitionData()
if err != nil {
SafeInternalError(c, "Get competition data", err)
return
}
c.JSON(http.StatusOK, competition)
}
// handleEquityHistory Return rate historical data
// Query directly from database, not dependent on trader in memory (so historical data can be retrieved after restart)
func (s *Server) handleEquityHistory(c *gin.Context) {
_, traderID, err := s.getTraderFromQuery(c)
if err != nil {
SafeBadRequest(c, "Invalid trader ID")
return
}
// Get equity historical data from new equity table
// Every 3 minutes per cycle: 10000 records = about 20 days of data
snapshots, err := s.store.Equity().GetLatest(traderID, 10000)
if err != nil {
SafeInternalError(c, "Get historical data", err)
return
}
if len(snapshots) == 0 {
c.JSON(http.StatusOK, []interface{}{})
return
}
// Build return rate historical data points
type EquityPoint struct {
Timestamp string `json:"timestamp"`
TotalEquity float64 `json:"total_equity"` // Account equity (wallet + unrealized)
AvailableBalance float64 `json:"available_balance"` // Available balance
TotalPnL float64 `json:"total_pnl"` // Total PnL (unrealized PnL)
TotalPnLPct float64 `json:"total_pnl_pct"` // Total PnL percentage
PositionCount int `json:"position_count"` // Position count
MarginUsedPct float64 `json:"margin_used_pct"` // Margin used percentage
}
// Use the balance of the first record as initial balance to calculate return rate
initialBalance := snapshots[0].Balance
if initialBalance == 0 {
initialBalance = 1 // Avoid division by zero
}
var history []EquityPoint
for _, snap := range snapshots {
// Calculate PnL percentage
totalPnLPct := 0.0
if initialBalance > 0 {
totalPnLPct = (snap.UnrealizedPnL / initialBalance) * 100
}
history = append(history, EquityPoint{
Timestamp: snap.Timestamp.Format("2006-01-02 15:04:05"),
TotalEquity: snap.TotalEquity,
AvailableBalance: snap.Balance,
TotalPnL: snap.UnrealizedPnL,
TotalPnLPct: totalPnLPct,
PositionCount: snap.PositionCount,
MarginUsedPct: snap.MarginUsedPct,
})
}
c.JSON(http.StatusOK, history)
}
// handlePublicTraderList Get public trader list (no authentication required)
func (s *Server) handlePublicTraderList(c *gin.Context) {
// Get trader information from all users
competition, err := s.traderManager.GetCompetitionData()
if err != nil {
SafeInternalError(c, "Get trader list", err)
return
}
// Get traders array
tradersData, exists := competition["traders"]
if !exists {
c.JSON(http.StatusOK, []map[string]interface{}{})
return
}
traders, ok := tradersData.([]map[string]interface{})
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Trader data format error",
})
return
}
// Return trader basic information, filter sensitive information
result := make([]map[string]interface{}, 0, len(traders))
for _, trader := range traders {
result = append(result, map[string]interface{}{
"trader_id": trader["trader_id"],
"trader_name": trader["trader_name"],
"ai_model": trader["ai_model"],
"exchange": trader["exchange"],
"is_running": trader["is_running"],
"total_equity": trader["total_equity"],
"total_pnl": trader["total_pnl"],
"total_pnl_pct": trader["total_pnl_pct"],
"position_count": trader["position_count"],
"margin_used_pct": trader["margin_used_pct"],
})
}
c.JSON(http.StatusOK, result)
}
// handlePublicCompetition Get public competition data (no authentication required)
func (s *Server) handlePublicCompetition(c *gin.Context) {
competition, err := s.traderManager.GetCompetitionData()
if err != nil {
SafeInternalError(c, "Get competition data", err)
return
}
c.JSON(http.StatusOK, competition)
}
// handleTopTraders Get top 5 trader data (no authentication required, for performance comparison)
func (s *Server) handleTopTraders(c *gin.Context) {
topTraders, err := s.traderManager.GetTopTradersData()
if err != nil {
SafeInternalError(c, "Get top traders data", err)
return
}
c.JSON(http.StatusOK, topTraders)
}
// handleEquityHistoryBatch Batch get return rate historical data for multiple traders (no authentication required, for performance comparison)
// Supports optional 'hours' parameter to filter data by time range (e.g., hours=24 for last 24 hours)
func (s *Server) handleEquityHistoryBatch(c *gin.Context) {
var requestBody struct {
TraderIDs []string `json:"trader_ids"`
Hours int `json:"hours"` // Optional: filter by last N hours (0 = all data)
}
// Try to parse POST request JSON body
if err := c.ShouldBindJSON(&requestBody); err != nil {
// If JSON parse fails, try to get from query parameters (compatible with GET request)
traderIDsParam := c.Query("trader_ids")
if traderIDsParam == "" {
// If no trader_ids specified, return historical data for top 5
topTraders, err := s.traderManager.GetTopTradersData()
if err != nil {
SafeInternalError(c, "Get top traders", err)
return
}
traders, ok := topTraders["traders"].([]map[string]interface{})
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Trader data format error"})
return
}
// Extract trader IDs
traderIDs := make([]string, 0, len(traders))
for _, trader := range traders {
if traderID, ok := trader["trader_id"].(string); ok {
traderIDs = append(traderIDs, traderID)
}
}
// Parse hours parameter from query
hoursParam := c.Query("hours")
hours := 0
if hoursParam != "" {
fmt.Sscanf(hoursParam, "%d", &hours)
}
result := s.getEquityHistoryForTraders(traderIDs, hours)
c.JSON(http.StatusOK, result)
return
}
// Parse comma-separated trader IDs
requestBody.TraderIDs = strings.Split(traderIDsParam, ",")
for i := range requestBody.TraderIDs {
requestBody.TraderIDs[i] = strings.TrimSpace(requestBody.TraderIDs[i])
}
// Parse hours parameter from query
hoursParam := c.Query("hours")
if hoursParam != "" {
fmt.Sscanf(hoursParam, "%d", &requestBody.Hours)
}
}
// Limit to maximum 20 traders to prevent oversized requests
if len(requestBody.TraderIDs) > 20 {
requestBody.TraderIDs = requestBody.TraderIDs[:20]
}
result := s.getEquityHistoryForTraders(requestBody.TraderIDs, requestBody.Hours)
c.JSON(http.StatusOK, result)
}
// getEquityHistoryForTraders Get historical data for multiple traders
// Query directly from database, not dependent on trader in memory (so historical data can be retrieved after restart)
// Also appends current real-time data point to ensure chart matches leaderboard
// hours: filter by last N hours (0 = use default limit of 500 records)
func (s *Server) getEquityHistoryForTraders(traderIDs []string, hours int) map[string]interface{} {
result := make(map[string]interface{})
histories := make(map[string]interface{})
errors := make(map[string]string)
// Use a single consistent timestamp for all real-time data points
now := time.Now()
// Pre-fetch initial balances for all traders
initialBalances := make(map[string]float64)
for _, traderID := range traderIDs {
if traderID == "" {
continue
}
// Get trader's initial balance from database (use GetByID which doesn't require userID)
trader, err := s.store.Trader().GetByID(traderID)
if err == nil && trader != nil && trader.InitialBalance > 0 {
initialBalances[traderID] = trader.InitialBalance
}
}
for _, traderID := range traderIDs {
if traderID == "" {
continue
}
// Get equity historical data from new equity table
var snapshots []*store.EquitySnapshot
var err error
if hours > 0 {
// Filter by time range
startTime := now.Add(-time.Duration(hours) * time.Hour)
snapshots, err = s.store.Equity().GetByTimeRange(traderID, startTime, now)
} else {
// Default: get latest 500 records
snapshots, err = s.store.Equity().GetLatest(traderID, 500)
}
if err != nil {
logger.Errorf("[API] Failed to get equity history for %s: %v", traderID, err)
errors[traderID] = "Failed to get historical data"
continue
}
// Get initial balance for calculating PnL percentage
initialBalance := initialBalances[traderID]
if initialBalance <= 0 && len(snapshots) > 0 {
// If no initial balance configured, use the first snapshot's equity as baseline
initialBalance = snapshots[0].TotalEquity
}
// Build return rate historical data with PnL percentage
history := make([]map[string]interface{}, 0, len(snapshots)+1)
var lastSnapshotTime time.Time
for _, snap := range snapshots {
// Calculate PnL percentage: (current_equity - initial_balance) / initial_balance * 100
pnlPct := 0.0
if initialBalance > 0 {
pnlPct = (snap.TotalEquity - initialBalance) / initialBalance * 100
}
history = append(history, map[string]interface{}{
"timestamp": snap.Timestamp,
"total_equity": snap.TotalEquity,
"total_pnl": snap.UnrealizedPnL,
"total_pnl_pct": pnlPct,
"balance": snap.Balance,
})
if snap.Timestamp.After(lastSnapshotTime) {
lastSnapshotTime = snap.Timestamp
}
}
// Append current real-time data point to ensure chart matches leaderboard
// This ensures the latest point is always current, not from a potentially stale snapshot
if trader, err := s.traderManager.GetTrader(traderID); err == nil {
if accountInfo, err := trader.GetAccountInfo(); err == nil {
// Only append if it's been more than 30 seconds since last snapshot
if now.Sub(lastSnapshotTime) > 30*time.Second {
totalEquity := 0.0
if v, ok := accountInfo["total_equity"].(float64); ok {
totalEquity = v
}
totalPnL := 0.0
if v, ok := accountInfo["total_pnl"].(float64); ok {
totalPnL = v
}
walletBalance := 0.0
if v, ok := accountInfo["wallet_balance"].(float64); ok {
walletBalance = v
}
pnlPct := 0.0
if initialBalance > 0 {
pnlPct = (totalEquity - initialBalance) / initialBalance * 100
}
history = append(history, map[string]interface{}{
"timestamp": now,
"total_equity": totalEquity,
"total_pnl": totalPnL,
"total_pnl_pct": pnlPct,
"balance": walletBalance,
})
}
}
}
histories[traderID] = history
}
result["histories"] = histories
result["count"] = len(histories)
if len(errors) > 0 {
result["errors"] = errors
}
return result
}
// handleGetPublicTraderConfig Get public trader configuration information (no authentication required, does not include sensitive information)
func (s *Server) handleGetPublicTraderConfig(c *gin.Context) {
traderID := c.Param("id")
if traderID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Trader ID cannot be empty"})
return
}
trader, err := s.traderManager.GetTrader(traderID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"})
return
}
// Get trader status information
status := trader.GetStatus()
// Only return public configuration information, not including sensitive data like API keys
result := map[string]interface{}{
"trader_id": trader.GetID(),
"trader_name": trader.GetName(),
"ai_model": trader.GetAIModel(),
"exchange": trader.GetExchange(),
"is_running": status["is_running"],
"ai_provider": status["ai_provider"],
"start_time": status["start_time"],
}
c.JSON(http.StatusOK, result)
}
+353
View File
@@ -0,0 +1,353 @@
package api
import (
"encoding/json"
"fmt"
"net/http"
"nofx/config"
"nofx/crypto"
"nofx/logger"
"github.com/gin-gonic/gin"
)
type ExchangeConfig struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"` // "cex" or "dex"
Enabled bool `json:"enabled"`
APIKey string `json:"apiKey,omitempty"`
SecretKey string `json:"secretKey,omitempty"`
Testnet bool `json:"testnet,omitempty"`
}
// SafeExchangeConfig Safe exchange configuration structure (does not contain sensitive information)
type SafeExchangeConfig struct {
ID string `json:"id"` // UUID
ExchangeType string `json:"exchange_type"` // "binance", "bybit", "okx", "hyperliquid", "aster", "lighter"
AccountName string `json:"account_name"` // User-defined account name
Name string `json:"name"` // Display name
Type string `json:"type"` // "cex" or "dex"
Enabled bool `json:"enabled"`
Testnet bool `json:"testnet,omitempty"`
HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` // Hyperliquid wallet address (not sensitive)
AsterUser string `json:"asterUser"` // Aster username (not sensitive)
AsterSigner string `json:"asterSigner"` // Aster signer (not sensitive)
LighterWalletAddr string `json:"lighterWalletAddr"` // LIGHTER wallet address (not sensitive)
}
type UpdateExchangeConfigRequest struct {
Exchanges map[string]struct {
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
SecretKey string `json:"secret_key"`
Passphrase string `json:"passphrase"` // OKX specific
Testnet bool `json:"testnet"`
HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"`
HyperliquidUnifiedAcct bool `json:"hyperliquid_unified_account"` // Unified Account mode
AsterUser string `json:"aster_user"`
AsterSigner string `json:"aster_signer"`
AsterPrivateKey string `json:"aster_private_key"`
LighterWalletAddr string `json:"lighter_wallet_addr"`
LighterPrivateKey string `json:"lighter_private_key"`
LighterAPIKeyPrivateKey string `json:"lighter_api_key_private_key"`
LighterAPIKeyIndex int `json:"lighter_api_key_index"`
} `json:"exchanges"`
}
// CreateExchangeRequest request structure for creating a new exchange account
type CreateExchangeRequest struct {
ExchangeType string `json:"exchange_type" binding:"required"` // "binance", "bybit", "okx", "hyperliquid", "aster", "lighter"
AccountName string `json:"account_name"` // User-defined account name
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
SecretKey string `json:"secret_key"`
Passphrase string `json:"passphrase"`
Testnet bool `json:"testnet"`
HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"`
HyperliquidUnifiedAcct bool `json:"hyperliquid_unified_account"` // Unified Account mode: Spot as Perp collateral
AsterUser string `json:"aster_user"`
AsterSigner string `json:"aster_signer"`
AsterPrivateKey string `json:"aster_private_key"`
LighterWalletAddr string `json:"lighter_wallet_addr"`
LighterPrivateKey string `json:"lighter_private_key"`
LighterAPIKeyPrivateKey string `json:"lighter_api_key_private_key"`
LighterAPIKeyIndex int `json:"lighter_api_key_index"`
}
// handleGetExchangeConfigs Get exchange configurations
func (s *Server) handleGetExchangeConfigs(c *gin.Context) {
userID := c.GetString("user_id")
logger.Infof("🔍 Querying exchange configs for user %s", userID)
exchanges, err := s.store.Exchange().List(userID)
if err != nil {
SafeInternalError(c, "Failed to get exchange configs", err)
return
}
// If no exchanges in database, return empty array (user needs to create accounts)
if len(exchanges) == 0 {
logger.Infof("⚠️ No exchanges in database for user %s", userID)
c.JSON(http.StatusOK, []SafeExchangeConfig{})
return
}
logger.Infof("✅ Found %d exchange configs", len(exchanges))
// Convert to safe response structure, remove sensitive information
safeExchanges := make([]SafeExchangeConfig, len(exchanges))
for i, exchange := range exchanges {
safeExchanges[i] = SafeExchangeConfig{
ID: exchange.ID,
ExchangeType: exchange.ExchangeType,
AccountName: exchange.AccountName,
Name: exchange.Name,
Type: exchange.Type,
Enabled: exchange.Enabled,
Testnet: exchange.Testnet,
HyperliquidWalletAddr: exchange.HyperliquidWalletAddr,
AsterUser: exchange.AsterUser,
AsterSigner: exchange.AsterSigner,
LighterWalletAddr: exchange.LighterWalletAddr,
}
}
c.JSON(http.StatusOK, safeExchanges)
}
// handleUpdateExchangeConfigs Update exchange configurations (supports both encrypted and plain text based on config)
func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) {
userID := c.GetString("user_id")
cfg := config.Get()
// Read raw request body
bodyBytes, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"})
return
}
var req UpdateExchangeConfigRequest
// Check if transport encryption is enabled
if !cfg.TransportEncryption {
// Transport encryption disabled, accept plain JSON
if err := json.Unmarshal(bodyBytes, &req); err != nil {
logger.Infof("❌ Failed to parse plain JSON request: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
return
}
logger.Infof("📝 Received plain text exchange config (UserID: %s)", userID)
} else {
// Transport encryption enabled, require encrypted payload
var encryptedPayload crypto.EncryptedPayload
if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil {
logger.Infof("❌ Failed to parse encrypted payload: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format, encrypted transmission required"})
return
}
// Verify encrypted data
if encryptedPayload.WrappedKey == "" {
logger.Infof("❌ Detected unencrypted request (UserID: %s)", userID)
c.JSON(http.StatusBadRequest, gin.H{
"error": "This endpoint only supports encrypted transmission, please use encrypted client",
"code": "ENCRYPTION_REQUIRED",
"message": "Encrypted transmission is required for security reasons",
})
return
}
// Decrypt data
decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload)
if err != nil {
logger.Infof("❌ Failed to decrypt exchange config (UserID: %s): %v", userID, err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to decrypt data"})
return
}
// Parse decrypted data
if err := json.Unmarshal([]byte(decrypted), &req); err != nil {
logger.Infof("❌ Failed to parse decrypted data: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse decrypted data"})
return
}
logger.Infof("🔓 Decrypted exchange config data (UserID: %s)", userID)
}
// Update each exchange's configuration and track traders that need reload
tradersToReload := make(map[string]bool)
for exchangeID, exchangeData := range req.Exchanges {
// Find traders using this exchange BEFORE updating
traders, _ := s.store.Trader().ListByExchangeID(userID, exchangeID)
for _, t := range traders {
tradersToReload[t.ID] = true
}
err := s.store.Exchange().Update(userID, exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Passphrase, exchangeData.Testnet, exchangeData.HyperliquidWalletAddr, exchangeData.HyperliquidUnifiedAcct, exchangeData.AsterUser, exchangeData.AsterSigner, exchangeData.AsterPrivateKey, exchangeData.LighterWalletAddr, exchangeData.LighterPrivateKey, exchangeData.LighterAPIKeyPrivateKey, exchangeData.LighterAPIKeyIndex)
if err != nil {
SafeInternalError(c, fmt.Sprintf("Update exchange %s", exchangeID), err)
return
}
}
// Remove affected traders from memory BEFORE reloading to pick up new config
for traderID := range tradersToReload {
logger.Infof("🔄 Removing trader %s from memory to reload with new exchange config", traderID)
s.traderManager.RemoveTrader(traderID)
}
// Reload all traders for this user to make new config take effect immediately
err = s.traderManager.LoadUserTradersFromStore(s.store, userID)
if err != nil {
logger.Infof("⚠️ Failed to reload user traders into memory: %v", err)
// Don't return error here since exchange config was successfully updated to database
}
logger.Infof("✓ Exchange config updated: %+v", req.Exchanges)
c.JSON(http.StatusOK, gin.H{"message": "Exchange configuration updated"})
}
// handleCreateExchange Create a new exchange account
func (s *Server) handleCreateExchange(c *gin.Context) {
userID := c.GetString("user_id")
cfg := config.Get()
// Read raw request body
bodyBytes, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"})
return
}
var req CreateExchangeRequest
// Check if transport encryption is enabled
if !cfg.TransportEncryption {
// Transport encryption disabled, accept plain JSON
if err := json.Unmarshal(bodyBytes, &req); err != nil {
logger.Infof("❌ Failed to parse plain JSON request: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
return
}
} else {
// Transport encryption enabled, require encrypted payload
var encryptedPayload crypto.EncryptedPayload
if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format, encrypted transmission required"})
return
}
if encryptedPayload.WrappedKey == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "This endpoint only supports encrypted transmission",
"code": "ENCRYPTION_REQUIRED",
"message": "Encrypted transmission is required for security reasons",
})
return
}
decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to decrypt data"})
return
}
if err := json.Unmarshal([]byte(decrypted), &req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse decrypted data"})
return
}
}
// Validate exchange type
validTypes := map[string]bool{
"binance": true, "bybit": true, "okx": true, "bitget": true,
"hyperliquid": true, "aster": true, "lighter": true, "gate": true, "kucoin": true, "indodax": true,
}
if !validTypes[req.ExchangeType] {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid exchange type: %s", req.ExchangeType)})
return
}
// Create new exchange account
id, err := s.store.Exchange().Create(
userID, req.ExchangeType, req.AccountName, req.Enabled,
req.APIKey, req.SecretKey, req.Passphrase, req.Testnet,
req.HyperliquidWalletAddr, req.HyperliquidUnifiedAcct,
req.AsterUser, req.AsterSigner, req.AsterPrivateKey,
req.LighterWalletAddr, req.LighterPrivateKey, req.LighterAPIKeyPrivateKey, req.LighterAPIKeyIndex,
)
if err != nil {
logger.Infof("❌ Failed to create exchange account: %v", err)
SafeInternalError(c, "Failed to create exchange account", err)
return
}
logger.Infof("✓ Created exchange account: type=%s, name=%s, id=%s", req.ExchangeType, req.AccountName, id)
c.JSON(http.StatusOK, gin.H{
"message": "Exchange account created",
"id": id,
})
}
// handleDeleteExchange Delete an exchange account
func (s *Server) handleDeleteExchange(c *gin.Context) {
userID := c.GetString("user_id")
exchangeID := c.Param("id")
if exchangeID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Exchange ID is required"})
return
}
// Check if any traders are using this exchange
traders, err := s.store.Trader().List(userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check traders"})
return
}
for _, trader := range traders {
if trader.ExchangeID == exchangeID {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Cannot delete exchange account that is in use by traders",
"trader_id": trader.ID,
"trader_name": trader.Name,
})
return
}
}
// Delete exchange account
err = s.store.Exchange().Delete(userID, exchangeID)
if err != nil {
logger.Infof("❌ Failed to delete exchange account: %v", err)
SafeInternalError(c, "Failed to delete exchange account", err)
return
}
logger.Infof("✓ Deleted exchange account: id=%s", exchangeID)
c.JSON(http.StatusOK, gin.H{"message": "Exchange account deleted"})
}
// handleGetSupportedExchanges Get list of exchanges supported by the system
func (s *Server) handleGetSupportedExchanges(c *gin.Context) {
// Return static list of supported exchange types
// Note: ID is empty for supported exchanges (they are templates, not actual accounts)
supportedExchanges := []SafeExchangeConfig{
{ExchangeType: "binance", Name: "Binance Futures", Type: "cex"},
{ExchangeType: "bybit", Name: "Bybit Futures", Type: "cex"},
{ExchangeType: "okx", Name: "OKX Futures", Type: "cex"},
{ExchangeType: "gate", Name: "Gate.io Futures", Type: "cex"},
{ExchangeType: "kucoin", Name: "KuCoin Futures", Type: "cex"},
{ExchangeType: "hyperliquid", Name: "Hyperliquid", Type: "dex"},
{ExchangeType: "aster", Name: "Aster DEX", Type: "dex"},
{ExchangeType: "lighter", Name: "LIGHTER DEX", Type: "dex"},
{ExchangeType: "alpaca", Name: "Alpaca (US Stocks)", Type: "stock"},
{ExchangeType: "forex", Name: "Forex (TwelveData)", Type: "forex"},
{ExchangeType: "metals", Name: "Metals (TwelveData)", Type: "metals"},
}
c.JSON(http.StatusOK, supportedExchanges)
}
+392
View File
@@ -0,0 +1,392 @@
package api
import (
"context"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"nofx/logger"
"nofx/market"
"nofx/provider/alpaca"
"nofx/provider/coinank/coinank_api"
"nofx/provider/coinank/coinank_enum"
"nofx/provider/hyperliquid"
"nofx/provider/twelvedata"
"github.com/gin-gonic/gin"
)
// handleKlines K-line data (supports multiple exchanges via coinank)
func (s *Server) handleKlines(c *gin.Context) {
// Get query parameters
symbol := c.Query("symbol")
if symbol == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "symbol parameter is required"})
return
}
interval := c.DefaultQuery("interval", "5m")
exchange := c.DefaultQuery("exchange", "binance") // Default to binance for backward compatibility
limitStr := c.DefaultQuery("limit", "1000")
limit, err := strconv.Atoi(limitStr)
if err != nil || limit <= 0 {
limit = 1000
}
// Coinank API has a maximum limit of 1500 klines per request
if limit > 1500 {
limit = 1500
}
var klines []market.Kline
exchangeLower := strings.ToLower(exchange)
// Route to appropriate data source based on exchange type
switch exchangeLower {
case "alpaca":
// US Stocks via Alpaca
klines, err = s.getKlinesFromAlpaca(symbol, interval, limit)
if err != nil {
SafeInternalError(c, "Get klines from Alpaca", err)
return
}
case "forex", "metals":
// Forex and Metals via Twelve Data
klines, err = s.getKlinesFromTwelveData(symbol, interval, limit)
if err != nil {
SafeInternalError(c, "Get klines from TwelveData", err)
return
}
case "hyperliquid", "hyperliquid-xyz", "xyz":
// Hyperliquid native API - supports both crypto perps and stock perps (xyz dex)
klines, err = s.getKlinesFromHyperliquid(symbol, interval, limit)
if err != nil {
SafeInternalError(c, "Get klines from Hyperliquid", err)
return
}
default:
// Crypto exchanges via CoinAnk
symbol = market.Normalize(symbol)
klines, err = s.getKlinesFromCoinank(symbol, interval, exchange, limit)
if err != nil {
SafeInternalError(c, "Get klines from CoinAnk", err)
return
}
}
c.JSON(http.StatusOK, klines)
}
// getKlinesFromCoinank fetches kline data from coinank free/open API for multiple exchanges
func (s *Server) getKlinesFromCoinank(symbol, interval, exchange string, limit int) ([]market.Kline, error) {
// Map exchange string to coinank enum
var coinankExchange coinank_enum.Exchange
switch strings.ToLower(exchange) {
case "binance":
coinankExchange = coinank_enum.Binance
case "bybit":
coinankExchange = coinank_enum.Bybit
case "okx":
coinankExchange = coinank_enum.Okex
case "bitget":
coinankExchange = coinank_enum.Bitget
case "gate":
coinankExchange = coinank_enum.Gate
case "aster":
coinankExchange = coinank_enum.Aster
case "lighter":
// Lighter doesn't have direct CoinAnk support, use Binance data as fallback
coinankExchange = coinank_enum.Binance
case "kucoin":
// KuCoin doesn't have direct CoinAnk support, use Binance data as fallback
coinankExchange = coinank_enum.Binance
default:
// For any unknown exchange, default to Binance
logger.Warnf("⚠️ Unknown exchange '%s', defaulting to Binance for CoinAnk", exchange)
coinankExchange = coinank_enum.Binance
}
// Map interval string to coinank enum
var coinankInterval coinank_enum.Interval
switch interval {
case "1s":
coinankInterval = coinank_enum.Second1
case "5s":
coinankInterval = coinank_enum.Second5
case "10s":
coinankInterval = coinank_enum.Second10
case "30s":
coinankInterval = coinank_enum.Second30
case "1m":
coinankInterval = coinank_enum.Minute1
case "3m":
coinankInterval = coinank_enum.Minute3
case "5m":
coinankInterval = coinank_enum.Minute5
case "10m":
coinankInterval = coinank_enum.Minute10
case "15m":
coinankInterval = coinank_enum.Minute15
case "30m":
coinankInterval = coinank_enum.Minute30
case "1h":
coinankInterval = coinank_enum.Hour1
case "2h":
coinankInterval = coinank_enum.Hour2
case "4h":
coinankInterval = coinank_enum.Hour4
case "6h":
coinankInterval = coinank_enum.Hour6
case "8h":
coinankInterval = coinank_enum.Hour8
case "12h":
coinankInterval = coinank_enum.Hour12
case "1d":
coinankInterval = coinank_enum.Day1
case "3d":
coinankInterval = coinank_enum.Day3
case "1w":
coinankInterval = coinank_enum.Week1
case "1M":
coinankInterval = coinank_enum.Month1
default:
return nil, fmt.Errorf("unsupported interval for coinank: %s", interval)
}
// Convert symbol format for different exchanges
// OKX uses "BTC-USDT-SWAP" format instead of "BTCUSDT"
apiSymbol := symbol
if coinankExchange == coinank_enum.Okex {
// Convert BTCUSDT -> BTC-USDT-SWAP
if strings.HasSuffix(symbol, "USDT") {
base := strings.TrimSuffix(symbol, "USDT")
apiSymbol = fmt.Sprintf("%s-USDT-SWAP", base)
}
}
// Call coinank free/open API (no authentication required)
ctx := context.Background()
ts := time.Now().UnixMilli()
// Use "To" side to search backward from current time (get historical klines)
coinankKlines, err := coinank_api.Kline(ctx, apiSymbol, coinankExchange, ts, coinank_enum.To, limit, coinankInterval)
if err != nil {
// Free API doesn't support all exchanges (e.g., OKX, Bitget)
// Fallback to Binance data as reference
if coinankExchange != coinank_enum.Binance {
logger.Warnf("⚠️ CoinAnk free API doesn't support %s, falling back to Binance data", coinankExchange)
coinankKlines, err = coinank_api.Kline(ctx, symbol, coinank_enum.Binance, ts, coinank_enum.To, limit, coinankInterval)
if err != nil {
return nil, fmt.Errorf("coinank API error (fallback): %w", err)
}
} else {
return nil, fmt.Errorf("coinank API error: %w", err)
}
}
// Convert coinank kline format to market.Kline format
// Coinank: Volume = BTC 数量, Quantity = USDT 成交额
klines := make([]market.Kline, len(coinankKlines))
for i, ck := range coinankKlines {
klines[i] = market.Kline{
OpenTime: ck.StartTime,
Open: ck.Open,
High: ck.High,
Low: ck.Low,
Close: ck.Close,
Volume: ck.Volume, // BTC 数量
QuoteVolume: ck.Quantity, // USDT 成交额
CloseTime: ck.EndTime,
}
}
return klines, nil
}
// getKlinesFromAlpaca fetches kline data from Alpaca API for US stocks
func (s *Server) getKlinesFromAlpaca(symbol, interval string, limit int) ([]market.Kline, error) {
// Create Alpaca client
client := alpaca.NewClient()
// Map interval to Alpaca timeframe format
timeframe := alpaca.MapTimeframe(interval)
// Fetch bars from Alpaca
ctx := context.Background()
bars, err := client.GetBars(ctx, symbol, timeframe, limit)
if err != nil {
return nil, fmt.Errorf("alpaca API error: %w", err)
}
// Convert Alpaca bars to market.Kline format
klines := make([]market.Kline, len(bars))
for i, bar := range bars {
klines[i] = market.Kline{
OpenTime: bar.Timestamp.UnixMilli(),
Open: bar.Open,
High: bar.High,
Low: bar.Low,
Close: bar.Close,
Volume: float64(bar.Volume), // 股数
QuoteVolume: float64(bar.Volume) * bar.Close, // 成交额 = 股数 * 收盘价 (USD)
CloseTime: bar.Timestamp.UnixMilli(),
}
}
return klines, nil
}
// getKlinesFromTwelveData fetches kline data from Twelve Data API for forex and metals
func (s *Server) getKlinesFromTwelveData(symbol, interval string, limit int) ([]market.Kline, error) {
// Create Twelve Data client
client := twelvedata.NewClient()
// Map interval to Twelve Data timeframe format
timeframe := twelvedata.MapTimeframe(interval)
// Fetch time series from Twelve Data
ctx := context.Background()
result, err := client.GetTimeSeries(ctx, symbol, timeframe, limit)
if err != nil {
return nil, fmt.Errorf("twelvedata API error: %w", err)
}
// Convert Twelve Data bars to market.Kline format
// Note: Twelve Data returns bars in reverse order (newest first)
klines := make([]market.Kline, len(result.Values))
for i, bar := range result.Values {
open, high, low, close, volume, timestamp, err := twelvedata.ParseBar(bar)
if err != nil {
logger.Warnf("⚠️ Failed to parse TwelveData bar: %v", err)
continue
}
// Reverse order: put oldest first
idx := len(result.Values) - 1 - i
klines[idx] = market.Kline{
OpenTime: timestamp,
Open: open,
High: high,
Low: low,
Close: close,
Volume: volume,
CloseTime: timestamp,
}
}
return klines, nil
}
// getKlinesFromHyperliquid fetches kline data from Hyperliquid API
// Supports both crypto perps (default dex) and stock perps/forex/commodities (xyz dex)
func (s *Server) getKlinesFromHyperliquid(symbol, interval string, limit int) ([]market.Kline, error) {
// Create Hyperliquid client
client := hyperliquid.NewClient()
// Map interval to Hyperliquid format
timeframe := hyperliquid.MapTimeframe(interval)
// Fetch candles from Hyperliquid
// FormatCoinForAPI will automatically add xyz: prefix for stock perps
ctx := context.Background()
candles, err := client.GetCandles(ctx, symbol, timeframe, limit)
if err != nil {
return nil, fmt.Errorf("hyperliquid API error: %w", err)
}
// Convert Hyperliquid candles to market.Kline format
klines := make([]market.Kline, len(candles))
for i, candle := range candles {
open, _ := strconv.ParseFloat(candle.Open, 64)
high, _ := strconv.ParseFloat(candle.High, 64)
low, _ := strconv.ParseFloat(candle.Low, 64)
close, _ := strconv.ParseFloat(candle.Close, 64)
volume, _ := strconv.ParseFloat(candle.Volume, 64)
klines[i] = market.Kline{
OpenTime: candle.OpenTime,
Open: open,
High: high,
Low: low,
Close: close,
Volume: volume, // 合约数量
QuoteVolume: volume * close, // 成交额 (USD)
CloseTime: candle.CloseTime,
}
}
return klines, nil
}
// handleSymbols returns available symbols for a given exchange
func (s *Server) handleSymbols(c *gin.Context) {
exchange := c.DefaultQuery("exchange", "hyperliquid")
type SymbolInfo struct {
Symbol string `json:"symbol"`
Name string `json:"name"`
Category string `json:"category"` // crypto, stock, forex, commodity, index
MaxLeverage int `json:"maxLeverage,omitempty"`
}
var symbols []SymbolInfo
switch strings.ToLower(exchange) {
case "hyperliquid", "hyperliquid-xyz", "xyz":
// Fetch symbols from Hyperliquid
client := hyperliquid.NewClient()
ctx := context.Background()
// Get crypto perps from default dex
if exchange == "hyperliquid" || exchange == "hyperliquid-xyz" {
mids, err := client.GetAllMids(ctx)
if err == nil {
for symbol := range mids {
// Skip spot tokens (start with @)
if strings.HasPrefix(symbol, "@") {
continue
}
symbols = append(symbols, SymbolInfo{
Symbol: symbol,
Name: symbol,
Category: "crypto",
})
}
}
}
// Get xyz dex symbols (stocks, forex, commodities)
xyzMids, err := client.GetAllMidsXYZ(ctx)
if err == nil {
for symbol := range xyzMids {
// Remove xyz: prefix for display
displaySymbol := strings.TrimPrefix(symbol, "xyz:")
category := "stock"
if displaySymbol == "GOLD" || displaySymbol == "SILVER" {
category = "commodity"
} else if displaySymbol == "EUR" || displaySymbol == "JPY" {
category = "forex"
} else if displaySymbol == "XYZ100" {
category = "index"
}
symbols = append(symbols, SymbolInfo{
Symbol: displaySymbol,
Name: displaySymbol,
Category: category,
})
}
}
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported exchange for symbol listing"})
return
}
c.JSON(http.StatusOK, gin.H{
"exchange": exchange,
"symbols": symbols,
"count": len(symbols),
})
}
+402
View File
@@ -0,0 +1,402 @@
package api
import (
"net/http"
"strconv"
"nofx/logger"
"nofx/market"
"github.com/gin-gonic/gin"
)
// handleTraderList Trader list
func (s *Server) handleTraderList(c *gin.Context) {
userID := c.GetString("user_id")
traders, err := s.store.Trader().List(userID)
if err != nil {
SafeInternalError(c, "Failed to get trader list", err)
return
}
result := make([]map[string]interface{}, 0, len(traders))
for _, trader := range traders {
// Get real-time running status
isRunning := trader.IsRunning
if at, err := s.traderManager.GetTrader(trader.ID); err == nil {
status := at.GetStatus()
if running, ok := status["is_running"].(bool); ok {
isRunning = running
}
}
// Get strategy name if strategy_id is set
var strategyName string
if trader.StrategyID != "" {
if strategy, err := s.store.Strategy().Get(userID, trader.StrategyID); err == nil {
strategyName = strategy.Name
}
}
// Return complete AIModelID (e.g. "admin_deepseek"), don't truncate
// Frontend needs complete ID to verify model exists (consistent with handleGetTraderConfig)
result = append(result, map[string]interface{}{
"trader_id": trader.ID,
"trader_name": trader.Name,
"ai_model": trader.AIModelID, // Use complete ID
"exchange_id": trader.ExchangeID,
"is_running": isRunning,
"show_in_competition": trader.ShowInCompetition,
"initial_balance": trader.InitialBalance,
"strategy_id": trader.StrategyID,
"strategy_name": strategyName,
})
}
c.JSON(http.StatusOK, result)
}
// handleGetTraderConfig Get trader detailed configuration
func (s *Server) handleGetTraderConfig(c *gin.Context) {
userID := c.GetString("user_id")
traderID := c.Param("id")
if traderID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Trader ID cannot be empty"})
return
}
fullCfg, err := s.store.Trader().GetFullConfig(userID, traderID)
if err != nil {
SafeNotFound(c, "Trader config")
return
}
traderConfig := fullCfg.Trader
// Get real-time running status
isRunning := traderConfig.IsRunning
if at, err := s.traderManager.GetTrader(traderID); err == nil {
status := at.GetStatus()
if running, ok := status["is_running"].(bool); ok {
isRunning = running
}
}
// Return complete model ID without conversion, consistent with frontend model list
aiModelID := traderConfig.AIModelID
result := map[string]interface{}{
"trader_id": traderConfig.ID,
"trader_name": traderConfig.Name,
"ai_model": aiModelID,
"exchange_id": traderConfig.ExchangeID,
"strategy_id": traderConfig.StrategyID,
"initial_balance": traderConfig.InitialBalance,
"scan_interval_minutes": traderConfig.ScanIntervalMinutes,
"btc_eth_leverage": traderConfig.BTCETHLeverage,
"altcoin_leverage": traderConfig.AltcoinLeverage,
"trading_symbols": traderConfig.TradingSymbols,
"custom_prompt": traderConfig.CustomPrompt,
"override_base_prompt": traderConfig.OverrideBasePrompt,
"is_cross_margin": traderConfig.IsCrossMargin,
"use_ai500": traderConfig.UseAI500,
"use_oi_top": traderConfig.UseOITop,
"is_running": isRunning,
}
c.JSON(http.StatusOK, result)
}
// handleStatus System status
func (s *Server) handleStatus(c *gin.Context) {
_, traderID, err := s.getTraderFromQuery(c)
if err != nil {
SafeBadRequest(c, "Invalid trader ID")
return
}
trader, err := s.traderManager.GetTrader(traderID)
if err != nil {
SafeNotFound(c, "Trader")
return
}
status := trader.GetStatus()
c.JSON(http.StatusOK, status)
}
// handleAccount Account information
func (s *Server) handleAccount(c *gin.Context) {
_, traderID, err := s.getTraderFromQuery(c)
if err != nil {
SafeBadRequest(c, "Invalid trader ID")
return
}
trader, err := s.traderManager.GetTrader(traderID)
if err != nil {
SafeNotFound(c, "Trader")
return
}
logger.Infof("📊 Received account info request [%s]", trader.GetName())
account, err := trader.GetAccountInfo()
if err != nil {
SafeInternalError(c, "Get account info", err)
return
}
logger.Infof("✓ Returning account info [%s]: equity=%.2f, available=%.2f, pnl=%.2f (%.2f%%)",
trader.GetName(),
account["total_equity"],
account["available_balance"],
account["total_pnl"],
account["total_pnl_pct"])
c.JSON(http.StatusOK, account)
}
// handlePositions Position list
func (s *Server) handlePositions(c *gin.Context) {
_, traderID, err := s.getTraderFromQuery(c)
if err != nil {
SafeBadRequest(c, "Invalid trader ID")
return
}
trader, err := s.traderManager.GetTrader(traderID)
if err != nil {
SafeNotFound(c, "Trader")
return
}
positions, err := trader.GetPositions()
if err != nil {
SafeInternalError(c, "Get positions", err)
return
}
c.JSON(http.StatusOK, positions)
}
// handlePositionHistory Historical closed positions with statistics
func (s *Server) handlePositionHistory(c *gin.Context) {
_, traderID, err := s.getTraderFromQuery(c)
if err != nil {
SafeBadRequest(c, "Invalid trader ID")
return
}
trader, err := s.traderManager.GetTrader(traderID)
if err != nil {
SafeNotFound(c, "Trader")
return
}
// Get optional query parameters
limitStr := c.DefaultQuery("limit", "100")
limit := 100
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 500 {
limit = l
}
// Get store
store := trader.GetStore()
if store == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"})
return
}
// Get closed positions
positions, err := store.Position().GetClosedPositions(trader.GetID(), limit)
if err != nil {
SafeInternalError(c, "Get position history", err)
return
}
// Get statistics
stats, _ := store.Position().GetFullStats(trader.GetID())
// Get symbol stats
symbolStats, _ := store.Position().GetSymbolStats(trader.GetID(), 10)
// Get direction stats
directionStats, _ := store.Position().GetDirectionStats(trader.GetID())
c.JSON(http.StatusOK, gin.H{
"positions": positions,
"stats": stats,
"symbol_stats": symbolStats,
"direction_stats": directionStats,
})
}
// handleTrades Historical trades list
func (s *Server) handleTrades(c *gin.Context) {
_, traderID, err := s.getTraderFromQuery(c)
if err != nil {
SafeBadRequest(c, "Invalid trader ID")
return
}
trader, err := s.traderManager.GetTrader(traderID)
if err != nil {
SafeNotFound(c, "Trader")
return
}
// Get optional query parameters
symbol := c.Query("symbol")
limitStr := c.DefaultQuery("limit", "100")
limit := 100
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
limit = l
}
// Normalize symbol (add USDT suffix if not present)
if symbol != "" {
symbol = market.Normalize(symbol)
}
// Get trades from store
store := trader.GetStore()
if store == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"})
return
}
allTrades, err := store.Position().GetRecentTrades(trader.GetID(), limit)
if err != nil {
SafeInternalError(c, "Get trades", err)
return
}
// Filter by symbol if specified
if symbol != "" {
var result []interface{}
for _, trade := range allTrades {
if trade.Symbol == symbol {
result = append(result, trade)
}
}
c.JSON(http.StatusOK, result)
return
}
c.JSON(http.StatusOK, allTrades)
}
// handleOrders Order list (all orders including open, close, stop loss, take profit, etc.)
func (s *Server) handleOrders(c *gin.Context) {
_, traderID, err := s.getTraderFromQuery(c)
if err != nil {
SafeBadRequest(c, "Invalid trader ID")
return
}
trader, err := s.traderManager.GetTrader(traderID)
if err != nil {
SafeNotFound(c, "Trader")
return
}
// Get optional query parameters
symbol := c.Query("symbol")
statusFilter := c.Query("status") // NEW, FILLED, CANCELED, etc.
limitStr := c.DefaultQuery("limit", "100")
limit := 100
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
limit = l
}
// Normalize symbol (add USDT suffix if not present)
if symbol != "" {
symbol = market.Normalize(symbol)
}
// Get orders from store
store := trader.GetStore()
if store == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"})
return
}
// Get orders with filters applied at database level
orders, err := store.Order().GetTraderOrdersFiltered(trader.GetID(), symbol, statusFilter, limit)
if err != nil {
SafeInternalError(c, "Get orders", err)
return
}
c.JSON(http.StatusOK, orders)
}
// handleOrderFills Order fill details (all fills for a specific order)
func (s *Server) handleOrderFills(c *gin.Context) {
orderIDStr := c.Param("id")
orderID, err := strconv.ParseInt(orderIDStr, 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid order ID"})
return
}
_, traderID, err := s.getTraderFromQuery(c)
if err != nil {
SafeBadRequest(c, "Invalid trader ID")
return
}
trader, err := s.traderManager.GetTrader(traderID)
if err != nil {
SafeNotFound(c, "Trader")
return
}
store := trader.GetStore()
if store == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"})
return
}
// Get fills for this order
fills, err := store.Order().GetOrderFills(orderID)
if err != nil {
SafeInternalError(c, "Get order fills", err)
return
}
c.JSON(http.StatusOK, fills)
}
// handleOpenOrders Get open orders (pending SL/TP) from exchange
func (s *Server) handleOpenOrders(c *gin.Context) {
_, traderID, err := s.getTraderFromQuery(c)
if err != nil {
SafeBadRequest(c, "Invalid trader ID")
return
}
trader, err := s.traderManager.GetTrader(traderID)
if err != nil {
SafeNotFound(c, "Trader")
return
}
// Get symbol parameter (required for exchange query)
symbol := c.Query("symbol")
if symbol == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "symbol parameter is required"})
return
}
// Normalize symbol
symbol = market.Normalize(symbol)
// Get open orders from exchange
openOrders, err := trader.GetOpenOrders(symbol)
if err != nil {
SafeInternalError(c, "Get open orders", err)
return
}
c.JSON(http.StatusOK, openOrders)
}
+105
View File
@@ -0,0 +1,105 @@
package api
import (
"net/http"
"github.com/gin-gonic/gin"
)
// handleGetTelegramConfig returns current Telegram bot configuration and binding status
func (s *Server) handleGetTelegramConfig(c *gin.Context) {
cfg, err := s.store.TelegramConfig().Get()
if err != nil {
// Not configured yet - return empty state
c.JSON(http.StatusOK, gin.H{
"configured": false,
"is_bound": false,
"token_masked": "",
"username": "",
})
return
}
// Mask bot token for security (show only last 6 chars)
tokenMasked := ""
if cfg.BotToken != "" {
if len(cfg.BotToken) > 6 {
tokenMasked = "***" + cfg.BotToken[len(cfg.BotToken)-6:]
} else {
tokenMasked = "***"
}
}
c.JSON(http.StatusOK, gin.H{
"configured": cfg.BotToken != "",
"is_bound": cfg.ChatID != 0,
"username": cfg.Username,
"bound_at": cfg.BoundAt,
"token_masked": tokenMasked,
"model_id": cfg.ModelID,
})
}
// handleUpdateTelegramConfig saves bot token (+ optional model ID) and triggers bot hot-reload
func (s *Server) handleUpdateTelegramConfig(c *gin.Context) {
var req struct {
BotToken string `json:"bot_token"`
ModelID string `json:"model_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
if req.BotToken == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "bot_token is required"})
return
}
if err := s.store.TelegramConfig().Save(req.BotToken, req.ModelID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save config"})
return
}
// Signal bot hot-reload if channel is available
if s.telegramReloadCh != nil {
select {
case s.telegramReloadCh <- struct{}{}:
default: // non-blocking
}
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "Bot token saved. Bot will reload automatically."})
}
// handleUnbindTelegram removes Telegram user binding
func (s *Server) handleUnbindTelegram(c *gin.Context) {
if err := s.store.TelegramConfig().Unbind(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to unbind"})
return
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "Telegram binding removed"})
}
// handleUpdateTelegramModel updates only the AI model used for Telegram replies (no token re-entry needed)
func (s *Server) handleUpdateTelegramModel(c *gin.Context) {
var req struct {
ModelID string `json:"model_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
cfg, err := s.store.TelegramConfig().Get()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "no Telegram config found, save a bot token first"})
return
}
if err := s.store.TelegramConfig().Save(cfg.BotToken, req.ModelID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save model config"})
return
}
c.JSON(http.StatusOK, gin.H{"success": true, "model_id": req.ModelID})
}
File diff suppressed because it is too large Load Diff
+223
View File
@@ -0,0 +1,223 @@
package api
import (
"net/http"
"strings"
"time"
"nofx/auth"
"nofx/logger"
"nofx/store"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// handleLogout Add current token to blacklist
func (s *Server) handleLogout(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing Authorization header"})
return
}
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"})
return
}
tokenString := parts[1]
claims, err := auth.ValidateJWT(tokenString)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
return
}
var exp time.Time
if claims.ExpiresAt != nil {
exp = claims.ExpiresAt.Time
} else {
exp = time.Now().Add(24 * time.Hour)
}
auth.BlacklistToken(tokenString, exp)
c.JSON(http.StatusOK, gin.H{"message": "Logged out"})
}
// handleRegister Handle user registration request.
// handleRegister allows registration only when no users exist yet (first-time setup).
// This is a single-user system; subsequent registrations are permanently closed.
func (s *Server) handleRegister(c *gin.Context) {
userCount, err := s.store.User().Count()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check user count"})
return
}
if userCount > 0 {
c.JSON(http.StatusForbidden, gin.H{"error": "System already initialized"})
return
}
var req struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
}
if err := c.ShouldBindJSON(&req); err != nil {
SafeBadRequest(c, "Invalid request parameters")
return
}
// Check if email already exists
_, err = s.store.User().GetByEmail(req.Email)
if err == nil {
c.JSON(http.StatusConflict, gin.H{"error": "Email already registered"})
return
}
// Generate password hash
passwordHash, err := auth.HashPassword(req.Password)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Password processing failed"})
return
}
// Create user
userID := uuid.New().String()
user := &store.User{
ID: userID,
Email: req.Email,
PasswordHash: passwordHash,
}
err = s.store.User().Create(user)
if err != nil {
SafeInternalError(c, "Failed to create user", err)
return
}
// Generate JWT token
token, err := auth.GenerateJWT(user.ID, user.Email)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
return
}
// Initialize default model and exchange configs for user
err = s.initUserDefaultConfigs(user.ID)
if err != nil {
logger.Infof("Failed to initialize user default configs: %v", err)
}
c.JSON(http.StatusOK, gin.H{
"token": token,
"user_id": user.ID,
"email": user.Email,
"message": "Registration successful",
})
}
// handleLogin Handle user login request
func (s *Server) handleLogin(c *gin.Context) {
var req struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
SafeBadRequest(c, "Invalid request parameters")
return
}
// Get user information
user, err := s.store.User().GetByEmail(req.Email)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Email or password incorrect"})
return
}
// Verify password
if !auth.CheckPassword(req.Password, user.PasswordHash) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Email or password incorrect"})
return
}
// Issue token directly after password verification.
token, err := auth.GenerateJWT(user.ID, user.Email)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
return
}
c.JSON(http.StatusOK, gin.H{
"token": token,
"user_id": user.ID,
"email": user.Email,
"message": "Login successful",
})
}
// handleChangePassword changes the password for the currently authenticated user.
func (s *Server) handleChangePassword(c *gin.Context) {
userID := c.GetString("user_id")
var req struct {
NewPassword string `json:"new_password" binding:"required,min=8"`
}
if err := c.ShouldBindJSON(&req); err != nil {
SafeBadRequest(c, "new_password is required (min 8 chars)")
return
}
hash, err := auth.HashPassword(req.NewPassword)
if err != nil {
SafeInternalError(c, "Password processing failed", err)
return
}
if err := s.store.User().UpdatePassword(userID, hash); err != nil {
SafeInternalError(c, "Failed to update password", err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "Password updated"})
}
// handleResetPassword Reset password via email and new password
func (s *Server) handleResetPassword(c *gin.Context) {
var req struct {
Email string `json:"email" binding:"required,email"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
if err := c.ShouldBindJSON(&req); err != nil {
SafeBadRequest(c, "Invalid request parameters")
return
}
// Query user
user, err := s.store.User().GetByEmail(req.Email)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Email does not exist"})
return
}
// Generate new password hash
newPasswordHash, err := auth.HashPassword(req.NewPassword)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Password processing failed"})
return
}
// Update password
err = s.store.User().UpdatePassword(user.ID, newPasswordHash)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Password update failed"})
return
}
logger.Infof("✓ User %s password has been reset", user.Email)
c.JSON(http.StatusOK, gin.H{"message": "Password reset successful, please login with new password"})
}
// initUserDefaultConfigs Initialize default model and exchange configs for new user
func (s *Server) initUserDefaultConfigs(userID string) error {
// Commented out auto-creation of default configs, let users add manually
// This way new users won't have config items automatically after registration
logger.Infof("User %s registration completed, waiting for manual AI model and exchange configuration", userID)
return nil
}
-3271
View File
File diff suppressed because it is too large Load Diff
+11 -38
View File
@@ -8,6 +8,8 @@ import (
"nofx/logger"
"nofx/market"
"nofx/mcp"
_ "nofx/mcp/payment"
_ "nofx/mcp/provider"
"nofx/store"
"time"
@@ -637,49 +639,20 @@ func (s *Server) runRealAITest(userID, modelID, systemPrompt, userPrompt string)
return "", fmt.Errorf("AI model %s is missing API Key", model.Name)
}
// Create AI client
var aiClient mcp.AIClient
// Create AI client via registry
provider := model.Provider
// Convert EncryptedString to string for API key
apiKey := string(model.APIKey)
aiClient := mcp.NewAIClientByProvider(provider)
if aiClient == nil {
aiClient = mcp.NewClient()
}
// Payment providers ignore custom URL
switch provider {
case "qwen":
aiClient = mcp.NewQwenClient()
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "deepseek":
aiClient = mcp.NewDeepSeekClient()
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "claude":
aiClient = mcp.NewClaudeClient()
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "kimi":
aiClient = mcp.NewKimiClient()
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "gemini":
aiClient = mcp.NewGeminiClient()
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "grok":
aiClient = mcp.NewGrokClient()
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "openai":
aiClient = mcp.NewOpenAIClient()
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "minimax":
aiClient = mcp.NewMiniMaxClient()
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
case "blockrun-base":
aiClient = mcp.NewBlockRunBaseClient()
aiClient.SetAPIKey(apiKey, "", model.CustomModelName)
case "blockrun-sol":
aiClient = mcp.NewBlockRunSolClient()
aiClient.SetAPIKey(apiKey, "", model.CustomModelName)
case "claw402":
aiClient = mcp.NewClaw402Client()
case "blockrun-base", "blockrun-sol", "claw402":
aiClient.SetAPIKey(apiKey, "", model.CustomModelName)
default:
// Use generic client
aiClient = mcp.NewClient()
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
}
+34 -126
View File
@@ -5,15 +5,17 @@ import (
"strings"
"nofx/mcp"
_ "nofx/mcp/payment"
_ "nofx/mcp/provider"
)
// configureMCPClient creates/clones an MCP client based on configuration (returns mcp.AIClient interface).
// Note: mcp.New() returns an interface type; here we convert to concrete implementation before copying to avoid concurrent shared state.
func configureMCPClient(cfg BacktestConfig, base mcp.AIClient) (mcp.AIClient, error) {
provider := strings.ToLower(strings.TrimSpace(cfg.AICfg.Provider))
providerName := strings.ToLower(strings.TrimSpace(cfg.AICfg.Provider))
// DeepSeek
if provider == "" || provider == "inherit" || provider == "default" {
// Inherit base client
if providerName == "" || providerName == "inherit" || providerName == "default" {
client := cloneBaseClient(base)
if cfg.AICfg.APIKey != "" || cfg.AICfg.BaseURL != "" || cfg.AICfg.Model != "" {
client.SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
@@ -21,143 +23,49 @@ func configureMCPClient(cfg BacktestConfig, base mcp.AIClient) (mcp.AIClient, er
return client, nil
}
switch provider {
case "deepseek":
if cfg.AICfg.APIKey == "" {
return nil, fmt.Errorf("deepseek provider requires api key")
}
ds := mcp.NewDeepSeekClientWithOptions()
ds.(*mcp.DeepSeekClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
return ds, nil
case "qwen":
if cfg.AICfg.APIKey == "" {
return nil, fmt.Errorf("qwen provider requires api key")
}
qc := mcp.NewQwenClientWithOptions()
qc.(*mcp.QwenClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
return qc, nil
case "claude":
if cfg.AICfg.APIKey == "" {
return nil, fmt.Errorf("claude provider requires api key")
}
cc := mcp.NewClaudeClientWithOptions()
cc.(*mcp.ClaudeClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
return cc, nil
case "kimi":
if cfg.AICfg.APIKey == "" {
return nil, fmt.Errorf("kimi provider requires api key")
}
kc := mcp.NewKimiClientWithOptions()
kc.(*mcp.KimiClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
return kc, nil
case "gemini":
if cfg.AICfg.APIKey == "" {
return nil, fmt.Errorf("gemini provider requires api key")
}
gc := mcp.NewGeminiClientWithOptions()
gc.(*mcp.GeminiClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
return gc, nil
case "grok":
if cfg.AICfg.APIKey == "" {
return nil, fmt.Errorf("grok provider requires api key")
}
grokC := mcp.NewGrokClientWithOptions()
grokC.(*mcp.GrokClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
return grokC, nil
case "openai":
if cfg.AICfg.APIKey == "" {
return nil, fmt.Errorf("openai provider requires api key")
}
oaiC := mcp.NewOpenAIClientWithOptions()
oaiC.(*mcp.OpenAIClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
return oaiC, nil
case "minimax":
if cfg.AICfg.APIKey == "" {
return nil, fmt.Errorf("minimax provider requires api key")
}
mmC := mcp.NewMiniMaxClientWithOptions()
mmC.(*mcp.MiniMaxClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
return mmC, nil
case "blockrun-base":
if cfg.AICfg.APIKey == "" {
return nil, fmt.Errorf("blockrun-base provider requires wallet private key")
}
brBase := mcp.NewBlockRunBaseClient()
brBase.SetAPIKey(cfg.AICfg.APIKey, "", cfg.AICfg.Model)
return brBase, nil
case "blockrun-sol":
if cfg.AICfg.APIKey == "" {
return nil, fmt.Errorf("blockrun-sol provider requires wallet keypair")
}
brSol := mcp.NewBlockRunSolClient()
brSol.SetAPIKey(cfg.AICfg.APIKey, "", cfg.AICfg.Model)
return brSol, nil
case "claw402":
if cfg.AICfg.APIKey == "" {
return nil, fmt.Errorf("claw402 provider requires wallet private key")
}
claw := mcp.NewClaw402Client()
claw.SetAPIKey(cfg.AICfg.APIKey, "", cfg.AICfg.Model)
return claw, nil
case "custom":
// Custom provider uses cloned base
if providerName == "custom" {
if cfg.AICfg.BaseURL == "" || cfg.AICfg.APIKey == "" || cfg.AICfg.Model == "" {
return nil, fmt.Errorf("custom provider requires base_url, api key and model")
}
client := cloneBaseClient(base)
client.SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
return client, nil
default:
}
// Create client via registry
client := mcp.NewAIClientByProvider(providerName)
if client == nil {
return nil, fmt.Errorf("unsupported ai provider %s", cfg.AICfg.Provider)
}
if cfg.AICfg.APIKey == "" {
return nil, fmt.Errorf("%s provider requires api key", providerName)
}
// Payment providers ignore custom URL
switch providerName {
case "blockrun-base", "blockrun-sol", "claw402":
client.SetAPIKey(cfg.AICfg.APIKey, "", cfg.AICfg.Model)
default:
client.SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
}
return client, nil
}
// cloneBaseClient copies the base client to avoid shared mutable state.
// Uses the ClientEmbedder interface to extract the underlying *mcp.Client
// from any provider type that embeds it.
func cloneBaseClient(base mcp.AIClient) *mcp.Client {
// Prefer to reuse the passed-in base client (deep copy)
switch c := base.(type) {
case *mcp.Client:
if embedder, ok := base.(mcp.ClientEmbedder); ok {
if inner := embedder.BaseClient(); inner != nil {
cp := *inner
return &cp
}
}
if c, ok := base.(*mcp.Client); ok {
cp := *c
return &cp
case *mcp.DeepSeekClient:
if c != nil && c.Client != nil {
cp := *c.Client
return &cp
}
case *mcp.QwenClient:
if c != nil && c.Client != nil {
cp := *c.Client
return &cp
}
case *mcp.ClaudeClient:
if c != nil && c.Client != nil {
cp := *c.Client
return &cp
}
case *mcp.KimiClient:
if c != nil && c.Client != nil {
cp := *c.Client
return &cp
}
case *mcp.GeminiClient:
if c != nil && c.Client != nil {
cp := *c.Client
return &cp
}
case *mcp.GrokClient:
if c != nil && c.Client != nil {
cp := *c.Client
return &cp
}
case *mcp.OpenAIClient:
if c != nil && c.Client != nil {
cp := *c.Client
return &cp
}
case *mcp.MiniMaxClient:
if c != nil && c.Client != nil {
cp := *c.Client
return &cp
}
}
// Fall back to a new default client
return mcp.NewClient().(*mcp.Client)
-351
View File
@@ -1,351 +0,0 @@
package llm
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// 阿里云 API 配置
const (
DefaultQwenBaseURL = "https://dashscope.aliyuncs.com/api/v1/apps"
// 标准 OpenAI 兼容模式 API
QwenCompatibleURL = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"
)
// QwenAgent 阿里云百炼智能体客户端
type QwenAgent struct {
AppID string
APIKey string
BaseURL string
SessionID string
Client *http.Client
}
// QwenRequest 请求结构
type QwenRequest struct {
Input QwenInput `json:"input"`
Parameters QwenParameters `json:"parameters,omitempty"`
}
// QwenInput 输入结构
type QwenInput struct {
Prompt string `json:"prompt"`
BizParams map[string]interface{} `json:"biz_params,omitempty"`
}
// QwenParameters 参数结构
type QwenParameters struct {
SessionID string `json:"session_id,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
}
// QwenResponse 响应结构
type QwenResponse struct {
Output QwenOutput `json:"output"`
Usage QwenUsage `json:"usage,omitempty"`
RequestID string `json:"request_id"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
// QwenOutput 输出结构
type QwenOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason,omitempty"`
SessionID string `json:"session_id,omitempty"`
}
// QwenUsage 用量统计
type QwenUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
// NewQwenAgent 创建新的智能体客户端
func NewQwenAgent(appID, apiKey string) *QwenAgent {
return &QwenAgent{
AppID: appID,
APIKey: apiKey,
BaseURL: DefaultQwenBaseURL,
Client: &http.Client{
Timeout: 180 * time.Second,
},
}
}
// Chat 同步对话
func (a *QwenAgent) Chat(ctx context.Context, prompt string) (*QwenResponse, error) {
reqBody := QwenRequest{
Input: QwenInput{
Prompt: prompt,
},
Parameters: QwenParameters{
SessionID: a.SessionID,
},
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request failed: %w", err)
}
url := fmt.Sprintf("%s/%s/completion", a.BaseURL, a.AppID)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+a.APIKey)
resp, err := a.Client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var result QwenResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("unmarshal response failed: %w, body: %s", err, string(body))
}
// 更新 session_id 用于多轮对话
if result.Output.SessionID != "" {
a.SessionID = result.Output.SessionID
}
// 检查 API 错误
if result.Code != "" {
return &result, fmt.Errorf("API error: code=%s, message=%s", result.Code, result.Message)
}
return &result, nil
}
// ChatStream 流式对话
func (a *QwenAgent) ChatStream(ctx context.Context, prompt string, callback func(chunk string)) error {
reqBody := QwenRequest{
Input: QwenInput{
Prompt: prompt,
},
Parameters: QwenParameters{
SessionID: a.SessionID,
IncrementalOutput: true,
},
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("marshal request failed: %w", err)
}
url := fmt.Sprintf("%s/%s/completion", a.BaseURL, a.AppID)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+a.APIKey)
req.Header.Set("X-DashScope-SSE", "enable")
resp, err := a.Client.Do(req)
if err != nil {
return fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
return fmt.Errorf("read stream failed: %w", err)
}
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "data:") {
continue
}
data := strings.TrimPrefix(line, "data:")
var chunk QwenResponse
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
// 更新 session_id
if chunk.Output.SessionID != "" {
a.SessionID = chunk.Output.SessionID
}
// 回调输出文本
if chunk.Output.Text != "" {
callback(chunk.Output.Text)
}
}
return nil
}
// ChatWithBizParams 带业务参数的对话
func (a *QwenAgent) ChatWithBizParams(ctx context.Context, prompt string, bizParams map[string]interface{}) (*QwenResponse, error) {
reqBody := QwenRequest{
Input: QwenInput{
Prompt: prompt,
BizParams: bizParams,
},
Parameters: QwenParameters{
SessionID: a.SessionID,
},
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request failed: %w", err)
}
url := fmt.Sprintf("%s/%s/completion", a.BaseURL, a.AppID)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+a.APIKey)
resp, err := a.Client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var result QwenResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("unmarshal response failed: %w, body: %s", err, string(body))
}
if result.Output.SessionID != "" {
a.SessionID = result.Output.SessionID
}
if result.Code != "" {
return &result, fmt.Errorf("API error: code=%s, message=%s", result.Code, result.Message)
}
return &result, nil
}
// ResetSession 重置会话
func (a *QwenAgent) ResetSession() {
a.SessionID = ""
}
// ========== 标准 OpenAI 兼容 API ==========
// ChatCompletionRequest OpenAI 兼容格式请求
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
}
// ChatCompletionMessage 消息结构
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// ChatCompletionResponse OpenAI 兼容格式响应
type ChatCompletionResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Choices []struct {
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
Error *struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error,omitempty"`
}
// ChatWithModel 使用标准 OpenAI 兼容 API 调用指定模型
func (a *QwenAgent) ChatWithModel(ctx context.Context, model, prompt string) (*ChatCompletionResponse, error) {
reqBody := ChatCompletionRequest{
Model: model,
Messages: []ChatCompletionMessage{
{Role: "user", Content: prompt},
},
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request failed: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", QwenCompatibleURL, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+a.APIKey)
resp, err := a.Client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var result ChatCompletionResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("unmarshal response failed: %w, body: %s", err, string(body))
}
if result.Error != nil {
return &result, fmt.Errorf("API error: code=%s, message=%s", result.Error.Code, result.Error.Message)
}
return &result, nil
}
// GetContent 从响应中获取内容
func (r *ChatCompletionResponse) GetContent() string {
if len(r.Choices) > 0 {
return r.Choices[0].Message.Content
}
return ""
}
-425
View File
@@ -1,425 +0,0 @@
package llm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"testing"
"time"
)
// 阿里云百炼平台配置 (从环境变量获取)
var (
QwenAppID = os.Getenv("QWEN_APP_ID")
QwenAPIKey = os.Getenv("QWEN_API_KEY")
)
// ============== 测试用例 ==============
// TestQwenBasicChat 测试基本同步对话
func TestQwenBasicChat(t *testing.T) {
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx := context.Background()
prompt := "你好,请用一句话介绍你自己"
t.Logf("用户: %s", prompt)
start := time.Now()
resp, err := agent.Chat(ctx, prompt)
elapsed := time.Since(start)
if err != nil {
t.Fatalf("Chat failed: %v", err)
}
if resp.Output.Text == "" {
t.Fatal("Empty response text")
}
t.Logf("助手: %s", resp.Output.Text)
t.Logf("耗时: %v, Token: %d", elapsed, resp.Usage.TotalTokens)
}
// TestQwenStreamChat 测试流式输出
func TestQwenStreamChat(t *testing.T) {
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx := context.Background()
prompt := "请用3句话解释什么是量化交易"
t.Logf("用户: %s", prompt)
var fullText strings.Builder
start := time.Now()
err := agent.ChatStream(ctx, prompt, func(chunk string) {
fullText.WriteString(chunk)
})
elapsed := time.Since(start)
if err != nil {
t.Fatalf("ChatStream failed: %v", err)
}
if fullText.Len() == 0 {
t.Fatal("Empty stream response")
}
t.Logf("助手: %s", fullText.String())
t.Logf("耗时: %v, 字符数: %d", elapsed, fullText.Len())
}
// TestQwenMultiTurn 测试多轮对话(上下文记忆)
func TestQwenMultiTurn(t *testing.T) {
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx := context.Background()
// 第一轮:设置上下文
resp1, err := agent.Chat(ctx, "我叫小明,我是一名 Go 程序员,请记住这些信息")
if err != nil {
t.Fatalf("Round 1 failed: %v", err)
}
t.Logf("[Round 1] 用户: 我叫小明,我是一名 Go 程序员")
t.Logf("[Round 1] 助手: %s", resp1.Output.Text)
t.Logf("[Round 1] SessionID: %s", agent.SessionID)
// 第二轮:验证记忆
resp2, err := agent.Chat(ctx, "请问我叫什么名字?我是做什么的?")
if err != nil {
t.Fatalf("Round 2 failed: %v", err)
}
t.Logf("[Round 2] 用户: 请问我叫什么名字?我是做什么的?")
t.Logf("[Round 2] 助手: %s", resp2.Output.Text)
// 检查是否记住了信息
text := strings.ToLower(resp2.Output.Text)
if !strings.Contains(text, "小明") && !strings.Contains(text, "go") {
t.Logf("警告: 模型可能没有正确记住上下文")
}
}
// TestQwenResetSession 测试重置会话
func TestQwenResetSession(t *testing.T) {
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx := context.Background()
// 建立上下文
resp1, err := agent.Chat(ctx, "记住这个密码: ABC123XYZ")
if err != nil {
t.Fatalf("Setup context failed: %v", err)
}
t.Logf("设置上下文: %s", resp1.Output.Text)
oldSession := agent.SessionID
t.Logf("原 SessionID: %s", oldSession)
// 重置会话
agent.ResetSession()
t.Log("会话已重置")
// 新对话 - 应该不记得之前的内容
resp2, err := agent.Chat(ctx, "我之前告诉你的密码是什么?")
if err != nil {
t.Fatalf("New session chat failed: %v", err)
}
t.Logf("新对话回复: %s", resp2.Output.Text)
t.Logf("新 SessionID: %s", agent.SessionID)
if oldSession == agent.SessionID {
t.Error("Session was not reset properly")
}
}
// TestQwenCodeGeneration 测试代码生成能力
func TestQwenCodeGeneration(t *testing.T) {
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx := context.Background()
prompt := "请用 Go 语言写一个计算移动平均线(MA)的函数,输入是 []float64 价格切片和 int 周期"
t.Logf("用户: %s", prompt)
resp, err := agent.Chat(ctx, prompt)
if err != nil {
t.Fatalf("Code generation failed: %v", err)
}
t.Logf("助手:\n%s", resp.Output.Text)
// 检查是否包含代码特征
text := resp.Output.Text
if !strings.Contains(text, "func") || !strings.Contains(text, "float64") {
t.Log("警告: 响应可能不包含有效的 Go 代码")
}
}
// TestQwenJSONOutput 测试 JSON 格式输出
func TestQwenJSONOutput(t *testing.T) {
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx := context.Background()
prompt := `请分析 BTC 的基本信息,以纯 JSON 格式返回(不要 markdown 代码块),包含以下字段:
{"name": "资产名称", "type": "资产类型", "risk": 1-10的风险等级数字}
只返回 JSON 对象,不要任何其他文字`
t.Logf("用户: %s", prompt)
resp, err := agent.Chat(ctx, prompt)
if err != nil {
t.Fatalf("JSON output test failed: %v", err)
}
t.Logf("助手: %s", resp.Output.Text)
// 尝试解析 JSON
text := resp.Output.Text
// 提取 JSON 部分
start := strings.Index(text, "{")
end := strings.LastIndex(text, "}")
if start != -1 && end != -1 && end > start {
jsonStr := text[start : end+1]
var result map[string]interface{}
if err := json.Unmarshal([]byte(jsonStr), &result); err != nil {
t.Logf("JSON 解析失败: %v", err)
} else {
t.Logf("JSON 解析成功: %+v", result)
}
}
}
// TestQwenLongResponse 测试长文本生成
func TestQwenLongResponse(t *testing.T) {
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx := context.Background()
prompt := "请详细介绍加密货币永续合约交易中的风险管理策略,包括止损设置、仓位管理、杠杆选择、资金费率考虑等方面,至少500字"
t.Logf("用户: %s", prompt)
start := time.Now()
resp, err := agent.Chat(ctx, prompt)
elapsed := time.Since(start)
if err != nil {
t.Fatalf("Long response test failed: %v", err)
}
text := resp.Output.Text
t.Logf("响应长度: %d 字符", len(text))
t.Logf("耗时: %v", elapsed)
t.Logf("Token 使用: input=%d, output=%d, total=%d",
resp.Usage.InputTokens, resp.Usage.OutputTokens, resp.Usage.TotalTokens)
// 只显示前500字符
if len(text) > 500 {
t.Logf("助手(前500字): %s...", text[:500])
} else {
t.Logf("助手: %s", text)
}
}
// TestQwenTradingScenario 测试交易场景问答
func TestQwenTradingScenario(t *testing.T) {
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx := context.Background()
questions := []string{
"BTC 当前价格 95000 美元,RSI 在 75 附近,MACD 金叉,你建议现在开多还是开空?简短回答",
"如果我有 10000 USDT,想用 10 倍杠杆做多 ETH,建议开多大仓位?",
"什么是资金费率?正的资金费率对多头有什么影响?",
}
for i, q := range questions {
agent.ResetSession() // 每个问题独立
t.Logf("\n[问题%d] %s", i+1, q)
resp, err := agent.Chat(ctx, q)
if err != nil {
t.Errorf("Question %d failed: %v", i+1, err)
continue
}
// 截取显示
text := resp.Output.Text
if len(text) > 300 {
text = text[:300] + "..."
}
t.Logf("[回答%d] %s", i+1, text)
}
}
// TestQwenErrorHandling 测试错误处理
func TestQwenErrorHandling(t *testing.T) {
ctx := context.Background()
// 测试无效 API Key
t.Run("InvalidAPIKey", func(t *testing.T) {
agent := NewQwenAgent(QwenAppID, "invalid-api-key")
_, err := agent.Chat(ctx, "测试")
if err == nil {
t.Log("警告: 无效 API Key 没有返回错误")
} else {
t.Logf("预期错误: %v", err)
}
})
// 测试无效 App ID
t.Run("InvalidAppID", func(t *testing.T) {
agent := NewQwenAgent("invalid-app-id", QwenAPIKey)
_, err := agent.Chat(ctx, "测试")
if err == nil {
t.Log("警告: 无效 App ID 没有返回错误")
} else {
t.Logf("预期错误: %v", err)
}
})
}
// TestQwenSpecialCharacters 测试特殊字符处理
func TestQwenSpecialCharacters(t *testing.T) {
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx := context.Background()
testCases := []string{
"请解释这个表情: 😀🎉🚀",
"中英文混合: Hello世界!",
"特殊符号: <>&\"'",
}
for _, prompt := range testCases {
agent.ResetSession()
t.Logf("用户: %s", prompt)
resp, err := agent.Chat(ctx, prompt)
if err != nil {
t.Errorf("特殊字符测试失败: %v", err)
continue
}
if len(resp.Output.Text) > 100 {
t.Logf("助手: %s...", resp.Output.Text[:100])
} else {
t.Logf("助手: %s", resp.Output.Text)
}
}
}
// TestQwenConcurrentSessions 测试并发会话
func TestQwenConcurrentSessions(t *testing.T) {
agent1 := NewQwenAgent(QwenAppID, QwenAPIKey)
agent2 := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx := context.Background()
// Agent1 对话
resp1, err := agent1.Chat(ctx, "我是 Alice,请记住")
if err != nil {
t.Fatalf("Agent1 chat failed: %v", err)
}
t.Logf("[Agent1] 设置: 我是 Alice -> %s", resp1.Output.Text[:min(100, len(resp1.Output.Text))])
// Agent2 对话
resp2, err := agent2.Chat(ctx, "我是 Bob,请记住")
if err != nil {
t.Fatalf("Agent2 chat failed: %v", err)
}
t.Logf("[Agent2] 设置: 我是 Bob -> %s", resp2.Output.Text[:min(100, len(resp2.Output.Text))])
// 验证会话隔离
resp1Check, _ := agent1.Chat(ctx, "我叫什么?")
resp2Check, _ := agent2.Chat(ctx, "我叫什么?")
t.Logf("[Agent1] 验证: %s", resp1Check.Output.Text[:min(100, len(resp1Check.Output.Text))])
t.Logf("[Agent2] 验证: %s", resp2Check.Output.Text[:min(100, len(resp2Check.Output.Text))])
if agent1.SessionID == agent2.SessionID {
t.Error("两个 Agent 的 SessionID 不应该相同")
} else {
t.Logf("Session 隔离正常: Agent1=%s..., Agent2=%s...",
agent1.SessionID[:min(20, len(agent1.SessionID))],
agent2.SessionID[:min(20, len(agent2.SessionID))])
}
}
// TestQwenTimeout 测试超时处理
func TestQwenTimeout(t *testing.T) {
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
agent.Client.Timeout = 1 * time.Millisecond // 极短超时
ctx := context.Background()
_, err := agent.Chat(ctx, "测试超时")
if err == nil {
t.Log("警告: 极短超时没有触发错误")
} else {
t.Logf("预期超时错误: %v", err)
}
// 恢复正常超时
agent.Client.Timeout = 120 * time.Second
}
// TestQwenContextCancel 测试上下文取消
func TestQwenContextCancel(t *testing.T) {
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx, cancel := context.WithCancel(context.Background())
cancel() // 立即取消
_, err := agent.Chat(ctx, "测试取消")
if err == nil {
t.Error("取消的上下文应该返回错误")
} else {
t.Logf("预期取消错误: %v", err)
}
}
// TestQwenWithBizParams 测试带业务参数的调用
func TestQwenWithBizParams(t *testing.T) {
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx := context.Background()
// 构造带业务参数的请求
reqBody := QwenRequest{
Input: QwenInput{
Prompt: "根据提供的用户信息,给出个性化的投资建议",
BizParams: map[string]interface{}{
"user_risk_level": "moderate",
"capital": 10000,
"experience": "intermediate",
},
},
}
jsonData, _ := json.Marshal(reqBody)
url := fmt.Sprintf("%s/%s/completion", agent.BaseURL, agent.AppID)
req, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+agent.APIKey)
resp, err := agent.Client.Do(req)
if err != nil {
t.Fatalf("Request with biz params failed: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var result QwenResponse
json.Unmarshal(body, &result)
if result.Output.Text != "" {
t.Logf("带业务参数响应: %s", result.Output.Text[:min(200, len(result.Output.Text))])
} else {
t.Logf("响应: %s", string(body))
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
-737
View File
@@ -1,737 +0,0 @@
package llm
import (
"context"
"encoding/json"
"fmt"
"math"
"nofx/market"
"nofx/provider/coinank"
"nofx/provider/coinank/coinank_api"
"nofx/provider/coinank/coinank_enum"
"regexp"
"strconv"
"strings"
"testing"
"time"
)
// IndicatorResult AI 计算的指标结果
type IndicatorResult struct {
EMA12 float64 `json:"ema12"`
EMA26 float64 `json:"ema26"`
MACD float64 `json:"macd"`
RSI14 float64 `json:"rsi14"`
BOLLUp float64 `json:"boll_upper"`
BOLLMid float64 `json:"boll_middle"`
BOLLLow float64 `json:"boll_lower"`
ATR14 float64 `json:"atr14"`
SMA20 float64 `json:"sma20"`
}
// 本地计算指标(使用 market 包的函数)
func calculateLocalIndicators(klines []market.Kline) IndicatorResult {
result := IndicatorResult{}
if len(klines) >= 12 {
result.EMA12 = market.ExportCalculateEMA(klines, 12)
}
if len(klines) >= 26 {
result.EMA26 = market.ExportCalculateEMA(klines, 26)
result.MACD = market.ExportCalculateMACD(klines)
}
if len(klines) > 14 {
result.RSI14 = market.ExportCalculateRSI(klines, 14)
}
if len(klines) >= 20 {
result.BOLLUp, result.BOLLMid, result.BOLLLow = market.ExportCalculateBOLL(klines, 20, 2.0)
// SMA20 就是 BOLL 中轨
result.SMA20 = result.BOLLMid
}
if len(klines) > 14 {
result.ATR14 = market.ExportCalculateATR(klines, 14)
}
return result
}
// 格式化 K 线数据为文本,发给 AI
func formatKlinesForAI(klines []market.Kline) string {
var sb strings.Builder
sb.WriteString("以下是K线数据(从旧到新排列):\n")
sb.WriteString("序号 | 时间 | 开盘价 | 最高价 | 最低价 | 收盘价 | 成交量\n")
sb.WriteString("-----|------|--------|--------|--------|--------|--------\n")
for i, k := range klines {
t := time.UnixMilli(k.OpenTime)
sb.WriteString(fmt.Sprintf("%d | %s | %.2f | %.2f | %.2f | %.2f | %.2f\n",
i+1, t.Format("01-02 15:04"), k.Open, k.High, k.Low, k.Close, k.Volume))
}
return sb.String()
}
// 构建 AI 计算指标的 prompt
func buildIndicatorPrompt(klines []market.Kline) string {
klinesText := formatKlinesForAI(klines)
prompt := fmt.Sprintf(`%s
请根据以上 %d 根K线数据,计算以下技术指标(使用标准算法):
1. EMA12(12周期指数移动平均线)
2. EMA26(26周期指数移动平均线)
3. MACDEMA12 - EMA26
4. RSI1414周期相对强弱指标,使用Wilder平滑法)
5. BOLL布林带(20周期,2倍标准差):上轨、中轨、下轨
6. ATR1414周期平均真实波幅,使用Wilder平滑法)
7. SMA20(20周期简单移动平均线)
请严格按照以下 JSON 格式返回结果,不要添加任何其他文字:
{
"ema12": 数值,
"ema26": 数值,
"macd": 数值,
"rsi14": 数值,
"boll_upper": 数值,
"boll_middle": 数值,
"boll_lower": 数值,
"atr14": 数值,
"sma20": 数值
}
注意:
- 所有数值保留2位小数
- EMA计算使用SMA作为初始值,乘数为 2/(period+1)
- RSI使用Wilder平滑法
- 只返回JSON,不要解释过程`, klinesText, len(klines))
return prompt
}
// 从 AI 响应中提取 JSON
func extractJSONFromResponse(text string) (IndicatorResult, error) {
var result IndicatorResult
// 尝试直接解析
if err := json.Unmarshal([]byte(text), &result); err == nil {
return result, nil
}
// 提取 JSON 部分
re := regexp.MustCompile(`\{[^{}]*"ema12"[^{}]*\}`)
match := re.FindString(text)
if match == "" {
// 尝试更宽松的匹配
start := strings.Index(text, "{")
end := strings.LastIndex(text, "}")
if start != -1 && end != -1 && end > start {
match = text[start : end+1]
}
}
if match == "" {
return result, fmt.Errorf("no JSON found in response: %s", text[:min(200, len(text))])
}
if err := json.Unmarshal([]byte(match), &result); err != nil {
return result, fmt.Errorf("parse JSON failed: %w, json: %s", err, match)
}
return result, nil
}
// 比较两个指标结果,返回误差百分比
func compareIndicators(local, ai IndicatorResult) map[string]float64 {
errors := make(map[string]float64)
calcError := func(name string, localVal, aiVal float64) {
if localVal == 0 {
if aiVal == 0 {
errors[name] = 0
} else {
errors[name] = 100 // 本地为0但AI不为0
}
return
}
errors[name] = math.Abs(localVal-aiVal) / math.Abs(localVal) * 100
}
calcError("EMA12", local.EMA12, ai.EMA12)
calcError("EMA26", local.EMA26, ai.EMA26)
calcError("MACD", local.MACD, ai.MACD)
calcError("RSI14", local.RSI14, ai.RSI14)
calcError("BOLL_UP", local.BOLLUp, ai.BOLLUp)
calcError("BOLL_MID", local.BOLLMid, ai.BOLLMid)
calcError("BOLL_LOW", local.BOLLLow, ai.BOLLLow)
calcError("ATR14", local.ATR14, ai.ATR14)
calcError("SMA20", local.SMA20, ai.SMA20)
return errors
}
// 生成测试用 K 线数据
func generateTestKlines(count int, basePrice float64) []market.Kline {
klines := make([]market.Kline, count)
price := basePrice
now := time.Now()
for i := 0; i < count; i++ {
// 模拟价格波动
change := (float64(i%7) - 3) * 0.5 // -1.5 到 +1.5 的波动
price = price + change
open := price
high := price + math.Abs(change)*0.5 + 0.5
low := price - math.Abs(change)*0.5 - 0.3
close := price + (change * 0.3)
klines[i] = market.Kline{
OpenTime: now.Add(time.Duration(-count+i) * time.Hour).UnixMilli(),
Open: open,
High: high,
Low: low,
Close: close,
Volume: 1000 + float64(i*100),
CloseTime: now.Add(time.Duration(-count+i+1) * time.Hour).UnixMilli(),
}
}
return klines
}
// TestQwenIndicatorCalculation 测试 AI 计算技术指标
func TestQwenIndicatorCalculation(t *testing.T) {
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx := context.Background()
// 生成 30 根测试 K 线
klines := generateTestKlines(30, 95000)
t.Log("===== K线数据 (最后5根) =====")
for i := len(klines) - 5; i < len(klines); i++ {
k := klines[i]
t.Logf(" [%d] O:%.2f H:%.2f L:%.2f C:%.2f", i+1, k.Open, k.High, k.Low, k.Close)
}
// 本地计算
t.Log("\n===== 本地计算结果 =====")
localResult := calculateLocalIndicators(klines)
t.Logf(" EMA12: %.2f", localResult.EMA12)
t.Logf(" EMA26: %.2f", localResult.EMA26)
t.Logf(" MACD: %.2f", localResult.MACD)
t.Logf(" RSI14: %.2f", localResult.RSI14)
t.Logf(" BOLL上轨: %.2f", localResult.BOLLUp)
t.Logf(" BOLL中轨: %.2f", localResult.BOLLMid)
t.Logf(" BOLL下轨: %.2f", localResult.BOLLLow)
t.Logf(" ATR14: %.2f", localResult.ATR14)
t.Logf(" SMA20: %.2f", localResult.SMA20)
// AI 计算
t.Log("\n===== 调用 AI 计算 =====")
prompt := buildIndicatorPrompt(klines)
t.Logf("Prompt 长度: %d 字符", len(prompt))
start := time.Now()
resp, err := agent.Chat(ctx, prompt)
elapsed := time.Since(start)
if err != nil {
t.Fatalf("AI 调用失败: %v", err)
}
t.Logf("AI 响应耗时: %v", elapsed)
t.Logf("AI 原始响应:\n%s", resp.Output.Text)
// 解析 AI 结果
aiResult, err := extractJSONFromResponse(resp.Output.Text)
if err != nil {
t.Fatalf("解析 AI 结果失败: %v", err)
}
t.Log("\n===== AI 计算结果 =====")
t.Logf(" EMA12: %.2f", aiResult.EMA12)
t.Logf(" EMA26: %.2f", aiResult.EMA26)
t.Logf(" MACD: %.2f", aiResult.MACD)
t.Logf(" RSI14: %.2f", aiResult.RSI14)
t.Logf(" BOLL上轨: %.2f", aiResult.BOLLUp)
t.Logf(" BOLL中轨: %.2f", aiResult.BOLLMid)
t.Logf(" BOLL下轨: %.2f", aiResult.BOLLLow)
t.Logf(" ATR14: %.2f", aiResult.ATR14)
t.Logf(" SMA20: %.2f", aiResult.SMA20)
// 对比结果
t.Log("\n===== 误差对比 (%) =====")
errors := compareIndicators(localResult, aiResult)
totalError := 0.0
for name, errPct := range errors {
status := "✓"
if errPct > 5 {
status = "⚠"
}
if errPct > 10 {
status = "✗"
}
t.Logf(" %s %s: %.2f%%", status, name, errPct)
totalError += errPct
}
avgError := totalError / float64(len(errors))
t.Logf("\n 平均误差: %.2f%%", avgError)
if avgError > 10 {
t.Logf("警告: AI 计算误差较大,可能算法理解有差异")
} else if avgError < 5 {
t.Log("AI 计算精度良好!")
}
}
// TestQwenIndicatorWithRealKlines 使用真实 K 线测试
func TestQwenIndicatorWithRealKlines(t *testing.T) {
// 尝试获取真实 K 线数据
client := market.NewAPIClient()
klines, err := client.GetKlines("BTC", "1h", 30)
if err != nil {
t.Skipf("获取真实 K 线失败,跳过测试: %v", err)
return
}
if len(klines) < 26 {
t.Skipf("K 线数量不足: %d", len(klines))
return
}
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx := context.Background()
t.Logf("获取到 %d 根 BTC 1h K线", len(klines))
t.Log("最新价格:", klines[len(klines)-1].Close)
// 本地计算
localResult := calculateLocalIndicators(klines)
t.Log("\n===== 本地计算 =====")
t.Logf(" EMA12: %.2f, EMA26: %.2f, MACD: %.2f", localResult.EMA12, localResult.EMA26, localResult.MACD)
t.Logf(" RSI14: %.2f", localResult.RSI14)
t.Logf(" BOLL: %.2f / %.2f / %.2f", localResult.BOLLUp, localResult.BOLLMid, localResult.BOLLLow)
// AI 计算
prompt := buildIndicatorPrompt(klines)
resp, err := agent.Chat(ctx, prompt)
if err != nil {
t.Fatalf("AI 调用失败: %v", err)
}
t.Log("\n===== AI 响应 =====")
t.Log(resp.Output.Text)
aiResult, err := extractJSONFromResponse(resp.Output.Text)
if err != nil {
t.Logf("解析失败: %v", err)
return
}
// 对比
errors := compareIndicators(localResult, aiResult)
t.Log("\n===== 误差 =====")
for name, errPct := range errors {
t.Logf(" %s: %.2f%%", name, errPct)
}
}
// TestQwenIndicatorMultiTimeframe 测试多个时间周期
func TestQwenIndicatorMultiTimeframe(t *testing.T) {
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx := context.Background()
timeframes := []struct {
name string
count int
price float64
}{
{"5m周期", 30, 95000},
{"1h周期", 50, 95000},
{"4h周期", 40, 95000},
}
for _, tf := range timeframes {
t.Run(tf.name, func(t *testing.T) {
klines := generateTestKlines(tf.count, tf.price)
localResult := calculateLocalIndicators(klines)
// 简化的 prompt
prompt := buildSimpleIndicatorPrompt(klines)
resp, err := agent.Chat(ctx, prompt)
if err != nil {
t.Fatalf("AI 调用失败: %v", err)
}
aiResult, err := extractJSONFromResponse(resp.Output.Text)
if err != nil {
t.Logf("解析失败: %v", err)
t.Logf("AI 响应: %s", resp.Output.Text[:min(500, len(resp.Output.Text))])
return
}
errors := compareIndicators(localResult, aiResult)
// 计算平均误差
total := 0.0
for _, e := range errors {
total += e
}
avgErr := total / float64(len(errors))
t.Logf("本地 MACD: %.2f, AI MACD: %.2f, 误差: %.2f%%", localResult.MACD, aiResult.MACD, errors["MACD"])
t.Logf("本地 RSI: %.2f, AI RSI: %.2f, 误差: %.2f%%", localResult.RSI14, aiResult.RSI14, errors["RSI14"])
t.Logf("平均误差: %.2f%%", avgErr)
})
time.Sleep(2 * time.Second) // 避免请求过快
}
}
// 简化的 prompt
func buildSimpleIndicatorPrompt(klines []market.Kline) string {
// 只提供收盘价序列,减少 token
var prices []string
for _, k := range klines {
prices = append(prices, fmt.Sprintf("%.2f", k.Close))
}
return fmt.Sprintf(`收盘价序列(从旧到新): [%s]
请计算技术指标并返回 JSON
- ema12: 12周期EMA
- ema26: 26周期EMA
- macd: EMA12-EMA26
- rsi14: 14周期RSI(Wilder平滑)
- boll_upper, boll_middle, boll_lower: 20周期BOLL(2倍标准差)
- atr14: 0 (无高低价数据)
- sma20: 20周期SMA
只返回JSON格式:{"ema12":数值,"ema26":数值,...}`, strings.Join(prices, ","))
}
// TestQwenIndicatorAccuracy 精度测试:使用简单数据验证算法
func TestQwenIndicatorAccuracy(t *testing.T) {
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
ctx := context.Background()
// 使用简单递增数据,便于验证
prices := []float64{
100, 101, 102, 103, 104, 105, 106, 107, 108, 109, // 1-10
110, 111, 112, 113, 114, 115, 116, 117, 118, 119, // 11-20
120, 121, 122, 123, 124, 125, 126, 127, 128, 129, // 21-30
}
// 构建 K 线
klines := make([]market.Kline, len(prices))
for i, p := range prices {
klines[i] = market.Kline{
Open: p - 0.5,
High: p + 1,
Low: p - 1,
Close: p,
}
}
// 本地计算
localResult := calculateLocalIndicators(klines)
t.Log("===== 简单递增数据测试 =====")
t.Logf("价格序列: %v", prices)
t.Logf("本地计算:")
t.Logf(" SMA20 = %.4f (理论值: 119.5)", localResult.SMA20)
t.Logf(" EMA12 = %.4f", localResult.EMA12)
t.Logf(" RSI14 = %.4f (持续上涨应接近100)", localResult.RSI14)
// AI 计算
var priceStrs []string
for _, p := range prices {
priceStrs = append(priceStrs, strconv.FormatFloat(p, 'f', 0, 64))
}
prompt := fmt.Sprintf(`收盘价序列: [%s]
请计算:
1. SMA20 (20周期简单移动平均)
2. EMA12 (12周期指数移动平均,初始值用SMA,乘数=2/13)
3. RSI14 (14周期RSIWilder平滑法)
返回JSON: {"sma20":数值,"ema12":数值,"rsi14":数值}
只返回JSON`, strings.Join(priceStrs, ","))
resp, err := agent.Chat(ctx, prompt)
if err != nil {
t.Fatalf("AI 调用失败: %v", err)
}
t.Logf("\nAI 响应: %s", resp.Output.Text)
// 简单解析
var aiSimple struct {
SMA20 float64 `json:"sma20"`
EMA12 float64 `json:"ema12"`
RSI14 float64 `json:"rsi14"`
}
text := resp.Output.Text
start := strings.Index(text, "{")
end := strings.LastIndex(text, "}")
if start != -1 && end > start {
json.Unmarshal([]byte(text[start:end+1]), &aiSimple)
}
t.Logf("\nAI 计算:")
t.Logf(" SMA20 = %.4f", aiSimple.SMA20)
t.Logf(" EMA12 = %.4f", aiSimple.EMA12)
t.Logf(" RSI14 = %.4f", aiSimple.RSI14)
// 验证 SMA20 (理论值应该是 110+...+129 的平均 = 119.5)
expectedSMA := 119.5
if math.Abs(aiSimple.SMA20-expectedSMA) < 0.1 {
t.Log("\n✓ AI 的 SMA20 计算正确!")
} else {
t.Logf("\n✗ AI 的 SMA20 有误差,期望 %.2f", expectedSMA)
}
}
// coinankKlinesToMarket 将 coinank K线转换为 market.Kline
func coinankKlinesToMarket(klines []coinank.KlineResult) []market.Kline {
result := make([]market.Kline, len(klines))
for i, k := range klines {
result[i] = market.Kline{
OpenTime: k.StartTime,
Open: k.Open,
High: k.High,
Low: k.Low,
Close: k.Close,
Volume: k.Volume,
CloseTime: k.EndTime,
}
}
return result
}
// TestQwenETHMultiTimeframe 使用 Coinank 免费 API 获取真实 ETH 数据测试多周期指标
func TestQwenETHMultiTimeframe(t *testing.T) {
ctx := context.Background()
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
// 测试多个时间周期
timeframes := []struct {
name string
interval coinank_enum.Interval
size int
}{
{"5分钟", coinank_enum.Minute5, 50},
{"1小时", coinank_enum.Hour1, 50},
{"4小时", coinank_enum.Hour4, 50},
{"日线", coinank_enum.Day1, 30},
}
now := time.Now()
for _, tf := range timeframes {
t.Run(tf.name, func(t *testing.T) {
// 使用 coinank 免费 API 获取 ETH K线数据
coinankKlines, err := coinank_api.Kline(ctx, "ETHUSDT", coinank_enum.Binance,
now.UnixMilli(), coinank_enum.To, tf.size, tf.interval)
if err != nil {
t.Fatalf("获取 %s K线失败: %v", tf.name, err)
}
if len(coinankKlines) < 26 {
t.Skipf("K线数量不足: %d", len(coinankKlines))
return
}
// 转换为 market.Kline
klines := coinankKlinesToMarket(coinankKlines)
t.Logf("获取到 %d 根 ETH %s K线", len(klines), tf.name)
t.Logf("最新收盘价: %.2f, 时间: %s",
klines[len(klines)-1].Close,
time.UnixMilli(klines[len(klines)-1].CloseTime).Format("2006-01-02 15:04"))
// 本地计算
localResult := calculateLocalIndicators(klines)
t.Log("\n===== 本地计算 =====")
t.Logf(" EMA12: %.2f, EMA26: %.2f, MACD: %.4f",
localResult.EMA12, localResult.EMA26, localResult.MACD)
t.Logf(" RSI14: %.2f", localResult.RSI14)
t.Logf(" BOLL: %.2f / %.2f / %.2f",
localResult.BOLLUp, localResult.BOLLMid, localResult.BOLLLow)
t.Logf(" ATR14: %.4f", localResult.ATR14)
// AI 计算 - 使用简化 prompt(只发收盘价)
prompt := buildSimpleIndicatorPrompt(klines)
t.Logf("\nPrompt 长度: %d 字符", len(prompt))
start := time.Now()
resp, err := agent.Chat(ctx, prompt)
elapsed := time.Since(start)
if err != nil {
t.Fatalf("AI 调用失败: %v", err)
}
t.Logf("AI 响应耗时: %v", elapsed)
// 解析 AI 结果
aiResult, err := extractJSONFromResponse(resp.Output.Text)
if err != nil {
t.Logf("AI 原始响应:\n%s", resp.Output.Text[:min(500, len(resp.Output.Text))])
t.Fatalf("解析失败: %v", err)
}
t.Log("\n===== AI 计算 =====")
t.Logf(" EMA12: %.2f, EMA26: %.2f, MACD: %.4f",
aiResult.EMA12, aiResult.EMA26, aiResult.MACD)
t.Logf(" RSI14: %.2f", aiResult.RSI14)
t.Logf(" BOLL: %.2f / %.2f / %.2f",
aiResult.BOLLUp, aiResult.BOLLMid, aiResult.BOLLLow)
// 对比误差
t.Log("\n===== 误差对比 =====")
errors := compareIndicators(localResult, aiResult)
totalErr := 0.0
for name, errPct := range errors {
status := "✓"
if errPct > 1 {
status = "⚠"
}
if errPct > 5 {
status = "✗"
}
t.Logf(" %s %-10s: %.2f%%", status, name, errPct)
totalErr += errPct
}
avgErr := totalErr / float64(len(errors))
t.Logf("\n 平均误差: %.2f%%", avgErr)
if avgErr < 1 {
t.Log(" ✓ AI 计算精度优秀!")
} else if avgErr < 5 {
t.Log(" ⚠ AI 计算精度良好")
} else {
t.Log(" ✗ AI 计算误差较大")
}
// 等待避免请求过快
time.Sleep(2 * time.Second)
})
}
}
// TestQwenETHIndicatorComparison ETH 指标对比:使用 Coinank 免费 API + Qwen 标准 API
func TestQwenETHIndicatorComparison(t *testing.T) {
ctx := context.Background()
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
// 使用 coinank 免费 API 获取 ETH 1小时 K线
now := time.Now()
coinankKlines, err := coinank_api.Kline(ctx, "ETHUSDT", coinank_enum.Binance,
now.UnixMilli(), coinank_enum.To, 30, coinank_enum.Hour1)
if err != nil {
t.Fatalf("获取 K线失败: %v", err)
}
// 转换为 market.Kline
klines := coinankKlinesToMarket(coinankKlines)
t.Logf("获取到 %d 根 ETH 1h K线", len(klines))
// 只用收盘价,简化 prompt
var prices []string
for _, k := range klines {
prices = append(prices, fmt.Sprintf("%.2f", k.Close))
}
// 本地计算
localResult := calculateLocalIndicators(klines)
t.Log("\n===== 本地计算结果 =====")
t.Logf("SMA20: %.2f", localResult.SMA20)
t.Logf("EMA12: %.2f", localResult.EMA12)
t.Logf("EMA26: %.2f", localResult.EMA26)
t.Logf("MACD: %.4f", localResult.MACD)
t.Logf("RSI14: %.2f", localResult.RSI14)
// 简化的 AI prompt
prompt := fmt.Sprintf(`ETH 最近30根1小时K线收盘价(从旧到新):
[%s]
请计算以下指标并返回纯 JSON:
1. sma20: 最后20个价格的简单移动平均
2. ema12: 12周期EMA(初始值用前12个价格的SMA,乘数=2/13)
3. ema26: 26周期EMA(初始值用前26个价格的SMA,乘数=2/27)
4. macd: EMA12 - EMA26
5. rsi14: 14周期RSIWilder平滑法)
只返回JSON格式: {"sma20":数值,"ema12":数值,"ema26":数值,"macd":数值,"rsi14":数值}
不要任何解释文字`, strings.Join(prices, ", "))
t.Logf("\n发送 Prompt (%d 字符)", len(prompt))
// 使用标准 API
resp, err := agent.ChatWithModel(ctx, "qwen-max", prompt)
if err != nil {
t.Fatalf("AI 调用失败: %v", err)
}
aiText := resp.GetContent()
t.Logf("\nAI 响应:\n%s", aiText)
// 解析
var aiResult struct {
SMA20 float64 `json:"sma20"`
EMA12 float64 `json:"ema12"`
EMA26 float64 `json:"ema26"`
MACD float64 `json:"macd"`
RSI14 float64 `json:"rsi14"`
}
start := strings.Index(aiText, "{")
end := strings.LastIndex(aiText, "}")
if start != -1 && end > start {
if err := json.Unmarshal([]byte(aiText[start:end+1]), &aiResult); err != nil {
t.Logf("JSON 解析失败: %v", err)
}
}
t.Log("\n===== AI 计算结果 =====")
t.Logf("SMA20: %.2f", aiResult.SMA20)
t.Logf("EMA12: %.2f", aiResult.EMA12)
t.Logf("EMA26: %.2f", aiResult.EMA26)
t.Logf("MACD: %.4f", aiResult.MACD)
t.Logf("RSI14: %.2f", aiResult.RSI14)
// 计算误差
t.Log("\n===== 误差 =====")
calcErr := func(name string, local, ai float64) {
if local == 0 {
t.Logf(" %s: 本地=0, AI=%.2f", name, ai)
return
}
errPct := math.Abs(local-ai) / math.Abs(local) * 100
status := "✓"
if errPct > 1 {
status = "⚠"
}
if errPct > 5 {
status = "✗"
}
t.Logf(" %s %s: 本地=%.2f, AI=%.2f, 误差=%.2f%%", status, name, local, ai, errPct)
}
calcErr("SMA20", localResult.SMA20, aiResult.SMA20)
calcErr("EMA12", localResult.EMA12, aiResult.EMA12)
calcErr("EMA26", localResult.EMA26, aiResult.EMA26)
calcErr("MACD", localResult.MACD, aiResult.MACD)
calcErr("RSI14", localResult.RSI14, aiResult.RSI14)
}
+3 -1
View File
@@ -10,6 +10,8 @@ import (
"nofx/logger"
"nofx/manager"
"nofx/mcp"
_ "nofx/mcp/payment"
_ "nofx/mcp/provider"
"nofx/store"
"nofx/telegram"
"os"
@@ -168,7 +170,7 @@ func newSharedMCPClient() mcp.AIClient {
logger.Warn("⚠️ DEEPSEEK_API_KEY not set, AI features will be unavailable")
return nil
}
return mcp.NewDeepSeekClient()
return mcp.NewAIClientByProvider("deepseek")
}
// initInstallationID initializes the anonymous installation ID for experience improvement
-248
View File
@@ -1,248 +0,0 @@
package mcp
import (
"encoding/json"
"net/http"
"testing"
)
// ── buildRequestBodyFromRequest ────────────────────────────────────────────────
func TestClaudeClient_BuildRequestBody_SystemPromptLifted(t *testing.T) {
c := newTestClaudeClient()
req := &Request{
Model: "claude-opus-4-6",
Messages: []Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
},
}
body := c.buildRequestBodyFromRequest(req)
if body["system"] != "You are helpful." {
t.Errorf("system not lifted to top level: %v", body["system"])
}
msgs := body["messages"].([]map[string]any)
if len(msgs) != 1 || msgs[0]["role"] != "user" {
t.Errorf("system message should be removed from messages array: %v", msgs)
}
}
func TestClaudeClient_BuildRequestBody_ToolsUseInputSchema(t *testing.T) {
c := newTestClaudeClient()
req := &Request{
Model: "claude-opus-4-6",
Messages: []Message{{Role: "user", Content: "hi"}},
Tools: []Tool{{
Type: "function",
Function: FunctionDef{
Name: "my_tool",
Description: "does stuff",
Parameters: map[string]any{"type": "object"},
},
}},
}
body := c.buildRequestBodyFromRequest(req)
tools, ok := body["tools"].([]map[string]any)
if !ok || len(tools) != 1 {
t.Fatalf("tools not set correctly: %v", body["tools"])
}
tool := tools[0]
if tool["name"] != "my_tool" {
t.Errorf("tool name wrong: %v", tool["name"])
}
if tool["input_schema"] == nil {
t.Error("tool must use input_schema, not parameters")
}
if _, hasParams := tool["parameters"]; hasParams {
t.Error("tool must NOT have parameters key (Anthropic uses input_schema)")
}
}
func TestClaudeClient_BuildRequestBody_ToolChoiceObject(t *testing.T) {
c := newTestClaudeClient()
req := &Request{
Model: "claude-opus-4-6",
Messages: []Message{{Role: "user", Content: "hi"}},
ToolChoice: "auto",
}
body := c.buildRequestBodyFromRequest(req)
tc, ok := body["tool_choice"].(map[string]any)
if !ok {
t.Fatalf("tool_choice must be an object, got: %T %v", body["tool_choice"], body["tool_choice"])
}
if tc["type"] != "auto" {
t.Errorf("tool_choice.type must be 'auto', got: %v", tc["type"])
}
}
// ── convertMessagesToAnthropic ─────────────────────────────────────────────────
func TestConvertMessages_AssistantToolCall(t *testing.T) {
msgs := []Message{
{
Role: "assistant",
ToolCalls: []ToolCall{{
ID: "tc1",
Type: "function",
Function: ToolCallFunction{Name: "api_request", Arguments: `{"method":"GET","path":"/api/x","body":{}}`},
}},
},
}
out := convertMessagesToAnthropic(msgs)
if len(out) != 1 {
t.Fatalf("expected 1 message, got %d", len(out))
}
msg := out[0]
if msg["role"] != "assistant" {
t.Errorf("role should be assistant: %v", msg["role"])
}
blocks := msg["content"].([]map[string]any)
if len(blocks) != 1 || blocks[0]["type"] != "tool_use" {
t.Errorf("content should be tool_use block: %v", blocks)
}
if blocks[0]["id"] != "tc1" {
t.Errorf("tool_use id wrong: %v", blocks[0]["id"])
}
// Input must be parsed JSON object, not a string.
input, ok := blocks[0]["input"].(map[string]any)
if !ok {
t.Errorf("tool_use input must be map, got %T", blocks[0]["input"])
}
if input["method"] != "GET" {
t.Errorf("input.method wrong: %v", input)
}
}
func TestConvertMessages_ToolResultMergedIntoUserTurn(t *testing.T) {
// Anthropic requires strictly alternating turns; consecutive tool results
// must be merged into a single user message.
msgs := []Message{
{Role: "tool", ToolCallID: "tc1", Content: `{"result":"a"}`},
{Role: "tool", ToolCallID: "tc2", Content: `{"result":"b"}`},
}
out := convertMessagesToAnthropic(msgs)
if len(out) != 1 {
t.Fatalf("consecutive tool results must be merged into one user turn, got %d messages", len(out))
}
if out[0]["role"] != "user" {
t.Errorf("tool results must become role=user: %v", out[0]["role"])
}
blocks := out[0]["content"].([]map[string]any)
if len(blocks) != 2 {
t.Errorf("expected 2 tool_result blocks, got %d", len(blocks))
}
if blocks[0]["type"] != "tool_result" || blocks[1]["type"] != "tool_result" {
t.Errorf("blocks should be tool_result: %v", blocks)
}
if blocks[0]["tool_use_id"] != "tc1" || blocks[1]["tool_use_id"] != "tc2" {
t.Errorf("tool_use_id mismatch: %v", blocks)
}
}
// ── parseMCPResponseFull ───────────────────────────────────────────────────────
func TestClaudeClient_ParseResponse_TextOnly(t *testing.T) {
c := newTestClaudeClient()
body := []byte(`{
"content": [{"type":"text","text":"Hello from Claude"}],
"usage": {"input_tokens": 10, "output_tokens": 5}
}`)
resp, err := c.parseMCPResponseFull(body)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Content != "Hello from Claude" {
t.Errorf("content mismatch: %q", resp.Content)
}
if len(resp.ToolCalls) != 0 {
t.Errorf("expected no tool calls: %v", resp.ToolCalls)
}
}
func TestClaudeClient_ParseResponse_ToolUse(t *testing.T) {
c := newTestClaudeClient()
body := []byte(`{
"content": [{
"type": "tool_use",
"id": "toolu_01abc",
"name": "api_request",
"input": {"method":"POST","path":"/api/strategies","body":{"name":"BTC策略"}}
}],
"usage": {"input_tokens": 100, "output_tokens": 30}
}`)
resp, err := c.parseMCPResponseFull(body)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(resp.ToolCalls))
}
tc := resp.ToolCalls[0]
if tc.ID != "toolu_01abc" {
t.Errorf("tool call ID wrong: %v", tc.ID)
}
if tc.Function.Name != "api_request" {
t.Errorf("function name wrong: %v", tc.Function.Name)
}
// Arguments must be a valid JSON string.
var args map[string]any
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
t.Errorf("arguments not valid JSON: %q — %v", tc.Function.Arguments, err)
}
if args["method"] != "POST" {
t.Errorf("args.method wrong: %v", args)
}
}
func TestClaudeClient_ParseResponse_APIError(t *testing.T) {
c := newTestClaudeClient()
body := []byte(`{"error":{"type":"authentication_error","message":"invalid x-api-key"}}`)
_, err := c.parseMCPResponseFull(body)
if err == nil {
t.Fatal("expected error for API error response")
}
if err.Error() == "" {
t.Error("error message should not be empty")
}
}
// ── Auth header ────────────────────────────────────────────────────────────────
func TestClaudeClient_SetAuthHeader(t *testing.T) {
c := newTestClaudeClient()
c.APIKey = "sk-ant-test123"
// net/http.Header canonicalizes keys (x-api-key → X-Api-Key).
h := make(http.Header)
c.setAuthHeader(h)
if got := h.Get("x-api-key"); got != "sk-ant-test123" {
t.Errorf("x-api-key header not set correctly: %q", got)
}
if h.Get("anthropic-version") == "" {
t.Error("anthropic-version header must be set")
}
// Must NOT use Authorization: Bearer (that's OpenAI format).
if h.Get("Authorization") != "" {
t.Error("Claude must use x-api-key, not Authorization header")
}
}
func TestClaudeClient_BuildUrl(t *testing.T) {
c := newTestClaudeClient()
url := c.buildUrl()
if url != DefaultClaudeBaseURL+"/messages" {
t.Errorf("URL should be /messages endpoint, got: %s", url)
}
}
// ── helpers ────────────────────────────────────────────────────────────────────
func newTestClaudeClient() *ClaudeClient {
return NewClaudeClientWithOptions().(*ClaudeClient)
}
+96 -109
View File
@@ -60,14 +60,14 @@ type Client struct {
UseFullURL bool // Whether to use full URL (without appending /chat/completions)
MaxTokens int // Maximum tokens for AI response
httpClient *http.Client
logger Logger // Logger (replaceable)
config *Config // Config object (stores all configurations)
HTTPClient *http.Client // Exported for sub-packages
Log Logger // Exported for sub-packages
Cfg *Config // Exported for sub-packages
// hooks are used to implement dynamic dispatch (polymorphism)
// When DeepSeekClient embeds Client, hooks point to DeepSeekClient
// This way methods called in call() are automatically dispatched to the overridden version in subclass
hooks clientHooks
// Hooks are used to implement dynamic dispatch (polymorphism)
// When provider.DeepSeekClient embeds Client, Hooks point to DeepSeekClient
// This way methods called in Call() are automatically dispatched to the overridden version
Hooks ClientHooks
}
// New creates default client (backward compatible)
@@ -80,21 +80,22 @@ func New() AIClient {
// NewClient creates client (supports options pattern)
//
// Usage examples:
// // Basic usage (backward compatible)
// client := mcp.NewClient()
//
// // Custom logger
// client := mcp.NewClient(mcp.WithLogger(customLogger))
// // Basic usage (backward compatible)
// client := mcp.NewClient()
//
// // Custom timeout
// client := mcp.NewClient(mcp.WithTimeout(60*time.Second))
// // Custom logger
// client := mcp.NewClient(mcp.WithLogger(customLogger))
//
// // Combine multiple options
// client := mcp.NewClient(
// mcp.WithDeepSeekConfig("sk-xxx"),
// mcp.WithLogger(customLogger),
// mcp.WithTimeout(60*time.Second),
// )
// // Custom timeout
// client := mcp.NewClient(mcp.WithTimeout(60*time.Second))
//
// // Combine multiple options
// client := mcp.NewClient(
// mcp.WithDeepSeekConfig("sk-xxx"),
// mcp.WithLogger(customLogger),
// mcp.WithTimeout(60*time.Second),
// )
func NewClient(opts ...ClientOption) AIClient {
// 1. Create default config
cfg := DefaultConfig()
@@ -112,9 +113,9 @@ func NewClient(opts ...ClientOption) AIClient {
Model: cfg.Model,
MaxTokens: cfg.MaxTokens,
UseFullURL: cfg.UseFullURL,
httpClient: cfg.HTTPClient,
logger: cfg.Logger,
config: cfg,
HTTPClient: cfg.HTTPClient,
Log: cfg.Logger,
Cfg: cfg,
}
// 4. Set default Provider (if not set)
@@ -125,7 +126,7 @@ func NewClient(opts ...ClientOption) AIClient {
}
// 5. Set hooks to point to self
client.hooks = client
client.Hooks = client
return client
}
@@ -148,7 +149,7 @@ func (client *Client) SetAPIKey(apiKey, apiURL, customModel string) {
}
func (client *Client) SetTimeout(timeout time.Duration) {
client.httpClient.Timeout = timeout
client.HTTPClient.Timeout = timeout
}
// CallWithMessages template method - fixed retry flow (cannot be overridden)
@@ -159,32 +160,32 @@ func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string,
// Fixed retry flow
var lastErr error
maxRetries := client.config.MaxRetries
maxRetries := client.Cfg.MaxRetries
for attempt := 1; attempt <= maxRetries; attempt++ {
if attempt > 1 {
client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
client.Log.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
}
// Call the fixed single-call flow
result, err := client.hooks.call(systemPrompt, userPrompt)
result, err := client.Hooks.Call(systemPrompt, userPrompt)
if err == nil {
if attempt > 1 {
client.logger.Infof("✓ AI API retry succeeded")
client.Log.Infof("✓ AI API retry succeeded")
}
return result, nil
}
lastErr = err
// Check if error is retryable via hooks (supports custom retry strategy in subclass)
if !client.hooks.isRetryableError(err) {
// Check if error is retryable via hooks (supports custom retry strategy)
if !client.Hooks.IsRetryableError(err) {
return "", err
}
// Wait before retry
if attempt < maxRetries {
waitTime := client.config.RetryWaitBase * time.Duration(attempt)
client.logger.Infof("⏳ Waiting %v before retry...", waitTime)
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
client.Log.Infof("⏳ Waiting %v before retry...", waitTime)
time.Sleep(waitTime)
}
}
@@ -192,11 +193,11 @@ func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string,
return "", fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
}
func (client *Client) setAuthHeader(reqHeader http.Header) {
func (client *Client) SetAuthHeader(reqHeader http.Header) {
reqHeader.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey))
}
func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
func (client *Client) BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
// Build messages array
messages := []map[string]string{}
@@ -217,7 +218,7 @@ func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[s
requestBody := map[string]interface{}{
"model": client.Model,
"messages": messages,
"temperature": client.config.Temperature, // Use configured temperature
"temperature": client.Cfg.Temperature, // Use configured temperature
}
// OpenAI newer models use max_completion_tokens instead of max_tokens
if client.Provider == ProviderOpenAI {
@@ -228,8 +229,8 @@ func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[s
return requestBody
}
// can be used to marshal the request body and can be overridden
func (client *Client) marshalRequestBody(requestBody map[string]any) ([]byte, error) {
// MarshalRequestBody can be used to marshal the request body and can be overridden
func (client *Client) MarshalRequestBody(requestBody map[string]any) ([]byte, error) {
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to serialize request: %w", err)
@@ -237,17 +238,17 @@ func (client *Client) marshalRequestBody(requestBody map[string]any) ([]byte, er
return jsonData, nil
}
func (client *Client) parseMCPResponse(body []byte) (string, error) {
r, err := client.parseMCPResponseFull(body)
func (client *Client) ParseMCPResponse(body []byte) (string, error) {
r, err := client.ParseMCPResponseFull(body)
if err != nil {
return "", err
}
return r.Content, nil
}
// parseMCPResponseFull parses the OpenAI-format response body and returns both
// ParseMCPResponseFull parses the OpenAI-format response body and returns both
// the text content and any tool calls.
func (client *Client) parseMCPResponseFull(body []byte) (*LLMResponse, error) {
func (client *Client) ParseMCPResponseFull(body []byte) (*LLMResponse, error) {
var result struct {
Choices []struct {
Message struct {
@@ -288,14 +289,14 @@ func (client *Client) parseMCPResponseFull(body []byte) (*LLMResponse, error) {
}, nil
}
func (client *Client) buildUrl() string {
func (client *Client) BuildUrl() string {
if client.UseFullURL {
return client.BaseURL
}
return fmt.Sprintf("%s/chat/completions", client.BaseURL)
}
func (client *Client) buildRequest(url string, jsonData []byte) (*http.Request, error) {
func (client *Client) BuildRequest(url string, jsonData []byte) (*http.Request, error) {
// Create HTTP request
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
@@ -304,42 +305,42 @@ func (client *Client) buildRequest(url string, jsonData []byte) (*http.Request,
req.Header.Set("Content-Type", "application/json")
// Set auth header via hooks (supports overriding in subclass)
client.hooks.setAuthHeader(req.Header)
// Set auth header via hooks (supports overriding)
client.Hooks.SetAuthHeader(req.Header)
return req, nil
}
// call single AI API call (fixed flow, cannot be overridden)
func (client *Client) call(systemPrompt, userPrompt string) (string, error) {
// Call single AI API call (fixed flow, cannot be overridden)
func (client *Client) Call(systemPrompt, userPrompt string) (string, error) {
// Print current AI configuration
client.logger.Infof("📡 [%s] Request AI Server: BaseURL: %s", client.String(), client.BaseURL)
client.logger.Debugf("[%s] UseFullURL: %v", client.String(), client.UseFullURL)
client.Log.Infof("📡 [%s] Request AI Server: BaseURL: %s", client.String(), client.BaseURL)
client.Log.Debugf("[%s] UseFullURL: %v", client.String(), client.UseFullURL)
if len(client.APIKey) > 8 {
client.logger.Debugf("[%s] API Key: %s...%s", client.String(), client.APIKey[:4], client.APIKey[len(client.APIKey)-4:])
client.Log.Debugf("[%s] API Key: %s...%s", client.String(), client.APIKey[:4], client.APIKey[len(client.APIKey)-4:])
}
// Step 1: Build request body (via hooks for dynamic dispatch)
requestBody := client.hooks.buildMCPRequestBody(systemPrompt, userPrompt)
requestBody := client.Hooks.BuildMCPRequestBody(systemPrompt, userPrompt)
// Step 2: Serialize request body (via hooks for dynamic dispatch)
jsonData, err := client.hooks.marshalRequestBody(requestBody)
jsonData, err := client.Hooks.MarshalRequestBody(requestBody)
if err != nil {
return "", err
}
// Step 3: Build URL (via hooks for dynamic dispatch)
url := client.hooks.buildUrl()
client.logger.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
url := client.Hooks.BuildUrl()
client.Log.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
// Step 4: Create HTTP request (fixed logic)
req, err := client.hooks.buildRequest(url, jsonData)
req, err := client.Hooks.BuildRequest(url, jsonData)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
// Step 5: Send HTTP request (fixed logic)
resp, err := client.httpClient.Do(req)
resp, err := client.HTTPClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}
@@ -357,7 +358,7 @@ func (client *Client) call(systemPrompt, userPrompt string) (string, error) {
}
// Step 8: Parse response (via hooks for dynamic dispatch)
result, err := client.hooks.parseMCPResponse(body)
result, err := client.Hooks.ParseMCPResponse(body)
if err != nil {
return "", fmt.Errorf("fail to parse AI server response: %w", err)
}
@@ -370,11 +371,11 @@ func (client *Client) String() string {
client.Provider, client.Model)
}
// isRetryableError determines if error is retryable (network errors, timeouts, etc.)
func (client *Client) isRetryableError(err error) bool {
// IsRetryableError determines if error is retryable (network errors, timeouts, etc.)
func (client *Client) IsRetryableError(err error) bool {
errStr := err.Error()
// Network errors, timeouts, EOF, etc. can be retried
for _, retryable := range client.config.RetryableErrors {
for _, retryable := range client.Cfg.RetryableErrors {
if strings.Contains(errStr, retryable) {
return true
}
@@ -387,20 +388,6 @@ func (client *Client) isRetryableError(err error) bool {
// ============================================================
// CallWithRequest calls AI API using Request object (supports advanced features)
//
// This method supports:
// - Multi-turn conversation history
// - Fine-grained parameter control (temperature, top_p, penalties, etc.)
// - Function Calling / Tools
// - Streaming response (future support)
//
// Usage example:
// request := NewRequestBuilder().
// WithSystemPrompt("You are helpful").
// WithUserPrompt("Hello").
// WithTemperature(0.8).
// Build()
// result, err := client.CallWithRequest(request)
func (client *Client) CallWithRequest(req *Request) (string, error) {
if client.APIKey == "" {
return "", fmt.Errorf("AI API key not set, please call SetAPIKey first")
@@ -413,32 +400,32 @@ func (client *Client) CallWithRequest(req *Request) (string, error) {
// Fixed retry flow
var lastErr error
maxRetries := client.config.MaxRetries
maxRetries := client.Cfg.MaxRetries
for attempt := 1; attempt <= maxRetries; attempt++ {
if attempt > 1 {
client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
client.Log.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
}
// Call single request
result, err := client.callWithRequest(req)
if err == nil {
if attempt > 1 {
client.logger.Infof("✓ AI API retry succeeded")
client.Log.Infof("✓ AI API retry succeeded")
}
return result, nil
}
lastErr = err
// Check if error is retryable
if !client.hooks.isRetryableError(err) {
if !client.Hooks.IsRetryableError(err) {
return "", err
}
// Wait before retry
if attempt < maxRetries {
waitTime := client.config.RetryWaitBase * time.Duration(attempt)
client.logger.Infof("⏳ Waiting %v before retry...", waitTime)
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
client.Log.Infof("⏳ Waiting %v before retry...", waitTime)
time.Sleep(waitTime)
}
}
@@ -456,21 +443,21 @@ func (client *Client) CallWithRequestFull(req *Request) (*LLMResponse, error) {
}
var lastErr error
maxRetries := client.config.MaxRetries
maxRetries := client.Cfg.MaxRetries
for attempt := 1; attempt <= maxRetries; attempt++ {
if attempt > 1 {
client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
client.Log.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
}
result, err := client.callWithRequestFull(req)
if err == nil {
return result, nil
}
lastErr = err
if !client.hooks.isRetryableError(err) {
if !client.Hooks.IsRetryableError(err) {
return nil, err
}
if attempt < maxRetries {
waitTime := client.config.RetryWaitBase * time.Duration(attempt)
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
time.Sleep(waitTime)
}
}
@@ -479,21 +466,21 @@ func (client *Client) CallWithRequestFull(req *Request) (*LLMResponse, error) {
// callWithRequestFull single call that returns LLMResponse (content + tool calls).
func (client *Client) callWithRequestFull(req *Request) (*LLMResponse, error) {
client.logger.Infof("📡 [%s] Request AI Server (full): BaseURL: %s", client.String(), client.BaseURL)
client.Log.Infof("📡 [%s] Request AI Server (full): BaseURL: %s", client.String(), client.BaseURL)
requestBody := client.hooks.buildRequestBodyFromRequest(req)
jsonData, err := client.hooks.marshalRequestBody(requestBody)
requestBody := client.Hooks.BuildRequestBodyFromRequest(req)
jsonData, err := client.Hooks.MarshalRequestBody(requestBody)
if err != nil {
return nil, err
}
url := client.hooks.buildUrl()
httpReq, err := client.hooks.buildRequest(url, jsonData)
url := client.Hooks.BuildUrl()
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := client.httpClient.Do(httpReq)
resp, err := client.HTTPClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
@@ -507,31 +494,31 @@ func (client *Client) callWithRequestFull(req *Request) (*LLMResponse, error) {
return nil, fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
}
return client.hooks.parseMCPResponseFull(body)
return client.Hooks.ParseMCPResponseFull(body)
}
// callWithRequest single AI API call (using Request object)
func (client *Client) callWithRequest(req *Request) (string, error) {
// Print current AI configuration
client.logger.Infof("📡 [%s] Request AI Server with Builder: BaseURL: %s", client.String(), client.BaseURL)
client.logger.Debugf("[%s] Messages count: %d", client.String(), len(req.Messages))
client.Log.Infof("📡 [%s] Request AI Server with Builder: BaseURL: %s", client.String(), client.BaseURL)
client.Log.Debugf("[%s] Messages count: %d", client.String(), len(req.Messages))
requestBody := client.hooks.buildRequestBodyFromRequest(req)
requestBody := client.Hooks.BuildRequestBodyFromRequest(req)
jsonData, err := client.hooks.marshalRequestBody(requestBody)
jsonData, err := client.Hooks.MarshalRequestBody(requestBody)
if err != nil {
return "", err
}
url := client.hooks.buildUrl()
client.logger.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
url := client.Hooks.BuildUrl()
client.Log.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
httpReq, err := client.hooks.buildRequest(url, jsonData)
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
resp, err := client.httpClient.Do(httpReq)
resp, err := client.HTTPClient.Do(httpReq)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}
@@ -546,7 +533,7 @@ func (client *Client) callWithRequest(req *Request) (string, error) {
return "", fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
}
result, err := client.hooks.parseMCPResponse(body)
result, err := client.Hooks.ParseMCPResponse(body)
if err != nil {
return "", fmt.Errorf("fail to parse AI server response: %w", err)
}
@@ -554,8 +541,8 @@ func (client *Client) callWithRequest(req *Request) (string, error) {
return result, nil
}
// buildRequestBodyFromRequest builds request body from Request object
func (client *Client) buildRequestBodyFromRequest(req *Request) map[string]any {
// BuildRequestBodyFromRequest builds request body from Request object
func (client *Client) BuildRequestBodyFromRequest(req *Request) map[string]any {
// Convert Message to API format — must use map[string]any to support
// tool-call messages (tool_calls, tool_call_id fields).
messages := make([]map[string]any, 0, len(req.Messages))
@@ -586,7 +573,7 @@ func (client *Client) buildRequestBodyFromRequest(req *Request) map[string]any {
requestBody["temperature"] = *req.Temperature
} else {
// If not set in Request, use Client's configuration
requestBody["temperature"] = client.config.Temperature
requestBody["temperature"] = client.Cfg.Temperature
}
// OpenAI newer models use max_completion_tokens instead of max_tokens
@@ -647,19 +634,19 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string))
}
req.Stream = true
requestBody := client.hooks.buildRequestBodyFromRequest(req)
jsonData, err := client.hooks.marshalRequestBody(requestBody)
requestBody := client.Hooks.BuildRequestBodyFromRequest(req)
jsonData, err := client.Hooks.MarshalRequestBody(requestBody)
if err != nil {
return "", err
}
url := client.hooks.buildUrl()
httpReq, err := client.hooks.buildRequest(url, jsonData)
url := client.Hooks.BuildUrl()
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
if err != nil {
return "", err
}
// Idle-timeout watchdog: cancel the request if no SSE line arrives for 30 seconds.
// Idle-timeout watchdog: cancel the request if no SSE line arrives for 60 seconds.
// This breaks the scanner out of an indefinitely blocking Read on a hung connection.
const idleTimeout = 60 * time.Second
ctx, cancel := context.WithCancel(context.Background())
@@ -689,7 +676,7 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string))
}()
httpReq = httpReq.WithContext(ctx)
resp, err := client.httpClient.Do(httpReq)
resp, err := client.HTTPClient.Do(httpReq)
if err != nil {
return "", fmt.Errorf("streaming request failed: %w", err)
}
+17 -17
View File
@@ -27,16 +27,16 @@ func TestNewClient_Default(t *testing.T) {
t.Error("MaxTokens should be positive")
}
if c.logger == nil {
t.Error("logger should not be nil")
if c.Log == nil {
t.Error("Log should not be nil")
}
if c.httpClient == nil {
t.Error("httpClient should not be nil")
if c.HTTPClient == nil {
t.Error("HTTPClient should not be nil")
}
if c.hooks == nil {
t.Error("hooks should not be nil")
if c.Hooks == nil {
t.Error("Hooks should not be nil")
}
}
@@ -54,12 +54,12 @@ func TestNewClient_WithOptions(t *testing.T) {
c := client.(*Client)
if c.logger != mockLogger {
t.Error("logger should be set from option")
if c.Log != mockLogger {
t.Error("Log should be set from option")
}
if c.httpClient != mockHTTP {
t.Error("httpClient should be set from option")
if c.HTTPClient != mockHTTP {
t.Error("HTTPClient should be set from option")
}
if c.MaxTokens != 4000 {
@@ -174,7 +174,7 @@ func TestClient_Retry_Success(t *testing.T) {
WithMaxRetries(3),
)
// Since our client uses hooks.call, need special handling
// Since our client uses Hooks.Call, need special handling
// Here we test that CallWithMessages will invoke retry logic
c := client.(*Client)
@@ -242,7 +242,7 @@ func TestClient_BuildMCPRequestBody(t *testing.T) {
client := NewClient()
c := client.(*Client)
body := c.buildMCPRequestBody("system prompt", "user prompt")
body := c.BuildMCPRequestBody("system prompt", "user prompt")
if body == nil {
t.Fatal("body should not be nil")
@@ -300,7 +300,7 @@ func TestClient_BuildUrl(t *testing.T) {
)
c := client.(*Client)
url := c.buildUrl()
url := c.BuildUrl()
if url != tt.expected {
t.Errorf("expected '%s', got '%s'", tt.expected, url)
}
@@ -313,7 +313,7 @@ func TestClient_SetAuthHeader(t *testing.T) {
c := client.(*Client)
headers := make(http.Header)
c.setAuthHeader(headers)
c.SetAuthHeader(headers)
authHeader := headers.Get("Authorization")
if authHeader != "Bearer test-api-key" {
@@ -359,7 +359,7 @@ func TestClient_IsRetryableError(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := c.isRetryableError(tt.err)
result := c.IsRetryableError(tt.err)
if result != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, result)
}
@@ -378,8 +378,8 @@ func TestClient_SetTimeout(t *testing.T) {
client.SetTimeout(newTimeout)
c := client.(*Client)
if c.httpClient.Timeout != newTimeout {
t.Errorf("expected timeout %v, got %v", newTimeout, c.httpClient.Timeout)
if c.HTTPClient.Timeout != newTimeout {
t.Errorf("expected timeout %v, got %v", newTimeout, c.HTTPClient.Timeout)
}
}
+10 -10
View File
@@ -84,7 +84,7 @@ func TestConfig_Temperature_IsUsed(t *testing.T) {
c := client.(*Client)
// Build request body
requestBody := c.buildMCPRequestBody("system", "user")
requestBody := c.BuildMCPRequestBody("system", "user")
// Verify temperature field
temp, ok := requestBody["temperature"].(float64)
@@ -201,7 +201,7 @@ func TestConfig_RetryableErrors_IsUsed(t *testing.T) {
c := client.(*Client)
// Modify config's RetryableErrors (no WithRetryableErrors option yet)
c.config.RetryableErrors = customRetryableErrors
c.Cfg.RetryableErrors = customRetryableErrors
tests := []struct {
name string
@@ -227,7 +227,7 @@ func TestConfig_RetryableErrors_IsUsed(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := c.isRetryableError(tt.err)
result := c.IsRetryableError(tt.err)
if result != tt.retryable {
t.Errorf("expected isRetryableError(%v) = %v, got %v", tt.err, tt.retryable, result)
}
@@ -244,19 +244,19 @@ func TestConfig_DefaultValues(t *testing.T) {
c := client.(*Client)
// Verify default values
if c.config.MaxRetries != 3 {
t.Errorf("default MaxRetries should be 3, got %d", c.config.MaxRetries)
if c.Cfg.MaxRetries != 3 {
t.Errorf("default MaxRetries should be 3, got %d", c.Cfg.MaxRetries)
}
if c.config.Temperature != 0.5 {
t.Errorf("default Temperature should be 0.5, got %f", c.config.Temperature)
if c.Cfg.Temperature != 0.5 {
t.Errorf("default Temperature should be 0.5, got %f", c.Cfg.Temperature)
}
if c.config.RetryWaitBase != 2*time.Second {
t.Errorf("default RetryWaitBase should be 2s, got %v", c.config.RetryWaitBase)
if c.Cfg.RetryWaitBase != 2*time.Second {
t.Errorf("default RetryWaitBase should be 2s, got %v", c.Cfg.RetryWaitBase)
}
if len(c.config.RetryableErrors) == 0 {
if len(c.Cfg.RetryableErrors) == 0 {
t.Error("default RetryableErrors should not be empty")
}
}
-83
View File
@@ -1,83 +0,0 @@
package mcp
import (
"net/http"
)
const (
ProviderDeepSeek = "deepseek"
DefaultDeepSeekBaseURL = "https://api.deepseek.com"
DefaultDeepSeekModel = "deepseek-chat"
)
type DeepSeekClient struct {
*Client
}
// NewDeepSeekClient creates DeepSeek client (backward compatible)
//
// Deprecated: Recommend using NewDeepSeekClientWithOptions for better flexibility
func NewDeepSeekClient() AIClient {
return NewDeepSeekClientWithOptions()
}
// NewDeepSeekClientWithOptions creates DeepSeek client (supports options pattern)
//
// Usage examples:
// // Basic usage
// client := mcp.NewDeepSeekClientWithOptions()
//
// // Custom configuration
// client := mcp.NewDeepSeekClientWithOptions(
// mcp.WithAPIKey("sk-xxx"),
// mcp.WithLogger(customLogger),
// mcp.WithTimeout(60*time.Second),
// )
func NewDeepSeekClientWithOptions(opts ...ClientOption) AIClient {
// 1. Create DeepSeek preset options
deepseekOpts := []ClientOption{
WithProvider(ProviderDeepSeek),
WithModel(DefaultDeepSeekModel),
WithBaseURL(DefaultDeepSeekBaseURL),
}
// 2. Merge user options (user options have higher priority)
allOpts := append(deepseekOpts, opts...)
// 3. Create base client
baseClient := NewClient(allOpts...).(*Client)
// 4. Create DeepSeek client
dsClient := &DeepSeekClient{
Client: baseClient,
}
// 5. Set hooks to point to DeepSeekClient (implement dynamic dispatch)
baseClient.hooks = dsClient
return dsClient
}
func (dsClient *DeepSeekClient) SetAPIKey(apiKey string, customURL string, customModel string) {
dsClient.APIKey = apiKey
if len(apiKey) > 8 {
dsClient.logger.Infof("🔧 [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
dsClient.BaseURL = customURL
dsClient.logger.Infof("🔧 [MCP] DeepSeek using custom BaseURL: %s", customURL)
} else {
dsClient.logger.Infof("🔧 [MCP] DeepSeek using default BaseURL: %s", dsClient.BaseURL)
}
if customModel != "" {
dsClient.Model = customModel
dsClient.logger.Infof("🔧 [MCP] DeepSeek using custom Model: %s", customModel)
} else {
dsClient.logger.Infof("🔧 [MCP] DeepSeek using default Model: %s", dsClient.Model)
}
}
func (dsClient *DeepSeekClient) setAuthHeader(reqHeaders http.Header) {
dsClient.Client.setAuthHeader(reqHeaders)
}
-272
View File
@@ -1,272 +0,0 @@
package mcp
import (
"testing"
"time"
)
// ============================================================
// Test DeepSeekClient Creation and Configuration
// ============================================================
func TestNewDeepSeekClient_Default(t *testing.T) {
client := NewDeepSeekClient()
if client == nil {
t.Fatal("client should not be nil")
}
// Type assertion check
dsClient, ok := client.(*DeepSeekClient)
if !ok {
t.Fatal("client should be *DeepSeekClient")
}
// Verify default values
if dsClient.Provider != ProviderDeepSeek {
t.Errorf("Provider should be '%s', got '%s'", ProviderDeepSeek, dsClient.Provider)
}
if dsClient.BaseURL != DefaultDeepSeekBaseURL {
t.Errorf("BaseURL should be '%s', got '%s'", DefaultDeepSeekBaseURL, dsClient.BaseURL)
}
if dsClient.Model != DefaultDeepSeekModel {
t.Errorf("Model should be '%s', got '%s'", DefaultDeepSeekModel, dsClient.Model)
}
if dsClient.logger == nil {
t.Error("logger should not be nil")
}
if dsClient.httpClient == nil {
t.Error("httpClient should not be nil")
}
}
func TestNewDeepSeekClientWithOptions(t *testing.T) {
mockLogger := NewMockLogger()
customModel := "deepseek-v2"
customAPIKey := "sk-custom-key"
client := NewDeepSeekClientWithOptions(
WithLogger(mockLogger),
WithModel(customModel),
WithAPIKey(customAPIKey),
WithMaxTokens(4000),
)
dsClient := client.(*DeepSeekClient)
// Verify custom options are applied
if dsClient.logger != mockLogger {
t.Error("logger should be set from option")
}
if dsClient.Model != customModel {
t.Error("Model should be set from option")
}
if dsClient.APIKey != customAPIKey {
t.Error("APIKey should be set from option")
}
if dsClient.MaxTokens != 4000 {
t.Error("MaxTokens should be 4000")
}
// Verify DeepSeek default values are retained
if dsClient.Provider != ProviderDeepSeek {
t.Errorf("Provider should still be '%s'", ProviderDeepSeek)
}
if dsClient.BaseURL != DefaultDeepSeekBaseURL {
t.Errorf("BaseURL should still be '%s'", DefaultDeepSeekBaseURL)
}
}
// ============================================================
// Test SetAPIKey
// ============================================================
func TestDeepSeekClient_SetAPIKey(t *testing.T) {
mockLogger := NewMockLogger()
client := NewDeepSeekClientWithOptions(
WithLogger(mockLogger),
)
dsClient := client.(*DeepSeekClient)
// Test setting API Key (default URL and Model)
dsClient.SetAPIKey("sk-test-key-12345678", "", "")
if dsClient.APIKey != "sk-test-key-12345678" {
t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", dsClient.APIKey)
}
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
if len(logs) == 0 {
t.Error("should have logged API key setting")
}
// Verify BaseURL and Model remain default
if dsClient.BaseURL != DefaultDeepSeekBaseURL {
t.Error("BaseURL should remain default")
}
if dsClient.Model != DefaultDeepSeekModel {
t.Error("Model should remain default")
}
}
func TestDeepSeekClient_SetAPIKey_WithCustomURL(t *testing.T) {
mockLogger := NewMockLogger()
client := NewDeepSeekClientWithOptions(
WithLogger(mockLogger),
)
dsClient := client.(*DeepSeekClient)
customURL := "https://custom.api.com/v1"
dsClient.SetAPIKey("sk-test-key-12345678", customURL, "")
if dsClient.BaseURL != customURL {
t.Errorf("BaseURL should be '%s', got '%s'", customURL, dsClient.BaseURL)
}
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
hasCustomURLLog := false
for _, log := range logs {
if log.Format == "🔧 [MCP] DeepSeek using custom BaseURL: %s" {
hasCustomURLLog = true
break
}
}
if !hasCustomURLLog {
t.Error("should have logged custom BaseURL")
}
}
func TestDeepSeekClient_SetAPIKey_WithCustomModel(t *testing.T) {
mockLogger := NewMockLogger()
client := NewDeepSeekClientWithOptions(
WithLogger(mockLogger),
)
dsClient := client.(*DeepSeekClient)
customModel := "deepseek-v3"
dsClient.SetAPIKey("sk-test-key-12345678", "", customModel)
if dsClient.Model != customModel {
t.Errorf("Model should be '%s', got '%s'", customModel, dsClient.Model)
}
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
hasCustomModelLog := false
for _, log := range logs {
if log.Format == "🔧 [MCP] DeepSeek using custom Model: %s" {
hasCustomModelLog = true
break
}
}
if !hasCustomModelLog {
t.Error("should have logged custom Model")
}
}
// ============================================================
// Test Integration Features
// ============================================================
func TestDeepSeekClient_CallWithMessages_Success(t *testing.T) {
mockHTTP := NewMockHTTPClient()
mockHTTP.SetSuccessResponse("DeepSeek AI response")
mockLogger := NewMockLogger()
client := NewDeepSeekClientWithOptions(
WithHTTPClient(mockHTTP.ToHTTPClient()),
WithLogger(mockLogger),
WithAPIKey("sk-test-key"),
)
result, err := client.CallWithMessages("system prompt", "user prompt")
if err != nil {
t.Fatalf("should not error: %v", err)
}
if result != "DeepSeek AI response" {
t.Errorf("expected 'DeepSeek AI response', got '%s'", result)
}
// Verify request
requests := mockHTTP.GetRequests()
if len(requests) != 1 {
t.Fatalf("expected 1 request, got %d", len(requests))
}
req := requests[0]
// Verify URL
expectedURL := DefaultDeepSeekBaseURL + "/chat/completions"
if req.URL.String() != expectedURL {
t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String())
}
// Verify Authorization header
authHeader := req.Header.Get("Authorization")
if authHeader != "Bearer sk-test-key" {
t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader)
}
// Verify Content-Type
if req.Header.Get("Content-Type") != "application/json" {
t.Error("Content-Type should be application/json")
}
}
func TestDeepSeekClient_Timeout(t *testing.T) {
client := NewDeepSeekClientWithOptions(
WithTimeout(30 * time.Second),
)
dsClient := client.(*DeepSeekClient)
if dsClient.httpClient.Timeout != 30*time.Second {
t.Errorf("expected timeout 30s, got %v", dsClient.httpClient.Timeout)
}
// Test SetTimeout
client.SetTimeout(60 * time.Second)
if dsClient.httpClient.Timeout != 60*time.Second {
t.Errorf("expected timeout 60s after SetTimeout, got %v", dsClient.httpClient.Timeout)
}
}
// ============================================================
// Test hooks Mechanism
// ============================================================
func TestDeepSeekClient_HooksIntegration(t *testing.T) {
client := NewDeepSeekClientWithOptions()
dsClient := client.(*DeepSeekClient)
// Verify hooks point to dsClient itself (implements polymorphism)
if dsClient.hooks != dsClient {
t.Error("hooks should point to dsClient for polymorphism")
}
// Verify buildUrl uses DeepSeek configuration
url := dsClient.buildUrl()
expectedURL := DefaultDeepSeekBaseURL + "/chat/completions"
if url != expectedURL {
t.Errorf("expected URL '%s', got '%s'", expectedURL, url)
}
}
+9 -15
View File
@@ -6,6 +6,7 @@ import (
"time"
"nofx/mcp"
"nofx/mcp/provider"
)
// ============================================================
@@ -24,7 +25,7 @@ func Example_backward_compatible() {
func Example_deepseek_backward_compatible() {
// DeepSeek old code continues to work
client := mcp.NewDeepSeekClient()
client := provider.NewDeepSeekClient()
client.SetAPIKey("sk-xxx", "", "")
result, _ := client.CallWithMessages("system", "user")
@@ -141,12 +142,12 @@ func Example_custom_http_client() {
func Example_deepseek_new_api() {
// Basic usage
client := mcp.NewDeepSeekClientWithOptions(
client := provider.NewDeepSeekClientWithOptions(
mcp.WithAPIKey("sk-xxx"),
)
// Advanced usage
client = mcp.NewDeepSeekClientWithOptions(
client = provider.NewDeepSeekClientWithOptions(
mcp.WithAPIKey("sk-xxx"),
mcp.WithLogger(&CustomLogger{}),
mcp.WithTimeout(90*time.Second),
@@ -163,12 +164,12 @@ func Example_deepseek_new_api() {
func Example_qwen_new_api() {
// Basic usage
client := mcp.NewQwenClientWithOptions(
client := provider.NewQwenClientWithOptions(
mcp.WithAPIKey("sk-xxx"),
)
// Advanced usage
client = mcp.NewQwenClientWithOptions(
client = provider.NewQwenClientWithOptions(
mcp.WithAPIKey("sk-xxx"),
mcp.WithLogger(&CustomLogger{}),
mcp.WithTimeout(90*time.Second),
@@ -185,7 +186,7 @@ func Example_qwen_new_api() {
func Example_trader_migration() {
// Old code (continues to work)
oldStyleClient := func(apiKey, customURL, customModel string) mcp.AIClient {
client := mcp.NewDeepSeekClient()
client := provider.NewDeepSeekClient()
client.SetAPIKey(apiKey, customURL, customModel)
return client
}
@@ -204,7 +205,7 @@ func Example_trader_migration() {
opts = append(opts, mcp.WithModel(customModel))
}
return mcp.NewDeepSeekClientWithOptions(opts...)
return provider.NewDeepSeekClientWithOptions(opts...)
}
// Both approaches work
@@ -230,13 +231,7 @@ func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
}
func Example_testing_with_mock() {
// Use Mock during testing
// mockHTTP := &MockHTTPClient{
// Response: `{"choices":[{"message":{"content":"test response"}}]}`,
// }
client := mcp.NewClient(
// mcp.WithHTTPClient(mockHTTP), // Use mockHTTP in actual tests
mcp.WithLogger(mcp.NewNoopLogger()), // Disable logging
)
@@ -258,7 +253,6 @@ func Example_environment_specific() {
// Production environment: structured logging + timeout protection
prodClient := mcp.NewClient(
mcp.WithDeepSeekConfig("sk-xxx"),
// mcp.WithLogger(&ZapLogger{}), // Production-grade logging
mcp.WithTimeout(30*time.Second),
mcp.WithMaxRetries(3),
)
@@ -273,7 +267,7 @@ func Example_environment_specific() {
func Example_real_world_usage() {
// Create client with complete configuration
client := mcp.NewDeepSeekClientWithOptions(
client := provider.NewDeepSeekClientWithOptions(
mcp.WithAPIKey("sk-xxxxxxxxxx"),
mcp.WithTimeout(60*time.Second),
mcp.WithMaxRetries(5),
-71
View File
@@ -1,71 +0,0 @@
package mcp
import (
"net/http"
)
const (
ProviderGemini = "gemini"
DefaultGeminiBaseURL = "https://generativelanguage.googleapis.com/v1beta/openai"
DefaultGeminiModel = "gemini-3-pro-preview"
)
type GeminiClient struct {
*Client
}
// NewGeminiClient creates Gemini client (backward compatible)
func NewGeminiClient() AIClient {
return NewGeminiClientWithOptions()
}
// NewGeminiClientWithOptions creates Gemini client (supports options pattern)
func NewGeminiClientWithOptions(opts ...ClientOption) AIClient {
// 1. Create Gemini preset options
geminiOpts := []ClientOption{
WithProvider(ProviderGemini),
WithModel(DefaultGeminiModel),
WithBaseURL(DefaultGeminiBaseURL),
}
// 2. Merge user options (user options have higher priority)
allOpts := append(geminiOpts, opts...)
// 3. Create base client
baseClient := NewClient(allOpts...).(*Client)
// 4. Create Gemini client
geminiClient := &GeminiClient{
Client: baseClient,
}
// 5. Set hooks to point to GeminiClient (implement dynamic dispatch)
baseClient.hooks = geminiClient
return geminiClient
}
func (c *GeminiClient) SetAPIKey(apiKey string, customURL string, customModel string) {
c.APIKey = apiKey
if len(apiKey) > 8 {
c.logger.Infof("🔧 [MCP] Gemini API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
c.BaseURL = customURL
c.logger.Infof("🔧 [MCP] Gemini using custom BaseURL: %s", customURL)
} else {
c.logger.Infof("🔧 [MCP] Gemini using default BaseURL: %s", c.BaseURL)
}
if customModel != "" {
c.Model = customModel
c.logger.Infof("🔧 [MCP] Gemini using custom Model: %s", customModel)
} else {
c.logger.Infof("🔧 [MCP] Gemini using default Model: %s", c.Model)
}
}
// Gemini OpenAI-compatible API uses standard Bearer auth
func (c *GeminiClient) setAuthHeader(reqHeaders http.Header) {
c.Client.setAuthHeader(reqHeaders)
}
-71
View File
@@ -1,71 +0,0 @@
package mcp
import (
"net/http"
)
const (
ProviderGrok = "grok"
DefaultGrokBaseURL = "https://api.x.ai/v1"
DefaultGrokModel = "grok-3-latest"
)
type GrokClient struct {
*Client
}
// NewGrokClient creates Grok client (backward compatible)
func NewGrokClient() AIClient {
return NewGrokClientWithOptions()
}
// NewGrokClientWithOptions creates Grok client (supports options pattern)
func NewGrokClientWithOptions(opts ...ClientOption) AIClient {
// 1. Create Grok preset options
grokOpts := []ClientOption{
WithProvider(ProviderGrok),
WithModel(DefaultGrokModel),
WithBaseURL(DefaultGrokBaseURL),
}
// 2. Merge user options (user options have higher priority)
allOpts := append(grokOpts, opts...)
// 3. Create base client
baseClient := NewClient(allOpts...).(*Client)
// 4. Create Grok client
grokClient := &GrokClient{
Client: baseClient,
}
// 5. Set hooks to point to GrokClient (implement dynamic dispatch)
baseClient.hooks = grokClient
return grokClient
}
func (c *GrokClient) SetAPIKey(apiKey string, customURL string, customModel string) {
c.APIKey = apiKey
if len(apiKey) > 8 {
c.logger.Infof("🔧 [MCP] Grok API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
c.BaseURL = customURL
c.logger.Infof("🔧 [MCP] Grok using custom BaseURL: %s", customURL)
} else {
c.logger.Infof("🔧 [MCP] Grok using default BaseURL: %s", c.BaseURL)
}
if customModel != "" {
c.Model = customModel
c.logger.Infof("🔧 [MCP] Grok using custom Model: %s", customModel)
} else {
c.logger.Infof("🔧 [MCP] Grok using default Model: %s", c.Model)
}
}
// Grok uses standard OpenAI-compatible API with Bearer auth
func (c *GrokClient) setAuthHeader(reqHeaders http.Header) {
c.Client.setAuthHeader(reqHeaders)
}
+37
View File
@@ -0,0 +1,37 @@
package mcp
import "net/http"
// ClientHooks is the dispatch interface used to implement per-provider
// polymorphism without Go's lack of virtual methods.
//
// Each method can be overridden by an embedding struct (e.g. provider.ClaudeClient).
// The base *Client provides OpenAI-compatible defaults; providers with a
// different wire format (Anthropic, Gemini native, etc.) override only what
// differs. All call-path methods in client.go invoke these via c.Hooks so
// that the override is always picked up at runtime.
type ClientHooks interface {
// ── Simple CallWithMessages path ────────────────────────────────────────
Call(systemPrompt, userPrompt string) (string, error)
BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any
// ── Shared request plumbing ─────────────────────────────────────────────
BuildUrl() string
BuildRequest(url string, jsonData []byte) (*http.Request, error)
SetAuthHeader(reqHeaders http.Header)
MarshalRequestBody(requestBody map[string]any) ([]byte, error)
// ── Advanced (Request-object) path ──────────────────────────────────────
// BuildRequestBodyFromRequest converts a *Request into the provider's
// native wire-format map.
BuildRequestBodyFromRequest(req *Request) map[string]any
// ParseMCPResponse extracts the plain-text reply from a non-streaming
// response body.
ParseMCPResponse(body []byte) (string, error)
// ParseMCPResponseFull extracts both text and tool calls.
ParseMCPResponseFull(body []byte) (*LLMResponse, error)
IsRetryableError(err error) bool
}
+6 -39
View File
@@ -1,10 +1,15 @@
package mcp
import (
"net/http"
"time"
)
// ClientEmbedder is implemented by provider types that embed *Client,
// allowing generic extraction of the underlying base client (e.g. for cloning).
type ClientEmbedder interface {
BaseClient() *Client
}
// AIClient public AI client interface (for external use)
type AIClient interface {
SetAPIKey(apiKey string, customURL string, customModel string)
@@ -21,41 +26,3 @@ type AIClient interface {
// (LLMResponse.ToolCalls), but not both.
CallWithRequestFull(req *Request) (*LLMResponse, error)
}
// clientHooks is the internal dispatch interface used to implement per-provider
// polymorphism without Go's lack of virtual methods.
//
// Each method can be overridden by an embedding struct (e.g. ClaudeClient).
// The base *Client provides OpenAI-compatible defaults; providers with a
// different wire format (Anthropic, Gemini native, etc.) override only what
// differs. All call-path methods in client.go invoke these via c.hooks so
// that the override is always picked up at runtime.
type clientHooks interface {
// ── Simple CallWithMessages path ────────────────────────────────────────
call(systemPrompt, userPrompt string) (string, error)
buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any
// ── Shared request plumbing ─────────────────────────────────────────────
buildUrl() string
buildRequest(url string, jsonData []byte) (*http.Request, error)
setAuthHeader(reqHeaders http.Header)
marshalRequestBody(requestBody map[string]any) ([]byte, error)
// ── Advanced (Request-object) path ──────────────────────────────────────
// buildRequestBodyFromRequest converts a *Request into the provider's
// native wire-format map. Providers that use a different protocol (e.g.
// Anthropic uses "input_schema" for tools, "tool_use" content blocks, and
// a top-level "system" field) override this method.
buildRequestBodyFromRequest(req *Request) map[string]any
// parseMCPResponse extracts the plain-text reply from a non-streaming
// response body.
parseMCPResponse(body []byte) (string, error)
// parseMCPResponseFull extracts both text and tool calls. Providers whose
// response envelope differs from the OpenAI choices[] structure (e.g.
// Anthropic content[] with tool_use blocks) override this method.
parseMCPResponseFull(body []byte) (*LLMResponse, error)
isRetryableError(err error) bool
}
-71
View File
@@ -1,71 +0,0 @@
package mcp
import (
"net/http"
)
const (
ProviderKimi = "kimi"
DefaultKimiBaseURL = "https://api.moonshot.ai/v1" // Global endpoint (use api.moonshot.cn for China)
DefaultKimiModel = "moonshot-v1-auto"
)
type KimiClient struct {
*Client
}
// NewKimiClient creates Kimi (Moonshot) client (backward compatible)
func NewKimiClient() AIClient {
return NewKimiClientWithOptions()
}
// NewKimiClientWithOptions creates Kimi client (supports options pattern)
func NewKimiClientWithOptions(opts ...ClientOption) AIClient {
// 1. Create Kimi preset options
kimiOpts := []ClientOption{
WithProvider(ProviderKimi),
WithModel(DefaultKimiModel),
WithBaseURL(DefaultKimiBaseURL),
}
// 2. Merge user options (user options have higher priority)
allOpts := append(kimiOpts, opts...)
// 3. Create base client
baseClient := NewClient(allOpts...).(*Client)
// 4. Create Kimi client
kimiClient := &KimiClient{
Client: baseClient,
}
// 5. Set hooks to point to KimiClient (implement dynamic dispatch)
baseClient.hooks = kimiClient
return kimiClient
}
func (c *KimiClient) SetAPIKey(apiKey string, customURL string, customModel string) {
c.APIKey = apiKey
if len(apiKey) > 8 {
c.logger.Infof("🔧 [MCP] Kimi API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
c.BaseURL = customURL
c.logger.Infof("🔧 [MCP] Kimi using custom BaseURL: %s", customURL)
} else {
c.logger.Infof("🔧 [MCP] Kimi using default BaseURL: %s", c.BaseURL)
}
if customModel != "" {
c.Model = customModel
c.logger.Infof("🔧 [MCP] Kimi using custom Model: %s", customModel)
} else {
c.logger.Infof("🔧 [MCP] Kimi using default Model: %s", c.Model)
}
}
// Kimi uses standard OpenAI-compatible API, so we just use the base client methods
func (c *KimiClient) setAuthHeader(reqHeaders http.Header) {
c.Client.setAuthHeader(reqHeaders)
}
-83
View File
@@ -1,83 +0,0 @@
package mcp
import (
"net/http"
)
const (
ProviderMiniMax = "minimax"
DefaultMiniMaxBaseURL = "https://api.minimax.io/v1"
DefaultMiniMaxModel = "MiniMax-M2.5"
)
type MiniMaxClient struct {
*Client
}
// NewMiniMaxClient creates MiniMax client (backward compatible)
func NewMiniMaxClient() AIClient {
return NewMiniMaxClientWithOptions()
}
// NewMiniMaxClientWithOptions creates MiniMax client (supports options pattern)
//
// Usage examples:
//
// // Basic usage
// client := mcp.NewMiniMaxClientWithOptions()
//
// // Custom configuration
// client := mcp.NewMiniMaxClientWithOptions(
// mcp.WithAPIKey("sk-xxx"),
// mcp.WithLogger(customLogger),
// mcp.WithTimeout(60*time.Second),
// )
func NewMiniMaxClientWithOptions(opts ...ClientOption) AIClient {
// 1. Create MiniMax preset options
minimaxOpts := []ClientOption{
WithProvider(ProviderMiniMax),
WithModel(DefaultMiniMaxModel),
WithBaseURL(DefaultMiniMaxBaseURL),
}
// 2. Merge user options (user options have higher priority)
allOpts := append(minimaxOpts, opts...)
// 3. Create base client
baseClient := NewClient(allOpts...).(*Client)
// 4. Create MiniMax client
minimaxClient := &MiniMaxClient{
Client: baseClient,
}
// 5. Set hooks to point to MiniMaxClient (implement dynamic dispatch)
baseClient.hooks = minimaxClient
return minimaxClient
}
func (c *MiniMaxClient) SetAPIKey(apiKey string, customURL string, customModel string) {
c.APIKey = apiKey
if len(apiKey) > 8 {
c.logger.Infof("🔧 [MCP] MiniMax API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
c.BaseURL = customURL
c.logger.Infof("🔧 [MCP] MiniMax using custom BaseURL: %s", customURL)
} else {
c.logger.Infof("🔧 [MCP] MiniMax using default BaseURL: %s", c.BaseURL)
}
if customModel != "" {
c.Model = customModel
c.logger.Infof("🔧 [MCP] MiniMax using custom Model: %s", customModel)
} else {
c.logger.Infof("🔧 [MCP] MiniMax using default Model: %s", c.Model)
}
}
// MiniMax uses standard OpenAI-compatible API with Bearer auth
func (c *MiniMaxClient) setAuthHeader(reqHeaders http.Header) {
c.Client.setAuthHeader(reqHeaders)
}
-272
View File
@@ -1,272 +0,0 @@
package mcp
import (
"testing"
"time"
)
// ============================================================
// Test MiniMaxClient Creation and Configuration
// ============================================================
func TestNewMiniMaxClient_Default(t *testing.T) {
client := NewMiniMaxClient()
if client == nil {
t.Fatal("client should not be nil")
}
// Type assertion check
mmClient, ok := client.(*MiniMaxClient)
if !ok {
t.Fatal("client should be *MiniMaxClient")
}
// Verify default values
if mmClient.Provider != ProviderMiniMax {
t.Errorf("Provider should be '%s', got '%s'", ProviderMiniMax, mmClient.Provider)
}
if mmClient.BaseURL != DefaultMiniMaxBaseURL {
t.Errorf("BaseURL should be '%s', got '%s'", DefaultMiniMaxBaseURL, mmClient.BaseURL)
}
if mmClient.Model != DefaultMiniMaxModel {
t.Errorf("Model should be '%s', got '%s'", DefaultMiniMaxModel, mmClient.Model)
}
if mmClient.logger == nil {
t.Error("logger should not be nil")
}
if mmClient.httpClient == nil {
t.Error("httpClient should not be nil")
}
}
func TestNewMiniMaxClientWithOptions(t *testing.T) {
mockLogger := NewMockLogger()
customModel := "MiniMax-M2.5-highspeed"
customAPIKey := "sk-custom-key"
client := NewMiniMaxClientWithOptions(
WithLogger(mockLogger),
WithModel(customModel),
WithAPIKey(customAPIKey),
WithMaxTokens(4000),
)
mmClient := client.(*MiniMaxClient)
// Verify custom options are applied
if mmClient.logger != mockLogger {
t.Error("logger should be set from option")
}
if mmClient.Model != customModel {
t.Error("Model should be set from option")
}
if mmClient.APIKey != customAPIKey {
t.Error("APIKey should be set from option")
}
if mmClient.MaxTokens != 4000 {
t.Error("MaxTokens should be 4000")
}
// Verify MiniMax default values are retained
if mmClient.Provider != ProviderMiniMax {
t.Errorf("Provider should still be '%s'", ProviderMiniMax)
}
if mmClient.BaseURL != DefaultMiniMaxBaseURL {
t.Errorf("BaseURL should still be '%s'", DefaultMiniMaxBaseURL)
}
}
// ============================================================
// Test SetAPIKey
// ============================================================
func TestMiniMaxClient_SetAPIKey(t *testing.T) {
mockLogger := NewMockLogger()
client := NewMiniMaxClientWithOptions(
WithLogger(mockLogger),
)
mmClient := client.(*MiniMaxClient)
// Test setting API Key (default URL and Model)
mmClient.SetAPIKey("sk-test-key-12345678", "", "")
if mmClient.APIKey != "sk-test-key-12345678" {
t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", mmClient.APIKey)
}
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
if len(logs) == 0 {
t.Error("should have logged API key setting")
}
// Verify BaseURL and Model remain default
if mmClient.BaseURL != DefaultMiniMaxBaseURL {
t.Error("BaseURL should remain default")
}
if mmClient.Model != DefaultMiniMaxModel {
t.Error("Model should remain default")
}
}
func TestMiniMaxClient_SetAPIKey_WithCustomURL(t *testing.T) {
mockLogger := NewMockLogger()
client := NewMiniMaxClientWithOptions(
WithLogger(mockLogger),
)
mmClient := client.(*MiniMaxClient)
customURL := "https://api.minimaxi.com/v1"
mmClient.SetAPIKey("sk-test-key-12345678", customURL, "")
if mmClient.BaseURL != customURL {
t.Errorf("BaseURL should be '%s', got '%s'", customURL, mmClient.BaseURL)
}
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
hasCustomURLLog := false
for _, log := range logs {
if log.Format == "🔧 [MCP] MiniMax using custom BaseURL: %s" {
hasCustomURLLog = true
break
}
}
if !hasCustomURLLog {
t.Error("should have logged custom BaseURL")
}
}
func TestMiniMaxClient_SetAPIKey_WithCustomModel(t *testing.T) {
mockLogger := NewMockLogger()
client := NewMiniMaxClientWithOptions(
WithLogger(mockLogger),
)
mmClient := client.(*MiniMaxClient)
customModel := "MiniMax-M2.5-highspeed"
mmClient.SetAPIKey("sk-test-key-12345678", "", customModel)
if mmClient.Model != customModel {
t.Errorf("Model should be '%s', got '%s'", customModel, mmClient.Model)
}
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
hasCustomModelLog := false
for _, log := range logs {
if log.Format == "🔧 [MCP] MiniMax using custom Model: %s" {
hasCustomModelLog = true
break
}
}
if !hasCustomModelLog {
t.Error("should have logged custom Model")
}
}
// ============================================================
// Test Integration Features
// ============================================================
func TestMiniMaxClient_CallWithMessages_Success(t *testing.T) {
mockHTTP := NewMockHTTPClient()
mockHTTP.SetSuccessResponse("MiniMax AI response")
mockLogger := NewMockLogger()
client := NewMiniMaxClientWithOptions(
WithHTTPClient(mockHTTP.ToHTTPClient()),
WithLogger(mockLogger),
WithAPIKey("sk-test-key"),
)
result, err := client.CallWithMessages("system prompt", "user prompt")
if err != nil {
t.Fatalf("should not error: %v", err)
}
if result != "MiniMax AI response" {
t.Errorf("expected 'MiniMax AI response', got '%s'", result)
}
// Verify request
requests := mockHTTP.GetRequests()
if len(requests) != 1 {
t.Fatalf("expected 1 request, got %d", len(requests))
}
req := requests[0]
// Verify URL
expectedURL := DefaultMiniMaxBaseURL + "/chat/completions"
if req.URL.String() != expectedURL {
t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String())
}
// Verify Authorization header
authHeader := req.Header.Get("Authorization")
if authHeader != "Bearer sk-test-key" {
t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader)
}
// Verify Content-Type
if req.Header.Get("Content-Type") != "application/json" {
t.Error("Content-Type should be application/json")
}
}
func TestMiniMaxClient_Timeout(t *testing.T) {
client := NewMiniMaxClientWithOptions(
WithTimeout(30 * time.Second),
)
mmClient := client.(*MiniMaxClient)
if mmClient.httpClient.Timeout != 30*time.Second {
t.Errorf("expected timeout 30s, got %v", mmClient.httpClient.Timeout)
}
// Test SetTimeout
client.SetTimeout(60 * time.Second)
if mmClient.httpClient.Timeout != 60*time.Second {
t.Errorf("expected timeout 60s after SetTimeout, got %v", mmClient.httpClient.Timeout)
}
}
// ============================================================
// Test hooks Mechanism
// ============================================================
func TestMiniMaxClient_HooksIntegration(t *testing.T) {
client := NewMiniMaxClientWithOptions()
mmClient := client.(*MiniMaxClient)
// Verify hooks point to mmClient itself (implements polymorphism)
if mmClient.hooks != mmClient {
t.Error("hooks should point to mmClient for polymorphism")
}
// Verify buildUrl uses MiniMax configuration
url := mmClient.buildUrl()
expectedURL := DefaultMiniMaxBaseURL + "/chat/completions"
if url != expectedURL {
t.Errorf("expected URL '%s', got '%s'", expectedURL, url)
}
}
+21 -9
View File
@@ -247,7 +247,7 @@ func NewMockClientHooks() *MockClientHooks {
return &MockClientHooks{}
}
func (m *MockClientHooks) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
func (m *MockClientHooks) BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
m.BuildRequestBodyCalled++
if m.BuildRequestBodyFunc != nil {
return m.BuildRequestBodyFunc(systemPrompt, userPrompt)
@@ -261,7 +261,7 @@ func (m *MockClientHooks) buildMCPRequestBody(systemPrompt, userPrompt string) m
}
}
func (m *MockClientHooks) buildUrl() string {
func (m *MockClientHooks) BuildUrl() string {
m.BuildUrlCalled++
if m.BuildUrlFunc != nil {
return m.BuildUrlFunc()
@@ -269,12 +269,12 @@ func (m *MockClientHooks) buildUrl() string {
return "https://api.test.com/chat/completions"
}
func (m *MockClientHooks) setAuthHeader(headers http.Header) {
func (m *MockClientHooks) SetAuthHeader(headers http.Header) {
m.SetAuthHeaderCalled++
headers.Set("Authorization", "Bearer test-key")
}
func (m *MockClientHooks) marshalRequestBody(body map[string]any) ([]byte, error) {
func (m *MockClientHooks) MarshalRequestBody(body map[string]any) ([]byte, error) {
m.MarshalRequestCalled++
if m.MarshalRequestBodyFunc != nil {
return m.MarshalRequestBodyFunc(body)
@@ -282,7 +282,7 @@ func (m *MockClientHooks) marshalRequestBody(body map[string]any) ([]byte, error
return json.Marshal(body)
}
func (m *MockClientHooks) parseMCPResponse(body []byte) (string, error) {
func (m *MockClientHooks) ParseMCPResponse(body []byte) (string, error) {
m.ParseResponseCalled++
if m.ParseResponseFunc != nil {
return m.ParseResponseFunc(body)
@@ -290,7 +290,15 @@ func (m *MockClientHooks) parseMCPResponse(body []byte) (string, error) {
return "mocked response", nil
}
func (m *MockClientHooks) isRetryableError(err error) bool {
func (m *MockClientHooks) ParseMCPResponseFull(body []byte) (*LLMResponse, error) {
r, err := m.ParseMCPResponse(body)
if err != nil {
return nil, err
}
return &LLMResponse{Content: r}, nil
}
func (m *MockClientHooks) IsRetryableError(err error) bool {
m.IsRetryableErrorCalled++
if m.IsRetryableErrorFunc != nil {
return m.IsRetryableErrorFunc(err)
@@ -298,13 +306,17 @@ func (m *MockClientHooks) isRetryableError(err error) bool {
return false
}
func (m *MockClientHooks) buildRequest(url string, jsonData []byte) (*http.Request, error) {
func (m *MockClientHooks) BuildRequest(url string, jsonData []byte) (*http.Request, error) {
req, _ := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
req.Header.Set("Content-Type", "application/json")
m.setAuthHeader(req.Header)
m.SetAuthHeader(req.Header)
return req, nil
}
func (m *MockClientHooks) call(systemPrompt, userPrompt string) (string, error) {
func (m *MockClientHooks) Call(systemPrompt, userPrompt string) (string, error) {
return "mocked call result", nil
}
func (m *MockClientHooks) BuildRequestBodyFromRequest(req *Request) map[string]any {
return map[string]any{"model": "test-model"}
}
-71
View File
@@ -1,71 +0,0 @@
package mcp
import (
"net/http"
)
const (
ProviderOpenAI = "openai"
DefaultOpenAIBaseURL = "https://api.openai.com/v1"
DefaultOpenAIModel = "gpt-5.4"
)
type OpenAIClient struct {
*Client
}
// NewOpenAIClient creates OpenAI client (backward compatible)
func NewOpenAIClient() AIClient {
return NewOpenAIClientWithOptions()
}
// NewOpenAIClientWithOptions creates OpenAI client (supports options pattern)
func NewOpenAIClientWithOptions(opts ...ClientOption) AIClient {
// 1. Create OpenAI preset options
openaiOpts := []ClientOption{
WithProvider(ProviderOpenAI),
WithModel(DefaultOpenAIModel),
WithBaseURL(DefaultOpenAIBaseURL),
}
// 2. Merge user options (user options have higher priority)
allOpts := append(openaiOpts, opts...)
// 3. Create base client
baseClient := NewClient(allOpts...).(*Client)
// 4. Create OpenAI client
openaiClient := &OpenAIClient{
Client: baseClient,
}
// 5. Set hooks to point to OpenAIClient (implement dynamic dispatch)
baseClient.hooks = openaiClient
return openaiClient
}
func (c *OpenAIClient) SetAPIKey(apiKey string, customURL string, customModel string) {
c.APIKey = apiKey
if len(apiKey) > 8 {
c.logger.Infof("🔧 [MCP] OpenAI API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
c.BaseURL = customURL
c.logger.Infof("🔧 [MCP] OpenAI using custom BaseURL: %s", customURL)
} else {
c.logger.Infof("🔧 [MCP] OpenAI using default BaseURL: %s", c.BaseURL)
}
if customModel != "" {
c.Model = customModel
c.logger.Infof("🔧 [MCP] OpenAI using custom Model: %s", customModel)
} else {
c.logger.Infof("🔧 [MCP] OpenAI using default Model: %s", c.Model)
}
}
// OpenAI uses standard Bearer auth
func (c *OpenAIClient) setAuthHeader(reqHeaders http.Header) {
c.Client.setAuthHeader(reqHeaders)
}
+3 -77
View File
@@ -279,8 +279,8 @@ func TestOptionsWithNewClient(t *testing.T) {
t.Error("Model should be set from options")
}
if c.logger != mockLogger {
t.Error("logger should be set from options")
if c.Log != mockLogger {
t.Error("Log should be set from options")
}
if c.MaxTokens != 4000 {
@@ -288,78 +288,4 @@ func TestOptionsWithNewClient(t *testing.T) {
}
}
func TestOptionsWithDeepSeekClient(t *testing.T) {
mockLogger := NewMockLogger()
client := NewDeepSeekClientWithOptions(
WithAPIKey("sk-deepseek-key"),
WithLogger(mockLogger),
WithMaxTokens(5000),
)
dsClient := client.(*DeepSeekClient)
// Verify DeepSeek default values
if dsClient.Provider != ProviderDeepSeek {
t.Error("Provider should be DeepSeek")
}
if dsClient.BaseURL != DefaultDeepSeekBaseURL {
t.Error("BaseURL should be DeepSeek default")
}
if dsClient.Model != DefaultDeepSeekModel {
t.Error("Model should be DeepSeek default")
}
// Verify custom options
if dsClient.APIKey != "sk-deepseek-key" {
t.Error("APIKey should be set from options")
}
if dsClient.logger != mockLogger {
t.Error("logger should be set from options")
}
if dsClient.MaxTokens != 5000 {
t.Error("MaxTokens should be 5000")
}
}
func TestOptionsWithQwenClient(t *testing.T) {
mockLogger := NewMockLogger()
client := NewQwenClientWithOptions(
WithAPIKey("sk-qwen-key"),
WithLogger(mockLogger),
WithMaxTokens(6000),
)
qwenClient := client.(*QwenClient)
// Verify Qwen default values
if qwenClient.Provider != ProviderQwen {
t.Error("Provider should be Qwen")
}
if qwenClient.BaseURL != DefaultQwenBaseURL {
t.Error("BaseURL should be Qwen default")
}
if qwenClient.Model != DefaultQwenModel {
t.Error("Model should be Qwen default")
}
// Verify custom options
if qwenClient.APIKey != "sk-qwen-key" {
t.Error("APIKey should be set from options")
}
if qwenClient.logger != mockLogger {
t.Error("logger should be set from options")
}
if qwenClient.MaxTokens != 6000 {
t.Error("MaxTokens should be 6000")
}
}
// Provider-specific option tests are in mcp/provider/options_test.go
@@ -1,4 +1,4 @@
package mcp
package payment
import (
"crypto/ecdsa"
@@ -14,12 +14,13 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"golang.org/x/crypto/sha3"
"nofx/mcp"
)
const (
ProviderBlockRunBase = "blockrun-base"
DefaultBlockRunBaseURL = "https://blockrun.ai"
DefaultBlockRunModel = "gpt-5.4"
DefaultBlockRunModel = "gpt-5.4"
BlockRunChatEndpoint = "/api/v1/chat/completions"
BaseUSDCContract = "0x833589fCD6eDb6E08f4c7C32D4f71b54bdA02913"
BaseChainID int64 = 8453
@@ -28,10 +29,16 @@ const (
// EIP-712 type hashes for USDC TransferWithAuthorization (ERC-3009)
var (
eip712DomainTypeHash = keccak256String("EIP712Domain(string name,string version,uint256 chainId,address verifyingContract)")
eip712DomainTypeHash = keccak256String("EIP712Domain(string name,string version,uint256 chainId,address verifyingContract)")
transferWithAuthTypeHash = keccak256String("TransferWithAuthorization(address from,address to,uint256 value,uint256 validAfter,uint256 validBefore,bytes32 nonce)")
)
func init() {
mcp.RegisterProvider(mcp.ProviderBlockRunBase, func(opts ...mcp.ClientOption) mcp.AIClient {
return NewBlockRunBaseClientWithOptions(opts...)
})
}
func keccak256String(s string) []byte {
h := sha3.NewLegacyKeccak256()
h.Write([]byte(s))
@@ -48,71 +55,72 @@ func keccak256Bytes(data ...[]byte) []byte {
// BlockRunBaseClient implements AIClient using BlockRun's API with x402 v2 EIP-712 payment signing.
type BlockRunBaseClient struct {
*Client
*mcp.Client
privateKey *ecdsa.PrivateKey
}
func (c *BlockRunBaseClient) BaseClient() *mcp.Client { return c.Client }
// NewBlockRunBaseClient creates a BlockRun Base wallet client (backward compatible).
func NewBlockRunBaseClient() AIClient {
func NewBlockRunBaseClient() mcp.AIClient {
return NewBlockRunBaseClientWithOptions()
}
// NewBlockRunBaseClientWithOptions creates a BlockRun Base wallet client.
func NewBlockRunBaseClientWithOptions(opts ...ClientOption) AIClient {
baseOpts := []ClientOption{
WithProvider(ProviderBlockRunBase),
WithModel(DefaultBlockRunModel),
WithBaseURL(DefaultBlockRunBaseURL),
func NewBlockRunBaseClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
baseOpts := []mcp.ClientOption{
mcp.WithProvider(mcp.ProviderBlockRunBase),
mcp.WithModel(DefaultBlockRunModel),
mcp.WithBaseURL(DefaultBlockRunBaseURL),
}
allOpts := append(baseOpts, opts...)
baseClient := NewClient(allOpts...).(*Client)
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
baseClient.UseFullURL = true
baseClient.BaseURL = DefaultBlockRunBaseURL + BlockRunChatEndpoint
c := &BlockRunBaseClient{Client: baseClient}
baseClient.hooks = c
baseClient.Hooks = c
return c
}
// SetAPIKey stores the EVM private key (hex, with or without 0x prefix).
// customModel selects the AI model to use (e.g. "claude-sonnet-4.6"); empty means default.
func (c *BlockRunBaseClient) SetAPIKey(apiKey string, customURL string, customModel string) {
hexKey := strings.TrimPrefix(apiKey, "0x")
privKey, err := crypto.HexToECDSA(hexKey)
if err != nil {
c.logger.Warnf("⚠️ [MCP] BlockRun Base: invalid private key: %v", err)
c.Log.Warnf("⚠️ [MCP] BlockRun Base: invalid private key: %v", err)
} else {
c.privateKey = privKey
c.APIKey = apiKey
addr := crypto.PubkeyToAddress(privKey.PublicKey).Hex()
c.logger.Infof("🔧 [MCP] BlockRun Base wallet: %s", addr)
c.Log.Infof("🔧 [MCP] BlockRun Base wallet: %s", addr)
}
if customModel != "" {
c.Model = customModel
c.logger.Infof("🔧 [MCP] BlockRun Base model: %s", customModel)
c.Log.Infof("🔧 [MCP] BlockRun Base model: %s", customModel)
} else {
c.logger.Infof("🔧 [MCP] BlockRun Base model: %s", DefaultBlockRunModel)
c.Log.Infof("🔧 [MCP] BlockRun Base model: %s", DefaultBlockRunModel)
}
}
func (c *BlockRunBaseClient) setAuthHeader(h http.Header) { x402SetAuthHeader(h) }
func (c *BlockRunBaseClient) SetAuthHeader(h http.Header) { X402SetAuthHeader(h) }
func (c *BlockRunBaseClient) call(systemPrompt, userPrompt string) (string, error) {
return x402Call(c.Client, c.signPayment, "BlockRun Base", systemPrompt, userPrompt)
func (c *BlockRunBaseClient) Call(systemPrompt, userPrompt string) (string, error) {
return X402Call(c.Client, c.signPayment, "BlockRun Base", systemPrompt, userPrompt)
}
func (c *BlockRunBaseClient) CallWithRequestFull(req *Request) (*LLMResponse, error) {
return x402CallFull(c.Client, c.signPayment, "BlockRun Base", req)
func (c *BlockRunBaseClient) CallWithRequestFull(req *mcp.Request) (*mcp.LLMResponse, error) {
return X402CallFull(c.Client, c.signPayment, "BlockRun Base", req)
}
// signPayment parses the Payment-Required header (x402 v2) and returns a signed payment value.
func (c *BlockRunBaseClient) signPayment(paymentHeaderB64 string) (string, error) {
return signBasePaymentHeader(c.privateKey, paymentHeaderB64, "BlockRun Base")
return SignBasePaymentHeader(c.privateKey, paymentHeaderB64, "BlockRun Base")
}
// signX402Payment is the shared EIP-712 signing logic for x402 v2 on Base USDC.
// SignX402Payment is the shared EIP-712 signing logic for x402 v2 on Base USDC.
// Used by both BlockRunBaseClient and Claw402Client.
func signX402Payment(privateKey *ecdsa.PrivateKey, senderAddr string, opt x402AcceptOption, resource *x402Resource) (string, error) {
func SignX402Payment(privateKey *ecdsa.PrivateKey, senderAddr string, opt X402AcceptOption, resource *X402Resource) (string, error) {
recipient := opt.PayTo
amount := opt.Amount
network := opt.Network
@@ -224,7 +232,6 @@ func signX402Payment(privateKey *ecdsa.PrivateKey, senderAddr string, opt x402Ac
// buildDomainSeparatorDynamic builds the EIP-712 domain separator using runtime values.
func buildDomainSeparatorDynamic(name, version, network, asset string) ([]byte, error) {
// Extract chain ID from network string like "eip155:8453"
chainID := new(big.Int).SetInt64(BaseChainID)
if strings.HasPrefix(network, "eip155:") {
parts := strings.SplitN(network, ":", 2)
@@ -311,8 +318,6 @@ func hexToBytes32(s string) ([]byte, error) {
func parseBigInt(s string) (*big.Int, error) {
n := new(big.Int)
// Only treat as hex when explicitly prefixed with 0x/0X.
// x402 amounts are always decimal strings (e.g. "3000" = 0.003 USDC).
if strings.HasPrefix(s, "0x") || strings.HasPrefix(s, "0X") {
if _, ok := n.SetString(s[2:], 16); ok {
return n, nil
@@ -335,11 +340,11 @@ func leftPad32(b []byte) []byte {
return padded
}
// buildUrl returns the full BlockRun endpoint URL.
func (c *BlockRunBaseClient) buildUrl() string {
// BuildUrl returns the full BlockRun endpoint URL.
func (c *BlockRunBaseClient) BuildUrl() string {
return DefaultBlockRunBaseURL + BlockRunChatEndpoint
}
func (c *BlockRunBaseClient) buildRequest(url string, jsonData []byte) (*http.Request, error) {
return x402BuildRequest(url, jsonData)
func (c *BlockRunBaseClient) BuildRequest(url string, jsonData []byte) (*http.Request, error) {
return X402BuildRequest(url, jsonData)
}
@@ -1,4 +1,4 @@
package mcp
package payment
import (
"context"
@@ -12,10 +12,11 @@ import (
"github.com/gagliardetto/solana-go/programs/compute-budget"
"github.com/gagliardetto/solana-go/programs/token"
"github.com/gagliardetto/solana-go/rpc"
"nofx/mcp"
)
const (
ProviderBlockRunSol = "blockrun-sol"
DefaultBlockRunSolURL = "https://sol.blockrun.ai"
SolanaUSDCMint = "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v"
SolanaNetwork = "solana:5eykt4UsFv8P8NJdTREpY1vzqKqZKvdp"
@@ -26,62 +27,69 @@ const (
computeUnitPrice = uint64(1)
)
func init() {
mcp.RegisterProvider(mcp.ProviderBlockRunSol, func(opts ...mcp.ClientOption) mcp.AIClient {
return NewBlockRunSolClientWithOptions(opts...)
})
}
// BlockRunSolClient implements AIClient using BlockRun's Solana x402 v2 payment protocol.
type BlockRunSolClient struct {
*Client
*mcp.Client
keypair solana.PrivateKey
}
func (c *BlockRunSolClient) BaseClient() *mcp.Client { return c.Client }
// NewBlockRunSolClient creates a BlockRun Solana wallet client (backward compatible).
func NewBlockRunSolClient() AIClient {
func NewBlockRunSolClient() mcp.AIClient {
return NewBlockRunSolClientWithOptions()
}
// NewBlockRunSolClientWithOptions creates a BlockRun Solana wallet client.
func NewBlockRunSolClientWithOptions(opts ...ClientOption) AIClient {
baseOpts := []ClientOption{
WithProvider(ProviderBlockRunSol),
WithModel(DefaultBlockRunModel),
WithBaseURL(DefaultBlockRunSolURL),
func NewBlockRunSolClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
baseOpts := []mcp.ClientOption{
mcp.WithProvider(mcp.ProviderBlockRunSol),
mcp.WithModel(DefaultBlockRunModel),
mcp.WithBaseURL(DefaultBlockRunSolURL),
}
allOpts := append(baseOpts, opts...)
baseClient := NewClient(allOpts...).(*Client)
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
baseClient.UseFullURL = true
baseClient.BaseURL = DefaultBlockRunSolURL + BlockRunChatEndpoint
c := &BlockRunSolClient{Client: baseClient}
baseClient.hooks = c
baseClient.Hooks = c
return c
}
// SetAPIKey stores the Solana wallet private key (base58-encoded 64-byte keypair).
// customModel selects the AI model; empty means default.
func (c *BlockRunSolClient) SetAPIKey(apiKey string, customURL string, customModel string) {
kp, err := solana.PrivateKeyFromBase58(strings.TrimSpace(apiKey))
if err != nil {
c.logger.Warnf("⚠️ [MCP] BlockRun Sol: failed to parse private key: %v", err)
c.Log.Warnf("⚠️ [MCP] BlockRun Sol: failed to parse private key: %v", err)
return
}
c.keypair = kp
c.APIKey = apiKey
c.logger.Infof("🔧 [MCP] BlockRun Sol wallet: %s", kp.PublicKey().String())
c.Log.Infof("🔧 [MCP] BlockRun Sol wallet: %s", kp.PublicKey().String())
if customModel != "" {
c.Model = customModel
c.logger.Infof("🔧 [MCP] BlockRun Sol model: %s", customModel)
c.Log.Infof("🔧 [MCP] BlockRun Sol model: %s", customModel)
} else {
c.logger.Infof("🔧 [MCP] BlockRun Sol model: %s", DefaultBlockRunModel)
c.Log.Infof("🔧 [MCP] BlockRun Sol model: %s", DefaultBlockRunModel)
}
}
func (c *BlockRunSolClient) setAuthHeader(h http.Header) { x402SetAuthHeader(h) }
func (c *BlockRunSolClient) SetAuthHeader(h http.Header) { X402SetAuthHeader(h) }
func (c *BlockRunSolClient) call(systemPrompt, userPrompt string) (string, error) {
return x402Call(c.Client, c.signSolanaPayment, "BlockRun Sol", systemPrompt, userPrompt)
func (c *BlockRunSolClient) Call(systemPrompt, userPrompt string) (string, error) {
return X402Call(c.Client, c.signSolanaPayment, "BlockRun Sol", systemPrompt, userPrompt)
}
func (c *BlockRunSolClient) CallWithRequestFull(req *Request) (*LLMResponse, error) {
return x402CallFull(c.Client, c.signSolanaPayment, "BlockRun Sol", req)
func (c *BlockRunSolClient) CallWithRequestFull(req *mcp.Request) (*mcp.LLMResponse, error) {
return X402CallFull(c.Client, c.signSolanaPayment, "BlockRun Sol", req)
}
// signSolanaPayment parses the Payment-Required header and builds a signed x402 v2 Solana payload.
@@ -90,18 +98,18 @@ func (c *BlockRunSolClient) signSolanaPayment(paymentHeaderB64 string) (string,
return "", fmt.Errorf("no private key set for BlockRun Sol wallet")
}
decoded, err := x402DecodeHeader(paymentHeaderB64)
decoded, err := X402DecodeHeader(paymentHeaderB64)
if err != nil {
return "", err
}
var req x402v2PaymentRequired
var req X402v2PaymentRequired
if err := json.Unmarshal(decoded, &req); err != nil {
return "", fmt.Errorf("failed to parse x402 v2 Solana header: %w", err)
}
// Find the Solana option
var opt *x402AcceptOption
var opt *X402AcceptOption
for i := range req.Accepts {
if strings.HasPrefix(req.Accepts[i].Network, "solana:") {
opt = &req.Accepts[i]
@@ -174,11 +182,9 @@ func (c *BlockRunSolClient) signSolanaPayment(paymentHeaderB64 string) (string,
}
// buildSolanaTransferTx builds a partial-signed VersionedTransaction for SPL USDC TransferChecked.
// The fee payer (CDP facilitator) slot is left with a zero signature; only the user signs.
func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr string) (string, error) {
ownerPubkey := c.keypair.PublicKey()
// Parse recipient and feePayer
recipientPK, err := solana.PublicKeyFromBase58(recipient)
if err != nil {
return "", fmt.Errorf("invalid recipient address: %w", err)
@@ -189,13 +195,11 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr
}
mintPK := solana.MustPublicKeyFromBase58(SolanaUSDCMint)
// Parse amount
var amountU64 uint64
if _, err := fmt.Sscanf(amountStr, "%d", &amountU64); err != nil {
return "", fmt.Errorf("invalid amount %q: %w", amountStr, err)
}
// Derive ATAs
sourceATA, _, err := solana.FindAssociatedTokenAddress(ownerPubkey, mintPK)
if err != nil {
return "", fmt.Errorf("failed to derive source ATA: %w", err)
@@ -205,7 +209,6 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr
return "", fmt.Errorf("failed to derive dest ATA: %w", err)
}
// Fetch latest blockhash from Solana mainnet
rpcClient := rpc.New(SolanaMainnetRPC)
bhResp, err := rpcClient.GetLatestBlockhash(context.Background(), rpc.CommitmentFinalized)
if err != nil {
@@ -213,7 +216,6 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr
}
recentBlockhash := bhResp.Value.Blockhash
// Build instructions: ComputeBudgetSetLimit, ComputeBudgetSetPrice, TransferChecked
setLimitIx, err := computebudget.NewSetComputeUnitLimitInstruction(computeUnitLimit).ValidateAndBuild()
if err != nil {
return "", fmt.Errorf("failed to build SetComputeUnitLimit: %w", err)
@@ -235,7 +237,6 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr
return "", fmt.Errorf("failed to build TransferChecked: %w", err)
}
// Build transaction with feePayer as payer (matches Python SDK)
tx, err := solana.NewTransaction(
[]solana.Instruction{setLimitIx, setPriceIx, transferIx},
recentBlockhash,
@@ -245,9 +246,6 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr
return "", fmt.Errorf("failed to build transaction: %w", err)
}
// Partial sign: user signs; fee_payer (CDP) co-signs on server side
// The transaction has 2 signers: [feePayer (index 0), owner (index 1)]
// We sign only our index (owner).
_, err = tx.Sign(func(key solana.PublicKey) *solana.PrivateKey {
if key.Equals(ownerPubkey) {
return &c.keypair
@@ -258,7 +256,6 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr
return "", fmt.Errorf("failed to sign transaction: %w", err)
}
// Serialize transaction
txBytes, err := tx.MarshalBinary()
if err != nil {
return "", fmt.Errorf("failed to serialize transaction: %w", err)
@@ -267,11 +264,11 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr
return base64.StdEncoding.EncodeToString(txBytes), nil
}
// buildUrl returns the full BlockRun Solana endpoint URL.
func (c *BlockRunSolClient) buildUrl() string {
// BuildUrl returns the full BlockRun Solana endpoint URL.
func (c *BlockRunSolClient) BuildUrl() string {
return DefaultBlockRunSolURL + BlockRunChatEndpoint
}
func (c *BlockRunSolClient) buildRequest(url string, jsonData []byte) (*http.Request, error) {
return x402BuildRequest(url, jsonData)
func (c *BlockRunSolClient) BuildRequest(url string, jsonData []byte) (*http.Request, error) {
return X402BuildRequest(url, jsonData)
}
+50 -41
View File
@@ -1,4 +1,4 @@
package mcp
package payment
import (
"crypto/ecdsa"
@@ -6,11 +6,13 @@ import (
"strings"
"github.com/ethereum/go-ethereum/crypto"
"nofx/mcp"
"nofx/mcp/provider"
)
const (
ProviderClaw402 = "claw402"
DefaultClaw402URL = "https://claw402.ai"
DefaultClaw402URL = "https://claw402.ai"
DefaultClaw402Model = "deepseek"
)
@@ -39,35 +41,42 @@ var claw402ModelEndpoints = map[string]string{
"kimi-k2.5": "/api/v1/ai/kimi/chat/k2.5",
}
func init() {
mcp.RegisterProvider(mcp.ProviderClaw402, func(opts ...mcp.ClientOption) mcp.AIClient {
return NewClaw402ClientWithOptions(opts...)
})
}
// Claw402Client implements AIClient using claw402.ai's x402 v2 USDC payment gateway.
// Reuses the same EIP-712 signing as BlockRunBaseClient (same Base chain + USDC contract).
// When the selected model routes to an Anthropic endpoint, it automatically uses
// the Anthropic wire format for requests and responses (via an internal ClaudeClient).
type Claw402Client struct {
*Client
*mcp.Client
privateKey *ecdsa.PrivateKey
claudeProxy *ClaudeClient // non-nil when endpoint is /anthropic/
claudeProxy *provider.ClaudeClient // non-nil when endpoint is /anthropic/
}
func (c *Claw402Client) BaseClient() *mcp.Client { return c.Client }
// NewClaw402Client creates a claw402 client (backward compatible).
func NewClaw402Client() AIClient {
func NewClaw402Client() mcp.AIClient {
return NewClaw402ClientWithOptions()
}
// NewClaw402ClientWithOptions creates a claw402 client with options.
func NewClaw402ClientWithOptions(opts ...ClientOption) AIClient {
baseOpts := []ClientOption{
WithProvider(ProviderClaw402),
WithModel(DefaultClaw402Model),
WithBaseURL(DefaultClaw402URL),
func NewClaw402ClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
baseOpts := []mcp.ClientOption{
mcp.WithProvider(mcp.ProviderClaw402),
mcp.WithModel(DefaultClaw402Model),
mcp.WithBaseURL(DefaultClaw402URL),
}
allOpts := append(baseOpts, opts...)
baseClient := NewClient(allOpts...).(*Client)
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
baseClient.UseFullURL = true
baseClient.BaseURL = DefaultClaw402URL + claw402ModelEndpoints[DefaultClaw402Model]
c := &Claw402Client{Client: baseClient}
baseClient.hooks = c
baseClient.Hooks = c
return c
}
@@ -76,12 +85,12 @@ func (c *Claw402Client) SetAPIKey(apiKey string, _ string, customModel string) {
hexKey := strings.TrimPrefix(apiKey, "0x")
privKey, err := crypto.HexToECDSA(hexKey)
if err != nil {
c.logger.Warnf("⚠️ [MCP] Claw402: invalid private key: %v", err)
c.Log.Warnf("⚠️ [MCP] Claw402: invalid private key: %v", err)
} else {
c.privateKey = privKey
c.APIKey = apiKey
addr := crypto.PubkeyToAddress(privKey.PublicKey).Hex()
c.logger.Infof("🔧 [MCP] Claw402 wallet: %s", addr)
c.Log.Infof("🔧 [MCP] Claw402 wallet: %s", addr)
}
if customModel != "" {
c.Model = customModel
@@ -91,11 +100,11 @@ func (c *Claw402Client) SetAPIKey(apiKey string, _ string, customModel string) {
// Anthropic endpoints need different wire format (Messages API)
if strings.Contains(endpoint, "/anthropic/") {
c.claudeProxy = &ClaudeClient{Client: c.Client}
c.logger.Infof("🔧 [MCP] Claw402 model: %s → %s (Anthropic format)", c.Model, endpoint)
c.claudeProxy = &provider.ClaudeClient{Client: c.Client}
c.Log.Infof("🔧 [MCP] Claw402 model: %s → %s (Anthropic format)", c.Model, endpoint)
} else {
c.claudeProxy = nil
c.logger.Infof("🔧 [MCP] Claw402 model: %s → %s", c.Model, endpoint)
c.Log.Infof("🔧 [MCP] Claw402 model: %s → %s", c.Model, endpoint)
}
}
@@ -111,56 +120,56 @@ func (c *Claw402Client) resolveEndpoint() string {
return claw402ModelEndpoints[DefaultClaw402Model]
}
func (c *Claw402Client) setAuthHeader(h http.Header) { x402SetAuthHeader(h) }
func (c *Claw402Client) SetAuthHeader(h http.Header) { X402SetAuthHeader(h) }
func (c *Claw402Client) call(systemPrompt, userPrompt string) (string, error) {
return x402Call(c.Client, c.signPayment, "Claw402", systemPrompt, userPrompt)
func (c *Claw402Client) Call(systemPrompt, userPrompt string) (string, error) {
return X402Call(c.Client, c.signPayment, "Claw402", systemPrompt, userPrompt)
}
func (c *Claw402Client) CallWithRequestFull(req *Request) (*LLMResponse, error) {
return x402CallFull(c.Client, c.signPayment, "Claw402", req)
func (c *Claw402Client) CallWithRequestFull(req *mcp.Request) (*mcp.LLMResponse, error) {
return X402CallFull(c.Client, c.signPayment, "Claw402", req)
}
// signPayment signs x402 v2 EIP-712 payment (same Base chain + USDC as BlockRunBase).
func (c *Claw402Client) signPayment(paymentHeaderB64 string) (string, error) {
return signBasePaymentHeader(c.privateKey, paymentHeaderB64, "Claw402")
return SignBasePaymentHeader(c.privateKey, paymentHeaderB64, "Claw402")
}
// ── Format overrides for Anthropic endpoints ─────────────────────────────────
func (c *Claw402Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
func (c *Claw402Client) BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
if c.claudeProxy != nil {
return c.claudeProxy.buildMCPRequestBody(systemPrompt, userPrompt)
return c.claudeProxy.BuildMCPRequestBody(systemPrompt, userPrompt)
}
return c.Client.buildMCPRequestBody(systemPrompt, userPrompt)
return c.Client.BuildMCPRequestBody(systemPrompt, userPrompt)
}
func (c *Claw402Client) buildRequestBodyFromRequest(req *Request) map[string]any {
func (c *Claw402Client) BuildRequestBodyFromRequest(req *mcp.Request) map[string]any {
if c.claudeProxy != nil {
return c.claudeProxy.buildRequestBodyFromRequest(req)
return c.claudeProxy.BuildRequestBodyFromRequest(req)
}
return c.Client.buildRequestBodyFromRequest(req)
return c.Client.BuildRequestBodyFromRequest(req)
}
func (c *Claw402Client) parseMCPResponse(body []byte) (string, error) {
func (c *Claw402Client) ParseMCPResponse(body []byte) (string, error) {
if c.claudeProxy != nil {
return c.claudeProxy.parseMCPResponse(body)
return c.claudeProxy.ParseMCPResponse(body)
}
return c.Client.parseMCPResponse(body)
return c.Client.ParseMCPResponse(body)
}
func (c *Claw402Client) parseMCPResponseFull(body []byte) (*LLMResponse, error) {
func (c *Claw402Client) ParseMCPResponseFull(body []byte) (*mcp.LLMResponse, error) {
if c.claudeProxy != nil {
return c.claudeProxy.parseMCPResponseFull(body)
return c.claudeProxy.ParseMCPResponseFull(body)
}
return c.Client.parseMCPResponseFull(body)
return c.Client.ParseMCPResponseFull(body)
}
// buildUrl returns the full claw402 endpoint URL.
func (c *Claw402Client) buildUrl() string {
// BuildUrl returns the full claw402 endpoint URL.
func (c *Claw402Client) BuildUrl() string {
return c.BaseURL
}
func (c *Claw402Client) buildRequest(url string, jsonData []byte) (*http.Request, error) {
return x402BuildRequest(url, jsonData)
func (c *Claw402Client) BuildRequest(url string, jsonData []byte) (*http.Request, error) {
return X402BuildRequest(url, jsonData)
}
+58 -60
View File
@@ -1,4 +1,4 @@
package mcp
package payment
import (
"bytes"
@@ -11,28 +11,30 @@ import (
"time"
"github.com/ethereum/go-ethereum/crypto"
"nofx/mcp"
)
const (
// x402MaxPaymentRetries is the number of retries for 5xx errors on the
// X402MaxPaymentRetries is the number of retries for 5xx errors on the
// payment-signed request. The same payment signature is reused (no double-charge).
x402MaxPaymentRetries = 3
X402MaxPaymentRetries = 3
// x402RetryBaseWait is the base wait between payment retry attempts.
x402RetryBaseWait = 3 * time.Second
// X402RetryBaseWait is the base wait between payment retry attempts.
X402RetryBaseWait = 3 * time.Second
)
// ── Shared x402 types ────────────────────────────────────────────────────────
// x402v2PaymentRequired is the structure of the Payment-Required header (x402 v2).
type x402v2PaymentRequired struct {
// X402v2PaymentRequired is the structure of the Payment-Required header (x402 v2).
type X402v2PaymentRequired struct {
X402Version int `json:"x402Version"`
Accepts []x402AcceptOption `json:"accepts"`
Resource *x402Resource `json:"resource"`
Accepts []X402AcceptOption `json:"accepts"`
Resource *X402Resource `json:"resource"`
}
// x402AcceptOption is a payment option from the x402 v2 header.
type x402AcceptOption struct {
// X402AcceptOption is a payment option from the x402 v2 header.
type X402AcceptOption struct {
Scheme string `json:"scheme"`
Network string `json:"network"`
Amount string `json:"amount"`
@@ -42,22 +44,22 @@ type x402AcceptOption struct {
Extra map[string]string `json:"extra"`
}
// x402Resource describes the resource being paid for.
type x402Resource struct {
// X402Resource describes the resource being paid for.
type X402Resource struct {
URL string `json:"url"`
Description string `json:"description"`
MimeType string `json:"mimeType"`
}
// x402SignFunc is a callback that signs an x402 payment header and returns the
// X402SignFunc is a callback that signs an x402 payment header and returns the
// base64-encoded payment signature.
type x402SignFunc func(paymentHeaderB64 string) (string, error)
type X402SignFunc func(paymentHeaderB64 string) (string, error)
// ── Shared x402 helpers ──────────────────────────────────────────────────────
// x402DecodeHeader decodes a base64-encoded x402 Payment-Required header,
// X402DecodeHeader decodes a base64-encoded x402 Payment-Required header,
// trying RawStdEncoding first then StdEncoding as fallback.
func x402DecodeHeader(b64 string) ([]byte, error) {
func X402DecodeHeader(b64 string) ([]byte, error) {
decoded, err := base64.RawStdEncoding.DecodeString(b64)
if err != nil {
decoded, err = base64.StdEncoding.DecodeString(b64)
@@ -68,19 +70,19 @@ func x402DecodeHeader(b64 string) ([]byte, error) {
return decoded, nil
}
// signBasePaymentHeader decodes a base64 x402 header, parses it, and signs with
// SignBasePaymentHeader decodes a base64 x402 header, parses it, and signs with
// EIP-712 (USDC TransferWithAuthorization). Shared by BlockRunBase and Claw402.
func signBasePaymentHeader(privateKey *ecdsa.PrivateKey, paymentHeaderB64 string, providerName string) (string, error) {
func SignBasePaymentHeader(privateKey *ecdsa.PrivateKey, paymentHeaderB64 string, providerName string) (string, error) {
if privateKey == nil {
return "", fmt.Errorf("no private key set for %s wallet", providerName)
}
decoded, err := x402DecodeHeader(paymentHeaderB64)
decoded, err := X402DecodeHeader(paymentHeaderB64)
if err != nil {
return "", err
}
var req x402v2PaymentRequired
var req X402v2PaymentRequired
if err := json.Unmarshal(decoded, &req); err != nil {
return "", fmt.Errorf("failed to parse x402 v2 payment header: %w", err)
}
@@ -89,19 +91,16 @@ func signBasePaymentHeader(privateKey *ecdsa.PrivateKey, paymentHeaderB64 string
}
senderAddr := crypto.PubkeyToAddress(privateKey.PublicKey).Hex()
return signX402Payment(privateKey, senderAddr, req.Accepts[0], req.Resource)
return SignX402Payment(privateKey, senderAddr, req.Accepts[0], req.Resource)
}
// doX402Request executes an HTTP request and handles the x402 v2 payment flow.
// On a 402 response it reads the Payment-Required (or X-Payment-Required) header,
// signs via signFn, retries with Payment-Signature, and logs the Payment-Response
// header (tx hash) on success.
func doX402Request(
// DoX402Request executes an HTTP request and handles the x402 v2 payment flow.
func DoX402Request(
httpClient *http.Client,
buildReqFn func() (*http.Request, error),
signFn x402SignFunc,
signFn X402SignFunc,
providerTag string,
logger Logger,
logger mcp.Logger,
) ([]byte, error) {
req, err := buildReqFn()
if err != nil {
@@ -133,10 +132,9 @@ func doX402Request(
}
// Retry loop for 5xx errors on the payment-signed request.
// Reuses the same payment signature — no double-charge.
var lastBody []byte
var lastStatus int
for attempt := 1; attempt <= x402MaxPaymentRetries; attempt++ {
for attempt := 1; attempt <= X402MaxPaymentRetries; attempt++ {
req2, err := buildReqFn()
if err != nil {
return nil, fmt.Errorf("failed to build retry request: %w", err)
@@ -146,10 +144,10 @@ func doX402Request(
resp2, err := httpClient.Do(req2)
if err != nil {
if attempt < x402MaxPaymentRetries {
wait := x402RetryBaseWait * time.Duration(attempt)
if attempt < X402MaxPaymentRetries {
wait := X402RetryBaseWait * time.Duration(attempt)
logger.Warnf("⚠️ [%s] Payment request failed: %v, retrying in %v (%d/%d)...",
providerTag, err, wait, attempt+1, x402MaxPaymentRetries)
providerTag, err, wait, attempt+1, X402MaxPaymentRetries)
time.Sleep(wait)
continue
}
@@ -175,11 +173,11 @@ func doX402Request(
lastBody = body2
lastStatus = resp2.StatusCode
// Retry on 5xx server errors (502, 503, 520, etc.)
if resp2.StatusCode >= 500 && attempt < x402MaxPaymentRetries {
wait := x402RetryBaseWait * time.Duration(attempt)
// Retry on 5xx server errors
if resp2.StatusCode >= 500 && attempt < X402MaxPaymentRetries {
wait := X402RetryBaseWait * time.Duration(attempt)
logger.Warnf("⚠️ [%s] Server error (status %d), retrying in %v (%d/%d)...",
providerTag, resp2.StatusCode, wait, attempt+1, x402MaxPaymentRetries)
providerTag, resp2.StatusCode, wait, attempt+1, X402MaxPaymentRetries)
time.Sleep(wait)
continue
}
@@ -201,8 +199,8 @@ func doX402Request(
return body, nil
}
// x402BuildRequest creates a POST request with Content-Type but no auth header.
func x402BuildRequest(url string, jsonData []byte) (*http.Request, error) {
// X402BuildRequest creates a POST request with Content-Type but no auth header.
func X402BuildRequest(url string, jsonData []byte) (*http.Request, error) {
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("fail to build request: %w", err)
@@ -211,30 +209,30 @@ func x402BuildRequest(url string, jsonData []byte) (*http.Request, error) {
return req, nil
}
// x402SetAuthHeader is a no-op — x402 providers authenticate via payment signing.
func x402SetAuthHeader(_ http.Header) {}
// X402SetAuthHeader is a no-op — x402 providers authenticate via payment signing.
func X402SetAuthHeader(_ http.Header) {}
// x402Call handles the x402 payment flow for the simple CallWithMessages path.
func x402Call(c *Client, signFn x402SignFunc, tag string, systemPrompt, userPrompt string) (string, error) {
c.logger.Infof("📡 [%s] Request AI Server: %s", tag, c.BaseURL)
// X402Call handles the x402 payment flow for the simple CallWithMessages path.
func X402Call(c *mcp.Client, signFn X402SignFunc, tag string, systemPrompt, userPrompt string) (string, error) {
c.Log.Infof("📡 [%s] Request AI Server: %s", tag, c.BaseURL)
requestBody := c.hooks.buildMCPRequestBody(systemPrompt, userPrompt)
jsonData, err := c.hooks.marshalRequestBody(requestBody)
requestBody := c.Hooks.BuildMCPRequestBody(systemPrompt, userPrompt)
jsonData, err := c.Hooks.MarshalRequestBody(requestBody)
if err != nil {
return "", err
}
body, err := doX402Request(c.httpClient, func() (*http.Request, error) {
return c.hooks.buildRequest(c.hooks.buildUrl(), jsonData)
}, signFn, tag, c.logger)
body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) {
return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData)
}, signFn, tag, c.Log)
if err != nil {
return "", err
}
return c.hooks.parseMCPResponse(body)
return c.Hooks.ParseMCPResponse(body)
}
// x402CallFull handles the x402 payment flow for the advanced Request path.
func x402CallFull(c *Client, signFn x402SignFunc, tag string, req *Request) (*LLMResponse, error) {
// X402CallFull handles the x402 payment flow for the advanced Request path.
func X402CallFull(c *mcp.Client, signFn X402SignFunc, tag string, req *mcp.Request) (*mcp.LLMResponse, error) {
if c.APIKey == "" {
return nil, fmt.Errorf("AI API key not set, please call SetAPIKey first")
}
@@ -242,19 +240,19 @@ func x402CallFull(c *Client, signFn x402SignFunc, tag string, req *Request) (*LL
req.Model = c.Model
}
c.logger.Infof("📡 [%s] Request AI (full): %s", tag, c.BaseURL)
c.Log.Infof("📡 [%s] Request AI (full): %s", tag, c.BaseURL)
requestBody := c.hooks.buildRequestBodyFromRequest(req)
jsonData, err := c.hooks.marshalRequestBody(requestBody)
requestBody := c.Hooks.BuildRequestBodyFromRequest(req)
jsonData, err := c.Hooks.MarshalRequestBody(requestBody)
if err != nil {
return nil, err
}
body, err := doX402Request(c.httpClient, func() (*http.Request, error) {
return c.hooks.buildRequest(c.hooks.buildUrl(), jsonData)
}, signFn, tag, c.logger)
body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) {
return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData)
}, signFn, tag, c.Log)
if err != nil {
return nil, err
}
return c.hooks.parseMCPResponseFull(body)
return c.Hooks.ParseMCPResponseFull(body)
}
+47 -67
View File
@@ -1,4 +1,4 @@
// Package mcp — ClaudeClient implements the Anthropic Messages API.
// Package provider — ClaudeClient implements the Anthropic Messages API.
//
// Wire-format differences from the OpenAI-compatible base Client:
//
@@ -14,42 +14,51 @@
// │ Tool result │ role=tool + tool_call_id │ role=user content[tool_result] │
// │ Max tokens │ max_tokens │ max_tokens (same) │
// └─────────────────────┴───────────────────────────┴─────────────────────────────────┘
package mcp
package provider
import (
"encoding/json"
"fmt"
"net/http"
"nofx/mcp"
)
const (
ProviderClaude = "claude"
DefaultClaudeBaseURL = "https://api.anthropic.com/v1"
DefaultClaudeModel = "claude-opus-4-6"
)
func init() {
mcp.RegisterProvider(mcp.ProviderClaude, func(opts ...mcp.ClientOption) mcp.AIClient {
return NewClaudeClientWithOptions(opts...)
})
}
// ClaudeClient wraps the base Client and overrides the methods that differ
// for the Anthropic Messages API. All other behaviour (retry, timeout,
// logging) is inherited unchanged.
type ClaudeClient struct {
*Client
*mcp.Client
}
func (c *ClaudeClient) BaseClient() *mcp.Client { return c.Client }
// NewClaudeClient creates a ClaudeClient with default settings.
func NewClaudeClient() AIClient {
func NewClaudeClient() mcp.AIClient {
return NewClaudeClientWithOptions()
}
// NewClaudeClientWithOptions creates a ClaudeClient with optional overrides.
func NewClaudeClientWithOptions(opts ...ClientOption) AIClient {
baseClient := NewClient(append([]ClientOption{
WithProvider(ProviderClaude),
WithModel(DefaultClaudeModel),
WithBaseURL(DefaultClaudeBaseURL),
}, opts...)...).(*Client)
func NewClaudeClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
baseClient := mcp.NewClient(append([]mcp.ClientOption{
mcp.WithProvider(mcp.ProviderClaude),
mcp.WithModel(DefaultClaudeModel),
mcp.WithBaseURL(DefaultClaudeBaseURL),
}, opts...)...).(*mcp.Client)
c := &ClaudeClient{Client: baseClient}
baseClient.hooks = c // wire dynamic dispatch to ClaudeClient
baseClient.Hooks = c // wire dynamic dispatch to ClaudeClient
return c
}
@@ -59,32 +68,32 @@ func NewClaudeClientWithOptions(opts ...ClientOption) AIClient {
func (c *ClaudeClient) SetAPIKey(apiKey, customURL, customModel string) {
c.APIKey = apiKey
if len(apiKey) > 8 {
c.logger.Infof("🔧 [MCP] Claude API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
c.Log.Infof("🔧 [MCP] Claude API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
c.BaseURL = customURL
c.logger.Infof("🔧 [MCP] Claude BaseURL: %s", customURL)
c.Log.Infof("🔧 [MCP] Claude BaseURL: %s", customURL)
}
if customModel != "" {
c.Model = customModel
c.logger.Infof("🔧 [MCP] Claude Model: %s", customModel)
c.Log.Infof("🔧 [MCP] Claude Model: %s", customModel)
}
}
// setAuthHeader uses x-api-key instead of Authorization: Bearer.
func (c *ClaudeClient) setAuthHeader(h http.Header) {
// SetAuthHeader uses x-api-key instead of Authorization: Bearer.
func (c *ClaudeClient) SetAuthHeader(h http.Header) {
h.Set("x-api-key", c.APIKey)
h.Set("anthropic-version", "2023-06-01")
}
// buildUrl targets /messages instead of /chat/completions.
func (c *ClaudeClient) buildUrl() string {
// BuildUrl targets /messages instead of /chat/completions.
func (c *ClaudeClient) BuildUrl() string {
return fmt.Sprintf("%s/messages", c.BaseURL)
}
// buildMCPRequestBody builds the Anthropic wire format for the simple
// BuildMCPRequestBody builds the Anthropic wire format for the simple
// CallWithMessages path (no tool support).
func (c *ClaudeClient) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
func (c *ClaudeClient) BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
return map[string]any{
"model": c.Model,
"max_tokens": c.MaxTokens,
@@ -95,23 +104,12 @@ func (c *ClaudeClient) buildMCPRequestBody(systemPrompt, userPrompt string) map[
}
}
// buildRequestBodyFromRequest converts a *Request into the Anthropic Messages
// API wire format. This is the key override that makes tool calling work
// correctly with Claude.
//
// Conversions applied:
//
// - System messages are lifted to the top-level "system" field.
// - Tool definitions: parameters → input_schema, wrapper removed.
// - Assistant messages with ToolCalls → content[{type:tool_use,...}].
// - Tool result messages (role=tool) → role=user with tool_result blocks.
// Consecutive tool results are merged into a single user turn (Anthropic
// requires strictly alternating user/assistant turns).
// - tool_choice "auto"/"any" → {"type":"auto"/"any"} object.
func (c *ClaudeClient) buildRequestBodyFromRequest(req *Request) map[string]any {
// BuildRequestBodyFromRequest converts a *Request into the Anthropic Messages
// API wire format.
func (c *ClaudeClient) BuildRequestBodyFromRequest(req *mcp.Request) map[string]any {
// ── 1. Separate system prompt from conversation messages ──────────────────
var systemPrompt string
var convMsgs []Message
var convMsgs []mcp.Message
for _, m := range req.Messages {
if m.Role == "system" {
systemPrompt = m.Content
@@ -121,7 +119,7 @@ func (c *ClaudeClient) buildRequestBodyFromRequest(req *Request) map[string]any
}
// ── 2. Convert messages to Anthropic format ───────────────────────────────
anthropicMsgs := convertMessagesToAnthropic(convMsgs)
anthropicMsgs := ConvertMessagesToAnthropic(convMsgs)
// ── 3. Convert tool definitions (parameters → input_schema) ──────────────
var anthropicTools []map[string]any
@@ -162,16 +160,9 @@ func (c *ClaudeClient) buildRequestBodyFromRequest(req *Request) map[string]any
return body
}
// convertMessagesToAnthropic translates from the OpenAI-shaped mcp.Message
// ConvertMessagesToAnthropic translates from the OpenAI-shaped mcp.Message
// slice to Anthropic's messages array.
//
// Rules:
// 1. role=assistant + ToolCalls → role=assistant, content=[tool_use, ...]
// 2. role=tool (result) → role=user, content=[tool_result, ...]
// Consecutive tool-result messages are merged into one user turn so the
// conversation always alternates user/assistant.
// 3. All other messages → {role, content} as-is.
func convertMessagesToAnthropic(msgs []Message) []map[string]any {
func ConvertMessagesToAnthropic(msgs []mcp.Message) []map[string]any {
var out []map[string]any
for i := 0; i < len(msgs); {
@@ -232,29 +223,18 @@ func convertMessagesToAnthropic(msgs []Message) []map[string]any {
// ── Response parsers ──────────────────────────────────────────────────────────
// parseMCPResponse extracts the plain-text reply from an Anthropic response.
// Used by CallWithMessages / CallWithRequest (no tool support).
func (c *ClaudeClient) parseMCPResponse(body []byte) (string, error) {
r, err := c.parseMCPResponseFull(body)
// ParseMCPResponse extracts the plain-text reply from an Anthropic response.
func (c *ClaudeClient) ParseMCPResponse(body []byte) (string, error) {
r, err := c.ParseMCPResponseFull(body)
if err != nil {
return "", err
}
return r.Content, nil
}
// parseMCPResponseFull extracts both text and tool calls from an Anthropic
// ParseMCPResponseFull extracts both text and tool calls from an Anthropic
// response envelope.
//
// Anthropic response shape:
//
// {
// "content": [
// {"type": "text", "text": "..."},
// {"type": "tool_use", "id": "...", "name": "...", "input": {...}}
// ],
// "stop_reason": "tool_use" | "end_turn"
// }
func (c *ClaudeClient) parseMCPResponseFull(body []byte) (*LLMResponse, error) {
func (c *ClaudeClient) ParseMCPResponseFull(body []byte) (*mcp.LLMResponse, error) {
var raw struct {
Content []struct {
Type string `json:"type"`
@@ -281,8 +261,8 @@ func (c *ClaudeClient) parseMCPResponseFull(body []byte) (*LLMResponse, error) {
}
total := raw.Usage.InputTokens + raw.Usage.OutputTokens
if TokenUsageCallback != nil && total > 0 {
TokenUsageCallback(TokenUsage{
if mcp.TokenUsageCallback != nil && total > 0 {
mcp.TokenUsageCallback(mcp.TokenUsage{
Provider: c.Provider,
Model: c.Model,
PromptTokens: raw.Usage.InputTokens,
@@ -291,7 +271,7 @@ func (c *ClaudeClient) parseMCPResponseFull(body []byte) (*LLMResponse, error) {
})
}
result := &LLMResponse{}
result := &mcp.LLMResponse{}
for _, block := range raw.Content {
switch block.Type {
case "text":
@@ -304,10 +284,10 @@ func (c *ClaudeClient) parseMCPResponseFull(body []byte) (*LLMResponse, error) {
if err != nil {
argsJSON = []byte("{}")
}
result.ToolCalls = append(result.ToolCalls, ToolCall{
result.ToolCalls = append(result.ToolCalls, mcp.ToolCall{
ID: block.ID,
Type: "function",
Function: ToolCallFunction{
Function: mcp.ToolCallFunction{
Name: block.Name,
Arguments: string(argsJSON),
},
+69
View File
@@ -0,0 +1,69 @@
package provider
import (
"net/http"
"nofx/mcp"
)
func init() {
mcp.RegisterProvider(mcp.ProviderDeepSeek, func(opts ...mcp.ClientOption) mcp.AIClient {
return NewDeepSeekClientWithOptions(opts...)
})
}
type DeepSeekClient struct {
*mcp.Client
}
func (c *DeepSeekClient) BaseClient() *mcp.Client { return c.Client }
// NewDeepSeekClient creates DeepSeek client (backward compatible)
//
// Deprecated: Recommend using NewDeepSeekClientWithOptions for better flexibility
func NewDeepSeekClient() mcp.AIClient {
return NewDeepSeekClientWithOptions()
}
// NewDeepSeekClientWithOptions creates DeepSeek client (supports options pattern)
func NewDeepSeekClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
deepseekOpts := []mcp.ClientOption{
mcp.WithProvider(mcp.ProviderDeepSeek),
mcp.WithModel(mcp.DefaultDeepSeekModel),
mcp.WithBaseURL(mcp.DefaultDeepSeekBaseURL),
}
allOpts := append(deepseekOpts, opts...)
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
dsClient := &DeepSeekClient{
Client: baseClient,
}
baseClient.Hooks = dsClient
return dsClient
}
func (dsClient *DeepSeekClient) SetAPIKey(apiKey string, customURL string, customModel string) {
dsClient.APIKey = apiKey
if len(apiKey) > 8 {
dsClient.Log.Infof("🔧 [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
dsClient.BaseURL = customURL
dsClient.Log.Infof("🔧 [MCP] DeepSeek using custom BaseURL: %s", customURL)
} else {
dsClient.Log.Infof("🔧 [MCP] DeepSeek using default BaseURL: %s", dsClient.BaseURL)
}
if customModel != "" {
dsClient.Model = customModel
dsClient.Log.Infof("🔧 [MCP] DeepSeek using custom Model: %s", customModel)
} else {
dsClient.Log.Infof("🔧 [MCP] DeepSeek using default Model: %s", dsClient.Model)
}
}
func (dsClient *DeepSeekClient) SetAuthHeader(reqHeaders http.Header) {
dsClient.Client.SetAuthHeader(reqHeaders)
}
+73
View File
@@ -0,0 +1,73 @@
package provider
import (
"net/http"
"nofx/mcp"
)
const (
DefaultGeminiBaseURL = "https://generativelanguage.googleapis.com/v1beta/openai"
DefaultGeminiModel = "gemini-3-pro-preview"
)
func init() {
mcp.RegisterProvider(mcp.ProviderGemini, func(opts ...mcp.ClientOption) mcp.AIClient {
return NewGeminiClientWithOptions(opts...)
})
}
type GeminiClient struct {
*mcp.Client
}
func (c *GeminiClient) BaseClient() *mcp.Client { return c.Client }
// NewGeminiClient creates Gemini client (backward compatible)
func NewGeminiClient() mcp.AIClient {
return NewGeminiClientWithOptions()
}
// NewGeminiClientWithOptions creates Gemini client (supports options pattern)
func NewGeminiClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
geminiOpts := []mcp.ClientOption{
mcp.WithProvider(mcp.ProviderGemini),
mcp.WithModel(DefaultGeminiModel),
mcp.WithBaseURL(DefaultGeminiBaseURL),
}
allOpts := append(geminiOpts, opts...)
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
geminiClient := &GeminiClient{
Client: baseClient,
}
baseClient.Hooks = geminiClient
return geminiClient
}
func (c *GeminiClient) SetAPIKey(apiKey string, customURL string, customModel string) {
c.APIKey = apiKey
if len(apiKey) > 8 {
c.Log.Infof("🔧 [MCP] Gemini API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
c.BaseURL = customURL
c.Log.Infof("🔧 [MCP] Gemini using custom BaseURL: %s", customURL)
} else {
c.Log.Infof("🔧 [MCP] Gemini using default BaseURL: %s", c.BaseURL)
}
if customModel != "" {
c.Model = customModel
c.Log.Infof("🔧 [MCP] Gemini using custom Model: %s", customModel)
} else {
c.Log.Infof("🔧 [MCP] Gemini using default Model: %s", c.Model)
}
}
// Gemini OpenAI-compatible API uses standard Bearer auth
func (c *GeminiClient) SetAuthHeader(reqHeaders http.Header) {
c.Client.SetAuthHeader(reqHeaders)
}
+73
View File
@@ -0,0 +1,73 @@
package provider
import (
"net/http"
"nofx/mcp"
)
const (
DefaultGrokBaseURL = "https://api.x.ai/v1"
DefaultGrokModel = "grok-3-latest"
)
func init() {
mcp.RegisterProvider(mcp.ProviderGrok, func(opts ...mcp.ClientOption) mcp.AIClient {
return NewGrokClientWithOptions(opts...)
})
}
type GrokClient struct {
*mcp.Client
}
func (c *GrokClient) BaseClient() *mcp.Client { return c.Client }
// NewGrokClient creates Grok client (backward compatible)
func NewGrokClient() mcp.AIClient {
return NewGrokClientWithOptions()
}
// NewGrokClientWithOptions creates Grok client (supports options pattern)
func NewGrokClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
grokOpts := []mcp.ClientOption{
mcp.WithProvider(mcp.ProviderGrok),
mcp.WithModel(DefaultGrokModel),
mcp.WithBaseURL(DefaultGrokBaseURL),
}
allOpts := append(grokOpts, opts...)
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
grokClient := &GrokClient{
Client: baseClient,
}
baseClient.Hooks = grokClient
return grokClient
}
func (c *GrokClient) SetAPIKey(apiKey string, customURL string, customModel string) {
c.APIKey = apiKey
if len(apiKey) > 8 {
c.Log.Infof("🔧 [MCP] Grok API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
c.BaseURL = customURL
c.Log.Infof("🔧 [MCP] Grok using custom BaseURL: %s", customURL)
} else {
c.Log.Infof("🔧 [MCP] Grok using default BaseURL: %s", c.BaseURL)
}
if customModel != "" {
c.Model = customModel
c.Log.Infof("🔧 [MCP] Grok using custom Model: %s", customModel)
} else {
c.Log.Infof("🔧 [MCP] Grok using default Model: %s", c.Model)
}
}
// Grok uses standard OpenAI-compatible API with Bearer auth
func (c *GrokClient) SetAuthHeader(reqHeaders http.Header) {
c.Client.SetAuthHeader(reqHeaders)
}
+73
View File
@@ -0,0 +1,73 @@
package provider
import (
"net/http"
"nofx/mcp"
)
const (
DefaultKimiBaseURL = "https://api.moonshot.ai/v1" // Global endpoint (use api.moonshot.cn for China)
DefaultKimiModel = "moonshot-v1-auto"
)
func init() {
mcp.RegisterProvider(mcp.ProviderKimi, func(opts ...mcp.ClientOption) mcp.AIClient {
return NewKimiClientWithOptions(opts...)
})
}
type KimiClient struct {
*mcp.Client
}
func (c *KimiClient) BaseClient() *mcp.Client { return c.Client }
// NewKimiClient creates Kimi (Moonshot) client (backward compatible)
func NewKimiClient() mcp.AIClient {
return NewKimiClientWithOptions()
}
// NewKimiClientWithOptions creates Kimi client (supports options pattern)
func NewKimiClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
kimiOpts := []mcp.ClientOption{
mcp.WithProvider(mcp.ProviderKimi),
mcp.WithModel(DefaultKimiModel),
mcp.WithBaseURL(DefaultKimiBaseURL),
}
allOpts := append(kimiOpts, opts...)
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
kimiClient := &KimiClient{
Client: baseClient,
}
baseClient.Hooks = kimiClient
return kimiClient
}
func (c *KimiClient) SetAPIKey(apiKey string, customURL string, customModel string) {
c.APIKey = apiKey
if len(apiKey) > 8 {
c.Log.Infof("🔧 [MCP] Kimi API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
c.BaseURL = customURL
c.Log.Infof("🔧 [MCP] Kimi using custom BaseURL: %s", customURL)
} else {
c.Log.Infof("🔧 [MCP] Kimi using default BaseURL: %s", c.BaseURL)
}
if customModel != "" {
c.Model = customModel
c.Log.Infof("🔧 [MCP] Kimi using custom Model: %s", customModel)
} else {
c.Log.Infof("🔧 [MCP] Kimi using default Model: %s", c.Model)
}
}
// Kimi uses standard OpenAI-compatible API
func (c *KimiClient) SetAuthHeader(reqHeaders http.Header) {
c.Client.SetAuthHeader(reqHeaders)
}
+73
View File
@@ -0,0 +1,73 @@
package provider
import (
"net/http"
"nofx/mcp"
)
const (
DefaultMiniMaxBaseURL = "https://api.minimax.io/v1"
DefaultMiniMaxModel = "MiniMax-M2.5"
)
func init() {
mcp.RegisterProvider(mcp.ProviderMiniMax, func(opts ...mcp.ClientOption) mcp.AIClient {
return NewMiniMaxClientWithOptions(opts...)
})
}
type MiniMaxClient struct {
*mcp.Client
}
func (c *MiniMaxClient) BaseClient() *mcp.Client { return c.Client }
// NewMiniMaxClient creates MiniMax client (backward compatible)
func NewMiniMaxClient() mcp.AIClient {
return NewMiniMaxClientWithOptions()
}
// NewMiniMaxClientWithOptions creates MiniMax client (supports options pattern)
func NewMiniMaxClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
minimaxOpts := []mcp.ClientOption{
mcp.WithProvider(mcp.ProviderMiniMax),
mcp.WithModel(DefaultMiniMaxModel),
mcp.WithBaseURL(DefaultMiniMaxBaseURL),
}
allOpts := append(minimaxOpts, opts...)
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
minimaxClient := &MiniMaxClient{
Client: baseClient,
}
baseClient.Hooks = minimaxClient
return minimaxClient
}
func (c *MiniMaxClient) SetAPIKey(apiKey string, customURL string, customModel string) {
c.APIKey = apiKey
if len(apiKey) > 8 {
c.Log.Infof("🔧 [MCP] MiniMax API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
c.BaseURL = customURL
c.Log.Infof("🔧 [MCP] MiniMax using custom BaseURL: %s", customURL)
} else {
c.Log.Infof("🔧 [MCP] MiniMax using default BaseURL: %s", c.BaseURL)
}
if customModel != "" {
c.Model = customModel
c.Log.Infof("🔧 [MCP] MiniMax using custom Model: %s", customModel)
} else {
c.Log.Infof("🔧 [MCP] MiniMax using default Model: %s", c.Model)
}
}
// MiniMax uses standard OpenAI-compatible API with Bearer auth
func (c *MiniMaxClient) SetAuthHeader(reqHeaders http.Header) {
c.Client.SetAuthHeader(reqHeaders)
}
+73
View File
@@ -0,0 +1,73 @@
package provider
import (
"net/http"
"nofx/mcp"
)
const (
DefaultOpenAIBaseURL = "https://api.openai.com/v1"
DefaultOpenAIModel = "gpt-5.4"
)
func init() {
mcp.RegisterProvider(mcp.ProviderOpenAI, func(opts ...mcp.ClientOption) mcp.AIClient {
return NewOpenAIClientWithOptions(opts...)
})
}
type OpenAIClient struct {
*mcp.Client
}
func (c *OpenAIClient) BaseClient() *mcp.Client { return c.Client }
// NewOpenAIClient creates OpenAI client (backward compatible)
func NewOpenAIClient() mcp.AIClient {
return NewOpenAIClientWithOptions()
}
// NewOpenAIClientWithOptions creates OpenAI client (supports options pattern)
func NewOpenAIClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
openaiOpts := []mcp.ClientOption{
mcp.WithProvider(mcp.ProviderOpenAI),
mcp.WithModel(DefaultOpenAIModel),
mcp.WithBaseURL(DefaultOpenAIBaseURL),
}
allOpts := append(openaiOpts, opts...)
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
openaiClient := &OpenAIClient{
Client: baseClient,
}
baseClient.Hooks = openaiClient
return openaiClient
}
func (c *OpenAIClient) SetAPIKey(apiKey string, customURL string, customModel string) {
c.APIKey = apiKey
if len(apiKey) > 8 {
c.Log.Infof("🔧 [MCP] OpenAI API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
c.BaseURL = customURL
c.Log.Infof("🔧 [MCP] OpenAI using custom BaseURL: %s", customURL)
} else {
c.Log.Infof("🔧 [MCP] OpenAI using default BaseURL: %s", c.BaseURL)
}
if customModel != "" {
c.Model = customModel
c.Log.Infof("🔧 [MCP] OpenAI using custom Model: %s", customModel)
} else {
c.Log.Infof("🔧 [MCP] OpenAI using default Model: %s", c.Model)
}
}
// OpenAI uses standard Bearer auth
func (c *OpenAIClient) SetAuthHeader(reqHeaders http.Header) {
c.Client.SetAuthHeader(reqHeaders)
}
+83
View File
@@ -0,0 +1,83 @@
package provider
import (
"testing"
"nofx/mcp"
)
func TestOptionsWithDeepSeekClient(t *testing.T) {
logger := mcp.NewNoopLogger()
client := NewDeepSeekClientWithOptions(
mcp.WithAPIKey("sk-deepseek-key"),
mcp.WithLogger(logger),
mcp.WithMaxTokens(5000),
)
dsClient := client.(*DeepSeekClient)
// Verify DeepSeek default values
if dsClient.Provider != mcp.ProviderDeepSeek {
t.Error("Provider should be DeepSeek")
}
if dsClient.BaseURL != mcp.DefaultDeepSeekBaseURL {
t.Error("BaseURL should be DeepSeek default")
}
if dsClient.Model != mcp.DefaultDeepSeekModel {
t.Error("Model should be DeepSeek default")
}
// Verify custom options
if dsClient.APIKey != "sk-deepseek-key" {
t.Error("APIKey should be set from options")
}
if dsClient.Log != logger {
t.Error("Log should be set from options")
}
if dsClient.MaxTokens != 5000 {
t.Error("MaxTokens should be 5000")
}
}
func TestOptionsWithQwenClient(t *testing.T) {
logger := mcp.NewNoopLogger()
client := NewQwenClientWithOptions(
mcp.WithAPIKey("sk-qwen-key"),
mcp.WithLogger(logger),
mcp.WithMaxTokens(6000),
)
qwenClient := client.(*QwenClient)
// Verify Qwen default values
if qwenClient.Provider != mcp.ProviderQwen {
t.Error("Provider should be Qwen")
}
if qwenClient.BaseURL != mcp.DefaultQwenBaseURL {
t.Error("BaseURL should be Qwen default")
}
if qwenClient.Model != mcp.DefaultQwenModel {
t.Error("Model should be Qwen default")
}
// Verify custom options
if qwenClient.APIKey != "sk-qwen-key" {
t.Error("APIKey should be set from options")
}
if qwenClient.Log != logger {
t.Error("Log should be set from options")
}
if qwenClient.MaxTokens != 6000 {
t.Error("MaxTokens should be 6000")
}
}
+74
View File
@@ -0,0 +1,74 @@
package provider
import (
"net/http"
"nofx/mcp"
)
const (
DefaultQwenBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
DefaultQwenModel = "qwen3-max"
)
func init() {
mcp.RegisterProvider(mcp.ProviderQwen, func(opts ...mcp.ClientOption) mcp.AIClient {
return NewQwenClientWithOptions(opts...)
})
}
type QwenClient struct {
*mcp.Client
}
func (c *QwenClient) BaseClient() *mcp.Client { return c.Client }
// NewQwenClient creates Qwen client (backward compatible)
//
// Deprecated: Recommend using NewQwenClientWithOptions for better flexibility
func NewQwenClient() mcp.AIClient {
return NewQwenClientWithOptions()
}
// NewQwenClientWithOptions creates Qwen client (supports options pattern)
func NewQwenClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
qwenOpts := []mcp.ClientOption{
mcp.WithProvider(mcp.ProviderQwen),
mcp.WithModel(DefaultQwenModel),
mcp.WithBaseURL(DefaultQwenBaseURL),
}
allOpts := append(qwenOpts, opts...)
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
qwenClient := &QwenClient{
Client: baseClient,
}
baseClient.Hooks = qwenClient
return qwenClient
}
func (qwenClient *QwenClient) SetAPIKey(apiKey string, customURL string, customModel string) {
qwenClient.APIKey = apiKey
if len(apiKey) > 8 {
qwenClient.Log.Infof("🔧 [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
qwenClient.BaseURL = customURL
qwenClient.Log.Infof("🔧 [MCP] Qwen using custom BaseURL: %s", customURL)
} else {
qwenClient.Log.Infof("🔧 [MCP] Qwen using default BaseURL: %s", qwenClient.BaseURL)
}
if customModel != "" {
qwenClient.Model = customModel
qwenClient.Log.Infof("🔧 [MCP] Qwen using custom Model: %s", customModel)
} else {
qwenClient.Log.Infof("🔧 [MCP] Qwen using default Model: %s", qwenClient.Model)
}
}
func (qwenClient *QwenClient) SetAuthHeader(reqHeaders http.Header) {
qwenClient.Client.SetAuthHeader(reqHeaders)
}
+31
View File
@@ -0,0 +1,31 @@
package mcp
// Provider name constants — kept in the mcp package so that client.go can
// reference them for default configuration without importing sub-packages.
// Provider sub-packages re-use these same values.
const (
ProviderDeepSeek = "deepseek"
ProviderOpenAI = "openai"
ProviderClaude = "claude"
ProviderQwen = "qwen"
ProviderGemini = "gemini"
ProviderGrok = "grok"
ProviderKimi = "kimi"
ProviderMiniMax = "minimax"
ProviderBlockRunBase = "blockrun-base"
ProviderBlockRunSol = "blockrun-sol"
ProviderClaw402 = "claw402"
// Default DeepSeek configuration (used as fallback in NewClient)
DefaultDeepSeekBaseURL = "https://api.deepseek.com"
DefaultDeepSeekModel = "deepseek-chat"
// Default Qwen configuration (used by WithQwenConfig convenience option)
DefaultQwenBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
DefaultQwenModel = "qwen3-max"
// Default MiniMax configuration (used by WithMiniMaxConfig convenience option)
DefaultMiniMaxBaseURL = "https://api.minimax.io/v1"
DefaultMiniMaxModel = "MiniMax-M2.5"
)
-83
View File
@@ -1,83 +0,0 @@
package mcp
import (
"net/http"
)
const (
ProviderQwen = "qwen"
DefaultQwenBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
DefaultQwenModel = "qwen3-max"
)
type QwenClient struct {
*Client
}
// NewQwenClient creates Qwen client (backward compatible)
//
// Deprecated: Recommend using NewQwenClientWithOptions for better flexibility
func NewQwenClient() AIClient {
return NewQwenClientWithOptions()
}
// NewQwenClientWithOptions creates Qwen client (supports options pattern)
//
// Usage examples:
// // Basic usage
// client := mcp.NewQwenClientWithOptions()
//
// // Custom configuration
// client := mcp.NewQwenClientWithOptions(
// mcp.WithAPIKey("sk-xxx"),
// mcp.WithLogger(customLogger),
// mcp.WithTimeout(60*time.Second),
// )
func NewQwenClientWithOptions(opts ...ClientOption) AIClient {
// 1. Create Qwen preset options
qwenOpts := []ClientOption{
WithProvider(ProviderQwen),
WithModel(DefaultQwenModel),
WithBaseURL(DefaultQwenBaseURL),
}
// 2. Merge user options (user options have higher priority)
allOpts := append(qwenOpts, opts...)
// 3. Create base client
baseClient := NewClient(allOpts...).(*Client)
// 4. Create Qwen client
qwenClient := &QwenClient{
Client: baseClient,
}
// 5. Set hooks to point to QwenClient (implement dynamic dispatch)
baseClient.hooks = qwenClient
return qwenClient
}
func (qwenClient *QwenClient) SetAPIKey(apiKey string, customURL string, customModel string) {
qwenClient.APIKey = apiKey
if len(apiKey) > 8 {
qwenClient.logger.Infof("🔧 [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
qwenClient.BaseURL = customURL
qwenClient.logger.Infof("🔧 [MCP] Qwen using custom BaseURL: %s", customURL)
} else {
qwenClient.logger.Infof("🔧 [MCP] Qwen using default BaseURL: %s", qwenClient.BaseURL)
}
if customModel != "" {
qwenClient.Model = customModel
qwenClient.logger.Infof("🔧 [MCP] Qwen using custom Model: %s", customModel)
} else {
qwenClient.logger.Infof("🔧 [MCP] Qwen using default Model: %s", qwenClient.Model)
}
}
func (qwenClient *QwenClient) setAuthHeader(reqHeaders http.Header) {
qwenClient.Client.setAuthHeader(reqHeaders)
}
-272
View File
@@ -1,272 +0,0 @@
package mcp
import (
"testing"
"time"
)
// ============================================================
// Test QwenClient Creation and Configuration
// ============================================================
func TestNewQwenClient_Default(t *testing.T) {
client := NewQwenClient()
if client == nil {
t.Fatal("client should not be nil")
}
// Type assertion check
qwenClient, ok := client.(*QwenClient)
if !ok {
t.Fatal("client should be *QwenClient")
}
// Verify default values
if qwenClient.Provider != ProviderQwen {
t.Errorf("Provider should be '%s', got '%s'", ProviderQwen, qwenClient.Provider)
}
if qwenClient.BaseURL != DefaultQwenBaseURL {
t.Errorf("BaseURL should be '%s', got '%s'", DefaultQwenBaseURL, qwenClient.BaseURL)
}
if qwenClient.Model != DefaultQwenModel {
t.Errorf("Model should be '%s', got '%s'", DefaultQwenModel, qwenClient.Model)
}
if qwenClient.logger == nil {
t.Error("logger should not be nil")
}
if qwenClient.httpClient == nil {
t.Error("httpClient should not be nil")
}
}
func TestNewQwenClientWithOptions(t *testing.T) {
mockLogger := NewMockLogger()
customModel := "qwen-plus"
customAPIKey := "sk-custom-qwen-key"
client := NewQwenClientWithOptions(
WithLogger(mockLogger),
WithModel(customModel),
WithAPIKey(customAPIKey),
WithMaxTokens(4000),
)
qwenClient := client.(*QwenClient)
// Verify custom options are applied
if qwenClient.logger != mockLogger {
t.Error("logger should be set from option")
}
if qwenClient.Model != customModel {
t.Error("Model should be set from option")
}
if qwenClient.APIKey != customAPIKey {
t.Error("APIKey should be set from option")
}
if qwenClient.MaxTokens != 4000 {
t.Error("MaxTokens should be 4000")
}
// Verify Qwen default values are retained
if qwenClient.Provider != ProviderQwen {
t.Errorf("Provider should still be '%s'", ProviderQwen)
}
if qwenClient.BaseURL != DefaultQwenBaseURL {
t.Errorf("BaseURL should still be '%s'", DefaultQwenBaseURL)
}
}
// ============================================================
// Test SetAPIKey
// ============================================================
func TestQwenClient_SetAPIKey(t *testing.T) {
mockLogger := NewMockLogger()
client := NewQwenClientWithOptions(
WithLogger(mockLogger),
)
qwenClient := client.(*QwenClient)
// Test setting API Key (default URL and Model)
qwenClient.SetAPIKey("sk-test-key-12345678", "", "")
if qwenClient.APIKey != "sk-test-key-12345678" {
t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", qwenClient.APIKey)
}
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
if len(logs) == 0 {
t.Error("should have logged API key setting")
}
// Verify BaseURL and Model remain default
if qwenClient.BaseURL != DefaultQwenBaseURL {
t.Error("BaseURL should remain default")
}
if qwenClient.Model != DefaultQwenModel {
t.Error("Model should remain default")
}
}
func TestQwenClient_SetAPIKey_WithCustomURL(t *testing.T) {
mockLogger := NewMockLogger()
client := NewQwenClientWithOptions(
WithLogger(mockLogger),
)
qwenClient := client.(*QwenClient)
customURL := "https://custom.qwen.api.com/v1"
qwenClient.SetAPIKey("sk-test-key-12345678", customURL, "")
if qwenClient.BaseURL != customURL {
t.Errorf("BaseURL should be '%s', got '%s'", customURL, qwenClient.BaseURL)
}
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
hasCustomURLLog := false
for _, log := range logs {
if log.Format == "🔧 [MCP] Qwen using custom BaseURL: %s" {
hasCustomURLLog = true
break
}
}
if !hasCustomURLLog {
t.Error("should have logged custom BaseURL")
}
}
func TestQwenClient_SetAPIKey_WithCustomModel(t *testing.T) {
mockLogger := NewMockLogger()
client := NewQwenClientWithOptions(
WithLogger(mockLogger),
)
qwenClient := client.(*QwenClient)
customModel := "qwen-turbo"
qwenClient.SetAPIKey("sk-test-key-12345678", "", customModel)
if qwenClient.Model != customModel {
t.Errorf("Model should be '%s', got '%s'", customModel, qwenClient.Model)
}
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
hasCustomModelLog := false
for _, log := range logs {
if log.Format == "🔧 [MCP] Qwen using custom Model: %s" {
hasCustomModelLog = true
break
}
}
if !hasCustomModelLog {
t.Error("should have logged custom Model")
}
}
// ============================================================
// Test Integration Features
// ============================================================
func TestQwenClient_CallWithMessages_Success(t *testing.T) {
mockHTTP := NewMockHTTPClient()
mockHTTP.SetSuccessResponse("Qwen AI response")
mockLogger := NewMockLogger()
client := NewQwenClientWithOptions(
WithHTTPClient(mockHTTP.ToHTTPClient()),
WithLogger(mockLogger),
WithAPIKey("sk-test-key"),
)
result, err := client.CallWithMessages("system prompt", "user prompt")
if err != nil {
t.Fatalf("should not error: %v", err)
}
if result != "Qwen AI response" {
t.Errorf("expected 'Qwen AI response', got '%s'", result)
}
// Verify request
requests := mockHTTP.GetRequests()
if len(requests) != 1 {
t.Fatalf("expected 1 request, got %d", len(requests))
}
req := requests[0]
// Verify URL
expectedURL := DefaultQwenBaseURL + "/chat/completions"
if req.URL.String() != expectedURL {
t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String())
}
// Verify Authorization header
authHeader := req.Header.Get("Authorization")
if authHeader != "Bearer sk-test-key" {
t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader)
}
// Verify Content-Type
if req.Header.Get("Content-Type") != "application/json" {
t.Error("Content-Type should be application/json")
}
}
func TestQwenClient_Timeout(t *testing.T) {
client := NewQwenClientWithOptions(
WithTimeout(30 * time.Second),
)
qwenClient := client.(*QwenClient)
if qwenClient.httpClient.Timeout != 30*time.Second {
t.Errorf("expected timeout 30s, got %v", qwenClient.httpClient.Timeout)
}
// Test SetTimeout
client.SetTimeout(60 * time.Second)
if qwenClient.httpClient.Timeout != 60*time.Second {
t.Errorf("expected timeout 60s after SetTimeout, got %v", qwenClient.httpClient.Timeout)
}
}
// ============================================================
// Test hooks Mechanism
// ============================================================
func TestQwenClient_HooksIntegration(t *testing.T) {
client := NewQwenClientWithOptions()
qwenClient := client.(*QwenClient)
// Verify hooks point to qwenClient itself (implements polymorphism)
if qwenClient.hooks != qwenClient {
t.Error("hooks should point to qwenClient for polymorphism")
}
// Verify buildUrl uses Qwen configuration
url := qwenClient.buildUrl()
expectedURL := DefaultQwenBaseURL + "/chat/completions"
if url != expectedURL {
t.Errorf("expected URL '%s', got '%s'", expectedURL, url)
}
}
+20
View File
@@ -0,0 +1,20 @@
package mcp
// providerRegistry maps provider names to factory functions.
var providerRegistry = map[string]func(...ClientOption) AIClient{}
// RegisterProvider registers a provider factory function.
// Called by provider/payment sub-packages in their init() functions.
func RegisterProvider(name string, factory func(...ClientOption) AIClient) {
providerRegistry[name] = factory
}
// NewAIClientByProvider creates an AIClient by provider name using the registry.
// Returns nil if the provider is not registered.
func NewAIClientByProvider(name string, opts ...ClientOption) AIClient {
factory, ok := providerRegistry[name]
if !ok {
return nil
}
return factory(opts...)
}
+2 -2
View File
@@ -450,10 +450,10 @@ func TestClient_CallWithRequest_UsesClientModel(t *testing.T) {
mockHTTP.SetSuccessResponse("Response")
mockLogger := NewMockLogger()
client := NewDeepSeekClientWithOptions(
client := NewClient(
WithDeepSeekConfig("sk-test-key"),
WithHTTPClient(mockHTTP.ToHTTPClient()),
WithLogger(mockLogger),
WithAPIKey("sk-test-key"),
)
// Request does not set model, should use Client's model
+6 -25
View File
@@ -5,6 +5,8 @@ import (
"nofx/config"
"nofx/logger"
"nofx/mcp"
_ "nofx/mcp/payment"
_ "nofx/mcp/provider"
"nofx/store"
"nofx/telegram/agent"
"os"
@@ -319,32 +321,11 @@ func isUSDCProvider(provider string) bool {
}
func clientForProvider(provider string) mcp.AIClient {
switch provider {
case "openai":
return mcp.NewOpenAIClient()
case "deepseek":
return mcp.NewDeepSeekClient()
case "claude":
return mcp.NewClaudeClient()
case "qwen":
return mcp.NewQwenClient()
case "kimi":
return mcp.NewKimiClient()
case "grok":
return mcp.NewGrokClient()
case "gemini":
return mcp.NewGeminiClient()
case "minimax":
return mcp.NewMiniMaxClient()
case "blockrun-base":
return mcp.NewBlockRunBaseClient()
case "blockrun-sol":
return mcp.NewBlockRunSolClient()
case "claw402":
return mcp.NewClaw402Client()
default:
return mcp.NewDeepSeekClient()
client := mcp.NewAIClientByProvider(provider)
if client == nil {
client = mcp.NewAIClientByProvider("deepseek")
}
return client
}
// ── Status message ────────────────────────────────────────────────────────────
+33 -1771
View File
File diff suppressed because it is too large Load Diff
+527
View File
@@ -0,0 +1,527 @@
package trader
import (
"fmt"
"math"
"nofx/experience"
"nofx/kernel"
"nofx/logger"
"nofx/market"
"nofx/store"
"time"
)
// saveEquitySnapshot saves equity snapshot independently (for drawing profit curve, decoupled from AI decision)
func (at *AutoTrader) saveEquitySnapshot(ctx *kernel.Context) {
if at.store == nil || ctx == nil {
return
}
snapshot := &store.EquitySnapshot{
TraderID: at.id,
Timestamp: time.Now().UTC(),
TotalEquity: ctx.Account.TotalEquity,
Balance: ctx.Account.TotalEquity - ctx.Account.UnrealizedPnL,
UnrealizedPnL: ctx.Account.UnrealizedPnL,
PositionCount: ctx.Account.PositionCount,
MarginUsedPct: ctx.Account.MarginUsedPct,
}
if err := at.store.Equity().Save(snapshot); err != nil {
logger.Infof("⚠️ Failed to save equity snapshot: %v", err)
}
}
// saveDecision saves AI decision log to database (only records AI input/output, for debugging)
func (at *AutoTrader) saveDecision(record *store.DecisionRecord) error {
if at.store == nil {
return nil
}
at.cycleNumber++
record.CycleNumber = at.cycleNumber
record.TraderID = at.id
if record.Timestamp.IsZero() {
record.Timestamp = time.Now().UTC()
}
if err := at.store.Decision().LogDecision(record); err != nil {
logger.Infof("⚠️ Failed to save decision record: %v", err)
return err
}
logger.Infof("📝 Decision record saved: trader=%s, cycle=%d", at.id, at.cycleNumber)
return nil
}
// GetStatus gets system status (for API)
func (at *AutoTrader) GetStatus() map[string]interface{} {
aiProvider := "DeepSeek"
if at.config.UseQwen {
aiProvider = "Qwen"
}
at.isRunningMutex.RLock()
isRunning := at.isRunning
at.isRunningMutex.RUnlock()
result := map[string]interface{}{
"trader_id": at.id,
"trader_name": at.name,
"ai_model": at.aiModel,
"exchange": at.exchange,
"is_running": isRunning,
"start_time": at.startTime.Format(time.RFC3339),
"runtime_minutes": int(time.Since(at.startTime).Minutes()),
"call_count": at.callCount,
"initial_balance": at.initialBalance,
"scan_interval": at.config.ScanInterval.String(),
"stop_until": at.stopUntil.Format(time.RFC3339),
"last_reset_time": at.lastResetTime.Format(time.RFC3339),
"ai_provider": aiProvider,
}
// Add strategy info
if at.config.StrategyConfig != nil {
result["strategy_type"] = at.config.StrategyConfig.StrategyType
if at.config.StrategyConfig.GridConfig != nil {
result["grid_symbol"] = at.config.StrategyConfig.GridConfig.Symbol
}
}
return result
}
// GetAccountInfo gets account information (for API)
func (at *AutoTrader) GetAccountInfo() (map[string]interface{}, error) {
balance, err := at.trader.GetBalance()
if err != nil {
return nil, fmt.Errorf("failed to get balance: %w", err)
}
// Get account fields
totalWalletBalance := 0.0
totalUnrealizedProfit := 0.0
availableBalance := 0.0
totalEquity := 0.0
if wallet, ok := balance["totalWalletBalance"].(float64); ok {
totalWalletBalance = wallet
}
if unrealized, ok := balance["totalUnrealizedProfit"].(float64); ok {
totalUnrealizedProfit = unrealized
}
if avail, ok := balance["availableBalance"].(float64); ok {
availableBalance = avail
}
// Use totalEquity directly if provided by trader (more accurate)
if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 {
totalEquity = eq
} else {
// Fallback: Total Equity = Wallet balance + Unrealized profit
totalEquity = totalWalletBalance + totalUnrealizedProfit
}
// Get positions to calculate total margin
positions, err := at.trader.GetPositions()
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
totalMarginUsed := 0.0
totalUnrealizedPnLCalculated := 0.0
for _, pos := range positions {
markPrice := pos["markPrice"].(float64)
quantity := pos["positionAmt"].(float64)
if quantity < 0 {
quantity = -quantity
}
unrealizedPnl := pos["unRealizedProfit"].(float64)
totalUnrealizedPnLCalculated += unrealizedPnl
leverage := 10
if lev, ok := pos["leverage"].(float64); ok {
leverage = int(lev)
}
marginUsed := (quantity * markPrice) / float64(leverage)
totalMarginUsed += marginUsed
}
// Verify unrealized P&L consistency (API value vs calculated from positions)
// Note: Lighter API may return 0 for unrealized PnL, this is a known limitation
diff := math.Abs(totalUnrealizedProfit - totalUnrealizedPnLCalculated)
if diff > 5.0 { // Only warn if difference is significant (> 5 USDT)
logger.Infof("⚠️ Unrealized P&L inconsistency (Lighter API limitation): API=%.4f, Calculated=%.4f, Diff=%.4f",
totalUnrealizedProfit, totalUnrealizedPnLCalculated, diff)
}
totalPnL := totalEquity - at.initialBalance
totalPnLPct := 0.0
if at.initialBalance > 0 {
totalPnLPct = (totalPnL / at.initialBalance) * 100
} else {
logger.Infof("⚠️ Initial Balance abnormal: %.2f, cannot calculate P&L percentage", at.initialBalance)
}
marginUsedPct := 0.0
if totalEquity > 0 {
marginUsedPct = (totalMarginUsed / totalEquity) * 100
}
return map[string]interface{}{
// Core fields
"total_equity": totalEquity, // Account equity = wallet + unrealized
"wallet_balance": totalWalletBalance, // Wallet balance (excluding unrealized P&L)
"unrealized_profit": totalUnrealizedProfit, // Unrealized P&L (official value from exchange API)
"available_balance": availableBalance, // Available balance
// P&L statistics
"total_pnl": totalPnL, // Total P&L = equity - initial
"total_pnl_pct": totalPnLPct, // Total P&L percentage
"initial_balance": at.initialBalance, // Initial balance
"daily_pnl": at.dailyPnL, // Daily P&L
// Position information
"position_count": len(positions), // Position count
"margin_used": totalMarginUsed, // Margin used
"margin_used_pct": marginUsedPct, // Margin usage rate
}, nil
}
// GetPositions gets position list (for API)
func (at *AutoTrader) GetPositions() ([]map[string]interface{}, error) {
positions, err := at.trader.GetPositions()
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
var result []map[string]interface{}
for _, pos := range positions {
symbol := pos["symbol"].(string)
side := pos["side"].(string)
entryPrice := pos["entryPrice"].(float64)
markPrice := pos["markPrice"].(float64)
quantity := pos["positionAmt"].(float64)
if quantity < 0 {
quantity = -quantity
}
unrealizedPnl := pos["unRealizedProfit"].(float64)
liquidationPrice := pos["liquidationPrice"].(float64)
leverage := 10
if lev, ok := pos["leverage"].(float64); ok {
leverage = int(lev)
}
// Calculate margin used
marginUsed := (quantity * markPrice) / float64(leverage)
// Calculate P&L percentage (based on margin)
pnlPct := calculatePnLPercentage(unrealizedPnl, marginUsed)
result = append(result, map[string]interface{}{
"symbol": symbol,
"side": side,
"entry_price": entryPrice,
"mark_price": markPrice,
"quantity": quantity,
"leverage": leverage,
"unrealized_pnl": unrealizedPnl,
"unrealized_pnl_pct": pnlPct,
"liquidation_price": liquidationPrice,
"margin_used": marginUsed,
})
}
return result, nil
}
// recordAndConfirmOrder polls order status for actual fill data and records position
// action: open_long, open_short, close_long, close_short
// entryPrice: entry price when closing (0 when opening)
func (at *AutoTrader) recordAndConfirmOrder(orderResult map[string]interface{}, symbol, action string, quantity float64, price float64, leverage int, entryPrice float64) {
if at.store == nil {
return
}
// Get order ID (supports multiple types)
var orderID string
switch v := orderResult["orderId"].(type) {
case int64:
orderID = fmt.Sprintf("%d", v)
case float64:
orderID = fmt.Sprintf("%.0f", v)
case string:
orderID = v
default:
orderID = fmt.Sprintf("%v", v)
}
if orderID == "" || orderID == "0" {
logger.Infof(" ⚠️ Order ID is empty, skipping record")
return
}
// Determine positionSide
var positionSide string
switch action {
case "open_long", "close_long":
positionSide = "LONG"
case "open_short", "close_short":
positionSide = "SHORT"
}
var actualPrice = price
var actualQty = quantity
var fee float64
// Exchanges with OrderSync: Skip immediate order recording, let OrderSync handle it
// This ensures accurate data from GetTrades API and avoids duplicate records
switch at.exchange {
case "binance", "lighter", "hyperliquid", "bybit", "okx", "bitget", "aster", "kucoin", "gate":
logger.Infof(" 📝 Order submitted (id: %s), will be synced by OrderSync", orderID)
return
}
// For exchanges without OrderSync (e.g., Binance): record immediately and poll for fill data
orderRecord := at.createOrderRecord(orderID, symbol, action, positionSide, quantity, price, leverage)
if err := at.store.Order().CreateOrder(orderRecord); err != nil {
logger.Infof(" ⚠️ Failed to record order: %v", err)
} else {
logger.Infof(" 📝 Order recorded: %s [%s] %s", orderID, action, symbol)
}
// Wait for order to be filled and get actual fill data
time.Sleep(500 * time.Millisecond)
for i := 0; i < 5; i++ {
status, err := at.trader.GetOrderStatus(symbol, orderID)
if err == nil {
statusStr, _ := status["status"].(string)
if statusStr == "FILLED" {
// Get actual fill price
if avgPrice, ok := status["avgPrice"].(float64); ok && avgPrice > 0 {
actualPrice = avgPrice
}
// Get actual executed quantity
if execQty, ok := status["executedQty"].(float64); ok && execQty > 0 {
actualQty = execQty
}
// Get commission/fee
if commission, ok := status["commission"].(float64); ok {
fee = commission
}
logger.Infof(" ✅ Order filled: avgPrice=%.6f, qty=%.6f, fee=%.6f", actualPrice, actualQty, fee)
// Update order status to FILLED
if err := at.store.Order().UpdateOrderStatus(orderRecord.ID, "FILLED", actualQty, actualPrice, fee); err != nil {
logger.Infof(" ⚠️ Failed to update order status: %v", err)
}
// Record fill details
at.recordOrderFill(orderRecord.ID, orderID, symbol, action, actualPrice, actualQty, fee)
break
} else if statusStr == "CANCELED" || statusStr == "EXPIRED" || statusStr == "REJECTED" {
logger.Infof(" ⚠️ Order %s, skipping position record", statusStr)
// Update order status
if err := at.store.Order().UpdateOrderStatus(orderRecord.ID, statusStr, 0, 0, 0); err != nil {
logger.Infof(" ⚠️ Failed to update order status: %v", err)
}
return
}
}
time.Sleep(500 * time.Millisecond)
}
// Normalize symbol for position record consistency
normalizedSymbolForPosition := market.Normalize(symbol)
logger.Infof(" 📝 Recording position (ID: %s, action: %s, price: %.6f, qty: %.6f, fee: %.4f)",
orderID, action, actualPrice, actualQty, fee)
// Record position change with actual fill data (use normalized symbol)
at.recordPositionChange(orderID, normalizedSymbolForPosition, positionSide, action, actualQty, actualPrice, leverage, entryPrice, fee)
// Send anonymous trade statistics for experience improvement (async, non-blocking)
// This helps us understand overall product usage across all deployments
experience.TrackTrade(experience.TradeEvent{
Exchange: at.exchange,
TradeType: action,
Symbol: symbol,
AmountUSD: actualPrice * actualQty,
Leverage: leverage,
UserID: at.userID,
TraderID: at.id,
})
}
// recordPositionChange records position change (create record on open, update record on close)
func (at *AutoTrader) recordPositionChange(orderID, symbol, side, action string, quantity, price float64, leverage int, entryPrice float64, fee float64) {
if at.store == nil {
return
}
switch action {
case "open_long", "open_short":
// Open position: create new position record
nowMs := time.Now().UTC().UnixMilli()
pos := &store.TraderPosition{
TraderID: at.id,
ExchangeID: at.exchangeID, // Exchange account UUID
ExchangeType: at.exchange, // Exchange type: binance/bybit/okx/etc
Symbol: symbol,
Side: side, // LONG or SHORT
Quantity: quantity,
EntryPrice: price,
EntryOrderID: orderID,
EntryTime: nowMs,
Leverage: leverage,
Status: "OPEN",
CreatedAt: nowMs,
UpdatedAt: nowMs,
}
if err := at.store.Position().Create(pos); err != nil {
logger.Infof(" ⚠️ Failed to record position: %v", err)
} else {
logger.Infof(" 📊 Position recorded [%s] %s %s @ %.4f", at.id[:8], symbol, side, price)
}
case "close_long", "close_short":
// Close position using PositionBuilder for consistent handling
// PositionBuilder will handle both cases:
// 1. If open position exists: close it properly
// 2. If no open position (e.g., table cleared): create a closed position record
posBuilder := store.NewPositionBuilder(at.store.Position())
if err := posBuilder.ProcessTrade(
at.id, at.exchangeID, at.exchange,
symbol, side, action,
quantity, price, fee, 0, // realizedPnL will be calculated
time.Now().UTC().UnixMilli(), orderID,
); err != nil {
logger.Infof(" ⚠️ Failed to process close position: %v", err)
} else {
logger.Infof(" ✅ Position closed [%s] %s %s @ %.4f", at.id[:8], symbol, side, price)
}
}
}
// createOrderRecord creates an order record struct from order details
func (at *AutoTrader) createOrderRecord(orderID, symbol, action, positionSide string, quantity, price float64, leverage int) *store.TraderOrder {
// Determine order type (market for auto trader)
orderType := "MARKET"
// Determine side (BUY/SELL)
var side string
switch action {
case "open_long", "close_short":
side = "BUY"
case "open_short", "close_long":
side = "SELL"
}
// Use action as orderAction directly (keep lowercase format)
orderAction := action
// Determine if it's a reduce only order
reduceOnly := (action == "close_long" || action == "close_short")
// Normalize symbol for consistency
normalizedSymbol := market.Normalize(symbol)
return &store.TraderOrder{
TraderID: at.id,
ExchangeID: at.exchangeID,
ExchangeType: at.exchange,
ExchangeOrderID: orderID,
Symbol: normalizedSymbol,
Side: side,
PositionSide: positionSide,
Type: orderType,
TimeInForce: "GTC",
Quantity: quantity,
Price: price,
Status: "NEW",
FilledQuantity: 0,
AvgFillPrice: 0,
Commission: 0,
CommissionAsset: "USDT",
Leverage: leverage,
ReduceOnly: reduceOnly,
ClosePosition: reduceOnly,
OrderAction: orderAction,
CreatedAt: time.Now().UTC().UnixMilli(),
UpdatedAt: time.Now().UTC().UnixMilli(),
}
}
// recordOrderFill records order fill/trade details
func (at *AutoTrader) recordOrderFill(orderRecordID int64, exchangeOrderID, symbol, action string, price, quantity, fee float64) {
if at.store == nil {
return
}
// Determine side (BUY/SELL)
var side string
switch action {
case "open_long", "close_short":
side = "BUY"
case "open_short", "close_long":
side = "SELL"
}
// Generate a simple trade ID (exchange doesn't always provide one)
tradeID := fmt.Sprintf("%s-%d", exchangeOrderID, time.Now().UnixNano())
// Normalize symbol for consistency
normalizedSymbol := market.Normalize(symbol)
fill := &store.TraderFill{
TraderID: at.id,
ExchangeID: at.exchangeID,
ExchangeType: at.exchange,
OrderID: orderRecordID,
ExchangeOrderID: exchangeOrderID,
ExchangeTradeID: tradeID,
Symbol: normalizedSymbol,
Side: side,
Price: price,
Quantity: quantity,
QuoteQuantity: price * quantity,
Commission: fee,
CommissionAsset: "USDT",
RealizedPnL: 0, // Will be calculated for close orders
IsMaker: false, // Market orders are usually taker
CreatedAt: time.Now().UTC().UnixMilli(),
}
// Calculate realized PnL for close orders
if action == "close_long" || action == "close_short" {
// Try to get the entry price from the open position
var positionSide string
if action == "close_long" {
positionSide = "LONG"
} else {
positionSide = "SHORT"
}
if openPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, symbol, positionSide); err == nil && openPos != nil {
if positionSide == "LONG" {
fill.RealizedPnL = (price - openPos.EntryPrice) * quantity
} else {
fill.RealizedPnL = (openPos.EntryPrice - price) * quantity
}
}
}
if err := at.store.Order().CreateFill(fill); err != nil {
logger.Infof(" ⚠️ Failed to record fill: %v", err)
} else {
logger.Infof(" 📋 Fill recorded: %.4f @ %.6f, fee: %.4f", quantity, price, fee)
}
}
// GetOpenOrders returns open orders (pending SL/TP) from exchange
func (at *AutoTrader) GetOpenOrders(symbol string) ([]OpenOrder, error) {
return at.trader.GetOpenOrders(symbol)
}
+560
View File
@@ -0,0 +1,560 @@
package trader
import (
"encoding/json"
"fmt"
"nofx/kernel"
"nofx/logger"
"nofx/store"
"strings"
"time"
)
// runCycle runs one trading cycle (using AI full decision-making)
func (at *AutoTrader) runCycle() error {
at.callCount++
logger.Info("\n" + strings.Repeat("=", 70) + "\n")
logger.Infof("⏰ %s - AI decision cycle #%d", time.Now().Format("2006-01-02 15:04:05"), at.callCount)
logger.Info(strings.Repeat("=", 70))
// 0. Check if trader is stopped (early exit to prevent trades after Stop() is called)
at.isRunningMutex.RLock()
running := at.isRunning
at.isRunningMutex.RUnlock()
if !running {
logger.Infof("⏹ Trader is stopped, aborting cycle #%d", at.callCount)
return nil
}
// Create decision record
record := &store.DecisionRecord{
ExecutionLog: []string{},
Success: true,
}
// 1. Check if trading needs to be stopped
if time.Now().Before(at.stopUntil) {
remaining := at.stopUntil.Sub(time.Now())
logger.Infof("⏸ Risk control: Trading paused, remaining %.0f minutes", remaining.Minutes())
record.Success = false
record.ErrorMessage = fmt.Sprintf("Risk control paused, remaining %.0f minutes", remaining.Minutes())
at.saveDecision(record)
return nil
}
// 2. Reset daily P&L (reset every day)
if time.Since(at.lastResetTime) > 24*time.Hour {
at.dailyPnL = 0
at.lastResetTime = time.Now()
logger.Info("📅 Daily P&L reset")
}
// 4. Collect trading context
ctx, err := at.buildTradingContext()
if err != nil {
record.Success = false
record.ErrorMessage = fmt.Sprintf("Failed to build trading context: %v", err)
at.saveDecision(record)
return fmt.Errorf("failed to build trading context: %w", err)
}
// Save equity snapshot independently (decoupled from AI decision, used for drawing profit curve)
// NOTE: Must be called BEFORE candidate coins check to ensure equity is always recorded
at.saveEquitySnapshot(ctx)
// 如果没有候选币种,记录但不报错
if len(ctx.CandidateCoins) == 0 {
logger.Infof("️ No candidate coins available, skipping this cycle")
record.Success = true // 不是错误,只是没有候选币
record.ExecutionLog = append(record.ExecutionLog, "No candidate coins available, cycle skipped")
record.AccountState = store.AccountSnapshot{
TotalBalance: ctx.Account.TotalEquity,
AvailableBalance: ctx.Account.AvailableBalance,
TotalUnrealizedProfit: ctx.Account.UnrealizedPnL,
PositionCount: ctx.Account.PositionCount,
InitialBalance: at.initialBalance,
}
at.saveDecision(record)
return nil
}
logger.Info(strings.Repeat("=", 70))
for _, coin := range ctx.CandidateCoins {
record.CandidateCoins = append(record.CandidateCoins, coin.Symbol)
}
logger.Infof("📊 Account equity: %.2f USDT | Available: %.2f USDT | Positions: %d",
ctx.Account.TotalEquity, ctx.Account.AvailableBalance, ctx.Account.PositionCount)
// 5. Use strategy engine to call AI for decision
logger.Infof("🤖 Requesting AI analysis and decision... [Strategy Engine]")
aiDecision, err := kernel.GetFullDecisionWithStrategy(ctx, at.mcpClient, at.strategyEngine, "balanced")
if aiDecision != nil && aiDecision.AIRequestDurationMs > 0 {
record.AIRequestDurationMs = aiDecision.AIRequestDurationMs
logger.Infof("⏱️ AI call duration: %.2f seconds", float64(record.AIRequestDurationMs)/1000)
record.ExecutionLog = append(record.ExecutionLog,
fmt.Sprintf("AI call duration: %d ms", record.AIRequestDurationMs))
}
// Save chain of thought, decisions, and input prompt even if there's an error (for debugging)
if aiDecision != nil {
record.SystemPrompt = aiDecision.SystemPrompt // Save system prompt
record.InputPrompt = aiDecision.UserPrompt
record.CoTTrace = aiDecision.CoTTrace
record.RawResponse = aiDecision.RawResponse // Save raw AI response for debugging
if len(aiDecision.Decisions) > 0 {
decisionJSON, _ := json.MarshalIndent(aiDecision.Decisions, "", " ")
record.DecisionJSON = string(decisionJSON)
}
}
if err != nil {
record.Success = false
record.ErrorMessage = fmt.Sprintf("Failed to get AI decision: %v", err)
// Print system prompt and AI chain of thought (output even with errors for debugging)
if aiDecision != nil {
logger.Info("\n" + strings.Repeat("=", 70) + "\n")
logger.Infof("📋 System prompt (error case)")
logger.Info(strings.Repeat("=", 70))
logger.Info(aiDecision.SystemPrompt)
logger.Info(strings.Repeat("=", 70))
if aiDecision.CoTTrace != "" {
logger.Info("\n" + strings.Repeat("-", 70) + "\n")
logger.Info("💭 AI chain of thought analysis (error case):")
logger.Info(strings.Repeat("-", 70))
logger.Info(aiDecision.CoTTrace)
logger.Info(strings.Repeat("-", 70))
}
}
at.saveDecision(record)
return fmt.Errorf("failed to get AI decision: %w", err)
}
// // 5. Print system prompt
// logger.Infof("\n" + strings.Repeat("=", 70))
// logger.Infof("📋 System prompt [template: %s]", at.systemPromptTemplate)
// logger.Info(strings.Repeat("=", 70))
// logger.Info(decision.SystemPrompt)
// logger.Infof(strings.Repeat("=", 70) + "\n")
// 6. Print AI chain of thought
// logger.Infof("\n" + strings.Repeat("-", 70))
// logger.Info("💭 AI chain of thought analysis:")
// logger.Info(strings.Repeat("-", 70))
// logger.Info(decision.CoTTrace)
// logger.Infof(strings.Repeat("-", 70) + "\n")
// 7. Print AI decisions
// logger.Infof("📋 AI decision list (%d items):\n", len(kernel.Decisions))
// for i, d := range kernel.Decisions {
// logger.Infof(" [%d] %s: %s - %s", i+1, d.Symbol, d.Action, d.Reasoning)
// if d.Action == "open_long" || d.Action == "open_short" {
// logger.Infof(" Leverage: %dx | Position: %.2f USDT | Stop loss: %.4f | Take profit: %.4f",
// d.Leverage, d.PositionSizeUSD, d.StopLoss, d.TakeProfit)
// }
// }
logger.Info()
logger.Info(strings.Repeat("-", 70))
// 8. Sort decisions: ensure close positions first, then open positions (prevent position stacking overflow)
logger.Info(strings.Repeat("-", 70))
// 8. Sort decisions: ensure close positions first, then open positions (prevent position stacking overflow)
sortedDecisions := sortDecisionsByPriority(aiDecision.Decisions)
logger.Info("🔄 Execution order (optimized): Close positions first → Open positions later")
for i, d := range sortedDecisions {
logger.Infof(" [%d] %s %s", i+1, d.Symbol, d.Action)
}
logger.Info()
// Check if trader is stopped before executing any decisions (prevent trades after Stop())
at.isRunningMutex.RLock()
running = at.isRunning
at.isRunningMutex.RUnlock()
if !running {
logger.Infof("⏹ Trader stopped before decision execution, aborting cycle #%d", at.callCount)
return nil
}
// Execute decisions and record results
for _, d := range sortedDecisions {
// Check if trader is stopped before each decision (allow immediate stop during execution)
at.isRunningMutex.RLock()
running = at.isRunning
at.isRunningMutex.RUnlock()
if !running {
logger.Infof("⏹ Trader stopped during decision execution, aborting remaining decisions")
break
}
actionRecord := store.DecisionAction{
Action: d.Action,
Symbol: d.Symbol,
Quantity: 0,
Leverage: d.Leverage,
Price: 0,
StopLoss: d.StopLoss,
TakeProfit: d.TakeProfit,
Confidence: d.Confidence,
Reasoning: d.Reasoning,
Timestamp: time.Now().UTC(),
Success: false,
}
if err := at.executeDecisionWithRecord(&d, &actionRecord); err != nil {
logger.Infof("❌ Failed to execute decision (%s %s): %v", d.Symbol, d.Action, err)
actionRecord.Error = err.Error()
record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("❌ %s %s failed: %v", d.Symbol, d.Action, err))
} else {
actionRecord.Success = true
record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("✓ %s %s succeeded", d.Symbol, d.Action))
// Brief delay after successful execution
time.Sleep(1 * time.Second)
}
record.Decisions = append(record.Decisions, actionRecord)
}
// 9. Save decision record
if err := at.saveDecision(record); err != nil {
logger.Infof("⚠ Failed to save decision record: %v", err)
}
return nil
}
// buildTradingContext builds trading context
func (at *AutoTrader) buildTradingContext() (*kernel.Context, error) {
// 1. Get account information
balance, err := at.trader.GetBalance()
if err != nil {
return nil, fmt.Errorf("failed to get account balance: %w", err)
}
// Get account fields
totalWalletBalance := 0.0
totalUnrealizedProfit := 0.0
availableBalance := 0.0
totalEquity := 0.0
if wallet, ok := balance["totalWalletBalance"].(float64); ok {
totalWalletBalance = wallet
}
if unrealized, ok := balance["totalUnrealizedProfit"].(float64); ok {
totalUnrealizedProfit = unrealized
}
if avail, ok := balance["availableBalance"].(float64); ok {
availableBalance = avail
}
// Use totalEquity directly if provided by trader (more accurate)
if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 {
totalEquity = eq
} else {
// Fallback: Total Equity = Wallet balance + Unrealized profit
totalEquity = totalWalletBalance + totalUnrealizedProfit
}
// 2. Get position information
positions, err := at.trader.GetPositions()
if err != nil {
return nil, fmt.Errorf("failed to get positions: %w", err)
}
var positionInfos []kernel.PositionInfo
totalMarginUsed := 0.0
// Current position key set (for cleaning up closed position records)
currentPositionKeys := make(map[string]bool)
for _, pos := range positions {
symbol := pos["symbol"].(string)
side := pos["side"].(string)
entryPrice := pos["entryPrice"].(float64)
markPrice := pos["markPrice"].(float64)
quantity := pos["positionAmt"].(float64)
if quantity < 0 {
quantity = -quantity // Short position quantity is negative, convert to positive
}
// Skip closed positions (quantity = 0), prevent "ghost positions" from being passed to AI
if quantity == 0 {
continue
}
unrealizedPnl := pos["unRealizedProfit"].(float64)
liquidationPrice := pos["liquidationPrice"].(float64)
// Calculate margin used (estimated)
leverage := 10 // Default value, should actually be fetched from position info
if lev, ok := pos["leverage"].(float64); ok {
leverage = int(lev)
}
marginUsed := (quantity * markPrice) / float64(leverage)
totalMarginUsed += marginUsed
// Calculate P&L percentage (based on margin, considering leverage)
pnlPct := calculatePnLPercentage(unrealizedPnl, marginUsed)
// Get position open time from exchange (preferred) or fallback to local tracking
posKey := symbol + "_" + side
currentPositionKeys[posKey] = true
var updateTime int64
// Priority 1: Get from database (trader_positions table) - most accurate
if at.store != nil {
if dbPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, symbol, side); err == nil && dbPos != nil {
if dbPos.EntryTime > 0 {
updateTime = dbPos.EntryTime
}
}
}
// Priority 2: Get from exchange API (Bybit: createdTime, OKX: createdTime)
if updateTime == 0 {
if createdTime, ok := pos["createdTime"].(int64); ok && createdTime > 0 {
updateTime = createdTime
}
}
// Priority 3: Fallback to local tracking
if updateTime == 0 {
if _, exists := at.positionFirstSeenTime[posKey]; !exists {
at.positionFirstSeenTime[posKey] = time.Now().UnixMilli()
}
updateTime = at.positionFirstSeenTime[posKey]
}
// Get peak profit rate for this position
at.peakPnLCacheMutex.RLock()
peakPnlPct := at.peakPnLCache[posKey]
at.peakPnLCacheMutex.RUnlock()
positionInfos = append(positionInfos, kernel.PositionInfo{
Symbol: symbol,
Side: side,
EntryPrice: entryPrice,
MarkPrice: markPrice,
Quantity: quantity,
Leverage: leverage,
UnrealizedPnL: unrealizedPnl,
UnrealizedPnLPct: pnlPct,
PeakPnLPct: peakPnlPct,
LiquidationPrice: liquidationPrice,
MarginUsed: marginUsed,
UpdateTime: updateTime,
})
}
// Clean up closed position records
for key := range at.positionFirstSeenTime {
if !currentPositionKeys[key] {
delete(at.positionFirstSeenTime, key)
}
}
// 3. Use strategy engine to get candidate coins (must have strategy engine)
var candidateCoins []kernel.CandidateCoin
if at.strategyEngine == nil {
logger.Infof("⚠️ [%s] No strategy engine configured, skipping candidate coins", at.name)
} else {
coins, err := at.strategyEngine.GetCandidateCoins()
if err != nil {
// Log warning but don't fail - equity snapshot should still be saved
logger.Infof("⚠️ [%s] Failed to get candidate coins: %v (will use empty list)", at.name, err)
} else {
candidateCoins = coins
logger.Infof("📋 [%s] Strategy engine fetched candidate coins: %d", at.name, len(candidateCoins))
}
}
// 4. Calculate total P&L
totalPnL := totalEquity - at.initialBalance
totalPnLPct := 0.0
if at.initialBalance > 0 {
totalPnLPct = (totalPnL / at.initialBalance) * 100
}
marginUsedPct := 0.0
if totalEquity > 0 {
marginUsedPct = (totalMarginUsed / totalEquity) * 100
}
// 5. Get leverage from strategy config
strategyConfig := at.strategyEngine.GetConfig()
btcEthLeverage := strategyConfig.RiskControl.BTCETHMaxLeverage
altcoinLeverage := strategyConfig.RiskControl.AltcoinMaxLeverage
logger.Infof("📋 [%s] Strategy leverage config: BTC/ETH=%dx, Altcoin=%dx", at.name, btcEthLeverage, altcoinLeverage)
// 6. Build context
ctx := &kernel.Context{
CurrentTime: time.Now().UTC().Format("2006-01-02 15:04:05 UTC"),
RuntimeMinutes: int(time.Since(at.startTime).Minutes()),
CallCount: at.callCount,
BTCETHLeverage: btcEthLeverage,
AltcoinLeverage: altcoinLeverage,
Account: kernel.AccountInfo{
TotalEquity: totalEquity,
AvailableBalance: availableBalance,
UnrealizedPnL: totalUnrealizedProfit,
TotalPnL: totalPnL,
TotalPnLPct: totalPnLPct,
MarginUsed: totalMarginUsed,
MarginUsedPct: marginUsedPct,
PositionCount: len(positionInfos),
},
Positions: positionInfos,
CandidateCoins: candidateCoins,
}
// 7. Add recent closed trades (if store is available)
if at.store != nil {
// Get recent 10 closed trades for AI context
recentTrades, err := at.store.Position().GetRecentTrades(at.id, 10)
if err != nil {
logger.Infof("⚠️ [%s] Failed to get recent trades: %v", at.name, err)
} else {
logger.Infof("📊 [%s] Found %d recent closed trades for AI context", at.name, len(recentTrades))
for _, trade := range recentTrades {
// Convert Unix timestamps to formatted strings for AI readability
entryTimeStr := ""
if trade.EntryTime > 0 {
entryTimeStr = time.Unix(trade.EntryTime, 0).UTC().Format("01-02 15:04 UTC")
}
exitTimeStr := ""
if trade.ExitTime > 0 {
exitTimeStr = time.Unix(trade.ExitTime, 0).UTC().Format("01-02 15:04 UTC")
}
ctx.RecentOrders = append(ctx.RecentOrders, kernel.RecentOrder{
Symbol: trade.Symbol,
Side: trade.Side,
EntryPrice: trade.EntryPrice,
ExitPrice: trade.ExitPrice,
RealizedPnL: trade.RealizedPnL,
PnLPct: trade.PnLPct,
EntryTime: entryTimeStr,
ExitTime: exitTimeStr,
HoldDuration: trade.HoldDuration,
})
}
}
// Get trading statistics for AI context
stats, err := at.store.Position().GetFullStats(at.id)
if err != nil {
logger.Infof("⚠️ [%s] Failed to get trading stats: %v", at.name, err)
} else if stats == nil {
logger.Infof("⚠️ [%s] GetFullStats returned nil", at.name)
} else if stats.TotalTrades == 0 {
logger.Infof("⚠️ [%s] GetFullStats returned 0 trades (traderID=%s)", at.name, at.id)
} else {
ctx.TradingStats = &kernel.TradingStats{
TotalTrades: stats.TotalTrades,
WinRate: stats.WinRate,
ProfitFactor: stats.ProfitFactor,
SharpeRatio: stats.SharpeRatio,
TotalPnL: stats.TotalPnL,
AvgWin: stats.AvgWin,
AvgLoss: stats.AvgLoss,
MaxDrawdownPct: stats.MaxDrawdownPct,
}
logger.Infof("📈 [%s] Trading stats: %d trades, %.1f%% win rate, PF=%.2f, Sharpe=%.2f, DD=%.1f%%",
at.name, stats.TotalTrades, stats.WinRate, stats.ProfitFactor, stats.SharpeRatio, stats.MaxDrawdownPct)
}
} else {
logger.Infof("⚠️ [%s] Store is nil, cannot get recent trades", at.name)
}
// 8. Get quantitative data (if enabled in strategy config)
if strategyConfig.Indicators.EnableQuantData {
// Collect symbols to query (candidate coins + position coins)
symbolsToQuery := make(map[string]bool)
for _, coin := range candidateCoins {
symbolsToQuery[coin.Symbol] = true
}
for _, pos := range positionInfos {
symbolsToQuery[pos.Symbol] = true
}
symbols := make([]string, 0, len(symbolsToQuery))
for sym := range symbolsToQuery {
symbols = append(symbols, sym)
}
logger.Infof("📊 [%s] Fetching quantitative data for %d symbols...", at.name, len(symbols))
ctx.QuantDataMap = at.strategyEngine.FetchQuantDataBatch(symbols)
logger.Infof("📊 [%s] Successfully fetched quantitative data for %d symbols", at.name, len(ctx.QuantDataMap))
}
// 9. Get OI ranking data (market-wide position changes)
if strategyConfig.Indicators.EnableOIRanking {
logger.Infof("📊 [%s] Fetching OI ranking data...", at.name)
ctx.OIRankingData = at.strategyEngine.FetchOIRankingData()
if ctx.OIRankingData != nil {
logger.Infof("📊 [%s] OI ranking data ready: %d top, %d low positions",
at.name, len(ctx.OIRankingData.TopPositions), len(ctx.OIRankingData.LowPositions))
}
}
// 10. Get NetFlow ranking data (market-wide fund flow)
if strategyConfig.Indicators.EnableNetFlowRanking {
logger.Infof("💰 [%s] Fetching NetFlow ranking data...", at.name)
ctx.NetFlowRankingData = at.strategyEngine.FetchNetFlowRankingData()
if ctx.NetFlowRankingData != nil {
logger.Infof("💰 [%s] NetFlow ranking data ready: inst_in=%d, inst_out=%d",
at.name, len(ctx.NetFlowRankingData.InstitutionFutureTop), len(ctx.NetFlowRankingData.InstitutionFutureLow))
}
}
// 11. Get Price ranking data (market-wide gainers/losers)
if strategyConfig.Indicators.EnablePriceRanking {
logger.Infof("📈 [%s] Fetching Price ranking data...", at.name)
ctx.PriceRankingData = at.strategyEngine.FetchPriceRankingData()
if ctx.PriceRankingData != nil {
logger.Infof("📈 [%s] Price ranking data ready for %d durations",
at.name, len(ctx.PriceRankingData.Durations))
}
}
return ctx, nil
}
// sortDecisionsByPriority sorts decisions: close positions first, then open positions, finally hold/wait
// This avoids position stacking overflow when changing positions
func sortDecisionsByPriority(decisions []kernel.Decision) []kernel.Decision {
if len(decisions) <= 1 {
return decisions
}
// Define priority
getActionPriority := func(action string) int {
switch action {
case "close_long", "close_short":
return 1 // Highest priority: close positions first
case "open_long", "open_short":
return 2 // Second priority: open positions later
case "hold", "wait":
return 3 // Lowest priority: wait
default:
return 999 // Unknown actions at the end
}
}
// Copy decision list
sorted := make([]kernel.Decision, len(decisions))
copy(sorted, decisions)
// Sort by priority
for i := 0; i < len(sorted)-1; i++ {
for j := i + 1; j < len(sorted); j++ {
if getActionPriority(sorted[i].Action) > getActionPriority(sorted[j].Action) {
sorted[i], sorted[j] = sorted[j], sorted[i]
}
}
}
return sorted
}
+391
View File
@@ -0,0 +1,391 @@
package trader
import (
"fmt"
"nofx/kernel"
"nofx/logger"
"nofx/market"
"nofx/store"
"time"
)
// executeDecisionWithRecord executes AI decision and records detailed information
func (at *AutoTrader) executeDecisionWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error {
switch decision.Action {
case "open_long":
return at.executeOpenLongWithRecord(decision, actionRecord)
case "open_short":
return at.executeOpenShortWithRecord(decision, actionRecord)
case "close_long":
return at.executeCloseLongWithRecord(decision, actionRecord)
case "close_short":
return at.executeCloseShortWithRecord(decision, actionRecord)
case "hold", "wait":
// No execution needed, just record
return nil
default:
return fmt.Errorf("unknown action: %s", decision.Action)
}
}
// executeOpenLongWithRecord executes open long position and records detailed information
func (at *AutoTrader) executeOpenLongWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error {
logger.Infof(" 📈 Open long: %s", decision.Symbol)
// ⚠️ Get current positions for multiple checks
positions, err := at.trader.GetPositions()
if err != nil {
return fmt.Errorf("failed to get positions: %w", err)
}
// [CODE ENFORCED] Check max positions limit
if err := at.enforceMaxPositions(len(positions)); err != nil {
return err
}
// Check if there's already a position in the same symbol and direction
for _, pos := range positions {
if pos["symbol"] == decision.Symbol && pos["side"] == "long" {
return fmt.Errorf("❌ %s already has long position, close it first", decision.Symbol)
}
}
// Get current price
marketData, err := market.GetWithExchange(decision.Symbol, at.exchange)
if err != nil {
return err
}
// Get balance (needed for multiple checks)
balance, err := at.trader.GetBalance()
if err != nil {
return fmt.Errorf("failed to get account balance: %w", err)
}
availableBalance := 0.0
if avail, ok := balance["availableBalance"].(float64); ok {
availableBalance = avail
}
// Get equity for position value ratio check
equity := 0.0
if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 {
equity = eq
} else if eq, ok := balance["totalWalletBalance"].(float64); ok && eq > 0 {
equity = eq
} else {
equity = availableBalance // Fallback to available balance
}
// [CODE ENFORCED] Position Value Ratio Check: position_value <= equity × ratio
adjustedPositionSize, wasCapped := at.enforcePositionValueRatio(decision.PositionSizeUSD, equity, decision.Symbol)
if wasCapped {
decision.PositionSizeUSD = adjustedPositionSize
}
// ⚠️ Auto-adjust position size if insufficient margin
// Formula: totalRequired = positionSize/leverage + positionSize*0.001 + positionSize/leverage*0.01
// = positionSize * (1.01/leverage + 0.001)
marginFactor := 1.01/float64(decision.Leverage) + 0.001
maxAffordablePositionSize := availableBalance / marginFactor
actualPositionSize := decision.PositionSizeUSD
if actualPositionSize > maxAffordablePositionSize {
// Use 98% of max to leave buffer for price fluctuation
adjustedSize := maxAffordablePositionSize * 0.98
logger.Infof(" ⚠️ Position size %.2f exceeds max affordable %.2f, auto-reducing to %.2f",
actualPositionSize, maxAffordablePositionSize, adjustedSize)
actualPositionSize = adjustedSize
decision.PositionSizeUSD = actualPositionSize
}
// [CODE ENFORCED] Minimum position size check
if err := at.enforceMinPositionSize(decision.PositionSizeUSD); err != nil {
return err
}
// Calculate quantity with adjusted position size
quantity := actualPositionSize / marketData.CurrentPrice
actionRecord.Quantity = quantity
actionRecord.Price = marketData.CurrentPrice
// Set margin mode
if err := at.trader.SetMarginMode(decision.Symbol, at.config.IsCrossMargin); err != nil {
logger.Infof(" ⚠️ Failed to set margin mode: %v", err)
// Continue execution, doesn't affect trading
}
// Open position
order, err := at.trader.OpenLong(decision.Symbol, quantity, decision.Leverage)
if err != nil {
return err
}
// Record order ID
if orderID, ok := order["orderId"].(int64); ok {
actionRecord.OrderID = orderID
}
logger.Infof(" ✓ Position opened successfully, order ID: %v, quantity: %.4f", order["orderId"], quantity)
// Record order to database and poll for confirmation
at.recordAndConfirmOrder(order, decision.Symbol, "open_long", quantity, marketData.CurrentPrice, decision.Leverage, 0)
// Record position opening time
posKey := decision.Symbol + "_long"
at.positionFirstSeenTime[posKey] = time.Now().UnixMilli()
// Set stop loss and take profit
if err := at.trader.SetStopLoss(decision.Symbol, "LONG", quantity, decision.StopLoss); err != nil {
logger.Infof(" ⚠ Failed to set stop loss: %v", err)
}
if err := at.trader.SetTakeProfit(decision.Symbol, "LONG", quantity, decision.TakeProfit); err != nil {
logger.Infof(" ⚠ Failed to set take profit: %v", err)
}
return nil
}
// executeOpenShortWithRecord executes open short position and records detailed information
func (at *AutoTrader) executeOpenShortWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error {
logger.Infof(" 📉 Open short: %s", decision.Symbol)
// ⚠️ Get current positions for multiple checks
positions, err := at.trader.GetPositions()
if err != nil {
return fmt.Errorf("failed to get positions: %w", err)
}
// [CODE ENFORCED] Check max positions limit
if err := at.enforceMaxPositions(len(positions)); err != nil {
return err
}
// Check if there's already a position in the same symbol and direction
for _, pos := range positions {
if pos["symbol"] == decision.Symbol && pos["side"] == "short" {
return fmt.Errorf("❌ %s already has short position, close it first", decision.Symbol)
}
}
// Get current price
marketData, err := market.GetWithExchange(decision.Symbol, at.exchange)
if err != nil {
return err
}
// Get balance (needed for multiple checks)
balance, err := at.trader.GetBalance()
if err != nil {
return fmt.Errorf("failed to get account balance: %w", err)
}
availableBalance := 0.0
if avail, ok := balance["availableBalance"].(float64); ok {
availableBalance = avail
}
// Get equity for position value ratio check
equity := 0.0
if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 {
equity = eq
} else if eq, ok := balance["totalWalletBalance"].(float64); ok && eq > 0 {
equity = eq
} else {
equity = availableBalance // Fallback to available balance
}
// [CODE ENFORCED] Position Value Ratio Check: position_value <= equity × ratio
adjustedPositionSize, wasCapped := at.enforcePositionValueRatio(decision.PositionSizeUSD, equity, decision.Symbol)
if wasCapped {
decision.PositionSizeUSD = adjustedPositionSize
}
// ⚠️ Auto-adjust position size if insufficient margin
// Formula: totalRequired = positionSize/leverage + positionSize*0.001 + positionSize/leverage*0.01
// = positionSize * (1.01/leverage + 0.001)
marginFactor := 1.01/float64(decision.Leverage) + 0.001
maxAffordablePositionSize := availableBalance / marginFactor
actualPositionSize := decision.PositionSizeUSD
if actualPositionSize > maxAffordablePositionSize {
// Use 98% of max to leave buffer for price fluctuation
adjustedSize := maxAffordablePositionSize * 0.98
logger.Infof(" ⚠️ Position size %.2f exceeds max affordable %.2f, auto-reducing to %.2f",
actualPositionSize, maxAffordablePositionSize, adjustedSize)
actualPositionSize = adjustedSize
decision.PositionSizeUSD = actualPositionSize
}
// [CODE ENFORCED] Minimum position size check
if err := at.enforceMinPositionSize(decision.PositionSizeUSD); err != nil {
return err
}
// Calculate quantity with adjusted position size
quantity := actualPositionSize / marketData.CurrentPrice
actionRecord.Quantity = quantity
actionRecord.Price = marketData.CurrentPrice
// Set margin mode
if err := at.trader.SetMarginMode(decision.Symbol, at.config.IsCrossMargin); err != nil {
logger.Infof(" ⚠️ Failed to set margin mode: %v", err)
// Continue execution, doesn't affect trading
}
// Open position
order, err := at.trader.OpenShort(decision.Symbol, quantity, decision.Leverage)
if err != nil {
return err
}
// Record order ID
if orderID, ok := order["orderId"].(int64); ok {
actionRecord.OrderID = orderID
}
logger.Infof(" ✓ Position opened successfully, order ID: %v, quantity: %.4f", order["orderId"], quantity)
// Record order to database and poll for confirmation
at.recordAndConfirmOrder(order, decision.Symbol, "open_short", quantity, marketData.CurrentPrice, decision.Leverage, 0)
// Record position opening time
posKey := decision.Symbol + "_short"
at.positionFirstSeenTime[posKey] = time.Now().UnixMilli()
// Set stop loss and take profit
if err := at.trader.SetStopLoss(decision.Symbol, "SHORT", quantity, decision.StopLoss); err != nil {
logger.Infof(" ⚠ Failed to set stop loss: %v", err)
}
if err := at.trader.SetTakeProfit(decision.Symbol, "SHORT", quantity, decision.TakeProfit); err != nil {
logger.Infof(" ⚠ Failed to set take profit: %v", err)
}
return nil
}
// executeCloseLongWithRecord executes close long position and records detailed information
func (at *AutoTrader) executeCloseLongWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error {
logger.Infof(" 🔄 Close long: %s", decision.Symbol)
// Get current price
marketData, err := market.GetWithExchange(decision.Symbol, at.exchange)
if err != nil {
return err
}
actionRecord.Price = marketData.CurrentPrice
// Normalize symbol for database lookup
normalizedSymbol := market.Normalize(decision.Symbol)
// Get entry price and quantity - prioritize local database for accurate quantity
var entryPrice float64
var quantity float64
// First try to get from local database (more accurate for quantity)
if at.store != nil {
if openPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, normalizedSymbol, "LONG"); err == nil && openPos != nil {
quantity = openPos.Quantity
entryPrice = openPos.EntryPrice
logger.Infof(" 📊 Using local position data: qty=%.8f, entry=%.2f", quantity, entryPrice)
}
}
// Fallback to exchange API if local data not found
if quantity == 0 {
positions, err := at.trader.GetPositions()
if err == nil {
for _, pos := range positions {
if pos["symbol"] == decision.Symbol && pos["side"] == "long" {
if ep, ok := pos["entryPrice"].(float64); ok {
entryPrice = ep
}
if amt, ok := pos["positionAmt"].(float64); ok && amt > 0 {
quantity = amt
}
break
}
}
}
logger.Infof(" 📊 Using exchange position data: qty=%.8f, entry=%.2f", quantity, entryPrice)
}
// Close position
order, err := at.trader.CloseLong(decision.Symbol, 0) // 0 = close all
if err != nil {
return err
}
// Record order ID
if orderID, ok := order["orderId"].(int64); ok {
actionRecord.OrderID = orderID
}
// Record order to database and poll for confirmation
at.recordAndConfirmOrder(order, decision.Symbol, "close_long", quantity, marketData.CurrentPrice, 0, entryPrice)
logger.Infof(" ✓ Position closed successfully")
return nil
}
// executeCloseShortWithRecord executes close short position and records detailed information
func (at *AutoTrader) executeCloseShortWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error {
logger.Infof(" 🔄 Close short: %s", decision.Symbol)
// Get current price
marketData, err := market.GetWithExchange(decision.Symbol, at.exchange)
if err != nil {
return err
}
actionRecord.Price = marketData.CurrentPrice
// Normalize symbol for database lookup
normalizedSymbol := market.Normalize(decision.Symbol)
// Get entry price and quantity - prioritize local database for accurate quantity
var entryPrice float64
var quantity float64
// First try to get from local database (more accurate for quantity)
if at.store != nil {
if openPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, normalizedSymbol, "SHORT"); err == nil && openPos != nil {
quantity = openPos.Quantity
entryPrice = openPos.EntryPrice
logger.Infof(" 📊 Using local position data: qty=%.8f, entry=%.2f", quantity, entryPrice)
}
}
// Fallback to exchange API if local data not found
if quantity == 0 {
positions, err := at.trader.GetPositions()
if err == nil {
for _, pos := range positions {
if pos["symbol"] == decision.Symbol && pos["side"] == "short" {
if ep, ok := pos["entryPrice"].(float64); ok {
entryPrice = ep
}
if amt, ok := pos["positionAmt"].(float64); ok {
quantity = -amt // positionAmt is negative for short
}
break
}
}
}
logger.Infof(" 📊 Using exchange position data: qty=%.8f, entry=%.2f", quantity, entryPrice)
}
// Close position
order, err := at.trader.CloseShort(decision.Symbol, 0) // 0 = close all
if err != nil {
return err
}
// Record order ID
if orderID, ok := order["orderId"].(int64); ok {
actionRecord.OrderID = orderID
}
// Record order to database and poll for confirmation
at.recordAndConfirmOrder(order, decision.Symbol, "close_short", quantity, marketData.CurrentPrice, 0, entryPrice)
logger.Infof(" ✓ Position closed successfully")
return nil
}
+263
View File
@@ -0,0 +1,263 @@
package trader
import (
"fmt"
"nofx/logger"
"strings"
"time"
)
// startDrawdownMonitor starts drawdown monitoring
func (at *AutoTrader) startDrawdownMonitor() {
at.monitorWg.Add(1)
go func() {
defer at.monitorWg.Done()
ticker := time.NewTicker(1 * time.Minute) // Check every minute
defer ticker.Stop()
logger.Info("📊 Started position drawdown monitoring (check every minute)")
for {
select {
case <-ticker.C:
at.checkPositionDrawdown()
case <-at.stopMonitorCh:
logger.Info("⏹ Stopped position drawdown monitoring")
return
}
}
}()
}
// checkPositionDrawdown checks position drawdown situation
func (at *AutoTrader) checkPositionDrawdown() {
// Get current positions
positions, err := at.trader.GetPositions()
if err != nil {
logger.Infof("❌ Drawdown monitoring: failed to get positions: %v", err)
return
}
for _, pos := range positions {
symbol := pos["symbol"].(string)
side := pos["side"].(string)
entryPrice := pos["entryPrice"].(float64)
markPrice := pos["markPrice"].(float64)
quantity := pos["positionAmt"].(float64)
if quantity < 0 {
quantity = -quantity // Short position quantity is negative, convert to positive
}
// Calculate current P&L percentage
leverage := 10 // Default value
if lev, ok := pos["leverage"].(float64); ok {
leverage = int(lev)
}
var currentPnLPct float64
if side == "long" {
currentPnLPct = ((markPrice - entryPrice) / entryPrice) * float64(leverage) * 100
} else {
currentPnLPct = ((entryPrice - markPrice) / entryPrice) * float64(leverage) * 100
}
// Construct unique position identifier (distinguish long/short)
posKey := symbol + "_" + side
// Get historical peak profit for this position
at.peakPnLCacheMutex.RLock()
peakPnLPct, exists := at.peakPnLCache[posKey]
at.peakPnLCacheMutex.RUnlock()
if !exists {
// If no historical peak record, use current P&L as initial value
peakPnLPct = currentPnLPct
at.UpdatePeakPnL(symbol, side, currentPnLPct)
} else {
// Update peak cache
at.UpdatePeakPnL(symbol, side, currentPnLPct)
}
// Calculate drawdown (magnitude of decline from peak)
var drawdownPct float64
if peakPnLPct > 0 && currentPnLPct < peakPnLPct {
drawdownPct = ((peakPnLPct - currentPnLPct) / peakPnLPct) * 100
}
// Check close position condition: profit > 5% and drawdown >= 40%
if currentPnLPct > 5.0 && drawdownPct >= 40.0 {
logger.Infof("🚨 Drawdown close position condition triggered: %s %s | Current profit: %.2f%% | Peak profit: %.2f%% | Drawdown: %.2f%%",
symbol, side, currentPnLPct, peakPnLPct, drawdownPct)
// Execute close position
if err := at.emergencyClosePosition(symbol, side); err != nil {
logger.Infof("❌ Drawdown close position failed (%s %s): %v", symbol, side, err)
} else {
logger.Infof("✅ Drawdown close position succeeded: %s %s", symbol, side)
// Clear cache for this position after closing
at.ClearPeakPnLCache(symbol, side)
}
} else if currentPnLPct > 5.0 {
// Record situations close to close position condition (for debugging)
logger.Infof("📊 Drawdown monitoring: %s %s | Profit: %.2f%% | Peak: %.2f%% | Drawdown: %.2f%%",
symbol, side, currentPnLPct, peakPnLPct, drawdownPct)
}
}
}
// emergencyClosePosition emergency close position function
func (at *AutoTrader) emergencyClosePosition(symbol, side string) error {
switch side {
case "long":
order, err := at.trader.CloseLong(symbol, 0) // 0 = close all
if err != nil {
return err
}
logger.Infof("✅ Emergency close long position succeeded, order ID: %v", order["orderId"])
case "short":
order, err := at.trader.CloseShort(symbol, 0) // 0 = close all
if err != nil {
return err
}
logger.Infof("✅ Emergency close short position succeeded, order ID: %v", order["orderId"])
default:
return fmt.Errorf("unknown position direction: %s", side)
}
return nil
}
// GetPeakPnLCache gets peak profit cache
func (at *AutoTrader) GetPeakPnLCache() map[string]float64 {
at.peakPnLCacheMutex.RLock()
defer at.peakPnLCacheMutex.RUnlock()
// Return a copy of the cache
cache := make(map[string]float64)
for k, v := range at.peakPnLCache {
cache[k] = v
}
return cache
}
// UpdatePeakPnL updates peak profit cache
func (at *AutoTrader) UpdatePeakPnL(symbol, side string, currentPnLPct float64) {
at.peakPnLCacheMutex.Lock()
defer at.peakPnLCacheMutex.Unlock()
posKey := symbol + "_" + side
if peak, exists := at.peakPnLCache[posKey]; exists {
// Update peak (if long, take larger value; if short, currentPnLPct is negative, also compare)
if currentPnLPct > peak {
at.peakPnLCache[posKey] = currentPnLPct
}
} else {
// First time recording
at.peakPnLCache[posKey] = currentPnLPct
}
}
// ClearPeakPnLCache clears peak cache for specified position
func (at *AutoTrader) ClearPeakPnLCache(symbol, side string) {
at.peakPnLCacheMutex.Lock()
defer at.peakPnLCacheMutex.Unlock()
posKey := symbol + "_" + side
delete(at.peakPnLCache, posKey)
}
// ============================================================================
// Risk Control Helpers
// ============================================================================
// isBTCETH checks if a symbol is BTC or ETH
func isBTCETH(symbol string) bool {
symbol = strings.ToUpper(symbol)
return strings.HasPrefix(symbol, "BTC") || strings.HasPrefix(symbol, "ETH")
}
// enforcePositionValueRatio checks and enforces position value ratio limits (CODE ENFORCED)
// Returns the adjusted position size (capped if necessary) and whether the position was capped
// positionSizeUSD: the original position size in USD
// equity: the account equity
// symbol: the trading symbol
func (at *AutoTrader) enforcePositionValueRatio(positionSizeUSD float64, equity float64, symbol string) (float64, bool) {
if at.config.StrategyConfig == nil {
return positionSizeUSD, false
}
riskControl := at.config.StrategyConfig.RiskControl
// Get the appropriate position value ratio limit
var maxPositionValueRatio float64
if isBTCETH(symbol) {
maxPositionValueRatio = riskControl.BTCETHMaxPositionValueRatio
if maxPositionValueRatio <= 0 {
maxPositionValueRatio = 5.0 // Default: 5x for BTC/ETH
}
} else {
maxPositionValueRatio = riskControl.AltcoinMaxPositionValueRatio
if maxPositionValueRatio <= 0 {
maxPositionValueRatio = 1.0 // Default: 1x for altcoins
}
}
// Calculate max allowed position value = equity × ratio
maxPositionValue := equity * maxPositionValueRatio
// Check if position size exceeds limit
if positionSizeUSD > maxPositionValue {
logger.Infof(" ⚠️ [RISK CONTROL] Position %.2f USDT exceeds limit (equity %.2f × %.1fx = %.2f USDT max for %s), capping",
positionSizeUSD, equity, maxPositionValueRatio, maxPositionValue, symbol)
return maxPositionValue, true
}
return positionSizeUSD, false
}
// enforceMinPositionSize checks minimum position size (CODE ENFORCED)
func (at *AutoTrader) enforceMinPositionSize(positionSizeUSD float64) error {
if at.config.StrategyConfig == nil {
return nil
}
minSize := at.config.StrategyConfig.RiskControl.MinPositionSize
if minSize <= 0 {
minSize = 12 // Default: 12 USDT
}
if positionSizeUSD < minSize {
return fmt.Errorf("❌ [RISK CONTROL] Position %.2f USDT below minimum (%.2f USDT)", positionSizeUSD, minSize)
}
return nil
}
// enforceMaxPositions checks maximum positions count (CODE ENFORCED)
func (at *AutoTrader) enforceMaxPositions(currentPositionCount int) error {
if at.config.StrategyConfig == nil {
return nil
}
maxPositions := at.config.StrategyConfig.RiskControl.MaxPositions
if maxPositions <= 0 {
maxPositions = 3 // Default: 3 positions
}
if currentPositionCount >= maxPositions {
return fmt.Errorf("❌ [RISK CONTROL] Already at max positions (%d/%d)", currentPositionCount, maxPositions)
}
return nil
}
// getSideFromAction converts order action to side (BUY/SELL)
func getSideFromAction(action string) string {
switch action {
case "open_long", "close_short":
return "BUY"
case "open_short", "close_long":
return "SELL"
default:
return "BUY"
}
}
+9 -9
View File
@@ -4,27 +4,27 @@ import useSWR from 'swr'
import { api } from './lib/api'
import { TraderDashboardPage } from './pages/TraderDashboardPage'
import { AITradersPage } from './components/AITradersPage'
import { LoginPage } from './components/LoginPage'
import { SetupPage } from './components/SetupPage'
import { AITradersPage } from './components/trader/AITradersPage'
import { LoginPage } from './components/auth/LoginPage'
import { SetupPage } from './components/modals/SetupPage'
import { SettingsPage } from './pages/SettingsPage'
import { ResetPasswordPage } from './components/ResetPasswordPage'
import { CompetitionPage } from './components/CompetitionPage'
import { ResetPasswordPage } from './components/auth/ResetPasswordPage'
import { CompetitionPage } from './components/trader/CompetitionPage'
import { LandingPage } from './pages/LandingPage'
import { FAQPage } from './pages/FAQPage'
import { StrategyStudioPage } from './pages/StrategyStudioPage'
import { StrategyMarketPage } from './pages/StrategyMarketPage'
import { DataPage } from './pages/DataPage'
import { LoginRequiredOverlay } from './components/LoginRequiredOverlay'
import HeaderBar from './components/HeaderBar'
import { LoginRequiredOverlay } from './components/auth/LoginRequiredOverlay'
import HeaderBar from './components/common/HeaderBar'
import { LanguageProvider, useLanguage } from './contexts/LanguageContext'
import { AuthProvider, useAuth } from './contexts/AuthContext'
import { ConfirmDialogProvider } from './components/ConfirmDialog'
import { ConfirmDialogProvider } from './components/common/ConfirmDialog'
import { t } from './i18n/translations'
import { useSystemConfig } from './hooks/useSystemConfig'
import { OFFICIAL_LINKS } from './constants/branding'
import { BacktestPage } from './components/BacktestPage'
import { BacktestPage } from './components/backtest/BacktestPage'
import type {
SystemStatus,
AccountInfo,
@@ -1,10 +1,10 @@
import React, { useState, useEffect } from 'react'
import { Eye, EyeOff } from 'lucide-react'
import { toast } from 'sonner'
import { useAuth } from '../contexts/AuthContext'
import { useLanguage } from '../contexts/LanguageContext'
import { t } from '../i18n/translations'
import { DeepVoidBackground } from './DeepVoidBackground'
import { useAuth } from '../../contexts/AuthContext'
import { useLanguage } from '../../contexts/LanguageContext'
import { t } from '../../i18n/translations'
import { DeepVoidBackground } from '../common/DeepVoidBackground'
export function LoginPage() {
const { language } = useLanguage()
@@ -1,7 +1,7 @@
import { motion, AnimatePresence } from 'framer-motion'
import { LogIn, UserPlus, X, AlertTriangle, Terminal } from 'lucide-react'
import { DeepVoidBackground } from './DeepVoidBackground'
import { useLanguage } from '../contexts/LanguageContext'
import { DeepVoidBackground } from '../common/DeepVoidBackground'
import { useLanguage } from '../../contexts/LanguageContext'
interface LoginRequiredOverlayProps {
isOpen: boolean
@@ -2,13 +2,13 @@ import React, { useEffect, useState } from 'react'
import { Eye, EyeOff } from 'lucide-react'
import PasswordChecklist from 'react-password-checklist'
import { toast } from 'sonner'
import { useAuth } from '../contexts/AuthContext'
import { useLanguage } from '../contexts/LanguageContext'
import { t } from '../i18n/translations'
import { getSystemConfig } from '../lib/config'
import { DeepVoidBackground } from './DeepVoidBackground'
import { useAuth } from '../../contexts/AuthContext'
import { useLanguage } from '../../contexts/LanguageContext'
import { t } from '../../i18n/translations'
import { getSystemConfig } from '../../lib/config'
import { DeepVoidBackground } from '../common/DeepVoidBackground'
import { RegistrationDisabled } from './RegistrationDisabled'
import { WhitelistFullPage } from './WhitelistFullPage'
import { WhitelistFullPage } from '../common/WhitelistFullPage'
export function RegisterPage() {
const { language } = useLanguage()
@@ -1,11 +1,11 @@
import { describe, it, expect, vi } from 'vitest'
import { render, screen, fireEvent } from '@testing-library/react'
import { RegistrationDisabled } from './RegistrationDisabled'
import { LanguageProvider } from '../contexts/LanguageContext'
import { LanguageProvider } from '../../contexts/LanguageContext'
// Mock useLanguage hook
vi.mock('../contexts/LanguageContext', async () => {
const actual = await vi.importActual('../contexts/LanguageContext')
vi.mock('../../contexts/LanguageContext', async () => {
const actual = await vi.importActual('../../contexts/LanguageContext')
return {
...actual,
useLanguage: () => ({ language: 'en' }),
@@ -1,5 +1,5 @@
import { useLanguage } from '../contexts/LanguageContext'
import { t } from '../i18n/translations'
import { useLanguage } from '../../contexts/LanguageContext'
import { t } from '../../i18n/translations'
export function RegistrationDisabled() {
const { language } = useLanguage()
@@ -1,11 +1,11 @@
import React, { useState } from 'react'
import { useAuth } from '../contexts/AuthContext'
import { useLanguage } from '../contexts/LanguageContext'
import { t } from '../i18n/translations'
import { Header } from './Header'
import { useAuth } from '../../contexts/AuthContext'
import { useLanguage } from '../../contexts/LanguageContext'
import { t } from '../../i18n/translations'
import { Header } from '../common/Header'
import { ArrowLeft, KeyRound, Eye, EyeOff } from 'lucide-react'
import PasswordChecklist from 'react-password-checklist'
import { Input } from './ui/input'
import { Input } from '../ui/input'
import { toast } from 'sonner'
export function ResetPasswordPage() {
@@ -28,7 +28,7 @@ import {
ArrowDownRight,
CandlestickChart as CandlestickIcon,
} from 'lucide-react'
import { DeepVoidBackground } from './DeepVoidBackground'
import { DeepVoidBackground } from '../common/DeepVoidBackground'
import {
ResponsiveContainer,
AreaChart,
@@ -39,12 +39,12 @@ import {
Tooltip,
ReferenceDot,
} from 'recharts'
import { api } from '../lib/api'
import { useLanguage } from '../contexts/LanguageContext'
import { t } from '../i18n/translations'
import { confirmToast } from '../lib/notify'
import { DecisionCard } from './DecisionCard'
import { MetricTooltip } from './MetricTooltip'
import { api } from '../../lib/api'
import { useLanguage } from '../../contexts/LanguageContext'
import { t } from '../../i18n/translations'
import { confirmToast } from '../../lib/notify'
import { DecisionCard } from '../trader/DecisionCard'
import { MetricTooltip } from '../common/MetricTooltip'
import type {
BacktestStatusPayload,
BacktestPositionStatus,
@@ -55,7 +55,7 @@ import type {
DecisionRecord,
AIModel,
Strategy,
} from '../types'
} from '../../types'
// ============ Types ============
type WizardStep = 1 | 2 | 3
@@ -10,14 +10,14 @@ import {
HistogramSeries,
createSeriesMarkers,
} from 'lightweight-charts'
import { useLanguage } from '../contexts/LanguageContext'
import { httpClient } from '../lib/httpClient'
import { useLanguage } from '../../contexts/LanguageContext'
import { httpClient } from '../../lib/httpClient'
import {
calculateSMA,
calculateEMA,
calculateBollingerBands,
type Kline,
} from '../utils/indicators'
} from '../../utils/indicators'
import { Settings, BarChart2 } from 'lucide-react'
// 订单接口定义
@@ -1,8 +1,8 @@
import { useState, useEffect, useRef } from 'react'
import { EquityChart } from './EquityChart'
import { AdvancedChart } from './AdvancedChart'
import { useLanguage } from '../contexts/LanguageContext'
import { t } from '../i18n/translations'
import { useLanguage } from '../../contexts/LanguageContext'
import { t } from '../../i18n/translations'
import { BarChart3, CandlestickChart, ChevronDown, Search } from 'lucide-react'
import { motion, AnimatePresence } from 'framer-motion'
@@ -8,8 +8,8 @@ import {
CandlestickSeries,
createSeriesMarkers,
} from 'lightweight-charts'
import { useLanguage } from '../contexts/LanguageContext'
import { httpClient } from '../lib/httpClient'
import { useLanguage } from '../../contexts/LanguageContext'
import { httpClient } from '../../lib/httpClient'
// 订单接口定义
interface OrderMarker {
@@ -1,5 +1,5 @@
import { useEffect, useState } from 'react'
import { httpClient } from '../lib/httpClient'
import { httpClient } from '../../lib/httpClient'
interface ChartWithOrdersSimpleProps {
symbol: string
@@ -12,11 +12,11 @@ import {
ComposedChart,
} from 'recharts'
import useSWR from 'swr'
import { api } from '../lib/api'
import type { CompetitionTraderData } from '../types'
import { getTraderColor } from '../utils/traderColors'
import { useLanguage } from '../contexts/LanguageContext'
import { t } from '../i18n/translations'
import { api } from '../../lib/api'
import type { CompetitionTraderData } from '../../types'
import { getTraderColor } from '../../utils/traderColors'
import { useLanguage } from '../../contexts/LanguageContext'
import { t } from '../../i18n/translations'
import { BarChart3, TrendingUp, TrendingDown, Zap } from 'lucide-react'
// Time period options: 1D, 3D, 7D, 30D, All
@@ -10,10 +10,10 @@ import {
ReferenceLine,
} from 'recharts'
import useSWR from 'swr'
import { api } from '../lib/api'
import { useLanguage } from '../contexts/LanguageContext'
import { useAuth } from '../contexts/AuthContext'
import { t } from '../i18n/translations'
import { api } from '../../lib/api'
import { useLanguage } from '../../contexts/LanguageContext'
import { useAuth } from '../../contexts/AuthContext'
import { t } from '../../i18n/translations'
import {
AlertTriangle,
BarChart3,
@@ -1,6 +1,6 @@
import { useEffect, useRef, useState, memo } from 'react'
import { useLanguage } from '../contexts/LanguageContext'
import { t } from '../i18n/translations'
import { useLanguage } from '../../contexts/LanguageContext'
import { t } from '../../i18n/translations'
import { ChevronDown, TrendingUp, X } from 'lucide-react'
// 支持的交易所列表 (合约格式)
@@ -13,8 +13,8 @@ import {
AlertDialogDescription,
AlertDialogFooter,
AlertDialogTitle,
} from './ui/alert-dialog'
import { setGlobalConfirm } from '../lib/notify'
} from '../ui/alert-dialog'
import { setGlobalConfirm } from '../../lib/notify'
interface ConfirmOptions {
title?: string
@@ -1,5 +1,5 @@
import { useLanguage } from '../contexts/LanguageContext'
import { t } from '../i18n/translations'
import { useLanguage } from '../../contexts/LanguageContext'
import { t } from '../../i18n/translations'
import { Container } from './Container'
interface HeaderProps {
@@ -2,8 +2,8 @@ import { useState, useEffect, useRef } from 'react'
import { useNavigate } from 'react-router-dom'
import { motion, AnimatePresence } from 'framer-motion'
import { Menu, X, ChevronDown, Settings } from 'lucide-react'
import { t, type Language } from '../i18n/translations'
import { OFFICIAL_LINKS } from '../constants/branding'
import { t, type Language } from '../../i18n/translations'
import { OFFICIAL_LINKS } from '../../constants/branding'
type Page =
| 'competition'
@@ -1,7 +1,7 @@
import { useCallback, useEffect, useState, type ReactNode } from 'react'
import { Loader2, ShieldAlert, ShieldCheck, ShieldMinus } from 'lucide-react'
import { CryptoService, diagnoseWebCryptoEnvironment } from '../lib/crypto'
import { t, type Language } from '../i18n/translations'
import { CryptoService, diagnoseWebCryptoEnvironment } from '../../lib/crypto'
import { t, type Language } from '../../i18n/translations'
export type WebCryptoCheckStatus =
| 'idle'
@@ -1,6 +1,6 @@
import { motion } from 'framer-motion'
import { ShieldAlert, ArrowLeft, Twitter, Send, Lock } from 'lucide-react'
import { OFFICIAL_LINKS } from '../constants/branding'
import { OFFICIAL_LINKS } from '../../constants/branding'
interface WhitelistFullPageProps {
onBack?: () => void
+1 -1
View File
@@ -1,6 +1,6 @@
import { useState, useMemo } from 'react'
import { HelpCircle } from 'lucide-react'
import { DeepVoidBackground } from '../DeepVoidBackground'
import { DeepVoidBackground } from '../common/DeepVoidBackground'
import { t, type Language } from '../../i18n/translations'
import { FAQSearchBar } from './FAQSearchBar'
import { FAQSidebar } from './FAQSidebar'
@@ -1,8 +1,8 @@
import React, { useState } from 'react'
import { Eye, EyeOff } from 'lucide-react'
import { useAuth } from '../contexts/AuthContext'
import { DeepVoidBackground } from './DeepVoidBackground'
import { invalidateSystemConfig } from '../lib/config'
import { useAuth } from '../../contexts/AuthContext'
import { DeepVoidBackground } from '../common/DeepVoidBackground'
import { invalidateSystemConfig } from '../../lib/config'
export function SetupPage() {
const { register } = useAuth()
@@ -1,8 +1,8 @@
import { useEffect, useMemo, useRef, useState } from 'react'
import { createPortal } from 'react-dom'
import { t, type Language } from '../i18n/translations'
import { t, type Language } from '../../i18n/translations'
import { toast } from 'sonner'
import { WebCryptoEnvironmentCheck } from './WebCryptoEnvironmentCheck'
import { WebCryptoEnvironmentCheck } from '../common/WebCryptoEnvironmentCheck'
const DEFAULT_LENGTH = 64
@@ -1,23 +1,23 @@
import React, { useState, useEffect } from 'react'
import { useNavigate } from 'react-router-dom'
import useSWR from 'swr'
import { api } from '../lib/api'
import { api } from '../../lib/api'
import type {
TraderInfo,
CreateTraderRequest,
AIModel,
Exchange,
} from '../types'
import { useLanguage } from '../contexts/LanguageContext'
import { t, type Language } from '../i18n/translations'
import { useAuth } from '../contexts/AuthContext'
import { getExchangeIcon } from './ExchangeIcons'
import { getModelIcon } from './ModelIcons'
} from '../../types'
import { useLanguage } from '../../contexts/LanguageContext'
import { t, type Language } from '../../i18n/translations'
import { useAuth } from '../../contexts/AuthContext'
import { getExchangeIcon } from '../common/ExchangeIcons'
import { getModelIcon } from '../common/ModelIcons'
import { TraderConfigModal } from './TraderConfigModal'
import { DeepVoidBackground } from './DeepVoidBackground'
import { ExchangeConfigModal } from './traders/ExchangeConfigModal'
import { TelegramConfigModal } from './traders/TelegramConfigModal'
import { PunkAvatar, getTraderAvatar } from './PunkAvatar'
import { DeepVoidBackground } from '../common/DeepVoidBackground'
import { ExchangeConfigModal } from './ExchangeConfigModal'
import { TelegramConfigModal } from './TelegramConfigModal'
import { PunkAvatar, getTraderAvatar } from '../common/PunkAvatar'
import {
Bot,
Brain,
@@ -34,7 +34,7 @@ import {
Check,
MessageCircle,
} from 'lucide-react'
import { confirmToast } from '../lib/notify'
import { confirmToast } from '../../lib/notify'
import { toast } from 'sonner'
// 获取友好的AI模型名称
@@ -1,15 +1,15 @@
import { useState } from 'react'
import { Trophy } from 'lucide-react'
import useSWR from 'swr'
import { api } from '../lib/api'
import type { CompetitionData } from '../types'
import { ComparisonChart } from './ComparisonChart'
import { api } from '../../lib/api'
import type { CompetitionData } from '../../types'
import { ComparisonChart } from '../charts/ComparisonChart'
import { TraderConfigViewModal } from './TraderConfigViewModal'
import { getTraderColor } from '../utils/traderColors'
import { useLanguage } from '../contexts/LanguageContext'
import { t } from '../i18n/translations'
import { PunkAvatar, getTraderAvatar } from './PunkAvatar'
import { DeepVoidBackground } from './DeepVoidBackground'
import { getTraderColor } from '../../utils/traderColors'
import { useLanguage } from '../../contexts/LanguageContext'
import { t } from '../../i18n/translations'
import { PunkAvatar, getTraderAvatar } from '../common/PunkAvatar'
import { DeepVoidBackground } from '../common/DeepVoidBackground'
export function CompetitionPage() {
const { language } = useLanguage()
@@ -1,6 +1,6 @@
import { useState } from 'react'
import type { DecisionRecord, DecisionAction } from '../types'
import { t, type Language } from '../i18n/translations'
import type { DecisionRecord, DecisionAction } from '../../types'
import { t, type Language } from '../../i18n/translations'
interface DecisionCardProps {
decision: DecisionRecord
@@ -2,15 +2,15 @@ import React, { useState, useEffect } from 'react'
import type { Exchange } from '../../types'
import { t, type Language } from '../../i18n/translations'
import { api } from '../../lib/api'
import { getExchangeIcon } from '../ExchangeIcons'
import { getExchangeIcon } from '../common/ExchangeIcons'
import {
TwoStageKeyModal,
type TwoStageKeyModalResult,
} from '../TwoStageKeyModal'
} from '../modals/TwoStageKeyModal'
import {
WebCryptoEnvironmentCheck,
type WebCryptoCheckStatus,
} from '../WebCryptoEnvironmentCheck'
} from '../common/WebCryptoEnvironmentCheck'
import {
BookOpen, Trash2, HelpCircle, ExternalLink, UserPlus,
Key, Shield, ChevronLeft, Check, Copy, ArrowRight
@@ -1,15 +1,15 @@
import { useState, useEffect, useMemo } from 'react'
import { api } from '../lib/api'
import { useLanguage } from '../contexts/LanguageContext'
import { t, type Language } from '../i18n/translations'
import { MetricTooltip } from './MetricTooltip'
import { formatPrice, formatQuantity } from '../utils/format'
import { api } from '../../lib/api'
import { useLanguage } from '../../contexts/LanguageContext'
import { t, type Language } from '../../i18n/translations'
import { MetricTooltip } from '../common/MetricTooltip'
import { formatPrice, formatQuantity } from '../../utils/format'
import type {
HistoricalPosition,
TraderStats,
SymbolStats,
DirectionStats,
} from '../types'
} from '../../types'
interface PositionHistoryProps {
traderId: string
@@ -1,10 +1,10 @@
import { useState, useEffect } from 'react'
import type { AIModel, Exchange, CreateTraderRequest, Strategy } from '../types'
import { useLanguage } from '../contexts/LanguageContext'
import { t } from '../i18n/translations'
import type { AIModel, Exchange, CreateTraderRequest, Strategy } from '../../types'
import { useLanguage } from '../../contexts/LanguageContext'
import { t } from '../../i18n/translations'
import { toast } from 'sonner'
import { Pencil, Plus, X as IconX, Sparkles, ExternalLink, UserPlus } from 'lucide-react'
import { httpClient } from '../lib/httpClient'
import { httpClient } from '../../lib/httpClient'
// 提取下划线后面的名称部分
function getShortName(fullName: string): string {
@@ -22,7 +22,7 @@ const EXCHANGE_REGISTRATION_LINKS: Record<string, { url: string; hasReferral?: b
lighter: { url: 'https://app.lighter.xyz/?referral=68151432', hasReferral: true },
}
import type { TraderConfigData } from '../types'
import type { TraderConfigData } from '../../types'
// 表单内部状态类型
interface FormState {
@@ -1,5 +1,5 @@
import type { TraderConfigData } from '../types'
import { PunkAvatar, getTraderAvatar } from './PunkAvatar'
import type { TraderConfigData } from '../../types'
import { PunkAvatar, getTraderAvatar } from '../common/PunkAvatar'
// 提取下划线后面的名称部分
function getShortName(fullName: string): string {
+2 -2
View File
@@ -1,7 +1,7 @@
import { useState } from 'react'
import HeaderBar from '../components/HeaderBar'
import HeaderBar from '../components/common/HeaderBar'
import LoginModal from '../components/landing/LoginModal'
import { LoginRequiredOverlay } from '../components/LoginRequiredOverlay'
import { LoginRequiredOverlay } from '../components/auth/LoginRequiredOverlay'
import FooterSection from '../components/landing/FooterSection'
import TerminalHero from '../components/landing/core/TerminalHero'
import LiveFeed from '../components/landing/core/LiveFeed'
+1 -1
View File
@@ -1,4 +1,4 @@
import { DeepVoidBackground } from '../components/DeepVoidBackground'
import { DeepVoidBackground } from '../components/common/DeepVoidBackground'
import { AlertCircle, Home } from 'lucide-react'
export function PageNotFound() {
+3 -3
View File
@@ -4,9 +4,9 @@ import { User, Cpu, Building2, MessageCircle, Eye, EyeOff, ChevronRight, Plus, P
import { useAuth } from '../contexts/AuthContext'
import { useLanguage } from '../contexts/LanguageContext'
import { api } from '../lib/api'
import { ExchangeConfigModal } from '../components/traders/ExchangeConfigModal'
import { TelegramConfigModal } from '../components/traders/TelegramConfigModal'
import { ModelConfigModal } from '../components/AITradersPage'
import { ExchangeConfigModal } from '../components/trader/ExchangeConfigModal'
import { TelegramConfigModal } from '../components/trader/TelegramConfigModal'
import { ModelConfigModal } from '../components/trader/AITradersPage'
import type { Exchange, AIModel } from '../types'
type Tab = 'account' | 'models' | 'exchanges' | 'telegram'

Some files were not shown because too many files have changed in this diff Show More