mirror of
https://github.com/laoxong/nofx.git
synced 2026-06-04 01:48:22 +08:00
refactor: restructure project directories for better modularity
- Delete llm/ dead code (3 files, zero references) - Split mcp/ into sub-packages: mcp/provider/ (8 providers) and mcp/payment/ (4 payment clients) with registry pattern - Export Client internal fields and ClientHooks interface for sub-package access - Split api/server.go (3892 lines) into 8 domain-specific handler files - Split trader/auto_trader.go (2296 lines) into 5 focused files - Reorganize web/src/components/ flat files into auth/, charts/, trader/, common/, modals/, backtest/ subdirectories - Update all consumer imports to use registry-based provider creation
This commit is contained in:
@@ -0,0 +1,211 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"nofx/config"
|
||||
"nofx/crypto"
|
||||
"nofx/logger"
|
||||
"nofx/security"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ModelConfig struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"apiKey,omitempty"`
|
||||
CustomAPIURL string `json:"customApiUrl,omitempty"`
|
||||
}
|
||||
|
||||
// SafeModelConfig Safe model configuration structure (does not contain sensitive information)
|
||||
type SafeModelConfig struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CustomAPIURL string `json:"customApiUrl"` // Custom API URL (usually not sensitive)
|
||||
CustomModelName string `json:"customModelName"` // Custom model name (not sensitive)
|
||||
}
|
||||
|
||||
type UpdateModelConfigRequest struct {
|
||||
Models map[string]struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
CustomAPIURL string `json:"custom_api_url"`
|
||||
CustomModelName string `json:"custom_model_name"`
|
||||
} `json:"models"`
|
||||
}
|
||||
|
||||
// handleGetModelConfigs Get AI model configurations
|
||||
func (s *Server) handleGetModelConfigs(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
logger.Infof("🔍 Querying AI model configs for user %s", userID)
|
||||
models, err := s.store.AIModel().List(userID)
|
||||
if err != nil {
|
||||
logger.Infof("❌ Failed to get AI model configs: %v", err)
|
||||
SafeInternalError(c, "Failed to get AI model configs", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If no models in database, return default models
|
||||
if len(models) == 0 {
|
||||
logger.Infof("⚠️ No AI models in database, returning defaults")
|
||||
defaultModels := []SafeModelConfig{
|
||||
{ID: "deepseek", Name: "DeepSeek AI", Provider: "deepseek", Enabled: false},
|
||||
{ID: "qwen", Name: "Qwen AI", Provider: "qwen", Enabled: false},
|
||||
{ID: "openai", Name: "OpenAI", Provider: "openai", Enabled: false},
|
||||
{ID: "claude", Name: "Claude AI", Provider: "claude", Enabled: false},
|
||||
{ID: "gemini", Name: "Gemini AI", Provider: "gemini", Enabled: false},
|
||||
{ID: "grok", Name: "Grok AI", Provider: "grok", Enabled: false},
|
||||
{ID: "kimi", Name: "Kimi AI", Provider: "kimi", Enabled: false},
|
||||
{ID: "minimax", Name: "MiniMax AI", Provider: "minimax", Enabled: false},
|
||||
}
|
||||
c.JSON(http.StatusOK, defaultModels)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("✅ Found %d AI model configs", len(models))
|
||||
|
||||
// Convert to safe response structure, remove sensitive information
|
||||
safeModels := make([]SafeModelConfig, len(models))
|
||||
for i, model := range models {
|
||||
safeModels[i] = SafeModelConfig{
|
||||
ID: model.ID,
|
||||
Name: model.Name,
|
||||
Provider: model.Provider,
|
||||
Enabled: model.Enabled,
|
||||
CustomAPIURL: model.CustomAPIURL,
|
||||
CustomModelName: model.CustomModelName,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, safeModels)
|
||||
}
|
||||
|
||||
// handleUpdateModelConfigs Update AI model configurations (supports both encrypted and plain text based on config)
|
||||
func (s *Server) handleUpdateModelConfigs(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
cfg := config.Get()
|
||||
|
||||
// Read raw request body
|
||||
bodyBytes, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateModelConfigRequest
|
||||
|
||||
// Check if transport encryption is enabled
|
||||
if !cfg.TransportEncryption {
|
||||
// Transport encryption disabled, accept plain JSON
|
||||
if err := json.Unmarshal(bodyBytes, &req); err != nil {
|
||||
logger.Infof("❌ Failed to parse plain JSON request: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
|
||||
return
|
||||
}
|
||||
logger.Infof("📝 Received plain text model config (UserID: %s)", userID)
|
||||
} else {
|
||||
// Transport encryption enabled, require encrypted payload
|
||||
var encryptedPayload crypto.EncryptedPayload
|
||||
if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil {
|
||||
logger.Infof("❌ Failed to parse encrypted payload: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format, encrypted transmission required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify encrypted data
|
||||
if encryptedPayload.WrappedKey == "" {
|
||||
logger.Infof("❌ Detected unencrypted request (UserID: %s)", userID)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "This endpoint only supports encrypted transmission, please use encrypted client",
|
||||
"code": "ENCRYPTION_REQUIRED",
|
||||
"message": "Encrypted transmission is required for security reasons",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt data
|
||||
decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload)
|
||||
if err != nil {
|
||||
logger.Infof("❌ Failed to decrypt model config (UserID: %s): %v", userID, err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to decrypt data"})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse decrypted data
|
||||
if err := json.Unmarshal([]byte(decrypted), &req); err != nil {
|
||||
logger.Infof("❌ Failed to parse decrypted data: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse decrypted data"})
|
||||
return
|
||||
}
|
||||
logger.Infof("🔓 Decrypted model config data (UserID: %s)", userID)
|
||||
}
|
||||
|
||||
// Update each model's configuration and track traders that need reload
|
||||
tradersToReload := make(map[string]bool)
|
||||
for modelID, modelData := range req.Models {
|
||||
// SSRF protection: validate custom_api_url before storing
|
||||
if modelData.CustomAPIURL != "" {
|
||||
cleanURL := strings.TrimSuffix(modelData.CustomAPIURL, "#")
|
||||
if err := security.ValidateURL(cleanURL); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid custom_api_url for model %s: %s", modelID, err.Error())})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Find traders using this AI model BEFORE updating
|
||||
traders, _ := s.store.Trader().ListByAIModelID(userID, modelID)
|
||||
for _, t := range traders {
|
||||
tradersToReload[t.ID] = true
|
||||
}
|
||||
|
||||
err := s.store.AIModel().Update(userID, modelID, modelData.Enabled, modelData.APIKey, modelData.CustomAPIURL, modelData.CustomModelName)
|
||||
if err != nil {
|
||||
SafeInternalError(c, fmt.Sprintf("Update model %s", modelID), err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Remove affected traders from memory BEFORE reloading to pick up new config
|
||||
for traderID := range tradersToReload {
|
||||
logger.Infof("🔄 Removing trader %s from memory to reload with new AI model config", traderID)
|
||||
s.traderManager.RemoveTrader(traderID)
|
||||
}
|
||||
|
||||
// Reload all traders for this user to make new config take effect immediately
|
||||
err = s.traderManager.LoadUserTradersFromStore(s.store, userID)
|
||||
if err != nil {
|
||||
logger.Infof("⚠️ Failed to reload user traders into memory: %v", err)
|
||||
// Don't return error here since model config was successfully updated to database
|
||||
}
|
||||
|
||||
logger.Infof("✓ AI model config updated: %+v", req.Models)
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Model configuration updated"})
|
||||
}
|
||||
|
||||
// handleGetSupportedModels Get list of AI models supported by the system
|
||||
func (s *Server) handleGetSupportedModels(c *gin.Context) {
|
||||
// Return static list of supported AI models with default versions
|
||||
supportedModels := []map[string]interface{}{
|
||||
{"id": "deepseek", "name": "DeepSeek", "provider": "deepseek", "defaultModel": "deepseek-chat"},
|
||||
{"id": "qwen", "name": "Qwen", "provider": "qwen", "defaultModel": "qwen3-max"},
|
||||
{"id": "openai", "name": "OpenAI", "provider": "openai", "defaultModel": "gpt-5.1"},
|
||||
{"id": "claude", "name": "Claude", "provider": "claude", "defaultModel": "claude-opus-4-6"},
|
||||
{"id": "gemini", "name": "Google Gemini", "provider": "gemini", "defaultModel": "gemini-3-pro-preview"},
|
||||
{"id": "grok", "name": "Grok (xAI)", "provider": "grok", "defaultModel": "grok-3-latest"},
|
||||
{"id": "kimi", "name": "Kimi (Moonshot)", "provider": "kimi", "defaultModel": "moonshot-v1-auto"},
|
||||
{"id": "minimax", "name": "MiniMax", "provider": "minimax", "defaultModel": "MiniMax-M2.5"},
|
||||
{"id": "blockrun-base", "name": "BlockRun (Base Wallet)", "provider": "blockrun-base", "defaultModel": "auto"},
|
||||
{"id": "blockrun-sol", "name": "BlockRun (Solana Wallet)", "provider": "blockrun-sol", "defaultModel": "auto"},
|
||||
{"id": "claw402", "name": "Claw402 (Base USDC)", "provider": "claw402", "defaultModel": "deepseek"},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, supportedModels)
|
||||
}
|
||||
@@ -0,0 +1,469 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nofx/logger"
|
||||
"nofx/store"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// handleDecisions Decision log list
|
||||
func (s *Server) handleDecisions(c *gin.Context) {
|
||||
_, traderID, err := s.getTraderFromQuery(c)
|
||||
if err != nil {
|
||||
SafeBadRequest(c, "Invalid trader ID")
|
||||
return
|
||||
}
|
||||
|
||||
trader, err := s.traderManager.GetTrader(traderID)
|
||||
if err != nil {
|
||||
SafeNotFound(c, "Trader")
|
||||
return
|
||||
}
|
||||
|
||||
// Get all historical decision records (unlimited)
|
||||
records, err := trader.GetStore().Decision().GetLatestRecords(trader.GetID(), 10000)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get decision log", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, records)
|
||||
}
|
||||
|
||||
// handleLatestDecisions Latest decision logs (newest first, supports limit parameter)
|
||||
func (s *Server) handleLatestDecisions(c *gin.Context) {
|
||||
_, traderID, err := s.getTraderFromQuery(c)
|
||||
if err != nil {
|
||||
SafeBadRequest(c, "Invalid trader ID")
|
||||
return
|
||||
}
|
||||
|
||||
trader, err := s.traderManager.GetTrader(traderID)
|
||||
if err != nil {
|
||||
SafeNotFound(c, "Trader")
|
||||
return
|
||||
}
|
||||
|
||||
// Get limit from query parameter, default to 5
|
||||
limit := 5
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 {
|
||||
limit = parsedLimit
|
||||
if limit > 100 {
|
||||
limit = 100 // Max 100 to prevent abuse
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
records, err := trader.GetStore().Decision().GetLatestRecords(trader.GetID(), limit)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get decision log", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Reverse array to put newest first (for list display)
|
||||
// GetLatestRecords returns oldest to newest (for charts), here we need newest to oldest
|
||||
for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 {
|
||||
records[i], records[j] = records[j], records[i]
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, records)
|
||||
}
|
||||
|
||||
// handleStatistics Statistics information
|
||||
func (s *Server) handleStatistics(c *gin.Context) {
|
||||
_, traderID, err := s.getTraderFromQuery(c)
|
||||
if err != nil {
|
||||
SafeBadRequest(c, "Invalid trader ID")
|
||||
return
|
||||
}
|
||||
|
||||
trader, err := s.traderManager.GetTrader(traderID)
|
||||
if err != nil {
|
||||
SafeNotFound(c, "Trader")
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := trader.GetStore().Decision().GetStatistics(trader.GetID())
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get statistics", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// handleCompetition Competition overview (compare all traders)
|
||||
func (s *Server) handleCompetition(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
|
||||
// Ensure user's traders are loaded into memory
|
||||
err := s.traderManager.LoadUserTradersFromStore(s.store, userID)
|
||||
if err != nil {
|
||||
logger.Infof("⚠️ Failed to load traders for user %s: %v", userID, err)
|
||||
}
|
||||
|
||||
competition, err := s.traderManager.GetCompetitionData()
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get competition data", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, competition)
|
||||
}
|
||||
|
||||
// handleEquityHistory Return rate historical data
|
||||
// Query directly from database, not dependent on trader in memory (so historical data can be retrieved after restart)
|
||||
func (s *Server) handleEquityHistory(c *gin.Context) {
|
||||
_, traderID, err := s.getTraderFromQuery(c)
|
||||
if err != nil {
|
||||
SafeBadRequest(c, "Invalid trader ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get equity historical data from new equity table
|
||||
// Every 3 minutes per cycle: 10000 records = about 20 days of data
|
||||
snapshots, err := s.store.Equity().GetLatest(traderID, 10000)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get historical data", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(snapshots) == 0 {
|
||||
c.JSON(http.StatusOK, []interface{}{})
|
||||
return
|
||||
}
|
||||
|
||||
// Build return rate historical data points
|
||||
type EquityPoint struct {
|
||||
Timestamp string `json:"timestamp"`
|
||||
TotalEquity float64 `json:"total_equity"` // Account equity (wallet + unrealized)
|
||||
AvailableBalance float64 `json:"available_balance"` // Available balance
|
||||
TotalPnL float64 `json:"total_pnl"` // Total PnL (unrealized PnL)
|
||||
TotalPnLPct float64 `json:"total_pnl_pct"` // Total PnL percentage
|
||||
PositionCount int `json:"position_count"` // Position count
|
||||
MarginUsedPct float64 `json:"margin_used_pct"` // Margin used percentage
|
||||
}
|
||||
|
||||
// Use the balance of the first record as initial balance to calculate return rate
|
||||
initialBalance := snapshots[0].Balance
|
||||
if initialBalance == 0 {
|
||||
initialBalance = 1 // Avoid division by zero
|
||||
}
|
||||
|
||||
var history []EquityPoint
|
||||
for _, snap := range snapshots {
|
||||
// Calculate PnL percentage
|
||||
totalPnLPct := 0.0
|
||||
if initialBalance > 0 {
|
||||
totalPnLPct = (snap.UnrealizedPnL / initialBalance) * 100
|
||||
}
|
||||
|
||||
history = append(history, EquityPoint{
|
||||
Timestamp: snap.Timestamp.Format("2006-01-02 15:04:05"),
|
||||
TotalEquity: snap.TotalEquity,
|
||||
AvailableBalance: snap.Balance,
|
||||
TotalPnL: snap.UnrealizedPnL,
|
||||
TotalPnLPct: totalPnLPct,
|
||||
PositionCount: snap.PositionCount,
|
||||
MarginUsedPct: snap.MarginUsedPct,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, history)
|
||||
}
|
||||
|
||||
// handlePublicTraderList Get public trader list (no authentication required)
|
||||
func (s *Server) handlePublicTraderList(c *gin.Context) {
|
||||
// Get trader information from all users
|
||||
competition, err := s.traderManager.GetCompetitionData()
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get trader list", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Get traders array
|
||||
tradersData, exists := competition["traders"]
|
||||
if !exists {
|
||||
c.JSON(http.StatusOK, []map[string]interface{}{})
|
||||
return
|
||||
}
|
||||
|
||||
traders, ok := tradersData.([]map[string]interface{})
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Trader data format error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Return trader basic information, filter sensitive information
|
||||
result := make([]map[string]interface{}, 0, len(traders))
|
||||
for _, trader := range traders {
|
||||
result = append(result, map[string]interface{}{
|
||||
"trader_id": trader["trader_id"],
|
||||
"trader_name": trader["trader_name"],
|
||||
"ai_model": trader["ai_model"],
|
||||
"exchange": trader["exchange"],
|
||||
"is_running": trader["is_running"],
|
||||
"total_equity": trader["total_equity"],
|
||||
"total_pnl": trader["total_pnl"],
|
||||
"total_pnl_pct": trader["total_pnl_pct"],
|
||||
"position_count": trader["position_count"],
|
||||
"margin_used_pct": trader["margin_used_pct"],
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// handlePublicCompetition Get public competition data (no authentication required)
|
||||
func (s *Server) handlePublicCompetition(c *gin.Context) {
|
||||
competition, err := s.traderManager.GetCompetitionData()
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get competition data", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, competition)
|
||||
}
|
||||
|
||||
// handleTopTraders Get top 5 trader data (no authentication required, for performance comparison)
|
||||
func (s *Server) handleTopTraders(c *gin.Context) {
|
||||
topTraders, err := s.traderManager.GetTopTradersData()
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get top traders data", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, topTraders)
|
||||
}
|
||||
|
||||
// handleEquityHistoryBatch Batch get return rate historical data for multiple traders (no authentication required, for performance comparison)
|
||||
// Supports optional 'hours' parameter to filter data by time range (e.g., hours=24 for last 24 hours)
|
||||
func (s *Server) handleEquityHistoryBatch(c *gin.Context) {
|
||||
var requestBody struct {
|
||||
TraderIDs []string `json:"trader_ids"`
|
||||
Hours int `json:"hours"` // Optional: filter by last N hours (0 = all data)
|
||||
}
|
||||
|
||||
// Try to parse POST request JSON body
|
||||
if err := c.ShouldBindJSON(&requestBody); err != nil {
|
||||
// If JSON parse fails, try to get from query parameters (compatible with GET request)
|
||||
traderIDsParam := c.Query("trader_ids")
|
||||
if traderIDsParam == "" {
|
||||
// If no trader_ids specified, return historical data for top 5
|
||||
topTraders, err := s.traderManager.GetTopTradersData()
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get top traders", err)
|
||||
return
|
||||
}
|
||||
|
||||
traders, ok := topTraders["traders"].([]map[string]interface{})
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Trader data format error"})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract trader IDs
|
||||
traderIDs := make([]string, 0, len(traders))
|
||||
for _, trader := range traders {
|
||||
if traderID, ok := trader["trader_id"].(string); ok {
|
||||
traderIDs = append(traderIDs, traderID)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse hours parameter from query
|
||||
hoursParam := c.Query("hours")
|
||||
hours := 0
|
||||
if hoursParam != "" {
|
||||
fmt.Sscanf(hoursParam, "%d", &hours)
|
||||
}
|
||||
|
||||
result := s.getEquityHistoryForTraders(traderIDs, hours)
|
||||
c.JSON(http.StatusOK, result)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse comma-separated trader IDs
|
||||
requestBody.TraderIDs = strings.Split(traderIDsParam, ",")
|
||||
for i := range requestBody.TraderIDs {
|
||||
requestBody.TraderIDs[i] = strings.TrimSpace(requestBody.TraderIDs[i])
|
||||
}
|
||||
|
||||
// Parse hours parameter from query
|
||||
hoursParam := c.Query("hours")
|
||||
if hoursParam != "" {
|
||||
fmt.Sscanf(hoursParam, "%d", &requestBody.Hours)
|
||||
}
|
||||
}
|
||||
|
||||
// Limit to maximum 20 traders to prevent oversized requests
|
||||
if len(requestBody.TraderIDs) > 20 {
|
||||
requestBody.TraderIDs = requestBody.TraderIDs[:20]
|
||||
}
|
||||
|
||||
result := s.getEquityHistoryForTraders(requestBody.TraderIDs, requestBody.Hours)
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// getEquityHistoryForTraders Get historical data for multiple traders
|
||||
// Query directly from database, not dependent on trader in memory (so historical data can be retrieved after restart)
|
||||
// Also appends current real-time data point to ensure chart matches leaderboard
|
||||
// hours: filter by last N hours (0 = use default limit of 500 records)
|
||||
func (s *Server) getEquityHistoryForTraders(traderIDs []string, hours int) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
histories := make(map[string]interface{})
|
||||
errors := make(map[string]string)
|
||||
|
||||
// Use a single consistent timestamp for all real-time data points
|
||||
now := time.Now()
|
||||
|
||||
// Pre-fetch initial balances for all traders
|
||||
initialBalances := make(map[string]float64)
|
||||
for _, traderID := range traderIDs {
|
||||
if traderID == "" {
|
||||
continue
|
||||
}
|
||||
// Get trader's initial balance from database (use GetByID which doesn't require userID)
|
||||
trader, err := s.store.Trader().GetByID(traderID)
|
||||
if err == nil && trader != nil && trader.InitialBalance > 0 {
|
||||
initialBalances[traderID] = trader.InitialBalance
|
||||
}
|
||||
}
|
||||
|
||||
for _, traderID := range traderIDs {
|
||||
if traderID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get equity historical data from new equity table
|
||||
var snapshots []*store.EquitySnapshot
|
||||
var err error
|
||||
|
||||
if hours > 0 {
|
||||
// Filter by time range
|
||||
startTime := now.Add(-time.Duration(hours) * time.Hour)
|
||||
snapshots, err = s.store.Equity().GetByTimeRange(traderID, startTime, now)
|
||||
} else {
|
||||
// Default: get latest 500 records
|
||||
snapshots, err = s.store.Equity().GetLatest(traderID, 500)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Errorf("[API] Failed to get equity history for %s: %v", traderID, err)
|
||||
errors[traderID] = "Failed to get historical data"
|
||||
continue
|
||||
}
|
||||
|
||||
// Get initial balance for calculating PnL percentage
|
||||
initialBalance := initialBalances[traderID]
|
||||
if initialBalance <= 0 && len(snapshots) > 0 {
|
||||
// If no initial balance configured, use the first snapshot's equity as baseline
|
||||
initialBalance = snapshots[0].TotalEquity
|
||||
}
|
||||
|
||||
// Build return rate historical data with PnL percentage
|
||||
history := make([]map[string]interface{}, 0, len(snapshots)+1)
|
||||
var lastSnapshotTime time.Time
|
||||
for _, snap := range snapshots {
|
||||
// Calculate PnL percentage: (current_equity - initial_balance) / initial_balance * 100
|
||||
pnlPct := 0.0
|
||||
if initialBalance > 0 {
|
||||
pnlPct = (snap.TotalEquity - initialBalance) / initialBalance * 100
|
||||
}
|
||||
|
||||
history = append(history, map[string]interface{}{
|
||||
"timestamp": snap.Timestamp,
|
||||
"total_equity": snap.TotalEquity,
|
||||
"total_pnl": snap.UnrealizedPnL,
|
||||
"total_pnl_pct": pnlPct,
|
||||
"balance": snap.Balance,
|
||||
})
|
||||
if snap.Timestamp.After(lastSnapshotTime) {
|
||||
lastSnapshotTime = snap.Timestamp
|
||||
}
|
||||
}
|
||||
|
||||
// Append current real-time data point to ensure chart matches leaderboard
|
||||
// This ensures the latest point is always current, not from a potentially stale snapshot
|
||||
if trader, err := s.traderManager.GetTrader(traderID); err == nil {
|
||||
if accountInfo, err := trader.GetAccountInfo(); err == nil {
|
||||
// Only append if it's been more than 30 seconds since last snapshot
|
||||
if now.Sub(lastSnapshotTime) > 30*time.Second {
|
||||
totalEquity := 0.0
|
||||
if v, ok := accountInfo["total_equity"].(float64); ok {
|
||||
totalEquity = v
|
||||
}
|
||||
totalPnL := 0.0
|
||||
if v, ok := accountInfo["total_pnl"].(float64); ok {
|
||||
totalPnL = v
|
||||
}
|
||||
walletBalance := 0.0
|
||||
if v, ok := accountInfo["wallet_balance"].(float64); ok {
|
||||
walletBalance = v
|
||||
}
|
||||
pnlPct := 0.0
|
||||
if initialBalance > 0 {
|
||||
pnlPct = (totalEquity - initialBalance) / initialBalance * 100
|
||||
}
|
||||
|
||||
history = append(history, map[string]interface{}{
|
||||
"timestamp": now,
|
||||
"total_equity": totalEquity,
|
||||
"total_pnl": totalPnL,
|
||||
"total_pnl_pct": pnlPct,
|
||||
"balance": walletBalance,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
histories[traderID] = history
|
||||
}
|
||||
|
||||
result["histories"] = histories
|
||||
result["count"] = len(histories)
|
||||
if len(errors) > 0 {
|
||||
result["errors"] = errors
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// handleGetPublicTraderConfig Get public trader configuration information (no authentication required, does not include sensitive information)
|
||||
func (s *Server) handleGetPublicTraderConfig(c *gin.Context) {
|
||||
traderID := c.Param("id")
|
||||
if traderID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Trader ID cannot be empty"})
|
||||
return
|
||||
}
|
||||
|
||||
trader, err := s.traderManager.GetTrader(traderID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Trader does not exist"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get trader status information
|
||||
status := trader.GetStatus()
|
||||
|
||||
// Only return public configuration information, not including sensitive data like API keys
|
||||
result := map[string]interface{}{
|
||||
"trader_id": trader.GetID(),
|
||||
"trader_name": trader.GetName(),
|
||||
"ai_model": trader.GetAIModel(),
|
||||
"exchange": trader.GetExchange(),
|
||||
"is_running": status["is_running"],
|
||||
"ai_provider": status["ai_provider"],
|
||||
"start_time": status["start_time"],
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
@@ -0,0 +1,353 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"nofx/config"
|
||||
"nofx/crypto"
|
||||
"nofx/logger"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ExchangeConfig struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // "cex" or "dex"
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"apiKey,omitempty"`
|
||||
SecretKey string `json:"secretKey,omitempty"`
|
||||
Testnet bool `json:"testnet,omitempty"`
|
||||
}
|
||||
|
||||
// SafeExchangeConfig Safe exchange configuration structure (does not contain sensitive information)
|
||||
type SafeExchangeConfig struct {
|
||||
ID string `json:"id"` // UUID
|
||||
ExchangeType string `json:"exchange_type"` // "binance", "bybit", "okx", "hyperliquid", "aster", "lighter"
|
||||
AccountName string `json:"account_name"` // User-defined account name
|
||||
Name string `json:"name"` // Display name
|
||||
Type string `json:"type"` // "cex" or "dex"
|
||||
Enabled bool `json:"enabled"`
|
||||
Testnet bool `json:"testnet,omitempty"`
|
||||
HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"` // Hyperliquid wallet address (not sensitive)
|
||||
AsterUser string `json:"asterUser"` // Aster username (not sensitive)
|
||||
AsterSigner string `json:"asterSigner"` // Aster signer (not sensitive)
|
||||
LighterWalletAddr string `json:"lighterWalletAddr"` // LIGHTER wallet address (not sensitive)
|
||||
}
|
||||
|
||||
type UpdateExchangeConfigRequest struct {
|
||||
Exchanges map[string]struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
Passphrase string `json:"passphrase"` // OKX specific
|
||||
Testnet bool `json:"testnet"`
|
||||
HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"`
|
||||
HyperliquidUnifiedAcct bool `json:"hyperliquid_unified_account"` // Unified Account mode
|
||||
AsterUser string `json:"aster_user"`
|
||||
AsterSigner string `json:"aster_signer"`
|
||||
AsterPrivateKey string `json:"aster_private_key"`
|
||||
LighterWalletAddr string `json:"lighter_wallet_addr"`
|
||||
LighterPrivateKey string `json:"lighter_private_key"`
|
||||
LighterAPIKeyPrivateKey string `json:"lighter_api_key_private_key"`
|
||||
LighterAPIKeyIndex int `json:"lighter_api_key_index"`
|
||||
} `json:"exchanges"`
|
||||
}
|
||||
|
||||
// CreateExchangeRequest request structure for creating a new exchange account
|
||||
type CreateExchangeRequest struct {
|
||||
ExchangeType string `json:"exchange_type" binding:"required"` // "binance", "bybit", "okx", "hyperliquid", "aster", "lighter"
|
||||
AccountName string `json:"account_name"` // User-defined account name
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
SecretKey string `json:"secret_key"`
|
||||
Passphrase string `json:"passphrase"`
|
||||
Testnet bool `json:"testnet"`
|
||||
HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr"`
|
||||
HyperliquidUnifiedAcct bool `json:"hyperliquid_unified_account"` // Unified Account mode: Spot as Perp collateral
|
||||
AsterUser string `json:"aster_user"`
|
||||
AsterSigner string `json:"aster_signer"`
|
||||
AsterPrivateKey string `json:"aster_private_key"`
|
||||
LighterWalletAddr string `json:"lighter_wallet_addr"`
|
||||
LighterPrivateKey string `json:"lighter_private_key"`
|
||||
LighterAPIKeyPrivateKey string `json:"lighter_api_key_private_key"`
|
||||
LighterAPIKeyIndex int `json:"lighter_api_key_index"`
|
||||
}
|
||||
|
||||
// handleGetExchangeConfigs Get exchange configurations
|
||||
func (s *Server) handleGetExchangeConfigs(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
logger.Infof("🔍 Querying exchange configs for user %s", userID)
|
||||
exchanges, err := s.store.Exchange().List(userID)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Failed to get exchange configs", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If no exchanges in database, return empty array (user needs to create accounts)
|
||||
if len(exchanges) == 0 {
|
||||
logger.Infof("⚠️ No exchanges in database for user %s", userID)
|
||||
c.JSON(http.StatusOK, []SafeExchangeConfig{})
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("✅ Found %d exchange configs", len(exchanges))
|
||||
|
||||
// Convert to safe response structure, remove sensitive information
|
||||
safeExchanges := make([]SafeExchangeConfig, len(exchanges))
|
||||
for i, exchange := range exchanges {
|
||||
safeExchanges[i] = SafeExchangeConfig{
|
||||
ID: exchange.ID,
|
||||
ExchangeType: exchange.ExchangeType,
|
||||
AccountName: exchange.AccountName,
|
||||
Name: exchange.Name,
|
||||
Type: exchange.Type,
|
||||
Enabled: exchange.Enabled,
|
||||
Testnet: exchange.Testnet,
|
||||
HyperliquidWalletAddr: exchange.HyperliquidWalletAddr,
|
||||
AsterUser: exchange.AsterUser,
|
||||
AsterSigner: exchange.AsterSigner,
|
||||
LighterWalletAddr: exchange.LighterWalletAddr,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, safeExchanges)
|
||||
}
|
||||
|
||||
// handleUpdateExchangeConfigs Update exchange configurations (supports both encrypted and plain text based on config)
|
||||
func (s *Server) handleUpdateExchangeConfigs(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
cfg := config.Get()
|
||||
|
||||
// Read raw request body
|
||||
bodyBytes, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateExchangeConfigRequest
|
||||
|
||||
// Check if transport encryption is enabled
|
||||
if !cfg.TransportEncryption {
|
||||
// Transport encryption disabled, accept plain JSON
|
||||
if err := json.Unmarshal(bodyBytes, &req); err != nil {
|
||||
logger.Infof("❌ Failed to parse plain JSON request: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
|
||||
return
|
||||
}
|
||||
logger.Infof("📝 Received plain text exchange config (UserID: %s)", userID)
|
||||
} else {
|
||||
// Transport encryption enabled, require encrypted payload
|
||||
var encryptedPayload crypto.EncryptedPayload
|
||||
if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil {
|
||||
logger.Infof("❌ Failed to parse encrypted payload: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format, encrypted transmission required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify encrypted data
|
||||
if encryptedPayload.WrappedKey == "" {
|
||||
logger.Infof("❌ Detected unencrypted request (UserID: %s)", userID)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "This endpoint only supports encrypted transmission, please use encrypted client",
|
||||
"code": "ENCRYPTION_REQUIRED",
|
||||
"message": "Encrypted transmission is required for security reasons",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt data
|
||||
decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload)
|
||||
if err != nil {
|
||||
logger.Infof("❌ Failed to decrypt exchange config (UserID: %s): %v", userID, err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to decrypt data"})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse decrypted data
|
||||
if err := json.Unmarshal([]byte(decrypted), &req); err != nil {
|
||||
logger.Infof("❌ Failed to parse decrypted data: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse decrypted data"})
|
||||
return
|
||||
}
|
||||
logger.Infof("🔓 Decrypted exchange config data (UserID: %s)", userID)
|
||||
}
|
||||
|
||||
// Update each exchange's configuration and track traders that need reload
|
||||
tradersToReload := make(map[string]bool)
|
||||
for exchangeID, exchangeData := range req.Exchanges {
|
||||
// Find traders using this exchange BEFORE updating
|
||||
traders, _ := s.store.Trader().ListByExchangeID(userID, exchangeID)
|
||||
for _, t := range traders {
|
||||
tradersToReload[t.ID] = true
|
||||
}
|
||||
|
||||
err := s.store.Exchange().Update(userID, exchangeID, exchangeData.Enabled, exchangeData.APIKey, exchangeData.SecretKey, exchangeData.Passphrase, exchangeData.Testnet, exchangeData.HyperliquidWalletAddr, exchangeData.HyperliquidUnifiedAcct, exchangeData.AsterUser, exchangeData.AsterSigner, exchangeData.AsterPrivateKey, exchangeData.LighterWalletAddr, exchangeData.LighterPrivateKey, exchangeData.LighterAPIKeyPrivateKey, exchangeData.LighterAPIKeyIndex)
|
||||
if err != nil {
|
||||
SafeInternalError(c, fmt.Sprintf("Update exchange %s", exchangeID), err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Remove affected traders from memory BEFORE reloading to pick up new config
|
||||
for traderID := range tradersToReload {
|
||||
logger.Infof("🔄 Removing trader %s from memory to reload with new exchange config", traderID)
|
||||
s.traderManager.RemoveTrader(traderID)
|
||||
}
|
||||
|
||||
// Reload all traders for this user to make new config take effect immediately
|
||||
err = s.traderManager.LoadUserTradersFromStore(s.store, userID)
|
||||
if err != nil {
|
||||
logger.Infof("⚠️ Failed to reload user traders into memory: %v", err)
|
||||
// Don't return error here since exchange config was successfully updated to database
|
||||
}
|
||||
|
||||
logger.Infof("✓ Exchange config updated: %+v", req.Exchanges)
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Exchange configuration updated"})
|
||||
}
|
||||
|
||||
// handleCreateExchange Create a new exchange account
|
||||
func (s *Server) handleCreateExchange(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
cfg := config.Get()
|
||||
|
||||
// Read raw request body
|
||||
bodyBytes, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to read request body"})
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateExchangeRequest
|
||||
|
||||
// Check if transport encryption is enabled
|
||||
if !cfg.TransportEncryption {
|
||||
// Transport encryption disabled, accept plain JSON
|
||||
if err := json.Unmarshal(bodyBytes, &req); err != nil {
|
||||
logger.Infof("❌ Failed to parse plain JSON request: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Transport encryption enabled, require encrypted payload
|
||||
var encryptedPayload crypto.EncryptedPayload
|
||||
if err := json.Unmarshal(bodyBytes, &encryptedPayload); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format, encrypted transmission required"})
|
||||
return
|
||||
}
|
||||
|
||||
if encryptedPayload.WrappedKey == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "This endpoint only supports encrypted transmission",
|
||||
"code": "ENCRYPTION_REQUIRED",
|
||||
"message": "Encrypted transmission is required for security reasons",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
decrypted, err := s.cryptoHandler.cryptoService.DecryptSensitiveData(&encryptedPayload)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to decrypt data"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(decrypted), &req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse decrypted data"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Validate exchange type
|
||||
validTypes := map[string]bool{
|
||||
"binance": true, "bybit": true, "okx": true, "bitget": true,
|
||||
"hyperliquid": true, "aster": true, "lighter": true, "gate": true, "kucoin": true, "indodax": true,
|
||||
}
|
||||
if !validTypes[req.ExchangeType] {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid exchange type: %s", req.ExchangeType)})
|
||||
return
|
||||
}
|
||||
|
||||
// Create new exchange account
|
||||
id, err := s.store.Exchange().Create(
|
||||
userID, req.ExchangeType, req.AccountName, req.Enabled,
|
||||
req.APIKey, req.SecretKey, req.Passphrase, req.Testnet,
|
||||
req.HyperliquidWalletAddr, req.HyperliquidUnifiedAcct,
|
||||
req.AsterUser, req.AsterSigner, req.AsterPrivateKey,
|
||||
req.LighterWalletAddr, req.LighterPrivateKey, req.LighterAPIKeyPrivateKey, req.LighterAPIKeyIndex,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Infof("❌ Failed to create exchange account: %v", err)
|
||||
SafeInternalError(c, "Failed to create exchange account", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("✓ Created exchange account: type=%s, name=%s, id=%s", req.ExchangeType, req.AccountName, id)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "Exchange account created",
|
||||
"id": id,
|
||||
})
|
||||
}
|
||||
|
||||
// handleDeleteExchange Delete an exchange account
|
||||
func (s *Server) handleDeleteExchange(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
exchangeID := c.Param("id")
|
||||
|
||||
if exchangeID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Exchange ID is required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if any traders are using this exchange
|
||||
traders, err := s.store.Trader().List(userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check traders"})
|
||||
return
|
||||
}
|
||||
|
||||
for _, trader := range traders {
|
||||
if trader.ExchangeID == exchangeID {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Cannot delete exchange account that is in use by traders",
|
||||
"trader_id": trader.ID,
|
||||
"trader_name": trader.Name,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Delete exchange account
|
||||
err = s.store.Exchange().Delete(userID, exchangeID)
|
||||
if err != nil {
|
||||
logger.Infof("❌ Failed to delete exchange account: %v", err)
|
||||
SafeInternalError(c, "Failed to delete exchange account", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("✓ Deleted exchange account: id=%s", exchangeID)
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Exchange account deleted"})
|
||||
}
|
||||
|
||||
// handleGetSupportedExchanges Get list of exchanges supported by the system
|
||||
func (s *Server) handleGetSupportedExchanges(c *gin.Context) {
|
||||
// Return static list of supported exchange types
|
||||
// Note: ID is empty for supported exchanges (they are templates, not actual accounts)
|
||||
supportedExchanges := []SafeExchangeConfig{
|
||||
{ExchangeType: "binance", Name: "Binance Futures", Type: "cex"},
|
||||
{ExchangeType: "bybit", Name: "Bybit Futures", Type: "cex"},
|
||||
{ExchangeType: "okx", Name: "OKX Futures", Type: "cex"},
|
||||
{ExchangeType: "gate", Name: "Gate.io Futures", Type: "cex"},
|
||||
{ExchangeType: "kucoin", Name: "KuCoin Futures", Type: "cex"},
|
||||
{ExchangeType: "hyperliquid", Name: "Hyperliquid", Type: "dex"},
|
||||
{ExchangeType: "aster", Name: "Aster DEX", Type: "dex"},
|
||||
{ExchangeType: "lighter", Name: "LIGHTER DEX", Type: "dex"},
|
||||
{ExchangeType: "alpaca", Name: "Alpaca (US Stocks)", Type: "stock"},
|
||||
{ExchangeType: "forex", Name: "Forex (TwelveData)", Type: "forex"},
|
||||
{ExchangeType: "metals", Name: "Metals (TwelveData)", Type: "metals"},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, supportedExchanges)
|
||||
}
|
||||
@@ -0,0 +1,392 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nofx/logger"
|
||||
"nofx/market"
|
||||
"nofx/provider/alpaca"
|
||||
"nofx/provider/coinank/coinank_api"
|
||||
"nofx/provider/coinank/coinank_enum"
|
||||
"nofx/provider/hyperliquid"
|
||||
"nofx/provider/twelvedata"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// handleKlines K-line data (supports multiple exchanges via coinank)
|
||||
func (s *Server) handleKlines(c *gin.Context) {
|
||||
// Get query parameters
|
||||
symbol := c.Query("symbol")
|
||||
if symbol == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "symbol parameter is required"})
|
||||
return
|
||||
}
|
||||
|
||||
interval := c.DefaultQuery("interval", "5m")
|
||||
exchange := c.DefaultQuery("exchange", "binance") // Default to binance for backward compatibility
|
||||
limitStr := c.DefaultQuery("limit", "1000")
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err != nil || limit <= 0 {
|
||||
limit = 1000
|
||||
}
|
||||
|
||||
// Coinank API has a maximum limit of 1500 klines per request
|
||||
if limit > 1500 {
|
||||
limit = 1500
|
||||
}
|
||||
|
||||
var klines []market.Kline
|
||||
exchangeLower := strings.ToLower(exchange)
|
||||
|
||||
// Route to appropriate data source based on exchange type
|
||||
switch exchangeLower {
|
||||
case "alpaca":
|
||||
// US Stocks via Alpaca
|
||||
klines, err = s.getKlinesFromAlpaca(symbol, interval, limit)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get klines from Alpaca", err)
|
||||
return
|
||||
}
|
||||
case "forex", "metals":
|
||||
// Forex and Metals via Twelve Data
|
||||
klines, err = s.getKlinesFromTwelveData(symbol, interval, limit)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get klines from TwelveData", err)
|
||||
return
|
||||
}
|
||||
case "hyperliquid", "hyperliquid-xyz", "xyz":
|
||||
// Hyperliquid native API - supports both crypto perps and stock perps (xyz dex)
|
||||
klines, err = s.getKlinesFromHyperliquid(symbol, interval, limit)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get klines from Hyperliquid", err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
// Crypto exchanges via CoinAnk
|
||||
symbol = market.Normalize(symbol)
|
||||
klines, err = s.getKlinesFromCoinank(symbol, interval, exchange, limit)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get klines from CoinAnk", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, klines)
|
||||
}
|
||||
|
||||
// getKlinesFromCoinank fetches kline data from coinank free/open API for multiple exchanges
|
||||
func (s *Server) getKlinesFromCoinank(symbol, interval, exchange string, limit int) ([]market.Kline, error) {
|
||||
// Map exchange string to coinank enum
|
||||
var coinankExchange coinank_enum.Exchange
|
||||
switch strings.ToLower(exchange) {
|
||||
case "binance":
|
||||
coinankExchange = coinank_enum.Binance
|
||||
case "bybit":
|
||||
coinankExchange = coinank_enum.Bybit
|
||||
case "okx":
|
||||
coinankExchange = coinank_enum.Okex
|
||||
case "bitget":
|
||||
coinankExchange = coinank_enum.Bitget
|
||||
case "gate":
|
||||
coinankExchange = coinank_enum.Gate
|
||||
case "aster":
|
||||
coinankExchange = coinank_enum.Aster
|
||||
case "lighter":
|
||||
// Lighter doesn't have direct CoinAnk support, use Binance data as fallback
|
||||
coinankExchange = coinank_enum.Binance
|
||||
case "kucoin":
|
||||
// KuCoin doesn't have direct CoinAnk support, use Binance data as fallback
|
||||
coinankExchange = coinank_enum.Binance
|
||||
default:
|
||||
// For any unknown exchange, default to Binance
|
||||
logger.Warnf("⚠️ Unknown exchange '%s', defaulting to Binance for CoinAnk", exchange)
|
||||
coinankExchange = coinank_enum.Binance
|
||||
}
|
||||
|
||||
// Map interval string to coinank enum
|
||||
var coinankInterval coinank_enum.Interval
|
||||
switch interval {
|
||||
case "1s":
|
||||
coinankInterval = coinank_enum.Second1
|
||||
case "5s":
|
||||
coinankInterval = coinank_enum.Second5
|
||||
case "10s":
|
||||
coinankInterval = coinank_enum.Second10
|
||||
case "30s":
|
||||
coinankInterval = coinank_enum.Second30
|
||||
case "1m":
|
||||
coinankInterval = coinank_enum.Minute1
|
||||
case "3m":
|
||||
coinankInterval = coinank_enum.Minute3
|
||||
case "5m":
|
||||
coinankInterval = coinank_enum.Minute5
|
||||
case "10m":
|
||||
coinankInterval = coinank_enum.Minute10
|
||||
case "15m":
|
||||
coinankInterval = coinank_enum.Minute15
|
||||
case "30m":
|
||||
coinankInterval = coinank_enum.Minute30
|
||||
case "1h":
|
||||
coinankInterval = coinank_enum.Hour1
|
||||
case "2h":
|
||||
coinankInterval = coinank_enum.Hour2
|
||||
case "4h":
|
||||
coinankInterval = coinank_enum.Hour4
|
||||
case "6h":
|
||||
coinankInterval = coinank_enum.Hour6
|
||||
case "8h":
|
||||
coinankInterval = coinank_enum.Hour8
|
||||
case "12h":
|
||||
coinankInterval = coinank_enum.Hour12
|
||||
case "1d":
|
||||
coinankInterval = coinank_enum.Day1
|
||||
case "3d":
|
||||
coinankInterval = coinank_enum.Day3
|
||||
case "1w":
|
||||
coinankInterval = coinank_enum.Week1
|
||||
case "1M":
|
||||
coinankInterval = coinank_enum.Month1
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported interval for coinank: %s", interval)
|
||||
}
|
||||
|
||||
// Convert symbol format for different exchanges
|
||||
// OKX uses "BTC-USDT-SWAP" format instead of "BTCUSDT"
|
||||
apiSymbol := symbol
|
||||
if coinankExchange == coinank_enum.Okex {
|
||||
// Convert BTCUSDT -> BTC-USDT-SWAP
|
||||
if strings.HasSuffix(symbol, "USDT") {
|
||||
base := strings.TrimSuffix(symbol, "USDT")
|
||||
apiSymbol = fmt.Sprintf("%s-USDT-SWAP", base)
|
||||
}
|
||||
}
|
||||
|
||||
// Call coinank free/open API (no authentication required)
|
||||
ctx := context.Background()
|
||||
ts := time.Now().UnixMilli()
|
||||
// Use "To" side to search backward from current time (get historical klines)
|
||||
coinankKlines, err := coinank_api.Kline(ctx, apiSymbol, coinankExchange, ts, coinank_enum.To, limit, coinankInterval)
|
||||
if err != nil {
|
||||
// Free API doesn't support all exchanges (e.g., OKX, Bitget)
|
||||
// Fallback to Binance data as reference
|
||||
if coinankExchange != coinank_enum.Binance {
|
||||
logger.Warnf("⚠️ CoinAnk free API doesn't support %s, falling back to Binance data", coinankExchange)
|
||||
coinankKlines, err = coinank_api.Kline(ctx, symbol, coinank_enum.Binance, ts, coinank_enum.To, limit, coinankInterval)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("coinank API error (fallback): %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("coinank API error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert coinank kline format to market.Kline format
|
||||
// Coinank: Volume = BTC 数量, Quantity = USDT 成交额
|
||||
klines := make([]market.Kline, len(coinankKlines))
|
||||
for i, ck := range coinankKlines {
|
||||
klines[i] = market.Kline{
|
||||
OpenTime: ck.StartTime,
|
||||
Open: ck.Open,
|
||||
High: ck.High,
|
||||
Low: ck.Low,
|
||||
Close: ck.Close,
|
||||
Volume: ck.Volume, // BTC 数量
|
||||
QuoteVolume: ck.Quantity, // USDT 成交额
|
||||
CloseTime: ck.EndTime,
|
||||
}
|
||||
}
|
||||
|
||||
return klines, nil
|
||||
}
|
||||
|
||||
// getKlinesFromAlpaca fetches kline data from Alpaca API for US stocks
|
||||
func (s *Server) getKlinesFromAlpaca(symbol, interval string, limit int) ([]market.Kline, error) {
|
||||
// Create Alpaca client
|
||||
client := alpaca.NewClient()
|
||||
|
||||
// Map interval to Alpaca timeframe format
|
||||
timeframe := alpaca.MapTimeframe(interval)
|
||||
|
||||
// Fetch bars from Alpaca
|
||||
ctx := context.Background()
|
||||
bars, err := client.GetBars(ctx, symbol, timeframe, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alpaca API error: %w", err)
|
||||
}
|
||||
|
||||
// Convert Alpaca bars to market.Kline format
|
||||
klines := make([]market.Kline, len(bars))
|
||||
for i, bar := range bars {
|
||||
klines[i] = market.Kline{
|
||||
OpenTime: bar.Timestamp.UnixMilli(),
|
||||
Open: bar.Open,
|
||||
High: bar.High,
|
||||
Low: bar.Low,
|
||||
Close: bar.Close,
|
||||
Volume: float64(bar.Volume), // 股数
|
||||
QuoteVolume: float64(bar.Volume) * bar.Close, // 成交额 = 股数 * 收盘价 (USD)
|
||||
CloseTime: bar.Timestamp.UnixMilli(),
|
||||
}
|
||||
}
|
||||
|
||||
return klines, nil
|
||||
}
|
||||
|
||||
// getKlinesFromTwelveData fetches kline data from Twelve Data API for forex and metals
|
||||
func (s *Server) getKlinesFromTwelveData(symbol, interval string, limit int) ([]market.Kline, error) {
|
||||
// Create Twelve Data client
|
||||
client := twelvedata.NewClient()
|
||||
|
||||
// Map interval to Twelve Data timeframe format
|
||||
timeframe := twelvedata.MapTimeframe(interval)
|
||||
|
||||
// Fetch time series from Twelve Data
|
||||
ctx := context.Background()
|
||||
result, err := client.GetTimeSeries(ctx, symbol, timeframe, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("twelvedata API error: %w", err)
|
||||
}
|
||||
|
||||
// Convert Twelve Data bars to market.Kline format
|
||||
// Note: Twelve Data returns bars in reverse order (newest first)
|
||||
klines := make([]market.Kline, len(result.Values))
|
||||
for i, bar := range result.Values {
|
||||
open, high, low, close, volume, timestamp, err := twelvedata.ParseBar(bar)
|
||||
if err != nil {
|
||||
logger.Warnf("⚠️ Failed to parse TwelveData bar: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Reverse order: put oldest first
|
||||
idx := len(result.Values) - 1 - i
|
||||
klines[idx] = market.Kline{
|
||||
OpenTime: timestamp,
|
||||
Open: open,
|
||||
High: high,
|
||||
Low: low,
|
||||
Close: close,
|
||||
Volume: volume,
|
||||
CloseTime: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
return klines, nil
|
||||
}
|
||||
|
||||
// getKlinesFromHyperliquid fetches kline data from Hyperliquid API
|
||||
// Supports both crypto perps (default dex) and stock perps/forex/commodities (xyz dex)
|
||||
func (s *Server) getKlinesFromHyperliquid(symbol, interval string, limit int) ([]market.Kline, error) {
|
||||
// Create Hyperliquid client
|
||||
client := hyperliquid.NewClient()
|
||||
|
||||
// Map interval to Hyperliquid format
|
||||
timeframe := hyperliquid.MapTimeframe(interval)
|
||||
|
||||
// Fetch candles from Hyperliquid
|
||||
// FormatCoinForAPI will automatically add xyz: prefix for stock perps
|
||||
ctx := context.Background()
|
||||
candles, err := client.GetCandles(ctx, symbol, timeframe, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hyperliquid API error: %w", err)
|
||||
}
|
||||
|
||||
// Convert Hyperliquid candles to market.Kline format
|
||||
klines := make([]market.Kline, len(candles))
|
||||
for i, candle := range candles {
|
||||
open, _ := strconv.ParseFloat(candle.Open, 64)
|
||||
high, _ := strconv.ParseFloat(candle.High, 64)
|
||||
low, _ := strconv.ParseFloat(candle.Low, 64)
|
||||
close, _ := strconv.ParseFloat(candle.Close, 64)
|
||||
volume, _ := strconv.ParseFloat(candle.Volume, 64)
|
||||
|
||||
klines[i] = market.Kline{
|
||||
OpenTime: candle.OpenTime,
|
||||
Open: open,
|
||||
High: high,
|
||||
Low: low,
|
||||
Close: close,
|
||||
Volume: volume, // 合约数量
|
||||
QuoteVolume: volume * close, // 成交额 (USD)
|
||||
CloseTime: candle.CloseTime,
|
||||
}
|
||||
}
|
||||
|
||||
return klines, nil
|
||||
}
|
||||
|
||||
// handleSymbols returns available symbols for a given exchange
|
||||
func (s *Server) handleSymbols(c *gin.Context) {
|
||||
exchange := c.DefaultQuery("exchange", "hyperliquid")
|
||||
|
||||
type SymbolInfo struct {
|
||||
Symbol string `json:"symbol"`
|
||||
Name string `json:"name"`
|
||||
Category string `json:"category"` // crypto, stock, forex, commodity, index
|
||||
MaxLeverage int `json:"maxLeverage,omitempty"`
|
||||
}
|
||||
|
||||
var symbols []SymbolInfo
|
||||
|
||||
switch strings.ToLower(exchange) {
|
||||
case "hyperliquid", "hyperliquid-xyz", "xyz":
|
||||
// Fetch symbols from Hyperliquid
|
||||
client := hyperliquid.NewClient()
|
||||
ctx := context.Background()
|
||||
|
||||
// Get crypto perps from default dex
|
||||
if exchange == "hyperliquid" || exchange == "hyperliquid-xyz" {
|
||||
mids, err := client.GetAllMids(ctx)
|
||||
if err == nil {
|
||||
for symbol := range mids {
|
||||
// Skip spot tokens (start with @)
|
||||
if strings.HasPrefix(symbol, "@") {
|
||||
continue
|
||||
}
|
||||
symbols = append(symbols, SymbolInfo{
|
||||
Symbol: symbol,
|
||||
Name: symbol,
|
||||
Category: "crypto",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get xyz dex symbols (stocks, forex, commodities)
|
||||
xyzMids, err := client.GetAllMidsXYZ(ctx)
|
||||
if err == nil {
|
||||
for symbol := range xyzMids {
|
||||
// Remove xyz: prefix for display
|
||||
displaySymbol := strings.TrimPrefix(symbol, "xyz:")
|
||||
category := "stock"
|
||||
if displaySymbol == "GOLD" || displaySymbol == "SILVER" {
|
||||
category = "commodity"
|
||||
} else if displaySymbol == "EUR" || displaySymbol == "JPY" {
|
||||
category = "forex"
|
||||
} else if displaySymbol == "XYZ100" {
|
||||
category = "index"
|
||||
}
|
||||
symbols = append(symbols, SymbolInfo{
|
||||
Symbol: displaySymbol,
|
||||
Name: displaySymbol,
|
||||
Category: category,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Unsupported exchange for symbol listing"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"exchange": exchange,
|
||||
"symbols": symbols,
|
||||
"count": len(symbols),
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,402 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"nofx/logger"
|
||||
"nofx/market"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// handleTraderList Trader list
|
||||
func (s *Server) handleTraderList(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
traders, err := s.store.Trader().List(userID)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Failed to get trader list", err)
|
||||
return
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(traders))
|
||||
for _, trader := range traders {
|
||||
// Get real-time running status
|
||||
isRunning := trader.IsRunning
|
||||
if at, err := s.traderManager.GetTrader(trader.ID); err == nil {
|
||||
status := at.GetStatus()
|
||||
if running, ok := status["is_running"].(bool); ok {
|
||||
isRunning = running
|
||||
}
|
||||
}
|
||||
|
||||
// Get strategy name if strategy_id is set
|
||||
var strategyName string
|
||||
if trader.StrategyID != "" {
|
||||
if strategy, err := s.store.Strategy().Get(userID, trader.StrategyID); err == nil {
|
||||
strategyName = strategy.Name
|
||||
}
|
||||
}
|
||||
|
||||
// Return complete AIModelID (e.g. "admin_deepseek"), don't truncate
|
||||
// Frontend needs complete ID to verify model exists (consistent with handleGetTraderConfig)
|
||||
result = append(result, map[string]interface{}{
|
||||
"trader_id": trader.ID,
|
||||
"trader_name": trader.Name,
|
||||
"ai_model": trader.AIModelID, // Use complete ID
|
||||
"exchange_id": trader.ExchangeID,
|
||||
"is_running": isRunning,
|
||||
"show_in_competition": trader.ShowInCompetition,
|
||||
"initial_balance": trader.InitialBalance,
|
||||
"strategy_id": trader.StrategyID,
|
||||
"strategy_name": strategyName,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// handleGetTraderConfig Get trader detailed configuration
|
||||
func (s *Server) handleGetTraderConfig(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
traderID := c.Param("id")
|
||||
|
||||
if traderID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Trader ID cannot be empty"})
|
||||
return
|
||||
}
|
||||
|
||||
fullCfg, err := s.store.Trader().GetFullConfig(userID, traderID)
|
||||
if err != nil {
|
||||
SafeNotFound(c, "Trader config")
|
||||
return
|
||||
}
|
||||
traderConfig := fullCfg.Trader
|
||||
|
||||
// Get real-time running status
|
||||
isRunning := traderConfig.IsRunning
|
||||
if at, err := s.traderManager.GetTrader(traderID); err == nil {
|
||||
status := at.GetStatus()
|
||||
if running, ok := status["is_running"].(bool); ok {
|
||||
isRunning = running
|
||||
}
|
||||
}
|
||||
|
||||
// Return complete model ID without conversion, consistent with frontend model list
|
||||
aiModelID := traderConfig.AIModelID
|
||||
|
||||
result := map[string]interface{}{
|
||||
"trader_id": traderConfig.ID,
|
||||
"trader_name": traderConfig.Name,
|
||||
"ai_model": aiModelID,
|
||||
"exchange_id": traderConfig.ExchangeID,
|
||||
"strategy_id": traderConfig.StrategyID,
|
||||
"initial_balance": traderConfig.InitialBalance,
|
||||
"scan_interval_minutes": traderConfig.ScanIntervalMinutes,
|
||||
"btc_eth_leverage": traderConfig.BTCETHLeverage,
|
||||
"altcoin_leverage": traderConfig.AltcoinLeverage,
|
||||
"trading_symbols": traderConfig.TradingSymbols,
|
||||
"custom_prompt": traderConfig.CustomPrompt,
|
||||
"override_base_prompt": traderConfig.OverrideBasePrompt,
|
||||
"is_cross_margin": traderConfig.IsCrossMargin,
|
||||
"use_ai500": traderConfig.UseAI500,
|
||||
"use_oi_top": traderConfig.UseOITop,
|
||||
"is_running": isRunning,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// handleStatus System status
|
||||
func (s *Server) handleStatus(c *gin.Context) {
|
||||
_, traderID, err := s.getTraderFromQuery(c)
|
||||
if err != nil {
|
||||
SafeBadRequest(c, "Invalid trader ID")
|
||||
return
|
||||
}
|
||||
|
||||
trader, err := s.traderManager.GetTrader(traderID)
|
||||
if err != nil {
|
||||
SafeNotFound(c, "Trader")
|
||||
return
|
||||
}
|
||||
|
||||
status := trader.GetStatus()
|
||||
c.JSON(http.StatusOK, status)
|
||||
}
|
||||
|
||||
// handleAccount Account information
|
||||
func (s *Server) handleAccount(c *gin.Context) {
|
||||
_, traderID, err := s.getTraderFromQuery(c)
|
||||
if err != nil {
|
||||
SafeBadRequest(c, "Invalid trader ID")
|
||||
return
|
||||
}
|
||||
|
||||
trader, err := s.traderManager.GetTrader(traderID)
|
||||
if err != nil {
|
||||
SafeNotFound(c, "Trader")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("📊 Received account info request [%s]", trader.GetName())
|
||||
account, err := trader.GetAccountInfo()
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get account info", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("✓ Returning account info [%s]: equity=%.2f, available=%.2f, pnl=%.2f (%.2f%%)",
|
||||
trader.GetName(),
|
||||
account["total_equity"],
|
||||
account["available_balance"],
|
||||
account["total_pnl"],
|
||||
account["total_pnl_pct"])
|
||||
c.JSON(http.StatusOK, account)
|
||||
}
|
||||
|
||||
// handlePositions Position list
|
||||
func (s *Server) handlePositions(c *gin.Context) {
|
||||
_, traderID, err := s.getTraderFromQuery(c)
|
||||
if err != nil {
|
||||
SafeBadRequest(c, "Invalid trader ID")
|
||||
return
|
||||
}
|
||||
|
||||
trader, err := s.traderManager.GetTrader(traderID)
|
||||
if err != nil {
|
||||
SafeNotFound(c, "Trader")
|
||||
return
|
||||
}
|
||||
|
||||
positions, err := trader.GetPositions()
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get positions", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, positions)
|
||||
}
|
||||
|
||||
// handlePositionHistory Historical closed positions with statistics
|
||||
func (s *Server) handlePositionHistory(c *gin.Context) {
|
||||
_, traderID, err := s.getTraderFromQuery(c)
|
||||
if err != nil {
|
||||
SafeBadRequest(c, "Invalid trader ID")
|
||||
return
|
||||
}
|
||||
|
||||
trader, err := s.traderManager.GetTrader(traderID)
|
||||
if err != nil {
|
||||
SafeNotFound(c, "Trader")
|
||||
return
|
||||
}
|
||||
|
||||
// Get optional query parameters
|
||||
limitStr := c.DefaultQuery("limit", "100")
|
||||
limit := 100
|
||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 500 {
|
||||
limit = l
|
||||
}
|
||||
|
||||
// Get store
|
||||
store := trader.GetStore()
|
||||
if store == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get closed positions
|
||||
positions, err := store.Position().GetClosedPositions(trader.GetID(), limit)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get position history", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Get statistics
|
||||
stats, _ := store.Position().GetFullStats(trader.GetID())
|
||||
|
||||
// Get symbol stats
|
||||
symbolStats, _ := store.Position().GetSymbolStats(trader.GetID(), 10)
|
||||
|
||||
// Get direction stats
|
||||
directionStats, _ := store.Position().GetDirectionStats(trader.GetID())
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"positions": positions,
|
||||
"stats": stats,
|
||||
"symbol_stats": symbolStats,
|
||||
"direction_stats": directionStats,
|
||||
})
|
||||
}
|
||||
|
||||
// handleTrades Historical trades list
|
||||
func (s *Server) handleTrades(c *gin.Context) {
|
||||
_, traderID, err := s.getTraderFromQuery(c)
|
||||
if err != nil {
|
||||
SafeBadRequest(c, "Invalid trader ID")
|
||||
return
|
||||
}
|
||||
|
||||
trader, err := s.traderManager.GetTrader(traderID)
|
||||
if err != nil {
|
||||
SafeNotFound(c, "Trader")
|
||||
return
|
||||
}
|
||||
|
||||
// Get optional query parameters
|
||||
symbol := c.Query("symbol")
|
||||
limitStr := c.DefaultQuery("limit", "100")
|
||||
limit := 100
|
||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
|
||||
limit = l
|
||||
}
|
||||
|
||||
// Normalize symbol (add USDT suffix if not present)
|
||||
if symbol != "" {
|
||||
symbol = market.Normalize(symbol)
|
||||
}
|
||||
|
||||
// Get trades from store
|
||||
store := trader.GetStore()
|
||||
if store == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"})
|
||||
return
|
||||
}
|
||||
|
||||
allTrades, err := store.Position().GetRecentTrades(trader.GetID(), limit)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get trades", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Filter by symbol if specified
|
||||
if symbol != "" {
|
||||
var result []interface{}
|
||||
for _, trade := range allTrades {
|
||||
if trade.Symbol == symbol {
|
||||
result = append(result, trade)
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, result)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, allTrades)
|
||||
}
|
||||
|
||||
// handleOrders Order list (all orders including open, close, stop loss, take profit, etc.)
|
||||
func (s *Server) handleOrders(c *gin.Context) {
|
||||
_, traderID, err := s.getTraderFromQuery(c)
|
||||
if err != nil {
|
||||
SafeBadRequest(c, "Invalid trader ID")
|
||||
return
|
||||
}
|
||||
|
||||
trader, err := s.traderManager.GetTrader(traderID)
|
||||
if err != nil {
|
||||
SafeNotFound(c, "Trader")
|
||||
return
|
||||
}
|
||||
|
||||
// Get optional query parameters
|
||||
symbol := c.Query("symbol")
|
||||
statusFilter := c.Query("status") // NEW, FILLED, CANCELED, etc.
|
||||
limitStr := c.DefaultQuery("limit", "100")
|
||||
limit := 100
|
||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 {
|
||||
limit = l
|
||||
}
|
||||
|
||||
// Normalize symbol (add USDT suffix if not present)
|
||||
if symbol != "" {
|
||||
symbol = market.Normalize(symbol)
|
||||
}
|
||||
|
||||
// Get orders from store
|
||||
store := trader.GetStore()
|
||||
if store == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get orders with filters applied at database level
|
||||
orders, err := store.Order().GetTraderOrdersFiltered(trader.GetID(), symbol, statusFilter, limit)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get orders", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, orders)
|
||||
}
|
||||
|
||||
// handleOrderFills Order fill details (all fills for a specific order)
|
||||
func (s *Server) handleOrderFills(c *gin.Context) {
|
||||
orderIDStr := c.Param("id")
|
||||
orderID, err := strconv.ParseInt(orderIDStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid order ID"})
|
||||
return
|
||||
}
|
||||
|
||||
_, traderID, err := s.getTraderFromQuery(c)
|
||||
if err != nil {
|
||||
SafeBadRequest(c, "Invalid trader ID")
|
||||
return
|
||||
}
|
||||
|
||||
trader, err := s.traderManager.GetTrader(traderID)
|
||||
if err != nil {
|
||||
SafeNotFound(c, "Trader")
|
||||
return
|
||||
}
|
||||
|
||||
store := trader.GetStore()
|
||||
if store == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Store not available"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get fills for this order
|
||||
fills, err := store.Order().GetOrderFills(orderID)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get order fills", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, fills)
|
||||
}
|
||||
|
||||
// handleOpenOrders Get open orders (pending SL/TP) from exchange
|
||||
func (s *Server) handleOpenOrders(c *gin.Context) {
|
||||
_, traderID, err := s.getTraderFromQuery(c)
|
||||
if err != nil {
|
||||
SafeBadRequest(c, "Invalid trader ID")
|
||||
return
|
||||
}
|
||||
|
||||
trader, err := s.traderManager.GetTrader(traderID)
|
||||
if err != nil {
|
||||
SafeNotFound(c, "Trader")
|
||||
return
|
||||
}
|
||||
|
||||
// Get symbol parameter (required for exchange query)
|
||||
symbol := c.Query("symbol")
|
||||
if symbol == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "symbol parameter is required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Normalize symbol
|
||||
symbol = market.Normalize(symbol)
|
||||
|
||||
// Get open orders from exchange
|
||||
openOrders, err := trader.GetOpenOrders(symbol)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Get open orders", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, openOrders)
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// handleGetTelegramConfig returns current Telegram bot configuration and binding status
|
||||
func (s *Server) handleGetTelegramConfig(c *gin.Context) {
|
||||
cfg, err := s.store.TelegramConfig().Get()
|
||||
if err != nil {
|
||||
// Not configured yet - return empty state
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"configured": false,
|
||||
"is_bound": false,
|
||||
"token_masked": "",
|
||||
"username": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Mask bot token for security (show only last 6 chars)
|
||||
tokenMasked := ""
|
||||
if cfg.BotToken != "" {
|
||||
if len(cfg.BotToken) > 6 {
|
||||
tokenMasked = "***" + cfg.BotToken[len(cfg.BotToken)-6:]
|
||||
} else {
|
||||
tokenMasked = "***"
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"configured": cfg.BotToken != "",
|
||||
"is_bound": cfg.ChatID != 0,
|
||||
"username": cfg.Username,
|
||||
"bound_at": cfg.BoundAt,
|
||||
"token_masked": tokenMasked,
|
||||
"model_id": cfg.ModelID,
|
||||
})
|
||||
}
|
||||
|
||||
// handleUpdateTelegramConfig saves bot token (+ optional model ID) and triggers bot hot-reload
|
||||
func (s *Server) handleUpdateTelegramConfig(c *gin.Context) {
|
||||
var req struct {
|
||||
BotToken string `json:"bot_token"`
|
||||
ModelID string `json:"model_id"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
if req.BotToken == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "bot_token is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.store.TelegramConfig().Save(req.BotToken, req.ModelID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save config"})
|
||||
return
|
||||
}
|
||||
|
||||
// Signal bot hot-reload if channel is available
|
||||
if s.telegramReloadCh != nil {
|
||||
select {
|
||||
case s.telegramReloadCh <- struct{}{}:
|
||||
default: // non-blocking
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "Bot token saved. Bot will reload automatically."})
|
||||
}
|
||||
|
||||
// handleUnbindTelegram removes Telegram user binding
|
||||
func (s *Server) handleUnbindTelegram(c *gin.Context) {
|
||||
if err := s.store.TelegramConfig().Unbind(); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to unbind"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "Telegram binding removed"})
|
||||
}
|
||||
|
||||
// handleUpdateTelegramModel updates only the AI model used for Telegram replies (no token re-entry needed)
|
||||
func (s *Server) handleUpdateTelegramModel(c *gin.Context) {
|
||||
var req struct {
|
||||
ModelID string `json:"model_id"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := s.store.TelegramConfig().Get()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "no Telegram config found, save a bot token first"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.store.TelegramConfig().Save(cfg.BotToken, req.ModelID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save model config"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "model_id": req.ModelID})
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,223 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nofx/auth"
|
||||
"nofx/logger"
|
||||
"nofx/store"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// handleLogout Add current token to blacklist
|
||||
func (s *Server) handleLogout(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing Authorization header"})
|
||||
return
|
||||
}
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"})
|
||||
return
|
||||
}
|
||||
tokenString := parts[1]
|
||||
claims, err := auth.ValidateJWT(tokenString)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||||
return
|
||||
}
|
||||
var exp time.Time
|
||||
if claims.ExpiresAt != nil {
|
||||
exp = claims.ExpiresAt.Time
|
||||
} else {
|
||||
exp = time.Now().Add(24 * time.Hour)
|
||||
}
|
||||
auth.BlacklistToken(tokenString, exp)
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Logged out"})
|
||||
}
|
||||
|
||||
// handleRegister Handle user registration request.
|
||||
// handleRegister allows registration only when no users exist yet (first-time setup).
|
||||
// This is a single-user system; subsequent registrations are permanently closed.
|
||||
func (s *Server) handleRegister(c *gin.Context) {
|
||||
userCount, err := s.store.User().Count()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check user count"})
|
||||
return
|
||||
}
|
||||
|
||||
if userCount > 0 {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "System already initialized"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
SafeBadRequest(c, "Invalid request parameters")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if email already exists
|
||||
_, err = s.store.User().GetByEmail(req.Email)
|
||||
if err == nil {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "Email already registered"})
|
||||
return
|
||||
}
|
||||
|
||||
// Generate password hash
|
||||
passwordHash, err := auth.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Password processing failed"})
|
||||
return
|
||||
}
|
||||
|
||||
// Create user
|
||||
userID := uuid.New().String()
|
||||
user := &store.User{
|
||||
ID: userID,
|
||||
Email: req.Email,
|
||||
PasswordHash: passwordHash,
|
||||
}
|
||||
|
||||
err = s.store.User().Create(user)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Failed to create user", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate JWT token
|
||||
token, err := auth.GenerateJWT(user.ID, user.Email)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize default model and exchange configs for user
|
||||
err = s.initUserDefaultConfigs(user.ID)
|
||||
if err != nil {
|
||||
logger.Infof("Failed to initialize user default configs: %v", err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": token,
|
||||
"user_id": user.ID,
|
||||
"email": user.Email,
|
||||
"message": "Registration successful",
|
||||
})
|
||||
}
|
||||
|
||||
// handleLogin Handle user login request
|
||||
func (s *Server) handleLogin(c *gin.Context) {
|
||||
var req struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
SafeBadRequest(c, "Invalid request parameters")
|
||||
return
|
||||
}
|
||||
|
||||
// Get user information
|
||||
user, err := s.store.User().GetByEmail(req.Email)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Email or password incorrect"})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify password
|
||||
if !auth.CheckPassword(req.Password, user.PasswordHash) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Email or password incorrect"})
|
||||
return
|
||||
}
|
||||
|
||||
// Issue token directly after password verification.
|
||||
token, err := auth.GenerateJWT(user.ID, user.Email)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": token,
|
||||
"user_id": user.ID,
|
||||
"email": user.Email,
|
||||
"message": "Login successful",
|
||||
})
|
||||
}
|
||||
|
||||
// handleChangePassword changes the password for the currently authenticated user.
|
||||
func (s *Server) handleChangePassword(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
var req struct {
|
||||
NewPassword string `json:"new_password" binding:"required,min=8"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
SafeBadRequest(c, "new_password is required (min 8 chars)")
|
||||
return
|
||||
}
|
||||
hash, err := auth.HashPassword(req.NewPassword)
|
||||
if err != nil {
|
||||
SafeInternalError(c, "Password processing failed", err)
|
||||
return
|
||||
}
|
||||
if err := s.store.User().UpdatePassword(userID, hash); err != nil {
|
||||
SafeInternalError(c, "Failed to update password", err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Password updated"})
|
||||
}
|
||||
|
||||
// handleResetPassword Reset password via email and new password
|
||||
func (s *Server) handleResetPassword(c *gin.Context) {
|
||||
var req struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
SafeBadRequest(c, "Invalid request parameters")
|
||||
return
|
||||
}
|
||||
|
||||
// Query user
|
||||
user, err := s.store.User().GetByEmail(req.Email)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Email does not exist"})
|
||||
return
|
||||
}
|
||||
|
||||
// Generate new password hash
|
||||
newPasswordHash, err := auth.HashPassword(req.NewPassword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Password processing failed"})
|
||||
return
|
||||
}
|
||||
|
||||
// Update password
|
||||
err = s.store.User().UpdatePassword(user.ID, newPasswordHash)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Password update failed"})
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("✓ User %s password has been reset", user.Email)
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Password reset successful, please login with new password"})
|
||||
}
|
||||
|
||||
// initUserDefaultConfigs Initialize default model and exchange configs for new user
|
||||
func (s *Server) initUserDefaultConfigs(userID string) error {
|
||||
// Commented out auto-creation of default configs, let users add manually
|
||||
// This way new users won't have config items automatically after registration
|
||||
logger.Infof("User %s registration completed, waiting for manual AI model and exchange configuration", userID)
|
||||
return nil
|
||||
}
|
||||
-3271
File diff suppressed because it is too large
Load Diff
+11
-38
@@ -8,6 +8,8 @@ import (
|
||||
"nofx/logger"
|
||||
"nofx/market"
|
||||
"nofx/mcp"
|
||||
_ "nofx/mcp/payment"
|
||||
_ "nofx/mcp/provider"
|
||||
"nofx/store"
|
||||
"time"
|
||||
|
||||
@@ -637,49 +639,20 @@ func (s *Server) runRealAITest(userID, modelID, systemPrompt, userPrompt string)
|
||||
return "", fmt.Errorf("AI model %s is missing API Key", model.Name)
|
||||
}
|
||||
|
||||
// Create AI client
|
||||
var aiClient mcp.AIClient
|
||||
// Create AI client via registry
|
||||
provider := model.Provider
|
||||
|
||||
// Convert EncryptedString to string for API key
|
||||
apiKey := string(model.APIKey)
|
||||
|
||||
aiClient := mcp.NewAIClientByProvider(provider)
|
||||
if aiClient == nil {
|
||||
aiClient = mcp.NewClient()
|
||||
}
|
||||
|
||||
// Payment providers ignore custom URL
|
||||
switch provider {
|
||||
case "qwen":
|
||||
aiClient = mcp.NewQwenClient()
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
case "deepseek":
|
||||
aiClient = mcp.NewDeepSeekClient()
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
case "claude":
|
||||
aiClient = mcp.NewClaudeClient()
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
case "kimi":
|
||||
aiClient = mcp.NewKimiClient()
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
case "gemini":
|
||||
aiClient = mcp.NewGeminiClient()
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
case "grok":
|
||||
aiClient = mcp.NewGrokClient()
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
case "openai":
|
||||
aiClient = mcp.NewOpenAIClient()
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
case "minimax":
|
||||
aiClient = mcp.NewMiniMaxClient()
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
case "blockrun-base":
|
||||
aiClient = mcp.NewBlockRunBaseClient()
|
||||
aiClient.SetAPIKey(apiKey, "", model.CustomModelName)
|
||||
case "blockrun-sol":
|
||||
aiClient = mcp.NewBlockRunSolClient()
|
||||
aiClient.SetAPIKey(apiKey, "", model.CustomModelName)
|
||||
case "claw402":
|
||||
aiClient = mcp.NewClaw402Client()
|
||||
case "blockrun-base", "blockrun-sol", "claw402":
|
||||
aiClient.SetAPIKey(apiKey, "", model.CustomModelName)
|
||||
default:
|
||||
// Use generic client
|
||||
aiClient = mcp.NewClient()
|
||||
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||
}
|
||||
|
||||
|
||||
+34
-126
@@ -5,15 +5,17 @@ import (
|
||||
"strings"
|
||||
|
||||
"nofx/mcp"
|
||||
_ "nofx/mcp/payment"
|
||||
_ "nofx/mcp/provider"
|
||||
)
|
||||
|
||||
// configureMCPClient creates/clones an MCP client based on configuration (returns mcp.AIClient interface).
|
||||
// Note: mcp.New() returns an interface type; here we convert to concrete implementation before copying to avoid concurrent shared state.
|
||||
func configureMCPClient(cfg BacktestConfig, base mcp.AIClient) (mcp.AIClient, error) {
|
||||
provider := strings.ToLower(strings.TrimSpace(cfg.AICfg.Provider))
|
||||
providerName := strings.ToLower(strings.TrimSpace(cfg.AICfg.Provider))
|
||||
|
||||
// DeepSeek
|
||||
if provider == "" || provider == "inherit" || provider == "default" {
|
||||
// Inherit base client
|
||||
if providerName == "" || providerName == "inherit" || providerName == "default" {
|
||||
client := cloneBaseClient(base)
|
||||
if cfg.AICfg.APIKey != "" || cfg.AICfg.BaseURL != "" || cfg.AICfg.Model != "" {
|
||||
client.SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
|
||||
@@ -21,143 +23,49 @@ func configureMCPClient(cfg BacktestConfig, base mcp.AIClient) (mcp.AIClient, er
|
||||
return client, nil
|
||||
}
|
||||
|
||||
switch provider {
|
||||
case "deepseek":
|
||||
if cfg.AICfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("deepseek provider requires api key")
|
||||
}
|
||||
ds := mcp.NewDeepSeekClientWithOptions()
|
||||
ds.(*mcp.DeepSeekClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
|
||||
return ds, nil
|
||||
case "qwen":
|
||||
if cfg.AICfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("qwen provider requires api key")
|
||||
}
|
||||
qc := mcp.NewQwenClientWithOptions()
|
||||
qc.(*mcp.QwenClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
|
||||
return qc, nil
|
||||
case "claude":
|
||||
if cfg.AICfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("claude provider requires api key")
|
||||
}
|
||||
cc := mcp.NewClaudeClientWithOptions()
|
||||
cc.(*mcp.ClaudeClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
|
||||
return cc, nil
|
||||
case "kimi":
|
||||
if cfg.AICfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("kimi provider requires api key")
|
||||
}
|
||||
kc := mcp.NewKimiClientWithOptions()
|
||||
kc.(*mcp.KimiClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
|
||||
return kc, nil
|
||||
case "gemini":
|
||||
if cfg.AICfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("gemini provider requires api key")
|
||||
}
|
||||
gc := mcp.NewGeminiClientWithOptions()
|
||||
gc.(*mcp.GeminiClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
|
||||
return gc, nil
|
||||
case "grok":
|
||||
if cfg.AICfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("grok provider requires api key")
|
||||
}
|
||||
grokC := mcp.NewGrokClientWithOptions()
|
||||
grokC.(*mcp.GrokClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
|
||||
return grokC, nil
|
||||
case "openai":
|
||||
if cfg.AICfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("openai provider requires api key")
|
||||
}
|
||||
oaiC := mcp.NewOpenAIClientWithOptions()
|
||||
oaiC.(*mcp.OpenAIClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
|
||||
return oaiC, nil
|
||||
case "minimax":
|
||||
if cfg.AICfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("minimax provider requires api key")
|
||||
}
|
||||
mmC := mcp.NewMiniMaxClientWithOptions()
|
||||
mmC.(*mcp.MiniMaxClient).SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
|
||||
return mmC, nil
|
||||
case "blockrun-base":
|
||||
if cfg.AICfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("blockrun-base provider requires wallet private key")
|
||||
}
|
||||
brBase := mcp.NewBlockRunBaseClient()
|
||||
brBase.SetAPIKey(cfg.AICfg.APIKey, "", cfg.AICfg.Model)
|
||||
return brBase, nil
|
||||
case "blockrun-sol":
|
||||
if cfg.AICfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("blockrun-sol provider requires wallet keypair")
|
||||
}
|
||||
brSol := mcp.NewBlockRunSolClient()
|
||||
brSol.SetAPIKey(cfg.AICfg.APIKey, "", cfg.AICfg.Model)
|
||||
return brSol, nil
|
||||
case "claw402":
|
||||
if cfg.AICfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("claw402 provider requires wallet private key")
|
||||
}
|
||||
claw := mcp.NewClaw402Client()
|
||||
claw.SetAPIKey(cfg.AICfg.APIKey, "", cfg.AICfg.Model)
|
||||
return claw, nil
|
||||
case "custom":
|
||||
// Custom provider uses cloned base
|
||||
if providerName == "custom" {
|
||||
if cfg.AICfg.BaseURL == "" || cfg.AICfg.APIKey == "" || cfg.AICfg.Model == "" {
|
||||
return nil, fmt.Errorf("custom provider requires base_url, api key and model")
|
||||
}
|
||||
client := cloneBaseClient(base)
|
||||
client.SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
|
||||
return client, nil
|
||||
default:
|
||||
}
|
||||
|
||||
// Create client via registry
|
||||
client := mcp.NewAIClientByProvider(providerName)
|
||||
if client == nil {
|
||||
return nil, fmt.Errorf("unsupported ai provider %s", cfg.AICfg.Provider)
|
||||
}
|
||||
|
||||
if cfg.AICfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("%s provider requires api key", providerName)
|
||||
}
|
||||
|
||||
// Payment providers ignore custom URL
|
||||
switch providerName {
|
||||
case "blockrun-base", "blockrun-sol", "claw402":
|
||||
client.SetAPIKey(cfg.AICfg.APIKey, "", cfg.AICfg.Model)
|
||||
default:
|
||||
client.SetAPIKey(cfg.AICfg.APIKey, cfg.AICfg.BaseURL, cfg.AICfg.Model)
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// cloneBaseClient copies the base client to avoid shared mutable state.
|
||||
// Uses the ClientEmbedder interface to extract the underlying *mcp.Client
|
||||
// from any provider type that embeds it.
|
||||
func cloneBaseClient(base mcp.AIClient) *mcp.Client {
|
||||
// Prefer to reuse the passed-in base client (deep copy)
|
||||
switch c := base.(type) {
|
||||
case *mcp.Client:
|
||||
if embedder, ok := base.(mcp.ClientEmbedder); ok {
|
||||
if inner := embedder.BaseClient(); inner != nil {
|
||||
cp := *inner
|
||||
return &cp
|
||||
}
|
||||
}
|
||||
if c, ok := base.(*mcp.Client); ok {
|
||||
cp := *c
|
||||
return &cp
|
||||
case *mcp.DeepSeekClient:
|
||||
if c != nil && c.Client != nil {
|
||||
cp := *c.Client
|
||||
return &cp
|
||||
}
|
||||
case *mcp.QwenClient:
|
||||
if c != nil && c.Client != nil {
|
||||
cp := *c.Client
|
||||
return &cp
|
||||
}
|
||||
case *mcp.ClaudeClient:
|
||||
if c != nil && c.Client != nil {
|
||||
cp := *c.Client
|
||||
return &cp
|
||||
}
|
||||
case *mcp.KimiClient:
|
||||
if c != nil && c.Client != nil {
|
||||
cp := *c.Client
|
||||
return &cp
|
||||
}
|
||||
case *mcp.GeminiClient:
|
||||
if c != nil && c.Client != nil {
|
||||
cp := *c.Client
|
||||
return &cp
|
||||
}
|
||||
case *mcp.GrokClient:
|
||||
if c != nil && c.Client != nil {
|
||||
cp := *c.Client
|
||||
return &cp
|
||||
}
|
||||
case *mcp.OpenAIClient:
|
||||
if c != nil && c.Client != nil {
|
||||
cp := *c.Client
|
||||
return &cp
|
||||
}
|
||||
case *mcp.MiniMaxClient:
|
||||
if c != nil && c.Client != nil {
|
||||
cp := *c.Client
|
||||
return &cp
|
||||
}
|
||||
}
|
||||
// Fall back to a new default client
|
||||
return mcp.NewClient().(*mcp.Client)
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 阿里云 API 配置
|
||||
const (
|
||||
DefaultQwenBaseURL = "https://dashscope.aliyuncs.com/api/v1/apps"
|
||||
// 标准 OpenAI 兼容模式 API
|
||||
QwenCompatibleURL = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"
|
||||
)
|
||||
|
||||
// QwenAgent 阿里云百炼智能体客户端
|
||||
type QwenAgent struct {
|
||||
AppID string
|
||||
APIKey string
|
||||
BaseURL string
|
||||
SessionID string
|
||||
Client *http.Client
|
||||
}
|
||||
|
||||
// QwenRequest 请求结构
|
||||
type QwenRequest struct {
|
||||
Input QwenInput `json:"input"`
|
||||
Parameters QwenParameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
// QwenInput 输入结构
|
||||
type QwenInput struct {
|
||||
Prompt string `json:"prompt"`
|
||||
BizParams map[string]interface{} `json:"biz_params,omitempty"`
|
||||
}
|
||||
|
||||
// QwenParameters 参数结构
|
||||
type QwenParameters struct {
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||
}
|
||||
|
||||
// QwenResponse 响应结构
|
||||
type QwenResponse struct {
|
||||
Output QwenOutput `json:"output"`
|
||||
Usage QwenUsage `json:"usage,omitempty"`
|
||||
RequestID string `json:"request_id"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// QwenOutput 输出结构
|
||||
type QwenOutput struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
}
|
||||
|
||||
// QwenUsage 用量统计
|
||||
type QwenUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// NewQwenAgent 创建新的智能体客户端
|
||||
func NewQwenAgent(appID, apiKey string) *QwenAgent {
|
||||
return &QwenAgent{
|
||||
AppID: appID,
|
||||
APIKey: apiKey,
|
||||
BaseURL: DefaultQwenBaseURL,
|
||||
Client: &http.Client{
|
||||
Timeout: 180 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Chat 同步对话
|
||||
func (a *QwenAgent) Chat(ctx context.Context, prompt string) (*QwenResponse, error) {
|
||||
reqBody := QwenRequest{
|
||||
Input: QwenInput{
|
||||
Prompt: prompt,
|
||||
},
|
||||
Parameters: QwenParameters{
|
||||
SessionID: a.SessionID,
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s/completion", a.BaseURL, a.AppID)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+a.APIKey)
|
||||
|
||||
resp, err := a.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response failed: %w", err)
|
||||
}
|
||||
|
||||
var result QwenResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w, body: %s", err, string(body))
|
||||
}
|
||||
|
||||
// 更新 session_id 用于多轮对话
|
||||
if result.Output.SessionID != "" {
|
||||
a.SessionID = result.Output.SessionID
|
||||
}
|
||||
|
||||
// 检查 API 错误
|
||||
if result.Code != "" {
|
||||
return &result, fmt.Errorf("API error: code=%s, message=%s", result.Code, result.Message)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ChatStream 流式对话
|
||||
func (a *QwenAgent) ChatStream(ctx context.Context, prompt string, callback func(chunk string)) error {
|
||||
reqBody := QwenRequest{
|
||||
Input: QwenInput{
|
||||
Prompt: prompt,
|
||||
},
|
||||
Parameters: QwenParameters{
|
||||
SessionID: a.SessionID,
|
||||
IncrementalOutput: true,
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s/completion", a.BaseURL, a.AppID)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+a.APIKey)
|
||||
req.Header.Set("X-DashScope-SSE", "enable")
|
||||
|
||||
resp, err := a.Client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return fmt.Errorf("read stream failed: %w", err)
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data:")
|
||||
var chunk QwenResponse
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新 session_id
|
||||
if chunk.Output.SessionID != "" {
|
||||
a.SessionID = chunk.Output.SessionID
|
||||
}
|
||||
|
||||
// 回调输出文本
|
||||
if chunk.Output.Text != "" {
|
||||
callback(chunk.Output.Text)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChatWithBizParams 带业务参数的对话
|
||||
func (a *QwenAgent) ChatWithBizParams(ctx context.Context, prompt string, bizParams map[string]interface{}) (*QwenResponse, error) {
|
||||
reqBody := QwenRequest{
|
||||
Input: QwenInput{
|
||||
Prompt: prompt,
|
||||
BizParams: bizParams,
|
||||
},
|
||||
Parameters: QwenParameters{
|
||||
SessionID: a.SessionID,
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s/completion", a.BaseURL, a.AppID)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+a.APIKey)
|
||||
|
||||
resp, err := a.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response failed: %w", err)
|
||||
}
|
||||
|
||||
var result QwenResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w, body: %s", err, string(body))
|
||||
}
|
||||
|
||||
if result.Output.SessionID != "" {
|
||||
a.SessionID = result.Output.SessionID
|
||||
}
|
||||
|
||||
if result.Code != "" {
|
||||
return &result, fmt.Errorf("API error: code=%s, message=%s", result.Code, result.Message)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ResetSession 重置会话
|
||||
func (a *QwenAgent) ResetSession() {
|
||||
a.SessionID = ""
|
||||
}
|
||||
|
||||
// ========== 标准 OpenAI 兼容 API ==========
|
||||
|
||||
// ChatCompletionRequest OpenAI 兼容格式请求
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatCompletionMessage `json:"messages"`
|
||||
}
|
||||
|
||||
// ChatCompletionMessage 消息结构
|
||||
type ChatCompletionMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ChatCompletionResponse OpenAI 兼容格式响应
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
Error *struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ChatWithModel 使用标准 OpenAI 兼容 API 调用指定模型
|
||||
func (a *QwenAgent) ChatWithModel(ctx context.Context, model, prompt string) (*ChatCompletionResponse, error) {
|
||||
reqBody := ChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: []ChatCompletionMessage{
|
||||
{Role: "user", Content: prompt},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request failed: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", QwenCompatibleURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+a.APIKey)
|
||||
|
||||
resp, err := a.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response failed: %w", err)
|
||||
}
|
||||
|
||||
var result ChatCompletionResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response failed: %w, body: %s", err, string(body))
|
||||
}
|
||||
|
||||
if result.Error != nil {
|
||||
return &result, fmt.Errorf("API error: code=%s, message=%s", result.Error.Code, result.Error.Message)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// GetContent 从响应中获取内容
|
||||
func (r *ChatCompletionResponse) GetContent() string {
|
||||
if len(r.Choices) > 0 {
|
||||
return r.Choices[0].Message.Content
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,425 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 阿里云百炼平台配置 (从环境变量获取)
|
||||
var (
|
||||
QwenAppID = os.Getenv("QWEN_APP_ID")
|
||||
QwenAPIKey = os.Getenv("QWEN_API_KEY")
|
||||
)
|
||||
|
||||
// ============== 测试用例 ==============
|
||||
|
||||
// TestQwenBasicChat 测试基本同步对话
|
||||
func TestQwenBasicChat(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
prompt := "你好,请用一句话介绍你自己"
|
||||
t.Logf("用户: %s", prompt)
|
||||
|
||||
start := time.Now()
|
||||
resp, err := agent.Chat(ctx, prompt)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Chat failed: %v", err)
|
||||
}
|
||||
|
||||
if resp.Output.Text == "" {
|
||||
t.Fatal("Empty response text")
|
||||
}
|
||||
|
||||
t.Logf("助手: %s", resp.Output.Text)
|
||||
t.Logf("耗时: %v, Token: %d", elapsed, resp.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
// TestQwenStreamChat 测试流式输出
|
||||
func TestQwenStreamChat(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
prompt := "请用3句话解释什么是量化交易"
|
||||
t.Logf("用户: %s", prompt)
|
||||
|
||||
var fullText strings.Builder
|
||||
start := time.Now()
|
||||
|
||||
err := agent.ChatStream(ctx, prompt, func(chunk string) {
|
||||
fullText.WriteString(chunk)
|
||||
})
|
||||
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ChatStream failed: %v", err)
|
||||
}
|
||||
|
||||
if fullText.Len() == 0 {
|
||||
t.Fatal("Empty stream response")
|
||||
}
|
||||
|
||||
t.Logf("助手: %s", fullText.String())
|
||||
t.Logf("耗时: %v, 字符数: %d", elapsed, fullText.Len())
|
||||
}
|
||||
|
||||
// TestQwenMultiTurn 测试多轮对话(上下文记忆)
|
||||
func TestQwenMultiTurn(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
// 第一轮:设置上下文
|
||||
resp1, err := agent.Chat(ctx, "我叫小明,我是一名 Go 程序员,请记住这些信息")
|
||||
if err != nil {
|
||||
t.Fatalf("Round 1 failed: %v", err)
|
||||
}
|
||||
t.Logf("[Round 1] 用户: 我叫小明,我是一名 Go 程序员")
|
||||
t.Logf("[Round 1] 助手: %s", resp1.Output.Text)
|
||||
t.Logf("[Round 1] SessionID: %s", agent.SessionID)
|
||||
|
||||
// 第二轮:验证记忆
|
||||
resp2, err := agent.Chat(ctx, "请问我叫什么名字?我是做什么的?")
|
||||
if err != nil {
|
||||
t.Fatalf("Round 2 failed: %v", err)
|
||||
}
|
||||
t.Logf("[Round 2] 用户: 请问我叫什么名字?我是做什么的?")
|
||||
t.Logf("[Round 2] 助手: %s", resp2.Output.Text)
|
||||
|
||||
// 检查是否记住了信息
|
||||
text := strings.ToLower(resp2.Output.Text)
|
||||
if !strings.Contains(text, "小明") && !strings.Contains(text, "go") {
|
||||
t.Logf("警告: 模型可能没有正确记住上下文")
|
||||
}
|
||||
}
|
||||
|
||||
// TestQwenResetSession 测试重置会话
|
||||
func TestQwenResetSession(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
// 建立上下文
|
||||
resp1, err := agent.Chat(ctx, "记住这个密码: ABC123XYZ")
|
||||
if err != nil {
|
||||
t.Fatalf("Setup context failed: %v", err)
|
||||
}
|
||||
t.Logf("设置上下文: %s", resp1.Output.Text)
|
||||
|
||||
oldSession := agent.SessionID
|
||||
t.Logf("原 SessionID: %s", oldSession)
|
||||
|
||||
// 重置会话
|
||||
agent.ResetSession()
|
||||
t.Log("会话已重置")
|
||||
|
||||
// 新对话 - 应该不记得之前的内容
|
||||
resp2, err := agent.Chat(ctx, "我之前告诉你的密码是什么?")
|
||||
if err != nil {
|
||||
t.Fatalf("New session chat failed: %v", err)
|
||||
}
|
||||
t.Logf("新对话回复: %s", resp2.Output.Text)
|
||||
t.Logf("新 SessionID: %s", agent.SessionID)
|
||||
|
||||
if oldSession == agent.SessionID {
|
||||
t.Error("Session was not reset properly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestQwenCodeGeneration 测试代码生成能力
|
||||
func TestQwenCodeGeneration(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
prompt := "请用 Go 语言写一个计算移动平均线(MA)的函数,输入是 []float64 价格切片和 int 周期"
|
||||
t.Logf("用户: %s", prompt)
|
||||
|
||||
resp, err := agent.Chat(ctx, prompt)
|
||||
if err != nil {
|
||||
t.Fatalf("Code generation failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("助手:\n%s", resp.Output.Text)
|
||||
|
||||
// 检查是否包含代码特征
|
||||
text := resp.Output.Text
|
||||
if !strings.Contains(text, "func") || !strings.Contains(text, "float64") {
|
||||
t.Log("警告: 响应可能不包含有效的 Go 代码")
|
||||
}
|
||||
}
|
||||
|
||||
// TestQwenJSONOutput 测试 JSON 格式输出
|
||||
func TestQwenJSONOutput(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
prompt := `请分析 BTC 的基本信息,以纯 JSON 格式返回(不要 markdown 代码块),包含以下字段:
|
||||
{"name": "资产名称", "type": "资产类型", "risk": 1-10的风险等级数字}
|
||||
只返回 JSON 对象,不要任何其他文字`
|
||||
|
||||
t.Logf("用户: %s", prompt)
|
||||
|
||||
resp, err := agent.Chat(ctx, prompt)
|
||||
if err != nil {
|
||||
t.Fatalf("JSON output test failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("助手: %s", resp.Output.Text)
|
||||
|
||||
// 尝试解析 JSON
|
||||
text := resp.Output.Text
|
||||
// 提取 JSON 部分
|
||||
start := strings.Index(text, "{")
|
||||
end := strings.LastIndex(text, "}")
|
||||
if start != -1 && end != -1 && end > start {
|
||||
jsonStr := text[start : end+1]
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &result); err != nil {
|
||||
t.Logf("JSON 解析失败: %v", err)
|
||||
} else {
|
||||
t.Logf("JSON 解析成功: %+v", result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestQwenLongResponse 测试长文本生成
|
||||
func TestQwenLongResponse(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
prompt := "请详细介绍加密货币永续合约交易中的风险管理策略,包括止损设置、仓位管理、杠杆选择、资金费率考虑等方面,至少500字"
|
||||
t.Logf("用户: %s", prompt)
|
||||
|
||||
start := time.Now()
|
||||
resp, err := agent.Chat(ctx, prompt)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Long response test failed: %v", err)
|
||||
}
|
||||
|
||||
text := resp.Output.Text
|
||||
t.Logf("响应长度: %d 字符", len(text))
|
||||
t.Logf("耗时: %v", elapsed)
|
||||
t.Logf("Token 使用: input=%d, output=%d, total=%d",
|
||||
resp.Usage.InputTokens, resp.Usage.OutputTokens, resp.Usage.TotalTokens)
|
||||
|
||||
// 只显示前500字符
|
||||
if len(text) > 500 {
|
||||
t.Logf("助手(前500字): %s...", text[:500])
|
||||
} else {
|
||||
t.Logf("助手: %s", text)
|
||||
}
|
||||
}
|
||||
|
||||
// TestQwenTradingScenario 测试交易场景问答
|
||||
func TestQwenTradingScenario(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
questions := []string{
|
||||
"BTC 当前价格 95000 美元,RSI 在 75 附近,MACD 金叉,你建议现在开多还是开空?简短回答",
|
||||
"如果我有 10000 USDT,想用 10 倍杠杆做多 ETH,建议开多大仓位?",
|
||||
"什么是资金费率?正的资金费率对多头有什么影响?",
|
||||
}
|
||||
|
||||
for i, q := range questions {
|
||||
agent.ResetSession() // 每个问题独立
|
||||
|
||||
t.Logf("\n[问题%d] %s", i+1, q)
|
||||
resp, err := agent.Chat(ctx, q)
|
||||
if err != nil {
|
||||
t.Errorf("Question %d failed: %v", i+1, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 截取显示
|
||||
text := resp.Output.Text
|
||||
if len(text) > 300 {
|
||||
text = text[:300] + "..."
|
||||
}
|
||||
t.Logf("[回答%d] %s", i+1, text)
|
||||
}
|
||||
}
|
||||
|
||||
// TestQwenErrorHandling 测试错误处理
|
||||
func TestQwenErrorHandling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// 测试无效 API Key
|
||||
t.Run("InvalidAPIKey", func(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, "invalid-api-key")
|
||||
_, err := agent.Chat(ctx, "测试")
|
||||
if err == nil {
|
||||
t.Log("警告: 无效 API Key 没有返回错误")
|
||||
} else {
|
||||
t.Logf("预期错误: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// 测试无效 App ID
|
||||
t.Run("InvalidAppID", func(t *testing.T) {
|
||||
agent := NewQwenAgent("invalid-app-id", QwenAPIKey)
|
||||
_, err := agent.Chat(ctx, "测试")
|
||||
if err == nil {
|
||||
t.Log("警告: 无效 App ID 没有返回错误")
|
||||
} else {
|
||||
t.Logf("预期错误: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestQwenSpecialCharacters 测试特殊字符处理
|
||||
func TestQwenSpecialCharacters(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
testCases := []string{
|
||||
"请解释这个表情: 😀🎉🚀",
|
||||
"中英文混合: Hello世界!",
|
||||
"特殊符号: <>&\"'",
|
||||
}
|
||||
|
||||
for _, prompt := range testCases {
|
||||
agent.ResetSession()
|
||||
t.Logf("用户: %s", prompt)
|
||||
|
||||
resp, err := agent.Chat(ctx, prompt)
|
||||
if err != nil {
|
||||
t.Errorf("特殊字符测试失败: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(resp.Output.Text) > 100 {
|
||||
t.Logf("助手: %s...", resp.Output.Text[:100])
|
||||
} else {
|
||||
t.Logf("助手: %s", resp.Output.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestQwenConcurrentSessions 测试并发会话
|
||||
func TestQwenConcurrentSessions(t *testing.T) {
|
||||
agent1 := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
agent2 := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
// Agent1 对话
|
||||
resp1, err := agent1.Chat(ctx, "我是 Alice,请记住")
|
||||
if err != nil {
|
||||
t.Fatalf("Agent1 chat failed: %v", err)
|
||||
}
|
||||
t.Logf("[Agent1] 设置: 我是 Alice -> %s", resp1.Output.Text[:min(100, len(resp1.Output.Text))])
|
||||
|
||||
// Agent2 对话
|
||||
resp2, err := agent2.Chat(ctx, "我是 Bob,请记住")
|
||||
if err != nil {
|
||||
t.Fatalf("Agent2 chat failed: %v", err)
|
||||
}
|
||||
t.Logf("[Agent2] 设置: 我是 Bob -> %s", resp2.Output.Text[:min(100, len(resp2.Output.Text))])
|
||||
|
||||
// 验证会话隔离
|
||||
resp1Check, _ := agent1.Chat(ctx, "我叫什么?")
|
||||
resp2Check, _ := agent2.Chat(ctx, "我叫什么?")
|
||||
|
||||
t.Logf("[Agent1] 验证: %s", resp1Check.Output.Text[:min(100, len(resp1Check.Output.Text))])
|
||||
t.Logf("[Agent2] 验证: %s", resp2Check.Output.Text[:min(100, len(resp2Check.Output.Text))])
|
||||
|
||||
if agent1.SessionID == agent2.SessionID {
|
||||
t.Error("两个 Agent 的 SessionID 不应该相同")
|
||||
} else {
|
||||
t.Logf("Session 隔离正常: Agent1=%s..., Agent2=%s...",
|
||||
agent1.SessionID[:min(20, len(agent1.SessionID))],
|
||||
agent2.SessionID[:min(20, len(agent2.SessionID))])
|
||||
}
|
||||
}
|
||||
|
||||
// TestQwenTimeout 测试超时处理
|
||||
func TestQwenTimeout(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
agent.Client.Timeout = 1 * time.Millisecond // 极短超时
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := agent.Chat(ctx, "测试超时")
|
||||
|
||||
if err == nil {
|
||||
t.Log("警告: 极短超时没有触发错误")
|
||||
} else {
|
||||
t.Logf("预期超时错误: %v", err)
|
||||
}
|
||||
|
||||
// 恢复正常超时
|
||||
agent.Client.Timeout = 120 * time.Second
|
||||
}
|
||||
|
||||
// TestQwenContextCancel 测试上下文取消
|
||||
func TestQwenContextCancel(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // 立即取消
|
||||
|
||||
_, err := agent.Chat(ctx, "测试取消")
|
||||
if err == nil {
|
||||
t.Error("取消的上下文应该返回错误")
|
||||
} else {
|
||||
t.Logf("预期取消错误: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestQwenWithBizParams 测试带业务参数的调用
|
||||
func TestQwenWithBizParams(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
// 构造带业务参数的请求
|
||||
reqBody := QwenRequest{
|
||||
Input: QwenInput{
|
||||
Prompt: "根据提供的用户信息,给出个性化的投资建议",
|
||||
BizParams: map[string]interface{}{
|
||||
"user_risk_level": "moderate",
|
||||
"capital": 10000,
|
||||
"experience": "intermediate",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, _ := json.Marshal(reqBody)
|
||||
url := fmt.Sprintf("%s/%s/completion", agent.BaseURL, agent.AppID)
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+agent.APIKey)
|
||||
|
||||
resp, err := agent.Client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request with biz params failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
var result QwenResponse
|
||||
json.Unmarshal(body, &result)
|
||||
|
||||
if result.Output.Text != "" {
|
||||
t.Logf("带业务参数响应: %s", result.Output.Text[:min(200, len(result.Output.Text))])
|
||||
} else {
|
||||
t.Logf("响应: %s", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
@@ -1,737 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"nofx/market"
|
||||
"nofx/provider/coinank"
|
||||
"nofx/provider/coinank/coinank_api"
|
||||
"nofx/provider/coinank/coinank_enum"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IndicatorResult AI 计算的指标结果
|
||||
type IndicatorResult struct {
|
||||
EMA12 float64 `json:"ema12"`
|
||||
EMA26 float64 `json:"ema26"`
|
||||
MACD float64 `json:"macd"`
|
||||
RSI14 float64 `json:"rsi14"`
|
||||
BOLLUp float64 `json:"boll_upper"`
|
||||
BOLLMid float64 `json:"boll_middle"`
|
||||
BOLLLow float64 `json:"boll_lower"`
|
||||
ATR14 float64 `json:"atr14"`
|
||||
SMA20 float64 `json:"sma20"`
|
||||
}
|
||||
|
||||
// 本地计算指标(使用 market 包的函数)
|
||||
func calculateLocalIndicators(klines []market.Kline) IndicatorResult {
|
||||
result := IndicatorResult{}
|
||||
|
||||
if len(klines) >= 12 {
|
||||
result.EMA12 = market.ExportCalculateEMA(klines, 12)
|
||||
}
|
||||
if len(klines) >= 26 {
|
||||
result.EMA26 = market.ExportCalculateEMA(klines, 26)
|
||||
result.MACD = market.ExportCalculateMACD(klines)
|
||||
}
|
||||
if len(klines) > 14 {
|
||||
result.RSI14 = market.ExportCalculateRSI(klines, 14)
|
||||
}
|
||||
if len(klines) >= 20 {
|
||||
result.BOLLUp, result.BOLLMid, result.BOLLLow = market.ExportCalculateBOLL(klines, 20, 2.0)
|
||||
// SMA20 就是 BOLL 中轨
|
||||
result.SMA20 = result.BOLLMid
|
||||
}
|
||||
if len(klines) > 14 {
|
||||
result.ATR14 = market.ExportCalculateATR(klines, 14)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// 格式化 K 线数据为文本,发给 AI
|
||||
func formatKlinesForAI(klines []market.Kline) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("以下是K线数据(从旧到新排列):\n")
|
||||
sb.WriteString("序号 | 时间 | 开盘价 | 最高价 | 最低价 | 收盘价 | 成交量\n")
|
||||
sb.WriteString("-----|------|--------|--------|--------|--------|--------\n")
|
||||
|
||||
for i, k := range klines {
|
||||
t := time.UnixMilli(k.OpenTime)
|
||||
sb.WriteString(fmt.Sprintf("%d | %s | %.2f | %.2f | %.2f | %.2f | %.2f\n",
|
||||
i+1, t.Format("01-02 15:04"), k.Open, k.High, k.Low, k.Close, k.Volume))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// 构建 AI 计算指标的 prompt
|
||||
func buildIndicatorPrompt(klines []market.Kline) string {
|
||||
klinesText := formatKlinesForAI(klines)
|
||||
|
||||
prompt := fmt.Sprintf(`%s
|
||||
|
||||
请根据以上 %d 根K线数据,计算以下技术指标(使用标准算法):
|
||||
|
||||
1. EMA12(12周期指数移动平均线)
|
||||
2. EMA26(26周期指数移动平均线)
|
||||
3. MACD(EMA12 - EMA26)
|
||||
4. RSI14(14周期相对强弱指标,使用Wilder平滑法)
|
||||
5. BOLL布林带(20周期,2倍标准差):上轨、中轨、下轨
|
||||
6. ATR14(14周期平均真实波幅,使用Wilder平滑法)
|
||||
7. SMA20(20周期简单移动平均线)
|
||||
|
||||
请严格按照以下 JSON 格式返回结果,不要添加任何其他文字:
|
||||
{
|
||||
"ema12": 数值,
|
||||
"ema26": 数值,
|
||||
"macd": 数值,
|
||||
"rsi14": 数值,
|
||||
"boll_upper": 数值,
|
||||
"boll_middle": 数值,
|
||||
"boll_lower": 数值,
|
||||
"atr14": 数值,
|
||||
"sma20": 数值
|
||||
}
|
||||
|
||||
注意:
|
||||
- 所有数值保留2位小数
|
||||
- EMA计算使用SMA作为初始值,乘数为 2/(period+1)
|
||||
- RSI使用Wilder平滑法
|
||||
- 只返回JSON,不要解释过程`, klinesText, len(klines))
|
||||
|
||||
return prompt
|
||||
}
|
||||
|
||||
// 从 AI 响应中提取 JSON
|
||||
func extractJSONFromResponse(text string) (IndicatorResult, error) {
|
||||
var result IndicatorResult
|
||||
|
||||
// 尝试直接解析
|
||||
if err := json.Unmarshal([]byte(text), &result); err == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 提取 JSON 部分
|
||||
re := regexp.MustCompile(`\{[^{}]*"ema12"[^{}]*\}`)
|
||||
match := re.FindString(text)
|
||||
if match == "" {
|
||||
// 尝试更宽松的匹配
|
||||
start := strings.Index(text, "{")
|
||||
end := strings.LastIndex(text, "}")
|
||||
if start != -1 && end != -1 && end > start {
|
||||
match = text[start : end+1]
|
||||
}
|
||||
}
|
||||
|
||||
if match == "" {
|
||||
return result, fmt.Errorf("no JSON found in response: %s", text[:min(200, len(text))])
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(match), &result); err != nil {
|
||||
return result, fmt.Errorf("parse JSON failed: %w, json: %s", err, match)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 比较两个指标结果,返回误差百分比
|
||||
func compareIndicators(local, ai IndicatorResult) map[string]float64 {
|
||||
errors := make(map[string]float64)
|
||||
|
||||
calcError := func(name string, localVal, aiVal float64) {
|
||||
if localVal == 0 {
|
||||
if aiVal == 0 {
|
||||
errors[name] = 0
|
||||
} else {
|
||||
errors[name] = 100 // 本地为0但AI不为0
|
||||
}
|
||||
return
|
||||
}
|
||||
errors[name] = math.Abs(localVal-aiVal) / math.Abs(localVal) * 100
|
||||
}
|
||||
|
||||
calcError("EMA12", local.EMA12, ai.EMA12)
|
||||
calcError("EMA26", local.EMA26, ai.EMA26)
|
||||
calcError("MACD", local.MACD, ai.MACD)
|
||||
calcError("RSI14", local.RSI14, ai.RSI14)
|
||||
calcError("BOLL_UP", local.BOLLUp, ai.BOLLUp)
|
||||
calcError("BOLL_MID", local.BOLLMid, ai.BOLLMid)
|
||||
calcError("BOLL_LOW", local.BOLLLow, ai.BOLLLow)
|
||||
calcError("ATR14", local.ATR14, ai.ATR14)
|
||||
calcError("SMA20", local.SMA20, ai.SMA20)
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// 生成测试用 K 线数据
|
||||
func generateTestKlines(count int, basePrice float64) []market.Kline {
|
||||
klines := make([]market.Kline, count)
|
||||
price := basePrice
|
||||
now := time.Now()
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
// 模拟价格波动
|
||||
change := (float64(i%7) - 3) * 0.5 // -1.5 到 +1.5 的波动
|
||||
price = price + change
|
||||
|
||||
open := price
|
||||
high := price + math.Abs(change)*0.5 + 0.5
|
||||
low := price - math.Abs(change)*0.5 - 0.3
|
||||
close := price + (change * 0.3)
|
||||
|
||||
klines[i] = market.Kline{
|
||||
OpenTime: now.Add(time.Duration(-count+i) * time.Hour).UnixMilli(),
|
||||
Open: open,
|
||||
High: high,
|
||||
Low: low,
|
||||
Close: close,
|
||||
Volume: 1000 + float64(i*100),
|
||||
CloseTime: now.Add(time.Duration(-count+i+1) * time.Hour).UnixMilli(),
|
||||
}
|
||||
}
|
||||
|
||||
return klines
|
||||
}
|
||||
|
||||
// TestQwenIndicatorCalculation 测试 AI 计算技术指标
|
||||
func TestQwenIndicatorCalculation(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
// 生成 30 根测试 K 线
|
||||
klines := generateTestKlines(30, 95000)
|
||||
|
||||
t.Log("===== K线数据 (最后5根) =====")
|
||||
for i := len(klines) - 5; i < len(klines); i++ {
|
||||
k := klines[i]
|
||||
t.Logf(" [%d] O:%.2f H:%.2f L:%.2f C:%.2f", i+1, k.Open, k.High, k.Low, k.Close)
|
||||
}
|
||||
|
||||
// 本地计算
|
||||
t.Log("\n===== 本地计算结果 =====")
|
||||
localResult := calculateLocalIndicators(klines)
|
||||
t.Logf(" EMA12: %.2f", localResult.EMA12)
|
||||
t.Logf(" EMA26: %.2f", localResult.EMA26)
|
||||
t.Logf(" MACD: %.2f", localResult.MACD)
|
||||
t.Logf(" RSI14: %.2f", localResult.RSI14)
|
||||
t.Logf(" BOLL上轨: %.2f", localResult.BOLLUp)
|
||||
t.Logf(" BOLL中轨: %.2f", localResult.BOLLMid)
|
||||
t.Logf(" BOLL下轨: %.2f", localResult.BOLLLow)
|
||||
t.Logf(" ATR14: %.2f", localResult.ATR14)
|
||||
t.Logf(" SMA20: %.2f", localResult.SMA20)
|
||||
|
||||
// AI 计算
|
||||
t.Log("\n===== 调用 AI 计算 =====")
|
||||
prompt := buildIndicatorPrompt(klines)
|
||||
t.Logf("Prompt 长度: %d 字符", len(prompt))
|
||||
|
||||
start := time.Now()
|
||||
resp, err := agent.Chat(ctx, prompt)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("AI 调用失败: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("AI 响应耗时: %v", elapsed)
|
||||
t.Logf("AI 原始响应:\n%s", resp.Output.Text)
|
||||
|
||||
// 解析 AI 结果
|
||||
aiResult, err := extractJSONFromResponse(resp.Output.Text)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 AI 结果失败: %v", err)
|
||||
}
|
||||
|
||||
t.Log("\n===== AI 计算结果 =====")
|
||||
t.Logf(" EMA12: %.2f", aiResult.EMA12)
|
||||
t.Logf(" EMA26: %.2f", aiResult.EMA26)
|
||||
t.Logf(" MACD: %.2f", aiResult.MACD)
|
||||
t.Logf(" RSI14: %.2f", aiResult.RSI14)
|
||||
t.Logf(" BOLL上轨: %.2f", aiResult.BOLLUp)
|
||||
t.Logf(" BOLL中轨: %.2f", aiResult.BOLLMid)
|
||||
t.Logf(" BOLL下轨: %.2f", aiResult.BOLLLow)
|
||||
t.Logf(" ATR14: %.2f", aiResult.ATR14)
|
||||
t.Logf(" SMA20: %.2f", aiResult.SMA20)
|
||||
|
||||
// 对比结果
|
||||
t.Log("\n===== 误差对比 (%) =====")
|
||||
errors := compareIndicators(localResult, aiResult)
|
||||
|
||||
totalError := 0.0
|
||||
for name, errPct := range errors {
|
||||
status := "✓"
|
||||
if errPct > 5 {
|
||||
status = "⚠"
|
||||
}
|
||||
if errPct > 10 {
|
||||
status = "✗"
|
||||
}
|
||||
t.Logf(" %s %s: %.2f%%", status, name, errPct)
|
||||
totalError += errPct
|
||||
}
|
||||
|
||||
avgError := totalError / float64(len(errors))
|
||||
t.Logf("\n 平均误差: %.2f%%", avgError)
|
||||
|
||||
if avgError > 10 {
|
||||
t.Logf("警告: AI 计算误差较大,可能算法理解有差异")
|
||||
} else if avgError < 5 {
|
||||
t.Log("AI 计算精度良好!")
|
||||
}
|
||||
}
|
||||
|
||||
// TestQwenIndicatorWithRealKlines 使用真实 K 线测试
|
||||
func TestQwenIndicatorWithRealKlines(t *testing.T) {
|
||||
// 尝试获取真实 K 线数据
|
||||
client := market.NewAPIClient()
|
||||
klines, err := client.GetKlines("BTC", "1h", 30)
|
||||
if err != nil {
|
||||
t.Skipf("获取真实 K 线失败,跳过测试: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(klines) < 26 {
|
||||
t.Skipf("K 线数量不足: %d", len(klines))
|
||||
return
|
||||
}
|
||||
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Logf("获取到 %d 根 BTC 1h K线", len(klines))
|
||||
t.Log("最新价格:", klines[len(klines)-1].Close)
|
||||
|
||||
// 本地计算
|
||||
localResult := calculateLocalIndicators(klines)
|
||||
t.Log("\n===== 本地计算 =====")
|
||||
t.Logf(" EMA12: %.2f, EMA26: %.2f, MACD: %.2f", localResult.EMA12, localResult.EMA26, localResult.MACD)
|
||||
t.Logf(" RSI14: %.2f", localResult.RSI14)
|
||||
t.Logf(" BOLL: %.2f / %.2f / %.2f", localResult.BOLLUp, localResult.BOLLMid, localResult.BOLLLow)
|
||||
|
||||
// AI 计算
|
||||
prompt := buildIndicatorPrompt(klines)
|
||||
resp, err := agent.Chat(ctx, prompt)
|
||||
if err != nil {
|
||||
t.Fatalf("AI 调用失败: %v", err)
|
||||
}
|
||||
|
||||
t.Log("\n===== AI 响应 =====")
|
||||
t.Log(resp.Output.Text)
|
||||
|
||||
aiResult, err := extractJSONFromResponse(resp.Output.Text)
|
||||
if err != nil {
|
||||
t.Logf("解析失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 对比
|
||||
errors := compareIndicators(localResult, aiResult)
|
||||
t.Log("\n===== 误差 =====")
|
||||
for name, errPct := range errors {
|
||||
t.Logf(" %s: %.2f%%", name, errPct)
|
||||
}
|
||||
}
|
||||
|
||||
// TestQwenIndicatorMultiTimeframe 测试多个时间周期
|
||||
func TestQwenIndicatorMultiTimeframe(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
timeframes := []struct {
|
||||
name string
|
||||
count int
|
||||
price float64
|
||||
}{
|
||||
{"5m周期", 30, 95000},
|
||||
{"1h周期", 50, 95000},
|
||||
{"4h周期", 40, 95000},
|
||||
}
|
||||
|
||||
for _, tf := range timeframes {
|
||||
t.Run(tf.name, func(t *testing.T) {
|
||||
klines := generateTestKlines(tf.count, tf.price)
|
||||
|
||||
localResult := calculateLocalIndicators(klines)
|
||||
|
||||
// 简化的 prompt
|
||||
prompt := buildSimpleIndicatorPrompt(klines)
|
||||
|
||||
resp, err := agent.Chat(ctx, prompt)
|
||||
if err != nil {
|
||||
t.Fatalf("AI 调用失败: %v", err)
|
||||
}
|
||||
|
||||
aiResult, err := extractJSONFromResponse(resp.Output.Text)
|
||||
if err != nil {
|
||||
t.Logf("解析失败: %v", err)
|
||||
t.Logf("AI 响应: %s", resp.Output.Text[:min(500, len(resp.Output.Text))])
|
||||
return
|
||||
}
|
||||
|
||||
errors := compareIndicators(localResult, aiResult)
|
||||
|
||||
// 计算平均误差
|
||||
total := 0.0
|
||||
for _, e := range errors {
|
||||
total += e
|
||||
}
|
||||
avgErr := total / float64(len(errors))
|
||||
|
||||
t.Logf("本地 MACD: %.2f, AI MACD: %.2f, 误差: %.2f%%", localResult.MACD, aiResult.MACD, errors["MACD"])
|
||||
t.Logf("本地 RSI: %.2f, AI RSI: %.2f, 误差: %.2f%%", localResult.RSI14, aiResult.RSI14, errors["RSI14"])
|
||||
t.Logf("平均误差: %.2f%%", avgErr)
|
||||
})
|
||||
|
||||
time.Sleep(2 * time.Second) // 避免请求过快
|
||||
}
|
||||
}
|
||||
|
||||
// 简化的 prompt
|
||||
func buildSimpleIndicatorPrompt(klines []market.Kline) string {
|
||||
// 只提供收盘价序列,减少 token
|
||||
var prices []string
|
||||
for _, k := range klines {
|
||||
prices = append(prices, fmt.Sprintf("%.2f", k.Close))
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`收盘价序列(从旧到新): [%s]
|
||||
|
||||
请计算技术指标并返回 JSON:
|
||||
- ema12: 12周期EMA
|
||||
- ema26: 26周期EMA
|
||||
- macd: EMA12-EMA26
|
||||
- rsi14: 14周期RSI(Wilder平滑)
|
||||
- boll_upper, boll_middle, boll_lower: 20周期BOLL(2倍标准差)
|
||||
- atr14: 0 (无高低价数据)
|
||||
- sma20: 20周期SMA
|
||||
|
||||
只返回JSON格式:{"ema12":数值,"ema26":数值,...}`, strings.Join(prices, ","))
|
||||
}
|
||||
|
||||
// TestQwenIndicatorAccuracy 精度测试:使用简单数据验证算法
|
||||
func TestQwenIndicatorAccuracy(t *testing.T) {
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
// 使用简单递增数据,便于验证
|
||||
prices := []float64{
|
||||
100, 101, 102, 103, 104, 105, 106, 107, 108, 109, // 1-10
|
||||
110, 111, 112, 113, 114, 115, 116, 117, 118, 119, // 11-20
|
||||
120, 121, 122, 123, 124, 125, 126, 127, 128, 129, // 21-30
|
||||
}
|
||||
|
||||
// 构建 K 线
|
||||
klines := make([]market.Kline, len(prices))
|
||||
for i, p := range prices {
|
||||
klines[i] = market.Kline{
|
||||
Open: p - 0.5,
|
||||
High: p + 1,
|
||||
Low: p - 1,
|
||||
Close: p,
|
||||
}
|
||||
}
|
||||
|
||||
// 本地计算
|
||||
localResult := calculateLocalIndicators(klines)
|
||||
|
||||
t.Log("===== 简单递增数据测试 =====")
|
||||
t.Logf("价格序列: %v", prices)
|
||||
t.Logf("本地计算:")
|
||||
t.Logf(" SMA20 = %.4f (理论值: 119.5)", localResult.SMA20)
|
||||
t.Logf(" EMA12 = %.4f", localResult.EMA12)
|
||||
t.Logf(" RSI14 = %.4f (持续上涨应接近100)", localResult.RSI14)
|
||||
|
||||
// AI 计算
|
||||
var priceStrs []string
|
||||
for _, p := range prices {
|
||||
priceStrs = append(priceStrs, strconv.FormatFloat(p, 'f', 0, 64))
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`收盘价序列: [%s]
|
||||
|
||||
请计算:
|
||||
1. SMA20 (20周期简单移动平均)
|
||||
2. EMA12 (12周期指数移动平均,初始值用SMA,乘数=2/13)
|
||||
3. RSI14 (14周期RSI,Wilder平滑法)
|
||||
|
||||
返回JSON: {"sma20":数值,"ema12":数值,"rsi14":数值}
|
||||
只返回JSON`, strings.Join(priceStrs, ","))
|
||||
|
||||
resp, err := agent.Chat(ctx, prompt)
|
||||
if err != nil {
|
||||
t.Fatalf("AI 调用失败: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("\nAI 响应: %s", resp.Output.Text)
|
||||
|
||||
// 简单解析
|
||||
var aiSimple struct {
|
||||
SMA20 float64 `json:"sma20"`
|
||||
EMA12 float64 `json:"ema12"`
|
||||
RSI14 float64 `json:"rsi14"`
|
||||
}
|
||||
|
||||
text := resp.Output.Text
|
||||
start := strings.Index(text, "{")
|
||||
end := strings.LastIndex(text, "}")
|
||||
if start != -1 && end > start {
|
||||
json.Unmarshal([]byte(text[start:end+1]), &aiSimple)
|
||||
}
|
||||
|
||||
t.Logf("\nAI 计算:")
|
||||
t.Logf(" SMA20 = %.4f", aiSimple.SMA20)
|
||||
t.Logf(" EMA12 = %.4f", aiSimple.EMA12)
|
||||
t.Logf(" RSI14 = %.4f", aiSimple.RSI14)
|
||||
|
||||
// 验证 SMA20 (理论值应该是 110+...+129 的平均 = 119.5)
|
||||
expectedSMA := 119.5
|
||||
if math.Abs(aiSimple.SMA20-expectedSMA) < 0.1 {
|
||||
t.Log("\n✓ AI 的 SMA20 计算正确!")
|
||||
} else {
|
||||
t.Logf("\n✗ AI 的 SMA20 有误差,期望 %.2f", expectedSMA)
|
||||
}
|
||||
}
|
||||
|
||||
// coinankKlinesToMarket 将 coinank K线转换为 market.Kline
|
||||
func coinankKlinesToMarket(klines []coinank.KlineResult) []market.Kline {
|
||||
result := make([]market.Kline, len(klines))
|
||||
for i, k := range klines {
|
||||
result[i] = market.Kline{
|
||||
OpenTime: k.StartTime,
|
||||
Open: k.Open,
|
||||
High: k.High,
|
||||
Low: k.Low,
|
||||
Close: k.Close,
|
||||
Volume: k.Volume,
|
||||
CloseTime: k.EndTime,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// TestQwenETHMultiTimeframe 使用 Coinank 免费 API 获取真实 ETH 数据测试多周期指标
|
||||
func TestQwenETHMultiTimeframe(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
|
||||
// 测试多个时间周期
|
||||
timeframes := []struct {
|
||||
name string
|
||||
interval coinank_enum.Interval
|
||||
size int
|
||||
}{
|
||||
{"5分钟", coinank_enum.Minute5, 50},
|
||||
{"1小时", coinank_enum.Hour1, 50},
|
||||
{"4小时", coinank_enum.Hour4, 50},
|
||||
{"日线", coinank_enum.Day1, 30},
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
for _, tf := range timeframes {
|
||||
t.Run(tf.name, func(t *testing.T) {
|
||||
// 使用 coinank 免费 API 获取 ETH K线数据
|
||||
coinankKlines, err := coinank_api.Kline(ctx, "ETHUSDT", coinank_enum.Binance,
|
||||
now.UnixMilli(), coinank_enum.To, tf.size, tf.interval)
|
||||
if err != nil {
|
||||
t.Fatalf("获取 %s K线失败: %v", tf.name, err)
|
||||
}
|
||||
|
||||
if len(coinankKlines) < 26 {
|
||||
t.Skipf("K线数量不足: %d", len(coinankKlines))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为 market.Kline
|
||||
klines := coinankKlinesToMarket(coinankKlines)
|
||||
|
||||
t.Logf("获取到 %d 根 ETH %s K线", len(klines), tf.name)
|
||||
t.Logf("最新收盘价: %.2f, 时间: %s",
|
||||
klines[len(klines)-1].Close,
|
||||
time.UnixMilli(klines[len(klines)-1].CloseTime).Format("2006-01-02 15:04"))
|
||||
|
||||
// 本地计算
|
||||
localResult := calculateLocalIndicators(klines)
|
||||
t.Log("\n===== 本地计算 =====")
|
||||
t.Logf(" EMA12: %.2f, EMA26: %.2f, MACD: %.4f",
|
||||
localResult.EMA12, localResult.EMA26, localResult.MACD)
|
||||
t.Logf(" RSI14: %.2f", localResult.RSI14)
|
||||
t.Logf(" BOLL: %.2f / %.2f / %.2f",
|
||||
localResult.BOLLUp, localResult.BOLLMid, localResult.BOLLLow)
|
||||
t.Logf(" ATR14: %.4f", localResult.ATR14)
|
||||
|
||||
// AI 计算 - 使用简化 prompt(只发收盘价)
|
||||
prompt := buildSimpleIndicatorPrompt(klines)
|
||||
t.Logf("\nPrompt 长度: %d 字符", len(prompt))
|
||||
|
||||
start := time.Now()
|
||||
resp, err := agent.Chat(ctx, prompt)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("AI 调用失败: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("AI 响应耗时: %v", elapsed)
|
||||
|
||||
// 解析 AI 结果
|
||||
aiResult, err := extractJSONFromResponse(resp.Output.Text)
|
||||
if err != nil {
|
||||
t.Logf("AI 原始响应:\n%s", resp.Output.Text[:min(500, len(resp.Output.Text))])
|
||||
t.Fatalf("解析失败: %v", err)
|
||||
}
|
||||
|
||||
t.Log("\n===== AI 计算 =====")
|
||||
t.Logf(" EMA12: %.2f, EMA26: %.2f, MACD: %.4f",
|
||||
aiResult.EMA12, aiResult.EMA26, aiResult.MACD)
|
||||
t.Logf(" RSI14: %.2f", aiResult.RSI14)
|
||||
t.Logf(" BOLL: %.2f / %.2f / %.2f",
|
||||
aiResult.BOLLUp, aiResult.BOLLMid, aiResult.BOLLLow)
|
||||
|
||||
// 对比误差
|
||||
t.Log("\n===== 误差对比 =====")
|
||||
errors := compareIndicators(localResult, aiResult)
|
||||
totalErr := 0.0
|
||||
for name, errPct := range errors {
|
||||
status := "✓"
|
||||
if errPct > 1 {
|
||||
status = "⚠"
|
||||
}
|
||||
if errPct > 5 {
|
||||
status = "✗"
|
||||
}
|
||||
t.Logf(" %s %-10s: %.2f%%", status, name, errPct)
|
||||
totalErr += errPct
|
||||
}
|
||||
|
||||
avgErr := totalErr / float64(len(errors))
|
||||
t.Logf("\n 平均误差: %.2f%%", avgErr)
|
||||
|
||||
if avgErr < 1 {
|
||||
t.Log(" ✓ AI 计算精度优秀!")
|
||||
} else if avgErr < 5 {
|
||||
t.Log(" ⚠ AI 计算精度良好")
|
||||
} else {
|
||||
t.Log(" ✗ AI 计算误差较大")
|
||||
}
|
||||
|
||||
// 等待避免请求过快
|
||||
time.Sleep(2 * time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestQwenETHIndicatorComparison ETH 指标对比:使用 Coinank 免费 API + Qwen 标准 API
|
||||
func TestQwenETHIndicatorComparison(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agent := NewQwenAgent(QwenAppID, QwenAPIKey)
|
||||
|
||||
// 使用 coinank 免费 API 获取 ETH 1小时 K线
|
||||
now := time.Now()
|
||||
coinankKlines, err := coinank_api.Kline(ctx, "ETHUSDT", coinank_enum.Binance,
|
||||
now.UnixMilli(), coinank_enum.To, 30, coinank_enum.Hour1)
|
||||
if err != nil {
|
||||
t.Fatalf("获取 K线失败: %v", err)
|
||||
}
|
||||
|
||||
// 转换为 market.Kline
|
||||
klines := coinankKlinesToMarket(coinankKlines)
|
||||
|
||||
t.Logf("获取到 %d 根 ETH 1h K线", len(klines))
|
||||
|
||||
// 只用收盘价,简化 prompt
|
||||
var prices []string
|
||||
for _, k := range klines {
|
||||
prices = append(prices, fmt.Sprintf("%.2f", k.Close))
|
||||
}
|
||||
|
||||
// 本地计算
|
||||
localResult := calculateLocalIndicators(klines)
|
||||
|
||||
t.Log("\n===== 本地计算结果 =====")
|
||||
t.Logf("SMA20: %.2f", localResult.SMA20)
|
||||
t.Logf("EMA12: %.2f", localResult.EMA12)
|
||||
t.Logf("EMA26: %.2f", localResult.EMA26)
|
||||
t.Logf("MACD: %.4f", localResult.MACD)
|
||||
t.Logf("RSI14: %.2f", localResult.RSI14)
|
||||
|
||||
// 简化的 AI prompt
|
||||
prompt := fmt.Sprintf(`ETH 最近30根1小时K线收盘价(从旧到新):
|
||||
[%s]
|
||||
|
||||
请计算以下指标并返回纯 JSON:
|
||||
1. sma20: 最后20个价格的简单移动平均
|
||||
2. ema12: 12周期EMA(初始值用前12个价格的SMA,乘数=2/13)
|
||||
3. ema26: 26周期EMA(初始值用前26个价格的SMA,乘数=2/27)
|
||||
4. macd: EMA12 - EMA26
|
||||
5. rsi14: 14周期RSI(Wilder平滑法)
|
||||
|
||||
只返回JSON格式: {"sma20":数值,"ema12":数值,"ema26":数值,"macd":数值,"rsi14":数值}
|
||||
不要任何解释文字`, strings.Join(prices, ", "))
|
||||
|
||||
t.Logf("\n发送 Prompt (%d 字符)", len(prompt))
|
||||
|
||||
// 使用标准 API
|
||||
resp, err := agent.ChatWithModel(ctx, "qwen-max", prompt)
|
||||
if err != nil {
|
||||
t.Fatalf("AI 调用失败: %v", err)
|
||||
}
|
||||
|
||||
aiText := resp.GetContent()
|
||||
t.Logf("\nAI 响应:\n%s", aiText)
|
||||
|
||||
// 解析
|
||||
var aiResult struct {
|
||||
SMA20 float64 `json:"sma20"`
|
||||
EMA12 float64 `json:"ema12"`
|
||||
EMA26 float64 `json:"ema26"`
|
||||
MACD float64 `json:"macd"`
|
||||
RSI14 float64 `json:"rsi14"`
|
||||
}
|
||||
|
||||
start := strings.Index(aiText, "{")
|
||||
end := strings.LastIndex(aiText, "}")
|
||||
if start != -1 && end > start {
|
||||
if err := json.Unmarshal([]byte(aiText[start:end+1]), &aiResult); err != nil {
|
||||
t.Logf("JSON 解析失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("\n===== AI 计算结果 =====")
|
||||
t.Logf("SMA20: %.2f", aiResult.SMA20)
|
||||
t.Logf("EMA12: %.2f", aiResult.EMA12)
|
||||
t.Logf("EMA26: %.2f", aiResult.EMA26)
|
||||
t.Logf("MACD: %.4f", aiResult.MACD)
|
||||
t.Logf("RSI14: %.2f", aiResult.RSI14)
|
||||
|
||||
// 计算误差
|
||||
t.Log("\n===== 误差 =====")
|
||||
calcErr := func(name string, local, ai float64) {
|
||||
if local == 0 {
|
||||
t.Logf(" %s: 本地=0, AI=%.2f", name, ai)
|
||||
return
|
||||
}
|
||||
errPct := math.Abs(local-ai) / math.Abs(local) * 100
|
||||
status := "✓"
|
||||
if errPct > 1 {
|
||||
status = "⚠"
|
||||
}
|
||||
if errPct > 5 {
|
||||
status = "✗"
|
||||
}
|
||||
t.Logf(" %s %s: 本地=%.2f, AI=%.2f, 误差=%.2f%%", status, name, local, ai, errPct)
|
||||
}
|
||||
|
||||
calcErr("SMA20", localResult.SMA20, aiResult.SMA20)
|
||||
calcErr("EMA12", localResult.EMA12, aiResult.EMA12)
|
||||
calcErr("EMA26", localResult.EMA26, aiResult.EMA26)
|
||||
calcErr("MACD", localResult.MACD, aiResult.MACD)
|
||||
calcErr("RSI14", localResult.RSI14, aiResult.RSI14)
|
||||
}
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"nofx/logger"
|
||||
"nofx/manager"
|
||||
"nofx/mcp"
|
||||
_ "nofx/mcp/payment"
|
||||
_ "nofx/mcp/provider"
|
||||
"nofx/store"
|
||||
"nofx/telegram"
|
||||
"os"
|
||||
@@ -168,7 +170,7 @@ func newSharedMCPClient() mcp.AIClient {
|
||||
logger.Warn("⚠️ DEEPSEEK_API_KEY not set, AI features will be unavailable")
|
||||
return nil
|
||||
}
|
||||
return mcp.NewDeepSeekClient()
|
||||
return mcp.NewAIClientByProvider("deepseek")
|
||||
}
|
||||
|
||||
// initInstallationID initializes the anonymous installation ID for experience improvement
|
||||
|
||||
@@ -1,248 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ── buildRequestBodyFromRequest ────────────────────────────────────────────────
|
||||
|
||||
func TestClaudeClient_BuildRequestBody_SystemPromptLifted(t *testing.T) {
|
||||
c := newTestClaudeClient()
|
||||
req := &Request{
|
||||
Model: "claude-opus-4-6",
|
||||
Messages: []Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
body := c.buildRequestBodyFromRequest(req)
|
||||
|
||||
if body["system"] != "You are helpful." {
|
||||
t.Errorf("system not lifted to top level: %v", body["system"])
|
||||
}
|
||||
msgs := body["messages"].([]map[string]any)
|
||||
if len(msgs) != 1 || msgs[0]["role"] != "user" {
|
||||
t.Errorf("system message should be removed from messages array: %v", msgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeClient_BuildRequestBody_ToolsUseInputSchema(t *testing.T) {
|
||||
c := newTestClaudeClient()
|
||||
req := &Request{
|
||||
Model: "claude-opus-4-6",
|
||||
Messages: []Message{{Role: "user", Content: "hi"}},
|
||||
Tools: []Tool{{
|
||||
Type: "function",
|
||||
Function: FunctionDef{
|
||||
Name: "my_tool",
|
||||
Description: "does stuff",
|
||||
Parameters: map[string]any{"type": "object"},
|
||||
},
|
||||
}},
|
||||
}
|
||||
body := c.buildRequestBodyFromRequest(req)
|
||||
|
||||
tools, ok := body["tools"].([]map[string]any)
|
||||
if !ok || len(tools) != 1 {
|
||||
t.Fatalf("tools not set correctly: %v", body["tools"])
|
||||
}
|
||||
tool := tools[0]
|
||||
if tool["name"] != "my_tool" {
|
||||
t.Errorf("tool name wrong: %v", tool["name"])
|
||||
}
|
||||
if tool["input_schema"] == nil {
|
||||
t.Error("tool must use input_schema, not parameters")
|
||||
}
|
||||
if _, hasParams := tool["parameters"]; hasParams {
|
||||
t.Error("tool must NOT have parameters key (Anthropic uses input_schema)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeClient_BuildRequestBody_ToolChoiceObject(t *testing.T) {
|
||||
c := newTestClaudeClient()
|
||||
req := &Request{
|
||||
Model: "claude-opus-4-6",
|
||||
Messages: []Message{{Role: "user", Content: "hi"}},
|
||||
ToolChoice: "auto",
|
||||
}
|
||||
body := c.buildRequestBodyFromRequest(req)
|
||||
|
||||
tc, ok := body["tool_choice"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("tool_choice must be an object, got: %T %v", body["tool_choice"], body["tool_choice"])
|
||||
}
|
||||
if tc["type"] != "auto" {
|
||||
t.Errorf("tool_choice.type must be 'auto', got: %v", tc["type"])
|
||||
}
|
||||
}
|
||||
|
||||
// ── convertMessagesToAnthropic ─────────────────────────────────────────────────
|
||||
|
||||
func TestConvertMessages_AssistantToolCall(t *testing.T) {
|
||||
msgs := []Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ToolCall{{
|
||||
ID: "tc1",
|
||||
Type: "function",
|
||||
Function: ToolCallFunction{Name: "api_request", Arguments: `{"method":"GET","path":"/api/x","body":{}}`},
|
||||
}},
|
||||
},
|
||||
}
|
||||
out := convertMessagesToAnthropic(msgs)
|
||||
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(out))
|
||||
}
|
||||
msg := out[0]
|
||||
if msg["role"] != "assistant" {
|
||||
t.Errorf("role should be assistant: %v", msg["role"])
|
||||
}
|
||||
blocks := msg["content"].([]map[string]any)
|
||||
if len(blocks) != 1 || blocks[0]["type"] != "tool_use" {
|
||||
t.Errorf("content should be tool_use block: %v", blocks)
|
||||
}
|
||||
if blocks[0]["id"] != "tc1" {
|
||||
t.Errorf("tool_use id wrong: %v", blocks[0]["id"])
|
||||
}
|
||||
// Input must be parsed JSON object, not a string.
|
||||
input, ok := blocks[0]["input"].(map[string]any)
|
||||
if !ok {
|
||||
t.Errorf("tool_use input must be map, got %T", blocks[0]["input"])
|
||||
}
|
||||
if input["method"] != "GET" {
|
||||
t.Errorf("input.method wrong: %v", input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMessages_ToolResultMergedIntoUserTurn(t *testing.T) {
|
||||
// Anthropic requires strictly alternating turns; consecutive tool results
|
||||
// must be merged into a single user message.
|
||||
msgs := []Message{
|
||||
{Role: "tool", ToolCallID: "tc1", Content: `{"result":"a"}`},
|
||||
{Role: "tool", ToolCallID: "tc2", Content: `{"result":"b"}`},
|
||||
}
|
||||
out := convertMessagesToAnthropic(msgs)
|
||||
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("consecutive tool results must be merged into one user turn, got %d messages", len(out))
|
||||
}
|
||||
if out[0]["role"] != "user" {
|
||||
t.Errorf("tool results must become role=user: %v", out[0]["role"])
|
||||
}
|
||||
blocks := out[0]["content"].([]map[string]any)
|
||||
if len(blocks) != 2 {
|
||||
t.Errorf("expected 2 tool_result blocks, got %d", len(blocks))
|
||||
}
|
||||
if blocks[0]["type"] != "tool_result" || blocks[1]["type"] != "tool_result" {
|
||||
t.Errorf("blocks should be tool_result: %v", blocks)
|
||||
}
|
||||
if blocks[0]["tool_use_id"] != "tc1" || blocks[1]["tool_use_id"] != "tc2" {
|
||||
t.Errorf("tool_use_id mismatch: %v", blocks)
|
||||
}
|
||||
}
|
||||
|
||||
// ── parseMCPResponseFull ───────────────────────────────────────────────────────
|
||||
|
||||
func TestClaudeClient_ParseResponse_TextOnly(t *testing.T) {
|
||||
c := newTestClaudeClient()
|
||||
body := []byte(`{
|
||||
"content": [{"type":"text","text":"Hello from Claude"}],
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5}
|
||||
}`)
|
||||
resp, err := c.parseMCPResponseFull(body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hello from Claude" {
|
||||
t.Errorf("content mismatch: %q", resp.Content)
|
||||
}
|
||||
if len(resp.ToolCalls) != 0 {
|
||||
t.Errorf("expected no tool calls: %v", resp.ToolCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeClient_ParseResponse_ToolUse(t *testing.T) {
|
||||
c := newTestClaudeClient()
|
||||
body := []byte(`{
|
||||
"content": [{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_01abc",
|
||||
"name": "api_request",
|
||||
"input": {"method":"POST","path":"/api/strategies","body":{"name":"BTC策略"}}
|
||||
}],
|
||||
"usage": {"input_tokens": 100, "output_tokens": 30}
|
||||
}`)
|
||||
resp, err := c.parseMCPResponseFull(body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(resp.ToolCalls))
|
||||
}
|
||||
tc := resp.ToolCalls[0]
|
||||
if tc.ID != "toolu_01abc" {
|
||||
t.Errorf("tool call ID wrong: %v", tc.ID)
|
||||
}
|
||||
if tc.Function.Name != "api_request" {
|
||||
t.Errorf("function name wrong: %v", tc.Function.Name)
|
||||
}
|
||||
// Arguments must be a valid JSON string.
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
|
||||
t.Errorf("arguments not valid JSON: %q — %v", tc.Function.Arguments, err)
|
||||
}
|
||||
if args["method"] != "POST" {
|
||||
t.Errorf("args.method wrong: %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeClient_ParseResponse_APIError(t *testing.T) {
|
||||
c := newTestClaudeClient()
|
||||
body := []byte(`{"error":{"type":"authentication_error","message":"invalid x-api-key"}}`)
|
||||
_, err := c.parseMCPResponseFull(body)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for API error response")
|
||||
}
|
||||
if err.Error() == "" {
|
||||
t.Error("error message should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// ── Auth header ────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestClaudeClient_SetAuthHeader(t *testing.T) {
|
||||
c := newTestClaudeClient()
|
||||
c.APIKey = "sk-ant-test123"
|
||||
|
||||
// net/http.Header canonicalizes keys (x-api-key → X-Api-Key).
|
||||
h := make(http.Header)
|
||||
c.setAuthHeader(h)
|
||||
|
||||
if got := h.Get("x-api-key"); got != "sk-ant-test123" {
|
||||
t.Errorf("x-api-key header not set correctly: %q", got)
|
||||
}
|
||||
if h.Get("anthropic-version") == "" {
|
||||
t.Error("anthropic-version header must be set")
|
||||
}
|
||||
// Must NOT use Authorization: Bearer (that's OpenAI format).
|
||||
if h.Get("Authorization") != "" {
|
||||
t.Error("Claude must use x-api-key, not Authorization header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeClient_BuildUrl(t *testing.T) {
|
||||
c := newTestClaudeClient()
|
||||
url := c.buildUrl()
|
||||
if url != DefaultClaudeBaseURL+"/messages" {
|
||||
t.Errorf("URL should be /messages endpoint, got: %s", url)
|
||||
}
|
||||
}
|
||||
|
||||
// ── helpers ────────────────────────────────────────────────────────────────────
|
||||
|
||||
func newTestClaudeClient() *ClaudeClient {
|
||||
return NewClaudeClientWithOptions().(*ClaudeClient)
|
||||
}
|
||||
+96
-109
@@ -60,14 +60,14 @@ type Client struct {
|
||||
UseFullURL bool // Whether to use full URL (without appending /chat/completions)
|
||||
MaxTokens int // Maximum tokens for AI response
|
||||
|
||||
httpClient *http.Client
|
||||
logger Logger // Logger (replaceable)
|
||||
config *Config // Config object (stores all configurations)
|
||||
HTTPClient *http.Client // Exported for sub-packages
|
||||
Log Logger // Exported for sub-packages
|
||||
Cfg *Config // Exported for sub-packages
|
||||
|
||||
// hooks are used to implement dynamic dispatch (polymorphism)
|
||||
// When DeepSeekClient embeds Client, hooks point to DeepSeekClient
|
||||
// This way methods called in call() are automatically dispatched to the overridden version in subclass
|
||||
hooks clientHooks
|
||||
// Hooks are used to implement dynamic dispatch (polymorphism)
|
||||
// When provider.DeepSeekClient embeds Client, Hooks point to DeepSeekClient
|
||||
// This way methods called in Call() are automatically dispatched to the overridden version
|
||||
Hooks ClientHooks
|
||||
}
|
||||
|
||||
// New creates default client (backward compatible)
|
||||
@@ -80,21 +80,22 @@ func New() AIClient {
|
||||
// NewClient creates client (supports options pattern)
|
||||
//
|
||||
// Usage examples:
|
||||
// // Basic usage (backward compatible)
|
||||
// client := mcp.NewClient()
|
||||
//
|
||||
// // Custom logger
|
||||
// client := mcp.NewClient(mcp.WithLogger(customLogger))
|
||||
// // Basic usage (backward compatible)
|
||||
// client := mcp.NewClient()
|
||||
//
|
||||
// // Custom timeout
|
||||
// client := mcp.NewClient(mcp.WithTimeout(60*time.Second))
|
||||
// // Custom logger
|
||||
// client := mcp.NewClient(mcp.WithLogger(customLogger))
|
||||
//
|
||||
// // Combine multiple options
|
||||
// client := mcp.NewClient(
|
||||
// mcp.WithDeepSeekConfig("sk-xxx"),
|
||||
// mcp.WithLogger(customLogger),
|
||||
// mcp.WithTimeout(60*time.Second),
|
||||
// )
|
||||
// // Custom timeout
|
||||
// client := mcp.NewClient(mcp.WithTimeout(60*time.Second))
|
||||
//
|
||||
// // Combine multiple options
|
||||
// client := mcp.NewClient(
|
||||
// mcp.WithDeepSeekConfig("sk-xxx"),
|
||||
// mcp.WithLogger(customLogger),
|
||||
// mcp.WithTimeout(60*time.Second),
|
||||
// )
|
||||
func NewClient(opts ...ClientOption) AIClient {
|
||||
// 1. Create default config
|
||||
cfg := DefaultConfig()
|
||||
@@ -112,9 +113,9 @@ func NewClient(opts ...ClientOption) AIClient {
|
||||
Model: cfg.Model,
|
||||
MaxTokens: cfg.MaxTokens,
|
||||
UseFullURL: cfg.UseFullURL,
|
||||
httpClient: cfg.HTTPClient,
|
||||
logger: cfg.Logger,
|
||||
config: cfg,
|
||||
HTTPClient: cfg.HTTPClient,
|
||||
Log: cfg.Logger,
|
||||
Cfg: cfg,
|
||||
}
|
||||
|
||||
// 4. Set default Provider (if not set)
|
||||
@@ -125,7 +126,7 @@ func NewClient(opts ...ClientOption) AIClient {
|
||||
}
|
||||
|
||||
// 5. Set hooks to point to self
|
||||
client.hooks = client
|
||||
client.Hooks = client
|
||||
|
||||
return client
|
||||
}
|
||||
@@ -148,7 +149,7 @@ func (client *Client) SetAPIKey(apiKey, apiURL, customModel string) {
|
||||
}
|
||||
|
||||
func (client *Client) SetTimeout(timeout time.Duration) {
|
||||
client.httpClient.Timeout = timeout
|
||||
client.HTTPClient.Timeout = timeout
|
||||
}
|
||||
|
||||
// CallWithMessages template method - fixed retry flow (cannot be overridden)
|
||||
@@ -159,32 +160,32 @@ func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string,
|
||||
|
||||
// Fixed retry flow
|
||||
var lastErr error
|
||||
maxRetries := client.config.MaxRetries
|
||||
maxRetries := client.Cfg.MaxRetries
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
if attempt > 1 {
|
||||
client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
|
||||
client.Log.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
|
||||
}
|
||||
|
||||
// Call the fixed single-call flow
|
||||
result, err := client.hooks.call(systemPrompt, userPrompt)
|
||||
result, err := client.Hooks.Call(systemPrompt, userPrompt)
|
||||
if err == nil {
|
||||
if attempt > 1 {
|
||||
client.logger.Infof("✓ AI API retry succeeded")
|
||||
client.Log.Infof("✓ AI API retry succeeded")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
// Check if error is retryable via hooks (supports custom retry strategy in subclass)
|
||||
if !client.hooks.isRetryableError(err) {
|
||||
// Check if error is retryable via hooks (supports custom retry strategy)
|
||||
if !client.Hooks.IsRetryableError(err) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Wait before retry
|
||||
if attempt < maxRetries {
|
||||
waitTime := client.config.RetryWaitBase * time.Duration(attempt)
|
||||
client.logger.Infof("⏳ Waiting %v before retry...", waitTime)
|
||||
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
|
||||
client.Log.Infof("⏳ Waiting %v before retry...", waitTime)
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
}
|
||||
@@ -192,11 +193,11 @@ func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string,
|
||||
return "", fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func (client *Client) setAuthHeader(reqHeader http.Header) {
|
||||
func (client *Client) SetAuthHeader(reqHeader http.Header) {
|
||||
reqHeader.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey))
|
||||
}
|
||||
|
||||
func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
|
||||
func (client *Client) BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
|
||||
// Build messages array
|
||||
messages := []map[string]string{}
|
||||
|
||||
@@ -217,7 +218,7 @@ func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[s
|
||||
requestBody := map[string]interface{}{
|
||||
"model": client.Model,
|
||||
"messages": messages,
|
||||
"temperature": client.config.Temperature, // Use configured temperature
|
||||
"temperature": client.Cfg.Temperature, // Use configured temperature
|
||||
}
|
||||
// OpenAI newer models use max_completion_tokens instead of max_tokens
|
||||
if client.Provider == ProviderOpenAI {
|
||||
@@ -228,8 +229,8 @@ func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[s
|
||||
return requestBody
|
||||
}
|
||||
|
||||
// can be used to marshal the request body and can be overridden
|
||||
func (client *Client) marshalRequestBody(requestBody map[string]any) ([]byte, error) {
|
||||
// MarshalRequestBody can be used to marshal the request body and can be overridden
|
||||
func (client *Client) MarshalRequestBody(requestBody map[string]any) ([]byte, error) {
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to serialize request: %w", err)
|
||||
@@ -237,17 +238,17 @@ func (client *Client) marshalRequestBody(requestBody map[string]any) ([]byte, er
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
func (client *Client) parseMCPResponse(body []byte) (string, error) {
|
||||
r, err := client.parseMCPResponseFull(body)
|
||||
func (client *Client) ParseMCPResponse(body []byte) (string, error) {
|
||||
r, err := client.ParseMCPResponseFull(body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return r.Content, nil
|
||||
}
|
||||
|
||||
// parseMCPResponseFull parses the OpenAI-format response body and returns both
|
||||
// ParseMCPResponseFull parses the OpenAI-format response body and returns both
|
||||
// the text content and any tool calls.
|
||||
func (client *Client) parseMCPResponseFull(body []byte) (*LLMResponse, error) {
|
||||
func (client *Client) ParseMCPResponseFull(body []byte) (*LLMResponse, error) {
|
||||
var result struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
@@ -288,14 +289,14 @@ func (client *Client) parseMCPResponseFull(body []byte) (*LLMResponse, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (client *Client) buildUrl() string {
|
||||
func (client *Client) BuildUrl() string {
|
||||
if client.UseFullURL {
|
||||
return client.BaseURL
|
||||
}
|
||||
return fmt.Sprintf("%s/chat/completions", client.BaseURL)
|
||||
}
|
||||
|
||||
func (client *Client) buildRequest(url string, jsonData []byte) (*http.Request, error) {
|
||||
func (client *Client) BuildRequest(url string, jsonData []byte) (*http.Request, error) {
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
@@ -304,42 +305,42 @@ func (client *Client) buildRequest(url string, jsonData []byte) (*http.Request,
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Set auth header via hooks (supports overriding in subclass)
|
||||
client.hooks.setAuthHeader(req.Header)
|
||||
// Set auth header via hooks (supports overriding)
|
||||
client.Hooks.SetAuthHeader(req.Header)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// call single AI API call (fixed flow, cannot be overridden)
|
||||
func (client *Client) call(systemPrompt, userPrompt string) (string, error) {
|
||||
// Call single AI API call (fixed flow, cannot be overridden)
|
||||
func (client *Client) Call(systemPrompt, userPrompt string) (string, error) {
|
||||
// Print current AI configuration
|
||||
client.logger.Infof("📡 [%s] Request AI Server: BaseURL: %s", client.String(), client.BaseURL)
|
||||
client.logger.Debugf("[%s] UseFullURL: %v", client.String(), client.UseFullURL)
|
||||
client.Log.Infof("📡 [%s] Request AI Server: BaseURL: %s", client.String(), client.BaseURL)
|
||||
client.Log.Debugf("[%s] UseFullURL: %v", client.String(), client.UseFullURL)
|
||||
if len(client.APIKey) > 8 {
|
||||
client.logger.Debugf("[%s] API Key: %s...%s", client.String(), client.APIKey[:4], client.APIKey[len(client.APIKey)-4:])
|
||||
client.Log.Debugf("[%s] API Key: %s...%s", client.String(), client.APIKey[:4], client.APIKey[len(client.APIKey)-4:])
|
||||
}
|
||||
|
||||
// Step 1: Build request body (via hooks for dynamic dispatch)
|
||||
requestBody := client.hooks.buildMCPRequestBody(systemPrompt, userPrompt)
|
||||
requestBody := client.Hooks.BuildMCPRequestBody(systemPrompt, userPrompt)
|
||||
|
||||
// Step 2: Serialize request body (via hooks for dynamic dispatch)
|
||||
jsonData, err := client.hooks.marshalRequestBody(requestBody)
|
||||
jsonData, err := client.Hooks.MarshalRequestBody(requestBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Step 3: Build URL (via hooks for dynamic dispatch)
|
||||
url := client.hooks.buildUrl()
|
||||
client.logger.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
|
||||
url := client.Hooks.BuildUrl()
|
||||
client.Log.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
|
||||
|
||||
// Step 4: Create HTTP request (fixed logic)
|
||||
req, err := client.hooks.buildRequest(url, jsonData)
|
||||
req, err := client.Hooks.BuildRequest(url, jsonData)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Step 5: Send HTTP request (fixed logic)
|
||||
resp, err := client.httpClient.Do(req)
|
||||
resp, err := client.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
@@ -357,7 +358,7 @@ func (client *Client) call(systemPrompt, userPrompt string) (string, error) {
|
||||
}
|
||||
|
||||
// Step 8: Parse response (via hooks for dynamic dispatch)
|
||||
result, err := client.hooks.parseMCPResponse(body)
|
||||
result, err := client.Hooks.ParseMCPResponse(body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fail to parse AI server response: %w", err)
|
||||
}
|
||||
@@ -370,11 +371,11 @@ func (client *Client) String() string {
|
||||
client.Provider, client.Model)
|
||||
}
|
||||
|
||||
// isRetryableError determines if error is retryable (network errors, timeouts, etc.)
|
||||
func (client *Client) isRetryableError(err error) bool {
|
||||
// IsRetryableError determines if error is retryable (network errors, timeouts, etc.)
|
||||
func (client *Client) IsRetryableError(err error) bool {
|
||||
errStr := err.Error()
|
||||
// Network errors, timeouts, EOF, etc. can be retried
|
||||
for _, retryable := range client.config.RetryableErrors {
|
||||
for _, retryable := range client.Cfg.RetryableErrors {
|
||||
if strings.Contains(errStr, retryable) {
|
||||
return true
|
||||
}
|
||||
@@ -387,20 +388,6 @@ func (client *Client) isRetryableError(err error) bool {
|
||||
// ============================================================
|
||||
|
||||
// CallWithRequest calls AI API using Request object (supports advanced features)
|
||||
//
|
||||
// This method supports:
|
||||
// - Multi-turn conversation history
|
||||
// - Fine-grained parameter control (temperature, top_p, penalties, etc.)
|
||||
// - Function Calling / Tools
|
||||
// - Streaming response (future support)
|
||||
//
|
||||
// Usage example:
|
||||
// request := NewRequestBuilder().
|
||||
// WithSystemPrompt("You are helpful").
|
||||
// WithUserPrompt("Hello").
|
||||
// WithTemperature(0.8).
|
||||
// Build()
|
||||
// result, err := client.CallWithRequest(request)
|
||||
func (client *Client) CallWithRequest(req *Request) (string, error) {
|
||||
if client.APIKey == "" {
|
||||
return "", fmt.Errorf("AI API key not set, please call SetAPIKey first")
|
||||
@@ -413,32 +400,32 @@ func (client *Client) CallWithRequest(req *Request) (string, error) {
|
||||
|
||||
// Fixed retry flow
|
||||
var lastErr error
|
||||
maxRetries := client.config.MaxRetries
|
||||
maxRetries := client.Cfg.MaxRetries
|
||||
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
if attempt > 1 {
|
||||
client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
|
||||
client.Log.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
|
||||
}
|
||||
|
||||
// Call single request
|
||||
result, err := client.callWithRequest(req)
|
||||
if err == nil {
|
||||
if attempt > 1 {
|
||||
client.logger.Infof("✓ AI API retry succeeded")
|
||||
client.Log.Infof("✓ AI API retry succeeded")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
// Check if error is retryable
|
||||
if !client.hooks.isRetryableError(err) {
|
||||
if !client.Hooks.IsRetryableError(err) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Wait before retry
|
||||
if attempt < maxRetries {
|
||||
waitTime := client.config.RetryWaitBase * time.Duration(attempt)
|
||||
client.logger.Infof("⏳ Waiting %v before retry...", waitTime)
|
||||
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
|
||||
client.Log.Infof("⏳ Waiting %v before retry...", waitTime)
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
}
|
||||
@@ -456,21 +443,21 @@ func (client *Client) CallWithRequestFull(req *Request) (*LLMResponse, error) {
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
maxRetries := client.config.MaxRetries
|
||||
maxRetries := client.Cfg.MaxRetries
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
if attempt > 1 {
|
||||
client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
|
||||
client.Log.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
|
||||
}
|
||||
result, err := client.callWithRequestFull(req)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
lastErr = err
|
||||
if !client.hooks.isRetryableError(err) {
|
||||
if !client.Hooks.IsRetryableError(err) {
|
||||
return nil, err
|
||||
}
|
||||
if attempt < maxRetries {
|
||||
waitTime := client.config.RetryWaitBase * time.Duration(attempt)
|
||||
waitTime := client.Cfg.RetryWaitBase * time.Duration(attempt)
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
}
|
||||
@@ -479,21 +466,21 @@ func (client *Client) CallWithRequestFull(req *Request) (*LLMResponse, error) {
|
||||
|
||||
// callWithRequestFull single call that returns LLMResponse (content + tool calls).
|
||||
func (client *Client) callWithRequestFull(req *Request) (*LLMResponse, error) {
|
||||
client.logger.Infof("📡 [%s] Request AI Server (full): BaseURL: %s", client.String(), client.BaseURL)
|
||||
client.Log.Infof("📡 [%s] Request AI Server (full): BaseURL: %s", client.String(), client.BaseURL)
|
||||
|
||||
requestBody := client.hooks.buildRequestBodyFromRequest(req)
|
||||
jsonData, err := client.hooks.marshalRequestBody(requestBody)
|
||||
requestBody := client.Hooks.BuildRequestBodyFromRequest(req)
|
||||
jsonData, err := client.Hooks.MarshalRequestBody(requestBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := client.hooks.buildUrl()
|
||||
httpReq, err := client.hooks.buildRequest(url, jsonData)
|
||||
url := client.Hooks.BuildUrl()
|
||||
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.httpClient.Do(httpReq)
|
||||
resp, err := client.HTTPClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
@@ -507,31 +494,31 @@ func (client *Client) callWithRequestFull(req *Request) (*LLMResponse, error) {
|
||||
return nil, fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return client.hooks.parseMCPResponseFull(body)
|
||||
return client.Hooks.ParseMCPResponseFull(body)
|
||||
}
|
||||
|
||||
// callWithRequest single AI API call (using Request object)
|
||||
func (client *Client) callWithRequest(req *Request) (string, error) {
|
||||
// Print current AI configuration
|
||||
client.logger.Infof("📡 [%s] Request AI Server with Builder: BaseURL: %s", client.String(), client.BaseURL)
|
||||
client.logger.Debugf("[%s] Messages count: %d", client.String(), len(req.Messages))
|
||||
client.Log.Infof("📡 [%s] Request AI Server with Builder: BaseURL: %s", client.String(), client.BaseURL)
|
||||
client.Log.Debugf("[%s] Messages count: %d", client.String(), len(req.Messages))
|
||||
|
||||
requestBody := client.hooks.buildRequestBodyFromRequest(req)
|
||||
requestBody := client.Hooks.BuildRequestBodyFromRequest(req)
|
||||
|
||||
jsonData, err := client.hooks.marshalRequestBody(requestBody)
|
||||
jsonData, err := client.Hooks.MarshalRequestBody(requestBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
url := client.hooks.buildUrl()
|
||||
client.logger.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
|
||||
url := client.Hooks.BuildUrl()
|
||||
client.Log.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
|
||||
|
||||
httpReq, err := client.hooks.buildRequest(url, jsonData)
|
||||
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.httpClient.Do(httpReq)
|
||||
resp, err := client.HTTPClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
@@ -546,7 +533,7 @@ func (client *Client) callWithRequest(req *Request) (string, error) {
|
||||
return "", fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
result, err := client.hooks.parseMCPResponse(body)
|
||||
result, err := client.Hooks.ParseMCPResponse(body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fail to parse AI server response: %w", err)
|
||||
}
|
||||
@@ -554,8 +541,8 @@ func (client *Client) callWithRequest(req *Request) (string, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// buildRequestBodyFromRequest builds request body from Request object
|
||||
func (client *Client) buildRequestBodyFromRequest(req *Request) map[string]any {
|
||||
// BuildRequestBodyFromRequest builds request body from Request object
|
||||
func (client *Client) BuildRequestBodyFromRequest(req *Request) map[string]any {
|
||||
// Convert Message to API format — must use map[string]any to support
|
||||
// tool-call messages (tool_calls, tool_call_id fields).
|
||||
messages := make([]map[string]any, 0, len(req.Messages))
|
||||
@@ -586,7 +573,7 @@ func (client *Client) buildRequestBodyFromRequest(req *Request) map[string]any {
|
||||
requestBody["temperature"] = *req.Temperature
|
||||
} else {
|
||||
// If not set in Request, use Client's configuration
|
||||
requestBody["temperature"] = client.config.Temperature
|
||||
requestBody["temperature"] = client.Cfg.Temperature
|
||||
}
|
||||
|
||||
// OpenAI newer models use max_completion_tokens instead of max_tokens
|
||||
@@ -647,19 +634,19 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string))
|
||||
}
|
||||
req.Stream = true
|
||||
|
||||
requestBody := client.hooks.buildRequestBodyFromRequest(req)
|
||||
jsonData, err := client.hooks.marshalRequestBody(requestBody)
|
||||
requestBody := client.Hooks.BuildRequestBodyFromRequest(req)
|
||||
jsonData, err := client.Hooks.MarshalRequestBody(requestBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
url := client.hooks.buildUrl()
|
||||
httpReq, err := client.hooks.buildRequest(url, jsonData)
|
||||
url := client.Hooks.BuildUrl()
|
||||
httpReq, err := client.Hooks.BuildRequest(url, jsonData)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Idle-timeout watchdog: cancel the request if no SSE line arrives for 30 seconds.
|
||||
// Idle-timeout watchdog: cancel the request if no SSE line arrives for 60 seconds.
|
||||
// This breaks the scanner out of an indefinitely blocking Read on a hung connection.
|
||||
const idleTimeout = 60 * time.Second
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -689,7 +676,7 @@ func (client *Client) CallWithRequestStream(req *Request, onChunk func(string))
|
||||
}()
|
||||
|
||||
httpReq = httpReq.WithContext(ctx)
|
||||
resp, err := client.httpClient.Do(httpReq)
|
||||
resp, err := client.HTTPClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("streaming request failed: %w", err)
|
||||
}
|
||||
|
||||
+17
-17
@@ -27,16 +27,16 @@ func TestNewClient_Default(t *testing.T) {
|
||||
t.Error("MaxTokens should be positive")
|
||||
}
|
||||
|
||||
if c.logger == nil {
|
||||
t.Error("logger should not be nil")
|
||||
if c.Log == nil {
|
||||
t.Error("Log should not be nil")
|
||||
}
|
||||
|
||||
if c.httpClient == nil {
|
||||
t.Error("httpClient should not be nil")
|
||||
if c.HTTPClient == nil {
|
||||
t.Error("HTTPClient should not be nil")
|
||||
}
|
||||
|
||||
if c.hooks == nil {
|
||||
t.Error("hooks should not be nil")
|
||||
if c.Hooks == nil {
|
||||
t.Error("Hooks should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,12 +54,12 @@ func TestNewClient_WithOptions(t *testing.T) {
|
||||
|
||||
c := client.(*Client)
|
||||
|
||||
if c.logger != mockLogger {
|
||||
t.Error("logger should be set from option")
|
||||
if c.Log != mockLogger {
|
||||
t.Error("Log should be set from option")
|
||||
}
|
||||
|
||||
if c.httpClient != mockHTTP {
|
||||
t.Error("httpClient should be set from option")
|
||||
if c.HTTPClient != mockHTTP {
|
||||
t.Error("HTTPClient should be set from option")
|
||||
}
|
||||
|
||||
if c.MaxTokens != 4000 {
|
||||
@@ -174,7 +174,7 @@ func TestClient_Retry_Success(t *testing.T) {
|
||||
WithMaxRetries(3),
|
||||
)
|
||||
|
||||
// Since our client uses hooks.call, need special handling
|
||||
// Since our client uses Hooks.Call, need special handling
|
||||
// Here we test that CallWithMessages will invoke retry logic
|
||||
c := client.(*Client)
|
||||
|
||||
@@ -242,7 +242,7 @@ func TestClient_BuildMCPRequestBody(t *testing.T) {
|
||||
client := NewClient()
|
||||
c := client.(*Client)
|
||||
|
||||
body := c.buildMCPRequestBody("system prompt", "user prompt")
|
||||
body := c.BuildMCPRequestBody("system prompt", "user prompt")
|
||||
|
||||
if body == nil {
|
||||
t.Fatal("body should not be nil")
|
||||
@@ -300,7 +300,7 @@ func TestClient_BuildUrl(t *testing.T) {
|
||||
)
|
||||
c := client.(*Client)
|
||||
|
||||
url := c.buildUrl()
|
||||
url := c.BuildUrl()
|
||||
if url != tt.expected {
|
||||
t.Errorf("expected '%s', got '%s'", tt.expected, url)
|
||||
}
|
||||
@@ -313,7 +313,7 @@ func TestClient_SetAuthHeader(t *testing.T) {
|
||||
c := client.(*Client)
|
||||
|
||||
headers := make(http.Header)
|
||||
c.setAuthHeader(headers)
|
||||
c.SetAuthHeader(headers)
|
||||
|
||||
authHeader := headers.Get("Authorization")
|
||||
if authHeader != "Bearer test-api-key" {
|
||||
@@ -359,7 +359,7 @@ func TestClient_IsRetryableError(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := c.isRetryableError(tt.err)
|
||||
result := c.IsRetryableError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
@@ -378,8 +378,8 @@ func TestClient_SetTimeout(t *testing.T) {
|
||||
client.SetTimeout(newTimeout)
|
||||
|
||||
c := client.(*Client)
|
||||
if c.httpClient.Timeout != newTimeout {
|
||||
t.Errorf("expected timeout %v, got %v", newTimeout, c.httpClient.Timeout)
|
||||
if c.HTTPClient.Timeout != newTimeout {
|
||||
t.Errorf("expected timeout %v, got %v", newTimeout, c.HTTPClient.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+10
-10
@@ -84,7 +84,7 @@ func TestConfig_Temperature_IsUsed(t *testing.T) {
|
||||
c := client.(*Client)
|
||||
|
||||
// Build request body
|
||||
requestBody := c.buildMCPRequestBody("system", "user")
|
||||
requestBody := c.BuildMCPRequestBody("system", "user")
|
||||
|
||||
// Verify temperature field
|
||||
temp, ok := requestBody["temperature"].(float64)
|
||||
@@ -201,7 +201,7 @@ func TestConfig_RetryableErrors_IsUsed(t *testing.T) {
|
||||
c := client.(*Client)
|
||||
|
||||
// Modify config's RetryableErrors (no WithRetryableErrors option yet)
|
||||
c.config.RetryableErrors = customRetryableErrors
|
||||
c.Cfg.RetryableErrors = customRetryableErrors
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -227,7 +227,7 @@ func TestConfig_RetryableErrors_IsUsed(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := c.isRetryableError(tt.err)
|
||||
result := c.IsRetryableError(tt.err)
|
||||
if result != tt.retryable {
|
||||
t.Errorf("expected isRetryableError(%v) = %v, got %v", tt.err, tt.retryable, result)
|
||||
}
|
||||
@@ -244,19 +244,19 @@ func TestConfig_DefaultValues(t *testing.T) {
|
||||
c := client.(*Client)
|
||||
|
||||
// Verify default values
|
||||
if c.config.MaxRetries != 3 {
|
||||
t.Errorf("default MaxRetries should be 3, got %d", c.config.MaxRetries)
|
||||
if c.Cfg.MaxRetries != 3 {
|
||||
t.Errorf("default MaxRetries should be 3, got %d", c.Cfg.MaxRetries)
|
||||
}
|
||||
|
||||
if c.config.Temperature != 0.5 {
|
||||
t.Errorf("default Temperature should be 0.5, got %f", c.config.Temperature)
|
||||
if c.Cfg.Temperature != 0.5 {
|
||||
t.Errorf("default Temperature should be 0.5, got %f", c.Cfg.Temperature)
|
||||
}
|
||||
|
||||
if c.config.RetryWaitBase != 2*time.Second {
|
||||
t.Errorf("default RetryWaitBase should be 2s, got %v", c.config.RetryWaitBase)
|
||||
if c.Cfg.RetryWaitBase != 2*time.Second {
|
||||
t.Errorf("default RetryWaitBase should be 2s, got %v", c.Cfg.RetryWaitBase)
|
||||
}
|
||||
|
||||
if len(c.config.RetryableErrors) == 0 {
|
||||
if len(c.Cfg.RetryableErrors) == 0 {
|
||||
t.Error("default RetryableErrors should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderDeepSeek = "deepseek"
|
||||
DefaultDeepSeekBaseURL = "https://api.deepseek.com"
|
||||
DefaultDeepSeekModel = "deepseek-chat"
|
||||
)
|
||||
|
||||
type DeepSeekClient struct {
|
||||
*Client
|
||||
}
|
||||
|
||||
// NewDeepSeekClient creates DeepSeek client (backward compatible)
|
||||
//
|
||||
// Deprecated: Recommend using NewDeepSeekClientWithOptions for better flexibility
|
||||
func NewDeepSeekClient() AIClient {
|
||||
return NewDeepSeekClientWithOptions()
|
||||
}
|
||||
|
||||
// NewDeepSeekClientWithOptions creates DeepSeek client (supports options pattern)
|
||||
//
|
||||
// Usage examples:
|
||||
// // Basic usage
|
||||
// client := mcp.NewDeepSeekClientWithOptions()
|
||||
//
|
||||
// // Custom configuration
|
||||
// client := mcp.NewDeepSeekClientWithOptions(
|
||||
// mcp.WithAPIKey("sk-xxx"),
|
||||
// mcp.WithLogger(customLogger),
|
||||
// mcp.WithTimeout(60*time.Second),
|
||||
// )
|
||||
func NewDeepSeekClientWithOptions(opts ...ClientOption) AIClient {
|
||||
// 1. Create DeepSeek preset options
|
||||
deepseekOpts := []ClientOption{
|
||||
WithProvider(ProviderDeepSeek),
|
||||
WithModel(DefaultDeepSeekModel),
|
||||
WithBaseURL(DefaultDeepSeekBaseURL),
|
||||
}
|
||||
|
||||
// 2. Merge user options (user options have higher priority)
|
||||
allOpts := append(deepseekOpts, opts...)
|
||||
|
||||
// 3. Create base client
|
||||
baseClient := NewClient(allOpts...).(*Client)
|
||||
|
||||
// 4. Create DeepSeek client
|
||||
dsClient := &DeepSeekClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
// 5. Set hooks to point to DeepSeekClient (implement dynamic dispatch)
|
||||
baseClient.hooks = dsClient
|
||||
|
||||
return dsClient
|
||||
}
|
||||
|
||||
func (dsClient *DeepSeekClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
dsClient.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
dsClient.logger.Infof("🔧 [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
dsClient.BaseURL = customURL
|
||||
dsClient.logger.Infof("🔧 [MCP] DeepSeek using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
dsClient.logger.Infof("🔧 [MCP] DeepSeek using default BaseURL: %s", dsClient.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
dsClient.Model = customModel
|
||||
dsClient.logger.Infof("🔧 [MCP] DeepSeek using custom Model: %s", customModel)
|
||||
} else {
|
||||
dsClient.logger.Infof("🔧 [MCP] DeepSeek using default Model: %s", dsClient.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func (dsClient *DeepSeekClient) setAuthHeader(reqHeaders http.Header) {
|
||||
dsClient.Client.setAuthHeader(reqHeaders)
|
||||
}
|
||||
@@ -1,272 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// Test DeepSeekClient Creation and Configuration
|
||||
// ============================================================
|
||||
|
||||
func TestNewDeepSeekClient_Default(t *testing.T) {
|
||||
client := NewDeepSeekClient()
|
||||
|
||||
if client == nil {
|
||||
t.Fatal("client should not be nil")
|
||||
}
|
||||
|
||||
// Type assertion check
|
||||
dsClient, ok := client.(*DeepSeekClient)
|
||||
if !ok {
|
||||
t.Fatal("client should be *DeepSeekClient")
|
||||
}
|
||||
|
||||
// Verify default values
|
||||
if dsClient.Provider != ProviderDeepSeek {
|
||||
t.Errorf("Provider should be '%s', got '%s'", ProviderDeepSeek, dsClient.Provider)
|
||||
}
|
||||
|
||||
if dsClient.BaseURL != DefaultDeepSeekBaseURL {
|
||||
t.Errorf("BaseURL should be '%s', got '%s'", DefaultDeepSeekBaseURL, dsClient.BaseURL)
|
||||
}
|
||||
|
||||
if dsClient.Model != DefaultDeepSeekModel {
|
||||
t.Errorf("Model should be '%s', got '%s'", DefaultDeepSeekModel, dsClient.Model)
|
||||
}
|
||||
|
||||
if dsClient.logger == nil {
|
||||
t.Error("logger should not be nil")
|
||||
}
|
||||
|
||||
if dsClient.httpClient == nil {
|
||||
t.Error("httpClient should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDeepSeekClientWithOptions(t *testing.T) {
|
||||
mockLogger := NewMockLogger()
|
||||
customModel := "deepseek-v2"
|
||||
customAPIKey := "sk-custom-key"
|
||||
|
||||
client := NewDeepSeekClientWithOptions(
|
||||
WithLogger(mockLogger),
|
||||
WithModel(customModel),
|
||||
WithAPIKey(customAPIKey),
|
||||
WithMaxTokens(4000),
|
||||
)
|
||||
|
||||
dsClient := client.(*DeepSeekClient)
|
||||
|
||||
// Verify custom options are applied
|
||||
if dsClient.logger != mockLogger {
|
||||
t.Error("logger should be set from option")
|
||||
}
|
||||
|
||||
if dsClient.Model != customModel {
|
||||
t.Error("Model should be set from option")
|
||||
}
|
||||
|
||||
if dsClient.APIKey != customAPIKey {
|
||||
t.Error("APIKey should be set from option")
|
||||
}
|
||||
|
||||
if dsClient.MaxTokens != 4000 {
|
||||
t.Error("MaxTokens should be 4000")
|
||||
}
|
||||
|
||||
// Verify DeepSeek default values are retained
|
||||
if dsClient.Provider != ProviderDeepSeek {
|
||||
t.Errorf("Provider should still be '%s'", ProviderDeepSeek)
|
||||
}
|
||||
|
||||
if dsClient.BaseURL != DefaultDeepSeekBaseURL {
|
||||
t.Errorf("BaseURL should still be '%s'", DefaultDeepSeekBaseURL)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Test SetAPIKey
|
||||
// ============================================================
|
||||
|
||||
func TestDeepSeekClient_SetAPIKey(t *testing.T) {
|
||||
mockLogger := NewMockLogger()
|
||||
client := NewDeepSeekClientWithOptions(
|
||||
WithLogger(mockLogger),
|
||||
)
|
||||
|
||||
dsClient := client.(*DeepSeekClient)
|
||||
|
||||
// Test setting API Key (default URL and Model)
|
||||
dsClient.SetAPIKey("sk-test-key-12345678", "", "")
|
||||
|
||||
if dsClient.APIKey != "sk-test-key-12345678" {
|
||||
t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", dsClient.APIKey)
|
||||
}
|
||||
|
||||
// Verify logging
|
||||
logs := mockLogger.GetLogsByLevel("INFO")
|
||||
if len(logs) == 0 {
|
||||
t.Error("should have logged API key setting")
|
||||
}
|
||||
|
||||
// Verify BaseURL and Model remain default
|
||||
if dsClient.BaseURL != DefaultDeepSeekBaseURL {
|
||||
t.Error("BaseURL should remain default")
|
||||
}
|
||||
|
||||
if dsClient.Model != DefaultDeepSeekModel {
|
||||
t.Error("Model should remain default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepSeekClient_SetAPIKey_WithCustomURL(t *testing.T) {
|
||||
mockLogger := NewMockLogger()
|
||||
client := NewDeepSeekClientWithOptions(
|
||||
WithLogger(mockLogger),
|
||||
)
|
||||
|
||||
dsClient := client.(*DeepSeekClient)
|
||||
|
||||
customURL := "https://custom.api.com/v1"
|
||||
dsClient.SetAPIKey("sk-test-key-12345678", customURL, "")
|
||||
|
||||
if dsClient.BaseURL != customURL {
|
||||
t.Errorf("BaseURL should be '%s', got '%s'", customURL, dsClient.BaseURL)
|
||||
}
|
||||
|
||||
// Verify logging
|
||||
logs := mockLogger.GetLogsByLevel("INFO")
|
||||
hasCustomURLLog := false
|
||||
for _, log := range logs {
|
||||
if log.Format == "🔧 [MCP] DeepSeek using custom BaseURL: %s" {
|
||||
hasCustomURLLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasCustomURLLog {
|
||||
t.Error("should have logged custom BaseURL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepSeekClient_SetAPIKey_WithCustomModel(t *testing.T) {
|
||||
mockLogger := NewMockLogger()
|
||||
client := NewDeepSeekClientWithOptions(
|
||||
WithLogger(mockLogger),
|
||||
)
|
||||
|
||||
dsClient := client.(*DeepSeekClient)
|
||||
|
||||
customModel := "deepseek-v3"
|
||||
dsClient.SetAPIKey("sk-test-key-12345678", "", customModel)
|
||||
|
||||
if dsClient.Model != customModel {
|
||||
t.Errorf("Model should be '%s', got '%s'", customModel, dsClient.Model)
|
||||
}
|
||||
|
||||
// Verify logging
|
||||
logs := mockLogger.GetLogsByLevel("INFO")
|
||||
hasCustomModelLog := false
|
||||
for _, log := range logs {
|
||||
if log.Format == "🔧 [MCP] DeepSeek using custom Model: %s" {
|
||||
hasCustomModelLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasCustomModelLog {
|
||||
t.Error("should have logged custom Model")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Test Integration Features
|
||||
// ============================================================
|
||||
|
||||
func TestDeepSeekClient_CallWithMessages_Success(t *testing.T) {
|
||||
mockHTTP := NewMockHTTPClient()
|
||||
mockHTTP.SetSuccessResponse("DeepSeek AI response")
|
||||
mockLogger := NewMockLogger()
|
||||
|
||||
client := NewDeepSeekClientWithOptions(
|
||||
WithHTTPClient(mockHTTP.ToHTTPClient()),
|
||||
WithLogger(mockLogger),
|
||||
WithAPIKey("sk-test-key"),
|
||||
)
|
||||
|
||||
result, err := client.CallWithMessages("system prompt", "user prompt")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("should not error: %v", err)
|
||||
}
|
||||
|
||||
if result != "DeepSeek AI response" {
|
||||
t.Errorf("expected 'DeepSeek AI response', got '%s'", result)
|
||||
}
|
||||
|
||||
// Verify request
|
||||
requests := mockHTTP.GetRequests()
|
||||
if len(requests) != 1 {
|
||||
t.Fatalf("expected 1 request, got %d", len(requests))
|
||||
}
|
||||
|
||||
req := requests[0]
|
||||
|
||||
// Verify URL
|
||||
expectedURL := DefaultDeepSeekBaseURL + "/chat/completions"
|
||||
if req.URL.String() != expectedURL {
|
||||
t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String())
|
||||
}
|
||||
|
||||
// Verify Authorization header
|
||||
authHeader := req.Header.Get("Authorization")
|
||||
if authHeader != "Bearer sk-test-key" {
|
||||
t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader)
|
||||
}
|
||||
|
||||
// Verify Content-Type
|
||||
if req.Header.Get("Content-Type") != "application/json" {
|
||||
t.Error("Content-Type should be application/json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeepSeekClient_Timeout(t *testing.T) {
|
||||
client := NewDeepSeekClientWithOptions(
|
||||
WithTimeout(30 * time.Second),
|
||||
)
|
||||
|
||||
dsClient := client.(*DeepSeekClient)
|
||||
|
||||
if dsClient.httpClient.Timeout != 30*time.Second {
|
||||
t.Errorf("expected timeout 30s, got %v", dsClient.httpClient.Timeout)
|
||||
}
|
||||
|
||||
// Test SetTimeout
|
||||
client.SetTimeout(60 * time.Second)
|
||||
|
||||
if dsClient.httpClient.Timeout != 60*time.Second {
|
||||
t.Errorf("expected timeout 60s after SetTimeout, got %v", dsClient.httpClient.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Test hooks Mechanism
|
||||
// ============================================================
|
||||
|
||||
func TestDeepSeekClient_HooksIntegration(t *testing.T) {
|
||||
client := NewDeepSeekClientWithOptions()
|
||||
dsClient := client.(*DeepSeekClient)
|
||||
|
||||
// Verify hooks point to dsClient itself (implements polymorphism)
|
||||
if dsClient.hooks != dsClient {
|
||||
t.Error("hooks should point to dsClient for polymorphism")
|
||||
}
|
||||
|
||||
// Verify buildUrl uses DeepSeek configuration
|
||||
url := dsClient.buildUrl()
|
||||
expectedURL := DefaultDeepSeekBaseURL + "/chat/completions"
|
||||
if url != expectedURL {
|
||||
t.Errorf("expected URL '%s', got '%s'", expectedURL, url)
|
||||
}
|
||||
}
|
||||
+9
-15
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"nofx/mcp"
|
||||
"nofx/mcp/provider"
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
@@ -24,7 +25,7 @@ func Example_backward_compatible() {
|
||||
|
||||
func Example_deepseek_backward_compatible() {
|
||||
// DeepSeek old code continues to work
|
||||
client := mcp.NewDeepSeekClient()
|
||||
client := provider.NewDeepSeekClient()
|
||||
client.SetAPIKey("sk-xxx", "", "")
|
||||
|
||||
result, _ := client.CallWithMessages("system", "user")
|
||||
@@ -141,12 +142,12 @@ func Example_custom_http_client() {
|
||||
|
||||
func Example_deepseek_new_api() {
|
||||
// Basic usage
|
||||
client := mcp.NewDeepSeekClientWithOptions(
|
||||
client := provider.NewDeepSeekClientWithOptions(
|
||||
mcp.WithAPIKey("sk-xxx"),
|
||||
)
|
||||
|
||||
// Advanced usage
|
||||
client = mcp.NewDeepSeekClientWithOptions(
|
||||
client = provider.NewDeepSeekClientWithOptions(
|
||||
mcp.WithAPIKey("sk-xxx"),
|
||||
mcp.WithLogger(&CustomLogger{}),
|
||||
mcp.WithTimeout(90*time.Second),
|
||||
@@ -163,12 +164,12 @@ func Example_deepseek_new_api() {
|
||||
|
||||
func Example_qwen_new_api() {
|
||||
// Basic usage
|
||||
client := mcp.NewQwenClientWithOptions(
|
||||
client := provider.NewQwenClientWithOptions(
|
||||
mcp.WithAPIKey("sk-xxx"),
|
||||
)
|
||||
|
||||
// Advanced usage
|
||||
client = mcp.NewQwenClientWithOptions(
|
||||
client = provider.NewQwenClientWithOptions(
|
||||
mcp.WithAPIKey("sk-xxx"),
|
||||
mcp.WithLogger(&CustomLogger{}),
|
||||
mcp.WithTimeout(90*time.Second),
|
||||
@@ -185,7 +186,7 @@ func Example_qwen_new_api() {
|
||||
func Example_trader_migration() {
|
||||
// Old code (continues to work)
|
||||
oldStyleClient := func(apiKey, customURL, customModel string) mcp.AIClient {
|
||||
client := mcp.NewDeepSeekClient()
|
||||
client := provider.NewDeepSeekClient()
|
||||
client.SetAPIKey(apiKey, customURL, customModel)
|
||||
return client
|
||||
}
|
||||
@@ -204,7 +205,7 @@ func Example_trader_migration() {
|
||||
opts = append(opts, mcp.WithModel(customModel))
|
||||
}
|
||||
|
||||
return mcp.NewDeepSeekClientWithOptions(opts...)
|
||||
return provider.NewDeepSeekClientWithOptions(opts...)
|
||||
}
|
||||
|
||||
// Both approaches work
|
||||
@@ -230,13 +231,7 @@ func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
|
||||
func Example_testing_with_mock() {
|
||||
// Use Mock during testing
|
||||
// mockHTTP := &MockHTTPClient{
|
||||
// Response: `{"choices":[{"message":{"content":"test response"}}]}`,
|
||||
// }
|
||||
|
||||
client := mcp.NewClient(
|
||||
// mcp.WithHTTPClient(mockHTTP), // Use mockHTTP in actual tests
|
||||
mcp.WithLogger(mcp.NewNoopLogger()), // Disable logging
|
||||
)
|
||||
|
||||
@@ -258,7 +253,6 @@ func Example_environment_specific() {
|
||||
// Production environment: structured logging + timeout protection
|
||||
prodClient := mcp.NewClient(
|
||||
mcp.WithDeepSeekConfig("sk-xxx"),
|
||||
// mcp.WithLogger(&ZapLogger{}), // Production-grade logging
|
||||
mcp.WithTimeout(30*time.Second),
|
||||
mcp.WithMaxRetries(3),
|
||||
)
|
||||
@@ -273,7 +267,7 @@ func Example_environment_specific() {
|
||||
|
||||
func Example_real_world_usage() {
|
||||
// Create client with complete configuration
|
||||
client := mcp.NewDeepSeekClientWithOptions(
|
||||
client := provider.NewDeepSeekClientWithOptions(
|
||||
mcp.WithAPIKey("sk-xxxxxxxxxx"),
|
||||
mcp.WithTimeout(60*time.Second),
|
||||
mcp.WithMaxRetries(5),
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderGemini = "gemini"
|
||||
DefaultGeminiBaseURL = "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
DefaultGeminiModel = "gemini-3-pro-preview"
|
||||
)
|
||||
|
||||
type GeminiClient struct {
|
||||
*Client
|
||||
}
|
||||
|
||||
// NewGeminiClient creates Gemini client (backward compatible)
|
||||
func NewGeminiClient() AIClient {
|
||||
return NewGeminiClientWithOptions()
|
||||
}
|
||||
|
||||
// NewGeminiClientWithOptions creates Gemini client (supports options pattern)
|
||||
func NewGeminiClientWithOptions(opts ...ClientOption) AIClient {
|
||||
// 1. Create Gemini preset options
|
||||
geminiOpts := []ClientOption{
|
||||
WithProvider(ProviderGemini),
|
||||
WithModel(DefaultGeminiModel),
|
||||
WithBaseURL(DefaultGeminiBaseURL),
|
||||
}
|
||||
|
||||
// 2. Merge user options (user options have higher priority)
|
||||
allOpts := append(geminiOpts, opts...)
|
||||
|
||||
// 3. Create base client
|
||||
baseClient := NewClient(allOpts...).(*Client)
|
||||
|
||||
// 4. Create Gemini client
|
||||
geminiClient := &GeminiClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
// 5. Set hooks to point to GeminiClient (implement dynamic dispatch)
|
||||
baseClient.hooks = geminiClient
|
||||
|
||||
return geminiClient
|
||||
}
|
||||
|
||||
func (c *GeminiClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
c.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
c.logger.Infof("🔧 [MCP] Gemini API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
c.BaseURL = customURL
|
||||
c.logger.Infof("🔧 [MCP] Gemini using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
c.logger.Infof("🔧 [MCP] Gemini using default BaseURL: %s", c.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
c.Model = customModel
|
||||
c.logger.Infof("🔧 [MCP] Gemini using custom Model: %s", customModel)
|
||||
} else {
|
||||
c.logger.Infof("🔧 [MCP] Gemini using default Model: %s", c.Model)
|
||||
}
|
||||
}
|
||||
|
||||
// Gemini OpenAI-compatible API uses standard Bearer auth
|
||||
func (c *GeminiClient) setAuthHeader(reqHeaders http.Header) {
|
||||
c.Client.setAuthHeader(reqHeaders)
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderGrok = "grok"
|
||||
DefaultGrokBaseURL = "https://api.x.ai/v1"
|
||||
DefaultGrokModel = "grok-3-latest"
|
||||
)
|
||||
|
||||
type GrokClient struct {
|
||||
*Client
|
||||
}
|
||||
|
||||
// NewGrokClient creates Grok client (backward compatible)
|
||||
func NewGrokClient() AIClient {
|
||||
return NewGrokClientWithOptions()
|
||||
}
|
||||
|
||||
// NewGrokClientWithOptions creates Grok client (supports options pattern)
|
||||
func NewGrokClientWithOptions(opts ...ClientOption) AIClient {
|
||||
// 1. Create Grok preset options
|
||||
grokOpts := []ClientOption{
|
||||
WithProvider(ProviderGrok),
|
||||
WithModel(DefaultGrokModel),
|
||||
WithBaseURL(DefaultGrokBaseURL),
|
||||
}
|
||||
|
||||
// 2. Merge user options (user options have higher priority)
|
||||
allOpts := append(grokOpts, opts...)
|
||||
|
||||
// 3. Create base client
|
||||
baseClient := NewClient(allOpts...).(*Client)
|
||||
|
||||
// 4. Create Grok client
|
||||
grokClient := &GrokClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
// 5. Set hooks to point to GrokClient (implement dynamic dispatch)
|
||||
baseClient.hooks = grokClient
|
||||
|
||||
return grokClient
|
||||
}
|
||||
|
||||
func (c *GrokClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
c.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
c.logger.Infof("🔧 [MCP] Grok API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
c.BaseURL = customURL
|
||||
c.logger.Infof("🔧 [MCP] Grok using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
c.logger.Infof("🔧 [MCP] Grok using default BaseURL: %s", c.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
c.Model = customModel
|
||||
c.logger.Infof("🔧 [MCP] Grok using custom Model: %s", customModel)
|
||||
} else {
|
||||
c.logger.Infof("🔧 [MCP] Grok using default Model: %s", c.Model)
|
||||
}
|
||||
}
|
||||
|
||||
// Grok uses standard OpenAI-compatible API with Bearer auth
|
||||
func (c *GrokClient) setAuthHeader(reqHeaders http.Header) {
|
||||
c.Client.setAuthHeader(reqHeaders)
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package mcp
|
||||
|
||||
import "net/http"
|
||||
|
||||
// ClientHooks is the dispatch interface used to implement per-provider
|
||||
// polymorphism without Go's lack of virtual methods.
|
||||
//
|
||||
// Each method can be overridden by an embedding struct (e.g. provider.ClaudeClient).
|
||||
// The base *Client provides OpenAI-compatible defaults; providers with a
|
||||
// different wire format (Anthropic, Gemini native, etc.) override only what
|
||||
// differs. All call-path methods in client.go invoke these via c.Hooks so
|
||||
// that the override is always picked up at runtime.
|
||||
type ClientHooks interface {
|
||||
// ── Simple CallWithMessages path ────────────────────────────────────────
|
||||
Call(systemPrompt, userPrompt string) (string, error)
|
||||
BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any
|
||||
|
||||
// ── Shared request plumbing ─────────────────────────────────────────────
|
||||
BuildUrl() string
|
||||
BuildRequest(url string, jsonData []byte) (*http.Request, error)
|
||||
SetAuthHeader(reqHeaders http.Header)
|
||||
MarshalRequestBody(requestBody map[string]any) ([]byte, error)
|
||||
|
||||
// ── Advanced (Request-object) path ──────────────────────────────────────
|
||||
// BuildRequestBodyFromRequest converts a *Request into the provider's
|
||||
// native wire-format map.
|
||||
BuildRequestBodyFromRequest(req *Request) map[string]any
|
||||
|
||||
// ParseMCPResponse extracts the plain-text reply from a non-streaming
|
||||
// response body.
|
||||
ParseMCPResponse(body []byte) (string, error)
|
||||
|
||||
// ParseMCPResponseFull extracts both text and tool calls.
|
||||
ParseMCPResponseFull(body []byte) (*LLMResponse, error)
|
||||
|
||||
IsRetryableError(err error) bool
|
||||
}
|
||||
+6
-39
@@ -1,10 +1,15 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientEmbedder is implemented by provider types that embed *Client,
|
||||
// allowing generic extraction of the underlying base client (e.g. for cloning).
|
||||
type ClientEmbedder interface {
|
||||
BaseClient() *Client
|
||||
}
|
||||
|
||||
// AIClient public AI client interface (for external use)
|
||||
type AIClient interface {
|
||||
SetAPIKey(apiKey string, customURL string, customModel string)
|
||||
@@ -21,41 +26,3 @@ type AIClient interface {
|
||||
// (LLMResponse.ToolCalls), but not both.
|
||||
CallWithRequestFull(req *Request) (*LLMResponse, error)
|
||||
}
|
||||
|
||||
// clientHooks is the internal dispatch interface used to implement per-provider
|
||||
// polymorphism without Go's lack of virtual methods.
|
||||
//
|
||||
// Each method can be overridden by an embedding struct (e.g. ClaudeClient).
|
||||
// The base *Client provides OpenAI-compatible defaults; providers with a
|
||||
// different wire format (Anthropic, Gemini native, etc.) override only what
|
||||
// differs. All call-path methods in client.go invoke these via c.hooks so
|
||||
// that the override is always picked up at runtime.
|
||||
type clientHooks interface {
|
||||
// ── Simple CallWithMessages path ────────────────────────────────────────
|
||||
call(systemPrompt, userPrompt string) (string, error)
|
||||
buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any
|
||||
|
||||
// ── Shared request plumbing ─────────────────────────────────────────────
|
||||
buildUrl() string
|
||||
buildRequest(url string, jsonData []byte) (*http.Request, error)
|
||||
setAuthHeader(reqHeaders http.Header)
|
||||
marshalRequestBody(requestBody map[string]any) ([]byte, error)
|
||||
|
||||
// ── Advanced (Request-object) path ──────────────────────────────────────
|
||||
// buildRequestBodyFromRequest converts a *Request into the provider's
|
||||
// native wire-format map. Providers that use a different protocol (e.g.
|
||||
// Anthropic uses "input_schema" for tools, "tool_use" content blocks, and
|
||||
// a top-level "system" field) override this method.
|
||||
buildRequestBodyFromRequest(req *Request) map[string]any
|
||||
|
||||
// parseMCPResponse extracts the plain-text reply from a non-streaming
|
||||
// response body.
|
||||
parseMCPResponse(body []byte) (string, error)
|
||||
|
||||
// parseMCPResponseFull extracts both text and tool calls. Providers whose
|
||||
// response envelope differs from the OpenAI choices[] structure (e.g.
|
||||
// Anthropic content[] with tool_use blocks) override this method.
|
||||
parseMCPResponseFull(body []byte) (*LLMResponse, error)
|
||||
|
||||
isRetryableError(err error) bool
|
||||
}
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderKimi = "kimi"
|
||||
DefaultKimiBaseURL = "https://api.moonshot.ai/v1" // Global endpoint (use api.moonshot.cn for China)
|
||||
DefaultKimiModel = "moonshot-v1-auto"
|
||||
)
|
||||
|
||||
type KimiClient struct {
|
||||
*Client
|
||||
}
|
||||
|
||||
// NewKimiClient creates Kimi (Moonshot) client (backward compatible)
|
||||
func NewKimiClient() AIClient {
|
||||
return NewKimiClientWithOptions()
|
||||
}
|
||||
|
||||
// NewKimiClientWithOptions creates Kimi client (supports options pattern)
|
||||
func NewKimiClientWithOptions(opts ...ClientOption) AIClient {
|
||||
// 1. Create Kimi preset options
|
||||
kimiOpts := []ClientOption{
|
||||
WithProvider(ProviderKimi),
|
||||
WithModel(DefaultKimiModel),
|
||||
WithBaseURL(DefaultKimiBaseURL),
|
||||
}
|
||||
|
||||
// 2. Merge user options (user options have higher priority)
|
||||
allOpts := append(kimiOpts, opts...)
|
||||
|
||||
// 3. Create base client
|
||||
baseClient := NewClient(allOpts...).(*Client)
|
||||
|
||||
// 4. Create Kimi client
|
||||
kimiClient := &KimiClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
// 5. Set hooks to point to KimiClient (implement dynamic dispatch)
|
||||
baseClient.hooks = kimiClient
|
||||
|
||||
return kimiClient
|
||||
}
|
||||
|
||||
func (c *KimiClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
c.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
c.logger.Infof("🔧 [MCP] Kimi API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
c.BaseURL = customURL
|
||||
c.logger.Infof("🔧 [MCP] Kimi using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
c.logger.Infof("🔧 [MCP] Kimi using default BaseURL: %s", c.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
c.Model = customModel
|
||||
c.logger.Infof("🔧 [MCP] Kimi using custom Model: %s", customModel)
|
||||
} else {
|
||||
c.logger.Infof("🔧 [MCP] Kimi using default Model: %s", c.Model)
|
||||
}
|
||||
}
|
||||
|
||||
// Kimi uses standard OpenAI-compatible API, so we just use the base client methods
|
||||
func (c *KimiClient) setAuthHeader(reqHeaders http.Header) {
|
||||
c.Client.setAuthHeader(reqHeaders)
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderMiniMax = "minimax"
|
||||
DefaultMiniMaxBaseURL = "https://api.minimax.io/v1"
|
||||
DefaultMiniMaxModel = "MiniMax-M2.5"
|
||||
)
|
||||
|
||||
type MiniMaxClient struct {
|
||||
*Client
|
||||
}
|
||||
|
||||
// NewMiniMaxClient creates MiniMax client (backward compatible)
|
||||
func NewMiniMaxClient() AIClient {
|
||||
return NewMiniMaxClientWithOptions()
|
||||
}
|
||||
|
||||
// NewMiniMaxClientWithOptions creates MiniMax client (supports options pattern)
|
||||
//
|
||||
// Usage examples:
|
||||
//
|
||||
// // Basic usage
|
||||
// client := mcp.NewMiniMaxClientWithOptions()
|
||||
//
|
||||
// // Custom configuration
|
||||
// client := mcp.NewMiniMaxClientWithOptions(
|
||||
// mcp.WithAPIKey("sk-xxx"),
|
||||
// mcp.WithLogger(customLogger),
|
||||
// mcp.WithTimeout(60*time.Second),
|
||||
// )
|
||||
func NewMiniMaxClientWithOptions(opts ...ClientOption) AIClient {
|
||||
// 1. Create MiniMax preset options
|
||||
minimaxOpts := []ClientOption{
|
||||
WithProvider(ProviderMiniMax),
|
||||
WithModel(DefaultMiniMaxModel),
|
||||
WithBaseURL(DefaultMiniMaxBaseURL),
|
||||
}
|
||||
|
||||
// 2. Merge user options (user options have higher priority)
|
||||
allOpts := append(minimaxOpts, opts...)
|
||||
|
||||
// 3. Create base client
|
||||
baseClient := NewClient(allOpts...).(*Client)
|
||||
|
||||
// 4. Create MiniMax client
|
||||
minimaxClient := &MiniMaxClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
// 5. Set hooks to point to MiniMaxClient (implement dynamic dispatch)
|
||||
baseClient.hooks = minimaxClient
|
||||
|
||||
return minimaxClient
|
||||
}
|
||||
|
||||
func (c *MiniMaxClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
c.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
c.logger.Infof("🔧 [MCP] MiniMax API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
c.BaseURL = customURL
|
||||
c.logger.Infof("🔧 [MCP] MiniMax using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
c.logger.Infof("🔧 [MCP] MiniMax using default BaseURL: %s", c.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
c.Model = customModel
|
||||
c.logger.Infof("🔧 [MCP] MiniMax using custom Model: %s", customModel)
|
||||
} else {
|
||||
c.logger.Infof("🔧 [MCP] MiniMax using default Model: %s", c.Model)
|
||||
}
|
||||
}
|
||||
|
||||
// MiniMax uses standard OpenAI-compatible API with Bearer auth
|
||||
func (c *MiniMaxClient) setAuthHeader(reqHeaders http.Header) {
|
||||
c.Client.setAuthHeader(reqHeaders)
|
||||
}
|
||||
@@ -1,272 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// Test MiniMaxClient Creation and Configuration
|
||||
// ============================================================
|
||||
|
||||
func TestNewMiniMaxClient_Default(t *testing.T) {
|
||||
client := NewMiniMaxClient()
|
||||
|
||||
if client == nil {
|
||||
t.Fatal("client should not be nil")
|
||||
}
|
||||
|
||||
// Type assertion check
|
||||
mmClient, ok := client.(*MiniMaxClient)
|
||||
if !ok {
|
||||
t.Fatal("client should be *MiniMaxClient")
|
||||
}
|
||||
|
||||
// Verify default values
|
||||
if mmClient.Provider != ProviderMiniMax {
|
||||
t.Errorf("Provider should be '%s', got '%s'", ProviderMiniMax, mmClient.Provider)
|
||||
}
|
||||
|
||||
if mmClient.BaseURL != DefaultMiniMaxBaseURL {
|
||||
t.Errorf("BaseURL should be '%s', got '%s'", DefaultMiniMaxBaseURL, mmClient.BaseURL)
|
||||
}
|
||||
|
||||
if mmClient.Model != DefaultMiniMaxModel {
|
||||
t.Errorf("Model should be '%s', got '%s'", DefaultMiniMaxModel, mmClient.Model)
|
||||
}
|
||||
|
||||
if mmClient.logger == nil {
|
||||
t.Error("logger should not be nil")
|
||||
}
|
||||
|
||||
if mmClient.httpClient == nil {
|
||||
t.Error("httpClient should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMiniMaxClientWithOptions(t *testing.T) {
|
||||
mockLogger := NewMockLogger()
|
||||
customModel := "MiniMax-M2.5-highspeed"
|
||||
customAPIKey := "sk-custom-key"
|
||||
|
||||
client := NewMiniMaxClientWithOptions(
|
||||
WithLogger(mockLogger),
|
||||
WithModel(customModel),
|
||||
WithAPIKey(customAPIKey),
|
||||
WithMaxTokens(4000),
|
||||
)
|
||||
|
||||
mmClient := client.(*MiniMaxClient)
|
||||
|
||||
// Verify custom options are applied
|
||||
if mmClient.logger != mockLogger {
|
||||
t.Error("logger should be set from option")
|
||||
}
|
||||
|
||||
if mmClient.Model != customModel {
|
||||
t.Error("Model should be set from option")
|
||||
}
|
||||
|
||||
if mmClient.APIKey != customAPIKey {
|
||||
t.Error("APIKey should be set from option")
|
||||
}
|
||||
|
||||
if mmClient.MaxTokens != 4000 {
|
||||
t.Error("MaxTokens should be 4000")
|
||||
}
|
||||
|
||||
// Verify MiniMax default values are retained
|
||||
if mmClient.Provider != ProviderMiniMax {
|
||||
t.Errorf("Provider should still be '%s'", ProviderMiniMax)
|
||||
}
|
||||
|
||||
if mmClient.BaseURL != DefaultMiniMaxBaseURL {
|
||||
t.Errorf("BaseURL should still be '%s'", DefaultMiniMaxBaseURL)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Test SetAPIKey
|
||||
// ============================================================
|
||||
|
||||
func TestMiniMaxClient_SetAPIKey(t *testing.T) {
|
||||
mockLogger := NewMockLogger()
|
||||
client := NewMiniMaxClientWithOptions(
|
||||
WithLogger(mockLogger),
|
||||
)
|
||||
|
||||
mmClient := client.(*MiniMaxClient)
|
||||
|
||||
// Test setting API Key (default URL and Model)
|
||||
mmClient.SetAPIKey("sk-test-key-12345678", "", "")
|
||||
|
||||
if mmClient.APIKey != "sk-test-key-12345678" {
|
||||
t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", mmClient.APIKey)
|
||||
}
|
||||
|
||||
// Verify logging
|
||||
logs := mockLogger.GetLogsByLevel("INFO")
|
||||
if len(logs) == 0 {
|
||||
t.Error("should have logged API key setting")
|
||||
}
|
||||
|
||||
// Verify BaseURL and Model remain default
|
||||
if mmClient.BaseURL != DefaultMiniMaxBaseURL {
|
||||
t.Error("BaseURL should remain default")
|
||||
}
|
||||
|
||||
if mmClient.Model != DefaultMiniMaxModel {
|
||||
t.Error("Model should remain default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiniMaxClient_SetAPIKey_WithCustomURL(t *testing.T) {
|
||||
mockLogger := NewMockLogger()
|
||||
client := NewMiniMaxClientWithOptions(
|
||||
WithLogger(mockLogger),
|
||||
)
|
||||
|
||||
mmClient := client.(*MiniMaxClient)
|
||||
|
||||
customURL := "https://api.minimaxi.com/v1"
|
||||
mmClient.SetAPIKey("sk-test-key-12345678", customURL, "")
|
||||
|
||||
if mmClient.BaseURL != customURL {
|
||||
t.Errorf("BaseURL should be '%s', got '%s'", customURL, mmClient.BaseURL)
|
||||
}
|
||||
|
||||
// Verify logging
|
||||
logs := mockLogger.GetLogsByLevel("INFO")
|
||||
hasCustomURLLog := false
|
||||
for _, log := range logs {
|
||||
if log.Format == "🔧 [MCP] MiniMax using custom BaseURL: %s" {
|
||||
hasCustomURLLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasCustomURLLog {
|
||||
t.Error("should have logged custom BaseURL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiniMaxClient_SetAPIKey_WithCustomModel(t *testing.T) {
|
||||
mockLogger := NewMockLogger()
|
||||
client := NewMiniMaxClientWithOptions(
|
||||
WithLogger(mockLogger),
|
||||
)
|
||||
|
||||
mmClient := client.(*MiniMaxClient)
|
||||
|
||||
customModel := "MiniMax-M2.5-highspeed"
|
||||
mmClient.SetAPIKey("sk-test-key-12345678", "", customModel)
|
||||
|
||||
if mmClient.Model != customModel {
|
||||
t.Errorf("Model should be '%s', got '%s'", customModel, mmClient.Model)
|
||||
}
|
||||
|
||||
// Verify logging
|
||||
logs := mockLogger.GetLogsByLevel("INFO")
|
||||
hasCustomModelLog := false
|
||||
for _, log := range logs {
|
||||
if log.Format == "🔧 [MCP] MiniMax using custom Model: %s" {
|
||||
hasCustomModelLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasCustomModelLog {
|
||||
t.Error("should have logged custom Model")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Test Integration Features
|
||||
// ============================================================
|
||||
|
||||
func TestMiniMaxClient_CallWithMessages_Success(t *testing.T) {
|
||||
mockHTTP := NewMockHTTPClient()
|
||||
mockHTTP.SetSuccessResponse("MiniMax AI response")
|
||||
mockLogger := NewMockLogger()
|
||||
|
||||
client := NewMiniMaxClientWithOptions(
|
||||
WithHTTPClient(mockHTTP.ToHTTPClient()),
|
||||
WithLogger(mockLogger),
|
||||
WithAPIKey("sk-test-key"),
|
||||
)
|
||||
|
||||
result, err := client.CallWithMessages("system prompt", "user prompt")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("should not error: %v", err)
|
||||
}
|
||||
|
||||
if result != "MiniMax AI response" {
|
||||
t.Errorf("expected 'MiniMax AI response', got '%s'", result)
|
||||
}
|
||||
|
||||
// Verify request
|
||||
requests := mockHTTP.GetRequests()
|
||||
if len(requests) != 1 {
|
||||
t.Fatalf("expected 1 request, got %d", len(requests))
|
||||
}
|
||||
|
||||
req := requests[0]
|
||||
|
||||
// Verify URL
|
||||
expectedURL := DefaultMiniMaxBaseURL + "/chat/completions"
|
||||
if req.URL.String() != expectedURL {
|
||||
t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String())
|
||||
}
|
||||
|
||||
// Verify Authorization header
|
||||
authHeader := req.Header.Get("Authorization")
|
||||
if authHeader != "Bearer sk-test-key" {
|
||||
t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader)
|
||||
}
|
||||
|
||||
// Verify Content-Type
|
||||
if req.Header.Get("Content-Type") != "application/json" {
|
||||
t.Error("Content-Type should be application/json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiniMaxClient_Timeout(t *testing.T) {
|
||||
client := NewMiniMaxClientWithOptions(
|
||||
WithTimeout(30 * time.Second),
|
||||
)
|
||||
|
||||
mmClient := client.(*MiniMaxClient)
|
||||
|
||||
if mmClient.httpClient.Timeout != 30*time.Second {
|
||||
t.Errorf("expected timeout 30s, got %v", mmClient.httpClient.Timeout)
|
||||
}
|
||||
|
||||
// Test SetTimeout
|
||||
client.SetTimeout(60 * time.Second)
|
||||
|
||||
if mmClient.httpClient.Timeout != 60*time.Second {
|
||||
t.Errorf("expected timeout 60s after SetTimeout, got %v", mmClient.httpClient.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Test hooks Mechanism
|
||||
// ============================================================
|
||||
|
||||
func TestMiniMaxClient_HooksIntegration(t *testing.T) {
|
||||
client := NewMiniMaxClientWithOptions()
|
||||
mmClient := client.(*MiniMaxClient)
|
||||
|
||||
// Verify hooks point to mmClient itself (implements polymorphism)
|
||||
if mmClient.hooks != mmClient {
|
||||
t.Error("hooks should point to mmClient for polymorphism")
|
||||
}
|
||||
|
||||
// Verify buildUrl uses MiniMax configuration
|
||||
url := mmClient.buildUrl()
|
||||
expectedURL := DefaultMiniMaxBaseURL + "/chat/completions"
|
||||
if url != expectedURL {
|
||||
t.Errorf("expected URL '%s', got '%s'", expectedURL, url)
|
||||
}
|
||||
}
|
||||
+21
-9
@@ -247,7 +247,7 @@ func NewMockClientHooks() *MockClientHooks {
|
||||
return &MockClientHooks{}
|
||||
}
|
||||
|
||||
func (m *MockClientHooks) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
|
||||
func (m *MockClientHooks) BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
|
||||
m.BuildRequestBodyCalled++
|
||||
if m.BuildRequestBodyFunc != nil {
|
||||
return m.BuildRequestBodyFunc(systemPrompt, userPrompt)
|
||||
@@ -261,7 +261,7 @@ func (m *MockClientHooks) buildMCPRequestBody(systemPrompt, userPrompt string) m
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockClientHooks) buildUrl() string {
|
||||
func (m *MockClientHooks) BuildUrl() string {
|
||||
m.BuildUrlCalled++
|
||||
if m.BuildUrlFunc != nil {
|
||||
return m.BuildUrlFunc()
|
||||
@@ -269,12 +269,12 @@ func (m *MockClientHooks) buildUrl() string {
|
||||
return "https://api.test.com/chat/completions"
|
||||
}
|
||||
|
||||
func (m *MockClientHooks) setAuthHeader(headers http.Header) {
|
||||
func (m *MockClientHooks) SetAuthHeader(headers http.Header) {
|
||||
m.SetAuthHeaderCalled++
|
||||
headers.Set("Authorization", "Bearer test-key")
|
||||
}
|
||||
|
||||
func (m *MockClientHooks) marshalRequestBody(body map[string]any) ([]byte, error) {
|
||||
func (m *MockClientHooks) MarshalRequestBody(body map[string]any) ([]byte, error) {
|
||||
m.MarshalRequestCalled++
|
||||
if m.MarshalRequestBodyFunc != nil {
|
||||
return m.MarshalRequestBodyFunc(body)
|
||||
@@ -282,7 +282,7 @@ func (m *MockClientHooks) marshalRequestBody(body map[string]any) ([]byte, error
|
||||
return json.Marshal(body)
|
||||
}
|
||||
|
||||
func (m *MockClientHooks) parseMCPResponse(body []byte) (string, error) {
|
||||
func (m *MockClientHooks) ParseMCPResponse(body []byte) (string, error) {
|
||||
m.ParseResponseCalled++
|
||||
if m.ParseResponseFunc != nil {
|
||||
return m.ParseResponseFunc(body)
|
||||
@@ -290,7 +290,15 @@ func (m *MockClientHooks) parseMCPResponse(body []byte) (string, error) {
|
||||
return "mocked response", nil
|
||||
}
|
||||
|
||||
func (m *MockClientHooks) isRetryableError(err error) bool {
|
||||
func (m *MockClientHooks) ParseMCPResponseFull(body []byte) (*LLMResponse, error) {
|
||||
r, err := m.ParseMCPResponse(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &LLMResponse{Content: r}, nil
|
||||
}
|
||||
|
||||
func (m *MockClientHooks) IsRetryableError(err error) bool {
|
||||
m.IsRetryableErrorCalled++
|
||||
if m.IsRetryableErrorFunc != nil {
|
||||
return m.IsRetryableErrorFunc(err)
|
||||
@@ -298,13 +306,17 @@ func (m *MockClientHooks) isRetryableError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MockClientHooks) buildRequest(url string, jsonData []byte) (*http.Request, error) {
|
||||
func (m *MockClientHooks) BuildRequest(url string, jsonData []byte) (*http.Request, error) {
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
m.setAuthHeader(req.Header)
|
||||
m.SetAuthHeader(req.Header)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (m *MockClientHooks) call(systemPrompt, userPrompt string) (string, error) {
|
||||
func (m *MockClientHooks) Call(systemPrompt, userPrompt string) (string, error) {
|
||||
return "mocked call result", nil
|
||||
}
|
||||
|
||||
func (m *MockClientHooks) BuildRequestBodyFromRequest(req *Request) map[string]any {
|
||||
return map[string]any{"model": "test-model"}
|
||||
}
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderOpenAI = "openai"
|
||||
DefaultOpenAIBaseURL = "https://api.openai.com/v1"
|
||||
DefaultOpenAIModel = "gpt-5.4"
|
||||
)
|
||||
|
||||
type OpenAIClient struct {
|
||||
*Client
|
||||
}
|
||||
|
||||
// NewOpenAIClient creates OpenAI client (backward compatible)
|
||||
func NewOpenAIClient() AIClient {
|
||||
return NewOpenAIClientWithOptions()
|
||||
}
|
||||
|
||||
// NewOpenAIClientWithOptions creates OpenAI client (supports options pattern)
|
||||
func NewOpenAIClientWithOptions(opts ...ClientOption) AIClient {
|
||||
// 1. Create OpenAI preset options
|
||||
openaiOpts := []ClientOption{
|
||||
WithProvider(ProviderOpenAI),
|
||||
WithModel(DefaultOpenAIModel),
|
||||
WithBaseURL(DefaultOpenAIBaseURL),
|
||||
}
|
||||
|
||||
// 2. Merge user options (user options have higher priority)
|
||||
allOpts := append(openaiOpts, opts...)
|
||||
|
||||
// 3. Create base client
|
||||
baseClient := NewClient(allOpts...).(*Client)
|
||||
|
||||
// 4. Create OpenAI client
|
||||
openaiClient := &OpenAIClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
// 5. Set hooks to point to OpenAIClient (implement dynamic dispatch)
|
||||
baseClient.hooks = openaiClient
|
||||
|
||||
return openaiClient
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
c.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
c.logger.Infof("🔧 [MCP] OpenAI API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
c.BaseURL = customURL
|
||||
c.logger.Infof("🔧 [MCP] OpenAI using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
c.logger.Infof("🔧 [MCP] OpenAI using default BaseURL: %s", c.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
c.Model = customModel
|
||||
c.logger.Infof("🔧 [MCP] OpenAI using custom Model: %s", customModel)
|
||||
} else {
|
||||
c.logger.Infof("🔧 [MCP] OpenAI using default Model: %s", c.Model)
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI uses standard Bearer auth
|
||||
func (c *OpenAIClient) setAuthHeader(reqHeaders http.Header) {
|
||||
c.Client.setAuthHeader(reqHeaders)
|
||||
}
|
||||
+3
-77
@@ -279,8 +279,8 @@ func TestOptionsWithNewClient(t *testing.T) {
|
||||
t.Error("Model should be set from options")
|
||||
}
|
||||
|
||||
if c.logger != mockLogger {
|
||||
t.Error("logger should be set from options")
|
||||
if c.Log != mockLogger {
|
||||
t.Error("Log should be set from options")
|
||||
}
|
||||
|
||||
if c.MaxTokens != 4000 {
|
||||
@@ -288,78 +288,4 @@ func TestOptionsWithNewClient(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionsWithDeepSeekClient(t *testing.T) {
|
||||
mockLogger := NewMockLogger()
|
||||
|
||||
client := NewDeepSeekClientWithOptions(
|
||||
WithAPIKey("sk-deepseek-key"),
|
||||
WithLogger(mockLogger),
|
||||
WithMaxTokens(5000),
|
||||
)
|
||||
|
||||
dsClient := client.(*DeepSeekClient)
|
||||
|
||||
// Verify DeepSeek default values
|
||||
if dsClient.Provider != ProviderDeepSeek {
|
||||
t.Error("Provider should be DeepSeek")
|
||||
}
|
||||
|
||||
if dsClient.BaseURL != DefaultDeepSeekBaseURL {
|
||||
t.Error("BaseURL should be DeepSeek default")
|
||||
}
|
||||
|
||||
if dsClient.Model != DefaultDeepSeekModel {
|
||||
t.Error("Model should be DeepSeek default")
|
||||
}
|
||||
|
||||
// Verify custom options
|
||||
if dsClient.APIKey != "sk-deepseek-key" {
|
||||
t.Error("APIKey should be set from options")
|
||||
}
|
||||
|
||||
if dsClient.logger != mockLogger {
|
||||
t.Error("logger should be set from options")
|
||||
}
|
||||
|
||||
if dsClient.MaxTokens != 5000 {
|
||||
t.Error("MaxTokens should be 5000")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionsWithQwenClient(t *testing.T) {
|
||||
mockLogger := NewMockLogger()
|
||||
|
||||
client := NewQwenClientWithOptions(
|
||||
WithAPIKey("sk-qwen-key"),
|
||||
WithLogger(mockLogger),
|
||||
WithMaxTokens(6000),
|
||||
)
|
||||
|
||||
qwenClient := client.(*QwenClient)
|
||||
|
||||
// Verify Qwen default values
|
||||
if qwenClient.Provider != ProviderQwen {
|
||||
t.Error("Provider should be Qwen")
|
||||
}
|
||||
|
||||
if qwenClient.BaseURL != DefaultQwenBaseURL {
|
||||
t.Error("BaseURL should be Qwen default")
|
||||
}
|
||||
|
||||
if qwenClient.Model != DefaultQwenModel {
|
||||
t.Error("Model should be Qwen default")
|
||||
}
|
||||
|
||||
// Verify custom options
|
||||
if qwenClient.APIKey != "sk-qwen-key" {
|
||||
t.Error("APIKey should be set from options")
|
||||
}
|
||||
|
||||
if qwenClient.logger != mockLogger {
|
||||
t.Error("logger should be set from options")
|
||||
}
|
||||
|
||||
if qwenClient.MaxTokens != 6000 {
|
||||
t.Error("MaxTokens should be 6000")
|
||||
}
|
||||
}
|
||||
// Provider-specific option tests are in mcp/provider/options_test.go
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package mcp
|
||||
package payment
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
@@ -14,12 +14,13 @@ import (
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"golang.org/x/crypto/sha3"
|
||||
|
||||
"nofx/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderBlockRunBase = "blockrun-base"
|
||||
DefaultBlockRunBaseURL = "https://blockrun.ai"
|
||||
DefaultBlockRunModel = "gpt-5.4"
|
||||
DefaultBlockRunModel = "gpt-5.4"
|
||||
BlockRunChatEndpoint = "/api/v1/chat/completions"
|
||||
BaseUSDCContract = "0x833589fCD6eDb6E08f4c7C32D4f71b54bdA02913"
|
||||
BaseChainID int64 = 8453
|
||||
@@ -28,10 +29,16 @@ const (
|
||||
|
||||
// EIP-712 type hashes for USDC TransferWithAuthorization (ERC-3009)
|
||||
var (
|
||||
eip712DomainTypeHash = keccak256String("EIP712Domain(string name,string version,uint256 chainId,address verifyingContract)")
|
||||
eip712DomainTypeHash = keccak256String("EIP712Domain(string name,string version,uint256 chainId,address verifyingContract)")
|
||||
transferWithAuthTypeHash = keccak256String("TransferWithAuthorization(address from,address to,uint256 value,uint256 validAfter,uint256 validBefore,bytes32 nonce)")
|
||||
)
|
||||
|
||||
func init() {
|
||||
mcp.RegisterProvider(mcp.ProviderBlockRunBase, func(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
return NewBlockRunBaseClientWithOptions(opts...)
|
||||
})
|
||||
}
|
||||
|
||||
func keccak256String(s string) []byte {
|
||||
h := sha3.NewLegacyKeccak256()
|
||||
h.Write([]byte(s))
|
||||
@@ -48,71 +55,72 @@ func keccak256Bytes(data ...[]byte) []byte {
|
||||
|
||||
// BlockRunBaseClient implements AIClient using BlockRun's API with x402 v2 EIP-712 payment signing.
|
||||
type BlockRunBaseClient struct {
|
||||
*Client
|
||||
*mcp.Client
|
||||
privateKey *ecdsa.PrivateKey
|
||||
}
|
||||
|
||||
func (c *BlockRunBaseClient) BaseClient() *mcp.Client { return c.Client }
|
||||
|
||||
// NewBlockRunBaseClient creates a BlockRun Base wallet client (backward compatible).
|
||||
func NewBlockRunBaseClient() AIClient {
|
||||
func NewBlockRunBaseClient() mcp.AIClient {
|
||||
return NewBlockRunBaseClientWithOptions()
|
||||
}
|
||||
|
||||
// NewBlockRunBaseClientWithOptions creates a BlockRun Base wallet client.
|
||||
func NewBlockRunBaseClientWithOptions(opts ...ClientOption) AIClient {
|
||||
baseOpts := []ClientOption{
|
||||
WithProvider(ProviderBlockRunBase),
|
||||
WithModel(DefaultBlockRunModel),
|
||||
WithBaseURL(DefaultBlockRunBaseURL),
|
||||
func NewBlockRunBaseClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
baseOpts := []mcp.ClientOption{
|
||||
mcp.WithProvider(mcp.ProviderBlockRunBase),
|
||||
mcp.WithModel(DefaultBlockRunModel),
|
||||
mcp.WithBaseURL(DefaultBlockRunBaseURL),
|
||||
}
|
||||
allOpts := append(baseOpts, opts...)
|
||||
baseClient := NewClient(allOpts...).(*Client)
|
||||
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
|
||||
baseClient.UseFullURL = true
|
||||
baseClient.BaseURL = DefaultBlockRunBaseURL + BlockRunChatEndpoint
|
||||
|
||||
c := &BlockRunBaseClient{Client: baseClient}
|
||||
baseClient.hooks = c
|
||||
baseClient.Hooks = c
|
||||
return c
|
||||
}
|
||||
|
||||
// SetAPIKey stores the EVM private key (hex, with or without 0x prefix).
|
||||
// customModel selects the AI model to use (e.g. "claude-sonnet-4.6"); empty means default.
|
||||
func (c *BlockRunBaseClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
hexKey := strings.TrimPrefix(apiKey, "0x")
|
||||
privKey, err := crypto.HexToECDSA(hexKey)
|
||||
if err != nil {
|
||||
c.logger.Warnf("⚠️ [MCP] BlockRun Base: invalid private key: %v", err)
|
||||
c.Log.Warnf("⚠️ [MCP] BlockRun Base: invalid private key: %v", err)
|
||||
} else {
|
||||
c.privateKey = privKey
|
||||
c.APIKey = apiKey
|
||||
addr := crypto.PubkeyToAddress(privKey.PublicKey).Hex()
|
||||
c.logger.Infof("🔧 [MCP] BlockRun Base wallet: %s", addr)
|
||||
c.Log.Infof("🔧 [MCP] BlockRun Base wallet: %s", addr)
|
||||
}
|
||||
if customModel != "" {
|
||||
c.Model = customModel
|
||||
c.logger.Infof("🔧 [MCP] BlockRun Base model: %s", customModel)
|
||||
c.Log.Infof("🔧 [MCP] BlockRun Base model: %s", customModel)
|
||||
} else {
|
||||
c.logger.Infof("🔧 [MCP] BlockRun Base model: %s", DefaultBlockRunModel)
|
||||
c.Log.Infof("🔧 [MCP] BlockRun Base model: %s", DefaultBlockRunModel)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *BlockRunBaseClient) setAuthHeader(h http.Header) { x402SetAuthHeader(h) }
|
||||
func (c *BlockRunBaseClient) SetAuthHeader(h http.Header) { X402SetAuthHeader(h) }
|
||||
|
||||
func (c *BlockRunBaseClient) call(systemPrompt, userPrompt string) (string, error) {
|
||||
return x402Call(c.Client, c.signPayment, "BlockRun Base", systemPrompt, userPrompt)
|
||||
func (c *BlockRunBaseClient) Call(systemPrompt, userPrompt string) (string, error) {
|
||||
return X402Call(c.Client, c.signPayment, "BlockRun Base", systemPrompt, userPrompt)
|
||||
}
|
||||
|
||||
func (c *BlockRunBaseClient) CallWithRequestFull(req *Request) (*LLMResponse, error) {
|
||||
return x402CallFull(c.Client, c.signPayment, "BlockRun Base", req)
|
||||
func (c *BlockRunBaseClient) CallWithRequestFull(req *mcp.Request) (*mcp.LLMResponse, error) {
|
||||
return X402CallFull(c.Client, c.signPayment, "BlockRun Base", req)
|
||||
}
|
||||
|
||||
// signPayment parses the Payment-Required header (x402 v2) and returns a signed payment value.
|
||||
func (c *BlockRunBaseClient) signPayment(paymentHeaderB64 string) (string, error) {
|
||||
return signBasePaymentHeader(c.privateKey, paymentHeaderB64, "BlockRun Base")
|
||||
return SignBasePaymentHeader(c.privateKey, paymentHeaderB64, "BlockRun Base")
|
||||
}
|
||||
|
||||
// signX402Payment is the shared EIP-712 signing logic for x402 v2 on Base USDC.
|
||||
// SignX402Payment is the shared EIP-712 signing logic for x402 v2 on Base USDC.
|
||||
// Used by both BlockRunBaseClient and Claw402Client.
|
||||
func signX402Payment(privateKey *ecdsa.PrivateKey, senderAddr string, opt x402AcceptOption, resource *x402Resource) (string, error) {
|
||||
func SignX402Payment(privateKey *ecdsa.PrivateKey, senderAddr string, opt X402AcceptOption, resource *X402Resource) (string, error) {
|
||||
recipient := opt.PayTo
|
||||
amount := opt.Amount
|
||||
network := opt.Network
|
||||
@@ -224,7 +232,6 @@ func signX402Payment(privateKey *ecdsa.PrivateKey, senderAddr string, opt x402Ac
|
||||
|
||||
// buildDomainSeparatorDynamic builds the EIP-712 domain separator using runtime values.
|
||||
func buildDomainSeparatorDynamic(name, version, network, asset string) ([]byte, error) {
|
||||
// Extract chain ID from network string like "eip155:8453"
|
||||
chainID := new(big.Int).SetInt64(BaseChainID)
|
||||
if strings.HasPrefix(network, "eip155:") {
|
||||
parts := strings.SplitN(network, ":", 2)
|
||||
@@ -311,8 +318,6 @@ func hexToBytes32(s string) ([]byte, error) {
|
||||
|
||||
func parseBigInt(s string) (*big.Int, error) {
|
||||
n := new(big.Int)
|
||||
// Only treat as hex when explicitly prefixed with 0x/0X.
|
||||
// x402 amounts are always decimal strings (e.g. "3000" = 0.003 USDC).
|
||||
if strings.HasPrefix(s, "0x") || strings.HasPrefix(s, "0X") {
|
||||
if _, ok := n.SetString(s[2:], 16); ok {
|
||||
return n, nil
|
||||
@@ -335,11 +340,11 @@ func leftPad32(b []byte) []byte {
|
||||
return padded
|
||||
}
|
||||
|
||||
// buildUrl returns the full BlockRun endpoint URL.
|
||||
func (c *BlockRunBaseClient) buildUrl() string {
|
||||
// BuildUrl returns the full BlockRun endpoint URL.
|
||||
func (c *BlockRunBaseClient) BuildUrl() string {
|
||||
return DefaultBlockRunBaseURL + BlockRunChatEndpoint
|
||||
}
|
||||
|
||||
func (c *BlockRunBaseClient) buildRequest(url string, jsonData []byte) (*http.Request, error) {
|
||||
return x402BuildRequest(url, jsonData)
|
||||
func (c *BlockRunBaseClient) BuildRequest(url string, jsonData []byte) (*http.Request, error) {
|
||||
return X402BuildRequest(url, jsonData)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package mcp
|
||||
package payment
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -12,10 +12,11 @@ import (
|
||||
"github.com/gagliardetto/solana-go/programs/compute-budget"
|
||||
"github.com/gagliardetto/solana-go/programs/token"
|
||||
"github.com/gagliardetto/solana-go/rpc"
|
||||
|
||||
"nofx/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderBlockRunSol = "blockrun-sol"
|
||||
DefaultBlockRunSolURL = "https://sol.blockrun.ai"
|
||||
SolanaUSDCMint = "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v"
|
||||
SolanaNetwork = "solana:5eykt4UsFv8P8NJdTREpY1vzqKqZKvdp"
|
||||
@@ -26,62 +27,69 @@ const (
|
||||
computeUnitPrice = uint64(1)
|
||||
)
|
||||
|
||||
func init() {
|
||||
mcp.RegisterProvider(mcp.ProviderBlockRunSol, func(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
return NewBlockRunSolClientWithOptions(opts...)
|
||||
})
|
||||
}
|
||||
|
||||
// BlockRunSolClient implements AIClient using BlockRun's Solana x402 v2 payment protocol.
|
||||
type BlockRunSolClient struct {
|
||||
*Client
|
||||
*mcp.Client
|
||||
keypair solana.PrivateKey
|
||||
}
|
||||
|
||||
func (c *BlockRunSolClient) BaseClient() *mcp.Client { return c.Client }
|
||||
|
||||
// NewBlockRunSolClient creates a BlockRun Solana wallet client (backward compatible).
|
||||
func NewBlockRunSolClient() AIClient {
|
||||
func NewBlockRunSolClient() mcp.AIClient {
|
||||
return NewBlockRunSolClientWithOptions()
|
||||
}
|
||||
|
||||
// NewBlockRunSolClientWithOptions creates a BlockRun Solana wallet client.
|
||||
func NewBlockRunSolClientWithOptions(opts ...ClientOption) AIClient {
|
||||
baseOpts := []ClientOption{
|
||||
WithProvider(ProviderBlockRunSol),
|
||||
WithModel(DefaultBlockRunModel),
|
||||
WithBaseURL(DefaultBlockRunSolURL),
|
||||
func NewBlockRunSolClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
baseOpts := []mcp.ClientOption{
|
||||
mcp.WithProvider(mcp.ProviderBlockRunSol),
|
||||
mcp.WithModel(DefaultBlockRunModel),
|
||||
mcp.WithBaseURL(DefaultBlockRunSolURL),
|
||||
}
|
||||
allOpts := append(baseOpts, opts...)
|
||||
baseClient := NewClient(allOpts...).(*Client)
|
||||
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
|
||||
baseClient.UseFullURL = true
|
||||
baseClient.BaseURL = DefaultBlockRunSolURL + BlockRunChatEndpoint
|
||||
|
||||
c := &BlockRunSolClient{Client: baseClient}
|
||||
baseClient.hooks = c
|
||||
baseClient.Hooks = c
|
||||
return c
|
||||
}
|
||||
|
||||
// SetAPIKey stores the Solana wallet private key (base58-encoded 64-byte keypair).
|
||||
// customModel selects the AI model; empty means default.
|
||||
func (c *BlockRunSolClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
kp, err := solana.PrivateKeyFromBase58(strings.TrimSpace(apiKey))
|
||||
if err != nil {
|
||||
c.logger.Warnf("⚠️ [MCP] BlockRun Sol: failed to parse private key: %v", err)
|
||||
c.Log.Warnf("⚠️ [MCP] BlockRun Sol: failed to parse private key: %v", err)
|
||||
return
|
||||
}
|
||||
c.keypair = kp
|
||||
c.APIKey = apiKey
|
||||
c.logger.Infof("🔧 [MCP] BlockRun Sol wallet: %s", kp.PublicKey().String())
|
||||
c.Log.Infof("🔧 [MCP] BlockRun Sol wallet: %s", kp.PublicKey().String())
|
||||
|
||||
if customModel != "" {
|
||||
c.Model = customModel
|
||||
c.logger.Infof("🔧 [MCP] BlockRun Sol model: %s", customModel)
|
||||
c.Log.Infof("🔧 [MCP] BlockRun Sol model: %s", customModel)
|
||||
} else {
|
||||
c.logger.Infof("🔧 [MCP] BlockRun Sol model: %s", DefaultBlockRunModel)
|
||||
c.Log.Infof("🔧 [MCP] BlockRun Sol model: %s", DefaultBlockRunModel)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *BlockRunSolClient) setAuthHeader(h http.Header) { x402SetAuthHeader(h) }
|
||||
func (c *BlockRunSolClient) SetAuthHeader(h http.Header) { X402SetAuthHeader(h) }
|
||||
|
||||
func (c *BlockRunSolClient) call(systemPrompt, userPrompt string) (string, error) {
|
||||
return x402Call(c.Client, c.signSolanaPayment, "BlockRun Sol", systemPrompt, userPrompt)
|
||||
func (c *BlockRunSolClient) Call(systemPrompt, userPrompt string) (string, error) {
|
||||
return X402Call(c.Client, c.signSolanaPayment, "BlockRun Sol", systemPrompt, userPrompt)
|
||||
}
|
||||
|
||||
func (c *BlockRunSolClient) CallWithRequestFull(req *Request) (*LLMResponse, error) {
|
||||
return x402CallFull(c.Client, c.signSolanaPayment, "BlockRun Sol", req)
|
||||
func (c *BlockRunSolClient) CallWithRequestFull(req *mcp.Request) (*mcp.LLMResponse, error) {
|
||||
return X402CallFull(c.Client, c.signSolanaPayment, "BlockRun Sol", req)
|
||||
}
|
||||
|
||||
// signSolanaPayment parses the Payment-Required header and builds a signed x402 v2 Solana payload.
|
||||
@@ -90,18 +98,18 @@ func (c *BlockRunSolClient) signSolanaPayment(paymentHeaderB64 string) (string,
|
||||
return "", fmt.Errorf("no private key set for BlockRun Sol wallet")
|
||||
}
|
||||
|
||||
decoded, err := x402DecodeHeader(paymentHeaderB64)
|
||||
decoded, err := X402DecodeHeader(paymentHeaderB64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var req x402v2PaymentRequired
|
||||
var req X402v2PaymentRequired
|
||||
if err := json.Unmarshal(decoded, &req); err != nil {
|
||||
return "", fmt.Errorf("failed to parse x402 v2 Solana header: %w", err)
|
||||
}
|
||||
|
||||
// Find the Solana option
|
||||
var opt *x402AcceptOption
|
||||
var opt *X402AcceptOption
|
||||
for i := range req.Accepts {
|
||||
if strings.HasPrefix(req.Accepts[i].Network, "solana:") {
|
||||
opt = &req.Accepts[i]
|
||||
@@ -174,11 +182,9 @@ func (c *BlockRunSolClient) signSolanaPayment(paymentHeaderB64 string) (string,
|
||||
}
|
||||
|
||||
// buildSolanaTransferTx builds a partial-signed VersionedTransaction for SPL USDC TransferChecked.
|
||||
// The fee payer (CDP facilitator) slot is left with a zero signature; only the user signs.
|
||||
func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr string) (string, error) {
|
||||
ownerPubkey := c.keypair.PublicKey()
|
||||
|
||||
// Parse recipient and feePayer
|
||||
recipientPK, err := solana.PublicKeyFromBase58(recipient)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid recipient address: %w", err)
|
||||
@@ -189,13 +195,11 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr
|
||||
}
|
||||
mintPK := solana.MustPublicKeyFromBase58(SolanaUSDCMint)
|
||||
|
||||
// Parse amount
|
||||
var amountU64 uint64
|
||||
if _, err := fmt.Sscanf(amountStr, "%d", &amountU64); err != nil {
|
||||
return "", fmt.Errorf("invalid amount %q: %w", amountStr, err)
|
||||
}
|
||||
|
||||
// Derive ATAs
|
||||
sourceATA, _, err := solana.FindAssociatedTokenAddress(ownerPubkey, mintPK)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to derive source ATA: %w", err)
|
||||
@@ -205,7 +209,6 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr
|
||||
return "", fmt.Errorf("failed to derive dest ATA: %w", err)
|
||||
}
|
||||
|
||||
// Fetch latest blockhash from Solana mainnet
|
||||
rpcClient := rpc.New(SolanaMainnetRPC)
|
||||
bhResp, err := rpcClient.GetLatestBlockhash(context.Background(), rpc.CommitmentFinalized)
|
||||
if err != nil {
|
||||
@@ -213,7 +216,6 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr
|
||||
}
|
||||
recentBlockhash := bhResp.Value.Blockhash
|
||||
|
||||
// Build instructions: ComputeBudgetSetLimit, ComputeBudgetSetPrice, TransferChecked
|
||||
setLimitIx, err := computebudget.NewSetComputeUnitLimitInstruction(computeUnitLimit).ValidateAndBuild()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build SetComputeUnitLimit: %w", err)
|
||||
@@ -235,7 +237,6 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr
|
||||
return "", fmt.Errorf("failed to build TransferChecked: %w", err)
|
||||
}
|
||||
|
||||
// Build transaction with feePayer as payer (matches Python SDK)
|
||||
tx, err := solana.NewTransaction(
|
||||
[]solana.Instruction{setLimitIx, setPriceIx, transferIx},
|
||||
recentBlockhash,
|
||||
@@ -245,9 +246,6 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr
|
||||
return "", fmt.Errorf("failed to build transaction: %w", err)
|
||||
}
|
||||
|
||||
// Partial sign: user signs; fee_payer (CDP) co-signs on server side
|
||||
// The transaction has 2 signers: [feePayer (index 0), owner (index 1)]
|
||||
// We sign only our index (owner).
|
||||
_, err = tx.Sign(func(key solana.PublicKey) *solana.PrivateKey {
|
||||
if key.Equals(ownerPubkey) {
|
||||
return &c.keypair
|
||||
@@ -258,7 +256,6 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr
|
||||
return "", fmt.Errorf("failed to sign transaction: %w", err)
|
||||
}
|
||||
|
||||
// Serialize transaction
|
||||
txBytes, err := tx.MarshalBinary()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to serialize transaction: %w", err)
|
||||
@@ -267,11 +264,11 @@ func (c *BlockRunSolClient) buildSolanaTransferTx(recipient, feePayer, amountStr
|
||||
return base64.StdEncoding.EncodeToString(txBytes), nil
|
||||
}
|
||||
|
||||
// buildUrl returns the full BlockRun Solana endpoint URL.
|
||||
func (c *BlockRunSolClient) buildUrl() string {
|
||||
// BuildUrl returns the full BlockRun Solana endpoint URL.
|
||||
func (c *BlockRunSolClient) BuildUrl() string {
|
||||
return DefaultBlockRunSolURL + BlockRunChatEndpoint
|
||||
}
|
||||
|
||||
func (c *BlockRunSolClient) buildRequest(url string, jsonData []byte) (*http.Request, error) {
|
||||
return x402BuildRequest(url, jsonData)
|
||||
func (c *BlockRunSolClient) BuildRequest(url string, jsonData []byte) (*http.Request, error) {
|
||||
return X402BuildRequest(url, jsonData)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package mcp
|
||||
package payment
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
@@ -6,11 +6,13 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
|
||||
"nofx/mcp"
|
||||
"nofx/mcp/provider"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderClaw402 = "claw402"
|
||||
DefaultClaw402URL = "https://claw402.ai"
|
||||
DefaultClaw402URL = "https://claw402.ai"
|
||||
DefaultClaw402Model = "deepseek"
|
||||
)
|
||||
|
||||
@@ -39,35 +41,42 @@ var claw402ModelEndpoints = map[string]string{
|
||||
"kimi-k2.5": "/api/v1/ai/kimi/chat/k2.5",
|
||||
}
|
||||
|
||||
func init() {
|
||||
mcp.RegisterProvider(mcp.ProviderClaw402, func(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
return NewClaw402ClientWithOptions(opts...)
|
||||
})
|
||||
}
|
||||
|
||||
// Claw402Client implements AIClient using claw402.ai's x402 v2 USDC payment gateway.
|
||||
// Reuses the same EIP-712 signing as BlockRunBaseClient (same Base chain + USDC contract).
|
||||
// When the selected model routes to an Anthropic endpoint, it automatically uses
|
||||
// the Anthropic wire format for requests and responses (via an internal ClaudeClient).
|
||||
type Claw402Client struct {
|
||||
*Client
|
||||
*mcp.Client
|
||||
privateKey *ecdsa.PrivateKey
|
||||
claudeProxy *ClaudeClient // non-nil when endpoint is /anthropic/
|
||||
claudeProxy *provider.ClaudeClient // non-nil when endpoint is /anthropic/
|
||||
}
|
||||
|
||||
func (c *Claw402Client) BaseClient() *mcp.Client { return c.Client }
|
||||
|
||||
// NewClaw402Client creates a claw402 client (backward compatible).
|
||||
func NewClaw402Client() AIClient {
|
||||
func NewClaw402Client() mcp.AIClient {
|
||||
return NewClaw402ClientWithOptions()
|
||||
}
|
||||
|
||||
// NewClaw402ClientWithOptions creates a claw402 client with options.
|
||||
func NewClaw402ClientWithOptions(opts ...ClientOption) AIClient {
|
||||
baseOpts := []ClientOption{
|
||||
WithProvider(ProviderClaw402),
|
||||
WithModel(DefaultClaw402Model),
|
||||
WithBaseURL(DefaultClaw402URL),
|
||||
func NewClaw402ClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
baseOpts := []mcp.ClientOption{
|
||||
mcp.WithProvider(mcp.ProviderClaw402),
|
||||
mcp.WithModel(DefaultClaw402Model),
|
||||
mcp.WithBaseURL(DefaultClaw402URL),
|
||||
}
|
||||
allOpts := append(baseOpts, opts...)
|
||||
baseClient := NewClient(allOpts...).(*Client)
|
||||
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
|
||||
baseClient.UseFullURL = true
|
||||
baseClient.BaseURL = DefaultClaw402URL + claw402ModelEndpoints[DefaultClaw402Model]
|
||||
|
||||
c := &Claw402Client{Client: baseClient}
|
||||
baseClient.hooks = c
|
||||
baseClient.Hooks = c
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -76,12 +85,12 @@ func (c *Claw402Client) SetAPIKey(apiKey string, _ string, customModel string) {
|
||||
hexKey := strings.TrimPrefix(apiKey, "0x")
|
||||
privKey, err := crypto.HexToECDSA(hexKey)
|
||||
if err != nil {
|
||||
c.logger.Warnf("⚠️ [MCP] Claw402: invalid private key: %v", err)
|
||||
c.Log.Warnf("⚠️ [MCP] Claw402: invalid private key: %v", err)
|
||||
} else {
|
||||
c.privateKey = privKey
|
||||
c.APIKey = apiKey
|
||||
addr := crypto.PubkeyToAddress(privKey.PublicKey).Hex()
|
||||
c.logger.Infof("🔧 [MCP] Claw402 wallet: %s", addr)
|
||||
c.Log.Infof("🔧 [MCP] Claw402 wallet: %s", addr)
|
||||
}
|
||||
if customModel != "" {
|
||||
c.Model = customModel
|
||||
@@ -91,11 +100,11 @@ func (c *Claw402Client) SetAPIKey(apiKey string, _ string, customModel string) {
|
||||
|
||||
// Anthropic endpoints need different wire format (Messages API)
|
||||
if strings.Contains(endpoint, "/anthropic/") {
|
||||
c.claudeProxy = &ClaudeClient{Client: c.Client}
|
||||
c.logger.Infof("🔧 [MCP] Claw402 model: %s → %s (Anthropic format)", c.Model, endpoint)
|
||||
c.claudeProxy = &provider.ClaudeClient{Client: c.Client}
|
||||
c.Log.Infof("🔧 [MCP] Claw402 model: %s → %s (Anthropic format)", c.Model, endpoint)
|
||||
} else {
|
||||
c.claudeProxy = nil
|
||||
c.logger.Infof("🔧 [MCP] Claw402 model: %s → %s", c.Model, endpoint)
|
||||
c.Log.Infof("🔧 [MCP] Claw402 model: %s → %s", c.Model, endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,56 +120,56 @@ func (c *Claw402Client) resolveEndpoint() string {
|
||||
return claw402ModelEndpoints[DefaultClaw402Model]
|
||||
}
|
||||
|
||||
func (c *Claw402Client) setAuthHeader(h http.Header) { x402SetAuthHeader(h) }
|
||||
func (c *Claw402Client) SetAuthHeader(h http.Header) { X402SetAuthHeader(h) }
|
||||
|
||||
func (c *Claw402Client) call(systemPrompt, userPrompt string) (string, error) {
|
||||
return x402Call(c.Client, c.signPayment, "Claw402", systemPrompt, userPrompt)
|
||||
func (c *Claw402Client) Call(systemPrompt, userPrompt string) (string, error) {
|
||||
return X402Call(c.Client, c.signPayment, "Claw402", systemPrompt, userPrompt)
|
||||
}
|
||||
|
||||
func (c *Claw402Client) CallWithRequestFull(req *Request) (*LLMResponse, error) {
|
||||
return x402CallFull(c.Client, c.signPayment, "Claw402", req)
|
||||
func (c *Claw402Client) CallWithRequestFull(req *mcp.Request) (*mcp.LLMResponse, error) {
|
||||
return X402CallFull(c.Client, c.signPayment, "Claw402", req)
|
||||
}
|
||||
|
||||
// signPayment signs x402 v2 EIP-712 payment (same Base chain + USDC as BlockRunBase).
|
||||
func (c *Claw402Client) signPayment(paymentHeaderB64 string) (string, error) {
|
||||
return signBasePaymentHeader(c.privateKey, paymentHeaderB64, "Claw402")
|
||||
return SignBasePaymentHeader(c.privateKey, paymentHeaderB64, "Claw402")
|
||||
}
|
||||
|
||||
// ── Format overrides for Anthropic endpoints ─────────────────────────────────
|
||||
|
||||
func (c *Claw402Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
|
||||
func (c *Claw402Client) BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
|
||||
if c.claudeProxy != nil {
|
||||
return c.claudeProxy.buildMCPRequestBody(systemPrompt, userPrompt)
|
||||
return c.claudeProxy.BuildMCPRequestBody(systemPrompt, userPrompt)
|
||||
}
|
||||
return c.Client.buildMCPRequestBody(systemPrompt, userPrompt)
|
||||
return c.Client.BuildMCPRequestBody(systemPrompt, userPrompt)
|
||||
}
|
||||
|
||||
func (c *Claw402Client) buildRequestBodyFromRequest(req *Request) map[string]any {
|
||||
func (c *Claw402Client) BuildRequestBodyFromRequest(req *mcp.Request) map[string]any {
|
||||
if c.claudeProxy != nil {
|
||||
return c.claudeProxy.buildRequestBodyFromRequest(req)
|
||||
return c.claudeProxy.BuildRequestBodyFromRequest(req)
|
||||
}
|
||||
return c.Client.buildRequestBodyFromRequest(req)
|
||||
return c.Client.BuildRequestBodyFromRequest(req)
|
||||
}
|
||||
|
||||
func (c *Claw402Client) parseMCPResponse(body []byte) (string, error) {
|
||||
func (c *Claw402Client) ParseMCPResponse(body []byte) (string, error) {
|
||||
if c.claudeProxy != nil {
|
||||
return c.claudeProxy.parseMCPResponse(body)
|
||||
return c.claudeProxy.ParseMCPResponse(body)
|
||||
}
|
||||
return c.Client.parseMCPResponse(body)
|
||||
return c.Client.ParseMCPResponse(body)
|
||||
}
|
||||
|
||||
func (c *Claw402Client) parseMCPResponseFull(body []byte) (*LLMResponse, error) {
|
||||
func (c *Claw402Client) ParseMCPResponseFull(body []byte) (*mcp.LLMResponse, error) {
|
||||
if c.claudeProxy != nil {
|
||||
return c.claudeProxy.parseMCPResponseFull(body)
|
||||
return c.claudeProxy.ParseMCPResponseFull(body)
|
||||
}
|
||||
return c.Client.parseMCPResponseFull(body)
|
||||
return c.Client.ParseMCPResponseFull(body)
|
||||
}
|
||||
|
||||
// buildUrl returns the full claw402 endpoint URL.
|
||||
func (c *Claw402Client) buildUrl() string {
|
||||
// BuildUrl returns the full claw402 endpoint URL.
|
||||
func (c *Claw402Client) BuildUrl() string {
|
||||
return c.BaseURL
|
||||
}
|
||||
|
||||
func (c *Claw402Client) buildRequest(url string, jsonData []byte) (*http.Request, error) {
|
||||
return x402BuildRequest(url, jsonData)
|
||||
func (c *Claw402Client) BuildRequest(url string, jsonData []byte) (*http.Request, error) {
|
||||
return X402BuildRequest(url, jsonData)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package mcp
|
||||
package payment
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -11,28 +11,30 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
|
||||
"nofx/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
// x402MaxPaymentRetries is the number of retries for 5xx errors on the
|
||||
// X402MaxPaymentRetries is the number of retries for 5xx errors on the
|
||||
// payment-signed request. The same payment signature is reused (no double-charge).
|
||||
x402MaxPaymentRetries = 3
|
||||
X402MaxPaymentRetries = 3
|
||||
|
||||
// x402RetryBaseWait is the base wait between payment retry attempts.
|
||||
x402RetryBaseWait = 3 * time.Second
|
||||
// X402RetryBaseWait is the base wait between payment retry attempts.
|
||||
X402RetryBaseWait = 3 * time.Second
|
||||
)
|
||||
|
||||
// ── Shared x402 types ────────────────────────────────────────────────────────
|
||||
|
||||
// x402v2PaymentRequired is the structure of the Payment-Required header (x402 v2).
|
||||
type x402v2PaymentRequired struct {
|
||||
// X402v2PaymentRequired is the structure of the Payment-Required header (x402 v2).
|
||||
type X402v2PaymentRequired struct {
|
||||
X402Version int `json:"x402Version"`
|
||||
Accepts []x402AcceptOption `json:"accepts"`
|
||||
Resource *x402Resource `json:"resource"`
|
||||
Accepts []X402AcceptOption `json:"accepts"`
|
||||
Resource *X402Resource `json:"resource"`
|
||||
}
|
||||
|
||||
// x402AcceptOption is a payment option from the x402 v2 header.
|
||||
type x402AcceptOption struct {
|
||||
// X402AcceptOption is a payment option from the x402 v2 header.
|
||||
type X402AcceptOption struct {
|
||||
Scheme string `json:"scheme"`
|
||||
Network string `json:"network"`
|
||||
Amount string `json:"amount"`
|
||||
@@ -42,22 +44,22 @@ type x402AcceptOption struct {
|
||||
Extra map[string]string `json:"extra"`
|
||||
}
|
||||
|
||||
// x402Resource describes the resource being paid for.
|
||||
type x402Resource struct {
|
||||
// X402Resource describes the resource being paid for.
|
||||
type X402Resource struct {
|
||||
URL string `json:"url"`
|
||||
Description string `json:"description"`
|
||||
MimeType string `json:"mimeType"`
|
||||
}
|
||||
|
||||
// x402SignFunc is a callback that signs an x402 payment header and returns the
|
||||
// X402SignFunc is a callback that signs an x402 payment header and returns the
|
||||
// base64-encoded payment signature.
|
||||
type x402SignFunc func(paymentHeaderB64 string) (string, error)
|
||||
type X402SignFunc func(paymentHeaderB64 string) (string, error)
|
||||
|
||||
// ── Shared x402 helpers ──────────────────────────────────────────────────────
|
||||
|
||||
// x402DecodeHeader decodes a base64-encoded x402 Payment-Required header,
|
||||
// X402DecodeHeader decodes a base64-encoded x402 Payment-Required header,
|
||||
// trying RawStdEncoding first then StdEncoding as fallback.
|
||||
func x402DecodeHeader(b64 string) ([]byte, error) {
|
||||
func X402DecodeHeader(b64 string) ([]byte, error) {
|
||||
decoded, err := base64.RawStdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
decoded, err = base64.StdEncoding.DecodeString(b64)
|
||||
@@ -68,19 +70,19 @@ func x402DecodeHeader(b64 string) ([]byte, error) {
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// signBasePaymentHeader decodes a base64 x402 header, parses it, and signs with
|
||||
// SignBasePaymentHeader decodes a base64 x402 header, parses it, and signs with
|
||||
// EIP-712 (USDC TransferWithAuthorization). Shared by BlockRunBase and Claw402.
|
||||
func signBasePaymentHeader(privateKey *ecdsa.PrivateKey, paymentHeaderB64 string, providerName string) (string, error) {
|
||||
func SignBasePaymentHeader(privateKey *ecdsa.PrivateKey, paymentHeaderB64 string, providerName string) (string, error) {
|
||||
if privateKey == nil {
|
||||
return "", fmt.Errorf("no private key set for %s wallet", providerName)
|
||||
}
|
||||
|
||||
decoded, err := x402DecodeHeader(paymentHeaderB64)
|
||||
decoded, err := X402DecodeHeader(paymentHeaderB64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var req x402v2PaymentRequired
|
||||
var req X402v2PaymentRequired
|
||||
if err := json.Unmarshal(decoded, &req); err != nil {
|
||||
return "", fmt.Errorf("failed to parse x402 v2 payment header: %w", err)
|
||||
}
|
||||
@@ -89,19 +91,16 @@ func signBasePaymentHeader(privateKey *ecdsa.PrivateKey, paymentHeaderB64 string
|
||||
}
|
||||
|
||||
senderAddr := crypto.PubkeyToAddress(privateKey.PublicKey).Hex()
|
||||
return signX402Payment(privateKey, senderAddr, req.Accepts[0], req.Resource)
|
||||
return SignX402Payment(privateKey, senderAddr, req.Accepts[0], req.Resource)
|
||||
}
|
||||
|
||||
// doX402Request executes an HTTP request and handles the x402 v2 payment flow.
|
||||
// On a 402 response it reads the Payment-Required (or X-Payment-Required) header,
|
||||
// signs via signFn, retries with Payment-Signature, and logs the Payment-Response
|
||||
// header (tx hash) on success.
|
||||
func doX402Request(
|
||||
// DoX402Request executes an HTTP request and handles the x402 v2 payment flow.
|
||||
func DoX402Request(
|
||||
httpClient *http.Client,
|
||||
buildReqFn func() (*http.Request, error),
|
||||
signFn x402SignFunc,
|
||||
signFn X402SignFunc,
|
||||
providerTag string,
|
||||
logger Logger,
|
||||
logger mcp.Logger,
|
||||
) ([]byte, error) {
|
||||
req, err := buildReqFn()
|
||||
if err != nil {
|
||||
@@ -133,10 +132,9 @@ func doX402Request(
|
||||
}
|
||||
|
||||
// Retry loop for 5xx errors on the payment-signed request.
|
||||
// Reuses the same payment signature — no double-charge.
|
||||
var lastBody []byte
|
||||
var lastStatus int
|
||||
for attempt := 1; attempt <= x402MaxPaymentRetries; attempt++ {
|
||||
for attempt := 1; attempt <= X402MaxPaymentRetries; attempt++ {
|
||||
req2, err := buildReqFn()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build retry request: %w", err)
|
||||
@@ -146,10 +144,10 @@ func doX402Request(
|
||||
|
||||
resp2, err := httpClient.Do(req2)
|
||||
if err != nil {
|
||||
if attempt < x402MaxPaymentRetries {
|
||||
wait := x402RetryBaseWait * time.Duration(attempt)
|
||||
if attempt < X402MaxPaymentRetries {
|
||||
wait := X402RetryBaseWait * time.Duration(attempt)
|
||||
logger.Warnf("⚠️ [%s] Payment request failed: %v, retrying in %v (%d/%d)...",
|
||||
providerTag, err, wait, attempt+1, x402MaxPaymentRetries)
|
||||
providerTag, err, wait, attempt+1, X402MaxPaymentRetries)
|
||||
time.Sleep(wait)
|
||||
continue
|
||||
}
|
||||
@@ -175,11 +173,11 @@ func doX402Request(
|
||||
lastBody = body2
|
||||
lastStatus = resp2.StatusCode
|
||||
|
||||
// Retry on 5xx server errors (502, 503, 520, etc.)
|
||||
if resp2.StatusCode >= 500 && attempt < x402MaxPaymentRetries {
|
||||
wait := x402RetryBaseWait * time.Duration(attempt)
|
||||
// Retry on 5xx server errors
|
||||
if resp2.StatusCode >= 500 && attempt < X402MaxPaymentRetries {
|
||||
wait := X402RetryBaseWait * time.Duration(attempt)
|
||||
logger.Warnf("⚠️ [%s] Server error (status %d), retrying in %v (%d/%d)...",
|
||||
providerTag, resp2.StatusCode, wait, attempt+1, x402MaxPaymentRetries)
|
||||
providerTag, resp2.StatusCode, wait, attempt+1, X402MaxPaymentRetries)
|
||||
time.Sleep(wait)
|
||||
continue
|
||||
}
|
||||
@@ -201,8 +199,8 @@ func doX402Request(
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// x402BuildRequest creates a POST request with Content-Type but no auth header.
|
||||
func x402BuildRequest(url string, jsonData []byte) (*http.Request, error) {
|
||||
// X402BuildRequest creates a POST request with Content-Type but no auth header.
|
||||
func X402BuildRequest(url string, jsonData []byte) (*http.Request, error) {
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fail to build request: %w", err)
|
||||
@@ -211,30 +209,30 @@ func x402BuildRequest(url string, jsonData []byte) (*http.Request, error) {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// x402SetAuthHeader is a no-op — x402 providers authenticate via payment signing.
|
||||
func x402SetAuthHeader(_ http.Header) {}
|
||||
// X402SetAuthHeader is a no-op — x402 providers authenticate via payment signing.
|
||||
func X402SetAuthHeader(_ http.Header) {}
|
||||
|
||||
// x402Call handles the x402 payment flow for the simple CallWithMessages path.
|
||||
func x402Call(c *Client, signFn x402SignFunc, tag string, systemPrompt, userPrompt string) (string, error) {
|
||||
c.logger.Infof("📡 [%s] Request AI Server: %s", tag, c.BaseURL)
|
||||
// X402Call handles the x402 payment flow for the simple CallWithMessages path.
|
||||
func X402Call(c *mcp.Client, signFn X402SignFunc, tag string, systemPrompt, userPrompt string) (string, error) {
|
||||
c.Log.Infof("📡 [%s] Request AI Server: %s", tag, c.BaseURL)
|
||||
|
||||
requestBody := c.hooks.buildMCPRequestBody(systemPrompt, userPrompt)
|
||||
jsonData, err := c.hooks.marshalRequestBody(requestBody)
|
||||
requestBody := c.Hooks.BuildMCPRequestBody(systemPrompt, userPrompt)
|
||||
jsonData, err := c.Hooks.MarshalRequestBody(requestBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
body, err := doX402Request(c.httpClient, func() (*http.Request, error) {
|
||||
return c.hooks.buildRequest(c.hooks.buildUrl(), jsonData)
|
||||
}, signFn, tag, c.logger)
|
||||
body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) {
|
||||
return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData)
|
||||
}, signFn, tag, c.Log)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return c.hooks.parseMCPResponse(body)
|
||||
return c.Hooks.ParseMCPResponse(body)
|
||||
}
|
||||
|
||||
// x402CallFull handles the x402 payment flow for the advanced Request path.
|
||||
func x402CallFull(c *Client, signFn x402SignFunc, tag string, req *Request) (*LLMResponse, error) {
|
||||
// X402CallFull handles the x402 payment flow for the advanced Request path.
|
||||
func X402CallFull(c *mcp.Client, signFn X402SignFunc, tag string, req *mcp.Request) (*mcp.LLMResponse, error) {
|
||||
if c.APIKey == "" {
|
||||
return nil, fmt.Errorf("AI API key not set, please call SetAPIKey first")
|
||||
}
|
||||
@@ -242,19 +240,19 @@ func x402CallFull(c *Client, signFn x402SignFunc, tag string, req *Request) (*LL
|
||||
req.Model = c.Model
|
||||
}
|
||||
|
||||
c.logger.Infof("📡 [%s] Request AI (full): %s", tag, c.BaseURL)
|
||||
c.Log.Infof("📡 [%s] Request AI (full): %s", tag, c.BaseURL)
|
||||
|
||||
requestBody := c.hooks.buildRequestBodyFromRequest(req)
|
||||
jsonData, err := c.hooks.marshalRequestBody(requestBody)
|
||||
requestBody := c.Hooks.BuildRequestBodyFromRequest(req)
|
||||
jsonData, err := c.Hooks.MarshalRequestBody(requestBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := doX402Request(c.httpClient, func() (*http.Request, error) {
|
||||
return c.hooks.buildRequest(c.hooks.buildUrl(), jsonData)
|
||||
}, signFn, tag, c.logger)
|
||||
body, err := DoX402Request(c.HTTPClient, func() (*http.Request, error) {
|
||||
return c.Hooks.BuildRequest(c.Hooks.BuildUrl(), jsonData)
|
||||
}, signFn, tag, c.Log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.hooks.parseMCPResponseFull(body)
|
||||
return c.Hooks.ParseMCPResponseFull(body)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// Package mcp — ClaudeClient implements the Anthropic Messages API.
|
||||
// Package provider — ClaudeClient implements the Anthropic Messages API.
|
||||
//
|
||||
// Wire-format differences from the OpenAI-compatible base Client:
|
||||
//
|
||||
@@ -14,42 +14,51 @@
|
||||
// │ Tool result │ role=tool + tool_call_id │ role=user content[tool_result] │
|
||||
// │ Max tokens │ max_tokens │ max_tokens (same) │
|
||||
// └─────────────────────┴───────────────────────────┴─────────────────────────────────┘
|
||||
package mcp
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"nofx/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderClaude = "claude"
|
||||
DefaultClaudeBaseURL = "https://api.anthropic.com/v1"
|
||||
DefaultClaudeModel = "claude-opus-4-6"
|
||||
)
|
||||
|
||||
func init() {
|
||||
mcp.RegisterProvider(mcp.ProviderClaude, func(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
return NewClaudeClientWithOptions(opts...)
|
||||
})
|
||||
}
|
||||
|
||||
// ClaudeClient wraps the base Client and overrides the methods that differ
|
||||
// for the Anthropic Messages API. All other behaviour (retry, timeout,
|
||||
// logging) is inherited unchanged.
|
||||
type ClaudeClient struct {
|
||||
*Client
|
||||
*mcp.Client
|
||||
}
|
||||
|
||||
func (c *ClaudeClient) BaseClient() *mcp.Client { return c.Client }
|
||||
|
||||
// NewClaudeClient creates a ClaudeClient with default settings.
|
||||
func NewClaudeClient() AIClient {
|
||||
func NewClaudeClient() mcp.AIClient {
|
||||
return NewClaudeClientWithOptions()
|
||||
}
|
||||
|
||||
// NewClaudeClientWithOptions creates a ClaudeClient with optional overrides.
|
||||
func NewClaudeClientWithOptions(opts ...ClientOption) AIClient {
|
||||
baseClient := NewClient(append([]ClientOption{
|
||||
WithProvider(ProviderClaude),
|
||||
WithModel(DefaultClaudeModel),
|
||||
WithBaseURL(DefaultClaudeBaseURL),
|
||||
}, opts...)...).(*Client)
|
||||
func NewClaudeClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
baseClient := mcp.NewClient(append([]mcp.ClientOption{
|
||||
mcp.WithProvider(mcp.ProviderClaude),
|
||||
mcp.WithModel(DefaultClaudeModel),
|
||||
mcp.WithBaseURL(DefaultClaudeBaseURL),
|
||||
}, opts...)...).(*mcp.Client)
|
||||
|
||||
c := &ClaudeClient{Client: baseClient}
|
||||
baseClient.hooks = c // wire dynamic dispatch to ClaudeClient
|
||||
baseClient.Hooks = c // wire dynamic dispatch to ClaudeClient
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -59,32 +68,32 @@ func NewClaudeClientWithOptions(opts ...ClientOption) AIClient {
|
||||
func (c *ClaudeClient) SetAPIKey(apiKey, customURL, customModel string) {
|
||||
c.APIKey = apiKey
|
||||
if len(apiKey) > 8 {
|
||||
c.logger.Infof("🔧 [MCP] Claude API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
c.Log.Infof("🔧 [MCP] Claude API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
c.BaseURL = customURL
|
||||
c.logger.Infof("🔧 [MCP] Claude BaseURL: %s", customURL)
|
||||
c.Log.Infof("🔧 [MCP] Claude BaseURL: %s", customURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
c.Model = customModel
|
||||
c.logger.Infof("🔧 [MCP] Claude Model: %s", customModel)
|
||||
c.Log.Infof("🔧 [MCP] Claude Model: %s", customModel)
|
||||
}
|
||||
}
|
||||
|
||||
// setAuthHeader uses x-api-key instead of Authorization: Bearer.
|
||||
func (c *ClaudeClient) setAuthHeader(h http.Header) {
|
||||
// SetAuthHeader uses x-api-key instead of Authorization: Bearer.
|
||||
func (c *ClaudeClient) SetAuthHeader(h http.Header) {
|
||||
h.Set("x-api-key", c.APIKey)
|
||||
h.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
|
||||
// buildUrl targets /messages instead of /chat/completions.
|
||||
func (c *ClaudeClient) buildUrl() string {
|
||||
// BuildUrl targets /messages instead of /chat/completions.
|
||||
func (c *ClaudeClient) BuildUrl() string {
|
||||
return fmt.Sprintf("%s/messages", c.BaseURL)
|
||||
}
|
||||
|
||||
// buildMCPRequestBody builds the Anthropic wire format for the simple
|
||||
// BuildMCPRequestBody builds the Anthropic wire format for the simple
|
||||
// CallWithMessages path (no tool support).
|
||||
func (c *ClaudeClient) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
|
||||
func (c *ClaudeClient) BuildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
|
||||
return map[string]any{
|
||||
"model": c.Model,
|
||||
"max_tokens": c.MaxTokens,
|
||||
@@ -95,23 +104,12 @@ func (c *ClaudeClient) buildMCPRequestBody(systemPrompt, userPrompt string) map[
|
||||
}
|
||||
}
|
||||
|
||||
// buildRequestBodyFromRequest converts a *Request into the Anthropic Messages
|
||||
// API wire format. This is the key override that makes tool calling work
|
||||
// correctly with Claude.
|
||||
//
|
||||
// Conversions applied:
|
||||
//
|
||||
// - System messages are lifted to the top-level "system" field.
|
||||
// - Tool definitions: parameters → input_schema, wrapper removed.
|
||||
// - Assistant messages with ToolCalls → content[{type:tool_use,...}].
|
||||
// - Tool result messages (role=tool) → role=user with tool_result blocks.
|
||||
// Consecutive tool results are merged into a single user turn (Anthropic
|
||||
// requires strictly alternating user/assistant turns).
|
||||
// - tool_choice "auto"/"any" → {"type":"auto"/"any"} object.
|
||||
func (c *ClaudeClient) buildRequestBodyFromRequest(req *Request) map[string]any {
|
||||
// BuildRequestBodyFromRequest converts a *Request into the Anthropic Messages
|
||||
// API wire format.
|
||||
func (c *ClaudeClient) BuildRequestBodyFromRequest(req *mcp.Request) map[string]any {
|
||||
// ── 1. Separate system prompt from conversation messages ──────────────────
|
||||
var systemPrompt string
|
||||
var convMsgs []Message
|
||||
var convMsgs []mcp.Message
|
||||
for _, m := range req.Messages {
|
||||
if m.Role == "system" {
|
||||
systemPrompt = m.Content
|
||||
@@ -121,7 +119,7 @@ func (c *ClaudeClient) buildRequestBodyFromRequest(req *Request) map[string]any
|
||||
}
|
||||
|
||||
// ── 2. Convert messages to Anthropic format ───────────────────────────────
|
||||
anthropicMsgs := convertMessagesToAnthropic(convMsgs)
|
||||
anthropicMsgs := ConvertMessagesToAnthropic(convMsgs)
|
||||
|
||||
// ── 3. Convert tool definitions (parameters → input_schema) ──────────────
|
||||
var anthropicTools []map[string]any
|
||||
@@ -162,16 +160,9 @@ func (c *ClaudeClient) buildRequestBodyFromRequest(req *Request) map[string]any
|
||||
return body
|
||||
}
|
||||
|
||||
// convertMessagesToAnthropic translates from the OpenAI-shaped mcp.Message
|
||||
// ConvertMessagesToAnthropic translates from the OpenAI-shaped mcp.Message
|
||||
// slice to Anthropic's messages array.
|
||||
//
|
||||
// Rules:
|
||||
// 1. role=assistant + ToolCalls → role=assistant, content=[tool_use, ...]
|
||||
// 2. role=tool (result) → role=user, content=[tool_result, ...]
|
||||
// Consecutive tool-result messages are merged into one user turn so the
|
||||
// conversation always alternates user/assistant.
|
||||
// 3. All other messages → {role, content} as-is.
|
||||
func convertMessagesToAnthropic(msgs []Message) []map[string]any {
|
||||
func ConvertMessagesToAnthropic(msgs []mcp.Message) []map[string]any {
|
||||
var out []map[string]any
|
||||
|
||||
for i := 0; i < len(msgs); {
|
||||
@@ -232,29 +223,18 @@ func convertMessagesToAnthropic(msgs []Message) []map[string]any {
|
||||
|
||||
// ── Response parsers ──────────────────────────────────────────────────────────
|
||||
|
||||
// parseMCPResponse extracts the plain-text reply from an Anthropic response.
|
||||
// Used by CallWithMessages / CallWithRequest (no tool support).
|
||||
func (c *ClaudeClient) parseMCPResponse(body []byte) (string, error) {
|
||||
r, err := c.parseMCPResponseFull(body)
|
||||
// ParseMCPResponse extracts the plain-text reply from an Anthropic response.
|
||||
func (c *ClaudeClient) ParseMCPResponse(body []byte) (string, error) {
|
||||
r, err := c.ParseMCPResponseFull(body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return r.Content, nil
|
||||
}
|
||||
|
||||
// parseMCPResponseFull extracts both text and tool calls from an Anthropic
|
||||
// ParseMCPResponseFull extracts both text and tool calls from an Anthropic
|
||||
// response envelope.
|
||||
//
|
||||
// Anthropic response shape:
|
||||
//
|
||||
// {
|
||||
// "content": [
|
||||
// {"type": "text", "text": "..."},
|
||||
// {"type": "tool_use", "id": "...", "name": "...", "input": {...}}
|
||||
// ],
|
||||
// "stop_reason": "tool_use" | "end_turn"
|
||||
// }
|
||||
func (c *ClaudeClient) parseMCPResponseFull(body []byte) (*LLMResponse, error) {
|
||||
func (c *ClaudeClient) ParseMCPResponseFull(body []byte) (*mcp.LLMResponse, error) {
|
||||
var raw struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
@@ -281,8 +261,8 @@ func (c *ClaudeClient) parseMCPResponseFull(body []byte) (*LLMResponse, error) {
|
||||
}
|
||||
|
||||
total := raw.Usage.InputTokens + raw.Usage.OutputTokens
|
||||
if TokenUsageCallback != nil && total > 0 {
|
||||
TokenUsageCallback(TokenUsage{
|
||||
if mcp.TokenUsageCallback != nil && total > 0 {
|
||||
mcp.TokenUsageCallback(mcp.TokenUsage{
|
||||
Provider: c.Provider,
|
||||
Model: c.Model,
|
||||
PromptTokens: raw.Usage.InputTokens,
|
||||
@@ -291,7 +271,7 @@ func (c *ClaudeClient) parseMCPResponseFull(body []byte) (*LLMResponse, error) {
|
||||
})
|
||||
}
|
||||
|
||||
result := &LLMResponse{}
|
||||
result := &mcp.LLMResponse{}
|
||||
for _, block := range raw.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
@@ -304,10 +284,10 @@ func (c *ClaudeClient) parseMCPResponseFull(body []byte) (*LLMResponse, error) {
|
||||
if err != nil {
|
||||
argsJSON = []byte("{}")
|
||||
}
|
||||
result.ToolCalls = append(result.ToolCalls, ToolCall{
|
||||
result.ToolCalls = append(result.ToolCalls, mcp.ToolCall{
|
||||
ID: block.ID,
|
||||
Type: "function",
|
||||
Function: ToolCallFunction{
|
||||
Function: mcp.ToolCallFunction{
|
||||
Name: block.Name,
|
||||
Arguments: string(argsJSON),
|
||||
},
|
||||
@@ -0,0 +1,69 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"nofx/mcp"
|
||||
)
|
||||
|
||||
func init() {
|
||||
mcp.RegisterProvider(mcp.ProviderDeepSeek, func(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
return NewDeepSeekClientWithOptions(opts...)
|
||||
})
|
||||
}
|
||||
|
||||
type DeepSeekClient struct {
|
||||
*mcp.Client
|
||||
}
|
||||
|
||||
func (c *DeepSeekClient) BaseClient() *mcp.Client { return c.Client }
|
||||
|
||||
// NewDeepSeekClient creates DeepSeek client (backward compatible)
|
||||
//
|
||||
// Deprecated: Recommend using NewDeepSeekClientWithOptions for better flexibility
|
||||
func NewDeepSeekClient() mcp.AIClient {
|
||||
return NewDeepSeekClientWithOptions()
|
||||
}
|
||||
|
||||
// NewDeepSeekClientWithOptions creates DeepSeek client (supports options pattern)
|
||||
func NewDeepSeekClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
deepseekOpts := []mcp.ClientOption{
|
||||
mcp.WithProvider(mcp.ProviderDeepSeek),
|
||||
mcp.WithModel(mcp.DefaultDeepSeekModel),
|
||||
mcp.WithBaseURL(mcp.DefaultDeepSeekBaseURL),
|
||||
}
|
||||
|
||||
allOpts := append(deepseekOpts, opts...)
|
||||
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
|
||||
|
||||
dsClient := &DeepSeekClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
baseClient.Hooks = dsClient
|
||||
return dsClient
|
||||
}
|
||||
|
||||
func (dsClient *DeepSeekClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
dsClient.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
dsClient.Log.Infof("🔧 [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
dsClient.BaseURL = customURL
|
||||
dsClient.Log.Infof("🔧 [MCP] DeepSeek using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
dsClient.Log.Infof("🔧 [MCP] DeepSeek using default BaseURL: %s", dsClient.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
dsClient.Model = customModel
|
||||
dsClient.Log.Infof("🔧 [MCP] DeepSeek using custom Model: %s", customModel)
|
||||
} else {
|
||||
dsClient.Log.Infof("🔧 [MCP] DeepSeek using default Model: %s", dsClient.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func (dsClient *DeepSeekClient) SetAuthHeader(reqHeaders http.Header) {
|
||||
dsClient.Client.SetAuthHeader(reqHeaders)
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"nofx/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultGeminiBaseURL = "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
DefaultGeminiModel = "gemini-3-pro-preview"
|
||||
)
|
||||
|
||||
func init() {
|
||||
mcp.RegisterProvider(mcp.ProviderGemini, func(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
return NewGeminiClientWithOptions(opts...)
|
||||
})
|
||||
}
|
||||
|
||||
type GeminiClient struct {
|
||||
*mcp.Client
|
||||
}
|
||||
|
||||
func (c *GeminiClient) BaseClient() *mcp.Client { return c.Client }
|
||||
|
||||
// NewGeminiClient creates Gemini client (backward compatible)
|
||||
func NewGeminiClient() mcp.AIClient {
|
||||
return NewGeminiClientWithOptions()
|
||||
}
|
||||
|
||||
// NewGeminiClientWithOptions creates Gemini client (supports options pattern)
|
||||
func NewGeminiClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
geminiOpts := []mcp.ClientOption{
|
||||
mcp.WithProvider(mcp.ProviderGemini),
|
||||
mcp.WithModel(DefaultGeminiModel),
|
||||
mcp.WithBaseURL(DefaultGeminiBaseURL),
|
||||
}
|
||||
|
||||
allOpts := append(geminiOpts, opts...)
|
||||
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
|
||||
|
||||
geminiClient := &GeminiClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
baseClient.Hooks = geminiClient
|
||||
return geminiClient
|
||||
}
|
||||
|
||||
func (c *GeminiClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
c.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
c.Log.Infof("🔧 [MCP] Gemini API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
c.BaseURL = customURL
|
||||
c.Log.Infof("🔧 [MCP] Gemini using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
c.Log.Infof("🔧 [MCP] Gemini using default BaseURL: %s", c.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
c.Model = customModel
|
||||
c.Log.Infof("🔧 [MCP] Gemini using custom Model: %s", customModel)
|
||||
} else {
|
||||
c.Log.Infof("🔧 [MCP] Gemini using default Model: %s", c.Model)
|
||||
}
|
||||
}
|
||||
|
||||
// Gemini OpenAI-compatible API uses standard Bearer auth
|
||||
func (c *GeminiClient) SetAuthHeader(reqHeaders http.Header) {
|
||||
c.Client.SetAuthHeader(reqHeaders)
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"nofx/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultGrokBaseURL = "https://api.x.ai/v1"
|
||||
DefaultGrokModel = "grok-3-latest"
|
||||
)
|
||||
|
||||
func init() {
|
||||
mcp.RegisterProvider(mcp.ProviderGrok, func(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
return NewGrokClientWithOptions(opts...)
|
||||
})
|
||||
}
|
||||
|
||||
type GrokClient struct {
|
||||
*mcp.Client
|
||||
}
|
||||
|
||||
func (c *GrokClient) BaseClient() *mcp.Client { return c.Client }
|
||||
|
||||
// NewGrokClient creates Grok client (backward compatible)
|
||||
func NewGrokClient() mcp.AIClient {
|
||||
return NewGrokClientWithOptions()
|
||||
}
|
||||
|
||||
// NewGrokClientWithOptions creates Grok client (supports options pattern)
|
||||
func NewGrokClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
grokOpts := []mcp.ClientOption{
|
||||
mcp.WithProvider(mcp.ProviderGrok),
|
||||
mcp.WithModel(DefaultGrokModel),
|
||||
mcp.WithBaseURL(DefaultGrokBaseURL),
|
||||
}
|
||||
|
||||
allOpts := append(grokOpts, opts...)
|
||||
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
|
||||
|
||||
grokClient := &GrokClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
baseClient.Hooks = grokClient
|
||||
return grokClient
|
||||
}
|
||||
|
||||
func (c *GrokClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
c.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
c.Log.Infof("🔧 [MCP] Grok API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
c.BaseURL = customURL
|
||||
c.Log.Infof("🔧 [MCP] Grok using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
c.Log.Infof("🔧 [MCP] Grok using default BaseURL: %s", c.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
c.Model = customModel
|
||||
c.Log.Infof("🔧 [MCP] Grok using custom Model: %s", customModel)
|
||||
} else {
|
||||
c.Log.Infof("🔧 [MCP] Grok using default Model: %s", c.Model)
|
||||
}
|
||||
}
|
||||
|
||||
// Grok uses standard OpenAI-compatible API with Bearer auth
|
||||
func (c *GrokClient) SetAuthHeader(reqHeaders http.Header) {
|
||||
c.Client.SetAuthHeader(reqHeaders)
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"nofx/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultKimiBaseURL = "https://api.moonshot.ai/v1" // Global endpoint (use api.moonshot.cn for China)
|
||||
DefaultKimiModel = "moonshot-v1-auto"
|
||||
)
|
||||
|
||||
func init() {
|
||||
mcp.RegisterProvider(mcp.ProviderKimi, func(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
return NewKimiClientWithOptions(opts...)
|
||||
})
|
||||
}
|
||||
|
||||
type KimiClient struct {
|
||||
*mcp.Client
|
||||
}
|
||||
|
||||
func (c *KimiClient) BaseClient() *mcp.Client { return c.Client }
|
||||
|
||||
// NewKimiClient creates Kimi (Moonshot) client (backward compatible)
|
||||
func NewKimiClient() mcp.AIClient {
|
||||
return NewKimiClientWithOptions()
|
||||
}
|
||||
|
||||
// NewKimiClientWithOptions creates Kimi client (supports options pattern)
|
||||
func NewKimiClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
kimiOpts := []mcp.ClientOption{
|
||||
mcp.WithProvider(mcp.ProviderKimi),
|
||||
mcp.WithModel(DefaultKimiModel),
|
||||
mcp.WithBaseURL(DefaultKimiBaseURL),
|
||||
}
|
||||
|
||||
allOpts := append(kimiOpts, opts...)
|
||||
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
|
||||
|
||||
kimiClient := &KimiClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
baseClient.Hooks = kimiClient
|
||||
return kimiClient
|
||||
}
|
||||
|
||||
func (c *KimiClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
c.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
c.Log.Infof("🔧 [MCP] Kimi API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
c.BaseURL = customURL
|
||||
c.Log.Infof("🔧 [MCP] Kimi using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
c.Log.Infof("🔧 [MCP] Kimi using default BaseURL: %s", c.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
c.Model = customModel
|
||||
c.Log.Infof("🔧 [MCP] Kimi using custom Model: %s", customModel)
|
||||
} else {
|
||||
c.Log.Infof("🔧 [MCP] Kimi using default Model: %s", c.Model)
|
||||
}
|
||||
}
|
||||
|
||||
// Kimi uses standard OpenAI-compatible API
|
||||
func (c *KimiClient) SetAuthHeader(reqHeaders http.Header) {
|
||||
c.Client.SetAuthHeader(reqHeaders)
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"nofx/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMiniMaxBaseURL = "https://api.minimax.io/v1"
|
||||
DefaultMiniMaxModel = "MiniMax-M2.5"
|
||||
)
|
||||
|
||||
func init() {
|
||||
mcp.RegisterProvider(mcp.ProviderMiniMax, func(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
return NewMiniMaxClientWithOptions(opts...)
|
||||
})
|
||||
}
|
||||
|
||||
type MiniMaxClient struct {
|
||||
*mcp.Client
|
||||
}
|
||||
|
||||
func (c *MiniMaxClient) BaseClient() *mcp.Client { return c.Client }
|
||||
|
||||
// NewMiniMaxClient creates MiniMax client (backward compatible)
|
||||
func NewMiniMaxClient() mcp.AIClient {
|
||||
return NewMiniMaxClientWithOptions()
|
||||
}
|
||||
|
||||
// NewMiniMaxClientWithOptions creates MiniMax client (supports options pattern)
|
||||
func NewMiniMaxClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
minimaxOpts := []mcp.ClientOption{
|
||||
mcp.WithProvider(mcp.ProviderMiniMax),
|
||||
mcp.WithModel(DefaultMiniMaxModel),
|
||||
mcp.WithBaseURL(DefaultMiniMaxBaseURL),
|
||||
}
|
||||
|
||||
allOpts := append(minimaxOpts, opts...)
|
||||
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
|
||||
|
||||
minimaxClient := &MiniMaxClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
baseClient.Hooks = minimaxClient
|
||||
return minimaxClient
|
||||
}
|
||||
|
||||
func (c *MiniMaxClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
c.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
c.Log.Infof("🔧 [MCP] MiniMax API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
c.BaseURL = customURL
|
||||
c.Log.Infof("🔧 [MCP] MiniMax using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
c.Log.Infof("🔧 [MCP] MiniMax using default BaseURL: %s", c.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
c.Model = customModel
|
||||
c.Log.Infof("🔧 [MCP] MiniMax using custom Model: %s", customModel)
|
||||
} else {
|
||||
c.Log.Infof("🔧 [MCP] MiniMax using default Model: %s", c.Model)
|
||||
}
|
||||
}
|
||||
|
||||
// MiniMax uses standard OpenAI-compatible API with Bearer auth
|
||||
func (c *MiniMaxClient) SetAuthHeader(reqHeaders http.Header) {
|
||||
c.Client.SetAuthHeader(reqHeaders)
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"nofx/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultOpenAIBaseURL = "https://api.openai.com/v1"
|
||||
DefaultOpenAIModel = "gpt-5.4"
|
||||
)
|
||||
|
||||
func init() {
|
||||
mcp.RegisterProvider(mcp.ProviderOpenAI, func(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
return NewOpenAIClientWithOptions(opts...)
|
||||
})
|
||||
}
|
||||
|
||||
type OpenAIClient struct {
|
||||
*mcp.Client
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) BaseClient() *mcp.Client { return c.Client }
|
||||
|
||||
// NewOpenAIClient creates OpenAI client (backward compatible)
|
||||
func NewOpenAIClient() mcp.AIClient {
|
||||
return NewOpenAIClientWithOptions()
|
||||
}
|
||||
|
||||
// NewOpenAIClientWithOptions creates OpenAI client (supports options pattern)
|
||||
func NewOpenAIClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
openaiOpts := []mcp.ClientOption{
|
||||
mcp.WithProvider(mcp.ProviderOpenAI),
|
||||
mcp.WithModel(DefaultOpenAIModel),
|
||||
mcp.WithBaseURL(DefaultOpenAIBaseURL),
|
||||
}
|
||||
|
||||
allOpts := append(openaiOpts, opts...)
|
||||
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
|
||||
|
||||
openaiClient := &OpenAIClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
baseClient.Hooks = openaiClient
|
||||
return openaiClient
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
c.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
c.Log.Infof("🔧 [MCP] OpenAI API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
c.BaseURL = customURL
|
||||
c.Log.Infof("🔧 [MCP] OpenAI using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
c.Log.Infof("🔧 [MCP] OpenAI using default BaseURL: %s", c.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
c.Model = customModel
|
||||
c.Log.Infof("🔧 [MCP] OpenAI using custom Model: %s", customModel)
|
||||
} else {
|
||||
c.Log.Infof("🔧 [MCP] OpenAI using default Model: %s", c.Model)
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI uses standard Bearer auth
|
||||
func (c *OpenAIClient) SetAuthHeader(reqHeaders http.Header) {
|
||||
c.Client.SetAuthHeader(reqHeaders)
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"nofx/mcp"
|
||||
)
|
||||
|
||||
func TestOptionsWithDeepSeekClient(t *testing.T) {
|
||||
logger := mcp.NewNoopLogger()
|
||||
|
||||
client := NewDeepSeekClientWithOptions(
|
||||
mcp.WithAPIKey("sk-deepseek-key"),
|
||||
mcp.WithLogger(logger),
|
||||
mcp.WithMaxTokens(5000),
|
||||
)
|
||||
|
||||
dsClient := client.(*DeepSeekClient)
|
||||
|
||||
// Verify DeepSeek default values
|
||||
if dsClient.Provider != mcp.ProviderDeepSeek {
|
||||
t.Error("Provider should be DeepSeek")
|
||||
}
|
||||
|
||||
if dsClient.BaseURL != mcp.DefaultDeepSeekBaseURL {
|
||||
t.Error("BaseURL should be DeepSeek default")
|
||||
}
|
||||
|
||||
if dsClient.Model != mcp.DefaultDeepSeekModel {
|
||||
t.Error("Model should be DeepSeek default")
|
||||
}
|
||||
|
||||
// Verify custom options
|
||||
if dsClient.APIKey != "sk-deepseek-key" {
|
||||
t.Error("APIKey should be set from options")
|
||||
}
|
||||
|
||||
if dsClient.Log != logger {
|
||||
t.Error("Log should be set from options")
|
||||
}
|
||||
|
||||
if dsClient.MaxTokens != 5000 {
|
||||
t.Error("MaxTokens should be 5000")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionsWithQwenClient(t *testing.T) {
|
||||
logger := mcp.NewNoopLogger()
|
||||
|
||||
client := NewQwenClientWithOptions(
|
||||
mcp.WithAPIKey("sk-qwen-key"),
|
||||
mcp.WithLogger(logger),
|
||||
mcp.WithMaxTokens(6000),
|
||||
)
|
||||
|
||||
qwenClient := client.(*QwenClient)
|
||||
|
||||
// Verify Qwen default values
|
||||
if qwenClient.Provider != mcp.ProviderQwen {
|
||||
t.Error("Provider should be Qwen")
|
||||
}
|
||||
|
||||
if qwenClient.BaseURL != mcp.DefaultQwenBaseURL {
|
||||
t.Error("BaseURL should be Qwen default")
|
||||
}
|
||||
|
||||
if qwenClient.Model != mcp.DefaultQwenModel {
|
||||
t.Error("Model should be Qwen default")
|
||||
}
|
||||
|
||||
// Verify custom options
|
||||
if qwenClient.APIKey != "sk-qwen-key" {
|
||||
t.Error("APIKey should be set from options")
|
||||
}
|
||||
|
||||
if qwenClient.Log != logger {
|
||||
t.Error("Log should be set from options")
|
||||
}
|
||||
|
||||
if qwenClient.MaxTokens != 6000 {
|
||||
t.Error("MaxTokens should be 6000")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"nofx/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultQwenBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
DefaultQwenModel = "qwen3-max"
|
||||
)
|
||||
|
||||
func init() {
|
||||
mcp.RegisterProvider(mcp.ProviderQwen, func(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
return NewQwenClientWithOptions(opts...)
|
||||
})
|
||||
}
|
||||
|
||||
type QwenClient struct {
|
||||
*mcp.Client
|
||||
}
|
||||
|
||||
func (c *QwenClient) BaseClient() *mcp.Client { return c.Client }
|
||||
|
||||
// NewQwenClient creates Qwen client (backward compatible)
|
||||
//
|
||||
// Deprecated: Recommend using NewQwenClientWithOptions for better flexibility
|
||||
func NewQwenClient() mcp.AIClient {
|
||||
return NewQwenClientWithOptions()
|
||||
}
|
||||
|
||||
// NewQwenClientWithOptions creates Qwen client (supports options pattern)
|
||||
func NewQwenClientWithOptions(opts ...mcp.ClientOption) mcp.AIClient {
|
||||
qwenOpts := []mcp.ClientOption{
|
||||
mcp.WithProvider(mcp.ProviderQwen),
|
||||
mcp.WithModel(DefaultQwenModel),
|
||||
mcp.WithBaseURL(DefaultQwenBaseURL),
|
||||
}
|
||||
|
||||
allOpts := append(qwenOpts, opts...)
|
||||
baseClient := mcp.NewClient(allOpts...).(*mcp.Client)
|
||||
|
||||
qwenClient := &QwenClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
baseClient.Hooks = qwenClient
|
||||
return qwenClient
|
||||
}
|
||||
|
||||
func (qwenClient *QwenClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
qwenClient.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
qwenClient.Log.Infof("🔧 [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
qwenClient.BaseURL = customURL
|
||||
qwenClient.Log.Infof("🔧 [MCP] Qwen using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
qwenClient.Log.Infof("🔧 [MCP] Qwen using default BaseURL: %s", qwenClient.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
qwenClient.Model = customModel
|
||||
qwenClient.Log.Infof("🔧 [MCP] Qwen using custom Model: %s", customModel)
|
||||
} else {
|
||||
qwenClient.Log.Infof("🔧 [MCP] Qwen using default Model: %s", qwenClient.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func (qwenClient *QwenClient) SetAuthHeader(reqHeaders http.Header) {
|
||||
qwenClient.Client.SetAuthHeader(reqHeaders)
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package mcp
|
||||
|
||||
// Provider name constants — kept in the mcp package so that client.go can
|
||||
// reference them for default configuration without importing sub-packages.
|
||||
// Provider sub-packages re-use these same values.
|
||||
const (
|
||||
ProviderDeepSeek = "deepseek"
|
||||
ProviderOpenAI = "openai"
|
||||
ProviderClaude = "claude"
|
||||
ProviderQwen = "qwen"
|
||||
ProviderGemini = "gemini"
|
||||
ProviderGrok = "grok"
|
||||
ProviderKimi = "kimi"
|
||||
ProviderMiniMax = "minimax"
|
||||
|
||||
ProviderBlockRunBase = "blockrun-base"
|
||||
ProviderBlockRunSol = "blockrun-sol"
|
||||
ProviderClaw402 = "claw402"
|
||||
|
||||
// Default DeepSeek configuration (used as fallback in NewClient)
|
||||
DefaultDeepSeekBaseURL = "https://api.deepseek.com"
|
||||
DefaultDeepSeekModel = "deepseek-chat"
|
||||
|
||||
// Default Qwen configuration (used by WithQwenConfig convenience option)
|
||||
DefaultQwenBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
DefaultQwenModel = "qwen3-max"
|
||||
|
||||
// Default MiniMax configuration (used by WithMiniMaxConfig convenience option)
|
||||
DefaultMiniMaxBaseURL = "https://api.minimax.io/v1"
|
||||
DefaultMiniMaxModel = "MiniMax-M2.5"
|
||||
)
|
||||
@@ -1,83 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderQwen = "qwen"
|
||||
DefaultQwenBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
DefaultQwenModel = "qwen3-max"
|
||||
)
|
||||
|
||||
type QwenClient struct {
|
||||
*Client
|
||||
}
|
||||
|
||||
// NewQwenClient creates Qwen client (backward compatible)
|
||||
//
|
||||
// Deprecated: Recommend using NewQwenClientWithOptions for better flexibility
|
||||
func NewQwenClient() AIClient {
|
||||
return NewQwenClientWithOptions()
|
||||
}
|
||||
|
||||
// NewQwenClientWithOptions creates Qwen client (supports options pattern)
|
||||
//
|
||||
// Usage examples:
|
||||
// // Basic usage
|
||||
// client := mcp.NewQwenClientWithOptions()
|
||||
//
|
||||
// // Custom configuration
|
||||
// client := mcp.NewQwenClientWithOptions(
|
||||
// mcp.WithAPIKey("sk-xxx"),
|
||||
// mcp.WithLogger(customLogger),
|
||||
// mcp.WithTimeout(60*time.Second),
|
||||
// )
|
||||
func NewQwenClientWithOptions(opts ...ClientOption) AIClient {
|
||||
// 1. Create Qwen preset options
|
||||
qwenOpts := []ClientOption{
|
||||
WithProvider(ProviderQwen),
|
||||
WithModel(DefaultQwenModel),
|
||||
WithBaseURL(DefaultQwenBaseURL),
|
||||
}
|
||||
|
||||
// 2. Merge user options (user options have higher priority)
|
||||
allOpts := append(qwenOpts, opts...)
|
||||
|
||||
// 3. Create base client
|
||||
baseClient := NewClient(allOpts...).(*Client)
|
||||
|
||||
// 4. Create Qwen client
|
||||
qwenClient := &QwenClient{
|
||||
Client: baseClient,
|
||||
}
|
||||
|
||||
// 5. Set hooks to point to QwenClient (implement dynamic dispatch)
|
||||
baseClient.hooks = qwenClient
|
||||
|
||||
return qwenClient
|
||||
}
|
||||
|
||||
func (qwenClient *QwenClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
qwenClient.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
qwenClient.logger.Infof("🔧 [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
qwenClient.BaseURL = customURL
|
||||
qwenClient.logger.Infof("🔧 [MCP] Qwen using custom BaseURL: %s", customURL)
|
||||
} else {
|
||||
qwenClient.logger.Infof("🔧 [MCP] Qwen using default BaseURL: %s", qwenClient.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
qwenClient.Model = customModel
|
||||
qwenClient.logger.Infof("🔧 [MCP] Qwen using custom Model: %s", customModel)
|
||||
} else {
|
||||
qwenClient.logger.Infof("🔧 [MCP] Qwen using default Model: %s", qwenClient.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func (qwenClient *QwenClient) setAuthHeader(reqHeaders http.Header) {
|
||||
qwenClient.Client.setAuthHeader(reqHeaders)
|
||||
}
|
||||
@@ -1,272 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// Test QwenClient Creation and Configuration
|
||||
// ============================================================
|
||||
|
||||
func TestNewQwenClient_Default(t *testing.T) {
|
||||
client := NewQwenClient()
|
||||
|
||||
if client == nil {
|
||||
t.Fatal("client should not be nil")
|
||||
}
|
||||
|
||||
// Type assertion check
|
||||
qwenClient, ok := client.(*QwenClient)
|
||||
if !ok {
|
||||
t.Fatal("client should be *QwenClient")
|
||||
}
|
||||
|
||||
// Verify default values
|
||||
if qwenClient.Provider != ProviderQwen {
|
||||
t.Errorf("Provider should be '%s', got '%s'", ProviderQwen, qwenClient.Provider)
|
||||
}
|
||||
|
||||
if qwenClient.BaseURL != DefaultQwenBaseURL {
|
||||
t.Errorf("BaseURL should be '%s', got '%s'", DefaultQwenBaseURL, qwenClient.BaseURL)
|
||||
}
|
||||
|
||||
if qwenClient.Model != DefaultQwenModel {
|
||||
t.Errorf("Model should be '%s', got '%s'", DefaultQwenModel, qwenClient.Model)
|
||||
}
|
||||
|
||||
if qwenClient.logger == nil {
|
||||
t.Error("logger should not be nil")
|
||||
}
|
||||
|
||||
if qwenClient.httpClient == nil {
|
||||
t.Error("httpClient should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewQwenClientWithOptions(t *testing.T) {
|
||||
mockLogger := NewMockLogger()
|
||||
customModel := "qwen-plus"
|
||||
customAPIKey := "sk-custom-qwen-key"
|
||||
|
||||
client := NewQwenClientWithOptions(
|
||||
WithLogger(mockLogger),
|
||||
WithModel(customModel),
|
||||
WithAPIKey(customAPIKey),
|
||||
WithMaxTokens(4000),
|
||||
)
|
||||
|
||||
qwenClient := client.(*QwenClient)
|
||||
|
||||
// Verify custom options are applied
|
||||
if qwenClient.logger != mockLogger {
|
||||
t.Error("logger should be set from option")
|
||||
}
|
||||
|
||||
if qwenClient.Model != customModel {
|
||||
t.Error("Model should be set from option")
|
||||
}
|
||||
|
||||
if qwenClient.APIKey != customAPIKey {
|
||||
t.Error("APIKey should be set from option")
|
||||
}
|
||||
|
||||
if qwenClient.MaxTokens != 4000 {
|
||||
t.Error("MaxTokens should be 4000")
|
||||
}
|
||||
|
||||
// Verify Qwen default values are retained
|
||||
if qwenClient.Provider != ProviderQwen {
|
||||
t.Errorf("Provider should still be '%s'", ProviderQwen)
|
||||
}
|
||||
|
||||
if qwenClient.BaseURL != DefaultQwenBaseURL {
|
||||
t.Errorf("BaseURL should still be '%s'", DefaultQwenBaseURL)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Test SetAPIKey
|
||||
// ============================================================
|
||||
|
||||
func TestQwenClient_SetAPIKey(t *testing.T) {
|
||||
mockLogger := NewMockLogger()
|
||||
client := NewQwenClientWithOptions(
|
||||
WithLogger(mockLogger),
|
||||
)
|
||||
|
||||
qwenClient := client.(*QwenClient)
|
||||
|
||||
// Test setting API Key (default URL and Model)
|
||||
qwenClient.SetAPIKey("sk-test-key-12345678", "", "")
|
||||
|
||||
if qwenClient.APIKey != "sk-test-key-12345678" {
|
||||
t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", qwenClient.APIKey)
|
||||
}
|
||||
|
||||
// Verify logging
|
||||
logs := mockLogger.GetLogsByLevel("INFO")
|
||||
if len(logs) == 0 {
|
||||
t.Error("should have logged API key setting")
|
||||
}
|
||||
|
||||
// Verify BaseURL and Model remain default
|
||||
if qwenClient.BaseURL != DefaultQwenBaseURL {
|
||||
t.Error("BaseURL should remain default")
|
||||
}
|
||||
|
||||
if qwenClient.Model != DefaultQwenModel {
|
||||
t.Error("Model should remain default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenClient_SetAPIKey_WithCustomURL(t *testing.T) {
|
||||
mockLogger := NewMockLogger()
|
||||
client := NewQwenClientWithOptions(
|
||||
WithLogger(mockLogger),
|
||||
)
|
||||
|
||||
qwenClient := client.(*QwenClient)
|
||||
|
||||
customURL := "https://custom.qwen.api.com/v1"
|
||||
qwenClient.SetAPIKey("sk-test-key-12345678", customURL, "")
|
||||
|
||||
if qwenClient.BaseURL != customURL {
|
||||
t.Errorf("BaseURL should be '%s', got '%s'", customURL, qwenClient.BaseURL)
|
||||
}
|
||||
|
||||
// Verify logging
|
||||
logs := mockLogger.GetLogsByLevel("INFO")
|
||||
hasCustomURLLog := false
|
||||
for _, log := range logs {
|
||||
if log.Format == "🔧 [MCP] Qwen using custom BaseURL: %s" {
|
||||
hasCustomURLLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasCustomURLLog {
|
||||
t.Error("should have logged custom BaseURL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenClient_SetAPIKey_WithCustomModel(t *testing.T) {
|
||||
mockLogger := NewMockLogger()
|
||||
client := NewQwenClientWithOptions(
|
||||
WithLogger(mockLogger),
|
||||
)
|
||||
|
||||
qwenClient := client.(*QwenClient)
|
||||
|
||||
customModel := "qwen-turbo"
|
||||
qwenClient.SetAPIKey("sk-test-key-12345678", "", customModel)
|
||||
|
||||
if qwenClient.Model != customModel {
|
||||
t.Errorf("Model should be '%s', got '%s'", customModel, qwenClient.Model)
|
||||
}
|
||||
|
||||
// Verify logging
|
||||
logs := mockLogger.GetLogsByLevel("INFO")
|
||||
hasCustomModelLog := false
|
||||
for _, log := range logs {
|
||||
if log.Format == "🔧 [MCP] Qwen using custom Model: %s" {
|
||||
hasCustomModelLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasCustomModelLog {
|
||||
t.Error("should have logged custom Model")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Test Integration Features
|
||||
// ============================================================
|
||||
|
||||
func TestQwenClient_CallWithMessages_Success(t *testing.T) {
|
||||
mockHTTP := NewMockHTTPClient()
|
||||
mockHTTP.SetSuccessResponse("Qwen AI response")
|
||||
mockLogger := NewMockLogger()
|
||||
|
||||
client := NewQwenClientWithOptions(
|
||||
WithHTTPClient(mockHTTP.ToHTTPClient()),
|
||||
WithLogger(mockLogger),
|
||||
WithAPIKey("sk-test-key"),
|
||||
)
|
||||
|
||||
result, err := client.CallWithMessages("system prompt", "user prompt")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("should not error: %v", err)
|
||||
}
|
||||
|
||||
if result != "Qwen AI response" {
|
||||
t.Errorf("expected 'Qwen AI response', got '%s'", result)
|
||||
}
|
||||
|
||||
// Verify request
|
||||
requests := mockHTTP.GetRequests()
|
||||
if len(requests) != 1 {
|
||||
t.Fatalf("expected 1 request, got %d", len(requests))
|
||||
}
|
||||
|
||||
req := requests[0]
|
||||
|
||||
// Verify URL
|
||||
expectedURL := DefaultQwenBaseURL + "/chat/completions"
|
||||
if req.URL.String() != expectedURL {
|
||||
t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String())
|
||||
}
|
||||
|
||||
// Verify Authorization header
|
||||
authHeader := req.Header.Get("Authorization")
|
||||
if authHeader != "Bearer sk-test-key" {
|
||||
t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader)
|
||||
}
|
||||
|
||||
// Verify Content-Type
|
||||
if req.Header.Get("Content-Type") != "application/json" {
|
||||
t.Error("Content-Type should be application/json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenClient_Timeout(t *testing.T) {
|
||||
client := NewQwenClientWithOptions(
|
||||
WithTimeout(30 * time.Second),
|
||||
)
|
||||
|
||||
qwenClient := client.(*QwenClient)
|
||||
|
||||
if qwenClient.httpClient.Timeout != 30*time.Second {
|
||||
t.Errorf("expected timeout 30s, got %v", qwenClient.httpClient.Timeout)
|
||||
}
|
||||
|
||||
// Test SetTimeout
|
||||
client.SetTimeout(60 * time.Second)
|
||||
|
||||
if qwenClient.httpClient.Timeout != 60*time.Second {
|
||||
t.Errorf("expected timeout 60s after SetTimeout, got %v", qwenClient.httpClient.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Test hooks Mechanism
|
||||
// ============================================================
|
||||
|
||||
func TestQwenClient_HooksIntegration(t *testing.T) {
|
||||
client := NewQwenClientWithOptions()
|
||||
qwenClient := client.(*QwenClient)
|
||||
|
||||
// Verify hooks point to qwenClient itself (implements polymorphism)
|
||||
if qwenClient.hooks != qwenClient {
|
||||
t.Error("hooks should point to qwenClient for polymorphism")
|
||||
}
|
||||
|
||||
// Verify buildUrl uses Qwen configuration
|
||||
url := qwenClient.buildUrl()
|
||||
expectedURL := DefaultQwenBaseURL + "/chat/completions"
|
||||
if url != expectedURL {
|
||||
t.Errorf("expected URL '%s', got '%s'", expectedURL, url)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package mcp
|
||||
|
||||
// providerRegistry maps provider names to factory functions.
|
||||
var providerRegistry = map[string]func(...ClientOption) AIClient{}
|
||||
|
||||
// RegisterProvider registers a provider factory function.
|
||||
// Called by provider/payment sub-packages in their init() functions.
|
||||
func RegisterProvider(name string, factory func(...ClientOption) AIClient) {
|
||||
providerRegistry[name] = factory
|
||||
}
|
||||
|
||||
// NewAIClientByProvider creates an AIClient by provider name using the registry.
|
||||
// Returns nil if the provider is not registered.
|
||||
func NewAIClientByProvider(name string, opts ...ClientOption) AIClient {
|
||||
factory, ok := providerRegistry[name]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return factory(opts...)
|
||||
}
|
||||
@@ -450,10 +450,10 @@ func TestClient_CallWithRequest_UsesClientModel(t *testing.T) {
|
||||
mockHTTP.SetSuccessResponse("Response")
|
||||
mockLogger := NewMockLogger()
|
||||
|
||||
client := NewDeepSeekClientWithOptions(
|
||||
client := NewClient(
|
||||
WithDeepSeekConfig("sk-test-key"),
|
||||
WithHTTPClient(mockHTTP.ToHTTPClient()),
|
||||
WithLogger(mockLogger),
|
||||
WithAPIKey("sk-test-key"),
|
||||
)
|
||||
|
||||
// Request does not set model, should use Client's model
|
||||
|
||||
+6
-25
@@ -5,6 +5,8 @@ import (
|
||||
"nofx/config"
|
||||
"nofx/logger"
|
||||
"nofx/mcp"
|
||||
_ "nofx/mcp/payment"
|
||||
_ "nofx/mcp/provider"
|
||||
"nofx/store"
|
||||
"nofx/telegram/agent"
|
||||
"os"
|
||||
@@ -319,32 +321,11 @@ func isUSDCProvider(provider string) bool {
|
||||
}
|
||||
|
||||
func clientForProvider(provider string) mcp.AIClient {
|
||||
switch provider {
|
||||
case "openai":
|
||||
return mcp.NewOpenAIClient()
|
||||
case "deepseek":
|
||||
return mcp.NewDeepSeekClient()
|
||||
case "claude":
|
||||
return mcp.NewClaudeClient()
|
||||
case "qwen":
|
||||
return mcp.NewQwenClient()
|
||||
case "kimi":
|
||||
return mcp.NewKimiClient()
|
||||
case "grok":
|
||||
return mcp.NewGrokClient()
|
||||
case "gemini":
|
||||
return mcp.NewGeminiClient()
|
||||
case "minimax":
|
||||
return mcp.NewMiniMaxClient()
|
||||
case "blockrun-base":
|
||||
return mcp.NewBlockRunBaseClient()
|
||||
case "blockrun-sol":
|
||||
return mcp.NewBlockRunSolClient()
|
||||
case "claw402":
|
||||
return mcp.NewClaw402Client()
|
||||
default:
|
||||
return mcp.NewDeepSeekClient()
|
||||
client := mcp.NewAIClientByProvider(provider)
|
||||
if client == nil {
|
||||
client = mcp.NewAIClientByProvider("deepseek")
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// ── Status message ────────────────────────────────────────────────────────────
|
||||
|
||||
+33
-1771
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,527 @@
|
||||
package trader
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"nofx/experience"
|
||||
"nofx/kernel"
|
||||
"nofx/logger"
|
||||
"nofx/market"
|
||||
"nofx/store"
|
||||
"time"
|
||||
)
|
||||
|
||||
// saveEquitySnapshot saves equity snapshot independently (for drawing profit curve, decoupled from AI decision)
|
||||
func (at *AutoTrader) saveEquitySnapshot(ctx *kernel.Context) {
|
||||
if at.store == nil || ctx == nil {
|
||||
return
|
||||
}
|
||||
|
||||
snapshot := &store.EquitySnapshot{
|
||||
TraderID: at.id,
|
||||
Timestamp: time.Now().UTC(),
|
||||
TotalEquity: ctx.Account.TotalEquity,
|
||||
Balance: ctx.Account.TotalEquity - ctx.Account.UnrealizedPnL,
|
||||
UnrealizedPnL: ctx.Account.UnrealizedPnL,
|
||||
PositionCount: ctx.Account.PositionCount,
|
||||
MarginUsedPct: ctx.Account.MarginUsedPct,
|
||||
}
|
||||
|
||||
if err := at.store.Equity().Save(snapshot); err != nil {
|
||||
logger.Infof("⚠️ Failed to save equity snapshot: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// saveDecision saves AI decision log to database (only records AI input/output, for debugging)
|
||||
func (at *AutoTrader) saveDecision(record *store.DecisionRecord) error {
|
||||
if at.store == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
at.cycleNumber++
|
||||
record.CycleNumber = at.cycleNumber
|
||||
record.TraderID = at.id
|
||||
|
||||
if record.Timestamp.IsZero() {
|
||||
record.Timestamp = time.Now().UTC()
|
||||
}
|
||||
|
||||
if err := at.store.Decision().LogDecision(record); err != nil {
|
||||
logger.Infof("⚠️ Failed to save decision record: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Infof("📝 Decision record saved: trader=%s, cycle=%d", at.id, at.cycleNumber)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStatus gets system status (for API)
|
||||
func (at *AutoTrader) GetStatus() map[string]interface{} {
|
||||
aiProvider := "DeepSeek"
|
||||
if at.config.UseQwen {
|
||||
aiProvider = "Qwen"
|
||||
}
|
||||
|
||||
at.isRunningMutex.RLock()
|
||||
isRunning := at.isRunning
|
||||
at.isRunningMutex.RUnlock()
|
||||
|
||||
result := map[string]interface{}{
|
||||
"trader_id": at.id,
|
||||
"trader_name": at.name,
|
||||
"ai_model": at.aiModel,
|
||||
"exchange": at.exchange,
|
||||
"is_running": isRunning,
|
||||
"start_time": at.startTime.Format(time.RFC3339),
|
||||
"runtime_minutes": int(time.Since(at.startTime).Minutes()),
|
||||
"call_count": at.callCount,
|
||||
"initial_balance": at.initialBalance,
|
||||
"scan_interval": at.config.ScanInterval.String(),
|
||||
"stop_until": at.stopUntil.Format(time.RFC3339),
|
||||
"last_reset_time": at.lastResetTime.Format(time.RFC3339),
|
||||
"ai_provider": aiProvider,
|
||||
}
|
||||
|
||||
// Add strategy info
|
||||
if at.config.StrategyConfig != nil {
|
||||
result["strategy_type"] = at.config.StrategyConfig.StrategyType
|
||||
if at.config.StrategyConfig.GridConfig != nil {
|
||||
result["grid_symbol"] = at.config.StrategyConfig.GridConfig.Symbol
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetAccountInfo gets account information (for API)
|
||||
func (at *AutoTrader) GetAccountInfo() (map[string]interface{}, error) {
|
||||
balance, err := at.trader.GetBalance()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get balance: %w", err)
|
||||
}
|
||||
|
||||
// Get account fields
|
||||
totalWalletBalance := 0.0
|
||||
totalUnrealizedProfit := 0.0
|
||||
availableBalance := 0.0
|
||||
totalEquity := 0.0
|
||||
|
||||
if wallet, ok := balance["totalWalletBalance"].(float64); ok {
|
||||
totalWalletBalance = wallet
|
||||
}
|
||||
if unrealized, ok := balance["totalUnrealizedProfit"].(float64); ok {
|
||||
totalUnrealizedProfit = unrealized
|
||||
}
|
||||
if avail, ok := balance["availableBalance"].(float64); ok {
|
||||
availableBalance = avail
|
||||
}
|
||||
|
||||
// Use totalEquity directly if provided by trader (more accurate)
|
||||
if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 {
|
||||
totalEquity = eq
|
||||
} else {
|
||||
// Fallback: Total Equity = Wallet balance + Unrealized profit
|
||||
totalEquity = totalWalletBalance + totalUnrealizedProfit
|
||||
}
|
||||
|
||||
// Get positions to calculate total margin
|
||||
positions, err := at.trader.GetPositions()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get positions: %w", err)
|
||||
}
|
||||
|
||||
totalMarginUsed := 0.0
|
||||
totalUnrealizedPnLCalculated := 0.0
|
||||
for _, pos := range positions {
|
||||
markPrice := pos["markPrice"].(float64)
|
||||
quantity := pos["positionAmt"].(float64)
|
||||
if quantity < 0 {
|
||||
quantity = -quantity
|
||||
}
|
||||
unrealizedPnl := pos["unRealizedProfit"].(float64)
|
||||
totalUnrealizedPnLCalculated += unrealizedPnl
|
||||
|
||||
leverage := 10
|
||||
if lev, ok := pos["leverage"].(float64); ok {
|
||||
leverage = int(lev)
|
||||
}
|
||||
marginUsed := (quantity * markPrice) / float64(leverage)
|
||||
totalMarginUsed += marginUsed
|
||||
}
|
||||
|
||||
// Verify unrealized P&L consistency (API value vs calculated from positions)
|
||||
// Note: Lighter API may return 0 for unrealized PnL, this is a known limitation
|
||||
diff := math.Abs(totalUnrealizedProfit - totalUnrealizedPnLCalculated)
|
||||
if diff > 5.0 { // Only warn if difference is significant (> 5 USDT)
|
||||
logger.Infof("⚠️ Unrealized P&L inconsistency (Lighter API limitation): API=%.4f, Calculated=%.4f, Diff=%.4f",
|
||||
totalUnrealizedProfit, totalUnrealizedPnLCalculated, diff)
|
||||
}
|
||||
|
||||
totalPnL := totalEquity - at.initialBalance
|
||||
totalPnLPct := 0.0
|
||||
if at.initialBalance > 0 {
|
||||
totalPnLPct = (totalPnL / at.initialBalance) * 100
|
||||
} else {
|
||||
logger.Infof("⚠️ Initial Balance abnormal: %.2f, cannot calculate P&L percentage", at.initialBalance)
|
||||
}
|
||||
|
||||
marginUsedPct := 0.0
|
||||
if totalEquity > 0 {
|
||||
marginUsedPct = (totalMarginUsed / totalEquity) * 100
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
// Core fields
|
||||
"total_equity": totalEquity, // Account equity = wallet + unrealized
|
||||
"wallet_balance": totalWalletBalance, // Wallet balance (excluding unrealized P&L)
|
||||
"unrealized_profit": totalUnrealizedProfit, // Unrealized P&L (official value from exchange API)
|
||||
"available_balance": availableBalance, // Available balance
|
||||
|
||||
// P&L statistics
|
||||
"total_pnl": totalPnL, // Total P&L = equity - initial
|
||||
"total_pnl_pct": totalPnLPct, // Total P&L percentage
|
||||
"initial_balance": at.initialBalance, // Initial balance
|
||||
"daily_pnl": at.dailyPnL, // Daily P&L
|
||||
|
||||
// Position information
|
||||
"position_count": len(positions), // Position count
|
||||
"margin_used": totalMarginUsed, // Margin used
|
||||
"margin_used_pct": marginUsedPct, // Margin usage rate
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetPositions gets position list (for API)
|
||||
func (at *AutoTrader) GetPositions() ([]map[string]interface{}, error) {
|
||||
positions, err := at.trader.GetPositions()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get positions: %w", err)
|
||||
}
|
||||
|
||||
var result []map[string]interface{}
|
||||
for _, pos := range positions {
|
||||
symbol := pos["symbol"].(string)
|
||||
side := pos["side"].(string)
|
||||
entryPrice := pos["entryPrice"].(float64)
|
||||
markPrice := pos["markPrice"].(float64)
|
||||
quantity := pos["positionAmt"].(float64)
|
||||
if quantity < 0 {
|
||||
quantity = -quantity
|
||||
}
|
||||
unrealizedPnl := pos["unRealizedProfit"].(float64)
|
||||
liquidationPrice := pos["liquidationPrice"].(float64)
|
||||
|
||||
leverage := 10
|
||||
if lev, ok := pos["leverage"].(float64); ok {
|
||||
leverage = int(lev)
|
||||
}
|
||||
|
||||
// Calculate margin used
|
||||
marginUsed := (quantity * markPrice) / float64(leverage)
|
||||
|
||||
// Calculate P&L percentage (based on margin)
|
||||
pnlPct := calculatePnLPercentage(unrealizedPnl, marginUsed)
|
||||
|
||||
result = append(result, map[string]interface{}{
|
||||
"symbol": symbol,
|
||||
"side": side,
|
||||
"entry_price": entryPrice,
|
||||
"mark_price": markPrice,
|
||||
"quantity": quantity,
|
||||
"leverage": leverage,
|
||||
"unrealized_pnl": unrealizedPnl,
|
||||
"unrealized_pnl_pct": pnlPct,
|
||||
"liquidation_price": liquidationPrice,
|
||||
"margin_used": marginUsed,
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// recordAndConfirmOrder polls order status for actual fill data and records position
|
||||
// action: open_long, open_short, close_long, close_short
|
||||
// entryPrice: entry price when closing (0 when opening)
|
||||
func (at *AutoTrader) recordAndConfirmOrder(orderResult map[string]interface{}, symbol, action string, quantity float64, price float64, leverage int, entryPrice float64) {
|
||||
if at.store == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get order ID (supports multiple types)
|
||||
var orderID string
|
||||
switch v := orderResult["orderId"].(type) {
|
||||
case int64:
|
||||
orderID = fmt.Sprintf("%d", v)
|
||||
case float64:
|
||||
orderID = fmt.Sprintf("%.0f", v)
|
||||
case string:
|
||||
orderID = v
|
||||
default:
|
||||
orderID = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
if orderID == "" || orderID == "0" {
|
||||
logger.Infof(" ⚠️ Order ID is empty, skipping record")
|
||||
return
|
||||
}
|
||||
|
||||
// Determine positionSide
|
||||
var positionSide string
|
||||
switch action {
|
||||
case "open_long", "close_long":
|
||||
positionSide = "LONG"
|
||||
case "open_short", "close_short":
|
||||
positionSide = "SHORT"
|
||||
}
|
||||
|
||||
var actualPrice = price
|
||||
var actualQty = quantity
|
||||
var fee float64
|
||||
|
||||
// Exchanges with OrderSync: Skip immediate order recording, let OrderSync handle it
|
||||
// This ensures accurate data from GetTrades API and avoids duplicate records
|
||||
switch at.exchange {
|
||||
case "binance", "lighter", "hyperliquid", "bybit", "okx", "bitget", "aster", "kucoin", "gate":
|
||||
logger.Infof(" 📝 Order submitted (id: %s), will be synced by OrderSync", orderID)
|
||||
return
|
||||
}
|
||||
|
||||
// For exchanges without OrderSync (e.g., Binance): record immediately and poll for fill data
|
||||
orderRecord := at.createOrderRecord(orderID, symbol, action, positionSide, quantity, price, leverage)
|
||||
if err := at.store.Order().CreateOrder(orderRecord); err != nil {
|
||||
logger.Infof(" ⚠️ Failed to record order: %v", err)
|
||||
} else {
|
||||
logger.Infof(" 📝 Order recorded: %s [%s] %s", orderID, action, symbol)
|
||||
}
|
||||
|
||||
// Wait for order to be filled and get actual fill data
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
for i := 0; i < 5; i++ {
|
||||
status, err := at.trader.GetOrderStatus(symbol, orderID)
|
||||
if err == nil {
|
||||
statusStr, _ := status["status"].(string)
|
||||
if statusStr == "FILLED" {
|
||||
// Get actual fill price
|
||||
if avgPrice, ok := status["avgPrice"].(float64); ok && avgPrice > 0 {
|
||||
actualPrice = avgPrice
|
||||
}
|
||||
// Get actual executed quantity
|
||||
if execQty, ok := status["executedQty"].(float64); ok && execQty > 0 {
|
||||
actualQty = execQty
|
||||
}
|
||||
// Get commission/fee
|
||||
if commission, ok := status["commission"].(float64); ok {
|
||||
fee = commission
|
||||
}
|
||||
logger.Infof(" ✅ Order filled: avgPrice=%.6f, qty=%.6f, fee=%.6f", actualPrice, actualQty, fee)
|
||||
|
||||
// Update order status to FILLED
|
||||
if err := at.store.Order().UpdateOrderStatus(orderRecord.ID, "FILLED", actualQty, actualPrice, fee); err != nil {
|
||||
logger.Infof(" ⚠️ Failed to update order status: %v", err)
|
||||
}
|
||||
|
||||
// Record fill details
|
||||
at.recordOrderFill(orderRecord.ID, orderID, symbol, action, actualPrice, actualQty, fee)
|
||||
break
|
||||
} else if statusStr == "CANCELED" || statusStr == "EXPIRED" || statusStr == "REJECTED" {
|
||||
logger.Infof(" ⚠️ Order %s, skipping position record", statusStr)
|
||||
// Update order status
|
||||
if err := at.store.Order().UpdateOrderStatus(orderRecord.ID, statusStr, 0, 0, 0); err != nil {
|
||||
logger.Infof(" ⚠️ Failed to update order status: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Normalize symbol for position record consistency
|
||||
normalizedSymbolForPosition := market.Normalize(symbol)
|
||||
|
||||
logger.Infof(" 📝 Recording position (ID: %s, action: %s, price: %.6f, qty: %.6f, fee: %.4f)",
|
||||
orderID, action, actualPrice, actualQty, fee)
|
||||
|
||||
// Record position change with actual fill data (use normalized symbol)
|
||||
at.recordPositionChange(orderID, normalizedSymbolForPosition, positionSide, action, actualQty, actualPrice, leverage, entryPrice, fee)
|
||||
|
||||
// Send anonymous trade statistics for experience improvement (async, non-blocking)
|
||||
// This helps us understand overall product usage across all deployments
|
||||
experience.TrackTrade(experience.TradeEvent{
|
||||
Exchange: at.exchange,
|
||||
TradeType: action,
|
||||
Symbol: symbol,
|
||||
AmountUSD: actualPrice * actualQty,
|
||||
Leverage: leverage,
|
||||
UserID: at.userID,
|
||||
TraderID: at.id,
|
||||
})
|
||||
}
|
||||
|
||||
// recordPositionChange records position change (create record on open, update record on close)
|
||||
func (at *AutoTrader) recordPositionChange(orderID, symbol, side, action string, quantity, price float64, leverage int, entryPrice float64, fee float64) {
|
||||
if at.store == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "open_long", "open_short":
|
||||
// Open position: create new position record
|
||||
nowMs := time.Now().UTC().UnixMilli()
|
||||
pos := &store.TraderPosition{
|
||||
TraderID: at.id,
|
||||
ExchangeID: at.exchangeID, // Exchange account UUID
|
||||
ExchangeType: at.exchange, // Exchange type: binance/bybit/okx/etc
|
||||
Symbol: symbol,
|
||||
Side: side, // LONG or SHORT
|
||||
Quantity: quantity,
|
||||
EntryPrice: price,
|
||||
EntryOrderID: orderID,
|
||||
EntryTime: nowMs,
|
||||
Leverage: leverage,
|
||||
Status: "OPEN",
|
||||
CreatedAt: nowMs,
|
||||
UpdatedAt: nowMs,
|
||||
}
|
||||
if err := at.store.Position().Create(pos); err != nil {
|
||||
logger.Infof(" ⚠️ Failed to record position: %v", err)
|
||||
} else {
|
||||
logger.Infof(" 📊 Position recorded [%s] %s %s @ %.4f", at.id[:8], symbol, side, price)
|
||||
}
|
||||
|
||||
case "close_long", "close_short":
|
||||
// Close position using PositionBuilder for consistent handling
|
||||
// PositionBuilder will handle both cases:
|
||||
// 1. If open position exists: close it properly
|
||||
// 2. If no open position (e.g., table cleared): create a closed position record
|
||||
posBuilder := store.NewPositionBuilder(at.store.Position())
|
||||
if err := posBuilder.ProcessTrade(
|
||||
at.id, at.exchangeID, at.exchange,
|
||||
symbol, side, action,
|
||||
quantity, price, fee, 0, // realizedPnL will be calculated
|
||||
time.Now().UTC().UnixMilli(), orderID,
|
||||
); err != nil {
|
||||
logger.Infof(" ⚠️ Failed to process close position: %v", err)
|
||||
} else {
|
||||
logger.Infof(" ✅ Position closed [%s] %s %s @ %.4f", at.id[:8], symbol, side, price)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createOrderRecord creates an order record struct from order details
|
||||
func (at *AutoTrader) createOrderRecord(orderID, symbol, action, positionSide string, quantity, price float64, leverage int) *store.TraderOrder {
|
||||
// Determine order type (market for auto trader)
|
||||
orderType := "MARKET"
|
||||
|
||||
// Determine side (BUY/SELL)
|
||||
var side string
|
||||
switch action {
|
||||
case "open_long", "close_short":
|
||||
side = "BUY"
|
||||
case "open_short", "close_long":
|
||||
side = "SELL"
|
||||
}
|
||||
|
||||
// Use action as orderAction directly (keep lowercase format)
|
||||
orderAction := action
|
||||
|
||||
// Determine if it's a reduce only order
|
||||
reduceOnly := (action == "close_long" || action == "close_short")
|
||||
|
||||
// Normalize symbol for consistency
|
||||
normalizedSymbol := market.Normalize(symbol)
|
||||
|
||||
return &store.TraderOrder{
|
||||
TraderID: at.id,
|
||||
ExchangeID: at.exchangeID,
|
||||
ExchangeType: at.exchange,
|
||||
ExchangeOrderID: orderID,
|
||||
Symbol: normalizedSymbol,
|
||||
Side: side,
|
||||
PositionSide: positionSide,
|
||||
Type: orderType,
|
||||
TimeInForce: "GTC",
|
||||
Quantity: quantity,
|
||||
Price: price,
|
||||
Status: "NEW",
|
||||
FilledQuantity: 0,
|
||||
AvgFillPrice: 0,
|
||||
Commission: 0,
|
||||
CommissionAsset: "USDT",
|
||||
Leverage: leverage,
|
||||
ReduceOnly: reduceOnly,
|
||||
ClosePosition: reduceOnly,
|
||||
OrderAction: orderAction,
|
||||
CreatedAt: time.Now().UTC().UnixMilli(),
|
||||
UpdatedAt: time.Now().UTC().UnixMilli(),
|
||||
}
|
||||
}
|
||||
|
||||
// recordOrderFill records order fill/trade details
|
||||
func (at *AutoTrader) recordOrderFill(orderRecordID int64, exchangeOrderID, symbol, action string, price, quantity, fee float64) {
|
||||
if at.store == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Determine side (BUY/SELL)
|
||||
var side string
|
||||
switch action {
|
||||
case "open_long", "close_short":
|
||||
side = "BUY"
|
||||
case "open_short", "close_long":
|
||||
side = "SELL"
|
||||
}
|
||||
|
||||
// Generate a simple trade ID (exchange doesn't always provide one)
|
||||
tradeID := fmt.Sprintf("%s-%d", exchangeOrderID, time.Now().UnixNano())
|
||||
|
||||
// Normalize symbol for consistency
|
||||
normalizedSymbol := market.Normalize(symbol)
|
||||
|
||||
fill := &store.TraderFill{
|
||||
TraderID: at.id,
|
||||
ExchangeID: at.exchangeID,
|
||||
ExchangeType: at.exchange,
|
||||
OrderID: orderRecordID,
|
||||
ExchangeOrderID: exchangeOrderID,
|
||||
ExchangeTradeID: tradeID,
|
||||
Symbol: normalizedSymbol,
|
||||
Side: side,
|
||||
Price: price,
|
||||
Quantity: quantity,
|
||||
QuoteQuantity: price * quantity,
|
||||
Commission: fee,
|
||||
CommissionAsset: "USDT",
|
||||
RealizedPnL: 0, // Will be calculated for close orders
|
||||
IsMaker: false, // Market orders are usually taker
|
||||
CreatedAt: time.Now().UTC().UnixMilli(),
|
||||
}
|
||||
|
||||
// Calculate realized PnL for close orders
|
||||
if action == "close_long" || action == "close_short" {
|
||||
// Try to get the entry price from the open position
|
||||
var positionSide string
|
||||
if action == "close_long" {
|
||||
positionSide = "LONG"
|
||||
} else {
|
||||
positionSide = "SHORT"
|
||||
}
|
||||
|
||||
if openPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, symbol, positionSide); err == nil && openPos != nil {
|
||||
if positionSide == "LONG" {
|
||||
fill.RealizedPnL = (price - openPos.EntryPrice) * quantity
|
||||
} else {
|
||||
fill.RealizedPnL = (openPos.EntryPrice - price) * quantity
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := at.store.Order().CreateFill(fill); err != nil {
|
||||
logger.Infof(" ⚠️ Failed to record fill: %v", err)
|
||||
} else {
|
||||
logger.Infof(" 📋 Fill recorded: %.4f @ %.6f, fee: %.4f", quantity, price, fee)
|
||||
}
|
||||
}
|
||||
|
||||
// GetOpenOrders returns open orders (pending SL/TP) from exchange
|
||||
func (at *AutoTrader) GetOpenOrders(symbol string) ([]OpenOrder, error) {
|
||||
return at.trader.GetOpenOrders(symbol)
|
||||
}
|
||||
@@ -0,0 +1,560 @@
|
||||
package trader
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"nofx/kernel"
|
||||
"nofx/logger"
|
||||
"nofx/store"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// runCycle runs one trading cycle (using AI full decision-making)
|
||||
func (at *AutoTrader) runCycle() error {
|
||||
at.callCount++
|
||||
|
||||
logger.Info("\n" + strings.Repeat("=", 70) + "\n")
|
||||
logger.Infof("⏰ %s - AI decision cycle #%d", time.Now().Format("2006-01-02 15:04:05"), at.callCount)
|
||||
logger.Info(strings.Repeat("=", 70))
|
||||
|
||||
// 0. Check if trader is stopped (early exit to prevent trades after Stop() is called)
|
||||
at.isRunningMutex.RLock()
|
||||
running := at.isRunning
|
||||
at.isRunningMutex.RUnlock()
|
||||
if !running {
|
||||
logger.Infof("⏹ Trader is stopped, aborting cycle #%d", at.callCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create decision record
|
||||
record := &store.DecisionRecord{
|
||||
ExecutionLog: []string{},
|
||||
Success: true,
|
||||
}
|
||||
|
||||
// 1. Check if trading needs to be stopped
|
||||
if time.Now().Before(at.stopUntil) {
|
||||
remaining := at.stopUntil.Sub(time.Now())
|
||||
logger.Infof("⏸ Risk control: Trading paused, remaining %.0f minutes", remaining.Minutes())
|
||||
record.Success = false
|
||||
record.ErrorMessage = fmt.Sprintf("Risk control paused, remaining %.0f minutes", remaining.Minutes())
|
||||
at.saveDecision(record)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 2. Reset daily P&L (reset every day)
|
||||
if time.Since(at.lastResetTime) > 24*time.Hour {
|
||||
at.dailyPnL = 0
|
||||
at.lastResetTime = time.Now()
|
||||
logger.Info("📅 Daily P&L reset")
|
||||
}
|
||||
|
||||
// 4. Collect trading context
|
||||
ctx, err := at.buildTradingContext()
|
||||
if err != nil {
|
||||
record.Success = false
|
||||
record.ErrorMessage = fmt.Sprintf("Failed to build trading context: %v", err)
|
||||
at.saveDecision(record)
|
||||
return fmt.Errorf("failed to build trading context: %w", err)
|
||||
}
|
||||
|
||||
// Save equity snapshot independently (decoupled from AI decision, used for drawing profit curve)
|
||||
// NOTE: Must be called BEFORE candidate coins check to ensure equity is always recorded
|
||||
at.saveEquitySnapshot(ctx)
|
||||
|
||||
// 如果没有候选币种,记录但不报错
|
||||
if len(ctx.CandidateCoins) == 0 {
|
||||
logger.Infof("ℹ️ No candidate coins available, skipping this cycle")
|
||||
record.Success = true // 不是错误,只是没有候选币
|
||||
record.ExecutionLog = append(record.ExecutionLog, "No candidate coins available, cycle skipped")
|
||||
record.AccountState = store.AccountSnapshot{
|
||||
TotalBalance: ctx.Account.TotalEquity,
|
||||
AvailableBalance: ctx.Account.AvailableBalance,
|
||||
TotalUnrealizedProfit: ctx.Account.UnrealizedPnL,
|
||||
PositionCount: ctx.Account.PositionCount,
|
||||
InitialBalance: at.initialBalance,
|
||||
}
|
||||
at.saveDecision(record)
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info(strings.Repeat("=", 70))
|
||||
for _, coin := range ctx.CandidateCoins {
|
||||
record.CandidateCoins = append(record.CandidateCoins, coin.Symbol)
|
||||
}
|
||||
|
||||
logger.Infof("📊 Account equity: %.2f USDT | Available: %.2f USDT | Positions: %d",
|
||||
ctx.Account.TotalEquity, ctx.Account.AvailableBalance, ctx.Account.PositionCount)
|
||||
|
||||
// 5. Use strategy engine to call AI for decision
|
||||
logger.Infof("🤖 Requesting AI analysis and decision... [Strategy Engine]")
|
||||
aiDecision, err := kernel.GetFullDecisionWithStrategy(ctx, at.mcpClient, at.strategyEngine, "balanced")
|
||||
|
||||
if aiDecision != nil && aiDecision.AIRequestDurationMs > 0 {
|
||||
record.AIRequestDurationMs = aiDecision.AIRequestDurationMs
|
||||
logger.Infof("⏱️ AI call duration: %.2f seconds", float64(record.AIRequestDurationMs)/1000)
|
||||
record.ExecutionLog = append(record.ExecutionLog,
|
||||
fmt.Sprintf("AI call duration: %d ms", record.AIRequestDurationMs))
|
||||
}
|
||||
|
||||
// Save chain of thought, decisions, and input prompt even if there's an error (for debugging)
|
||||
if aiDecision != nil {
|
||||
record.SystemPrompt = aiDecision.SystemPrompt // Save system prompt
|
||||
record.InputPrompt = aiDecision.UserPrompt
|
||||
record.CoTTrace = aiDecision.CoTTrace
|
||||
record.RawResponse = aiDecision.RawResponse // Save raw AI response for debugging
|
||||
if len(aiDecision.Decisions) > 0 {
|
||||
decisionJSON, _ := json.MarshalIndent(aiDecision.Decisions, "", " ")
|
||||
record.DecisionJSON = string(decisionJSON)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
record.Success = false
|
||||
record.ErrorMessage = fmt.Sprintf("Failed to get AI decision: %v", err)
|
||||
|
||||
// Print system prompt and AI chain of thought (output even with errors for debugging)
|
||||
if aiDecision != nil {
|
||||
logger.Info("\n" + strings.Repeat("=", 70) + "\n")
|
||||
logger.Infof("📋 System prompt (error case)")
|
||||
logger.Info(strings.Repeat("=", 70))
|
||||
logger.Info(aiDecision.SystemPrompt)
|
||||
logger.Info(strings.Repeat("=", 70))
|
||||
|
||||
if aiDecision.CoTTrace != "" {
|
||||
logger.Info("\n" + strings.Repeat("-", 70) + "\n")
|
||||
logger.Info("💭 AI chain of thought analysis (error case):")
|
||||
logger.Info(strings.Repeat("-", 70))
|
||||
logger.Info(aiDecision.CoTTrace)
|
||||
logger.Info(strings.Repeat("-", 70))
|
||||
}
|
||||
}
|
||||
|
||||
at.saveDecision(record)
|
||||
return fmt.Errorf("failed to get AI decision: %w", err)
|
||||
}
|
||||
|
||||
// // 5. Print system prompt
|
||||
// logger.Infof("\n" + strings.Repeat("=", 70))
|
||||
// logger.Infof("📋 System prompt [template: %s]", at.systemPromptTemplate)
|
||||
// logger.Info(strings.Repeat("=", 70))
|
||||
// logger.Info(decision.SystemPrompt)
|
||||
// logger.Infof(strings.Repeat("=", 70) + "\n")
|
||||
|
||||
// 6. Print AI chain of thought
|
||||
// logger.Infof("\n" + strings.Repeat("-", 70))
|
||||
// logger.Info("💭 AI chain of thought analysis:")
|
||||
// logger.Info(strings.Repeat("-", 70))
|
||||
// logger.Info(decision.CoTTrace)
|
||||
// logger.Infof(strings.Repeat("-", 70) + "\n")
|
||||
|
||||
// 7. Print AI decisions
|
||||
// logger.Infof("📋 AI decision list (%d items):\n", len(kernel.Decisions))
|
||||
// for i, d := range kernel.Decisions {
|
||||
// logger.Infof(" [%d] %s: %s - %s", i+1, d.Symbol, d.Action, d.Reasoning)
|
||||
// if d.Action == "open_long" || d.Action == "open_short" {
|
||||
// logger.Infof(" Leverage: %dx | Position: %.2f USDT | Stop loss: %.4f | Take profit: %.4f",
|
||||
// d.Leverage, d.PositionSizeUSD, d.StopLoss, d.TakeProfit)
|
||||
// }
|
||||
// }
|
||||
logger.Info()
|
||||
logger.Info(strings.Repeat("-", 70))
|
||||
// 8. Sort decisions: ensure close positions first, then open positions (prevent position stacking overflow)
|
||||
logger.Info(strings.Repeat("-", 70))
|
||||
|
||||
// 8. Sort decisions: ensure close positions first, then open positions (prevent position stacking overflow)
|
||||
sortedDecisions := sortDecisionsByPriority(aiDecision.Decisions)
|
||||
|
||||
logger.Info("🔄 Execution order (optimized): Close positions first → Open positions later")
|
||||
for i, d := range sortedDecisions {
|
||||
logger.Infof(" [%d] %s %s", i+1, d.Symbol, d.Action)
|
||||
}
|
||||
logger.Info()
|
||||
|
||||
// Check if trader is stopped before executing any decisions (prevent trades after Stop())
|
||||
at.isRunningMutex.RLock()
|
||||
running = at.isRunning
|
||||
at.isRunningMutex.RUnlock()
|
||||
if !running {
|
||||
logger.Infof("⏹ Trader stopped before decision execution, aborting cycle #%d", at.callCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute decisions and record results
|
||||
for _, d := range sortedDecisions {
|
||||
// Check if trader is stopped before each decision (allow immediate stop during execution)
|
||||
at.isRunningMutex.RLock()
|
||||
running = at.isRunning
|
||||
at.isRunningMutex.RUnlock()
|
||||
if !running {
|
||||
logger.Infof("⏹ Trader stopped during decision execution, aborting remaining decisions")
|
||||
break
|
||||
}
|
||||
|
||||
actionRecord := store.DecisionAction{
|
||||
Action: d.Action,
|
||||
Symbol: d.Symbol,
|
||||
Quantity: 0,
|
||||
Leverage: d.Leverage,
|
||||
Price: 0,
|
||||
StopLoss: d.StopLoss,
|
||||
TakeProfit: d.TakeProfit,
|
||||
Confidence: d.Confidence,
|
||||
Reasoning: d.Reasoning,
|
||||
Timestamp: time.Now().UTC(),
|
||||
Success: false,
|
||||
}
|
||||
|
||||
if err := at.executeDecisionWithRecord(&d, &actionRecord); err != nil {
|
||||
logger.Infof("❌ Failed to execute decision (%s %s): %v", d.Symbol, d.Action, err)
|
||||
actionRecord.Error = err.Error()
|
||||
record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("❌ %s %s failed: %v", d.Symbol, d.Action, err))
|
||||
} else {
|
||||
actionRecord.Success = true
|
||||
record.ExecutionLog = append(record.ExecutionLog, fmt.Sprintf("✓ %s %s succeeded", d.Symbol, d.Action))
|
||||
// Brief delay after successful execution
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
record.Decisions = append(record.Decisions, actionRecord)
|
||||
}
|
||||
|
||||
// 9. Save decision record
|
||||
if err := at.saveDecision(record); err != nil {
|
||||
logger.Infof("⚠ Failed to save decision record: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildTradingContext builds trading context
|
||||
func (at *AutoTrader) buildTradingContext() (*kernel.Context, error) {
|
||||
// 1. Get account information
|
||||
balance, err := at.trader.GetBalance()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account balance: %w", err)
|
||||
}
|
||||
|
||||
// Get account fields
|
||||
totalWalletBalance := 0.0
|
||||
totalUnrealizedProfit := 0.0
|
||||
availableBalance := 0.0
|
||||
totalEquity := 0.0
|
||||
|
||||
if wallet, ok := balance["totalWalletBalance"].(float64); ok {
|
||||
totalWalletBalance = wallet
|
||||
}
|
||||
if unrealized, ok := balance["totalUnrealizedProfit"].(float64); ok {
|
||||
totalUnrealizedProfit = unrealized
|
||||
}
|
||||
if avail, ok := balance["availableBalance"].(float64); ok {
|
||||
availableBalance = avail
|
||||
}
|
||||
|
||||
// Use totalEquity directly if provided by trader (more accurate)
|
||||
if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 {
|
||||
totalEquity = eq
|
||||
} else {
|
||||
// Fallback: Total Equity = Wallet balance + Unrealized profit
|
||||
totalEquity = totalWalletBalance + totalUnrealizedProfit
|
||||
}
|
||||
|
||||
// 2. Get position information
|
||||
positions, err := at.trader.GetPositions()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get positions: %w", err)
|
||||
}
|
||||
|
||||
var positionInfos []kernel.PositionInfo
|
||||
totalMarginUsed := 0.0
|
||||
|
||||
// Current position key set (for cleaning up closed position records)
|
||||
currentPositionKeys := make(map[string]bool)
|
||||
|
||||
for _, pos := range positions {
|
||||
symbol := pos["symbol"].(string)
|
||||
side := pos["side"].(string)
|
||||
entryPrice := pos["entryPrice"].(float64)
|
||||
markPrice := pos["markPrice"].(float64)
|
||||
quantity := pos["positionAmt"].(float64)
|
||||
if quantity < 0 {
|
||||
quantity = -quantity // Short position quantity is negative, convert to positive
|
||||
}
|
||||
|
||||
// Skip closed positions (quantity = 0), prevent "ghost positions" from being passed to AI
|
||||
if quantity == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
unrealizedPnl := pos["unRealizedProfit"].(float64)
|
||||
liquidationPrice := pos["liquidationPrice"].(float64)
|
||||
|
||||
// Calculate margin used (estimated)
|
||||
leverage := 10 // Default value, should actually be fetched from position info
|
||||
if lev, ok := pos["leverage"].(float64); ok {
|
||||
leverage = int(lev)
|
||||
}
|
||||
marginUsed := (quantity * markPrice) / float64(leverage)
|
||||
totalMarginUsed += marginUsed
|
||||
|
||||
// Calculate P&L percentage (based on margin, considering leverage)
|
||||
pnlPct := calculatePnLPercentage(unrealizedPnl, marginUsed)
|
||||
|
||||
// Get position open time from exchange (preferred) or fallback to local tracking
|
||||
posKey := symbol + "_" + side
|
||||
currentPositionKeys[posKey] = true
|
||||
|
||||
var updateTime int64
|
||||
// Priority 1: Get from database (trader_positions table) - most accurate
|
||||
if at.store != nil {
|
||||
if dbPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, symbol, side); err == nil && dbPos != nil {
|
||||
if dbPos.EntryTime > 0 {
|
||||
updateTime = dbPos.EntryTime
|
||||
}
|
||||
}
|
||||
}
|
||||
// Priority 2: Get from exchange API (Bybit: createdTime, OKX: createdTime)
|
||||
if updateTime == 0 {
|
||||
if createdTime, ok := pos["createdTime"].(int64); ok && createdTime > 0 {
|
||||
updateTime = createdTime
|
||||
}
|
||||
}
|
||||
// Priority 3: Fallback to local tracking
|
||||
if updateTime == 0 {
|
||||
if _, exists := at.positionFirstSeenTime[posKey]; !exists {
|
||||
at.positionFirstSeenTime[posKey] = time.Now().UnixMilli()
|
||||
}
|
||||
updateTime = at.positionFirstSeenTime[posKey]
|
||||
}
|
||||
|
||||
// Get peak profit rate for this position
|
||||
at.peakPnLCacheMutex.RLock()
|
||||
peakPnlPct := at.peakPnLCache[posKey]
|
||||
at.peakPnLCacheMutex.RUnlock()
|
||||
|
||||
positionInfos = append(positionInfos, kernel.PositionInfo{
|
||||
Symbol: symbol,
|
||||
Side: side,
|
||||
EntryPrice: entryPrice,
|
||||
MarkPrice: markPrice,
|
||||
Quantity: quantity,
|
||||
Leverage: leverage,
|
||||
UnrealizedPnL: unrealizedPnl,
|
||||
UnrealizedPnLPct: pnlPct,
|
||||
PeakPnLPct: peakPnlPct,
|
||||
LiquidationPrice: liquidationPrice,
|
||||
MarginUsed: marginUsed,
|
||||
UpdateTime: updateTime,
|
||||
})
|
||||
}
|
||||
|
||||
// Clean up closed position records
|
||||
for key := range at.positionFirstSeenTime {
|
||||
if !currentPositionKeys[key] {
|
||||
delete(at.positionFirstSeenTime, key)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Use strategy engine to get candidate coins (must have strategy engine)
|
||||
var candidateCoins []kernel.CandidateCoin
|
||||
if at.strategyEngine == nil {
|
||||
logger.Infof("⚠️ [%s] No strategy engine configured, skipping candidate coins", at.name)
|
||||
} else {
|
||||
coins, err := at.strategyEngine.GetCandidateCoins()
|
||||
if err != nil {
|
||||
// Log warning but don't fail - equity snapshot should still be saved
|
||||
logger.Infof("⚠️ [%s] Failed to get candidate coins: %v (will use empty list)", at.name, err)
|
||||
} else {
|
||||
candidateCoins = coins
|
||||
logger.Infof("📋 [%s] Strategy engine fetched candidate coins: %d", at.name, len(candidateCoins))
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Calculate total P&L
|
||||
totalPnL := totalEquity - at.initialBalance
|
||||
totalPnLPct := 0.0
|
||||
if at.initialBalance > 0 {
|
||||
totalPnLPct = (totalPnL / at.initialBalance) * 100
|
||||
}
|
||||
|
||||
marginUsedPct := 0.0
|
||||
if totalEquity > 0 {
|
||||
marginUsedPct = (totalMarginUsed / totalEquity) * 100
|
||||
}
|
||||
|
||||
// 5. Get leverage from strategy config
|
||||
strategyConfig := at.strategyEngine.GetConfig()
|
||||
btcEthLeverage := strategyConfig.RiskControl.BTCETHMaxLeverage
|
||||
altcoinLeverage := strategyConfig.RiskControl.AltcoinMaxLeverage
|
||||
logger.Infof("📋 [%s] Strategy leverage config: BTC/ETH=%dx, Altcoin=%dx", at.name, btcEthLeverage, altcoinLeverage)
|
||||
|
||||
// 6. Build context
|
||||
ctx := &kernel.Context{
|
||||
CurrentTime: time.Now().UTC().Format("2006-01-02 15:04:05 UTC"),
|
||||
RuntimeMinutes: int(time.Since(at.startTime).Minutes()),
|
||||
CallCount: at.callCount,
|
||||
BTCETHLeverage: btcEthLeverage,
|
||||
AltcoinLeverage: altcoinLeverage,
|
||||
Account: kernel.AccountInfo{
|
||||
TotalEquity: totalEquity,
|
||||
AvailableBalance: availableBalance,
|
||||
UnrealizedPnL: totalUnrealizedProfit,
|
||||
TotalPnL: totalPnL,
|
||||
TotalPnLPct: totalPnLPct,
|
||||
MarginUsed: totalMarginUsed,
|
||||
MarginUsedPct: marginUsedPct,
|
||||
PositionCount: len(positionInfos),
|
||||
},
|
||||
Positions: positionInfos,
|
||||
CandidateCoins: candidateCoins,
|
||||
}
|
||||
|
||||
// 7. Add recent closed trades (if store is available)
|
||||
if at.store != nil {
|
||||
// Get recent 10 closed trades for AI context
|
||||
recentTrades, err := at.store.Position().GetRecentTrades(at.id, 10)
|
||||
if err != nil {
|
||||
logger.Infof("⚠️ [%s] Failed to get recent trades: %v", at.name, err)
|
||||
} else {
|
||||
logger.Infof("📊 [%s] Found %d recent closed trades for AI context", at.name, len(recentTrades))
|
||||
for _, trade := range recentTrades {
|
||||
// Convert Unix timestamps to formatted strings for AI readability
|
||||
entryTimeStr := ""
|
||||
if trade.EntryTime > 0 {
|
||||
entryTimeStr = time.Unix(trade.EntryTime, 0).UTC().Format("01-02 15:04 UTC")
|
||||
}
|
||||
exitTimeStr := ""
|
||||
if trade.ExitTime > 0 {
|
||||
exitTimeStr = time.Unix(trade.ExitTime, 0).UTC().Format("01-02 15:04 UTC")
|
||||
}
|
||||
|
||||
ctx.RecentOrders = append(ctx.RecentOrders, kernel.RecentOrder{
|
||||
Symbol: trade.Symbol,
|
||||
Side: trade.Side,
|
||||
EntryPrice: trade.EntryPrice,
|
||||
ExitPrice: trade.ExitPrice,
|
||||
RealizedPnL: trade.RealizedPnL,
|
||||
PnLPct: trade.PnLPct,
|
||||
EntryTime: entryTimeStr,
|
||||
ExitTime: exitTimeStr,
|
||||
HoldDuration: trade.HoldDuration,
|
||||
})
|
||||
}
|
||||
}
|
||||
// Get trading statistics for AI context
|
||||
stats, err := at.store.Position().GetFullStats(at.id)
|
||||
if err != nil {
|
||||
logger.Infof("⚠️ [%s] Failed to get trading stats: %v", at.name, err)
|
||||
} else if stats == nil {
|
||||
logger.Infof("⚠️ [%s] GetFullStats returned nil", at.name)
|
||||
} else if stats.TotalTrades == 0 {
|
||||
logger.Infof("⚠️ [%s] GetFullStats returned 0 trades (traderID=%s)", at.name, at.id)
|
||||
} else {
|
||||
ctx.TradingStats = &kernel.TradingStats{
|
||||
TotalTrades: stats.TotalTrades,
|
||||
WinRate: stats.WinRate,
|
||||
ProfitFactor: stats.ProfitFactor,
|
||||
SharpeRatio: stats.SharpeRatio,
|
||||
TotalPnL: stats.TotalPnL,
|
||||
AvgWin: stats.AvgWin,
|
||||
AvgLoss: stats.AvgLoss,
|
||||
MaxDrawdownPct: stats.MaxDrawdownPct,
|
||||
}
|
||||
logger.Infof("📈 [%s] Trading stats: %d trades, %.1f%% win rate, PF=%.2f, Sharpe=%.2f, DD=%.1f%%",
|
||||
at.name, stats.TotalTrades, stats.WinRate, stats.ProfitFactor, stats.SharpeRatio, stats.MaxDrawdownPct)
|
||||
}
|
||||
} else {
|
||||
logger.Infof("⚠️ [%s] Store is nil, cannot get recent trades", at.name)
|
||||
}
|
||||
|
||||
// 8. Get quantitative data (if enabled in strategy config)
|
||||
if strategyConfig.Indicators.EnableQuantData {
|
||||
// Collect symbols to query (candidate coins + position coins)
|
||||
symbolsToQuery := make(map[string]bool)
|
||||
for _, coin := range candidateCoins {
|
||||
symbolsToQuery[coin.Symbol] = true
|
||||
}
|
||||
for _, pos := range positionInfos {
|
||||
symbolsToQuery[pos.Symbol] = true
|
||||
}
|
||||
|
||||
symbols := make([]string, 0, len(symbolsToQuery))
|
||||
for sym := range symbolsToQuery {
|
||||
symbols = append(symbols, sym)
|
||||
}
|
||||
|
||||
logger.Infof("📊 [%s] Fetching quantitative data for %d symbols...", at.name, len(symbols))
|
||||
ctx.QuantDataMap = at.strategyEngine.FetchQuantDataBatch(symbols)
|
||||
logger.Infof("📊 [%s] Successfully fetched quantitative data for %d symbols", at.name, len(ctx.QuantDataMap))
|
||||
}
|
||||
|
||||
// 9. Get OI ranking data (market-wide position changes)
|
||||
if strategyConfig.Indicators.EnableOIRanking {
|
||||
logger.Infof("📊 [%s] Fetching OI ranking data...", at.name)
|
||||
ctx.OIRankingData = at.strategyEngine.FetchOIRankingData()
|
||||
if ctx.OIRankingData != nil {
|
||||
logger.Infof("📊 [%s] OI ranking data ready: %d top, %d low positions",
|
||||
at.name, len(ctx.OIRankingData.TopPositions), len(ctx.OIRankingData.LowPositions))
|
||||
}
|
||||
}
|
||||
|
||||
// 10. Get NetFlow ranking data (market-wide fund flow)
|
||||
if strategyConfig.Indicators.EnableNetFlowRanking {
|
||||
logger.Infof("💰 [%s] Fetching NetFlow ranking data...", at.name)
|
||||
ctx.NetFlowRankingData = at.strategyEngine.FetchNetFlowRankingData()
|
||||
if ctx.NetFlowRankingData != nil {
|
||||
logger.Infof("💰 [%s] NetFlow ranking data ready: inst_in=%d, inst_out=%d",
|
||||
at.name, len(ctx.NetFlowRankingData.InstitutionFutureTop), len(ctx.NetFlowRankingData.InstitutionFutureLow))
|
||||
}
|
||||
}
|
||||
|
||||
// 11. Get Price ranking data (market-wide gainers/losers)
|
||||
if strategyConfig.Indicators.EnablePriceRanking {
|
||||
logger.Infof("📈 [%s] Fetching Price ranking data...", at.name)
|
||||
ctx.PriceRankingData = at.strategyEngine.FetchPriceRankingData()
|
||||
if ctx.PriceRankingData != nil {
|
||||
logger.Infof("📈 [%s] Price ranking data ready for %d durations",
|
||||
at.name, len(ctx.PriceRankingData.Durations))
|
||||
}
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
// sortDecisionsByPriority sorts decisions: close positions first, then open positions, finally hold/wait
|
||||
// This avoids position stacking overflow when changing positions
|
||||
func sortDecisionsByPriority(decisions []kernel.Decision) []kernel.Decision {
|
||||
if len(decisions) <= 1 {
|
||||
return decisions
|
||||
}
|
||||
|
||||
// Define priority
|
||||
getActionPriority := func(action string) int {
|
||||
switch action {
|
||||
case "close_long", "close_short":
|
||||
return 1 // Highest priority: close positions first
|
||||
case "open_long", "open_short":
|
||||
return 2 // Second priority: open positions later
|
||||
case "hold", "wait":
|
||||
return 3 // Lowest priority: wait
|
||||
default:
|
||||
return 999 // Unknown actions at the end
|
||||
}
|
||||
}
|
||||
|
||||
// Copy decision list
|
||||
sorted := make([]kernel.Decision, len(decisions))
|
||||
copy(sorted, decisions)
|
||||
|
||||
// Sort by priority
|
||||
for i := 0; i < len(sorted)-1; i++ {
|
||||
for j := i + 1; j < len(sorted); j++ {
|
||||
if getActionPriority(sorted[i].Action) > getActionPriority(sorted[j].Action) {
|
||||
sorted[i], sorted[j] = sorted[j], sorted[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sorted
|
||||
}
|
||||
@@ -0,0 +1,391 @@
|
||||
package trader
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"nofx/kernel"
|
||||
"nofx/logger"
|
||||
"nofx/market"
|
||||
"nofx/store"
|
||||
"time"
|
||||
)
|
||||
|
||||
// executeDecisionWithRecord executes AI decision and records detailed information
|
||||
func (at *AutoTrader) executeDecisionWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error {
|
||||
switch decision.Action {
|
||||
case "open_long":
|
||||
return at.executeOpenLongWithRecord(decision, actionRecord)
|
||||
case "open_short":
|
||||
return at.executeOpenShortWithRecord(decision, actionRecord)
|
||||
case "close_long":
|
||||
return at.executeCloseLongWithRecord(decision, actionRecord)
|
||||
case "close_short":
|
||||
return at.executeCloseShortWithRecord(decision, actionRecord)
|
||||
case "hold", "wait":
|
||||
// No execution needed, just record
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unknown action: %s", decision.Action)
|
||||
}
|
||||
}
|
||||
|
||||
// executeOpenLongWithRecord executes open long position and records detailed information
|
||||
func (at *AutoTrader) executeOpenLongWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error {
|
||||
logger.Infof(" 📈 Open long: %s", decision.Symbol)
|
||||
|
||||
// ⚠️ Get current positions for multiple checks
|
||||
positions, err := at.trader.GetPositions()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get positions: %w", err)
|
||||
}
|
||||
|
||||
// [CODE ENFORCED] Check max positions limit
|
||||
if err := at.enforceMaxPositions(len(positions)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if there's already a position in the same symbol and direction
|
||||
for _, pos := range positions {
|
||||
if pos["symbol"] == decision.Symbol && pos["side"] == "long" {
|
||||
return fmt.Errorf("❌ %s already has long position, close it first", decision.Symbol)
|
||||
}
|
||||
}
|
||||
|
||||
// Get current price
|
||||
marketData, err := market.GetWithExchange(decision.Symbol, at.exchange)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get balance (needed for multiple checks)
|
||||
balance, err := at.trader.GetBalance()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account balance: %w", err)
|
||||
}
|
||||
availableBalance := 0.0
|
||||
if avail, ok := balance["availableBalance"].(float64); ok {
|
||||
availableBalance = avail
|
||||
}
|
||||
|
||||
// Get equity for position value ratio check
|
||||
equity := 0.0
|
||||
if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 {
|
||||
equity = eq
|
||||
} else if eq, ok := balance["totalWalletBalance"].(float64); ok && eq > 0 {
|
||||
equity = eq
|
||||
} else {
|
||||
equity = availableBalance // Fallback to available balance
|
||||
}
|
||||
|
||||
// [CODE ENFORCED] Position Value Ratio Check: position_value <= equity × ratio
|
||||
adjustedPositionSize, wasCapped := at.enforcePositionValueRatio(decision.PositionSizeUSD, equity, decision.Symbol)
|
||||
if wasCapped {
|
||||
decision.PositionSizeUSD = adjustedPositionSize
|
||||
}
|
||||
|
||||
// ⚠️ Auto-adjust position size if insufficient margin
|
||||
// Formula: totalRequired = positionSize/leverage + positionSize*0.001 + positionSize/leverage*0.01
|
||||
// = positionSize * (1.01/leverage + 0.001)
|
||||
marginFactor := 1.01/float64(decision.Leverage) + 0.001
|
||||
maxAffordablePositionSize := availableBalance / marginFactor
|
||||
|
||||
actualPositionSize := decision.PositionSizeUSD
|
||||
if actualPositionSize > maxAffordablePositionSize {
|
||||
// Use 98% of max to leave buffer for price fluctuation
|
||||
adjustedSize := maxAffordablePositionSize * 0.98
|
||||
logger.Infof(" ⚠️ Position size %.2f exceeds max affordable %.2f, auto-reducing to %.2f",
|
||||
actualPositionSize, maxAffordablePositionSize, adjustedSize)
|
||||
actualPositionSize = adjustedSize
|
||||
decision.PositionSizeUSD = actualPositionSize
|
||||
}
|
||||
|
||||
// [CODE ENFORCED] Minimum position size check
|
||||
if err := at.enforceMinPositionSize(decision.PositionSizeUSD); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Calculate quantity with adjusted position size
|
||||
quantity := actualPositionSize / marketData.CurrentPrice
|
||||
actionRecord.Quantity = quantity
|
||||
actionRecord.Price = marketData.CurrentPrice
|
||||
|
||||
// Set margin mode
|
||||
if err := at.trader.SetMarginMode(decision.Symbol, at.config.IsCrossMargin); err != nil {
|
||||
logger.Infof(" ⚠️ Failed to set margin mode: %v", err)
|
||||
// Continue execution, doesn't affect trading
|
||||
}
|
||||
|
||||
// Open position
|
||||
order, err := at.trader.OpenLong(decision.Symbol, quantity, decision.Leverage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Record order ID
|
||||
if orderID, ok := order["orderId"].(int64); ok {
|
||||
actionRecord.OrderID = orderID
|
||||
}
|
||||
|
||||
logger.Infof(" ✓ Position opened successfully, order ID: %v, quantity: %.4f", order["orderId"], quantity)
|
||||
|
||||
// Record order to database and poll for confirmation
|
||||
at.recordAndConfirmOrder(order, decision.Symbol, "open_long", quantity, marketData.CurrentPrice, decision.Leverage, 0)
|
||||
|
||||
// Record position opening time
|
||||
posKey := decision.Symbol + "_long"
|
||||
at.positionFirstSeenTime[posKey] = time.Now().UnixMilli()
|
||||
|
||||
// Set stop loss and take profit
|
||||
if err := at.trader.SetStopLoss(decision.Symbol, "LONG", quantity, decision.StopLoss); err != nil {
|
||||
logger.Infof(" ⚠ Failed to set stop loss: %v", err)
|
||||
}
|
||||
if err := at.trader.SetTakeProfit(decision.Symbol, "LONG", quantity, decision.TakeProfit); err != nil {
|
||||
logger.Infof(" ⚠ Failed to set take profit: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeOpenShortWithRecord executes open short position and records detailed information
|
||||
func (at *AutoTrader) executeOpenShortWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error {
|
||||
logger.Infof(" 📉 Open short: %s", decision.Symbol)
|
||||
|
||||
// ⚠️ Get current positions for multiple checks
|
||||
positions, err := at.trader.GetPositions()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get positions: %w", err)
|
||||
}
|
||||
|
||||
// [CODE ENFORCED] Check max positions limit
|
||||
if err := at.enforceMaxPositions(len(positions)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if there's already a position in the same symbol and direction
|
||||
for _, pos := range positions {
|
||||
if pos["symbol"] == decision.Symbol && pos["side"] == "short" {
|
||||
return fmt.Errorf("❌ %s already has short position, close it first", decision.Symbol)
|
||||
}
|
||||
}
|
||||
|
||||
// Get current price
|
||||
marketData, err := market.GetWithExchange(decision.Symbol, at.exchange)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get balance (needed for multiple checks)
|
||||
balance, err := at.trader.GetBalance()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account balance: %w", err)
|
||||
}
|
||||
availableBalance := 0.0
|
||||
if avail, ok := balance["availableBalance"].(float64); ok {
|
||||
availableBalance = avail
|
||||
}
|
||||
|
||||
// Get equity for position value ratio check
|
||||
equity := 0.0
|
||||
if eq, ok := balance["totalEquity"].(float64); ok && eq > 0 {
|
||||
equity = eq
|
||||
} else if eq, ok := balance["totalWalletBalance"].(float64); ok && eq > 0 {
|
||||
equity = eq
|
||||
} else {
|
||||
equity = availableBalance // Fallback to available balance
|
||||
}
|
||||
|
||||
// [CODE ENFORCED] Position Value Ratio Check: position_value <= equity × ratio
|
||||
adjustedPositionSize, wasCapped := at.enforcePositionValueRatio(decision.PositionSizeUSD, equity, decision.Symbol)
|
||||
if wasCapped {
|
||||
decision.PositionSizeUSD = adjustedPositionSize
|
||||
}
|
||||
|
||||
// ⚠️ Auto-adjust position size if insufficient margin
|
||||
// Formula: totalRequired = positionSize/leverage + positionSize*0.001 + positionSize/leverage*0.01
|
||||
// = positionSize * (1.01/leverage + 0.001)
|
||||
marginFactor := 1.01/float64(decision.Leverage) + 0.001
|
||||
maxAffordablePositionSize := availableBalance / marginFactor
|
||||
|
||||
actualPositionSize := decision.PositionSizeUSD
|
||||
if actualPositionSize > maxAffordablePositionSize {
|
||||
// Use 98% of max to leave buffer for price fluctuation
|
||||
adjustedSize := maxAffordablePositionSize * 0.98
|
||||
logger.Infof(" ⚠️ Position size %.2f exceeds max affordable %.2f, auto-reducing to %.2f",
|
||||
actualPositionSize, maxAffordablePositionSize, adjustedSize)
|
||||
actualPositionSize = adjustedSize
|
||||
decision.PositionSizeUSD = actualPositionSize
|
||||
}
|
||||
|
||||
// [CODE ENFORCED] Minimum position size check
|
||||
if err := at.enforceMinPositionSize(decision.PositionSizeUSD); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Calculate quantity with adjusted position size
|
||||
quantity := actualPositionSize / marketData.CurrentPrice
|
||||
actionRecord.Quantity = quantity
|
||||
actionRecord.Price = marketData.CurrentPrice
|
||||
|
||||
// Set margin mode
|
||||
if err := at.trader.SetMarginMode(decision.Symbol, at.config.IsCrossMargin); err != nil {
|
||||
logger.Infof(" ⚠️ Failed to set margin mode: %v", err)
|
||||
// Continue execution, doesn't affect trading
|
||||
}
|
||||
|
||||
// Open position
|
||||
order, err := at.trader.OpenShort(decision.Symbol, quantity, decision.Leverage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Record order ID
|
||||
if orderID, ok := order["orderId"].(int64); ok {
|
||||
actionRecord.OrderID = orderID
|
||||
}
|
||||
|
||||
logger.Infof(" ✓ Position opened successfully, order ID: %v, quantity: %.4f", order["orderId"], quantity)
|
||||
|
||||
// Record order to database and poll for confirmation
|
||||
at.recordAndConfirmOrder(order, decision.Symbol, "open_short", quantity, marketData.CurrentPrice, decision.Leverage, 0)
|
||||
|
||||
// Record position opening time
|
||||
posKey := decision.Symbol + "_short"
|
||||
at.positionFirstSeenTime[posKey] = time.Now().UnixMilli()
|
||||
|
||||
// Set stop loss and take profit
|
||||
if err := at.trader.SetStopLoss(decision.Symbol, "SHORT", quantity, decision.StopLoss); err != nil {
|
||||
logger.Infof(" ⚠ Failed to set stop loss: %v", err)
|
||||
}
|
||||
if err := at.trader.SetTakeProfit(decision.Symbol, "SHORT", quantity, decision.TakeProfit); err != nil {
|
||||
logger.Infof(" ⚠ Failed to set take profit: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeCloseLongWithRecord executes close long position and records detailed information
|
||||
func (at *AutoTrader) executeCloseLongWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error {
|
||||
logger.Infof(" 🔄 Close long: %s", decision.Symbol)
|
||||
|
||||
// Get current price
|
||||
marketData, err := market.GetWithExchange(decision.Symbol, at.exchange)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
actionRecord.Price = marketData.CurrentPrice
|
||||
|
||||
// Normalize symbol for database lookup
|
||||
normalizedSymbol := market.Normalize(decision.Symbol)
|
||||
|
||||
// Get entry price and quantity - prioritize local database for accurate quantity
|
||||
var entryPrice float64
|
||||
var quantity float64
|
||||
|
||||
// First try to get from local database (more accurate for quantity)
|
||||
if at.store != nil {
|
||||
if openPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, normalizedSymbol, "LONG"); err == nil && openPos != nil {
|
||||
quantity = openPos.Quantity
|
||||
entryPrice = openPos.EntryPrice
|
||||
logger.Infof(" 📊 Using local position data: qty=%.8f, entry=%.2f", quantity, entryPrice)
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to exchange API if local data not found
|
||||
if quantity == 0 {
|
||||
positions, err := at.trader.GetPositions()
|
||||
if err == nil {
|
||||
for _, pos := range positions {
|
||||
if pos["symbol"] == decision.Symbol && pos["side"] == "long" {
|
||||
if ep, ok := pos["entryPrice"].(float64); ok {
|
||||
entryPrice = ep
|
||||
}
|
||||
if amt, ok := pos["positionAmt"].(float64); ok && amt > 0 {
|
||||
quantity = amt
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.Infof(" 📊 Using exchange position data: qty=%.8f, entry=%.2f", quantity, entryPrice)
|
||||
}
|
||||
|
||||
// Close position
|
||||
order, err := at.trader.CloseLong(decision.Symbol, 0) // 0 = close all
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Record order ID
|
||||
if orderID, ok := order["orderId"].(int64); ok {
|
||||
actionRecord.OrderID = orderID
|
||||
}
|
||||
|
||||
// Record order to database and poll for confirmation
|
||||
at.recordAndConfirmOrder(order, decision.Symbol, "close_long", quantity, marketData.CurrentPrice, 0, entryPrice)
|
||||
|
||||
logger.Infof(" ✓ Position closed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeCloseShortWithRecord executes close short position and records detailed information
|
||||
func (at *AutoTrader) executeCloseShortWithRecord(decision *kernel.Decision, actionRecord *store.DecisionAction) error {
|
||||
logger.Infof(" 🔄 Close short: %s", decision.Symbol)
|
||||
|
||||
// Get current price
|
||||
marketData, err := market.GetWithExchange(decision.Symbol, at.exchange)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
actionRecord.Price = marketData.CurrentPrice
|
||||
|
||||
// Normalize symbol for database lookup
|
||||
normalizedSymbol := market.Normalize(decision.Symbol)
|
||||
|
||||
// Get entry price and quantity - prioritize local database for accurate quantity
|
||||
var entryPrice float64
|
||||
var quantity float64
|
||||
|
||||
// First try to get from local database (more accurate for quantity)
|
||||
if at.store != nil {
|
||||
if openPos, err := at.store.Position().GetOpenPositionBySymbol(at.id, normalizedSymbol, "SHORT"); err == nil && openPos != nil {
|
||||
quantity = openPos.Quantity
|
||||
entryPrice = openPos.EntryPrice
|
||||
logger.Infof(" 📊 Using local position data: qty=%.8f, entry=%.2f", quantity, entryPrice)
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to exchange API if local data not found
|
||||
if quantity == 0 {
|
||||
positions, err := at.trader.GetPositions()
|
||||
if err == nil {
|
||||
for _, pos := range positions {
|
||||
if pos["symbol"] == decision.Symbol && pos["side"] == "short" {
|
||||
if ep, ok := pos["entryPrice"].(float64); ok {
|
||||
entryPrice = ep
|
||||
}
|
||||
if amt, ok := pos["positionAmt"].(float64); ok {
|
||||
quantity = -amt // positionAmt is negative for short
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.Infof(" 📊 Using exchange position data: qty=%.8f, entry=%.2f", quantity, entryPrice)
|
||||
}
|
||||
|
||||
// Close position
|
||||
order, err := at.trader.CloseShort(decision.Symbol, 0) // 0 = close all
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Record order ID
|
||||
if orderID, ok := order["orderId"].(int64); ok {
|
||||
actionRecord.OrderID = orderID
|
||||
}
|
||||
|
||||
// Record order to database and poll for confirmation
|
||||
at.recordAndConfirmOrder(order, decision.Symbol, "close_short", quantity, marketData.CurrentPrice, 0, entryPrice)
|
||||
|
||||
logger.Infof(" ✓ Position closed successfully")
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,263 @@
|
||||
package trader
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"nofx/logger"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// startDrawdownMonitor starts drawdown monitoring
|
||||
func (at *AutoTrader) startDrawdownMonitor() {
|
||||
at.monitorWg.Add(1)
|
||||
go func() {
|
||||
defer at.monitorWg.Done()
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute) // Check every minute
|
||||
defer ticker.Stop()
|
||||
|
||||
logger.Info("📊 Started position drawdown monitoring (check every minute)")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
at.checkPositionDrawdown()
|
||||
case <-at.stopMonitorCh:
|
||||
logger.Info("⏹ Stopped position drawdown monitoring")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// checkPositionDrawdown checks position drawdown situation
|
||||
func (at *AutoTrader) checkPositionDrawdown() {
|
||||
// Get current positions
|
||||
positions, err := at.trader.GetPositions()
|
||||
if err != nil {
|
||||
logger.Infof("❌ Drawdown monitoring: failed to get positions: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, pos := range positions {
|
||||
symbol := pos["symbol"].(string)
|
||||
side := pos["side"].(string)
|
||||
entryPrice := pos["entryPrice"].(float64)
|
||||
markPrice := pos["markPrice"].(float64)
|
||||
quantity := pos["positionAmt"].(float64)
|
||||
if quantity < 0 {
|
||||
quantity = -quantity // Short position quantity is negative, convert to positive
|
||||
}
|
||||
|
||||
// Calculate current P&L percentage
|
||||
leverage := 10 // Default value
|
||||
if lev, ok := pos["leverage"].(float64); ok {
|
||||
leverage = int(lev)
|
||||
}
|
||||
|
||||
var currentPnLPct float64
|
||||
if side == "long" {
|
||||
currentPnLPct = ((markPrice - entryPrice) / entryPrice) * float64(leverage) * 100
|
||||
} else {
|
||||
currentPnLPct = ((entryPrice - markPrice) / entryPrice) * float64(leverage) * 100
|
||||
}
|
||||
|
||||
// Construct unique position identifier (distinguish long/short)
|
||||
posKey := symbol + "_" + side
|
||||
|
||||
// Get historical peak profit for this position
|
||||
at.peakPnLCacheMutex.RLock()
|
||||
peakPnLPct, exists := at.peakPnLCache[posKey]
|
||||
at.peakPnLCacheMutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// If no historical peak record, use current P&L as initial value
|
||||
peakPnLPct = currentPnLPct
|
||||
at.UpdatePeakPnL(symbol, side, currentPnLPct)
|
||||
} else {
|
||||
// Update peak cache
|
||||
at.UpdatePeakPnL(symbol, side, currentPnLPct)
|
||||
}
|
||||
|
||||
// Calculate drawdown (magnitude of decline from peak)
|
||||
var drawdownPct float64
|
||||
if peakPnLPct > 0 && currentPnLPct < peakPnLPct {
|
||||
drawdownPct = ((peakPnLPct - currentPnLPct) / peakPnLPct) * 100
|
||||
}
|
||||
|
||||
// Check close position condition: profit > 5% and drawdown >= 40%
|
||||
if currentPnLPct > 5.0 && drawdownPct >= 40.0 {
|
||||
logger.Infof("🚨 Drawdown close position condition triggered: %s %s | Current profit: %.2f%% | Peak profit: %.2f%% | Drawdown: %.2f%%",
|
||||
symbol, side, currentPnLPct, peakPnLPct, drawdownPct)
|
||||
|
||||
// Execute close position
|
||||
if err := at.emergencyClosePosition(symbol, side); err != nil {
|
||||
logger.Infof("❌ Drawdown close position failed (%s %s): %v", symbol, side, err)
|
||||
} else {
|
||||
logger.Infof("✅ Drawdown close position succeeded: %s %s", symbol, side)
|
||||
// Clear cache for this position after closing
|
||||
at.ClearPeakPnLCache(symbol, side)
|
||||
}
|
||||
} else if currentPnLPct > 5.0 {
|
||||
// Record situations close to close position condition (for debugging)
|
||||
logger.Infof("📊 Drawdown monitoring: %s %s | Profit: %.2f%% | Peak: %.2f%% | Drawdown: %.2f%%",
|
||||
symbol, side, currentPnLPct, peakPnLPct, drawdownPct)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// emergencyClosePosition emergency close position function
|
||||
func (at *AutoTrader) emergencyClosePosition(symbol, side string) error {
|
||||
switch side {
|
||||
case "long":
|
||||
order, err := at.trader.CloseLong(symbol, 0) // 0 = close all
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Infof("✅ Emergency close long position succeeded, order ID: %v", order["orderId"])
|
||||
case "short":
|
||||
order, err := at.trader.CloseShort(symbol, 0) // 0 = close all
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Infof("✅ Emergency close short position succeeded, order ID: %v", order["orderId"])
|
||||
default:
|
||||
return fmt.Errorf("unknown position direction: %s", side)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeakPnLCache gets peak profit cache
|
||||
func (at *AutoTrader) GetPeakPnLCache() map[string]float64 {
|
||||
at.peakPnLCacheMutex.RLock()
|
||||
defer at.peakPnLCacheMutex.RUnlock()
|
||||
|
||||
// Return a copy of the cache
|
||||
cache := make(map[string]float64)
|
||||
for k, v := range at.peakPnLCache {
|
||||
cache[k] = v
|
||||
}
|
||||
return cache
|
||||
}
|
||||
|
||||
// UpdatePeakPnL updates peak profit cache
|
||||
func (at *AutoTrader) UpdatePeakPnL(symbol, side string, currentPnLPct float64) {
|
||||
at.peakPnLCacheMutex.Lock()
|
||||
defer at.peakPnLCacheMutex.Unlock()
|
||||
|
||||
posKey := symbol + "_" + side
|
||||
if peak, exists := at.peakPnLCache[posKey]; exists {
|
||||
// Update peak (if long, take larger value; if short, currentPnLPct is negative, also compare)
|
||||
if currentPnLPct > peak {
|
||||
at.peakPnLCache[posKey] = currentPnLPct
|
||||
}
|
||||
} else {
|
||||
// First time recording
|
||||
at.peakPnLCache[posKey] = currentPnLPct
|
||||
}
|
||||
}
|
||||
|
||||
// ClearPeakPnLCache clears peak cache for specified position
|
||||
func (at *AutoTrader) ClearPeakPnLCache(symbol, side string) {
|
||||
at.peakPnLCacheMutex.Lock()
|
||||
defer at.peakPnLCacheMutex.Unlock()
|
||||
|
||||
posKey := symbol + "_" + side
|
||||
delete(at.peakPnLCache, posKey)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Risk Control Helpers
|
||||
// ============================================================================
|
||||
|
||||
// isBTCETH checks if a symbol is BTC or ETH
|
||||
func isBTCETH(symbol string) bool {
|
||||
symbol = strings.ToUpper(symbol)
|
||||
return strings.HasPrefix(symbol, "BTC") || strings.HasPrefix(symbol, "ETH")
|
||||
}
|
||||
|
||||
// enforcePositionValueRatio checks and enforces position value ratio limits (CODE ENFORCED)
|
||||
// Returns the adjusted position size (capped if necessary) and whether the position was capped
|
||||
// positionSizeUSD: the original position size in USD
|
||||
// equity: the account equity
|
||||
// symbol: the trading symbol
|
||||
func (at *AutoTrader) enforcePositionValueRatio(positionSizeUSD float64, equity float64, symbol string) (float64, bool) {
|
||||
if at.config.StrategyConfig == nil {
|
||||
return positionSizeUSD, false
|
||||
}
|
||||
|
||||
riskControl := at.config.StrategyConfig.RiskControl
|
||||
|
||||
// Get the appropriate position value ratio limit
|
||||
var maxPositionValueRatio float64
|
||||
if isBTCETH(symbol) {
|
||||
maxPositionValueRatio = riskControl.BTCETHMaxPositionValueRatio
|
||||
if maxPositionValueRatio <= 0 {
|
||||
maxPositionValueRatio = 5.0 // Default: 5x for BTC/ETH
|
||||
}
|
||||
} else {
|
||||
maxPositionValueRatio = riskControl.AltcoinMaxPositionValueRatio
|
||||
if maxPositionValueRatio <= 0 {
|
||||
maxPositionValueRatio = 1.0 // Default: 1x for altcoins
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate max allowed position value = equity × ratio
|
||||
maxPositionValue := equity * maxPositionValueRatio
|
||||
|
||||
// Check if position size exceeds limit
|
||||
if positionSizeUSD > maxPositionValue {
|
||||
logger.Infof(" ⚠️ [RISK CONTROL] Position %.2f USDT exceeds limit (equity %.2f × %.1fx = %.2f USDT max for %s), capping",
|
||||
positionSizeUSD, equity, maxPositionValueRatio, maxPositionValue, symbol)
|
||||
return maxPositionValue, true
|
||||
}
|
||||
|
||||
return positionSizeUSD, false
|
||||
}
|
||||
|
||||
// enforceMinPositionSize checks minimum position size (CODE ENFORCED)
|
||||
func (at *AutoTrader) enforceMinPositionSize(positionSizeUSD float64) error {
|
||||
if at.config.StrategyConfig == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
minSize := at.config.StrategyConfig.RiskControl.MinPositionSize
|
||||
if minSize <= 0 {
|
||||
minSize = 12 // Default: 12 USDT
|
||||
}
|
||||
|
||||
if positionSizeUSD < minSize {
|
||||
return fmt.Errorf("❌ [RISK CONTROL] Position %.2f USDT below minimum (%.2f USDT)", positionSizeUSD, minSize)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// enforceMaxPositions checks maximum positions count (CODE ENFORCED)
|
||||
func (at *AutoTrader) enforceMaxPositions(currentPositionCount int) error {
|
||||
if at.config.StrategyConfig == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
maxPositions := at.config.StrategyConfig.RiskControl.MaxPositions
|
||||
if maxPositions <= 0 {
|
||||
maxPositions = 3 // Default: 3 positions
|
||||
}
|
||||
|
||||
if currentPositionCount >= maxPositions {
|
||||
return fmt.Errorf("❌ [RISK CONTROL] Already at max positions (%d/%d)", currentPositionCount, maxPositions)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getSideFromAction converts order action to side (BUY/SELL)
|
||||
func getSideFromAction(action string) string {
|
||||
switch action {
|
||||
case "open_long", "close_short":
|
||||
return "BUY"
|
||||
case "open_short", "close_long":
|
||||
return "SELL"
|
||||
default:
|
||||
return "BUY"
|
||||
}
|
||||
}
|
||||
+9
-9
@@ -4,27 +4,27 @@ import useSWR from 'swr'
|
||||
import { api } from './lib/api'
|
||||
import { TraderDashboardPage } from './pages/TraderDashboardPage'
|
||||
|
||||
import { AITradersPage } from './components/AITradersPage'
|
||||
import { LoginPage } from './components/LoginPage'
|
||||
import { SetupPage } from './components/SetupPage'
|
||||
import { AITradersPage } from './components/trader/AITradersPage'
|
||||
import { LoginPage } from './components/auth/LoginPage'
|
||||
import { SetupPage } from './components/modals/SetupPage'
|
||||
import { SettingsPage } from './pages/SettingsPage'
|
||||
import { ResetPasswordPage } from './components/ResetPasswordPage'
|
||||
import { CompetitionPage } from './components/CompetitionPage'
|
||||
import { ResetPasswordPage } from './components/auth/ResetPasswordPage'
|
||||
import { CompetitionPage } from './components/trader/CompetitionPage'
|
||||
import { LandingPage } from './pages/LandingPage'
|
||||
import { FAQPage } from './pages/FAQPage'
|
||||
import { StrategyStudioPage } from './pages/StrategyStudioPage'
|
||||
import { StrategyMarketPage } from './pages/StrategyMarketPage'
|
||||
import { DataPage } from './pages/DataPage'
|
||||
import { LoginRequiredOverlay } from './components/LoginRequiredOverlay'
|
||||
import HeaderBar from './components/HeaderBar'
|
||||
import { LoginRequiredOverlay } from './components/auth/LoginRequiredOverlay'
|
||||
import HeaderBar from './components/common/HeaderBar'
|
||||
import { LanguageProvider, useLanguage } from './contexts/LanguageContext'
|
||||
import { AuthProvider, useAuth } from './contexts/AuthContext'
|
||||
import { ConfirmDialogProvider } from './components/ConfirmDialog'
|
||||
import { ConfirmDialogProvider } from './components/common/ConfirmDialog'
|
||||
import { t } from './i18n/translations'
|
||||
import { useSystemConfig } from './hooks/useSystemConfig'
|
||||
|
||||
import { OFFICIAL_LINKS } from './constants/branding'
|
||||
import { BacktestPage } from './components/BacktestPage'
|
||||
import { BacktestPage } from './components/backtest/BacktestPage'
|
||||
import type {
|
||||
SystemStatus,
|
||||
AccountInfo,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import React, { useState, useEffect } from 'react'
|
||||
import { Eye, EyeOff } from 'lucide-react'
|
||||
import { toast } from 'sonner'
|
||||
import { useAuth } from '../contexts/AuthContext'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t } from '../i18n/translations'
|
||||
import { DeepVoidBackground } from './DeepVoidBackground'
|
||||
import { useAuth } from '../../contexts/AuthContext'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { t } from '../../i18n/translations'
|
||||
import { DeepVoidBackground } from '../common/DeepVoidBackground'
|
||||
|
||||
export function LoginPage() {
|
||||
const { language } = useLanguage()
|
||||
+2
-2
@@ -1,7 +1,7 @@
|
||||
import { motion, AnimatePresence } from 'framer-motion'
|
||||
import { LogIn, UserPlus, X, AlertTriangle, Terminal } from 'lucide-react'
|
||||
import { DeepVoidBackground } from './DeepVoidBackground'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { DeepVoidBackground } from '../common/DeepVoidBackground'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
|
||||
interface LoginRequiredOverlayProps {
|
||||
isOpen: boolean
|
||||
@@ -2,13 +2,13 @@ import React, { useEffect, useState } from 'react'
|
||||
import { Eye, EyeOff } from 'lucide-react'
|
||||
import PasswordChecklist from 'react-password-checklist'
|
||||
import { toast } from 'sonner'
|
||||
import { useAuth } from '../contexts/AuthContext'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t } from '../i18n/translations'
|
||||
import { getSystemConfig } from '../lib/config'
|
||||
import { DeepVoidBackground } from './DeepVoidBackground'
|
||||
import { useAuth } from '../../contexts/AuthContext'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { t } from '../../i18n/translations'
|
||||
import { getSystemConfig } from '../../lib/config'
|
||||
import { DeepVoidBackground } from '../common/DeepVoidBackground'
|
||||
import { RegistrationDisabled } from './RegistrationDisabled'
|
||||
import { WhitelistFullPage } from './WhitelistFullPage'
|
||||
import { WhitelistFullPage } from '../common/WhitelistFullPage'
|
||||
|
||||
export function RegisterPage() {
|
||||
const { language } = useLanguage()
|
||||
+3
-3
@@ -1,11 +1,11 @@
|
||||
import { describe, it, expect, vi } from 'vitest'
|
||||
import { render, screen, fireEvent } from '@testing-library/react'
|
||||
import { RegistrationDisabled } from './RegistrationDisabled'
|
||||
import { LanguageProvider } from '../contexts/LanguageContext'
|
||||
import { LanguageProvider } from '../../contexts/LanguageContext'
|
||||
|
||||
// Mock useLanguage hook
|
||||
vi.mock('../contexts/LanguageContext', async () => {
|
||||
const actual = await vi.importActual('../contexts/LanguageContext')
|
||||
vi.mock('../../contexts/LanguageContext', async () => {
|
||||
const actual = await vi.importActual('../../contexts/LanguageContext')
|
||||
return {
|
||||
...actual,
|
||||
useLanguage: () => ({ language: 'en' }),
|
||||
+2
-2
@@ -1,5 +1,5 @@
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t } from '../i18n/translations'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { t } from '../../i18n/translations'
|
||||
|
||||
export function RegistrationDisabled() {
|
||||
const { language } = useLanguage()
|
||||
+5
-5
@@ -1,11 +1,11 @@
|
||||
import React, { useState } from 'react'
|
||||
import { useAuth } from '../contexts/AuthContext'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t } from '../i18n/translations'
|
||||
import { Header } from './Header'
|
||||
import { useAuth } from '../../contexts/AuthContext'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { t } from '../../i18n/translations'
|
||||
import { Header } from '../common/Header'
|
||||
import { ArrowLeft, KeyRound, Eye, EyeOff } from 'lucide-react'
|
||||
import PasswordChecklist from 'react-password-checklist'
|
||||
import { Input } from './ui/input'
|
||||
import { Input } from '../ui/input'
|
||||
import { toast } from 'sonner'
|
||||
|
||||
export function ResetPasswordPage() {
|
||||
@@ -28,7 +28,7 @@ import {
|
||||
ArrowDownRight,
|
||||
CandlestickChart as CandlestickIcon,
|
||||
} from 'lucide-react'
|
||||
import { DeepVoidBackground } from './DeepVoidBackground'
|
||||
import { DeepVoidBackground } from '../common/DeepVoidBackground'
|
||||
import {
|
||||
ResponsiveContainer,
|
||||
AreaChart,
|
||||
@@ -39,12 +39,12 @@ import {
|
||||
Tooltip,
|
||||
ReferenceDot,
|
||||
} from 'recharts'
|
||||
import { api } from '../lib/api'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t } from '../i18n/translations'
|
||||
import { confirmToast } from '../lib/notify'
|
||||
import { DecisionCard } from './DecisionCard'
|
||||
import { MetricTooltip } from './MetricTooltip'
|
||||
import { api } from '../../lib/api'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { t } from '../../i18n/translations'
|
||||
import { confirmToast } from '../../lib/notify'
|
||||
import { DecisionCard } from '../trader/DecisionCard'
|
||||
import { MetricTooltip } from '../common/MetricTooltip'
|
||||
import type {
|
||||
BacktestStatusPayload,
|
||||
BacktestPositionStatus,
|
||||
@@ -55,7 +55,7 @@ import type {
|
||||
DecisionRecord,
|
||||
AIModel,
|
||||
Strategy,
|
||||
} from '../types'
|
||||
} from '../../types'
|
||||
|
||||
// ============ Types ============
|
||||
type WizardStep = 1 | 2 | 3
|
||||
@@ -10,14 +10,14 @@ import {
|
||||
HistogramSeries,
|
||||
createSeriesMarkers,
|
||||
} from 'lightweight-charts'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { httpClient } from '../lib/httpClient'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { httpClient } from '../../lib/httpClient'
|
||||
import {
|
||||
calculateSMA,
|
||||
calculateEMA,
|
||||
calculateBollingerBands,
|
||||
type Kline,
|
||||
} from '../utils/indicators'
|
||||
} from '../../utils/indicators'
|
||||
import { Settings, BarChart2 } from 'lucide-react'
|
||||
|
||||
// 订单接口定义
|
||||
@@ -1,8 +1,8 @@
|
||||
import { useState, useEffect, useRef } from 'react'
|
||||
import { EquityChart } from './EquityChart'
|
||||
import { AdvancedChart } from './AdvancedChart'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t } from '../i18n/translations'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { t } from '../../i18n/translations'
|
||||
import { BarChart3, CandlestickChart, ChevronDown, Search } from 'lucide-react'
|
||||
import { motion, AnimatePresence } from 'framer-motion'
|
||||
|
||||
+2
-2
@@ -8,8 +8,8 @@ import {
|
||||
CandlestickSeries,
|
||||
createSeriesMarkers,
|
||||
} from 'lightweight-charts'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { httpClient } from '../lib/httpClient'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { httpClient } from '../../lib/httpClient'
|
||||
|
||||
// 订单接口定义
|
||||
interface OrderMarker {
|
||||
+1
-1
@@ -1,5 +1,5 @@
|
||||
import { useEffect, useState } from 'react'
|
||||
import { httpClient } from '../lib/httpClient'
|
||||
import { httpClient } from '../../lib/httpClient'
|
||||
|
||||
interface ChartWithOrdersSimpleProps {
|
||||
symbol: string
|
||||
+5
-5
@@ -12,11 +12,11 @@ import {
|
||||
ComposedChart,
|
||||
} from 'recharts'
|
||||
import useSWR from 'swr'
|
||||
import { api } from '../lib/api'
|
||||
import type { CompetitionTraderData } from '../types'
|
||||
import { getTraderColor } from '../utils/traderColors'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t } from '../i18n/translations'
|
||||
import { api } from '../../lib/api'
|
||||
import type { CompetitionTraderData } from '../../types'
|
||||
import { getTraderColor } from '../../utils/traderColors'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { t } from '../../i18n/translations'
|
||||
import { BarChart3, TrendingUp, TrendingDown, Zap } from 'lucide-react'
|
||||
|
||||
// Time period options: 1D, 3D, 7D, 30D, All
|
||||
@@ -10,10 +10,10 @@ import {
|
||||
ReferenceLine,
|
||||
} from 'recharts'
|
||||
import useSWR from 'swr'
|
||||
import { api } from '../lib/api'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { useAuth } from '../contexts/AuthContext'
|
||||
import { t } from '../i18n/translations'
|
||||
import { api } from '../../lib/api'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { useAuth } from '../../contexts/AuthContext'
|
||||
import { t } from '../../i18n/translations'
|
||||
import {
|
||||
AlertTriangle,
|
||||
BarChart3,
|
||||
+2
-2
@@ -1,6 +1,6 @@
|
||||
import { useEffect, useRef, useState, memo } from 'react'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t } from '../i18n/translations'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { t } from '../../i18n/translations'
|
||||
import { ChevronDown, TrendingUp, X } from 'lucide-react'
|
||||
|
||||
// 支持的交易所列表 (合约格式)
|
||||
@@ -13,8 +13,8 @@ import {
|
||||
AlertDialogDescription,
|
||||
AlertDialogFooter,
|
||||
AlertDialogTitle,
|
||||
} from './ui/alert-dialog'
|
||||
import { setGlobalConfirm } from '../lib/notify'
|
||||
} from '../ui/alert-dialog'
|
||||
import { setGlobalConfirm } from '../../lib/notify'
|
||||
|
||||
interface ConfirmOptions {
|
||||
title?: string
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t } from '../i18n/translations'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { t } from '../../i18n/translations'
|
||||
import { Container } from './Container'
|
||||
|
||||
interface HeaderProps {
|
||||
@@ -2,8 +2,8 @@ import { useState, useEffect, useRef } from 'react'
|
||||
import { useNavigate } from 'react-router-dom'
|
||||
import { motion, AnimatePresence } from 'framer-motion'
|
||||
import { Menu, X, ChevronDown, Settings } from 'lucide-react'
|
||||
import { t, type Language } from '../i18n/translations'
|
||||
import { OFFICIAL_LINKS } from '../constants/branding'
|
||||
import { t, type Language } from '../../i18n/translations'
|
||||
import { OFFICIAL_LINKS } from '../../constants/branding'
|
||||
|
||||
type Page =
|
||||
| 'competition'
|
||||
+2
-2
@@ -1,7 +1,7 @@
|
||||
import { useCallback, useEffect, useState, type ReactNode } from 'react'
|
||||
import { Loader2, ShieldAlert, ShieldCheck, ShieldMinus } from 'lucide-react'
|
||||
import { CryptoService, diagnoseWebCryptoEnvironment } from '../lib/crypto'
|
||||
import { t, type Language } from '../i18n/translations'
|
||||
import { CryptoService, diagnoseWebCryptoEnvironment } from '../../lib/crypto'
|
||||
import { t, type Language } from '../../i18n/translations'
|
||||
|
||||
export type WebCryptoCheckStatus =
|
||||
| 'idle'
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
import { motion } from 'framer-motion'
|
||||
import { ShieldAlert, ArrowLeft, Twitter, Send, Lock } from 'lucide-react'
|
||||
import { OFFICIAL_LINKS } from '../constants/branding'
|
||||
import { OFFICIAL_LINKS } from '../../constants/branding'
|
||||
|
||||
interface WhitelistFullPageProps {
|
||||
onBack?: () => void
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useState, useMemo } from 'react'
|
||||
import { HelpCircle } from 'lucide-react'
|
||||
import { DeepVoidBackground } from '../DeepVoidBackground'
|
||||
import { DeepVoidBackground } from '../common/DeepVoidBackground'
|
||||
import { t, type Language } from '../../i18n/translations'
|
||||
import { FAQSearchBar } from './FAQSearchBar'
|
||||
import { FAQSidebar } from './FAQSidebar'
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import React, { useState } from 'react'
|
||||
import { Eye, EyeOff } from 'lucide-react'
|
||||
import { useAuth } from '../contexts/AuthContext'
|
||||
import { DeepVoidBackground } from './DeepVoidBackground'
|
||||
import { invalidateSystemConfig } from '../lib/config'
|
||||
import { useAuth } from '../../contexts/AuthContext'
|
||||
import { DeepVoidBackground } from '../common/DeepVoidBackground'
|
||||
import { invalidateSystemConfig } from '../../lib/config'
|
||||
|
||||
export function SetupPage() {
|
||||
const { register } = useAuth()
|
||||
+2
-2
@@ -1,8 +1,8 @@
|
||||
import { useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { createPortal } from 'react-dom'
|
||||
import { t, type Language } from '../i18n/translations'
|
||||
import { t, type Language } from '../../i18n/translations'
|
||||
import { toast } from 'sonner'
|
||||
import { WebCryptoEnvironmentCheck } from './WebCryptoEnvironmentCheck'
|
||||
import { WebCryptoEnvironmentCheck } from '../common/WebCryptoEnvironmentCheck'
|
||||
|
||||
const DEFAULT_LENGTH = 64
|
||||
|
||||
+12
-12
@@ -1,23 +1,23 @@
|
||||
import React, { useState, useEffect } from 'react'
|
||||
import { useNavigate } from 'react-router-dom'
|
||||
import useSWR from 'swr'
|
||||
import { api } from '../lib/api'
|
||||
import { api } from '../../lib/api'
|
||||
import type {
|
||||
TraderInfo,
|
||||
CreateTraderRequest,
|
||||
AIModel,
|
||||
Exchange,
|
||||
} from '../types'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t, type Language } from '../i18n/translations'
|
||||
import { useAuth } from '../contexts/AuthContext'
|
||||
import { getExchangeIcon } from './ExchangeIcons'
|
||||
import { getModelIcon } from './ModelIcons'
|
||||
} from '../../types'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { t, type Language } from '../../i18n/translations'
|
||||
import { useAuth } from '../../contexts/AuthContext'
|
||||
import { getExchangeIcon } from '../common/ExchangeIcons'
|
||||
import { getModelIcon } from '../common/ModelIcons'
|
||||
import { TraderConfigModal } from './TraderConfigModal'
|
||||
import { DeepVoidBackground } from './DeepVoidBackground'
|
||||
import { ExchangeConfigModal } from './traders/ExchangeConfigModal'
|
||||
import { TelegramConfigModal } from './traders/TelegramConfigModal'
|
||||
import { PunkAvatar, getTraderAvatar } from './PunkAvatar'
|
||||
import { DeepVoidBackground } from '../common/DeepVoidBackground'
|
||||
import { ExchangeConfigModal } from './ExchangeConfigModal'
|
||||
import { TelegramConfigModal } from './TelegramConfigModal'
|
||||
import { PunkAvatar, getTraderAvatar } from '../common/PunkAvatar'
|
||||
import {
|
||||
Bot,
|
||||
Brain,
|
||||
@@ -34,7 +34,7 @@ import {
|
||||
Check,
|
||||
MessageCircle,
|
||||
} from 'lucide-react'
|
||||
import { confirmToast } from '../lib/notify'
|
||||
import { confirmToast } from '../../lib/notify'
|
||||
import { toast } from 'sonner'
|
||||
|
||||
// 获取友好的AI模型名称
|
||||
+8
-8
@@ -1,15 +1,15 @@
|
||||
import { useState } from 'react'
|
||||
import { Trophy } from 'lucide-react'
|
||||
import useSWR from 'swr'
|
||||
import { api } from '../lib/api'
|
||||
import type { CompetitionData } from '../types'
|
||||
import { ComparisonChart } from './ComparisonChart'
|
||||
import { api } from '../../lib/api'
|
||||
import type { CompetitionData } from '../../types'
|
||||
import { ComparisonChart } from '../charts/ComparisonChart'
|
||||
import { TraderConfigViewModal } from './TraderConfigViewModal'
|
||||
import { getTraderColor } from '../utils/traderColors'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t } from '../i18n/translations'
|
||||
import { PunkAvatar, getTraderAvatar } from './PunkAvatar'
|
||||
import { DeepVoidBackground } from './DeepVoidBackground'
|
||||
import { getTraderColor } from '../../utils/traderColors'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { t } from '../../i18n/translations'
|
||||
import { PunkAvatar, getTraderAvatar } from '../common/PunkAvatar'
|
||||
import { DeepVoidBackground } from '../common/DeepVoidBackground'
|
||||
|
||||
export function CompetitionPage() {
|
||||
const { language } = useLanguage()
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useState } from 'react'
|
||||
import type { DecisionRecord, DecisionAction } from '../types'
|
||||
import { t, type Language } from '../i18n/translations'
|
||||
import type { DecisionRecord, DecisionAction } from '../../types'
|
||||
import { t, type Language } from '../../i18n/translations'
|
||||
|
||||
interface DecisionCardProps {
|
||||
decision: DecisionRecord
|
||||
+3
-3
@@ -2,15 +2,15 @@ import React, { useState, useEffect } from 'react'
|
||||
import type { Exchange } from '../../types'
|
||||
import { t, type Language } from '../../i18n/translations'
|
||||
import { api } from '../../lib/api'
|
||||
import { getExchangeIcon } from '../ExchangeIcons'
|
||||
import { getExchangeIcon } from '../common/ExchangeIcons'
|
||||
import {
|
||||
TwoStageKeyModal,
|
||||
type TwoStageKeyModalResult,
|
||||
} from '../TwoStageKeyModal'
|
||||
} from '../modals/TwoStageKeyModal'
|
||||
import {
|
||||
WebCryptoEnvironmentCheck,
|
||||
type WebCryptoCheckStatus,
|
||||
} from '../WebCryptoEnvironmentCheck'
|
||||
} from '../common/WebCryptoEnvironmentCheck'
|
||||
import {
|
||||
BookOpen, Trash2, HelpCircle, ExternalLink, UserPlus,
|
||||
Key, Shield, ChevronLeft, Check, Copy, ArrowRight
|
||||
+6
-6
@@ -1,15 +1,15 @@
|
||||
import { useState, useEffect, useMemo } from 'react'
|
||||
import { api } from '../lib/api'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t, type Language } from '../i18n/translations'
|
||||
import { MetricTooltip } from './MetricTooltip'
|
||||
import { formatPrice, formatQuantity } from '../utils/format'
|
||||
import { api } from '../../lib/api'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { t, type Language } from '../../i18n/translations'
|
||||
import { MetricTooltip } from '../common/MetricTooltip'
|
||||
import { formatPrice, formatQuantity } from '../../utils/format'
|
||||
import type {
|
||||
HistoricalPosition,
|
||||
TraderStats,
|
||||
SymbolStats,
|
||||
DirectionStats,
|
||||
} from '../types'
|
||||
} from '../../types'
|
||||
|
||||
interface PositionHistoryProps {
|
||||
traderId: string
|
||||
+5
-5
@@ -1,10 +1,10 @@
|
||||
import { useState, useEffect } from 'react'
|
||||
import type { AIModel, Exchange, CreateTraderRequest, Strategy } from '../types'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t } from '../i18n/translations'
|
||||
import type { AIModel, Exchange, CreateTraderRequest, Strategy } from '../../types'
|
||||
import { useLanguage } from '../../contexts/LanguageContext'
|
||||
import { t } from '../../i18n/translations'
|
||||
import { toast } from 'sonner'
|
||||
import { Pencil, Plus, X as IconX, Sparkles, ExternalLink, UserPlus } from 'lucide-react'
|
||||
import { httpClient } from '../lib/httpClient'
|
||||
import { httpClient } from '../../lib/httpClient'
|
||||
|
||||
// 提取下划线后面的名称部分
|
||||
function getShortName(fullName: string): string {
|
||||
@@ -22,7 +22,7 @@ const EXCHANGE_REGISTRATION_LINKS: Record<string, { url: string; hasReferral?: b
|
||||
lighter: { url: 'https://app.lighter.xyz/?referral=68151432', hasReferral: true },
|
||||
}
|
||||
|
||||
import type { TraderConfigData } from '../types'
|
||||
import type { TraderConfigData } from '../../types'
|
||||
|
||||
// 表单内部状态类型
|
||||
interface FormState {
|
||||
+2
-2
@@ -1,5 +1,5 @@
|
||||
import type { TraderConfigData } from '../types'
|
||||
import { PunkAvatar, getTraderAvatar } from './PunkAvatar'
|
||||
import type { TraderConfigData } from '../../types'
|
||||
import { PunkAvatar, getTraderAvatar } from '../common/PunkAvatar'
|
||||
|
||||
// 提取下划线后面的名称部分
|
||||
function getShortName(fullName: string): string {
|
||||
@@ -1,7 +1,7 @@
|
||||
import { useState } from 'react'
|
||||
import HeaderBar from '../components/HeaderBar'
|
||||
import HeaderBar from '../components/common/HeaderBar'
|
||||
import LoginModal from '../components/landing/LoginModal'
|
||||
import { LoginRequiredOverlay } from '../components/LoginRequiredOverlay'
|
||||
import { LoginRequiredOverlay } from '../components/auth/LoginRequiredOverlay'
|
||||
import FooterSection from '../components/landing/FooterSection'
|
||||
import TerminalHero from '../components/landing/core/TerminalHero'
|
||||
import LiveFeed from '../components/landing/core/LiveFeed'
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { DeepVoidBackground } from '../components/DeepVoidBackground'
|
||||
import { DeepVoidBackground } from '../components/common/DeepVoidBackground'
|
||||
import { AlertCircle, Home } from 'lucide-react'
|
||||
|
||||
export function PageNotFound() {
|
||||
|
||||
@@ -4,9 +4,9 @@ import { User, Cpu, Building2, MessageCircle, Eye, EyeOff, ChevronRight, Plus, P
|
||||
import { useAuth } from '../contexts/AuthContext'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { api } from '../lib/api'
|
||||
import { ExchangeConfigModal } from '../components/traders/ExchangeConfigModal'
|
||||
import { TelegramConfigModal } from '../components/traders/TelegramConfigModal'
|
||||
import { ModelConfigModal } from '../components/AITradersPage'
|
||||
import { ExchangeConfigModal } from '../components/trader/ExchangeConfigModal'
|
||||
import { TelegramConfigModal } from '../components/trader/TelegramConfigModal'
|
||||
import { ModelConfigModal } from '../components/trader/AITradersPage'
|
||||
import type { Exchange, AIModel } from '../types'
|
||||
|
||||
type Tab = 'account' | 'models' | 'exchanges' | 'telegram'
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user