refactor: standardize code comments

This commit is contained in:
tinkle-community
2025-12-08 01:40:48 +08:00
parent 0636ced476
commit a12c0ae8c9
103 changed files with 5466 additions and 5468 deletions
+11 -11
View File
@@ -69,7 +69,7 @@ func (s *Server) handleBacktestStart(c *gin.Context) {
cfg.PromptTemplate = "default"
}
if _, err := decision.GetPromptTemplate(cfg.PromptTemplate); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("提示词模板不存在: %s", cfg.PromptTemplate)})
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Prompt template does not exist: %s", cfg.PromptTemplate)})
return
}
cfg.CustomPrompt = strings.TrimSpace(cfg.CustomPrompt)
@@ -498,9 +498,9 @@ func writeBacktestAccessError(c *gin.Context, err error) bool {
}
switch {
case errors.Is(err, errBacktestForbidden):
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问该回测任务"})
c.JSON(http.StatusForbidden, gin.H{"error": "No permission to access this backtest task"})
case errors.Is(err, os.ErrNotExist), errors.Is(err, sql.ErrNoRows):
c.JSON(http.StatusNotFound, gin.H{"error": "回测任务不存在"})
c.JSON(http.StatusNotFound, gin.H{"error": "Backtest task does not exist"})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
@@ -512,7 +512,7 @@ func (s *Server) resolveBacktestAIConfig(cfg *backtest.BacktestConfig, userID st
return fmt.Errorf("config is nil")
}
if s.store == nil {
return fmt.Errorf("系统数据库未就绪,无法加载AI模型配置")
return fmt.Errorf("System database not ready, cannot load AI model configuration")
}
cfg.UserID = normalizeUserID(userID)
@@ -525,7 +525,7 @@ func (s *Server) hydrateBacktestAIConfig(cfg *backtest.BacktestConfig) error {
return fmt.Errorf("config is nil")
}
if s.store == nil {
return fmt.Errorf("系统数据库未就绪,无法加载AI模型配置")
return fmt.Errorf("System database not ready, cannot load AI model configuration")
}
cfg.UserID = normalizeUserID(cfg.UserID)
@@ -539,23 +539,23 @@ func (s *Server) hydrateBacktestAIConfig(cfg *backtest.BacktestConfig) error {
if modelID != "" {
model, err = s.store.AIModel().Get(cfg.UserID, modelID)
if err != nil {
return fmt.Errorf("加载AI模型失败: %w", err)
return fmt.Errorf("Failed to load AI model: %w", err)
}
} else {
model, err = s.store.AIModel().GetDefault(cfg.UserID)
if err != nil {
return fmt.Errorf("未找到可用的AI模型: %w", err)
return fmt.Errorf("No available AI model found: %w", err)
}
cfg.AIModelID = model.ID
}
if !model.Enabled {
return fmt.Errorf("AI模型 %s 尚未启用", model.Name)
return fmt.Errorf("AI model %s is not enabled yet", model.Name)
}
apiKey := strings.TrimSpace(model.APIKey)
if apiKey == "" {
return fmt.Errorf("AI模型 %s 缺少API Key,请先在系统中配置", model.Name)
return fmt.Errorf("AI model %s is missing API Key, please configure it in the system first", model.Name)
}
cfg.AICfg.Provider = strings.ToLower(model.Provider)
@@ -569,10 +569,10 @@ func (s *Server) hydrateBacktestAIConfig(cfg *backtest.BacktestConfig) error {
if cfg.AICfg.Provider == "custom" {
if cfg.AICfg.BaseURL == "" {
return fmt.Errorf("自定义AI模型需要配置 API 地址")
return fmt.Errorf("Custom AI model requires API URL configuration")
}
if cfg.AICfg.Model == "" {
return fmt.Errorf("自定义AI模型需要配置模型名称")
return fmt.Errorf("Custom AI model requires model name configuration")
}
}
+14 -14
View File
@@ -8,21 +8,21 @@ import (
"github.com/gin-gonic/gin"
)
// CryptoHandler 加密 API 處理器
// CryptoHandler Encryption API handler
type CryptoHandler struct {
cryptoService *crypto.CryptoService
}
// NewCryptoHandler 創建加密處理器
// NewCryptoHandler Creates encryption handler
func NewCryptoHandler(cryptoService *crypto.CryptoService) *CryptoHandler {
return &CryptoHandler{
cryptoService: cryptoService,
}
}
// ==================== 公鑰端點 ====================
// ==================== Public Key Endpoint ====================
// HandleGetPublicKey 獲取伺服器公鑰
// HandleGetPublicKey Get server public key
func (h *CryptoHandler) HandleGetPublicKey(c *gin.Context) {
publicKey := h.cryptoService.GetPublicKeyPEM()
@@ -32,9 +32,9 @@ func (h *CryptoHandler) HandleGetPublicKey(c *gin.Context) {
})
}
// ==================== 加密數據解密端點 ====================
// ==================== Encrypted Data Decryption Endpoint ====================
// HandleDecryptSensitiveData 解密客戶端傳送的加密数据
// HandleDecryptSensitiveData Decrypt encrypted data sent from client
func (h *CryptoHandler) HandleDecryptSensitiveData(c *gin.Context) {
var payload crypto.EncryptedPayload
if err := c.ShouldBindJSON(&payload); err != nil {
@@ -42,10 +42,10 @@ func (h *CryptoHandler) HandleDecryptSensitiveData(c *gin.Context) {
return
}
// 解密
// Decrypt
decrypted, err := h.cryptoService.DecryptSensitiveData(&payload)
if err != nil {
log.Printf("❌ 解密失敗: %v", err)
log.Printf("❌ Decryption failed: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Decryption failed"})
return
}
@@ -55,18 +55,18 @@ func (h *CryptoHandler) HandleDecryptSensitiveData(c *gin.Context) {
})
}
// ==================== 審計日誌查詢端點 ====================
// ==================== Audit Log Query Endpoint ====================
// 删除审计日志相关功能,在当前简化的实现中不需要
// Audit log functionality removed, not needed in current simplified implementation
// ==================== 工具函數 ====================
// ==================== Utility Functions ====================
// isValidPrivateKey 驗證私鑰格式
// isValidPrivateKey Validate private key format
func isValidPrivateKey(key string) bool {
// EVM 私鑰: 64 位十六進制 (可選 0x 前綴)
// EVM private key: 64 hex characters (optional 0x prefix)
if len(key) == 64 || (len(key) == 66 && key[:2] == "0x") {
return true
}
// TODO: 添加其他鏈的驗證
// TODO: Add validation for other chains
return false
}
+53 -53
View File
@@ -4,7 +4,7 @@ import (
"testing"
)
// MockUser 模擬用戶結構
// MockUser Mock user structure
type MockUser struct {
ID int
Email string
@@ -12,7 +12,7 @@ type MockUser struct {
OTPVerified bool
}
// TestOTPRefetchLogic 測試 OTP 重新獲取邏輯
// TestOTPRefetchLogic Test OTP refetch logic
func TestOTPRefetchLogic(t *testing.T) {
tests := []struct {
name string
@@ -22,14 +22,14 @@ func TestOTPRefetchLogic(t *testing.T) {
expectedMessage string
}{
{
name: "新用戶註冊_郵箱不存在",
name: "New user registration - email does not exist",
existingUser: nil,
userExists: false,
expectedAction: "create_new",
expectedMessage: "創建新用戶",
expectedMessage: "Create new user",
},
{
name: "未完成OTP驗證_允許重新獲取",
name: "Incomplete OTP verification - allow refetch",
existingUser: &MockUser{
ID: 1,
Email: "test@example.com",
@@ -38,10 +38,10 @@ func TestOTPRefetchLogic(t *testing.T) {
},
userExists: true,
expectedAction: "allow_refetch",
expectedMessage: "检测到未完成的注册,请继续完成OTP设置",
expectedMessage: "Incomplete registration detected, please continue OTP setup",
},
{
name: "已完成OTP驗證_拒絕重複註冊",
name: "Completed OTP verification - reject duplicate registration",
existingUser: &MockUser{
ID: 2,
Email: "verified@example.com",
@@ -50,45 +50,45 @@ func TestOTPRefetchLogic(t *testing.T) {
},
userExists: true,
expectedAction: "reject_duplicate",
expectedMessage: "邮箱已被注册",
expectedMessage: "Email already registered",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模擬邏輯處理流程
// Simulate logic processing flow
var actualAction string
var actualMessage string
if !tt.userExists {
// 用戶不存在,創建新用戶
// User does not exist, create new user
actualAction = "create_new"
actualMessage = "創建新用戶"
actualMessage = "Create new user"
} else {
// 用戶已存在,檢查 OTP 驗證狀態
// User exists, check OTP verification status
if !tt.existingUser.OTPVerified {
// 未完成 OTP 驗證,允許重新獲取
// OTP verification incomplete, allow refetch
actualAction = "allow_refetch"
actualMessage = "检测到未完成的注册,请继续完成OTP设置"
actualMessage = "Incomplete registration detected, please continue OTP setup"
} else {
// 已完成驗證,拒絕重複註冊
// Verification completed, reject duplicate registration
actualAction = "reject_duplicate"
actualMessage = "邮箱已被注册"
actualMessage = "Email already registered"
}
}
// 驗證結果
// Verify results
if actualAction != tt.expectedAction {
t.Errorf("Action 不符: got %s, want %s", actualAction, tt.expectedAction)
t.Errorf("Action mismatch: got %s, want %s", actualAction, tt.expectedAction)
}
if actualMessage != tt.expectedMessage {
t.Errorf("Message 不符: got %s, want %s", actualMessage, tt.expectedMessage)
t.Errorf("Message mismatch: got %s, want %s", actualMessage, tt.expectedMessage)
}
})
}
}
// TestOTPVerificationStates 測試 OTP 驗證狀態判斷
// TestOTPVerificationStates Test OTP verification state determination
func TestOTPVerificationStates(t *testing.T) {
tests := []struct {
name string
@@ -96,12 +96,12 @@ func TestOTPVerificationStates(t *testing.T) {
shouldAllowRefetch bool
}{
{
name: "OTP已驗證_不允許重新獲取",
name: "OTP verified - disallow refetch",
otpVerified: true,
shouldAllowRefetch: false,
},
{
name: "OTP未驗證_允許重新獲取",
name: "OTP not verified - allow refetch",
otpVerified: false,
shouldAllowRefetch: true,
},
@@ -109,7 +109,7 @@ func TestOTPVerificationStates(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模擬驗證邏輯
// Simulate verification logic
allowRefetch := !tt.otpVerified
if allowRefetch != tt.shouldAllowRefetch {
@@ -120,72 +120,72 @@ func TestOTPVerificationStates(t *testing.T) {
}
}
// TestRegistrationFlow 測試完整註冊流程的邏輯分支
// TestRegistrationFlow Test complete registration flow logic branches
func TestRegistrationFlow(t *testing.T) {
tests := []struct {
name string
scenario string
userExists bool
otpVerified bool
expectHTTPCode int // 模擬的 HTTP 狀態碼
expectHTTPCode int // Simulated HTTP status code
expectResponse string
}{
{
name: "場景1_新用戶首次註冊",
scenario: "新用戶首次訪問註冊接口",
name: "Scenario 1: New user first registration",
scenario: "New user first accesses registration endpoint",
userExists: false,
otpVerified: false,
expectHTTPCode: 200,
expectResponse: "創建用戶並返回 OTP 設置信息",
expectResponse: "Create user and return OTP setup information",
},
{
name: "場景2_用戶中斷註冊後重新訪問",
scenario: "用戶之前註冊但未完成 OTP 設置,現在重新訪問",
name: "Scenario 2: User re-accesses after interrupting registration",
scenario: "User registered previously but did not complete OTP setup, now re-accessing",
userExists: true,
otpVerified: false,
expectHTTPCode: 200,
expectResponse: "返回現有用戶的 OTP 信息,允許繼續完成",
expectResponse: "Return existing user's OTP information, allow continuation",
},
{
name: "場景3_已註冊用戶嘗試重複註冊",
scenario: "用戶已完成註冊,嘗試用同一郵箱再次註冊",
name: "Scenario 3: Registered user attempts duplicate registration",
scenario: "User already completed registration, attempts to register again with same email",
userExists: true,
otpVerified: true,
expectHTTPCode: 409, // Conflict
expectResponse: "邮箱已被注册",
expectResponse: "Email already registered",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模擬註冊流程邏輯
// Simulate registration flow logic
var actualHTTPCode int
var actualResponse string
if !tt.userExists {
// 新用戶,創建並返回 OTP 信息
// New user, create and return OTP information
actualHTTPCode = 200
actualResponse = "創建用戶並返回 OTP 設置信息"
actualResponse = "Create user and return OTP setup information"
} else {
// 用戶已存在
// User exists
if !tt.otpVerified {
// 未完成 OTP 驗證,允許重新獲取
// OTP verification incomplete, allow refetch
actualHTTPCode = 200
actualResponse = "返回現有用戶的 OTP 信息,允許繼續完成"
actualResponse = "Return existing user's OTP information, allow continuation"
} else {
// 已完成驗證,拒絕重複註冊
// Verification completed, reject duplicate registration
actualHTTPCode = 409
actualResponse = "邮箱已被注册"
actualResponse = "Email already registered"
}
}
// 驗證
// Verify
if actualHTTPCode != tt.expectHTTPCode {
t.Errorf("HTTP code 不符: got %d, want %d (scenario: %s)",
t.Errorf("HTTP code mismatch: got %d, want %d (scenario: %s)",
actualHTTPCode, tt.expectHTTPCode, tt.scenario)
}
if actualResponse != tt.expectResponse {
t.Errorf("Response 不符: got %s, want %s (scenario: %s)",
t.Errorf("Response mismatch: got %s, want %s (scenario: %s)",
actualResponse, tt.expectResponse, tt.scenario)
}
@@ -194,7 +194,7 @@ func TestRegistrationFlow(t *testing.T) {
}
}
// TestEdgeCases 測試邊界情況
// TestEdgeCases Test edge cases
func TestEdgeCases(t *testing.T) {
tests := []struct {
name string
@@ -203,17 +203,17 @@ func TestEdgeCases(t *testing.T) {
description string
}{
{
name: "用戶ID為0_視為新用戶",
name: "User ID is 0 - treated as new user",
user: &MockUser{
ID: 0,
Email: "new@example.com",
OTPVerified: false,
},
expectAllow: true,
description: "ID為0通常表示用戶還未創建",
description: "ID of 0 usually indicates user has not been created yet",
},
{
name: "OTPSecret為空_仍可重新獲取",
name: "OTPSecret is empty - still can refetch",
user: &MockUser{
ID: 1,
Email: "test@example.com",
@@ -221,10 +221,10 @@ func TestEdgeCases(t *testing.T) {
OTPVerified: false,
},
expectAllow: true,
description: "即使 OTPSecret 為空,只要未驗證就允許重新獲取",
description: "Even if OTPSecret is empty, as long as not verified, refetch is allowed",
},
{
name: "OTPSecret存在但已驗證_不允許",
name: "OTPSecret exists but already verified - not allowed",
user: &MockUser{
ID: 2,
Email: "verified@example.com",
@@ -232,13 +232,13 @@ func TestEdgeCases(t *testing.T) {
OTPVerified: true,
},
expectAllow: false,
description: "OTP 已驗證的用戶不能重新獲取",
description: "Users with verified OTP cannot refetch",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 核心邏輯:只要 OTPVerified false,就允許重新獲取
// Core logic: as long as OTPVerified is false, refetch is allowed
allowRefetch := !tt.user.OTPVerified
if allowRefetch != tt.expectAllow {
+487 -487
View File
File diff suppressed because it is too large Load Diff
+26 -26
View File
@@ -7,7 +7,7 @@ import (
"nofx/store"
)
// TestUpdateTraderRequest_SystemPromptTemplate 测试更新交易员时 SystemPromptTemplate 字段是否存在
// TestUpdateTraderRequest_SystemPromptTemplate Test whether SystemPromptTemplate field exists when updating trader
func TestUpdateTraderRequest_SystemPromptTemplate(t *testing.T) {
tests := []struct {
name string
@@ -15,7 +15,7 @@ func TestUpdateTraderRequest_SystemPromptTemplate(t *testing.T) {
expectedPromptTemplate string
}{
{
name: "更新时应该能接收 system_prompt_template=nof1",
name: "Should accept system_prompt_template=nof1 during update",
requestJSON: `{
"name": "Test Trader",
"ai_model_id": "gpt-4",
@@ -33,7 +33,7 @@ func TestUpdateTraderRequest_SystemPromptTemplate(t *testing.T) {
expectedPromptTemplate: "nof1",
},
{
name: "更新时应该能接收 system_prompt_template=default",
name: "Should accept system_prompt_template=default during update",
requestJSON: `{
"name": "Test Trader",
"ai_model_id": "gpt-4",
@@ -51,7 +51,7 @@ func TestUpdateTraderRequest_SystemPromptTemplate(t *testing.T) {
expectedPromptTemplate: "default",
},
{
name: "更新时应该能接收 system_prompt_template=custom",
name: "Should accept system_prompt_template=custom during update",
requestJSON: `{
"name": "Test Trader",
"ai_model_id": "gpt-4",
@@ -72,20 +72,20 @@ func TestUpdateTraderRequest_SystemPromptTemplate(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 测试 UpdateTraderRequest 结构体是否能正确解析 system_prompt_template 字段
// Test whether UpdateTraderRequest struct can correctly parse system_prompt_template field
var req UpdateTraderRequest
err := json.Unmarshal([]byte(tt.requestJSON), &req)
if err != nil {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
// ✅ 验证 SystemPromptTemplate 字段是否被正确读取
// Verify SystemPromptTemplate field is correctly read
if req.SystemPromptTemplate != tt.expectedPromptTemplate {
t.Errorf("Expected SystemPromptTemplate=%q, got %q",
tt.expectedPromptTemplate, req.SystemPromptTemplate)
}
// 验证其他字段也被正确解析
// Verify other fields are also correctly parsed
if req.Name != "Test Trader" {
t.Errorf("Name not parsed correctly")
}
@@ -96,7 +96,7 @@ func TestUpdateTraderRequest_SystemPromptTemplate(t *testing.T) {
}
}
// TestGetTraderConfigResponse_SystemPromptTemplate 测试获取交易员配置时返回值是否包含 system_prompt_template
// TestGetTraderConfigResponse_SystemPromptTemplate Test whether return value contains system_prompt_template when getting trader config
func TestGetTraderConfigResponse_SystemPromptTemplate(t *testing.T) {
tests := []struct {
name string
@@ -104,7 +104,7 @@ func TestGetTraderConfigResponse_SystemPromptTemplate(t *testing.T) {
expectedTemplate string
}{
{
name: "获取配置应该返回 system_prompt_template=nof1",
name: "Get config should return system_prompt_template=nof1",
traderConfig: &store.Trader{
ID: "trader-123",
UserID: "user-1",
@@ -125,7 +125,7 @@ func TestGetTraderConfigResponse_SystemPromptTemplate(t *testing.T) {
expectedTemplate: "nof1",
},
{
name: "获取配置应该返回 system_prompt_template=default",
name: "Get config should return system_prompt_template=default",
traderConfig: &store.Trader{
ID: "trader-456",
UserID: "user-1",
@@ -149,7 +149,7 @@ func TestGetTraderConfigResponse_SystemPromptTemplate(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 handleGetTraderConfig 的返回值构造逻辑(修复后的实现)
// Simulate handleGetTraderConfig return value construction logic (fixed implementation)
result := map[string]interface{}{
"trader_id": tt.traderConfig.ID,
"trader_name": tt.traderConfig.Name,
@@ -167,7 +167,7 @@ func TestGetTraderConfigResponse_SystemPromptTemplate(t *testing.T) {
"is_running": tt.traderConfig.IsRunning,
}
// ✅ 检查响应中是否包含 system_prompt_template
// Check if response contains system_prompt_template
if _, exists := result["system_prompt_template"]; !exists {
t.Errorf("Response is missing 'system_prompt_template' field")
} else {
@@ -178,7 +178,7 @@ func TestGetTraderConfigResponse_SystemPromptTemplate(t *testing.T) {
}
}
// 验证其他字段是否正确
// Verify other fields are correct
if result["trader_id"] != tt.traderConfig.ID {
t.Errorf("trader_id mismatch")
}
@@ -189,7 +189,7 @@ func TestGetTraderConfigResponse_SystemPromptTemplate(t *testing.T) {
}
}
// TestUpdateTraderRequest_CompleteFields 验证 UpdateTraderRequest 结构体定义完整性
// TestUpdateTraderRequest_CompleteFields Verify UpdateTraderRequest struct definition completeness
func TestUpdateTraderRequest_CompleteFields(t *testing.T) {
jsonData := `{
"name": "Test Trader",
@@ -212,7 +212,7 @@ func TestUpdateTraderRequest_CompleteFields(t *testing.T) {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
// 验证基本字段是否正确解析
// Verify basic fields are correctly parsed
if req.Name != "Test Trader" {
t.Errorf("Name mismatch: got %q", req.Name)
}
@@ -220,15 +220,15 @@ func TestUpdateTraderRequest_CompleteFields(t *testing.T) {
t.Errorf("AIModelID mismatch: got %q", req.AIModelID)
}
// ✅ 验证 SystemPromptTemplate 字段已正确添加到结构体
// Verify SystemPromptTemplate field has been correctly added to struct
if req.SystemPromptTemplate != "nof1" {
t.Errorf("SystemPromptTemplate mismatch: expected %q, got %q", "nof1", req.SystemPromptTemplate)
}
}
// TestTraderListResponse_SystemPromptTemplate 测试 handleTraderList API 返回的 trader 对象是否包含 system_prompt_template 字段
// TestTraderListResponse_SystemPromptTemplate Test whether trader object returned by handleTraderList API contains system_prompt_template field
func TestTraderListResponse_SystemPromptTemplate(t *testing.T) {
// 模拟 handleTraderList 中的 trader 对象构造
// Simulate trader object construction in handleTraderList
trader := &store.Trader{
ID: "trader-001",
UserID: "user-1",
@@ -240,7 +240,7 @@ func TestTraderListResponse_SystemPromptTemplate(t *testing.T) {
IsRunning: true,
}
// 构造 API 响应对象(与 api/server.go 中的逻辑一致)
// Construct API response object (consistent with logic in api/server.go)
response := map[string]interface{}{
"trader_id": trader.ID,
"trader_name": trader.Name,
@@ -251,20 +251,20 @@ func TestTraderListResponse_SystemPromptTemplate(t *testing.T) {
"system_prompt_template": trader.SystemPromptTemplate,
}
// ✅ 验证 system_prompt_template 字段存在
// Verify system_prompt_template field exists
if _, exists := response["system_prompt_template"]; !exists {
t.Errorf("Trader list response is missing 'system_prompt_template' field")
}
// ✅ 验证 system_prompt_template 值正确
// Verify system_prompt_template value is correct
if response["system_prompt_template"] != "nof1" {
t.Errorf("Expected system_prompt_template='nof1', got %v", response["system_prompt_template"])
}
}
// TestPublicTraderListResponse_SystemPromptTemplate 测试 handlePublicTraderList API 返回的 trader 对象是否包含 system_prompt_template 字段
// TestPublicTraderListResponse_SystemPromptTemplate Test whether trader object returned by handlePublicTraderList API contains system_prompt_template field
func TestPublicTraderListResponse_SystemPromptTemplate(t *testing.T) {
// 模拟 getConcurrentTraderData 返回的 trader 数据
// Simulate trader data returned by getConcurrentTraderData
traderData := map[string]interface{}{
"trader_id": "trader-002",
"trader_name": "Public Trader",
@@ -279,7 +279,7 @@ func TestPublicTraderListResponse_SystemPromptTemplate(t *testing.T) {
"system_prompt_template": "default",
}
// 构造 API 响应对象(与 api/server.go handlePublicTraderList 中的逻辑一致)
// Construct API response object (consistent with logic in api/server.go handlePublicTraderList)
response := map[string]interface{}{
"trader_id": traderData["trader_id"],
"trader_name": traderData["trader_name"],
@@ -293,12 +293,12 @@ func TestPublicTraderListResponse_SystemPromptTemplate(t *testing.T) {
"system_prompt_template": traderData["system_prompt_template"],
}
// ✅ 验证 system_prompt_template 字段存在
// Verify system_prompt_template field exists
if _, exists := response["system_prompt_template"]; !exists {
t.Errorf("Public trader list response is missing 'system_prompt_template' field")
}
// ✅ 验证 system_prompt_template 值正确
// Verify system_prompt_template value is correct
if response["system_prompt_template"] != "default" {
t.Errorf("Expected system_prompt_template='default', got %v", response["system_prompt_template"])
}
+102 -102
View File
@@ -14,21 +14,21 @@ import (
"github.com/google/uuid"
)
// handleGetStrategies 获取策略列表
// handleGetStrategies Get strategy list
func (s *Server) handleGetStrategies(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权"})
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
strategies, err := s.store.Strategy().List(userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取策略列表失败: " + err.Error()})
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get strategy list: " + err.Error()})
return
}
// 转换为前端格式
// Convert to frontend format
result := make([]gin.H, 0, len(strategies))
for _, st := range strategies {
var config store.StrategyConfig
@@ -51,19 +51,19 @@ func (s *Server) handleGetStrategies(c *gin.Context) {
})
}
// handleGetStrategy 获取单个策略
// handleGetStrategy Get single strategy
func (s *Server) handleGetStrategy(c *gin.Context) {
userID := c.GetString("user_id")
strategyID := c.Param("id")
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权"})
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
strategy, err := s.store.Strategy().Get(userID, strategyID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "策略不存在"})
c.JSON(http.StatusNotFound, gin.H{"error": "Strategy not found"})
return
}
@@ -82,11 +82,11 @@ func (s *Server) handleGetStrategy(c *gin.Context) {
})
}
// handleCreateStrategy 创建策略
// handleCreateStrategy Create strategy
func (s *Server) handleCreateStrategy(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权"})
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
@@ -97,14 +97,14 @@ func (s *Server) handleCreateStrategy(c *gin.Context) {
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request parameters: " + err.Error()})
return
}
// 序列化配置
// Serialize configuration
configJSON, err := json.Marshal(req.Config)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "序列化配置失败"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to serialize configuration"})
return
}
@@ -119,34 +119,34 @@ func (s *Server) handleCreateStrategy(c *gin.Context) {
}
if err := s.store.Strategy().Create(strategy); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建策略失败: " + err.Error()})
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create strategy: " + err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"id": strategy.ID,
"message": "策略创建成功",
"message": "Strategy created successfully",
})
}
// handleUpdateStrategy 更新策略
// handleUpdateStrategy Update strategy
func (s *Server) handleUpdateStrategy(c *gin.Context) {
userID := c.GetString("user_id")
strategyID := c.Param("id")
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权"})
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
// 检查是否是系统默认策略
// Check if it's a system default strategy
existing, err := s.store.Strategy().Get(userID, strategyID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "策略不存在"})
c.JSON(http.StatusNotFound, gin.H{"error": "Strategy not found"})
return
}
if existing.IsDefault {
c.JSON(http.StatusForbidden, gin.H{"error": "不能修改系统默认策略"})
c.JSON(http.StatusForbidden, gin.H{"error": "Cannot modify system default strategy"})
return
}
@@ -157,14 +157,14 @@ func (s *Server) handleUpdateStrategy(c *gin.Context) {
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request parameters: " + err.Error()})
return
}
// 序列化配置
// Serialize configuration
configJSON, err := json.Marshal(req.Config)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "序列化配置失败"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to serialize configuration"})
return
}
@@ -177,56 +177,56 @@ func (s *Server) handleUpdateStrategy(c *gin.Context) {
}
if err := s.store.Strategy().Update(strategy); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新策略失败: " + err.Error()})
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update strategy: " + err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "策略更新成功"})
c.JSON(http.StatusOK, gin.H{"message": "Strategy updated successfully"})
}
// handleDeleteStrategy 删除策略
// handleDeleteStrategy Delete strategy
func (s *Server) handleDeleteStrategy(c *gin.Context) {
userID := c.GetString("user_id")
strategyID := c.Param("id")
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权"})
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
if err := s.store.Strategy().Delete(userID, strategyID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除策略失败: " + err.Error()})
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete strategy: " + err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "策略删除成功"})
c.JSON(http.StatusOK, gin.H{"message": "Strategy deleted successfully"})
}
// handleActivateStrategy 激活策略
// handleActivateStrategy Activate strategy
func (s *Server) handleActivateStrategy(c *gin.Context) {
userID := c.GetString("user_id")
strategyID := c.Param("id")
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权"})
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
if err := s.store.Strategy().SetActive(userID, strategyID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "激活策略失败: " + err.Error()})
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to activate strategy: " + err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "策略激活成功"})
c.JSON(http.StatusOK, gin.H{"message": "Strategy activated successfully"})
}
// handleDuplicateStrategy 复制策略
// handleDuplicateStrategy Duplicate strategy
func (s *Server) handleDuplicateStrategy(c *gin.Context) {
userID := c.GetString("user_id")
sourceID := c.Param("id")
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权"})
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
@@ -235,34 +235,34 @@ func (s *Server) handleDuplicateStrategy(c *gin.Context) {
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request parameters: " + err.Error()})
return
}
newID := uuid.New().String()
if err := s.store.Strategy().Duplicate(userID, sourceID, newID, req.Name); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "复制策略失败: " + err.Error()})
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to duplicate strategy: " + err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"id": newID,
"message": "策略复制成功",
"message": "Strategy duplicated successfully",
})
}
// handleGetActiveStrategy 获取当前激活的策略
// handleGetActiveStrategy Get currently active strategy
func (s *Server) handleGetActiveStrategy(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权"})
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
strategy, err := s.store.Strategy().GetActive(userID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "没有激活的策略"})
c.JSON(http.StatusNotFound, gin.H{"error": "No active strategy"})
return
}
@@ -281,9 +281,9 @@ func (s *Server) handleGetActiveStrategy(c *gin.Context) {
})
}
// handleGetDefaultStrategyConfig 获取默认策略配置模板
// handleGetDefaultStrategyConfig Get default strategy configuration template
func (s *Server) handleGetDefaultStrategyConfig(c *gin.Context) {
// 返回默认配置结构,供前端创建新策略时使用
// Return default configuration structure for frontend to use when creating new strategies
defaultConfig := store.StrategyConfig{
CoinSource: store.CoinSourceConfig{
SourceType: "coinpool",
@@ -324,42 +324,42 @@ func (s *Server) handleGetDefaultStrategyConfig(c *gin.Context) {
MinConfidence: 75,
},
PromptSections: store.PromptSectionsConfig{
RoleDefinition: `# 你是专业的加密货币交易AI
RoleDefinition: `# You are a professional cryptocurrency trading AI
你专注于技术分析和风险管理,基于市场数据做出理性的交易决策。
你的目标是在控制风险的前提下,捕捉高概率的交易机会。`,
TradingFrequency: `# ⏱️ 交易频率认知
You focus on technical analysis and risk management, making rational trading decisions based on market data.
Your goal is to capture high-probability trading opportunities while controlling risk.`,
TradingFrequency: `# ⏱️ Trading Frequency Awareness
- 优秀交易员:每天2-4笔 ≈ 每小时0.1-0.2笔
- 每小时>2笔 = 过度交易
- 单笔持仓时间≥30-60分钟
如果你发现自己每个周期都在交易 → 标准过低;若持仓<30分钟就平仓 → 过于急躁。`,
EntryStandards: `# 🎯 开仓标准(严格)
- Excellent traders: 2-4 trades per day ≈ 0.1-0.2 trades per hour
- >2 trades per hour = overtrading
- Single position holding time ≥30-60 minutes
If you find yourself trading every cycle → standards too low; if closing positions <30 minutes → too impatient.`,
EntryStandards: `# 🎯 Entry Standards (Strict)
只在多重信号共振时开仓:
- 趋势方向明确(EMA排列、价格位置)
- 动量确认(MACDRSI协同)
- 波动率适中(ATR合理范围)
- 量价配合(成交量支持方向)
Only enter when multiple signals align:
- Clear trend direction (EMA alignment, price position)
- Momentum confirmation (MACD, RSI cooperation)
- Moderate volatility (ATR reasonable range)
- Volume-price coordination (volume supports direction)
避免:单一指标、信号矛盾、横盘震荡、刚平仓即重启。`,
DecisionProcess: `# 📋 决策流程
Avoid: single indicator, conflicting signals, sideways consolidation, reopening immediately after closing.`,
DecisionProcess: `# 📋 Decision Process
1. 检查持仓 → 是否该止盈/止损
2. 扫描候选币 + 多时间框 → 是否存在强信号
3. 评估风险回报比 → 是否满足最小要求
4. 先写思维链,再输出结构化JSON`,
1. Check positions → Should take profit/stop loss
2. Scan candidate coins + multiple timeframes → Are there strong signals
3. Evaluate risk-reward ratio → Does it meet minimum requirements
4. Write chain of thought first, then output structured JSON`,
},
}
c.JSON(http.StatusOK, defaultConfig)
}
// handlePreviewPrompt 预览策略生成的 Prompt
// handlePreviewPrompt Preview prompt generated by strategy
func (s *Server) handlePreviewPrompt(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权"})
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
@@ -370,28 +370,28 @@ func (s *Server) handlePreviewPrompt(c *gin.Context) {
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request parameters: " + err.Error()})
return
}
// 使用默认值
// Use default values
if req.AccountEquity <= 0 {
req.AccountEquity = 1000.0 // 默认模拟账户净值
req.AccountEquity = 1000.0 // Default simulated account equity
}
if req.PromptVariant == "" {
req.PromptVariant = "balanced"
}
// 创建策略引擎来构建 prompt
// Create strategy engine to build prompt
engine := decision.NewStrategyEngine(&req.Config)
// 构建系统 prompt(使用策略引擎内置的方法)
// Build system prompt (using built-in method from strategy engine)
systemPrompt := engine.BuildSystemPrompt(
req.AccountEquity,
req.PromptVariant,
)
// 获取可用的 prompt 模板列表
// Get list of available prompt templates
templateNames := decision.GetAllPromptTemplateNames()
c.JSON(http.StatusOK, gin.H{
@@ -408,11 +408,11 @@ func (s *Server) handlePreviewPrompt(c *gin.Context) {
})
}
// handleStrategyTestRun AI 测试运行(不执行交易,只返回 AI 分析结果)
// handleStrategyTestRun AI test run (does not execute trades, only returns AI analysis results)
func (s *Server) handleStrategyTestRun(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权"})
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
@@ -424,7 +424,7 @@ func (s *Server) handleStrategyTestRun(c *gin.Context) {
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request parameters: " + err.Error()})
return
}
@@ -432,27 +432,27 @@ func (s *Server) handleStrategyTestRun(c *gin.Context) {
req.PromptVariant = "balanced"
}
// 创建策略引擎来构建 prompt
// Create strategy engine to build prompt
engine := decision.NewStrategyEngine(&req.Config)
// 获取候选币种
// Get candidate coins
candidates, err := engine.GetCandidateCoins()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "获取候选币种失败: " + err.Error(),
"error": "Failed to get candidate coins: " + err.Error(),
"ai_response": "",
})
return
}
// 获取时间周期配置
// Get timeframe configuration
timeframes := req.Config.Indicators.Klines.SelectedTimeframes
primaryTimeframe := req.Config.Indicators.Klines.PrimaryTimeframe
klineCount := req.Config.Indicators.Klines.PrimaryCount
// 如果没有选择时间周期,使用默认值
// If no timeframes selected, use default values
if len(timeframes) == 0 {
// 兼容旧配置:使用主周期和长周期
// Backward compatibility: use primary and longer timeframes
if primaryTimeframe != "" {
timeframes = append(timeframes, primaryTimeframe)
} else {
@@ -469,21 +469,21 @@ func (s *Server) handleStrategyTestRun(c *gin.Context) {
klineCount = 30
}
fmt.Printf("📊 使用时间周期: %v, 主周期: %s, K线数量: %d\n", timeframes, primaryTimeframe, klineCount)
fmt.Printf("📊 Using timeframes: %v, primary: %s, kline count: %d\n", timeframes, primaryTimeframe, klineCount)
// 获取真实市场数据(使用多时间周期)
// Get real market data (using multiple timeframes)
marketDataMap := make(map[string]*market.Data)
for _, coin := range candidates {
data, err := market.GetWithTimeframes(coin.Symbol, timeframes, primaryTimeframe, klineCount)
if err != nil {
// 如果获取某个币种数据失败,记录日志但继续
fmt.Printf("⚠️ 获取 %s 市场数据失败: %v\n", coin.Symbol, err)
// If getting data for a coin fails, log but continue
fmt.Printf("⚠️ Failed to get market data for %s: %v\n", coin.Symbol, err)
continue
}
marketDataMap[coin.Symbol] = data
}
// 构建真实的上下文(用于生成 User Prompt
// Build real context (for generating User Prompt)
testContext := &decision.Context{
CurrentTime: time.Now().Format("2006-01-02 15:04:05"),
RuntimeMinutes: 0,
@@ -504,13 +504,13 @@ func (s *Server) handleStrategyTestRun(c *gin.Context) {
MarketDataMap: marketDataMap,
}
// 构建 System Prompt
// Build System Prompt
systemPrompt := engine.BuildSystemPrompt(1000.0, req.PromptVariant)
// 构建 User Prompt(使用真实市场数据)
// Build User Prompt (using real market data)
userPrompt := engine.BuildUserPrompt(testContext)
// 如果请求真实 AI 调用
// If requesting real AI call
if req.RunRealAI && req.AIModelID != "" {
aiResponse, aiErr := s.runRealAITest(userID, req.AIModelID, systemPrompt, userPrompt)
if aiErr != nil {
@@ -520,9 +520,9 @@ func (s *Server) handleStrategyTestRun(c *gin.Context) {
"candidate_count": len(candidates),
"candidates": candidates,
"prompt_variant": req.PromptVariant,
"ai_response": fmt.Sprintf("❌ AI 调用失败: %s", aiErr.Error()),
"ai_response": fmt.Sprintf("❌ AI call failed: %s", aiErr.Error()),
"ai_error": aiErr.Error(),
"note": "AI 调用出错",
"note": "AI call error",
})
return
}
@@ -534,40 +534,40 @@ func (s *Server) handleStrategyTestRun(c *gin.Context) {
"candidates": candidates,
"prompt_variant": req.PromptVariant,
"ai_response": aiResponse,
"note": "✅ 真实 AI 测试运行成功",
"note": "✅ Real AI test run successful",
})
return
}
// 返回结果(不实际调用 AI,只返回构建的 prompt
// Return result (without actually calling AI, only return built prompt)
c.JSON(http.StatusOK, gin.H{
"system_prompt": systemPrompt,
"user_prompt": userPrompt,
"candidate_count": len(candidates),
"candidates": candidates,
"prompt_variant": req.PromptVariant,
"ai_response": "请选择 AI 模型并点击「运行测试」来执行真实的 AI 分析。",
"note": "未选择 AI 模型或未启用真实 AI 调用",
"ai_response": "Please select an AI model and click 'Run Test' to perform real AI analysis.",
"note": "AI model not selected or real AI call not enabled",
})
}
// runRealAITest 执行真实的 AI 测试调用
// runRealAITest Execute real AI test call
func (s *Server) runRealAITest(userID, modelID, systemPrompt, userPrompt string) (string, error) {
// 获取 AI 模型配置
// Get AI model configuration
model, err := s.store.AIModel().Get(userID, modelID)
if err != nil {
return "", fmt.Errorf("获取 AI 模型失败: %w", err)
return "", fmt.Errorf("failed to get AI model: %w", err)
}
if !model.Enabled {
return "", fmt.Errorf("AI 模型 %s 尚未启用", model.Name)
return "", fmt.Errorf("AI model %s is not enabled", model.Name)
}
if model.APIKey == "" {
return "", fmt.Errorf("AI 模型 %s 缺少 API Key", model.Name)
return "", fmt.Errorf("AI model %s is missing API Key", model.Name)
}
// 创建 AI 客户端
// Create AI client
var aiClient mcp.AIClient
provider := model.Provider
@@ -579,15 +579,15 @@ func (s *Server) runRealAITest(userID, modelID, systemPrompt, userPrompt string)
aiClient = mcp.NewDeepSeekClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
default:
// 使用通用客户端
// Use generic client
aiClient = mcp.NewClient()
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
}
// 调用 AI API
// Call AI API
response, err := aiClient.CallWithMessages(systemPrompt, userPrompt)
if err != nil {
return "", fmt.Errorf("AI API 调用失败: %w", err)
return "", fmt.Errorf("AI API call failed: %w", err)
}
return response, nil
+22 -22
View File
@@ -8,67 +8,67 @@ import (
"github.com/google/uuid"
)
// TestTraderIDUniqueness 测试 traderID 的唯一性(修复 Issue #893
// 验证即使在相同的 exchange AI model 下,也能生成唯一的 traderID
// TestTraderIDUniqueness Test traderID uniqueness (fixes Issue #893)
// Verify that unique traderIDs can be generated even with the same exchange and AI model
func TestTraderIDUniqueness(t *testing.T) {
exchangeID := "binance"
aiModelID := "gpt-4"
// 模拟同时创建 100 trader(相同参数)
// Simulate creating 100 traders simultaneously (with same parameters)
traderIDs := make(map[string]bool)
const numTraders = 100
for i := 0; i < numTraders; i++ {
// 模拟 api/server.go:497 的 traderID 生成逻辑
// Simulate traderID generation logic from api/server.go:497
traderID := generateTraderID(exchangeID, aiModelID)
// ✅ 检查是否重复
// Check for duplicates
if traderIDs[traderID] {
t.Errorf("Duplicate traderID detected: %s", traderID)
}
traderIDs[traderID] = true
// ✅ 验证格式:应该是 "exchange_model_uuid"
// Verify format: should be "exchange_model_uuid"
if !isValidTraderIDFormat(traderID, exchangeID, aiModelID) {
t.Errorf("Invalid traderID format: %s", traderID)
}
}
// ✅ 验证生成了预期数量的唯一 ID
// Verify expected number of unique IDs were generated
if len(traderIDs) != numTraders {
t.Errorf("Expected %d unique traderIDs, got %d", numTraders, len(traderIDs))
}
}
// generateTraderID 辅助函数,模拟 api/server.go 中的 traderID 生成逻辑
// generateTraderID Helper function that simulates traderID generation logic from api/server.go
func generateTraderID(exchangeID, aiModelID string) string {
return fmt.Sprintf("%s_%s_%s", exchangeID, aiModelID, uuid.New().String())
}
// isValidTraderIDFormat 验证 traderID 格式是否符合预期
// isValidTraderIDFormat Verify traderID format matches expected format
func isValidTraderIDFormat(traderID, expectedExchange, expectedModel string) bool {
// 格式:exchange_model_uuid
// 例如:binance_gpt-4_a1b2c3d4-e5f6-7890-abcd-ef1234567890
// Format: exchange_model_uuid
// Example: binance_gpt-4_a1b2c3d4-e5f6-7890-abcd-ef1234567890
parts := strings.Split(traderID, "_")
if len(parts) < 3 {
return false
}
// 验证前缀
// Verify prefix
if parts[0] != expectedExchange {
return false
}
// AI model 可能包含连字符(如 gpt-4),所以需要重组
// 最后一部分应该是 UUID
// AI model may contain hyphens (e.g. gpt-4), so need to reconstruct
// Last part should be UUID
uuidPart := parts[len(parts)-1]
// 验证 UUID 格式(36 个字符,包含 4 个连字符)
// Verify UUID format (36 characters, containing 4 hyphens)
_, err := uuid.Parse(uuidPart)
return err == nil
}
// TestTraderIDFormat 测试 traderID 格式的正确性
// TestTraderIDFormat Test traderID format correctness
func TestTraderIDFormat(t *testing.T) {
tests := []struct {
name string
@@ -84,18 +84,18 @@ func TestTraderIDFormat(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
traderID := generateTraderID(tt.exchangeID, tt.aiModelID)
// ✅ 验证包含正确的前缀
// Verify correct prefix
if !strings.HasPrefix(traderID, tt.exchangeID+"_"+tt.aiModelID+"_") {
t.Errorf("traderID does not have correct prefix. Got: %s", traderID)
}
// ✅ 验证格式有效
// Verify format is valid
if !isValidTraderIDFormat(traderID, tt.exchangeID, tt.aiModelID) {
t.Errorf("Invalid traderID format: %s", traderID)
}
// ✅ 验证长度合理(至少应该有 exchange + model + "_" + UUID(36) 的长度)
minLength := len(tt.exchangeID) + len(tt.aiModelID) + 2 + 36 // 2个下划线 + 36字符UUID
// Verify reasonable length (should be at least exchange + model + "_" + UUID(36))
minLength := len(tt.exchangeID) + len(tt.aiModelID) + 2 + 36 // 2 underscores + 36 character UUID
if len(traderID) < minLength {
t.Errorf("traderID too short: expected at least %d chars, got %d", minLength, len(traderID))
}
@@ -103,12 +103,12 @@ func TestTraderIDFormat(t *testing.T) {
}
}
// TestTraderIDNoCollision 测试在高并发场景下不会产生碰撞
// TestTraderIDNoCollision Test that no collisions occur in high concurrency scenarios
func TestTraderIDNoCollision(t *testing.T) {
const iterations = 1000
uniqueIDs := make(map[string]bool, iterations)
// 模拟高并发场景
// Simulate high concurrency scenario
for i := 0; i < iterations; i++ {
id := generateTraderID("binance", "gpt-4")
if uniqueIDs[id] {
+9 -9
View File
@@ -2,20 +2,20 @@ package api
import "strings"
// MaskSensitiveString 脱敏敏感字符串,只显示前4位和后4位
// 用于脱敏 API KeySecret KeyPrivate Key 等敏感信息
// MaskSensitiveString Mask sensitive strings, showing only first 4 and last 4 characters
// Used to mask API Key, Secret Key, Private Key and other sensitive information
func MaskSensitiveString(s string) string {
if s == "" {
return ""
}
length := len(s)
if length <= 8 {
return "****" // 字符串太短,全部隐藏
return "****" // String too short, hide everything
}
return s[:4] + "****" + s[length-4:]
}
// SanitizeModelConfigForLog 脱敏模型配置用于日志输出
// SanitizeModelConfigForLog Sanitize model configuration for log output
func SanitizeModelConfigForLog(models map[string]struct {
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
@@ -34,7 +34,7 @@ func SanitizeModelConfigForLog(models map[string]struct {
return safe
}
// SanitizeExchangeConfigForLog 脱敏交易所配置用于日志输出
// SanitizeExchangeConfigForLog Sanitize exchange configuration for log output
func SanitizeExchangeConfigForLog(exchanges map[string]struct {
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
@@ -54,7 +54,7 @@ func SanitizeExchangeConfigForLog(exchanges map[string]struct {
"testnet": cfg.Testnet,
}
// 只在有值时才添加脱敏后的敏感字段
// Only add masked sensitive fields when they have values
if cfg.APIKey != "" {
safeExchange["api_key"] = MaskSensitiveString(cfg.APIKey)
}
@@ -68,7 +68,7 @@ func SanitizeExchangeConfigForLog(exchanges map[string]struct {
safeExchange["lighter_private_key"] = MaskSensitiveString(cfg.LighterPrivateKey)
}
// 非敏感字段直接添加
// Add non-sensitive fields directly
if cfg.HyperliquidWalletAddr != "" {
safeExchange["hyperliquid_wallet_addr"] = cfg.HyperliquidWalletAddr
}
@@ -87,14 +87,14 @@ func SanitizeExchangeConfigForLog(exchanges map[string]struct {
return safe
}
// MaskEmail 脱敏邮箱地址,保留前2位和@后部分
// MaskEmail Mask email address, keeping first 2 characters and domain part
func MaskEmail(email string) string {
if email == "" {
return ""
}
parts := strings.Split(email, "@")
if len(parts) != 2 {
return "****" // 格式不正确
return "****" // Incorrect format
}
username := parts[0]
domain := parts[1]
+12 -12
View File
@@ -11,27 +11,27 @@ func TestMaskSensitiveString(t *testing.T) {
expected string
}{
{
name: "空字符串",
name: "Empty string",
input: "",
expected: "",
},
{
name: "短字符串(小于等于8位)",
name: "Short string (8 characters or less)",
input: "short",
expected: "****",
},
{
name: "正常API key",
name: "Normal API key",
input: "sk-1234567890abcdefghijklmnopqrstuvwxyz",
expected: "sk-1****wxyz",
},
{
name: "正常私钥",
name: "Normal private key",
input: "0x1234567890abcdef1234567890abcdef12345678",
expected: "0x12****5678",
},
{
name: "刚好9位",
name: "Exactly 9 characters",
input: "123456789",
expected: "1234****6789",
},
@@ -119,7 +119,7 @@ func TestSanitizeExchangeConfigForLog(t *testing.T) {
result := SanitizeExchangeConfigForLog(exchanges)
// 检查币安配置
// Check Binance configuration
binanceConfig, ok := result["binance"].(map[string]interface{})
if !ok {
t.Fatal("binance config not found or wrong type")
@@ -143,7 +143,7 @@ func TestSanitizeExchangeConfigForLog(t *testing.T) {
t.Errorf("expected masked secret_key='bina****cdef', got %q", maskedSecretKey)
}
// 检查 Hyperliquid 配置
// Check Hyperliquid configuration
hlConfig, ok := result["hyperliquid"].(map[string]interface{})
if !ok {
t.Fatal("hyperliquid config not found or wrong type")
@@ -154,7 +154,7 @@ func TestSanitizeExchangeConfigForLog(t *testing.T) {
t.Fatal("hyperliquid_wallet_addr not found or wrong type")
}
// 钱包地址不应该被脱敏
// Wallet address should not be masked
if walletAddr != "0x1234567890abcdef1234567890abcdef12345678" {
t.Errorf("wallet address should not be masked, got %q", walletAddr)
}
@@ -167,22 +167,22 @@ func TestMaskEmail(t *testing.T) {
expected string
}{
{
name: "空邮箱",
name: "Empty email",
input: "",
expected: "",
},
{
name: "格式错误",
name: "Invalid format",
input: "notanemail",
expected: "****",
},
{
name: "正常邮箱",
name: "Normal email",
input: "user@example.com",
expected: "us****@example.com",
},
{
name: "短用户名",
name: "Short username",
input: "a@example.com",
expected: "**@example.com",
},
+19 -19
View File
@@ -13,33 +13,33 @@ import (
"golang.org/x/crypto/bcrypt"
)
// JWTSecret JWT密钥,将从配置中动态设置
// JWTSecret is the JWT secret key, will be dynamically set from config
var JWTSecret []byte
// tokenBlacklist 用于登出后的token黑名单(仅内存,按过期时间清理)
// tokenBlacklist for logged out tokens (memory only, cleaned by expiration time)
var tokenBlacklist = struct {
sync.RWMutex
items map[string]time.Time
}{items: make(map[string]time.Time)}
// maxBlacklistEntries 黑名单最大容量阈值
// maxBlacklistEntries is the maximum capacity threshold for blacklist
const maxBlacklistEntries = 100_000
// OTPIssuer OTP发行者名称
// OTPIssuer is the OTP issuer name
const OTPIssuer = "nofxAI"
// SetJWTSecret 设置JWT密钥
// SetJWTSecret sets the JWT secret key
func SetJWTSecret(secret string) {
JWTSecret = []byte(secret)
}
// BlacklistToken 将token加入黑名单直到过期
// BlacklistToken adds token to blacklist until expiration
func BlacklistToken(token string, exp time.Time) {
tokenBlacklist.Lock()
defer tokenBlacklist.Unlock()
tokenBlacklist.items[token] = exp
// 如果超过容量阈值,则进行一次过期清理;若仍超限,记录警告日志
// If exceeds capacity threshold, perform expired cleanup; if still over limit, log warning
if len(tokenBlacklist.items) > maxBlacklistEntries {
now := time.Now()
for t, e := range tokenBlacklist.items {
@@ -54,7 +54,7 @@ func BlacklistToken(token string, exp time.Time) {
}
}
// IsTokenBlacklisted 检查token是否在黑名单中(过期自动清理)
// IsTokenBlacklisted checks if token is in blacklist (auto cleanup on expiration)
func IsTokenBlacklisted(token string) bool {
tokenBlacklist.Lock()
defer tokenBlacklist.Unlock()
@@ -68,26 +68,26 @@ func IsTokenBlacklisted(token string) bool {
return false
}
// Claims JWT声明
// Claims represents JWT claims
type Claims struct {
UserID string `json:"user_id"`
Email string `json:"email"`
jwt.RegisteredClaims
}
// HashPassword 哈希密码
// HashPassword hashes the password
func HashPassword(password string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return string(bytes), err
}
// CheckPassword 验证密码
// CheckPassword verifies the password
func CheckPassword(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}
// GenerateOTPSecret 生成OTP密钥
// GenerateOTPSecret generates OTP secret
func GenerateOTPSecret() (string, error) {
secret := make([]byte, 20)
_, err := rand.Read(secret)
@@ -106,18 +106,18 @@ func GenerateOTPSecret() (string, error) {
return key.Secret(), nil
}
// VerifyOTP 验证OTP码
// VerifyOTP verifies OTP code
func VerifyOTP(secret, code string) bool {
return totp.Validate(code, secret)
}
// GenerateJWT 生成JWT token
// GenerateJWT generates JWT token
func GenerateJWT(userID, email string) (string, error) {
claims := Claims{
UserID: userID,
Email: email,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), // 24小时过期
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), // Expires in 24 hours
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: "nofxAI",
@@ -128,11 +128,11 @@ func GenerateJWT(userID, email string) (string, error) {
return token.SignedString(JWTSecret)
}
// ValidateJWT 验证JWT token
// ValidateJWT validates JWT token
func ValidateJWT(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("意外的签名方法: %v", token.Header["alg"])
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return JWTSecret, nil
})
@@ -145,10 +145,10 @@ func ValidateJWT(tokenString string) (*Claims, error) {
return claims, nil
}
return nil, fmt.Errorf("无效的token")
return nil, fmt.Errorf("invalid token")
}
// GetOTPQRCodeURL 获取OTP二维码URL
// GetOTPQRCodeURL gets OTP QR code URL
func GetOTPQRCodeURL(secret, email string) string {
return fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s", OTPIssuer, email, secret, OTPIssuer)
}
+2 -2
View File
@@ -89,7 +89,7 @@ func (acc *BacktestAccount) Open(symbol, side string, quantity float64, leverage
pos.LiquidationPrice = computeLiquidation(execPrice, leverage, side)
} else {
if leverage != pos.Leverage {
// 采用权重平均杠杆(近似)
// Use weighted average leverage (approximate)
weightedMargin := pos.Margin + margin
pos.Leverage = int(math.Round((pos.Notional + notional) / weightedMargin))
}
@@ -227,7 +227,7 @@ func (acc *BacktestAccount) RealizedPnL() float64 {
return acc.realizedPnL
}
// RestoreFromSnapshots 用于从检查点恢复账户状态。
// RestoreFromSnapshots restores account state from checkpoint.
func (acc *BacktestAccount) RestoreFromSnapshots(cash float64, realized float64, snaps []PositionSnapshot) {
acc.cash = cash
acc.realizedPnL = realized
+5 -5
View File
@@ -7,8 +7,8 @@ import (
"nofx/mcp"
)
// configureMCPClient 根据配置创建/克隆 MCP 客户端(返回 mcp.AIClient 接口)。
// 说明:mcp.New() 返回接口类型,这里统一转为具体实现再做拷贝,避免并发共享状态。
// configureMCPClient creates/clones an MCP client based on configuration (returns mcp.AIClient interface).
// Note: mcp.New() returns an interface type; here we convert to concrete implementation before copying to avoid concurrent shared state.
func configureMCPClient(cfg BacktestConfig, base mcp.AIClient) (mcp.AIClient, error) {
provider := strings.ToLower(strings.TrimSpace(cfg.AICfg.Provider))
@@ -48,9 +48,9 @@ func configureMCPClient(cfg BacktestConfig, base mcp.AIClient) (mcp.AIClient, er
}
}
// cloneBaseClient 复制基础客户端以避免共享可变状态。
// cloneBaseClient copies the base client to avoid shared mutable state.
func cloneBaseClient(base mcp.AIClient) *mcp.Client {
// 优先尝试复用传入的基础客户端(深拷贝)
// Prefer to reuse the passed-in base client (deep copy)
switch c := base.(type) {
case *mcp.Client:
cp := *c
@@ -66,6 +66,6 @@ func cloneBaseClient(base mcp.AIClient) *mcp.Client {
return &cp
}
}
// 回退到新的默认客户端
// Fall back to a new default client
return mcp.NewClient().(*mcp.Client)
}
+1 -1
View File
@@ -20,7 +20,7 @@ type cachedDecision struct {
Decision *decision.FullDecision `json:"decision"`
}
// AICache 持久化 AI 决策,便于重复回测或重放。
// AICache persists AI decisions for repeated backtesting or replay.
type AICache struct {
mu sync.RWMutex
path string
+7 -7
View File
@@ -8,7 +8,7 @@ import (
"nofx/market"
)
// AIConfig 定义回测中使用的 AI 客户端配置。
// AIConfig defines the AI client configuration used in backtesting.
type AIConfig struct {
Provider string `json:"provider"`
Model string `json:"model"`
@@ -23,7 +23,7 @@ type LeverageConfig struct {
AltcoinLeverage int `json:"altcoin_leverage"`
}
// BacktestConfig 描述一次回测运行的输入配置。
// BacktestConfig describes the input configuration for a backtest run.
type BacktestConfig struct {
RunID string `json:"run_id"`
UserID string `json:"user_id,omitempty"`
@@ -54,7 +54,7 @@ type BacktestConfig struct {
ReplayDecisionDir string `json:"replay_decision_dir,omitempty"`
}
// Validate 对配置进行合法性检查并填充默认值。
// Validate performs validity checks on the configuration and fills in default values.
func (cfg *BacktestConfig) Validate() error {
if cfg == nil {
return fmt.Errorf("config is nil")
@@ -151,7 +151,7 @@ func (cfg *BacktestConfig) Validate() error {
return nil
}
// Duration 返回回测区间时长。
// Duration returns the backtest interval duration.
func (cfg *BacktestConfig) Duration() time.Duration {
if cfg == nil {
return 0
@@ -160,11 +160,11 @@ func (cfg *BacktestConfig) Duration() time.Duration {
}
const (
// FillPolicyNextOpen 使用下一根 K 线的开盘价成交。
// FillPolicyNextOpen uses the open price of the next bar for execution.
FillPolicyNextOpen = "next_open"
// FillPolicyBarVWAP 采用当前 K 线的近似 VWAP 成交。
// FillPolicyBarVWAP uses the approximate VWAP of the current bar for execution.
FillPolicyBarVWAP = "bar_vwap"
// FillPolicyMidPrice 采用 (high+low)/2 的中间价成交。
// FillPolicyMidPrice uses the mid-price (high+low)/2 for execution.
FillPolicyMidPrice = "mid"
)
+4 -4
View File
@@ -17,7 +17,7 @@ type symbolSeries struct {
byTF map[string]*timeframeSeries
}
// DataFeed 管理历史K线数据,为回测提供按时间推进的快照。
// DataFeed manages historical kline data and provides time-progressive snapshots for backtesting.
type DataFeed struct {
cfg BacktestConfig
symbols []string
@@ -49,7 +49,7 @@ func (df *DataFeed) loadAll() error {
start := time.Unix(df.cfg.StartTS, 0)
end := time.Unix(df.cfg.EndTS, 0)
// longest timeframe用于辅助指标
// longest timeframe used for auxiliary indicators
var longestDur time.Duration
for _, tf := range df.timeframes {
dur, err := market.TFDuration(tf)
@@ -93,7 +93,7 @@ func (df *DataFeed) loadAll() error {
df.symbolSeries[symbol] = ss
}
// 以第一个符号的主周期生成回测进度时间轴
// Generate backtest progress timeline using the primary timeframe of the first symbol
firstSymbol := df.symbols[0]
primarySeries := df.symbolSeries[firstSymbol].byTF[df.primaryTF]
startMs := start.UnixMilli()
@@ -106,7 +106,7 @@ func (df *DataFeed) loadAll() error {
break
}
df.decisionTimes = append(df.decisionTimes, ts)
// 对齐其他符号,如果缺数据则提前报错
// Align other symbols; report error early if data is missing
for _, symbol := range df.symbols[1:] {
if _, ok := df.symbolSeries[symbol].byTF[df.primaryTF]; !ok {
return fmt.Errorf("symbol %s missing timeframe %s", symbol, df.primaryTF)
+4 -4
View File
@@ -7,7 +7,7 @@ import (
"nofx/market"
)
// ResampleEquity 根据时间周期重采样资金曲线。
// ResampleEquity resamples equity curve based on timeframe.
func ResampleEquity(points []EquityPoint, timeframe string) ([]EquityPoint, error) {
if timeframe == "" {
return points, nil
@@ -49,7 +49,7 @@ func ResampleEquity(points []EquityPoint, timeframe string) ([]EquityPoint, erro
return resampled, nil
}
// LimitEquityPoints 将数据点数量限制在给定范围内(均匀抽样)。
// LimitEquityPoints limits the number of data points within a given range (uniform sampling).
func LimitEquityPoints(points []EquityPoint, limit int) []EquityPoint {
if limit <= 0 || len(points) <= limit {
return points
@@ -68,7 +68,7 @@ func LimitEquityPoints(points []EquityPoint, limit int) []EquityPoint {
return result
}
// LimitTradeEvents 同样对交易事件按均匀抽样。
// LimitTradeEvents applies uniform sampling to trade events.
func LimitTradeEvents(events []TradeEvent, limit int) []TradeEvent {
if limit <= 0 || len(events) <= limit {
return events
@@ -86,7 +86,7 @@ func LimitTradeEvents(events []TradeEvent, limit int) []TradeEvent {
return result
}
// AlignEquityTimestamps 确保时间戳按升序排列。
// AlignEquityTimestamps ensures timestamps are sorted in ascending order.
func AlignEquityTimestamps(points []EquityPoint) []EquityPoint {
sort.Slice(points, func(i, j int) bool {
return points[i].Timestamp < points[j].Timestamp
+1 -1
View File
@@ -15,7 +15,7 @@ const (
lockStaleAfter = 10 * time.Second
)
// RunLockInfo 表示回测运行的锁文件结构。
// RunLockInfo represents the lock file structure for a backtest run.
type RunLockInfo struct {
RunID string `json:"run_id"`
PID int `json:"pid"`
+3 -3
View File
@@ -438,7 +438,7 @@ func (m *Manager) resolveAIConfig(cfg *BacktestConfig) error {
m.mu.RUnlock()
if resolver == nil {
if apiKey == "" {
return fmt.Errorf("AI配置缺少密钥且未配置解析器")
return fmt.Errorf("AI configuration missing key and no resolver configured")
}
return nil
}
@@ -453,7 +453,7 @@ func (m *Manager) ExportRun(runID string) (string, error) {
return CreateRunExport(runID)
}
// RestoreRunsFromDisk 扫描 backtests 目录并恢复现有 run 的元数据(服务重启场景)。
// RestoreRuns scans the backtests directory and restores metadata for existing runs (service restart scenario).
func (m *Manager) RestoreRuns() error {
runIDs, err := LoadRunIDs()
if err != nil {
@@ -487,7 +487,7 @@ func (m *Manager) RestoreRuns() error {
return nil
}
// RestoreRunsFromDisk 保留旧方法名,兼容历史调用。
// RestoreRunsFromDisk retains the old method name for backward compatibility.
func (m *Manager) RestoreRunsFromDisk() error {
return m.RestoreRuns()
}
+1 -1
View File
@@ -6,7 +6,7 @@ import (
"strings"
)
// CalculateMetrics 读取已有日志并计算汇总指标。state 可选,用于补充尚未落盘的信息。
// CalculateMetrics reads existing logs and calculates summary metrics. state is optional, used to supplement information not yet persisted.
func CalculateMetrics(runID string, cfg *BacktestConfig, state *BacktestState) (*Metrics, error) {
if cfg == nil {
return nil, fmt.Errorf("config is nil")
+13 -13
View File
@@ -29,7 +29,7 @@ const (
aiDecisionMaxRetries = 3
)
// Runner 封装单次回测运行的生命周期。
// Runner encapsulates the lifecycle of a single backtest run.
type Runner struct {
cfg BacktestConfig
feed *DataFeed
@@ -63,7 +63,7 @@ type Runner struct {
lockStop chan struct{}
}
// NewRunner 构建回测运行器。
// NewRunner constructs a backtest runner.
func NewRunner(cfg BacktestConfig, mcpClient mcp.AIClient) (*Runner, error) {
if err := ensureRunDir(cfg.RunID); err != nil {
return nil, err
@@ -179,7 +179,7 @@ func (r *Runner) releaseLock() {
r.lockInfo = nil
}
// Start 启动回测循环。
// Start launches the backtest loop.
func (r *Runner) Start(ctx context.Context) error {
r.statusMu.Lock()
if r.status != RunStateCreated && r.status != RunStatePaused {
@@ -193,7 +193,7 @@ func (r *Runner) Start(ctx context.Context) error {
return nil
}
// PersistMetadata 将当前快照写入 run.json
// PersistMetadata writes the current snapshot to run.json.
func (r *Runner) PersistMetadata() {
r.persistMetadata()
}
@@ -214,7 +214,7 @@ func (r *Runner) lastErrorString() string {
return r.lastError
}
// CurrentMetadata 返回当前内存状态对应的元数据。
// CurrentMetadata returns the metadata corresponding to the current in-memory state.
func (r *Runner) CurrentMetadata() *RunMetadata {
state := r.snapshotState()
meta := r.buildMetadata(state, r.Status())
@@ -292,7 +292,7 @@ func (r *Runner) stepOnce() error {
ctx, rec, err := r.buildDecisionContext(ts, marketData, multiTF, priceMap, callCount)
if err != nil {
rec.Success = false
rec.ErrorMessage = fmt.Sprintf("构建交易上下文失败: %v", err)
rec.ErrorMessage = fmt.Sprintf("failed to build trading context: %v", err)
_ = r.logDecision(rec)
return err
}
@@ -312,7 +312,7 @@ func (r *Runner) stepOnce() error {
} else if r.cfg.ReplayOnly {
decisionErr := fmt.Errorf("replay_only enabled but cache miss at %d", ts)
record.Success = false
record.ErrorMessage = fmt.Sprintf("没有找到 ts=%d 的缓存决策", ts)
record.ErrorMessage = fmt.Sprintf("cached decision not found for ts=%d", ts)
_ = r.logDecision(record)
return decisionErr
}
@@ -327,8 +327,8 @@ func (r *Runner) stepOnce() error {
decisionAttempted = true
hadError = true
record.Success = false
record.ErrorMessage = fmt.Sprintf("AI决策失败: %v", err)
execLog = append(execLog, fmt.Sprintf("⚠️ AI决策失败: %v", err))
record.ErrorMessage = fmt.Sprintf("AI decision failed: %v", err)
execLog = append(execLog, fmt.Sprintf("⚠️ AI decision failed: %v", err))
r.setLastError(err)
} else {
fullDecision = fd
@@ -392,7 +392,7 @@ func (r *Runner) stepOnce() error {
hadError = true
tradeEvents = append(tradeEvents, liquidationEvents...)
if record != nil {
execLog = append(execLog, fmt.Sprintf("⚠️ 强制平仓: %s", liquidationNote))
execLog = append(execLog, fmt.Sprintf("⚠️ Forced liquidation: %s", liquidationNote))
}
}
@@ -690,7 +690,7 @@ func (r *Runner) executeDecision(dec decision.Decision, priceMap map[string]floa
return actionRecord, []TradeEvent{trade}, "", nil
case "hold", "wait":
return actionRecord, nil, fmt.Sprintf("保持仓位: %s", dec.Action), nil
return actionRecord, nil, fmt.Sprintf("hold position: %s", dec.Action), nil
default:
return actionRecord, nil, "", fmt.Errorf("unsupported action %s", dec.Action)
}
@@ -1078,14 +1078,14 @@ func (r *Runner) Wait() error {
return r.err
}
// Status 返回当前运行状态。
// Status returns the current run state.
func (r *Runner) Status() RunState {
r.statusMu.RLock()
defer r.statusMu.RUnlock()
return r.status
}
// StatusPayload 构建用于 API 的状态响应。
// StatusPayload builds the status response for the API.
func (r *Runner) StatusPayload() StatusPayload {
snapshot := r.snapshotState()
progress := progressPercent(snapshot, r.cfg)
+4 -4
View File
@@ -132,7 +132,7 @@ func appendJSONLine(path string, payload any) error {
return f.Sync()
}
// SaveCheckpoint 将检查点写入磁盘。
// SaveCheckpoint writes the checkpoint to disk.
func SaveCheckpoint(runID string, ckpt *Checkpoint) error {
if ckpt == nil {
return fmt.Errorf("checkpoint is nil")
@@ -143,7 +143,7 @@ func SaveCheckpoint(runID string, ckpt *Checkpoint) error {
return writeJSONAtomic(checkpointPath(runID), ckpt)
}
// LoadCheckpoint 读取最近一次检查点。
// LoadCheckpoint reads the most recent checkpoint.
func LoadCheckpoint(runID string) (*Checkpoint, error) {
if usingDB() {
return loadCheckpointDB(runID)
@@ -160,7 +160,7 @@ func LoadCheckpoint(runID string) (*Checkpoint, error) {
return &ckpt, nil
}
// SaveRunMetadata 写入 run.json
// SaveRunMetadata writes to run.json.
func SaveRunMetadata(meta *RunMetadata) error {
if meta == nil {
return fmt.Errorf("run metadata is nil")
@@ -178,7 +178,7 @@ func SaveRunMetadata(meta *RunMetadata) error {
return writeJSONAtomic(runMetadataPath(meta.RunID), meta)
}
// LoadRunMetadata 读取 run.json
// LoadRunMetadata reads run.json.
func LoadRunMetadata(runID string) (*RunMetadata, error) {
if usingDB() {
return loadRunMetadataDB(runID)
+11 -11
View File
@@ -2,7 +2,7 @@ package backtest
import "time"
// RunState 表示回测运行当前状态。
// RunState represents the current state of a backtest run.
type RunState string
const (
@@ -15,7 +15,7 @@ const (
RunStateLiquidated RunState = "liquidated"
)
// PositionSnapshot 表示当前持仓的核心数据,用于回测状态与持久化。
// PositionSnapshot represents core position data for backtest state and persistence.
type PositionSnapshot struct {
Symbol string `json:"symbol"`
Side string `json:"side"`
@@ -27,7 +27,7 @@ type PositionSnapshot struct {
OpenTime int64 `json:"open_time"`
}
// BacktestState 表示执行过程中的实时状态(内存态)。
// BacktestState represents the real-time state during execution (in-memory state).
type BacktestState struct {
BarIndex int
BarTimestamp int64
@@ -46,7 +46,7 @@ type BacktestState struct {
LiquidationNote string
}
// EquityPoint 表示资金曲线中的单个节点。
// EquityPoint represents a single point on the equity curve.
type EquityPoint struct {
Timestamp int64 `json:"ts"`
Equity float64 `json:"equity"`
@@ -57,7 +57,7 @@ type EquityPoint struct {
Cycle int `json:"cycle"`
}
// TradeEvent 记录一次交易执行结果或特殊事件(如爆仓)。
// TradeEvent records a trade execution result or special event (such as liquidation).
type TradeEvent struct {
Timestamp int64 `json:"ts"`
Symbol string `json:"symbol"`
@@ -76,7 +76,7 @@ type TradeEvent struct {
Note string `json:"note,omitempty"`
}
// Metrics 汇总回测表现指标。
// Metrics summarizes backtest performance metrics.
type Metrics struct {
TotalReturnPct float64 `json:"total_return_pct"`
MaxDrawdownPct float64 `json:"max_drawdown_pct"`
@@ -92,7 +92,7 @@ type Metrics struct {
Liquidated bool `json:"liquidated"`
}
// SymbolMetrics 记录单个标的的表现。
// SymbolMetrics records performance for a single symbol.
type SymbolMetrics struct {
TotalTrades int `json:"total_trades"`
WinningTrades int `json:"winning_trades"`
@@ -102,7 +102,7 @@ type SymbolMetrics struct {
WinRate float64 `json:"win_rate"`
}
// Checkpoint 表示磁盘保存的检查点信息,用于暂停、恢复与崩溃恢复。
// Checkpoint represents checkpoint information saved to disk for pause, resume, and crash recovery.
type Checkpoint struct {
BarIndex int `json:"bar_index"`
BarTimestamp int64 `json:"bar_ts"`
@@ -122,7 +122,7 @@ type Checkpoint struct {
LiquidationNote string `json:"liquidation_note,omitempty"`
}
// RunMetadata 记录 run.json 所需摘要。
// RunMetadata records the summary required for run.json.
type RunMetadata struct {
RunID string `json:"run_id"`
Label string `json:"label,omitempty"`
@@ -135,7 +135,7 @@ type RunMetadata struct {
Summary RunSummary `json:"summary"`
}
// RunSummary 为 run.json 中的 summary 字段。
// RunSummary represents the summary field in run.json.
type RunSummary struct {
SymbolCount int `json:"symbol_count"`
DecisionTF string `json:"decision_tf"`
@@ -147,7 +147,7 @@ type RunSummary struct {
LiquidationNote string `json:"liquidation_note,omitempty"`
}
// StatusPayload 用于 /status API 的响应。
// StatusPayload is used for /status API responses.
type StatusPayload struct {
RunID string `json:"run_id"`
State RunState `json:"state"`
+7 -7
View File
@@ -6,26 +6,26 @@ import (
"strings"
)
// 全局配置实例
// Global configuration instance
var global *Config
// Config 全局配置(从 .env 加载)
// 只包含真正的全局配置,交易相关配置在 trader/策略 级别
// Config is the global configuration (loaded from .env)
// Only contains truly global config, trading related config is at trader/strategy level
type Config struct {
// 服务配置
// Service configuration
APIServerPort int
JWTSecret string
RegistrationEnabled bool
}
// Init 初始化全局配置(从 .env 加载)
// Init initializes global configuration (from .env)
func Init() {
cfg := &Config{
APIServerPort: 8080,
RegistrationEnabled: true,
}
// 从环境变量加载
// Load from environment variables
if v := os.Getenv("JWT_SECRET"); v != "" {
cfg.JWTSecret = strings.TrimSpace(v)
}
@@ -46,7 +46,7 @@ func Init() {
global = cfg
}
// Get 获取全局配置
// Get returns the global configuration
func Get() *Config {
if global == nil {
Init()
+50 -50
View File
@@ -23,10 +23,10 @@ const (
storageDelimiter = ":"
)
// 环境变量名称
// Environment variable names
const (
EnvDataEncryptionKey = "DATA_ENCRYPTION_KEY" // AES 数据加密密钥 (Base64)
EnvRSAPrivateKey = "RSA_PRIVATE_KEY" // RSA 私钥 (PEM 格式,换行用 \n)
EnvDataEncryptionKey = "DATA_ENCRYPTION_KEY" // AES data encryption key (Base64)
EnvRSAPrivateKey = "RSA_PRIVATE_KEY" // RSA private key (PEM format, use \n for newlines)
)
type EncryptedPayload struct {
@@ -51,18 +51,18 @@ type CryptoService struct {
dataKey []byte
}
// NewCryptoService 创建加密服务(从环境变量加载密钥)
// NewCryptoService creates crypto service (loads keys from environment variables)
func NewCryptoService() (*CryptoService, error) {
// 1. 加载 RSA 私钥
// 1. Load RSA private key
privateKey, err := loadRSAPrivateKeyFromEnv()
if err != nil {
return nil, fmt.Errorf("RSA 私钥加载失败: %w", err)
return nil, fmt.Errorf("failed to load RSA private key: %w", err)
}
// 2. 加载 AES 数据加密密钥
// 2. Load AES data encryption key
dataKey, err := loadDataKeyFromEnv()
if err != nil {
return nil, fmt.Errorf("数据加密密钥加载失败: %w", err)
return nil, fmt.Errorf("failed to load data encryption key: %w", err)
}
return &CryptoService{
@@ -72,43 +72,43 @@ func NewCryptoService() (*CryptoService, error) {
}, nil
}
// loadRSAPrivateKeyFromEnv 从环境变量加载 RSA 私钥
// loadRSAPrivateKeyFromEnv loads RSA private key from environment variable
func loadRSAPrivateKeyFromEnv() (*rsa.PrivateKey, error) {
keyPEM := os.Getenv(EnvRSAPrivateKey)
if keyPEM == "" {
return nil, fmt.Errorf("环境变量 %s 未设置,请在 .env 中配置 RSA 私钥", EnvRSAPrivateKey)
return nil, fmt.Errorf("environment variable %s not set, please configure RSA private key in .env", EnvRSAPrivateKey)
}
// 处理环境变量中的换行符(\n -> 实际换行)
// Handle newlines in environment variable (\n -> actual newline)
keyPEM = strings.ReplaceAll(keyPEM, "\\n", "\n")
return ParseRSAPrivateKeyFromPEM([]byte(keyPEM))
}
// loadDataKeyFromEnv 从环境变量加载 AES 数据加密密钥
// loadDataKeyFromEnv loads AES data encryption key from environment variable
func loadDataKeyFromEnv() ([]byte, error) {
keyStr := strings.TrimSpace(os.Getenv(EnvDataEncryptionKey))
if keyStr == "" {
return nil, fmt.Errorf("环境变量 %s 未设置,请在 .env 中配置数据加密密钥", EnvDataEncryptionKey)
return nil, fmt.Errorf("environment variable %s not set, please configure data encryption key in .env", EnvDataEncryptionKey)
}
// 尝试解码
// Try to decode
if key, ok := decodePossibleKey(keyStr); ok {
return key, nil
}
// 如果无法解码,使用 SHA256 哈希作为密钥
// If decoding fails, use SHA256 hash as key
sum := sha256.Sum256([]byte(keyStr))
key := make([]byte, len(sum))
copy(key, sum[:])
return key, nil
}
// ParseRSAPrivateKeyFromPEM 解析 PEM 格式的 RSA 私钥
// ParseRSAPrivateKeyFromPEM parses RSA private key from PEM format
func ParseRSAPrivateKeyFromPEM(pemBytes []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(pemBytes)
if block == nil {
return nil, errors.New("无效的 PEM 格式")
return nil, errors.New("invalid PEM format")
}
switch block.Type {
@@ -121,15 +121,15 @@ func ParseRSAPrivateKeyFromPEM(pemBytes []byte) (*rsa.PrivateKey, error) {
}
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("不是 RSA 密钥")
return nil, errors.New("not an RSA key")
}
return rsaKey, nil
default:
return nil, errors.New("不支持的密钥类型: " + block.Type)
return nil, errors.New("unsupported key type: " + block.Type)
}
}
// decodePossibleKey 尝试用多种编码方式解码密钥
// decodePossibleKey tries to decode key using multiple encoding methods
func decodePossibleKey(value string) ([]byte, bool) {
decoders := []func(string) ([]byte, error){
base64.StdEncoding.DecodeString,
@@ -148,7 +148,7 @@ func decodePossibleKey(value string) ([]byte, bool) {
return nil, false
}
// normalizeAESKey 标准化 AES 密钥长度
// normalizeAESKey normalizes AES key length
func normalizeAESKey(raw []byte) ([]byte, bool) {
switch len(raw) {
case 16, 24, 32:
@@ -186,7 +186,7 @@ func (cs *CryptoService) EncryptForStorage(plaintext string, aadParts ...string)
return "", nil
}
if !cs.HasDataKey() {
return "", errors.New("数据加密密钥未配置")
return "", errors.New("data encryption key not configured")
}
if isEncryptedStorageValue(plaintext) {
return plaintext, nil
@@ -220,26 +220,26 @@ func (cs *CryptoService) DecryptFromStorage(value string, aadParts ...string) (s
return "", nil
}
if !cs.HasDataKey() {
return "", errors.New("数据加密密钥未配置")
return "", errors.New("data encryption key not configured")
}
if !isEncryptedStorageValue(value) {
return "", errors.New("数据未加密")
return "", errors.New("data not encrypted")
}
payload := strings.TrimPrefix(value, storagePrefix)
parts := strings.SplitN(payload, storageDelimiter, 2)
if len(parts) != 2 {
return "", errors.New("无效的加密数据格式")
return "", errors.New("invalid encrypted data format")
}
nonce, err := base64.StdEncoding.DecodeString(parts[0])
if err != nil {
return "", fmt.Errorf("解码 nonce 失败: %w", err)
return "", fmt.Errorf("failed to decode nonce: %w", err)
}
ciphertext, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("解码密文失败: %w", err)
return "", fmt.Errorf("failed to decode ciphertext: %w", err)
}
block, err := aes.NewCipher(cs.dataKey)
@@ -253,13 +253,13 @@ func (cs *CryptoService) DecryptFromStorage(value string, aadParts ...string) (s
}
if len(nonce) != gcm.NonceSize() {
return "", fmt.Errorf("无效的 nonce 长度: 期望 %d, 实际 %d", gcm.NonceSize(), len(nonce))
return "", fmt.Errorf("invalid nonce length: expected %d, got %d", gcm.NonceSize(), len(nonce))
}
aad := composeAAD(aadParts)
plaintext, err := gcm.Open(nil, nonce, ciphertext, aad)
if err != nil {
return "", fmt.Errorf("解密失败: %w", err)
return "", fmt.Errorf("decryption failed: %w", err)
}
return string(plaintext), nil
@@ -281,67 +281,67 @@ func isEncryptedStorageValue(value string) bool {
}
func (cs *CryptoService) DecryptPayload(payload *EncryptedPayload) ([]byte, error) {
// 1. 验证时间戳(防止重放攻击)
// 1. Validate timestamp (prevent replay attacks)
if payload.TS != 0 {
elapsed := time.Since(time.Unix(payload.TS, 0))
if elapsed > 5*time.Minute || elapsed < -1*time.Minute {
return nil, errors.New("时间戳无效或已过期")
return nil, errors.New("timestamp invalid or expired")
}
}
// 2. 解码 base64url
// 2. Decode base64url
wrappedKey, err := base64.RawURLEncoding.DecodeString(payload.WrappedKey)
if err != nil {
return nil, fmt.Errorf("解码 wrapped key 失败: %w", err)
return nil, fmt.Errorf("failed to decode wrapped key: %w", err)
}
iv, err := base64.RawURLEncoding.DecodeString(payload.IV)
if err != nil {
return nil, fmt.Errorf("解码 IV 失败: %w", err)
return nil, fmt.Errorf("failed to decode IV: %w", err)
}
ciphertext, err := base64.RawURLEncoding.DecodeString(payload.Ciphertext)
if err != nil {
return nil, fmt.Errorf("解码密文失败: %w", err)
return nil, fmt.Errorf("failed to decode ciphertext: %w", err)
}
var aad []byte
if payload.AAD != "" {
aad, err = base64.RawURLEncoding.DecodeString(payload.AAD)
if err != nil {
return nil, fmt.Errorf("解码 AAD 失败: %w", err)
return nil, fmt.Errorf("failed to decode AAD: %w", err)
}
var aadData AADData
if err := json.Unmarshal(aad, &aadData); err == nil {
// 可以在这里添加额外的验证逻辑
// Additional validation logic can be added here
}
}
// 3. 使用 RSA-OAEP 解密 AES 密钥
// 3. Decrypt AES key using RSA-OAEP
aesKey, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, cs.privateKey, wrappedKey, nil)
if err != nil {
return nil, fmt.Errorf("RSA 解密失败: %w", err)
return nil, fmt.Errorf("RSA decryption failed: %w", err)
}
// 4. 使用 AES-GCM 解密数据
// 4. Decrypt data using AES-GCM
block, err := aes.NewCipher(aesKey)
if err != nil {
return nil, fmt.Errorf("创建 AES cipher 失败: %w", err)
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("创建 GCM 失败: %w", err)
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
if len(iv) != gcm.NonceSize() {
return nil, fmt.Errorf("无效的 IV 长度: 期望 %d, 实际 %d", gcm.NonceSize(), len(iv))
return nil, fmt.Errorf("invalid IV length: expected %d, got %d", gcm.NonceSize(), len(iv))
}
plaintext, err := gcm.Open(nil, iv, ciphertext, aad)
if err != nil {
return nil, fmt.Errorf("解密验证失败: %w", err)
return nil, fmt.Errorf("decryption verification failed: %w", err)
}
return plaintext, nil
@@ -355,21 +355,21 @@ func (cs *CryptoService) DecryptSensitiveData(payload *EncryptedPayload) (string
return string(plaintext), nil
}
// GenerateKeyPair 生成 RSA 密钥对(用于初始化时生成密钥)
// 返回 PEM 格式的私钥和公钥
// GenerateKeyPair generates RSA key pair (for key generation during initialization)
// Returns PEM format private key and public key
func GenerateKeyPair() (privateKeyPEM, publicKeyPEM string, err error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return "", "", err
}
// 编码私钥
// Encode private key
privPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
// 编码公钥
// Encode public key
publicKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
if err != nil {
return "", "", err
@@ -383,8 +383,8 @@ func GenerateKeyPair() (privateKeyPEM, publicKeyPEM string, err error) {
return string(privPEM), string(pubPEM), nil
}
// GenerateDataKey 生成 AES 数据加密密钥
// 返回 Base64 编码的 32 字节密钥
// GenerateDataKey generates AES data encryption key
// Returns Base64 encoded 32-byte key
func GenerateDataKey() (string, error) {
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
+329 -329
View File
File diff suppressed because it is too large Load Diff
+32 -32
View File
@@ -9,102 +9,102 @@ import (
"sync"
)
// PromptTemplate 系统提示词模板
// PromptTemplate system prompt template
type PromptTemplate struct {
Name string // 模板名称(文件名,不含扩展名)
Content string // 模板内容
Name string // Template name (filename without extension)
Content string // Template content
}
// PromptManager 提示词管理器
// PromptManager prompt manager
type PromptManager struct {
templates map[string]*PromptTemplate
mu sync.RWMutex
}
var (
// globalPromptManager 全局提示词管理器
// globalPromptManager global prompt manager
globalPromptManager *PromptManager
// promptsDir 提示词文件夹路径
// promptsDir prompt folder path
promptsDir = "prompts"
)
// init 包初始化时加载所有提示词模板
// init loads all prompt templates during package initialization
func init() {
globalPromptManager = NewPromptManager()
if err := globalPromptManager.LoadTemplates(promptsDir); err != nil {
log.Printf("⚠️ 加载提示词模板失败: %v", err)
log.Printf("⚠️ Failed to load prompt templates: %v", err)
} else {
log.Printf("✓ 已加载 %d 个系统提示词模板", len(globalPromptManager.templates))
log.Printf("✓ Loaded %d system prompt templates", len(globalPromptManager.templates))
}
}
// NewPromptManager 创建提示词管理器
// NewPromptManager creates a prompt manager
func NewPromptManager() *PromptManager {
return &PromptManager{
templates: make(map[string]*PromptTemplate),
}
}
// LoadTemplates 从指定目录加载所有提示词模板
// LoadTemplates loads all prompt templates from specified directory
func (pm *PromptManager) LoadTemplates(dir string) error {
pm.mu.Lock()
defer pm.mu.Unlock()
// 检查目录是否存在
// Check if directory exists
if _, err := os.Stat(dir); os.IsNotExist(err) {
return fmt.Errorf("提示词目录不存在: %s", dir)
return fmt.Errorf("prompt directory does not exist: %s", dir)
}
// 扫描目录中的所有 .txt 文件
// Scan all .txt files in directory
files, err := filepath.Glob(filepath.Join(dir, "*.txt"))
if err != nil {
return fmt.Errorf("扫描提示词目录失败: %w", err)
return fmt.Errorf("failed to scan prompt directory: %w", err)
}
if len(files) == 0 {
log.Printf("⚠️ 提示词目录 %s 中没有找到 .txt 文件", dir)
log.Printf("⚠️ No .txt files found in prompt directory %s", dir)
return nil
}
// 加载每个模板文件
// Load each template file
for _, file := range files {
// 读取文件内容
// Read file content
content, err := os.ReadFile(file)
if err != nil {
log.Printf("⚠️ 读取提示词文件失败 %s: %v", file, err)
log.Printf("⚠️ Failed to read prompt file %s: %v", file, err)
continue
}
// 提取文件名(不含扩展名)作为模板名称
// Extract filename (without extension) as template name
fileName := filepath.Base(file)
templateName := strings.TrimSuffix(fileName, filepath.Ext(fileName))
// 存储模板
// Store template
pm.templates[templateName] = &PromptTemplate{
Name: templateName,
Content: string(content),
}
log.Printf(" 📄 加载提示词模板: %s (%s)", templateName, fileName)
log.Printf(" 📄 Loaded prompt template: %s (%s)", templateName, fileName)
}
return nil
}
// GetTemplate 获取指定名称的提示词模板
// GetTemplate gets prompt template by name
func (pm *PromptManager) GetTemplate(name string) (*PromptTemplate, error) {
pm.mu.RLock()
defer pm.mu.RUnlock()
template, exists := pm.templates[name]
if !exists {
return nil, fmt.Errorf("提示词模板不存在: %s", name)
return nil, fmt.Errorf("prompt template does not exist: %s", name)
}
return template, nil
}
// GetAllTemplateNames 获取所有模板名称列表
// GetAllTemplateNames gets all template names list
func (pm *PromptManager) GetAllTemplateNames() []string {
pm.mu.RLock()
defer pm.mu.RUnlock()
@@ -117,7 +117,7 @@ func (pm *PromptManager) GetAllTemplateNames() []string {
return names
}
// GetAllTemplates 获取所有模板
// GetAllTemplates gets all templates
func (pm *PromptManager) GetAllTemplates() []*PromptTemplate {
pm.mu.RLock()
defer pm.mu.RUnlock()
@@ -130,7 +130,7 @@ func (pm *PromptManager) GetAllTemplates() []*PromptTemplate {
return templates
}
// ReloadTemplates 重新加载所有模板
// ReloadTemplates reloads all templates
func (pm *PromptManager) ReloadTemplates(dir string) error {
pm.mu.Lock()
pm.templates = make(map[string]*PromptTemplate)
@@ -139,24 +139,24 @@ func (pm *PromptManager) ReloadTemplates(dir string) error {
return pm.LoadTemplates(dir)
}
// === 全局函数(供外部调用)===
// === Global functions (for external calls) ===
// GetPromptTemplate 获取指定名称的提示词模板(全局函数)
// GetPromptTemplate gets prompt template by name (global function)
func GetPromptTemplate(name string) (*PromptTemplate, error) {
return globalPromptManager.GetTemplate(name)
}
// GetAllPromptTemplateNames 获取所有模板名称(全局函数)
// GetAllPromptTemplateNames gets all template names (global function)
func GetAllPromptTemplateNames() []string {
return globalPromptManager.GetAllTemplateNames()
}
// GetAllPromptTemplates 获取所有模板(全局函数)
// GetAllPromptTemplates gets all templates (global function)
func GetAllPromptTemplates() []*PromptTemplate {
return globalPromptManager.GetAllTemplates()
}
// ReloadPromptTemplates 重新加载所有模板(全局函数)
// ReloadPromptTemplates reloads all templates (global function)
func ReloadPromptTemplates() error {
return globalPromptManager.ReloadTemplates(promptsDir)
}
+77 -77
View File
@@ -7,49 +7,49 @@ import (
)
func TestPromptManager_LoadTemplates(t *testing.T) {
// 创建临时目录用于测试
// Create temporary directory for testing
tempDir := t.TempDir()
tests := []struct {
name string
setupFiles map[string]string // 文件名 -> 内容
setupFiles map[string]string // filename -> content
expectedCount int
expectedNames []string
shouldError bool
}{
{
name: "加载单个模板文件",
name: "Load single template file",
setupFiles: map[string]string{
"default.txt": "你是专业的加密货币交易AI",
"default.txt": "You are a professional cryptocurrency trading AI.",
},
expectedCount: 1,
expectedNames: []string{"default"},
shouldError: false,
},
{
name: "加载多个模板文件",
name: "Load multiple template files",
setupFiles: map[string]string{
"default.txt": "默认策略",
"conservative.txt": "保守策略",
"aggressive.txt": "激进策略",
"default.txt": "Default strategy",
"conservative.txt": "Conservative strategy",
"aggressive.txt": "Aggressive strategy",
},
expectedCount: 3,
expectedNames: []string{"default", "conservative", "aggressive"},
shouldError: false,
},
{
name: "空目录",
name: "Empty directory",
setupFiles: map[string]string{},
expectedCount: 0,
expectedNames: []string{},
shouldError: false,
},
{
name: "忽略非.txt文件",
name: "Ignore non-.txt files",
setupFiles: map[string]string{
"default.txt": "正确的模板",
"readme.md": "应该被忽略",
"config.json": "应该被忽略",
"default.txt": "Correct template",
"readme.md": "Should be ignored",
"config.json": "Should be ignored",
},
expectedCount: 1,
expectedNames: []string{"default"},
@@ -59,57 +59,57 @@ func TestPromptManager_LoadTemplates(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 为每个测试用例创建独立的子目录
// Create independent subdirectory for each test case
testDir := filepath.Join(tempDir, tt.name)
if err := os.MkdirAll(testDir, 0755); err != nil {
t.Fatalf("创建测试目录失败: %v", err)
t.Fatalf("Failed to create test directory: %v", err)
}
// 设置测试文件
// Setup test files
for filename, content := range tt.setupFiles {
filePath := filepath.Join(testDir, filename)
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
t.Fatalf("创建测试文件失败 %s: %v", filename, err)
t.Fatalf("Failed to create test file %s: %v", filename, err)
}
}
// 创建新的 PromptManager
// Create new PromptManager
pm := NewPromptManager()
// 执行测试
// Execute test
err := pm.LoadTemplates(testDir)
// 检查错误
// Check error
if (err != nil) != tt.shouldError {
t.Errorf("LoadTemplates() error = %v, shouldError %v", err, tt.shouldError)
return
}
// 检查加载的模板数量
// Check loaded template count
if len(pm.templates) != tt.expectedCount {
t.Errorf("加载的模板数量 = %d, 期望 %d", len(pm.templates), tt.expectedCount)
t.Errorf("Loaded template count = %d, expected %d", len(pm.templates), tt.expectedCount)
}
// 检查模板名称
// Check template names
for _, expectedName := range tt.expectedNames {
if _, exists := pm.templates[expectedName]; !exists {
t.Errorf("缺少预期的模板: %s", expectedName)
t.Errorf("Missing expected template: %s", expectedName)
}
}
// 验证模板内容
// Verify template content
for filename, expectedContent := range tt.setupFiles {
if filepath.Ext(filename) != ".txt" {
continue
}
templateName := filename[:len(filename)-4] // 去掉 .txt
templateName := filename[:len(filename)-4] // Remove .txt
template, err := pm.GetTemplate(templateName)
if err != nil {
t.Errorf("获取模板 %s 失败: %v", templateName, err)
t.Errorf("Failed to get template %s: %v", templateName, err)
continue
}
if template.Content != expectedContent {
t.Errorf("模板内容不匹配\n期望: %s\n实际: %s", expectedContent, template.Content)
t.Errorf("Template content mismatch\nExpected: %s\nActual: %s", expectedContent, template.Content)
}
}
})
@@ -121,11 +121,11 @@ func TestPromptManager_GetTemplate(t *testing.T) {
pm.templates = map[string]*PromptTemplate{
"default": {
Name: "default",
Content: "默认策略内容",
Content: "Default strategy content",
},
"aggressive": {
Name: "aggressive",
Content: "激进策略内容",
Content: "Aggressive strategy content",
},
}
@@ -136,13 +136,13 @@ func TestPromptManager_GetTemplate(t *testing.T) {
expectedContent string
}{
{
name: "获取存在的模板",
name: "Get existing template",
templateName: "default",
expectError: false,
expectedContent: "默认策略内容",
expectedContent: "Default strategy content",
},
{
name: "获取不存在的模板",
name: "Get non-existent template",
templateName: "nonexistent",
expectError: true,
},
@@ -158,7 +158,7 @@ func TestPromptManager_GetTemplate(t *testing.T) {
}
if !tt.expectError && template.Content != tt.expectedContent {
t.Errorf("模板内容 = %s, 期望 %s", template.Content, tt.expectedContent)
t.Errorf("Template content = %s, expected %s", template.Content, tt.expectedContent)
}
})
}
@@ -167,76 +167,76 @@ func TestPromptManager_GetTemplate(t *testing.T) {
func TestPromptManager_ReloadTemplates(t *testing.T) {
tempDir := t.TempDir()
// 初始文件
if err := os.WriteFile(filepath.Join(tempDir, "default.txt"), []byte("初始内容"), 0644); err != nil {
t.Fatalf("创建初始文件失败: %v", err)
// Initial file
if err := os.WriteFile(filepath.Join(tempDir, "default.txt"), []byte("Initial content"), 0644); err != nil {
t.Fatalf("Failed to create initial file: %v", err)
}
pm := NewPromptManager()
if err := pm.LoadTemplates(tempDir); err != nil {
t.Fatalf("初始加载失败: %v", err)
t.Fatalf("Initial load failed: %v", err)
}
// 验证初始内容
// Verify initial content
template, _ := pm.GetTemplate("default")
if template.Content != "初始内容" {
t.Errorf("初始内容不正确: %s", template.Content)
if template.Content != "Initial content" {
t.Errorf("Initial content incorrect: %s", template.Content)
}
// 修改文件内容
if err := os.WriteFile(filepath.Join(tempDir, "default.txt"), []byte("更新后内容"), 0644); err != nil {
t.Fatalf("更新文件失败: %v", err)
// Modify file content
if err := os.WriteFile(filepath.Join(tempDir, "default.txt"), []byte("Updated content"), 0644); err != nil {
t.Fatalf("Failed to update file: %v", err)
}
// 添加新文件
if err := os.WriteFile(filepath.Join(tempDir, "new.txt"), []byte("新模板内容"), 0644); err != nil {
t.Fatalf("创建新文件失败: %v", err)
// Add new file
if err := os.WriteFile(filepath.Join(tempDir, "new.txt"), []byte("New template content"), 0644); err != nil {
t.Fatalf("Failed to create new file: %v", err)
}
// 重新加载
// Reload
if err := pm.ReloadTemplates(tempDir); err != nil {
t.Fatalf("重新加载失败: %v", err)
t.Fatalf("Reload failed: %v", err)
}
// 验证更新后的内容
// Verify updated content
template, err := pm.GetTemplate("default")
if err != nil {
t.Fatalf("获取 default 模板失败: %v", err)
t.Fatalf("Failed to get default template: %v", err)
}
if template.Content != "更新后内容" {
t.Errorf("重新加载后内容不正确: got %s, want '更新后内容'", template.Content)
if template.Content != "Updated content" {
t.Errorf("Content after reload incorrect: got %s, want 'Updated content'", template.Content)
}
// 验证新模板
// Verify new template
newTemplate, err := pm.GetTemplate("new")
if err != nil {
t.Fatalf("获取 new 模板失败: %v", err)
t.Fatalf("Failed to get new template: %v", err)
}
if newTemplate.Content != "新模板内容" {
t.Errorf("新模板内容不正确: %s", newTemplate.Content)
if newTemplate.Content != "New template content" {
t.Errorf("New template content incorrect: %s", newTemplate.Content)
}
// 验证模板数量
// Verify template count
if len(pm.templates) != 2 {
t.Errorf("重新加载后模板数量 = %d, 期望 2", len(pm.templates))
t.Errorf("Template count after reload = %d, expected 2", len(pm.templates))
}
}
func TestPromptManager_GetAllTemplateNames(t *testing.T) {
pm := NewPromptManager()
pm.templates = map[string]*PromptTemplate{
"default": {Name: "default", Content: "默认策略"},
"conservative": {Name: "conservative", Content: "保守策略"},
"aggressive": {Name: "aggressive", Content: "激进策略"},
"default": {Name: "default", Content: "Default strategy"},
"conservative": {Name: "conservative", Content: "Conservative strategy"},
"aggressive": {Name: "aggressive", Content: "Aggressive strategy"},
}
names := pm.GetAllTemplateNames()
if len(names) != 3 {
t.Errorf("GetAllTemplateNames() 返回数量 = %d, 期望 3", len(names))
t.Errorf("GetAllTemplateNames() returned count = %d, expected 3", len(names))
}
// 验证所有名称都存在
// Verify all names exist
nameMap := make(map[string]bool)
for _, name := range names {
nameMap[name] = true
@@ -245,41 +245,41 @@ func TestPromptManager_GetAllTemplateNames(t *testing.T) {
expectedNames := []string{"default", "conservative", "aggressive"}
for _, expectedName := range expectedNames {
if !nameMap[expectedName] {
t.Errorf("缺少预期的模板名称: %s", expectedName)
t.Errorf("Missing expected template name: %s", expectedName)
}
}
}
func TestReloadPromptTemplates_GlobalFunction(t *testing.T) {
// 保存原始的 promptsDir
// Save original promptsDir
originalDir := promptsDir
defer func() {
promptsDir = originalDir
// 恢复原始模板
// Restore original templates
globalPromptManager.ReloadTemplates(originalDir)
}()
// 创建临时目录
// Create temporary directory
tempDir := t.TempDir()
promptsDir = tempDir
// 创建测试文件
if err := os.WriteFile(filepath.Join(tempDir, "test.txt"), []byte("测试内容"), 0644); err != nil {
t.Fatalf("创建测试文件失败: %v", err)
// Create test file
if err := os.WriteFile(filepath.Join(tempDir, "test.txt"), []byte("Test content"), 0644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
// 调用全局重新加载函数
// Call global reload function
if err := ReloadPromptTemplates(); err != nil {
t.Fatalf("ReloadPromptTemplates() 失败: %v", err)
t.Fatalf("ReloadPromptTemplates() failed: %v", err)
}
// 验证全局管理器已更新
// Verify global manager has been updated
template, err := GetPromptTemplate("test")
if err != nil {
t.Fatalf("获取模板失败: %v", err)
t.Fatalf("Failed to get template: %v", err)
}
if template.Content != "测试内容" {
t.Errorf("模板内容不正确: got %s, want '测试内容'", template.Content)
if template.Content != "Test content" {
t.Errorf("Template content incorrect: got %s, want 'Test content'", template.Content)
}
}
+77 -77
View File
@@ -7,205 +7,205 @@ import (
"testing"
)
// TestPromptReloadEndToEnd 端到端测试:验证从文件修改到决策引擎使用的完整流程
// TestPromptReloadEndToEnd end-to-end test: verify complete flow from file modification to decision engine usage
func TestPromptReloadEndToEnd(t *testing.T) {
// 保存原始的 promptsDir
// Save original promptsDir
originalDir := promptsDir
defer func() {
promptsDir = originalDir
// 恢复原始模板
// Restore original templates
globalPromptManager.ReloadTemplates(originalDir)
}()
// 创建临时目录模拟 prompts/ 目录
// Create temporary directory to simulate prompts/ directory
tempDir := t.TempDir()
promptsDir = tempDir
// 步骤1: 创建初始 prompt 文件
initialContent := "# 初始交易策略\n你是一个保守的交易AI"
// Step 1: Create initial prompt file
initialContent := "# Initial Trading Strategy\nYou are a conservative trading AI."
if err := os.WriteFile(filepath.Join(tempDir, "test_strategy.txt"), []byte(initialContent), 0644); err != nil {
t.Fatalf("创建初始文件失败: %v", err)
t.Fatalf("Failed to create initial file: %v", err)
}
// 步骤2: 首次加载(模拟系统启动)
// Step 2: First load (simulate system startup)
if err := ReloadPromptTemplates(); err != nil {
t.Fatalf("首次加载失败: %v", err)
t.Fatalf("First load failed: %v", err)
}
// 步骤3: 验证初始内容
// Step 3: Verify initial content
template, err := GetPromptTemplate("test_strategy")
if err != nil {
t.Fatalf("获取初始模板失败: %v", err)
t.Fatalf("Failed to get initial template: %v", err)
}
if template.Content != initialContent {
t.Errorf("初始内容不匹配\n期望: %s\n实际: %s", initialContent, template.Content)
t.Errorf("Initial content mismatch\nExpected: %s\nActual: %s", initialContent, template.Content)
}
// 步骤4: 使用 buildSystemPrompt 验证模板被正确使用
// Step 4: Use buildSystemPrompt to verify template is correctly used
systemPrompt := buildSystemPrompt(10000.0, 10, 5, "test_strategy", "")
if !strings.Contains(systemPrompt, initialContent) {
t.Errorf("buildSystemPrompt 未包含模板内容\n生成的 prompt:\n%s", systemPrompt)
t.Errorf("buildSystemPrompt doesn't contain template content\nGenerated prompt:\n%s", systemPrompt)
}
// 步骤5: 模拟用户修改文件(这是用户在硬盘上修改 prompt)
updatedContent := "# 更新的交易策略\n你是一个激进的交易AI,追求高风险高收益。"
// Step 5: Simulate user modifying file (user modifies prompt on disk)
updatedContent := "# Updated Trading Strategy\nYou are an aggressive trading AI seeking high risk and high reward."
if err := os.WriteFile(filepath.Join(tempDir, "test_strategy.txt"), []byte(updatedContent), 0644); err != nil {
t.Fatalf("更新文件失败: %v", err)
t.Fatalf("Failed to update file: %v", err)
}
// 步骤6: 模拟交易员启动时调用 ReloadPromptTemplates()
t.Log("模拟交易员启动,调用 ReloadPromptTemplates()...")
// Step 6: Simulate trader startup calling ReloadPromptTemplates()
t.Log("Simulating trader startup, calling ReloadPromptTemplates()...")
if err := ReloadPromptTemplates(); err != nil {
t.Fatalf("重新加载失败: %v", err)
t.Fatalf("Reload failed: %v", err)
}
// 步骤7: 验证新内容已生效
// Step 7: Verify new content has taken effect
reloadedTemplate, err := GetPromptTemplate("test_strategy")
if err != nil {
t.Fatalf("获取重新加载的模板失败: %v", err)
t.Fatalf("Failed to get reloaded template: %v", err)
}
if reloadedTemplate.Content != updatedContent {
t.Errorf("重新加载后内容不匹配\n期望: %s\n实际: %s", updatedContent, reloadedTemplate.Content)
t.Errorf("Content mismatch after reload\nExpected: %s\nActual: %s", updatedContent, reloadedTemplate.Content)
}
// 步骤8: 验证 buildSystemPrompt 使用了新内容
// Step 8: Verify buildSystemPrompt uses new content
newSystemPrompt := buildSystemPrompt(10000.0, 10, 5, "test_strategy", "")
if !strings.Contains(newSystemPrompt, updatedContent) {
t.Errorf("buildSystemPrompt 未包含更新后的模板内容\n生成的 prompt:\n%s", newSystemPrompt)
t.Errorf("buildSystemPrompt doesn't contain updated template content\nGenerated prompt:\n%s", newSystemPrompt)
}
// 步骤9: 验证旧内容不再存在
if strings.Contains(newSystemPrompt, "保守的交易AI") {
t.Errorf("buildSystemPrompt 仍包含旧的模板内容")
// Step 9: Verify old content no longer exists
if strings.Contains(newSystemPrompt, "conservative trading AI") {
t.Errorf("buildSystemPrompt still contains old template content")
}
t.Log("✅ 端到端测试通过:文件修改 -> 重新加载 -> 决策引擎使用新内容")
t.Log("✅ End-to-end test passed: file modification -> reload -> decision engine uses new content")
}
// TestPromptReloadWithCustomPrompt 测试自定义 prompt 与模板重新加载的交互
// TestPromptReloadWithCustomPrompt tests interaction between custom prompt and template reload
func TestPromptReloadWithCustomPrompt(t *testing.T) {
// 保存原始的 promptsDir
// Save original promptsDir
originalDir := promptsDir
defer func() {
promptsDir = originalDir
globalPromptManager.ReloadTemplates(originalDir)
}()
// 创建临时目录
// Create temporary directory
tempDir := t.TempDir()
promptsDir = tempDir
// 创建基础模板
baseContent := "基础策略:稳健交易"
// Create base template
baseContent := "Base strategy: Stable trading"
if err := os.WriteFile(filepath.Join(tempDir, "base.txt"), []byte(baseContent), 0644); err != nil {
t.Fatalf("创建文件失败: %v", err)
t.Fatalf("Failed to create file: %v", err)
}
// 加载模板
// Load templates
if err := ReloadPromptTemplates(); err != nil {
t.Fatalf("加载失败: %v", err)
t.Fatalf("Load failed: %v", err)
}
// 测试1: 基础模板 + 自定义 prompt(不覆盖)
customPrompt := "个性化规则:只交易 BTC"
// Test 1: Base template + custom prompt (no override)
customPrompt := "Personalized rule: Only trade BTC"
result := buildSystemPromptWithCustom(10000.0, 10, 5, customPrompt, false, "base", "")
if !strings.Contains(result, baseContent) {
t.Errorf("未包含基础模板内容")
t.Errorf("Doesn't contain base template content")
}
if !strings.Contains(result, customPrompt) {
t.Errorf("未包含自定义 prompt")
t.Errorf("Doesn't contain custom prompt")
}
// 测试2: 覆盖基础 prompt
// Test 2: Override base prompt
result = buildSystemPromptWithCustom(10000.0, 10, 5, customPrompt, true, "base", "")
if strings.Contains(result, baseContent) {
t.Errorf("覆盖模式下仍包含基础模板内容")
t.Errorf("Override mode still contains base template content")
}
if !strings.Contains(result, customPrompt) {
t.Errorf("覆盖模式下未包含自定义 prompt")
t.Errorf("Override mode doesn't contain custom prompt")
}
// 测试3: 重新加载后效果
updatedBase := "更新的基础策略:激进交易"
// Test 3: Effect after reload
updatedBase := "Updated base strategy: Aggressive trading"
if err := os.WriteFile(filepath.Join(tempDir, "base.txt"), []byte(updatedBase), 0644); err != nil {
t.Fatalf("更新文件失败: %v", err)
t.Fatalf("Failed to update file: %v", err)
}
if err := ReloadPromptTemplates(); err != nil {
t.Fatalf("重新加载失败: %v", err)
t.Fatalf("Reload failed: %v", err)
}
result = buildSystemPromptWithCustom(10000.0, 10, 5, customPrompt, false, "base", "")
if !strings.Contains(result, updatedBase) {
t.Errorf("重新加载后未包含更新的基础模板内容")
t.Errorf("After reload doesn't contain updated base template content")
}
if strings.Contains(result, baseContent) {
t.Errorf("重新加载后仍包含旧的基础模板内容")
t.Errorf("After reload still contains old base template content")
}
}
// TestPromptReloadFallback 测试模板不存在时的降级机制
// TestPromptReloadFallback tests fallback mechanism when template doesn't exist
func TestPromptReloadFallback(t *testing.T) {
// 保存原始的 promptsDir
// Save original promptsDir
originalDir := promptsDir
defer func() {
promptsDir = originalDir
globalPromptManager.ReloadTemplates(originalDir)
}()
// 创建临时目录
// Create temporary directory
tempDir := t.TempDir()
promptsDir = tempDir
// 只创建 default 模板
defaultContent := "默认策略"
// Only create default template
defaultContent := "Default strategy"
if err := os.WriteFile(filepath.Join(tempDir, "default.txt"), []byte(defaultContent), 0644); err != nil {
t.Fatalf("创建文件失败: %v", err)
t.Fatalf("Failed to create file: %v", err)
}
if err := ReloadPromptTemplates(); err != nil {
t.Fatalf("加载失败: %v", err)
t.Fatalf("Load failed: %v", err)
}
// 测试1: 请求不存在的模板,应该降级到 default
// Test 1: Request non-existent template, should fall back to default
result := buildSystemPrompt(10000.0, 10, 5, "nonexistent", "")
if !strings.Contains(result, defaultContent) {
t.Errorf("请求不存在的模板时,未降级到 default")
t.Errorf("When requesting non-existent template, didn't fall back to default")
}
// 测试2: 空模板名,应该使用 default
// Test 2: Empty template name, should use default
result = buildSystemPrompt(10000.0, 10, 5, "", "")
if !strings.Contains(result, defaultContent) {
t.Errorf("空模板名时,未使用 default")
t.Errorf("With empty template name, didn't use default")
}
}
// TestConcurrentPromptReload 测试并发场景下的 prompt 重新加载
// TestConcurrentPromptReload tests prompt reload in concurrent scenarios
func TestConcurrentPromptReload(t *testing.T) {
// 保存原始的 promptsDir
// Save original promptsDir
originalDir := promptsDir
defer func() {
promptsDir = originalDir
globalPromptManager.ReloadTemplates(originalDir)
}()
// 创建临时目录
// Create temporary directory
tempDir := t.TempDir()
promptsDir = tempDir
// 创建测试文件
if err := os.WriteFile(filepath.Join(tempDir, "test.txt"), []byte("测试内容"), 0644); err != nil {
t.Fatalf("创建文件失败: %v", err)
// Create test file
if err := os.WriteFile(filepath.Join(tempDir, "test.txt"), []byte("Test content"), 0644); err != nil {
t.Fatalf("Failed to create file: %v", err)
}
if err := ReloadPromptTemplates(); err != nil {
t.Fatalf("初始加载失败: %v", err)
t.Fatalf("Initial load failed: %v", err)
}
// 并发测试:同时读取和重新加载
// Concurrent test: read and reload simultaneously
done := make(chan bool)
// 启动多个读取 goroutine
// Start multiple read goroutines
for i := 0; i < 10; i++ {
go func() {
for j := 0; j < 100; j++ {
@@ -215,7 +215,7 @@ func TestConcurrentPromptReload(t *testing.T) {
}()
}
// 启动多个重新加载 goroutine
// Start multiple reload goroutines
for i := 0; i < 3; i++ {
go func() {
for j := 0; j < 10; j++ {
@@ -225,19 +225,19 @@ func TestConcurrentPromptReload(t *testing.T) {
}()
}
// 等待所有 goroutine 完成
// Wait for all goroutines to complete
for i := 0; i < 13; i++ {
<-done
}
// 验证最终状态正确
// Verify final state is correct
template, err := GetPromptTemplate("test")
if err != nil {
t.Errorf("并发测试后获取模板失败: %v", err)
t.Errorf("Failed to get template after concurrent test: %v", err)
}
if template.Content != "测试内容" {
t.Errorf("并发测试后模板内容错误: %s", template.Content)
if template.Content != "Test content" {
t.Errorf("Template content error after concurrent test: %s", template.Content)
}
t.Log("✅ 并发测试通过:多个 goroutine 同时读取和重新加载模板,无数据竞争")
t.Log("✅ Concurrent test passed: multiple goroutines reading and reloading templates simultaneously, no data race")
}
+5 -5
View File
@@ -5,9 +5,9 @@ import (
"testing"
)
// TestBuildSystemPrompt_ContainsAllValidActions 测试 prompt 是否包含所有有效的 action
// TestBuildSystemPrompt_ContainsAllValidActions tests whether prompt contains all valid actions
func TestBuildSystemPrompt_ContainsAllValidActions(t *testing.T) {
// 这是系统中定义的所有有效 action(来自 validateDecision
// These are all valid actions defined in the system (from validateDecision)
validActions := []string{
"open_long",
"open_short",
@@ -17,13 +17,13 @@ func TestBuildSystemPrompt_ContainsAllValidActions(t *testing.T) {
"wait",
}
// 构建 prompt
// Build prompt
prompt := buildSystemPrompt(1000.0, 10, 5, "default", "")
// 验证每个有效 action 都在 prompt 中出现
// Verify each valid action appears in prompt
for _, action := range validActions {
if !strings.Contains(prompt, action) {
t.Errorf("Prompt 缺少有效的 action: %s", action)
t.Errorf("Prompt missing valid action: %s", action)
}
}
}
+159 -159
View File
@@ -13,37 +13,37 @@ import (
"time"
)
// StrategyEngine 策略执行引擎
// 负责基于策略配置动态获取数据和组装 Prompt
// StrategyEngine strategy execution engine
// Responsible for dynamically fetching data and assembling prompts based on strategy configuration
type StrategyEngine struct {
config *store.StrategyConfig
}
// NewStrategyEngine 创建策略执行引擎
// NewStrategyEngine creates strategy execution engine
func NewStrategyEngine(config *store.StrategyConfig) *StrategyEngine {
return &StrategyEngine{config: config}
}
// GetCandidateCoins 根据策略配置获取候选币种
// GetCandidateCoins gets candidate coins based on strategy configuration
func (e *StrategyEngine) GetCandidateCoins() ([]CandidateCoin, error) {
var candidates []CandidateCoin
symbolSources := make(map[string][]string)
coinSource := e.config.CoinSource
// 设置自定义的 API URL(如果配置了)
// Set custom API URL (if configured)
if coinSource.CoinPoolAPIURL != "" {
pool.SetCoinPoolAPI(coinSource.CoinPoolAPIURL)
logger.Infof("✓ 使用策略配置的 AI500 API URL: %s", coinSource.CoinPoolAPIURL)
logger.Infof("✓ Using strategy-configured AI500 API URL: %s", coinSource.CoinPoolAPIURL)
}
if coinSource.OITopAPIURL != "" {
pool.SetOITopAPI(coinSource.OITopAPIURL)
logger.Infof("✓ 使用策略配置的 OI Top API URL: %s", coinSource.OITopAPIURL)
logger.Infof("✓ Using strategy-configured OI Top API URL: %s", coinSource.OITopAPIURL)
}
switch coinSource.SourceType {
case "static":
// 静态币种列表
// Static coin list
for _, symbol := range coinSource.StaticCoins {
symbol = market.Normalize(symbol)
candidates = append(candidates, CandidateCoin{
@@ -54,19 +54,19 @@ func (e *StrategyEngine) GetCandidateCoins() ([]CandidateCoin, error) {
return candidates, nil
case "coinpool":
// 仅使用 AI500 币种池
// Use AI500 coin pool only
return e.getCoinPoolCoins(coinSource.CoinPoolLimit)
case "oi_top":
// 仅使用 OI Top
// Use OI Top only
return e.getOITopCoins(coinSource.OITopLimit)
case "mixed":
// 混合模式:AI500 + OI Top
// Mixed mode: AI500 + OI Top
if coinSource.UseCoinPool {
poolCoins, err := e.getCoinPoolCoins(coinSource.CoinPoolLimit)
if err != nil {
logger.Infof("⚠️ 获取 AI500 币种池失败: %v", err)
logger.Infof("⚠️ Failed to get AI500 coin pool: %v", err)
} else {
for _, coin := range poolCoins {
symbolSources[coin.Symbol] = append(symbolSources[coin.Symbol], "ai500")
@@ -77,7 +77,7 @@ func (e *StrategyEngine) GetCandidateCoins() ([]CandidateCoin, error) {
if coinSource.UseOITop {
oiCoins, err := e.getOITopCoins(coinSource.OITopLimit)
if err != nil {
logger.Infof("⚠️ 获取 OI Top 失败: %v", err)
logger.Infof("⚠️ Failed to get OI Top: %v", err)
} else {
for _, coin := range oiCoins {
symbolSources[coin.Symbol] = append(symbolSources[coin.Symbol], "oi_top")
@@ -85,7 +85,7 @@ func (e *StrategyEngine) GetCandidateCoins() ([]CandidateCoin, error) {
}
}
// 添加静态币种(如果有)
// Add static coins (if any)
for _, symbol := range coinSource.StaticCoins {
symbol = market.Normalize(symbol)
if _, exists := symbolSources[symbol]; !exists {
@@ -95,7 +95,7 @@ func (e *StrategyEngine) GetCandidateCoins() ([]CandidateCoin, error) {
}
}
// 转换为候选币种列表
// Convert to candidate coin list
for symbol, sources := range symbolSources {
candidates = append(candidates, CandidateCoin{
Symbol: symbol,
@@ -105,11 +105,11 @@ func (e *StrategyEngine) GetCandidateCoins() ([]CandidateCoin, error) {
return candidates, nil
default:
return nil, fmt.Errorf("未知的币种来源类型: %s", coinSource.SourceType)
return nil, fmt.Errorf("unknown coin source type: %s", coinSource.SourceType)
}
}
// getCoinPoolCoins 获取 AI500 币种池
// getCoinPoolCoins gets AI500 coin pool
func (e *StrategyEngine) getCoinPoolCoins(limit int) ([]CandidateCoin, error) {
if limit <= 0 {
limit = 30
@@ -130,7 +130,7 @@ func (e *StrategyEngine) getCoinPoolCoins(limit int) ([]CandidateCoin, error) {
return candidates, nil
}
// getOITopCoins 获取 OI Top 币种
// getOITopCoins gets OI Top coins
func (e *StrategyEngine) getOITopCoins(limit int) ([]CandidateCoin, error) {
if limit <= 0 {
limit = 20
@@ -155,20 +155,20 @@ func (e *StrategyEngine) getOITopCoins(limit int) ([]CandidateCoin, error) {
return candidates, nil
}
// FetchMarketData 根据策略配置获取市场数据
// FetchMarketData fetches market data based on strategy configuration
func (e *StrategyEngine) FetchMarketData(symbol string) (*market.Data, error) {
// 目前使用现有的 market.Get,后续可以根据策略配置自定义
// Currently using existing market.Get, can be customized based on strategy configuration later
return market.Get(symbol)
}
// FetchExternalData 获取外部数据源
// FetchExternalData fetches external data sources
func (e *StrategyEngine) FetchExternalData() (map[string]interface{}, error) {
externalData := make(map[string]interface{})
for _, source := range e.config.Indicators.ExternalDataSources {
data, err := e.fetchSingleExternalSource(source)
if err != nil {
logger.Infof("⚠️ 获取外部数据源 [%s] 失败: %v", source.Name, err)
logger.Infof("⚠️ Failed to fetch external data source [%s]: %v", source.Name, err)
continue
}
externalData[source.Name] = data
@@ -177,7 +177,7 @@ func (e *StrategyEngine) FetchExternalData() (map[string]interface{}, error) {
return externalData, nil
}
// QuantData 量化数据结构(资金流向、持仓变化、价格变化)
// QuantData quantitative data structure (fund flow, position changes, price changes)
type QuantData struct {
Symbol string `json:"symbol"`
Price float64 `json:"price"`
@@ -209,49 +209,49 @@ type OIDeltaData struct {
OIDeltaPercent float64 `json:"oi_delta_percent"`
}
// FetchQuantData 获取单个币种的量化数据
// FetchQuantData fetches quantitative data for a single coin
func (e *StrategyEngine) FetchQuantData(symbol string) (*QuantData, error) {
if !e.config.Indicators.EnableQuantData || e.config.Indicators.QuantDataAPIURL == "" {
return nil, nil
}
// 替换 {symbol} 占位符
// Replace {symbol} placeholder
url := strings.Replace(e.config.Indicators.QuantDataAPIURL, "{symbol}", symbol, -1)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(url)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP状态码: %d", resp.StatusCode)
return nil, fmt.Errorf("HTTP status code: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
return nil, fmt.Errorf("failed to read response: %w", err)
}
// 解析响应
// Parse response
var apiResp struct {
Code int `json:"code"`
Data *QuantData `json:"data"`
}
if err := json.Unmarshal(body, &apiResp); err != nil {
return nil, fmt.Errorf("解析JSON失败: %w", err)
return nil, fmt.Errorf("failed to parse JSON: %w", err)
}
if apiResp.Code != 0 {
return nil, fmt.Errorf("API返回错误码: %d", apiResp.Code)
return nil, fmt.Errorf("API returned error code: %d", apiResp.Code)
}
return apiResp.Data, nil
}
// FetchQuantDataBatch 批量获取量化数据
// FetchQuantDataBatch batch fetches quantitative data
func (e *StrategyEngine) FetchQuantDataBatch(symbols []string) map[string]*QuantData {
result := make(map[string]*QuantData)
@@ -262,7 +262,7 @@ func (e *StrategyEngine) FetchQuantDataBatch(symbols []string) map[string]*Quant
for _, symbol := range symbols {
data, err := e.FetchQuantData(symbol)
if err != nil {
logger.Infof("⚠️ 获取 %s 量化数据失败: %v", symbol, err)
logger.Infof("⚠️ Failed to fetch quantitative data for %s: %v", symbol, err)
continue
}
if data != nil {
@@ -273,18 +273,18 @@ func (e *StrategyEngine) FetchQuantDataBatch(symbols []string) map[string]*Quant
return result
}
// formatQuantData 格式化量化数据
// formatQuantData formats quantitative data
func (e *StrategyEngine) formatQuantData(data *QuantData) string {
if data == nil {
return ""
}
var sb strings.Builder
sb.WriteString("📊 量化数据:\n")
sb.WriteString("📊 Quantitative Data:\n")
// 价格变化
// Price changes
if len(data.PriceChange) > 0 {
sb.WriteString("价格变化: ")
sb.WriteString("Price Change: ")
timeframes := []string{"5m", "15m", "1h", "4h", "24h"}
parts := []string{}
for _, tf := range timeframes {
@@ -296,14 +296,14 @@ func (e *StrategyEngine) formatQuantData(data *QuantData) string {
sb.WriteString("\n")
}
// 资金流向
// Fund flow
if data.Netflow != nil {
sb.WriteString("资金流向(USDT):\n")
sb.WriteString("Fund Flow (USDT):\n")
// 机构资金
// Institutional funds
if data.Netflow.Institution != nil {
if data.Netflow.Institution.Future != nil {
sb.WriteString(" 机构合约: ")
sb.WriteString(" Institutional Futures: ")
parts := []string{}
for _, tf := range []string{"1h", "4h", "24h"} {
if v, ok := data.Netflow.Institution.Future[tf]; ok {
@@ -314,7 +314,7 @@ func (e *StrategyEngine) formatQuantData(data *QuantData) string {
sb.WriteString("\n")
}
if data.Netflow.Institution.Spot != nil {
sb.WriteString(" 机构现货: ")
sb.WriteString(" Institutional Spot: ")
parts := []string{}
for _, tf := range []string{"1h", "4h", "24h"} {
if v, ok := data.Netflow.Institution.Spot[tf]; ok {
@@ -326,10 +326,10 @@ func (e *StrategyEngine) formatQuantData(data *QuantData) string {
}
}
// 散户资金
// Retail funds
if data.Netflow.Personal != nil {
if data.Netflow.Personal.Future != nil {
sb.WriteString(" 散户合约: ")
sb.WriteString(" Retail Futures: ")
parts := []string{}
for _, tf := range []string{"1h", "4h", "24h"} {
if v, ok := data.Netflow.Personal.Future[tf]; ok {
@@ -342,13 +342,13 @@ func (e *StrategyEngine) formatQuantData(data *QuantData) string {
}
}
// 持仓数据
// Position data
if len(data.OI) > 0 {
for exchange, oiData := range data.OI {
sb.WriteString(fmt.Sprintf("持仓(%s): 当前%.2f | 多%.2f 空%.2f\n",
sb.WriteString(fmt.Sprintf("Open Interest (%s): Current %.2f | Long %.2f Short %.2f\n",
exchange, oiData.CurrentOI, oiData.NetLong, oiData.NetShort))
if len(oiData.Delta) > 0 {
sb.WriteString(" 持仓变化: ")
sb.WriteString(" OI Change: ")
parts := []string{}
for _, tf := range []string{"1h", "4h", "24h"} {
if d, ok := oiData.Delta[tf]; ok {
@@ -364,7 +364,7 @@ func (e *StrategyEngine) formatQuantData(data *QuantData) string {
return sb.String()
}
// fetchSingleExternalSource 获取单个外部数据源
// fetchSingleExternalSource fetches a single external data source
func (e *StrategyEngine) fetchSingleExternalSource(source store.ExternalDataSource) (interface{}, error) {
client := &http.Client{
Timeout: time.Duration(source.RefreshSecs) * time.Second,
@@ -379,7 +379,7 @@ func (e *StrategyEngine) fetchSingleExternalSource(source store.ExternalDataSour
return nil, err
}
// 添加请求头
// Add request headers
for k, v := range source.Headers {
req.Header.Set(k, v)
}
@@ -400,7 +400,7 @@ func (e *StrategyEngine) fetchSingleExternalSource(source store.ExternalDataSour
return nil, err
}
// 如果指定了数据路径,提取指定路径的数据
// If data path is specified, extract data at specified path
if source.DataPath != "" {
result = extractJSONPath(result, source.DataPath)
}
@@ -408,7 +408,7 @@ func (e *StrategyEngine) fetchSingleExternalSource(source store.ExternalDataSour
return result, nil
}
// extractJSONPath 提取 JSON 路径数据(简单实现)
// extractJSONPath extracts JSON path data (simple implementation)
func extractJSONPath(data interface{}, path string) interface{} {
parts := strings.Split(path, ".")
current := data
@@ -424,23 +424,23 @@ func extractJSONPath(data interface{}, path string) interface{} {
return current
}
// BuildUserPrompt 根据策略配置构建 User Prompt
// BuildUserPrompt builds User Prompt based on strategy configuration
func (e *StrategyEngine) BuildUserPrompt(ctx *Context) string {
var sb strings.Builder
// 系统状态
sb.WriteString(fmt.Sprintf("时间: %s | 周期: #%d | 运行: %d分钟\n\n",
// System status
sb.WriteString(fmt.Sprintf("Time: %s | Period: #%d | Runtime: %d minutes\n\n",
ctx.CurrentTime, ctx.CallCount, ctx.RuntimeMinutes))
// BTC 市场(如果配置了)
// BTC market (if configured)
if btcData, hasBTC := ctx.MarketDataMap["BTCUSDT"]; hasBTC {
sb.WriteString(fmt.Sprintf("BTC: %.2f (1h: %+.2f%%, 4h: %+.2f%%) | MACD: %.4f | RSI: %.2f\n\n",
btcData.CurrentPrice, btcData.PriceChange1h, btcData.PriceChange4h,
btcData.CurrentMACD, btcData.CurrentRSI7))
}
// 账户信息
sb.WriteString(fmt.Sprintf("账户: 净值%.2f | 余额%.2f (%.1f%%) | 盈亏%+.2f%% | 保证金%.1f%% | 持仓%d\n\n",
// Account information
sb.WriteString(fmt.Sprintf("Account: Equity %.2f | Balance %.2f (%.1f%%) | PnL %+.2f%% | Margin %.1f%% | Positions %d\n\n",
ctx.Account.TotalEquity,
ctx.Account.AvailableBalance,
(ctx.Account.AvailableBalance/ctx.Account.TotalEquity)*100,
@@ -448,40 +448,40 @@ func (e *StrategyEngine) BuildUserPrompt(ctx *Context) string {
ctx.Account.MarginUsedPct,
ctx.Account.PositionCount))
// 持仓信息
// Position information
if len(ctx.Positions) > 0 {
sb.WriteString("## 当前持仓\n")
sb.WriteString("## Current Positions\n")
for i, pos := range ctx.Positions {
sb.WriteString(e.formatPositionInfo(i+1, pos, ctx))
}
} else {
sb.WriteString("当前持仓: 无\n\n")
sb.WriteString("Current Positions: None\n\n")
}
// 交易统计
// Trading statistics
if ctx.TradingStats != nil && ctx.TradingStats.TotalTrades > 0 {
sb.WriteString("## 历史交易统计\n")
sb.WriteString(fmt.Sprintf("总交易数: %d | 胜率: %.1f%% | 盈亏比: %.2f | 夏普比: %.2f\n",
sb.WriteString("## Historical Trading Statistics\n")
sb.WriteString(fmt.Sprintf("Total Trades: %d | Win Rate: %.1f%% | Profit Factor: %.2f | Sharpe Ratio: %.2f\n",
ctx.TradingStats.TotalTrades,
ctx.TradingStats.WinRate,
ctx.TradingStats.ProfitFactor,
ctx.TradingStats.SharpeRatio))
sb.WriteString(fmt.Sprintf("总盈亏: %.2f USDT | 平均盈利: %.2f | 平均亏损: %.2f | 最大回撤: %.1f%%\n\n",
sb.WriteString(fmt.Sprintf("Total P&L: %.2f USDT | Avg Win: %.2f | Avg Loss: %.2f | Max Drawdown: %.1f%%\n\n",
ctx.TradingStats.TotalPnL,
ctx.TradingStats.AvgWin,
ctx.TradingStats.AvgLoss,
ctx.TradingStats.MaxDrawdownPct))
}
// 最近完成的订单
// Recently completed orders
if len(ctx.RecentOrders) > 0 {
sb.WriteString("## 最近完成的交易\n")
sb.WriteString("## Recent Completed Trades\n")
for i, order := range ctx.RecentOrders {
resultStr := "盈利"
resultStr := "Profit"
if order.RealizedPnL < 0 {
resultStr = "亏损"
resultStr = "Loss"
}
sb.WriteString(fmt.Sprintf("%d. %s %s | 入场%.4f 出场%.4f | %s: %+.2f USDT (%+.2f%%) | %s\n",
sb.WriteString(fmt.Sprintf("%d. %s %s | Entry %.4f Exit %.4f | %s: %+.2f USDT (%+.2f%%) | %s\n",
i+1, order.Symbol, order.Side,
order.EntryPrice, order.ExitPrice,
resultStr, order.RealizedPnL, order.PnLPct,
@@ -490,8 +490,8 @@ func (e *StrategyEngine) BuildUserPrompt(ctx *Context) string {
sb.WriteString("\n")
}
// 候选币种
sb.WriteString(fmt.Sprintf("## 候选币种 (%d个)\n\n", len(ctx.MarketDataMap)))
// Candidate coins
sb.WriteString(fmt.Sprintf("## Candidate Coins (%d coins)\n\n", len(ctx.MarketDataMap)))
displayedCount := 0
for _, coin := range ctx.CandidateCoins {
marketData, hasData := ctx.MarketDataMap[coin.Symbol]
@@ -504,7 +504,7 @@ func (e *StrategyEngine) BuildUserPrompt(ctx *Context) string {
sb.WriteString(fmt.Sprintf("### %d. %s%s\n\n", displayedCount, coin.Symbol, sourceTags))
sb.WriteString(e.formatMarketData(marketData))
// 添加量化数据(如果有)
// Add quantitative data if available
if ctx.QuantDataMap != nil {
if quantData, hasQuant := ctx.QuantDataMap[coin.Symbol]; hasQuant {
sb.WriteString(e.formatQuantData(quantData))
@@ -515,45 +515,45 @@ func (e *StrategyEngine) BuildUserPrompt(ctx *Context) string {
sb.WriteString("\n")
sb.WriteString("---\n\n")
sb.WriteString("现在请分析并输出决策(思维链 + JSON\n")
sb.WriteString("Now please analyze and output your decision (Chain of Thought + JSON)\n")
return sb.String()
}
// formatPositionInfo 格式化持仓信息
// formatPositionInfo formats position information
func (e *StrategyEngine) formatPositionInfo(index int, pos PositionInfo, ctx *Context) string {
var sb strings.Builder
// 计算持仓时长
// Calculate holding duration
holdingDuration := ""
if pos.UpdateTime > 0 {
durationMs := time.Now().UnixMilli() - pos.UpdateTime
durationMin := durationMs / (1000 * 60)
if durationMin < 60 {
holdingDuration = fmt.Sprintf(" | 持仓时长%d分钟", durationMin)
holdingDuration = fmt.Sprintf(" | Holding Duration %d min", durationMin)
} else {
durationHour := durationMin / 60
durationMinRemainder := durationMin % 60
holdingDuration = fmt.Sprintf(" | 持仓时长%d小时%d分钟", durationHour, durationMinRemainder)
holdingDuration = fmt.Sprintf(" | Holding Duration %dh %dm", durationHour, durationMinRemainder)
}
}
// 计算仓位价值
// Calculate position value
positionValue := pos.Quantity * pos.MarkPrice
if positionValue < 0 {
positionValue = -positionValue
}
sb.WriteString(fmt.Sprintf("%d. %s %s | 入场价%.4f 当前价%.4f | 数量%.4f | 仓位价值%.2f USDT | 盈亏%+.2f%% | 盈亏金额%+.2f USDT | 最高收益率%.2f%% | 杠杆%dx | 保证金%.0f | 强平价%.4f%s\n\n",
sb.WriteString(fmt.Sprintf("%d. %s %s | Entry %.4f Current %.4f | Qty %.4f | Position Value %.2f USDT | PnL%+.2f%% | PnL Amount%+.2f USDT | Peak PnL%.2f%% | Leverage %dx | Margin %.0f | Liq Price %.4f%s\n\n",
index, pos.Symbol, strings.ToUpper(pos.Side),
pos.EntryPrice, pos.MarkPrice, pos.Quantity, positionValue, pos.UnrealizedPnLPct, pos.UnrealizedPnL, pos.PeakPnLPct,
pos.Leverage, pos.MarginUsed, pos.LiquidationPrice, holdingDuration))
// 使用策略配置的指标输出市场数据
// Output market data using strategy configured indicators
if marketData, ok := ctx.MarketDataMap[pos.Symbol]; ok {
sb.WriteString(e.formatMarketData(marketData))
// 添加量化数据(如果有)
// Add quantitative data if available
if ctx.QuantDataMap != nil {
if quantData, hasQuant := ctx.QuantDataMap[pos.Symbol]; hasQuant {
sb.WriteString(e.formatQuantData(quantData))
@@ -565,29 +565,29 @@ func (e *StrategyEngine) formatPositionInfo(index int, pos PositionInfo, ctx *Co
return sb.String()
}
// formatCoinSourceTag 格式化币种来源标签
// formatCoinSourceTag formats coin source tag
func (e *StrategyEngine) formatCoinSourceTag(sources []string) string {
if len(sources) > 1 {
return " (AI500+OI_Top双重信号)"
return " (AI500+OI_Top dual signal)"
} else if len(sources) == 1 {
switch sources[0] {
case "ai500":
return " (AI500)"
case "oi_top":
return " (OI_Top持仓增长)"
return " (OI_Top position growth)"
case "static":
return " (手动选择)"
return " (Manual selection)"
}
}
return ""
}
// formatMarketData 根据策略配置格式化市场数据
// formatMarketData formats market data according to strategy configuration
func (e *StrategyEngine) formatMarketData(data *market.Data) string {
var sb strings.Builder
indicators := e.config.Indicators
// 当前价格(总是显示)
// Current price (always display)
sb.WriteString(fmt.Sprintf("current_price = %.4f", data.CurrentPrice))
// EMA
@@ -607,7 +607,7 @@ func (e *StrategyEngine) formatMarketData(data *market.Data) string {
sb.WriteString("\n\n")
// OI Funding Rate
// OI and Funding Rate
if indicators.EnableOI || indicators.EnableFundingRate {
sb.WriteString(fmt.Sprintf("Additional data for %s:\n\n", data.Symbol))
@@ -621,9 +621,9 @@ func (e *StrategyEngine) formatMarketData(data *market.Data) string {
}
}
// 优先使用多时间周期数据(新增)
// Prefer using multi-timeframe data (new addition)
if len(data.TimeframeData) > 0 {
// 按时间周期排序输出
// Output in timeframe order
timeframeOrder := []string{"1m", "3m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "8h", "12h", "1d", "3d", "1w"}
for _, tf := range timeframeOrder {
if tfData, ok := data.TimeframeData[tf]; ok {
@@ -632,8 +632,8 @@ func (e *StrategyEngine) formatMarketData(data *market.Data) string {
}
}
} else {
// 兼容旧的数据格式
// 日内数据
// Compatible with old data format
// Intraday data
if data.IntradaySeries != nil {
klineConfig := indicators.Klines
sb.WriteString(fmt.Sprintf("Intraday series (%s intervals, oldest → latest):\n\n", klineConfig.PrimaryTimeframe))
@@ -668,7 +668,7 @@ func (e *StrategyEngine) formatMarketData(data *market.Data) string {
}
}
// 长周期数据
// Longer-term data
if data.LongerTermContext != nil && indicators.Klines.EnableMultiTimeframe {
sb.WriteString(fmt.Sprintf("Longer-term context (%s timeframe):\n\n", indicators.Klines.LongerTimeframe))
@@ -700,7 +700,7 @@ func (e *StrategyEngine) formatMarketData(data *market.Data) string {
return sb.String()
}
// formatTimeframeSeriesData 格式化单个时间周期的序列数据
// formatTimeframeSeriesData formats series data for a single timeframe
func (e *StrategyEngine) formatTimeframeSeriesData(sb *strings.Builder, data *market.TimeframeSeriesData, indicators store.IndicatorConfig) {
if len(data.MidPrices) > 0 {
sb.WriteString(fmt.Sprintf("Mid prices: %s\n\n", formatFloatSlice(data.MidPrices)))
@@ -737,7 +737,7 @@ func (e *StrategyEngine) formatTimeframeSeriesData(sb *strings.Builder, data *ma
}
}
// formatFloatSlice 格式化浮点数切片
// formatFloatSlice formats float slice
func formatFloatSlice(values []float64) string {
strValues := make([]string, len(values))
for i, v := range values {
@@ -746,179 +746,179 @@ func formatFloatSlice(values []float64) string {
return "[" + strings.Join(strValues, ", ") + "]"
}
// BuildSystemPrompt 根据策略配置构建 System Prompt
// BuildSystemPrompt builds System Prompt according to strategy configuration
func (e *StrategyEngine) BuildSystemPrompt(accountEquity float64, variant string) string {
var sb strings.Builder
riskControl := e.config.RiskControl
promptSections := e.config.PromptSections
// 1. 角色定义(可编辑)
// 1. Role definition (editable)
if promptSections.RoleDefinition != "" {
sb.WriteString(promptSections.RoleDefinition)
sb.WriteString("\n\n")
} else {
sb.WriteString("# 你是专业的加密货币交易AI\n\n")
sb.WriteString("你的任务是根据提供的市场数据做出交易决策。\n\n")
sb.WriteString("# You are a professional cryptocurrency trading AI\n\n")
sb.WriteString("Your task is to make trading decisions based on provided market data.\n\n")
}
// 2. 交易模式变体
// 2. Trading mode variant
switch strings.ToLower(strings.TrimSpace(variant)) {
case "aggressive":
sb.WriteString("## 模式:Aggressive(进攻型)\n- 优先捕捉趋势突破,可在信心度≥70时分批建仓\n- 允许更高仓位,但须严格设置止损并说明盈亏比\n\n")
sb.WriteString("## Mode: Aggressive\n- Prioritize capturing trend breakouts, can build positions in batches when confidence ≥ 70\n- Allow higher positions, but must strictly set stop-loss and explain risk-reward ratio\n\n")
case "conservative":
sb.WriteString("## 模式:Conservative(稳健型)\n- 仅在多重信号共振时开仓\n- 优先保留现金,连续亏损必须暂停多个周期\n\n")
sb.WriteString("## Mode: Conservative\n- Only open positions when multiple signals resonate\n- Prioritize cash preservation, must pause for multiple periods after consecutive losses\n\n")
case "scalping":
sb.WriteString("## 模式:Scalping(剥头皮)\n- 聚焦短周期动量,目标收益较小但要求迅速\n- 若价格两根bar内未按预期运行,立即减仓或止损\n\n")
sb.WriteString("## Mode: Scalping\n- Focus on short-term momentum, smaller profit targets but require quick action\n- If price doesn't move as expected within two bars, immediately reduce position or stop-loss\n\n")
}
// 3. 硬约束(风险控制)- 来自策略配置(不可编辑,自动生成)
sb.WriteString("# 硬约束(风险控制)\n\n")
sb.WriteString(fmt.Sprintf("1. 风险回报比: 必须 ≥ 1:%.1f\n", riskControl.MinRiskRewardRatio))
sb.WriteString(fmt.Sprintf("2. 最多持仓: %d个币种(质量>数量)\n", riskControl.MaxPositions))
sb.WriteString(fmt.Sprintf("3. 单币仓位: 山寨%.0f-%.0f U | BTC/ETH %.0f-%.0f U\n",
// 3. Hard constraints (risk control) - from strategy config (non-editable, auto-generated)
sb.WriteString("# Hard Constraints (Risk Control)\n\n")
sb.WriteString(fmt.Sprintf("1. Risk-Reward Ratio: Must be ≥ 1:%.1f\n", riskControl.MinRiskRewardRatio))
sb.WriteString(fmt.Sprintf("2. Max Positions: %d coins (quality > quantity)\n", riskControl.MaxPositions))
sb.WriteString(fmt.Sprintf("3. Single Coin Position: Altcoins %.0f-%.0f U | BTC/ETH %.0f-%.0f U\n",
accountEquity*0.8, accountEquity*riskControl.MaxPositionRatio,
accountEquity*5, accountEquity*10))
sb.WriteString(fmt.Sprintf("4. 杠杆限制: **山寨币最大%dx杠杆** | **BTC/ETH最大%dx杠杆**\n",
sb.WriteString(fmt.Sprintf("4. Leverage Limits: **Altcoins max %dx leverage** | **BTC/ETH max %dx leverage**\n",
riskControl.AltcoinMaxLeverage, riskControl.BTCETHMaxLeverage))
sb.WriteString(fmt.Sprintf("5. 保证金使用率 ≤ %.0f%%\n", riskControl.MaxMarginUsage*100))
sb.WriteString(fmt.Sprintf("6. 开仓金额: 建议 ≥%.0f USDT\n", riskControl.MinPositionSize))
sb.WriteString(fmt.Sprintf("7. 最小信心度: ≥%d\n\n", riskControl.MinConfidence))
sb.WriteString(fmt.Sprintf("5. Margin Usage ≤ %.0f%%\n", riskControl.MaxMarginUsage*100))
sb.WriteString(fmt.Sprintf("6. Opening Amount: Recommended ≥%.0f USDT\n", riskControl.MinPositionSize))
sb.WriteString(fmt.Sprintf("7. Minimum Confidence: ≥%d\n\n", riskControl.MinConfidence))
// 4. 交易频率与信号质量(可编辑)
// 4. Trading frequency and signal quality (editable)
if promptSections.TradingFrequency != "" {
sb.WriteString(promptSections.TradingFrequency)
sb.WriteString("\n\n")
} else {
sb.WriteString("# ⏱️ 交易频率认知\n\n")
sb.WriteString("- 优秀交易员:每天2-4笔 ≈ 每小时0.1-0.2笔\n")
sb.WriteString("- 每小时>2笔 = 过度交易\n")
sb.WriteString("- 单笔持仓时间≥30-60分钟\n")
sb.WriteString("如果你发现自己每个周期都在交易 → 标准过低;若持仓<30分钟就平仓 → 过于急躁。\n\n")
sb.WriteString("# ⏱️ Trading Frequency Awareness\n\n")
sb.WriteString("- Excellent traders: 2-4 trades/day ≈ 0.1-0.2 trades/hour\n")
sb.WriteString("- >2 trades/hour = Overtrading\n")
sb.WriteString("- Single position hold time ≥ 30-60 minutes\n")
sb.WriteString("If you find yourself trading every period → standards too low; if closing positions < 30 minutes → too impatient.\n\n")
}
// 5. 开仓标准(可编辑)
// 5. Entry standards (editable)
if promptSections.EntryStandards != "" {
sb.WriteString(promptSections.EntryStandards)
sb.WriteString("\n\n你拥有以下指标数据:\n")
sb.WriteString("\n\nYou have the following indicator data:\n")
e.writeAvailableIndicators(&sb)
sb.WriteString(fmt.Sprintf("\n**信心度 ≥%d** 才能开仓。\n\n", riskControl.MinConfidence))
sb.WriteString(fmt.Sprintf("\n**Confidence %d** required to open positions.\n\n", riskControl.MinConfidence))
} else {
sb.WriteString("# 🎯 开仓标准(严格)\n\n")
sb.WriteString("只在多重信号共振时开仓。你拥有:\n")
sb.WriteString("# 🎯 Entry Standards (Strict)\n\n")
sb.WriteString("Only open positions when multiple signals resonate. You have:\n")
e.writeAvailableIndicators(&sb)
sb.WriteString(fmt.Sprintf("\n自由运用任何有效的分析方法,但**信心度 ≥%d** 才能开仓;避免单一指标、信号矛盾、横盘震荡、刚平仓即重启等低质量行为。\n\n", riskControl.MinConfidence))
sb.WriteString(fmt.Sprintf("\nFeel free to use any effective analysis method, but **confidence %d** required to open positions; avoid low-quality behaviors such as single indicators, contradictory signals, sideways consolidation, reopening immediately after closing, etc.\n\n", riskControl.MinConfidence))
}
// 6. 决策流程提示(可编辑)
// 6. Decision process tips (editable)
if promptSections.DecisionProcess != "" {
sb.WriteString(promptSections.DecisionProcess)
sb.WriteString("\n\n")
} else {
sb.WriteString("# 📋 决策流程\n\n")
sb.WriteString("1. 检查持仓 → 是否该止盈/止损\n")
sb.WriteString("2. 扫描候选币 + 多时间框 → 是否存在强信号\n")
sb.WriteString("3. 先写思维链,再输出结构化JSON\n\n")
sb.WriteString("# 📋 Decision Process\n\n")
sb.WriteString("1. Check positions → Should we take profit/stop-loss\n")
sb.WriteString("2. Scan candidate coins + multi-timeframe → Are there strong signals\n")
sb.WriteString("3. Write chain of thought first, then output structured JSON\n\n")
}
// 7. 输出格式
sb.WriteString("# 输出格式 (严格遵守)\n\n")
sb.WriteString("**必须使用XML标签 <reasoning> <decision> 标签分隔思维链和决策JSON,避免解析错误**\n\n")
sb.WriteString("## 格式要求\n\n")
// 7. Output format
sb.WriteString("# Output Format (Strictly Follow)\n\n")
sb.WriteString("**Must use XML tags <reasoning> and <decision> to separate chain of thought and decision JSON, avoiding parsing errors**\n\n")
sb.WriteString("## Format Requirements\n\n")
sb.WriteString("<reasoning>\n")
sb.WriteString("你的思维链分析...\n")
sb.WriteString("- 简洁分析你的思考过程 \n")
sb.WriteString("Your chain of thought analysis...\n")
sb.WriteString("- Briefly analyze your thinking process \n")
sb.WriteString("</reasoning>\n\n")
sb.WriteString("<decision>\n")
sb.WriteString("第二步: JSON决策数组\n\n")
sb.WriteString("Step 2: JSON decision array\n\n")
sb.WriteString("```json\n[\n")
sb.WriteString(fmt.Sprintf(" {\"symbol\": \"BTCUSDT\", \"action\": \"open_short\", \"leverage\": %d, \"position_size_usd\": %.0f, \"stop_loss\": 97000, \"take_profit\": 91000, \"confidence\": 85, \"risk_usd\": 300},\n",
riskControl.BTCETHMaxLeverage, accountEquity*5))
sb.WriteString(" {\"symbol\": \"ETHUSDT\", \"action\": \"close_long\"}\n")
sb.WriteString("]\n```\n")
sb.WriteString("</decision>\n\n")
sb.WriteString("## 字段说明\n\n")
sb.WriteString("## Field Description\n\n")
sb.WriteString("- `action`: open_long | open_short | close_long | close_short | hold | wait\n")
sb.WriteString(fmt.Sprintf("- `confidence`: 0-100(开仓建议≥%d\n", riskControl.MinConfidence))
sb.WriteString("- 开仓时必填: leverage, position_size_usd, stop_loss, take_profit, confidence, risk_usd\n\n")
sb.WriteString(fmt.Sprintf("- `confidence`: 0-100 (opening recommended ≥ %d)\n", riskControl.MinConfidence))
sb.WriteString("- Required when opening: leverage, position_size_usd, stop_loss, take_profit, confidence, risk_usd\n\n")
// 8. 自定义 Prompt
// 8. Custom Prompt
if e.config.CustomPrompt != "" {
sb.WriteString("# 📌 个性化交易策略\n\n")
sb.WriteString("# 📌 Personalized Trading Strategy\n\n")
sb.WriteString(e.config.CustomPrompt)
sb.WriteString("\n\n")
sb.WriteString("注意: 以上个性化策略是对基础规则的补充,不能违背基础风险控制原则。\n")
sb.WriteString("Note: The above personalized strategy is a supplement to the basic rules and cannot violate the basic risk control principles.\n")
}
return sb.String()
}
// writeAvailableIndicators 写入可用指标列表
// writeAvailableIndicators writes list of available indicators
func (e *StrategyEngine) writeAvailableIndicators(sb *strings.Builder) {
indicators := e.config.Indicators
kline := indicators.Klines
sb.WriteString(fmt.Sprintf("- %s价格序列", kline.PrimaryTimeframe))
sb.WriteString(fmt.Sprintf("- %s price series", kline.PrimaryTimeframe))
if kline.EnableMultiTimeframe {
sb.WriteString(fmt.Sprintf(" + %s K线序列\n", kline.LongerTimeframe))
sb.WriteString(fmt.Sprintf(" + %s K-line series\n", kline.LongerTimeframe))
} else {
sb.WriteString("\n")
}
if indicators.EnableEMA {
sb.WriteString("- EMA 指标")
sb.WriteString("- EMA indicators")
if len(indicators.EMAPeriods) > 0 {
sb.WriteString(fmt.Sprintf("(周期: %v", indicators.EMAPeriods))
sb.WriteString(fmt.Sprintf(" (periods: %v)", indicators.EMAPeriods))
}
sb.WriteString("\n")
}
if indicators.EnableMACD {
sb.WriteString("- MACD 指标\n")
sb.WriteString("- MACD indicators\n")
}
if indicators.EnableRSI {
sb.WriteString("- RSI 指标")
sb.WriteString("- RSI indicators")
if len(indicators.RSIPeriods) > 0 {
sb.WriteString(fmt.Sprintf("(周期: %v", indicators.RSIPeriods))
sb.WriteString(fmt.Sprintf(" (periods: %v)", indicators.RSIPeriods))
}
sb.WriteString("\n")
}
if indicators.EnableATR {
sb.WriteString("- ATR 指标")
sb.WriteString("- ATR indicators")
if len(indicators.ATRPeriods) > 0 {
sb.WriteString(fmt.Sprintf("(周期: %v", indicators.ATRPeriods))
sb.WriteString(fmt.Sprintf(" (periods: %v)", indicators.ATRPeriods))
}
sb.WriteString("\n")
}
if indicators.EnableVolume {
sb.WriteString("- 成交量数据\n")
sb.WriteString("- Volume data\n")
}
if indicators.EnableOI {
sb.WriteString("- 持仓量(OI)数据\n")
sb.WriteString("- Open Interest (OI) data\n")
}
if indicators.EnableFundingRate {
sb.WriteString("- 资金费率\n")
sb.WriteString("- Funding rate\n")
}
if len(e.config.CoinSource.StaticCoins) > 0 || e.config.CoinSource.UseCoinPool || e.config.CoinSource.UseOITop {
sb.WriteString("- AI500 / OI_Top 筛选标签(若有)\n")
sb.WriteString("- AI500 / OI_Top filter tags (if available)\n")
}
if indicators.EnableQuantData {
sb.WriteString("- 量化数据(机构/散户资金流向、持仓变化、多周期价格变化)\n")
sb.WriteString("- Quantitative data (institutional/retail fund flow, position changes, multi-period price changes)\n")
}
}
// GetRiskControlConfig 获取风险控制配置
// GetRiskControlConfig gets risk control configuration
func (e *StrategyEngine) GetRiskControlConfig() store.RiskControlConfig {
return e.config.RiskControl
}
// GetConfig 获取完整策略配置
// GetConfig gets complete strategy configuration
func (e *StrategyEngine) GetConfig() *store.StrategyConfig {
return e.config
}
+18 -18
View File
@@ -4,7 +4,7 @@ import (
"testing"
)
// TestLeverageFallback 测试杠杆超限时的自动修正功能
// TestLeverageFallback tests automatic correction when leverage exceeds limit
func TestLeverageFallback(t *testing.T) {
tests := []struct {
name string
@@ -12,47 +12,47 @@ func TestLeverageFallback(t *testing.T) {
accountEquity float64
btcEthLeverage int
altcoinLeverage int
wantLeverage int // 期望修正后的杠杆值
wantLeverage int // Expected leverage after correction
wantError bool
}{
{
name: "山寨币杠杆超限_自动修正为上限",
name: "Altcoin leverage exceeded - auto-correct to limit",
decision: Decision{
Symbol: "SOLUSDT",
Action: "open_long",
Leverage: 20, // 超过上限
Leverage: 20, // Exceeds limit
PositionSizeUSD: 100,
StopLoss: 50,
TakeProfit: 200,
},
accountEquity: 100,
btcEthLeverage: 10,
altcoinLeverage: 5, // 上限 5x
wantLeverage: 5, // 应该修正为 5
altcoinLeverage: 5, // Limit 5x
wantLeverage: 5, // Should be corrected to 5
wantError: false,
},
{
name: "BTC杠杆超限_自动修正为上限",
name: "BTC leverage exceeded - auto-correct to limit",
decision: Decision{
Symbol: "BTCUSDT",
Action: "open_long",
Leverage: 20, // 超过上限
Leverage: 20, // Exceeds limit
PositionSizeUSD: 1000,
StopLoss: 90000,
TakeProfit: 110000,
},
accountEquity: 100,
btcEthLeverage: 10, // 上限 10x
btcEthLeverage: 10, // Limit 10x
altcoinLeverage: 5,
wantLeverage: 10, // 应该修正为 10
wantLeverage: 10, // Should be corrected to 10
wantError: false,
},
{
name: "杠杆在上限内_不修正",
name: "Leverage within limit - no correction",
decision: Decision{
Symbol: "ETHUSDT",
Action: "open_short",
Leverage: 5, // 未超限
Leverage: 5, // Not exceeded
PositionSizeUSD: 500,
StopLoss: 4000,
TakeProfit: 3000,
@@ -60,15 +60,15 @@ func TestLeverageFallback(t *testing.T) {
accountEquity: 100,
btcEthLeverage: 10,
altcoinLeverage: 5,
wantLeverage: 5, // 保持不变
wantLeverage: 5, // Stays unchanged
wantError: false,
},
{
name: "杠杆为0_应该报错",
name: "Leverage is 0 - should error",
decision: Decision{
Symbol: "SOLUSDT",
Action: "open_long",
Leverage: 0, // 无效
Leverage: 0, // Invalid
PositionSizeUSD: 100,
StopLoss: 50,
TakeProfit: 200,
@@ -85,13 +85,13 @@ func TestLeverageFallback(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
err := validateDecision(&tt.decision, tt.accountEquity, tt.btcEthLeverage, tt.altcoinLeverage)
// 检查错误状态
// Check error status
if (err != nil) != tt.wantError {
t.Errorf("validateDecision() error = %v, wantError %v", err, tt.wantError)
return
}
// 如果不应该报错,检查杠杆是否被正确修正
// If shouldn't error, check if leverage was correctly corrected
if !tt.wantError && tt.decision.Leverage != tt.wantLeverage {
t.Errorf("Leverage not corrected: got %d, want %d", tt.decision.Leverage, tt.wantLeverage)
}
@@ -100,7 +100,7 @@ func TestLeverageFallback(t *testing.T) {
}
// contains 检查字符串是否包含子串(辅助函数)
// contains checks if string contains substring (helper function)
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
(len(s) > 0 && len(substr) > 0 && stringContains(s, substr)))
+1 -1
View File
@@ -12,7 +12,7 @@ type SetHttpClientResult struct {
func (r *SetHttpClientResult) Error() error {
if r.Err != nil {
log.Printf("⚠️ 执行NewAsterTraderResult时出错: %v", r.Err)
log.Printf("⚠️ Error executing SetHttpClientResult: %v", r.Err)
}
return r.Err
}
+1 -1
View File
@@ -13,7 +13,7 @@ func (r *IpResult) Error() error {
func (r *IpResult) GetResult() string {
if r.Err != nil {
log.Printf("⚠️ 执行GetIP时出错: %v", r.Err)
log.Printf("⚠️ Error executing GetIP: %v", r.Err)
}
return r.IP
}
+2 -2
View File
@@ -14,7 +14,7 @@ type NewBinanceTraderResult struct {
func (r *NewBinanceTraderResult) Error() error {
if r.Err != nil {
log.Printf("⚠️ 执行NewBinanceTraderResult时出错: %v", r.Err)
log.Printf("⚠️ Error executing NewBinanceTraderResult: %v", r.Err)
}
return r.Err
}
@@ -31,7 +31,7 @@ type NewAsterTraderResult struct {
func (r *NewAsterTraderResult) Error() error {
if r.Err != nil {
log.Printf("⚠️ 执行NewAsterTraderResult时出错: %v", r.Err)
log.Printf("⚠️ Error executing NewAsterTraderResult: %v", r.Err)
}
return r.Err
}
+3 -3
View File
@@ -1,11 +1,11 @@
package logger
// Config 日志配置(简化版)
// Config is the logger configuration (simplified version)
type Config struct {
Level string `json:"level"` // 日志级别: debug, info, warn, error (默认: info)
Level string `json:"level"` // Log level: debug, info, warn, error (default: info)
}
// SetDefaults 设置默认值
// SetDefaults sets default values
func (c *Config) SetDefaults() {
if c.Level == "" {
c.Level = "info"
+22 -22
View File
@@ -7,12 +7,12 @@ import (
)
var (
// Log 全局logger实例
// Log is the global logger instance
Log *logrus.Logger
)
func init() {
// 自动初始化默认 logger,确保在 Init 被调用前也能使用
// Auto-initialize default logger to ensure it works before Init is called
Log = logrus.New()
Log.SetLevel(logrus.InfoLevel)
Log.SetFormatter(&logrus.TextFormatter{
@@ -24,66 +24,66 @@ func init() {
}
// ============================================================================
// 初始化函数
// Initialization functions
// ============================================================================
// Init 初始化全局logger
// 如果config为nil,使用默认配置(console输出,info级别)
// Init initializes the global logger
// If config is nil, uses default configuration (console output, info level)
func Init(cfg *Config) error {
Log = logrus.New()
// 如果没有配置,使用默认值
// Use default values if no config provided
if cfg == nil {
cfg = &Config{Level: "info"}
}
// 设置默认值
// Set default values
cfg.SetDefaults()
// 设置日志级别
// Set log level
level, err := logrus.ParseLevel(cfg.Level)
if err != nil {
level = logrus.InfoLevel
}
Log.SetLevel(level)
// 设置格式化器(固定使用彩色文本格式)
// Set formatter (always use colored text format)
Log.SetFormatter(&logrus.TextFormatter{
FullTimestamp: true,
TimestampFormat: "2006-01-02 15:04:05",
ForceColors: true,
})
// 设置输出目标(默认stdout
// Set output target (default stdout)
Log.SetOutput(os.Stdout)
// 启用调用位置信息
// Enable caller location info
Log.SetReportCaller(true)
return nil
}
// InitWithSimpleConfig 使用简化配置初始化logger
// 适用于只需要基本功能的场景
// InitWithSimpleConfig initializes logger with simplified config
// Suitable for scenarios that only need basic functionality
func InitWithSimpleConfig(level string) error {
return Init(&Config{Level: level})
}
// Shutdown 优雅关闭logger
// Shutdown gracefully shuts down the logger
func Shutdown() {
// 预留用于未来扩展
// Reserved for future extensions
}
// ============================================================================
// 日志记录函数
// Logging functions
// ============================================================================
// WithFields 创建带字段的logger entry
// WithFields creates logger entry with fields
func WithFields(fields logrus.Fields) *logrus.Entry {
return Log.WithFields(fields)
}
// WithField 创建带单个字段的logger entry
// WithField creates logger entry with a single field
func WithField(key string, value interface{}) *logrus.Entry {
return Log.WithField(key, value)
}
@@ -138,14 +138,14 @@ func Panicf(format string, args ...interface{}) {
}
// ============================================================================
// MCP Logger 适配器
// MCP Logger adapter
// ============================================================================
// MCPLogger 适配器,使 MCP 包使用全局 logger
// 实现 mcp.Logger 接口
// MCPLogger adapter that allows MCP package to use the global logger
// Implements mcp.Logger interface
type MCPLogger struct{}
// NewMCPLogger 创建 MCP 日志适配器
// NewMCPLogger creates MCP log adapter
func NewMCPLogger() *MCPLogger {
return &MCPLogger{}
}
+39 -39
View File
@@ -19,40 +19,40 @@ import (
)
func main() {
// 加载 .env 环境变量
// Load .env environment variables
_ = godotenv.Load()
// 初始化日志
// Initialize logger
logger.Init(nil)
logger.Info("╔════════════════════════════════════════════════════════════╗")
logger.Info("║ 🤖 AI多模型交易系统 - 支持 DeepSeek & Qwen ║")
logger.Info("║ 🤖 AI Multi-Model Trading System - DeepSeek & Qwen ║")
logger.Info("╚════════════════════════════════════════════════════════════╝")
// 初始化全局配置(从 .env 加载)
// Initialize global configuration (loaded from .env)
config.Init()
cfg := config.Get()
logger.Info("✅ 配置加载完成")
logger.Info("✅ Configuration loaded")
// 初始化数据库
// Initialize database
dbPath := "data.db"
if len(os.Args) > 1 {
dbPath = os.Args[1]
}
logger.Infof("📋 初始化数据库: %s", dbPath)
logger.Infof("📋 Initializing database: %s", dbPath)
st, err := store.New(dbPath)
if err != nil {
logger.Fatalf("❌ 初始化数据库失败: %v", err)
logger.Fatalf("❌ Failed to initialize database: %v", err)
}
defer st.Close()
backtest.UseDatabase(st.DB())
// 初始化加密服务
logger.Info("🔐 初始化加密服务...")
// Initialize encryption service
logger.Info("🔐 Initializing encryption service...")
cryptoService, err := crypto.NewCryptoService()
if err != nil {
logger.Fatalf("❌ 初始化加密服务失败: %v", err)
logger.Fatalf("❌ Failed to initialize encryption service: %v", err)
}
encryptFunc := func(plaintext string) string {
if plaintext == "" {
@@ -60,7 +60,7 @@ func main() {
}
encrypted, err := cryptoService.EncryptForStorage(plaintext)
if err != nil {
logger.Warnf("⚠️ 加密失败: %v", err)
logger.Warnf("⚠️ Encryption failed: %v", err)
return plaintext
}
return encrypted
@@ -74,83 +74,83 @@ func main() {
}
decrypted, err := cryptoService.DecryptFromStorage(encrypted)
if err != nil {
logger.Warnf("⚠️ 解密失败: %v", err)
logger.Warnf("⚠️ Decryption failed: %v", err)
return encrypted
}
return decrypted
}
st.SetCryptoFuncs(encryptFunc, decryptFunc)
logger.Info("✅ 加密服务初始化成功")
logger.Info("✅ Encryption service initialized successfully")
// 设置 JWT 密钥
// Set JWT secret
auth.SetJWTSecret(cfg.JWTSecret)
logger.Info("🔑 JWT 密钥已设置")
logger.Info("🔑 JWT secret configured")
// 创建 TraderManager BacktestManager
// Create TraderManager and BacktestManager
traderManager := manager.NewTraderManager()
mcpClient := newSharedMCPClient()
backtestManager := backtest.NewManager(mcpClient)
if err := backtestManager.RestoreRuns(); err != nil {
logger.Warnf("⚠️ 恢复历史回测失败: %v", err)
logger.Warnf("⚠️ Failed to restore backtest history: %v", err)
}
// 从数据库加载所有交易员到内存
// Load all traders from database to memory
if err := traderManager.LoadTradersFromStore(st); err != nil {
logger.Fatalf("❌ 加载交易员失败: %v", err)
logger.Fatalf("❌ Failed to load traders: %v", err)
}
// 显示加载的交易员信息
// Display loaded trader information
traders, err := st.Trader().List("default")
if err != nil {
logger.Fatalf("❌ 获取交易员列表失败: %v", err)
logger.Fatalf("❌ Failed to get trader list: %v", err)
}
logger.Info("🤖 数据库中的AI交易员配置:")
logger.Info("🤖 AI Trader Configurations in Database:")
if len(traders) == 0 {
logger.Info(" (无交易员配置,请通过Web管理界面创建)")
logger.Info(" (No trader configurations, please create via Web interface)")
} else {
for _, t := range traders {
status := "❌ 已停止"
status := "❌ Stopped"
if t.IsRunning {
status = "✅ 运行中"
status = "✅ Running"
}
logger.Infof(" • %s [%s] %s - AI模型: %s, 交易所: %s",
logger.Infof(" • %s [%s] %s - AI Model: %s, Exchange: %s",
t.Name, t.ID[:8], status, t.AIModelID, t.ExchangeID)
}
}
// 启动 WebSocket 行情监控(获取所有 USDT 永续合约的行情数据)
// Start WebSocket market monitor (get market data for all USDT perpetual contracts)
go market.NewWSMonitor(150).Start(nil)
logger.Info("📊 WebSocket 行情监控已启动")
logger.Info("📊 WebSocket market monitor started")
// 启动API服务器
// Start API server
server := api.NewServer(traderManager, st, cryptoService, backtestManager, cfg.APIServerPort)
go func() {
if err := server.Start(); err != nil {
logger.Fatalf("❌ API服务器启动失败: %v", err)
logger.Fatalf("❌ Failed to start API server: %v", err)
}
}()
// 等待中断信号
// Wait for interrupt signal
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
logger.Info("✅ 系统启动完成,等待交易指令...")
logger.Info("📌 提示: 使用 Ctrl+C 停止系统")
logger.Info("✅ System started successfully, waiting for trading commands...")
logger.Info("📌 Tip: Use Ctrl+C to stop the system")
<-quit
logger.Info("📴 收到停止信号,正在关闭系统...")
logger.Info("📴 Shutdown signal received, closing system...")
// 停止所有交易员
// Stop all traders
traderManager.StopAll()
logger.Info("✅ 系统已安全关闭")
logger.Info("✅ System shut down safely")
}
// newSharedMCPClient 创建共享的 MCP AI 客户端(用于回测)
// newSharedMCPClient creates a shared MCP AI client (for backtesting)
func newSharedMCPClient() mcp.AIClient {
apiKey := os.Getenv("DEEPSEEK_API_KEY")
if apiKey == "" {
logger.Warn("⚠️ DEEPSEEK_API_KEY 未设置,AI 功能将不可用")
logger.Warn("⚠️ DEEPSEEK_API_KEY not set, AI features will be unavailable")
return nil
}
return mcp.NewDeepSeekClient()
+116 -116
View File
@@ -11,21 +11,21 @@ import (
"time"
)
// CompetitionCache 竞赛数据缓存
// CompetitionCache competition data cache
type CompetitionCache struct {
data map[string]interface{}
timestamp time.Time
mu sync.RWMutex
}
// TraderManager 管理多个trader实例
// TraderManager manages multiple trader instances
type TraderManager struct {
traders map[string]*trader.AutoTrader // key: trader ID
competitionCache *CompetitionCache
mu sync.RWMutex
}
// NewTraderManager 创建trader管理器
// NewTraderManager creates a trader manager
func NewTraderManager() *TraderManager {
return &TraderManager{
traders: make(map[string]*trader.AutoTrader),
@@ -35,19 +35,19 @@ func NewTraderManager() *TraderManager {
}
}
// GetTrader 获取指定ID的trader
// GetTrader retrieves a trader by ID
func (tm *TraderManager) GetTrader(id string) (*trader.AutoTrader, error) {
tm.mu.RLock()
defer tm.mu.RUnlock()
t, exists := tm.traders[id]
if !exists {
return nil, fmt.Errorf("trader ID '%s' 不存在", id)
return nil, fmt.Errorf("trader ID '%s' does not exist", id)
}
return t, nil
}
// GetAllTraders 获取所有trader
// GetAllTraders retrieves all traders
func (tm *TraderManager) GetAllTraders() map[string]*trader.AutoTrader {
tm.mu.RLock()
defer tm.mu.RUnlock()
@@ -59,7 +59,7 @@ func (tm *TraderManager) GetAllTraders() map[string]*trader.AutoTrader {
return result
}
// GetTraderIDs 获取所有trader ID列表
// GetTraderIDs retrieves all trader IDs
func (tm *TraderManager) GetTraderIDs() []string {
tm.mu.RLock()
defer tm.mu.RUnlock()
@@ -71,43 +71,43 @@ func (tm *TraderManager) GetTraderIDs() []string {
return ids
}
// StartAll 启动所有trader
// StartAll starts all traders
func (tm *TraderManager) StartAll() {
tm.mu.RLock()
defer tm.mu.RUnlock()
logger.Info("🚀 启动所有Trader...")
logger.Info("🚀 Starting all traders...")
for id, t := range tm.traders {
go func(traderID string, at *trader.AutoTrader) {
logger.Infof("▶️ 启动 %s...", at.GetName())
logger.Infof("▶️ Starting %s...", at.GetName())
if err := at.Run(); err != nil {
logger.Infof("❌ %s 运行错误: %v", at.GetName(), err)
logger.Infof("❌ %s runtime error: %v", at.GetName(), err)
}
}(id, t)
}
}
// StopAll 停止所有trader
// StopAll stops all traders
func (tm *TraderManager) StopAll() {
tm.mu.RLock()
defer tm.mu.RUnlock()
logger.Info("⏹ 停止所有Trader...")
logger.Info("⏹ Stopping all traders...")
for _, t := range tm.traders {
t.Stop()
}
}
// AutoStartRunningTraders 自动启动数据库中标记为运行中的交易员
// AutoStartRunningTraders automatically starts traders marked as running in the database
func (tm *TraderManager) AutoStartRunningTraders(st *store.Store) {
// 先获取所有交易员配置(一次性查询)
// Get all trader configurations (single query)
traderList, err := st.Trader().ListAll()
if err != nil {
logger.Infof("⚠️ 获取交易员列表失败: %v", err)
logger.Infof("⚠️ Failed to get trader list: %v", err)
return
}
// 构建运行中交易员的ID集合
// Build set of running trader IDs
runningTraderIDs := make(map[string]bool)
for _, traderCfg := range traderList {
if traderCfg.IsRunning {
@@ -116,7 +116,7 @@ func (tm *TraderManager) AutoStartRunningTraders(st *store.Store) {
}
if len(runningTraderIDs) == 0 {
logger.Info("📋 没有需要自动恢复的交易员")
logger.Info("📋 No traders to auto-restore")
return
}
@@ -127,9 +127,9 @@ func (tm *TraderManager) AutoStartRunningTraders(st *store.Store) {
for id, t := range tm.traders {
if runningTraderIDs[id] {
go func(traderID string, at *trader.AutoTrader) {
logger.Infof("▶️ 自动恢复启动 %s...", at.GetName())
logger.Infof("▶️ Auto-restoring %s...", at.GetName())
if err := at.Run(); err != nil {
logger.Infof("❌ %s 运行错误: %v", at.GetName(), err)
logger.Infof("❌ %s runtime error: %v", at.GetName(), err)
}
}(id, t)
startedCount++
@@ -137,11 +137,11 @@ func (tm *TraderManager) AutoStartRunningTraders(st *store.Store) {
}
if startedCount > 0 {
logger.Infof("✓ 自动恢复启动了 %d 个交易员", startedCount)
logger.Infof("✓ Auto-restored %d traders", startedCount)
}
}
// GetComparisonData 获取对比数据
// GetComparisonData retrieves comparison data
func (tm *TraderManager) GetComparisonData() (map[string]interface{}, error) {
tm.mu.RLock()
defer tm.mu.RUnlock()
@@ -178,38 +178,38 @@ func (tm *TraderManager) GetComparisonData() (map[string]interface{}, error) {
return comparison, nil
}
// GetCompetitionData 获取竞赛数据(全平台所有交易员)
// GetCompetitionData retrieves competition data (all traders across platform)
func (tm *TraderManager) GetCompetitionData() (map[string]interface{}, error) {
// 检查缓存是否有效(30秒内)
// Check if cache is valid (within 30 seconds)
tm.competitionCache.mu.RLock()
if time.Since(tm.competitionCache.timestamp) < 30*time.Second && len(tm.competitionCache.data) > 0 {
// 返回缓存数据
// Return cached data
cachedData := make(map[string]interface{})
for k, v := range tm.competitionCache.data {
cachedData[k] = v
}
tm.competitionCache.mu.RUnlock()
logger.Infof("📋 返回竞赛数据缓存 (缓存时间: %.1fs)", time.Since(tm.competitionCache.timestamp).Seconds())
logger.Infof("📋 Returning competition data cache (cache age: %.1fs)", time.Since(tm.competitionCache.timestamp).Seconds())
return cachedData, nil
}
tm.competitionCache.mu.RUnlock()
tm.mu.RLock()
// 获取所有交易员列表
// Get all trader list
allTraders := make([]*trader.AutoTrader, 0, len(tm.traders))
for id, t := range tm.traders {
allTraders = append(allTraders, t)
logger.Infof("📋 竞赛数据包含交易员: %s (%s)", t.GetName(), id)
logger.Infof("📋 Competition data includes trader: %s (%s)", t.GetName(), id)
}
tm.mu.RUnlock()
logger.Infof("🔄 重新获取竞赛数据,交易员数量: %d", len(allTraders))
logger.Infof("🔄 Refreshing competition data, trader count: %d", len(allTraders))
// 并发获取交易员数据
// Concurrently fetch trader data
traders := tm.getConcurrentTraderData(allTraders)
// 按收益率排序(降序)
// Sort by profit rate (descending)
sort.Slice(traders, func(i, j int) bool {
pnlPctI, okI := traders[i]["total_pnl_pct"].(float64)
pnlPctJ, okJ := traders[j]["total_pnl_pct"].(float64)
@@ -222,7 +222,7 @@ func (tm *TraderManager) GetCompetitionData() (map[string]interface{}, error) {
return pnlPctI > pnlPctJ
})
// 限制返回前50
// Limit to top 50
totalCount := len(traders)
limit := 50
if len(traders) > limit {
@@ -232,9 +232,9 @@ func (tm *TraderManager) GetCompetitionData() (map[string]interface{}, error) {
comparison := make(map[string]interface{})
comparison["traders"] = traders
comparison["count"] = len(traders)
comparison["total_count"] = totalCount // 总交易员数量
comparison["total_count"] = totalCount // Total number of traders
// 更新缓存
// Update cache
tm.competitionCache.mu.Lock()
tm.competitionCache.data = comparison
tm.competitionCache.timestamp = time.Now()
@@ -243,24 +243,24 @@ func (tm *TraderManager) GetCompetitionData() (map[string]interface{}, error) {
return comparison, nil
}
// getConcurrentTraderData 并发获取多个交易员的数据
// getConcurrentTraderData concurrently fetches data for multiple traders
func (tm *TraderManager) getConcurrentTraderData(traders []*trader.AutoTrader) []map[string]interface{} {
type traderResult struct {
index int
data map[string]interface{}
}
// 创建结果通道
// Create result channel
resultChan := make(chan traderResult, len(traders))
// 并发获取每个交易员的数据
// Concurrently fetch data for each trader
for i, t := range traders {
go func(index int, trader *trader.AutoTrader) {
// 设置单个交易员的超时时间为3秒
// Set timeout to 3 seconds for single trader
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
// 使用通道来实现超时控制
// Use channel for timeout control
accountChan := make(chan map[string]interface{}, 1)
errorChan := make(chan error, 1)
@@ -278,7 +278,7 @@ func (tm *TraderManager) getConcurrentTraderData(traders []*trader.AutoTrader) [
select {
case account := <-accountChan:
// 成功获取账户信息
// Successfully got account info
traderData = map[string]interface{}{
"trader_id": trader.GetID(),
"trader_name": trader.GetName(),
@@ -293,8 +293,8 @@ func (tm *TraderManager) getConcurrentTraderData(traders []*trader.AutoTrader) [
"system_prompt_template": trader.GetSystemPromptTemplate(),
}
case err := <-errorChan:
// 获取账户信息失败
logger.Infof("⚠️ 获取交易员 %s 账户信息失败: %v", trader.GetID(), err)
// Failed to get account info
logger.Infof("⚠️ Failed to get account info for trader %s: %v", trader.GetID(), err)
traderData = map[string]interface{}{
"trader_id": trader.GetID(),
"trader_name": trader.GetName(),
@@ -307,11 +307,11 @@ func (tm *TraderManager) getConcurrentTraderData(traders []*trader.AutoTrader) [
"margin_used_pct": 0.0,
"is_running": status["is_running"],
"system_prompt_template": trader.GetSystemPromptTemplate(),
"error": "账户数据获取失败",
"error": "Failed to get account data",
}
case <-ctx.Done():
// 超时
logger.Infof("⏰ 获取交易员 %s 账户信息超时", trader.GetID())
// Timeout
logger.Infof("⏰ Timeout getting account info for trader %s", trader.GetID())
traderData = map[string]interface{}{
"trader_id": trader.GetID(),
"trader_name": trader.GetName(),
@@ -324,7 +324,7 @@ func (tm *TraderManager) getConcurrentTraderData(traders []*trader.AutoTrader) [
"margin_used_pct": 0.0,
"is_running": status["is_running"],
"system_prompt_template": trader.GetSystemPromptTemplate(),
"error": "获取超时",
"error": "Request timeout",
}
}
@@ -332,7 +332,7 @@ func (tm *TraderManager) getConcurrentTraderData(traders []*trader.AutoTrader) [
}(i, t)
}
// 收集所有结果
// Collect all results
results := make([]map[string]interface{}, len(traders))
for i := 0; i < len(traders); i++ {
result := <-resultChan
@@ -342,21 +342,21 @@ func (tm *TraderManager) getConcurrentTraderData(traders []*trader.AutoTrader) [
return results
}
// GetTopTradersData 获取前5名交易员数据(用于表现对比)
// GetTopTradersData retrieves top 5 traders data (for performance comparison)
func (tm *TraderManager) GetTopTradersData() (map[string]interface{}, error) {
// 复用竞赛数据缓存,因为前5名是从全部数据中筛选出来的
// Reuse competition data cache, as top 5 is filtered from all data
competitionData, err := tm.GetCompetitionData()
if err != nil {
return nil, err
}
// 从竞赛数据中提取前5名
// Extract top 5 from competition data
allTraders, ok := competitionData["traders"].([]map[string]interface{})
if !ok {
return nil, fmt.Errorf("竞赛数据格式错误")
return nil, fmt.Errorf("invalid competition data format")
}
// 限制返回前5名
// Limit to top 5
limit := 5
topTraders := allTraders
if len(allTraders) > limit {
@@ -372,53 +372,53 @@ func (tm *TraderManager) GetTopTradersData() (map[string]interface{}, error) {
}
// RemoveTrader 从内存中移除指定的trader(不影响数据库)
// 用于更新trader配置时强制重新加载
// RemoveTrader removes a trader from memory (does not affect database)
// Used to force reload when updating trader configuration
func (tm *TraderManager) RemoveTrader(traderID string) {
tm.mu.Lock()
defer tm.mu.Unlock()
if _, exists := tm.traders[traderID]; exists {
delete(tm.traders, traderID)
logger.Infof("✓ Trader %s 已从内存中移除", traderID)
logger.Infof("✓ Trader %s removed from memory", traderID)
}
}
// LoadUserTradersFromStore 为特定用户从store加载交易员到内存
// LoadUserTradersFromStore loads traders from store for a specific user to memory
func (tm *TraderManager) LoadUserTradersFromStore(st *store.Store, userID string) error {
tm.mu.Lock()
defer tm.mu.Unlock()
// 获取指定用户的所有交易员
// Get all traders for the specified user
traders, err := st.Trader().List(userID)
if err != nil {
return fmt.Errorf("获取用户 %s 的交易员列表失败: %w", userID, err)
return fmt.Errorf("failed to get trader list for user %s: %w", userID, err)
}
logger.Infof("📋 为用户 %s 加载交易员配置: %d 个", userID, len(traders))
logger.Infof("📋 Loading trader configurations for user %s: %d traders", userID, len(traders))
// 获取AI模型和交易所列表(在循环外只查询一次)
// Get AI model and exchange lists (query only once outside loop)
aiModels, err := st.AIModel().List(userID)
if err != nil {
logger.Infof("⚠️ 获取用户 %s 的AI模型配置失败: %v", userID, err)
return fmt.Errorf("获取AI模型配置失败: %w", err)
logger.Infof("⚠️ Failed to get AI model config for user %s: %v", userID, err)
return fmt.Errorf("failed to get AI model config: %w", err)
}
exchanges, err := st.Exchange().List(userID)
if err != nil {
logger.Infof("⚠️ 获取用户 %s 的交易所配置失败: %v", userID, err)
return fmt.Errorf("获取交易所配置失败: %w", err)
logger.Infof("⚠️ Failed to get exchange config for user %s: %v", userID, err)
return fmt.Errorf("failed to get exchange config: %w", err)
}
// 为每个交易员加载配置
// Load configuration for each trader
for _, traderCfg := range traders {
// 检查是否已经加载过这个交易员
// Check if this trader is already loaded
if _, exists := tm.traders[traderCfg.ID]; exists {
logger.Infof("⚠️ 交易员 %s 已经加载,跳过", traderCfg.Name)
logger.Infof("⚠️ Trader %s already loaded, skipping", traderCfg.Name)
continue
}
// 从已查询的列表中查找AI模型配置
// Find AI model config from already queried list
var aiModelCfg *store.AIModel
for _, model := range aiModels {
if model.ID == traderCfg.AIModelID {
@@ -436,16 +436,16 @@ func (tm *TraderManager) LoadUserTradersFromStore(st *store.Store, userID string
}
if aiModelCfg == nil {
logger.Infof("⚠️ 交易员 %s 的AI模型 %s 不存在,跳过", traderCfg.Name, traderCfg.AIModelID)
logger.Infof("⚠️ AI model %s for trader %s does not exist, skipping", traderCfg.AIModelID, traderCfg.Name)
continue
}
if !aiModelCfg.Enabled {
logger.Infof("⚠️ 交易员 %s 的AI模型 %s 未启用,跳过", traderCfg.Name, traderCfg.AIModelID)
logger.Infof("⚠️ AI model %s for trader %s is not enabled, skipping", traderCfg.AIModelID, traderCfg.Name)
continue
}
// 从已查询的列表中查找交易所配置
// Find exchange config from already queried list
var exchangeCfg *store.Exchange
for _, exchange := range exchanges {
if exchange.ID == traderCfg.ExchangeID {
@@ -455,95 +455,95 @@ func (tm *TraderManager) LoadUserTradersFromStore(st *store.Store, userID string
}
if exchangeCfg == nil {
logger.Infof("⚠️ 交易员 %s 的交易所 %s 不存在,跳过", traderCfg.Name, traderCfg.ExchangeID)
logger.Infof("⚠️ Exchange %s for trader %s does not exist, skipping", traderCfg.ExchangeID, traderCfg.Name)
continue
}
if !exchangeCfg.Enabled {
logger.Infof("⚠️ 交易员 %s 的交易所 %s 未启用,跳过", traderCfg.Name, traderCfg.ExchangeID)
logger.Infof("⚠️ Exchange %s for trader %s is not enabled, skipping", traderCfg.ExchangeID, traderCfg.Name)
continue
}
// 使用现有的方法加载交易员
logger.Infof("📦 正在加载交易员 %s (AI模型: %s, 交易所: %s, 策略ID: %s)", traderCfg.Name, aiModelCfg.Provider, exchangeCfg.ID, traderCfg.StrategyID)
// Use existing method to load trader
logger.Infof("📦 Loading trader %s (AI Model: %s, Exchange: %s, Strategy ID: %s)", traderCfg.Name, aiModelCfg.Provider, exchangeCfg.ID, traderCfg.StrategyID)
err = tm.addTraderFromStore(traderCfg, aiModelCfg, exchangeCfg, st)
if err != nil {
logger.Infof("❌ 加载交易员 %s 失败: %v", traderCfg.Name, err)
logger.Infof("❌ Failed to load trader %s: %v", traderCfg.Name, err)
}
}
return nil
}
// LoadTradersFromStore 从store加载所有交易员到内存(新版API
// LoadTradersFromStore loads all traders from store to memory (new API)
func (tm *TraderManager) LoadTradersFromStore(st *store.Store) error {
tm.mu.Lock()
defer tm.mu.Unlock()
// 获取所有用户
// Get all users
userIDs, err := st.User().GetAllIDs()
if err != nil {
return fmt.Errorf("获取用户列表失败: %w", err)
return fmt.Errorf("failed to get user list: %w", err)
}
logger.Infof("📋 发现 %d 个用户,开始加载所有交易员配置...", len(userIDs))
logger.Infof("📋 Found %d users, loading all trader configurations...", len(userIDs))
var allTraders []*store.Trader
for _, userID := range userIDs {
// 获取每个用户的交易员
// Get traders for each user
traders, err := st.Trader().List(userID)
if err != nil {
logger.Infof("⚠️ 获取用户 %s 的交易员失败: %v", userID, err)
logger.Infof("⚠️ Failed to get traders for user %s: %v", userID, err)
continue
}
logger.Infof("📋 用户 %s: %d 个交易员", userID, len(traders))
logger.Infof("📋 User %s: %d traders", userID, len(traders))
allTraders = append(allTraders, traders...)
}
logger.Infof("📋 总共加载 %d 个交易员配置", len(allTraders))
logger.Infof("📋 Total loaded trader configurations: %d", len(allTraders))
// 为每个交易员获取AI模型和交易所配置
// Get AI model and exchange configs for each trader
for _, traderCfg := range allTraders {
// 获取AI模型配置
// Get AI model config
aiModels, err := st.AIModel().List(traderCfg.UserID)
if err != nil {
logger.Infof("⚠️ 获取AI模型配置失败: %v", err)
logger.Infof("⚠️ Failed to get AI model config: %v", err)
continue
}
var aiModelCfg *store.AIModel
// 优先精确匹配 model.ID
// Prioritize exact match on model.ID
for _, model := range aiModels {
if model.ID == traderCfg.AIModelID {
aiModelCfg = model
break
}
}
// 如果没有精确匹配,尝试匹配 provider(兼容旧数据)
// If no exact match, try matching provider (for backward compatibility)
if aiModelCfg == nil {
for _, model := range aiModels {
if model.Provider == traderCfg.AIModelID {
aiModelCfg = model
logger.Infof("⚠️ 交易员 %s 使用旧版 provider 匹配: %s -> %s", traderCfg.Name, traderCfg.AIModelID, model.ID)
logger.Infof("⚠️ Trader %s using legacy provider match: %s -> %s", traderCfg.Name, traderCfg.AIModelID, model.ID)
break
}
}
}
if aiModelCfg == nil {
logger.Infof("⚠️ 交易员 %s 的AI模型 %s 不存在,跳过", traderCfg.Name, traderCfg.AIModelID)
logger.Infof("⚠️ AI model %s for trader %s does not exist, skipping", traderCfg.AIModelID, traderCfg.Name)
continue
}
if !aiModelCfg.Enabled {
logger.Infof("⚠️ 交易员 %s 的AI模型 %s 未启用,跳过", traderCfg.Name, traderCfg.AIModelID)
logger.Infof("⚠️ AI model %s for trader %s is not enabled, skipping", traderCfg.AIModelID, traderCfg.Name)
continue
}
// 获取交易所配置
// Get exchange config
exchanges, err := st.Exchange().List(traderCfg.UserID)
if err != nil {
logger.Infof("⚠️ 获取交易所配置失败: %v", err)
logger.Infof("⚠️ Failed to get exchange config: %v", err)
continue
}
@@ -556,51 +556,51 @@ func (tm *TraderManager) LoadTradersFromStore(st *store.Store) error {
}
if exchangeCfg == nil {
logger.Infof("⚠️ 交易员 %s 的交易所 %s 不存在,跳过", traderCfg.Name, traderCfg.ExchangeID)
logger.Infof("⚠️ Exchange %s for trader %s does not exist, skipping", traderCfg.ExchangeID, traderCfg.Name)
continue
}
if !exchangeCfg.Enabled {
logger.Infof("⚠️ 交易员 %s 的交易所 %s 未启用,跳过", traderCfg.Name, traderCfg.ExchangeID)
logger.Infof("⚠️ Exchange %s for trader %s is not enabled, skipping", traderCfg.ExchangeID, traderCfg.Name)
continue
}
// 添加到TraderManagercoinPoolURL/oiTopURL 已从策略配置中获取)
// Add to TraderManager (coinPoolURL/oiTopURL already obtained from strategy config)
err = tm.addTraderFromStore(traderCfg, aiModelCfg, exchangeCfg, st)
if err != nil {
logger.Infof("❌ 添加交易员 %s 失败: %v", traderCfg.Name, err)
logger.Infof("❌ Failed to add trader %s: %v", traderCfg.Name, err)
continue
}
}
logger.Infof("✓ 成功加载 %d 个交易员到内存", len(tm.traders))
logger.Infof("✓ Successfully loaded %d traders to memory", len(tm.traders))
return nil
}
// addTraderFromStore 内部方法:从store配置添加交易员
// addTraderFromStore internal method: adds trader from store configuration
func (tm *TraderManager) addTraderFromStore(traderCfg *store.Trader, aiModelCfg *store.AIModel, exchangeCfg *store.Exchange, st *store.Store) error {
if _, exists := tm.traders[traderCfg.ID]; exists {
return fmt.Errorf("trader ID '%s' 已存在", traderCfg.ID)
return fmt.Errorf("trader ID '%s' already exists", traderCfg.ID)
}
// 加载策略配置(必须有策略)
// Load strategy config (must have strategy)
var strategyConfig *store.StrategyConfig
if traderCfg.StrategyID != "" {
strategy, err := st.Strategy().Get(traderCfg.UserID, traderCfg.StrategyID)
if err != nil {
return fmt.Errorf("交易员 %s 的策略 %s 加载失败: %w", traderCfg.Name, traderCfg.StrategyID, err)
return fmt.Errorf("failed to load strategy %s for trader %s: %w", traderCfg.StrategyID, traderCfg.Name, err)
}
// 解析 JSON 配置
// Parse JSON config
strategyConfig, err = strategy.ParseConfig()
if err != nil {
return fmt.Errorf("交易员 %s 的策略配置解析失败: %w", traderCfg.Name, err)
return fmt.Errorf("failed to parse strategy config for trader %s: %w", traderCfg.Name, err)
}
logger.Infof("✓ 交易员 %s 加载策略配置: %s", traderCfg.Name, strategy.Name)
logger.Infof("✓ Trader %s loaded strategy config: %s", traderCfg.Name, strategy.Name)
} else {
return fmt.Errorf("交易员 %s 未配置策略", traderCfg.Name)
return fmt.Errorf("trader %s has no strategy configured", traderCfg.Name)
}
// 构建AutoTraderConfigcoinPoolURL/oiTopURL 从策略配置获取,在 StrategyEngine 中使用)
// Build AutoTraderConfig (coinPoolURL/oiTopURL obtained from strategy config, used in StrategyEngine)
traderConfig := trader.AutoTraderConfig{
ID: traderCfg.ID,
Name: traderCfg.Name,
@@ -621,7 +621,7 @@ func (tm *TraderManager) addTraderFromStore(traderCfg *store.Trader, aiModelCfg
StrategyConfig: strategyConfig,
}
// 根据交易所类型设置API密钥
// Set API keys based on exchange type
switch exchangeCfg.ID {
case "binance":
traderConfig.BinanceAPIKey = exchangeCfg.APIKey
@@ -646,31 +646,31 @@ func (tm *TraderManager) addTraderFromStore(traderCfg *store.Trader, aiModelCfg
traderConfig.LighterTestnet = exchangeCfg.Testnet
}
// 根据AI模型设置API密钥
// Set API keys based on AI model
if aiModelCfg.Provider == "qwen" {
traderConfig.QwenKey = aiModelCfg.APIKey
} else if aiModelCfg.Provider == "deepseek" {
traderConfig.DeepSeekKey = aiModelCfg.APIKey
}
// 创建trader实例
// Create trader instance
at, err := trader.NewAutoTrader(traderConfig, st, traderCfg.UserID)
if err != nil {
return fmt.Errorf("创建trader失败: %w", err)
return fmt.Errorf("failed to create trader: %w", err)
}
// 设置自定义prompt(如果有)
// Set custom prompt (if exists)
if traderCfg.CustomPrompt != "" {
at.SetCustomPrompt(traderCfg.CustomPrompt)
at.SetOverrideBasePrompt(traderCfg.OverrideBasePrompt)
if traderCfg.OverrideBasePrompt {
logger.Infof("✓ 已设置自定义交易策略prompt (覆盖基础prompt)")
logger.Infof("✓ Set custom trading strategy prompt (overriding base prompt)")
} else {
logger.Infof("✓ 已设置自定义交易策略prompt (补充基础prompt)")
logger.Infof("✓ Set custom trading strategy prompt (supplementing base prompt)")
}
}
tm.traders[traderCfg.ID] = at
logger.Infof("✓ Trader '%s' (%s + %s) 已加载到内存", traderCfg.Name, aiModelCfg.Provider, exchangeCfg.ID)
logger.Infof("✓ Trader '%s' (%s + %s) loaded to memory", traderCfg.Name, aiModelCfg.Provider, exchangeCfg.ID)
return nil
}
+22 -22
View File
@@ -4,51 +4,51 @@ import (
"testing"
)
// TestRemoveTrader 测试从内存中移除trader
// TestRemoveTrader tests removing trader from memory
func TestRemoveTrader(t *testing.T) {
tm := NewTraderManager()
// 创建一个模拟的 trader 并添加到 map
// Create a mock trader and add it to map
traderID := "test-trader-123"
tm.traders[traderID] = nil // 使用 nil 作为占位符,实际测试中只需验证删除逻辑
tm.traders[traderID] = nil // Use nil as placeholder, only need to verify deletion logic in test
// 验证 trader 存在
// Verify trader exists
if _, exists := tm.traders[traderID]; !exists {
t.Fatal("trader 应该存在于 map")
t.Fatal("trader should exist in map")
}
// 调用 RemoveTrader
// Call RemoveTrader
tm.RemoveTrader(traderID)
// 验证 trader 已被移除
// Verify trader has been removed
if _, exists := tm.traders[traderID]; exists {
t.Error("trader 应该已从 map 中移除")
t.Error("trader should be removed from map")
}
}
// TestRemoveTrader_NonExistent 测试移除不存在的trader不会报错
// TestRemoveTrader_NonExistent tests that removing non-existent trader doesn't error
func TestRemoveTrader_NonExistent(t *testing.T) {
tm := NewTraderManager()
// 尝试移除不存在的 trader,不应该 panic
// Trying to remove non-existent trader should not panic
defer func() {
if r := recover(); r != nil {
t.Errorf("移除不存在的 trader 不应该 panic: %v", r)
t.Errorf("removing non-existent trader should not panic: %v", r)
}
}()
tm.RemoveTrader("non-existent-trader")
}
// TestRemoveTrader_Concurrent 测试并发移除trader的安全性
// TestRemoveTrader_Concurrent tests concurrent removal of trader safety
func TestRemoveTrader_Concurrent(t *testing.T) {
tm := NewTraderManager()
traderID := "test-trader-concurrent"
// 添加 trader
// Add trader
tm.traders[traderID] = nil
// 并发调用 RemoveTrader
// Concurrently call RemoveTrader
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func() {
@@ -57,31 +57,31 @@ func TestRemoveTrader_Concurrent(t *testing.T) {
}()
}
// 等待所有 goroutine 完成
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
<-done
}
// 验证 trader 已被移除
// Verify trader has been removed
if _, exists := tm.traders[traderID]; exists {
t.Error("trader 应该已从 map 中移除")
t.Error("trader should be removed from map")
}
}
// TestGetTrader_AfterRemove 测试移除后获取trader返回错误
// TestGetTrader_AfterRemove tests that getting trader after removal returns error
func TestGetTrader_AfterRemove(t *testing.T) {
tm := NewTraderManager()
traderID := "test-trader-get"
// 添加 trader
// Add trader
tm.traders[traderID] = nil
// 移除 trader
// Remove trader
tm.RemoveTrader(traderID)
// 尝试获取已移除的 trader
// Try to get removed trader
_, err := tm.GetTrader(traderID)
if err == nil {
t.Error("获取已移除的 trader 应该返回错误")
t.Error("getting removed trader should return error")
}
}
+4 -4
View File
@@ -26,7 +26,7 @@ func NewAPIClient() *APIClient {
hookRes := hook.HookExec[hook.SetHttpClientResult](hook.SET_HTTP_CLIENT, client)
if hookRes != nil && hookRes.Error() == nil {
log.Printf("使用Hook设置的HTTP客户端")
log.Printf("Using HTTP client set by Hook")
client = hookRes.GetResult()
}
@@ -83,7 +83,7 @@ func (c *APIClient) GetKlines(symbol, interval string, limit int) ([]Kline, erro
var klineResponses []KlineResponse
err = json.Unmarshal(body, &klineResponses)
if err != nil {
log.Printf("获取K线数据失败,响应内容: %s", string(body))
log.Printf("Failed to get K-line data, response content: %s", string(body))
return nil, err
}
@@ -91,7 +91,7 @@ func (c *APIClient) GetKlines(symbol, interval string, limit int) ([]Kline, erro
for _, kr := range klineResponses {
kline, err := parseKline(kr)
if err != nil {
log.Printf("解析K线数据失败: %v", err)
log.Printf("Failed to parse K-line data: %v", err)
continue
}
klines = append(klines, kline)
@@ -107,7 +107,7 @@ func parseKline(kr KlineResponse) (Kline, error) {
return kline, fmt.Errorf("invalid kline data")
}
// 解析各个字段
// Parse each field
kline.OpenTime = int64(kr[0].(float64))
kline.Open, _ = strconv.ParseFloat(kr[1].(string), 64)
kline.High, _ = strconv.ParseFloat(kr[2].(string), 64)
+18 -18
View File
@@ -17,7 +17,7 @@ type CombinedStreamsClient struct {
subscribers map[string]chan []byte
reconnect bool
done chan struct{}
batchSize int // 每批订阅的流数量
batchSize int // Number of streams per batch subscription
}
func NewCombinedStreamsClient(batchSize int) *CombinedStreamsClient {
@@ -34,29 +34,29 @@ func (c *CombinedStreamsClient) Connect() error {
HandshakeTimeout: 10 * time.Second,
}
// 组合流使用不同的端点
// Combined streams use a different endpoint
conn, _, err := dialer.Dial("wss://fstream.binance.com/stream", nil)
if err != nil {
return fmt.Errorf("组合流WebSocket连接失败: %v", err)
return fmt.Errorf("Combined stream WebSocket connection failed: %v", err)
}
c.mu.Lock()
c.conn = conn
c.mu.Unlock()
log.Println("组合流WebSocket连接成功")
log.Println("Combined stream WebSocket connected successfully")
go c.readMessages()
return nil
}
// BatchSubscribeKlines 批量订阅K线
// BatchSubscribeKlines subscribes to K-lines in batches
func (c *CombinedStreamsClient) BatchSubscribeKlines(symbols []string, interval string) error {
// symbols分批处理
// Split symbols into batches
batches := c.splitIntoBatches(symbols, c.batchSize)
for i, batch := range batches {
log.Printf("订阅第 %d 批, 数量: %d", i+1, len(batch))
log.Printf("Subscribing batch %d, count: %d", i+1, len(batch))
streams := make([]string, len(batch))
for j, symbol := range batch {
@@ -64,10 +64,10 @@ func (c *CombinedStreamsClient) BatchSubscribeKlines(symbols []string, interval
}
if err := c.subscribeStreams(streams); err != nil {
return fmt.Errorf("第 %d 批订阅失败: %v", i+1, err)
return fmt.Errorf("Batch %d subscription failed: %v", i+1, err)
}
// 批次间延迟,避免被限制
// Delay between batches to avoid rate limiting
if i < len(batches)-1 {
time.Sleep(100 * time.Millisecond)
}
@@ -76,7 +76,7 @@ func (c *CombinedStreamsClient) BatchSubscribeKlines(symbols []string, interval
return nil
}
// splitIntoBatches 将切片分成指定大小的批次
// splitIntoBatches splits a slice into batches of specified size
func (c *CombinedStreamsClient) splitIntoBatches(symbols []string, batchSize int) [][]string {
var batches [][]string
@@ -91,7 +91,7 @@ func (c *CombinedStreamsClient) splitIntoBatches(symbols []string, batchSize int
return batches
}
// subscribeStreams 订阅多个流
// subscribeStreams subscribes to multiple streams
func (c *CombinedStreamsClient) subscribeStreams(streams []string) error {
subscribeMsg := map[string]interface{}{
"method": "SUBSCRIBE",
@@ -103,10 +103,10 @@ func (c *CombinedStreamsClient) subscribeStreams(streams []string) error {
defer c.mu.RUnlock()
if c.conn == nil {
return fmt.Errorf("WebSocket未连接")
return fmt.Errorf("WebSocket not connected")
}
log.Printf("订阅流: %v", streams)
log.Printf("Subscribing to streams: %v", streams)
return c.conn.WriteJSON(subscribeMsg)
}
@@ -127,7 +127,7 @@ func (c *CombinedStreamsClient) readMessages() {
_, message, err := conn.ReadMessage()
if err != nil {
log.Printf("读取组合流消息失败: %v", err)
log.Printf("Failed to read combined stream message: %v", err)
c.handleReconnect()
return
}
@@ -144,7 +144,7 @@ func (c *CombinedStreamsClient) handleCombinedMessage(message []byte) {
}
if err := json.Unmarshal(message, &combinedMsg); err != nil {
log.Printf("解析组合消息失败: %v", err)
log.Printf("Failed to parse combined message: %v", err)
return
}
@@ -156,7 +156,7 @@ func (c *CombinedStreamsClient) handleCombinedMessage(message []byte) {
select {
case ch <- combinedMsg.Data:
default:
log.Printf("订阅者通道已满: %s", combinedMsg.Stream)
log.Printf("Subscriber channel is full: %s", combinedMsg.Stream)
}
}
}
@@ -174,11 +174,11 @@ func (c *CombinedStreamsClient) handleReconnect() {
return
}
log.Println("组合流尝试重新连接...")
log.Println("Combined stream attempting to reconnect...")
time.Sleep(3 * time.Second)
if err := c.Connect(); err != nil {
log.Printf("组合流重新连接失败: %v", err)
log.Printf("Combined stream reconnection failed: %v", err)
go c.handleReconnect()
}
}
+110 -110
View File
@@ -12,8 +12,8 @@ import (
"time"
)
// FundingRateCache 资金费率缓存结构
// Binance Funding Rate 每 8 小时才更新一次,使用 1 小时缓存可显著减少 API 调用
// FundingRateCache is the funding rate cache structure
// Binance Funding Rate only updates every 8 hours, using 1-hour cache can significantly reduce API calls
type FundingRateCache struct {
Rate float64
UpdatedAt time.Time
@@ -24,16 +24,16 @@ var (
frCacheTTL = 1 * time.Hour
)
// Get 获取指定代币的市场数据
// Get retrieves market data for the specified token
func Get(symbol string) (*Data, error) {
var klines3m, klines4h []Kline
var err error
// 标准化symbol
// Normalize symbol
symbol = Normalize(symbol)
// 获取3分钟K线数据 (最近10)
klines3m, err = WSMonitorCli.GetCurrentKlines(symbol, "3m") // 多获取一些用于计算
// Get 3-minute K-line data (latest 10)
klines3m, err = WSMonitorCli.GetCurrentKlines(symbol, "3m") // Get more for calculation
if err != nil {
return nil, fmt.Errorf("获取3分钟K线失败: %v", err)
return nil, fmt.Errorf("Failed to get 3-minute K-line: %v", err)
}
// Data staleness detection: Prevent DOGEUSDT-style price freeze issues
@@ -42,37 +42,37 @@ func Get(symbol string) (*Data, error) {
return nil, fmt.Errorf("%s data is stale, possible cache failure", symbol)
}
// 获取4小时K线数据 (最近10)
klines4h, err = WSMonitorCli.GetCurrentKlines(symbol, "4h") // 多获取用于计算指标
// Get 4-hour K-line data (latest 10)
klines4h, err = WSMonitorCli.GetCurrentKlines(symbol, "4h") // Get more for indicator calculation
if err != nil {
return nil, fmt.Errorf("获取4小时K线失败: %v", err)
return nil, fmt.Errorf("Failed to get 4-hour K-line: %v", err)
}
// 检查数据是否为空
// Check if data is empty
if len(klines3m) == 0 {
return nil, fmt.Errorf("3分钟K线数据为空")
return nil, fmt.Errorf("3-minute K-line data is empty")
}
if len(klines4h) == 0 {
return nil, fmt.Errorf("4小时K线数据为空")
return nil, fmt.Errorf("4-hour K-line data is empty")
}
// 计算当前指标 (基于3分钟最新数据)
// Calculate current indicators (based on 3-minute latest data)
currentPrice := klines3m[len(klines3m)-1].Close
currentEMA20 := calculateEMA(klines3m, 20)
currentMACD := calculateMACD(klines3m)
currentRSI7 := calculateRSI(klines3m, 7)
// 计算价格变化百分比
// 1小时价格变化 = 20个3分钟K线前的价格
// Calculate price change percentage
// 1-hour price change = price from 20 3-minute K-lines ago
priceChange1h := 0.0
if len(klines3m) >= 21 { // 至少需要21根K线 (当前 + 20根前)
if len(klines3m) >= 21 { // Need at least 21 K-lines (current + 20 previous)
price1hAgo := klines3m[len(klines3m)-21].Close
if price1hAgo > 0 {
priceChange1h = ((currentPrice - price1hAgo) / price1hAgo) * 100
}
}
// 4小时价格变化 = 1个4小时K线前的价格
// 4-hour price change = price from 1 4-hour K-line ago
priceChange4h := 0.0
if len(klines4h) >= 2 {
price4hAgo := klines4h[len(klines4h)-2].Close
@@ -81,20 +81,20 @@ func Get(symbol string) (*Data, error) {
}
}
// 获取OI数据
// Get OI data
oiData, err := getOpenInterestData(symbol)
if err != nil {
// OI失败不影响整体,使用默认值
// OI failure doesn't affect overall result, use default values
oiData = &OIData{Latest: 0, Average: 0}
}
// 获取Funding Rate
// Get Funding Rate
fundingRate, _ := getFundingRate(symbol)
// 计算日内系列数据
// Calculate intraday series data
intradayData := calculateIntradaySeries(klines3m)
// 计算长期数据
// Calculate longer-term data
longerTermData := calculateLongerTermData(klines4h)
return &Data{
@@ -112,23 +112,23 @@ func Get(symbol string) (*Data, error) {
}, nil
}
// GetWithTimeframes 获取指定多个时间周期的市场数据
// timeframes: 时间周期列表,如 ["5m", "15m", "1h", "4h"]
// primaryTimeframe: 主时间周期(用于计算当前指标),默认使用 timeframes[0]
// count: 每个时间周期的 K 线数量
// GetWithTimeframes retrieves market data for specified multiple timeframes
// timeframes: list of timeframes, e.g. ["5m", "15m", "1h", "4h"]
// primaryTimeframe: primary timeframe (used for calculating current indicators), defaults to timeframes[0]
// count: number of K-lines for each timeframe
func GetWithTimeframes(symbol string, timeframes []string, primaryTimeframe string, count int) (*Data, error) {
symbol = Normalize(symbol)
if len(timeframes) == 0 {
return nil, fmt.Errorf("至少需要一个时间周期")
return nil, fmt.Errorf("at least one timeframe is required")
}
// 如果未指定主周期,使用第一个
// If primary timeframe is not specified, use the first one
if primaryTimeframe == "" {
primaryTimeframe = timeframes[0]
}
// 确保主周期在列表中
// Ensure primary timeframe is in the list
hasPrimary := false
for _, tf := range timeframes {
if tf == primaryTimeframe {
@@ -140,36 +140,36 @@ func GetWithTimeframes(symbol string, timeframes []string, primaryTimeframe stri
timeframes = append([]string{primaryTimeframe}, timeframes...)
}
// 存储所有时间周期的数据
// Store data for all timeframes
timeframeData := make(map[string]*TimeframeSeriesData)
var primaryKlines []Kline
// 获取每个时间周期的 K 线数据
// Get K-line data for each timeframe
for _, tf := range timeframes {
klines, err := WSMonitorCli.GetCurrentKlines(symbol, tf)
if err != nil {
logger.Infof("⚠️ 获取 %s %s K线失败: %v", symbol, tf, err)
logger.Infof("⚠️ Failed to get %s %s K-line: %v", symbol, tf, err)
continue
}
if len(klines) == 0 {
logger.Infof("⚠️ %s %s K线数据为空", symbol, tf)
logger.Infof("⚠️ %s %s K-line data is empty", symbol, tf)
continue
}
// 保存主周期的 K 线用于计算基础指标
// Save primary timeframe K-lines for calculating base indicators
if tf == primaryTimeframe {
primaryKlines = klines
}
// 计算该时间周期的系列数据
// Calculate series data for this timeframe
seriesData := calculateTimeframeSeries(klines, tf)
timeframeData[tf] = seriesData
}
// 如果主周期数据为空,返回错误
// If primary timeframe data is empty, return error
if len(primaryKlines) == 0 {
return nil, fmt.Errorf("主时间周期 %s K线数据为空", primaryTimeframe)
return nil, fmt.Errorf("Primary timeframe %s K-line data is empty", primaryTimeframe)
}
// Data staleness detection
@@ -178,23 +178,23 @@ func GetWithTimeframes(symbol string, timeframes []string, primaryTimeframe stri
return nil, fmt.Errorf("%s data is stale, possible cache failure", symbol)
}
// 计算当前指标 (基于主周期最新数据)
// Calculate current indicators (based on primary timeframe latest data)
currentPrice := primaryKlines[len(primaryKlines)-1].Close
currentEMA20 := calculateEMA(primaryKlines, 20)
currentMACD := calculateMACD(primaryKlines)
currentRSI7 := calculateRSI(primaryKlines, 7)
// 计算价格变化
priceChange1h := calculatePriceChangeByBars(primaryKlines, primaryTimeframe, 60) // 1小时
priceChange4h := calculatePriceChangeByBars(primaryKlines, primaryTimeframe, 240) // 4小时
// Calculate price changes
priceChange1h := calculatePriceChangeByBars(primaryKlines, primaryTimeframe, 60) // 1 hour
priceChange4h := calculatePriceChangeByBars(primaryKlines, primaryTimeframe, 240) // 4 hours
// 获取OI数据
// Get OI data
oiData, err := getOpenInterestData(symbol)
if err != nil {
oiData = &OIData{Latest: 0, Average: 0}
}
// 获取Funding Rate
// Get Funding Rate
fundingRate, _ := getFundingRate(symbol)
return &Data{
@@ -211,7 +211,7 @@ func GetWithTimeframes(symbol string, timeframes []string, primaryTimeframe stri
}, nil
}
// calculateTimeframeSeries 计算单个时间周期的系列数据
// calculateTimeframeSeries calculates series data for a single timeframe
func calculateTimeframeSeries(klines []Kline, timeframe string) *TimeframeSeriesData {
data := &TimeframeSeriesData{
Timeframe: timeframe,
@@ -224,7 +224,7 @@ func calculateTimeframeSeries(klines []Kline, timeframe string) *TimeframeSeries
Volume: make([]float64, 0, 10),
}
// 获取最近10个数据点
// Get latest 10 data points
start := len(klines) - 10
if start < 0 {
start = 0
@@ -234,25 +234,25 @@ func calculateTimeframeSeries(klines []Kline, timeframe string) *TimeframeSeries
data.MidPrices = append(data.MidPrices, klines[i].Close)
data.Volume = append(data.Volume, klines[i].Volume)
// 计算每个点的 EMA20
// Calculate EMA20 for each point
if i >= 19 {
ema20 := calculateEMA(klines[:i+1], 20)
data.EMA20Values = append(data.EMA20Values, ema20)
}
// 计算每个点的 EMA50
// Calculate EMA50 for each point
if i >= 49 {
ema50 := calculateEMA(klines[:i+1], 50)
data.EMA50Values = append(data.EMA50Values, ema50)
}
// 计算每个点的 MACD
// Calculate MACD for each point
if i >= 25 {
macd := calculateMACD(klines[:i+1])
data.MACDValues = append(data.MACDValues, macd)
}
// 计算每个点的 RSI
// Calculate RSI for each point
if i >= 7 {
rsi7 := calculateRSI(klines[:i+1], 7)
data.RSI7Values = append(data.RSI7Values, rsi7)
@@ -263,25 +263,25 @@ func calculateTimeframeSeries(klines []Kline, timeframe string) *TimeframeSeries
}
}
// 计算 ATR14
// Calculate ATR14
data.ATR14 = calculateATR(klines, 14)
return data
}
// calculatePriceChangeByBars 根据时间周期计算需要回溯多少根 K 线来计算价格变化
// calculatePriceChangeByBars calculates how many K-lines to look back for price change based on timeframe
func calculatePriceChangeByBars(klines []Kline, timeframe string, targetMinutes int) float64 {
if len(klines) < 2 {
return 0
}
// 解析时间周期为分钟数
// Parse timeframe to minutes
tfMinutes := parseTimeframeToMinutes(timeframe)
if tfMinutes <= 0 {
return 0
}
// 计算需要回溯多少根 K 线
// Calculate how many K-lines to look back
barsBack := targetMinutes / tfMinutes
if barsBack < 1 {
barsBack = 1
@@ -300,7 +300,7 @@ func calculatePriceChangeByBars(klines []Kline, timeframe string, targetMinutes
return 0
}
// parseTimeframeToMinutes 将时间周期字符串解析为分钟数
// parseTimeframeToMinutes parses timeframe string to minutes
func parseTimeframeToMinutes(tf string) int {
switch tf {
case "1m":
@@ -336,20 +336,20 @@ func parseTimeframeToMinutes(tf string) int {
}
}
// calculateEMA 计算EMA
// calculateEMA calculates EMA
func calculateEMA(klines []Kline, period int) float64 {
if len(klines) < period {
return 0
}
// 计算SMA作为初始EMA
// Calculate SMA as initial EMA
sum := 0.0
for i := 0; i < period; i++ {
sum += klines[i].Close
}
ema := sum / float64(period)
// 计算EMA
// Calculate EMA
multiplier := 2.0 / float64(period+1)
for i := period; i < len(klines); i++ {
ema = (klines[i].Close-ema)*multiplier + ema
@@ -358,13 +358,13 @@ func calculateEMA(klines []Kline, period int) float64 {
return ema
}
// calculateMACD 计算MACD
// calculateMACD calculates MACD
func calculateMACD(klines []Kline) float64 {
if len(klines) < 26 {
return 0
}
// 计算12期和26期EMA
// Calculate 12-period and 26-period EMA
ema12 := calculateEMA(klines, 12)
ema26 := calculateEMA(klines, 26)
@@ -372,7 +372,7 @@ func calculateMACD(klines []Kline) float64 {
return ema12 - ema26
}
// calculateRSI 计算RSI
// calculateRSI calculates RSI
func calculateRSI(klines []Kline, period int) float64 {
if len(klines) <= period {
return 0
@@ -381,7 +381,7 @@ func calculateRSI(klines []Kline, period int) float64 {
gains := 0.0
losses := 0.0
// 计算初始平均涨跌幅
// Calculate initial average gain/loss
for i := 1; i <= period; i++ {
change := klines[i].Close - klines[i-1].Close
if change > 0 {
@@ -394,7 +394,7 @@ func calculateRSI(klines []Kline, period int) float64 {
avgGain := gains / float64(period)
avgLoss := losses / float64(period)
// 使用Wilder平滑方法计算后续RSI
// Use Wilder smoothing method to calculate subsequent RSI
for i := period + 1; i < len(klines); i++ {
change := klines[i].Close - klines[i-1].Close
if change > 0 {
@@ -416,7 +416,7 @@ func calculateRSI(klines []Kline, period int) float64 {
return rsi
}
// calculateATR 计算ATR
// calculateATR calculates ATR
func calculateATR(klines []Kline, period int) float64 {
if len(klines) <= period {
return 0
@@ -435,14 +435,14 @@ func calculateATR(klines []Kline, period int) float64 {
trs[i] = math.Max(tr1, math.Max(tr2, tr3))
}
// 计算初始ATR
// Calculate initial ATR
sum := 0.0
for i := 1; i <= period; i++ {
sum += trs[i]
}
atr := sum / float64(period)
// Wilder平滑
// Wilder smoothing
for i := period + 1; i < len(klines); i++ {
atr = (atr*float64(period-1) + trs[i]) / float64(period)
}
@@ -450,7 +450,7 @@ func calculateATR(klines []Kline, period int) float64 {
return atr
}
// calculateIntradaySeries 计算日内系列数据
// calculateIntradaySeries calculates intraday series data
func calculateIntradaySeries(klines []Kline) *IntradayData {
data := &IntradayData{
MidPrices: make([]float64, 0, 10),
@@ -461,7 +461,7 @@ func calculateIntradaySeries(klines []Kline) *IntradayData {
Volume: make([]float64, 0, 10),
}
// 获取最近10个数据点
// Get latest 10 data points
start := len(klines) - 10
if start < 0 {
start = 0
@@ -471,19 +471,19 @@ func calculateIntradaySeries(klines []Kline) *IntradayData {
data.MidPrices = append(data.MidPrices, klines[i].Close)
data.Volume = append(data.Volume, klines[i].Volume)
// 计算每个点的EMA20
// Calculate EMA20 for each point
if i >= 19 {
ema20 := calculateEMA(klines[:i+1], 20)
data.EMA20Values = append(data.EMA20Values, ema20)
}
// 计算每个点的MACD
// Calculate MACD for each point
if i >= 25 {
macd := calculateMACD(klines[:i+1])
data.MACDValues = append(data.MACDValues, macd)
}
// 计算每个点的RSI
// Calculate RSI for each point
if i >= 7 {
rsi7 := calculateRSI(klines[:i+1], 7)
data.RSI7Values = append(data.RSI7Values, rsi7)
@@ -494,31 +494,31 @@ func calculateIntradaySeries(klines []Kline) *IntradayData {
}
}
// 计算3m ATR14
// Calculate 3m ATR14
data.ATR14 = calculateATR(klines, 14)
return data
}
// calculateLongerTermData 计算长期数据
// calculateLongerTermData calculates longer-term data
func calculateLongerTermData(klines []Kline) *LongerTermData {
data := &LongerTermData{
MACDValues: make([]float64, 0, 10),
RSI14Values: make([]float64, 0, 10),
}
// 计算EMA
// Calculate EMA
data.EMA20 = calculateEMA(klines, 20)
data.EMA50 = calculateEMA(klines, 50)
// 计算ATR
// Calculate ATR
data.ATR3 = calculateATR(klines, 3)
data.ATR14 = calculateATR(klines, 14)
// 计算成交量
// Calculate volume
if len(klines) > 0 {
data.CurrentVolume = klines[len(klines)-1].Volume
// 计算平均成交量
// Calculate average volume
sum := 0.0
for _, k := range klines {
sum += k.Volume
@@ -526,7 +526,7 @@ func calculateLongerTermData(klines []Kline) *LongerTermData {
data.AverageVolume = sum / float64(len(klines))
}
// 计算MACD和RSI序列
// Calculate MACD and RSI series
start := len(klines) - 10
if start < 0 {
start = 0
@@ -546,7 +546,7 @@ func calculateLongerTermData(klines []Kline) *LongerTermData {
return data
}
// getOpenInterestData 获取OI数据
// getOpenInterestData retrieves OI data
func getOpenInterestData(symbol string) (*OIData, error) {
url := fmt.Sprintf("https://fapi.binance.com/fapi/v1/openInterest?symbol=%s", symbol)
@@ -576,23 +576,23 @@ func getOpenInterestData(symbol string) (*OIData, error) {
return &OIData{
Latest: oi,
Average: oi * 0.999, // 近似平均值
Average: oi * 0.999, // Approximate average
}, nil
}
// getFundingRate 获取资金费率(优化:使用 1 小时缓存)
// getFundingRate retrieves funding rate (optimized: uses 1-hour cache)
func getFundingRate(symbol string) (float64, error) {
// 检查缓存(有效期 1 小时)
// Funding Rate 每 8 小时才更新,1 小时缓存非常合理
// Check cache (1-hour validity)
// Funding Rate only updates every 8 hours, 1-hour cache is very reasonable
if cached, ok := fundingRateMap.Load(symbol); ok {
cache := cached.(*FundingRateCache)
if time.Since(cache.UpdatedAt) < frCacheTTL {
// 缓存命中,直接返回
// Cache hit, return directly
return cache.Rate, nil
}
}
// 缓存过期或不存在,调用 API
// Cache expired or doesn't exist, call API
url := fmt.Sprintf("https://fapi.binance.com/fapi/v1/premiumIndex?symbol=%s", symbol)
apiClient := NewAPIClient()
@@ -623,7 +623,7 @@ func getFundingRate(symbol string) (float64, error) {
rate, _ := strconv.ParseFloat(result.LastFundingRate, 64)
// 更新缓存
// Update cache
fundingRateMap.Store(symbol, &FundingRateCache{
Rate: rate,
UpdatedAt: time.Now(),
@@ -632,11 +632,11 @@ func getFundingRate(symbol string) (float64, error) {
return rate, nil
}
// Format 格式化输出市场数据
// Format formats and outputs market data
func Format(data *Data) string {
var sb strings.Builder
// 使用动态精度格式化价格
// Format price with dynamic precision
priceStr := formatPriceWithDynamicPrecision(data.CurrentPrice)
sb.WriteString(fmt.Sprintf("current_price = %s, current_ema20 = %.3f, current_macd = %.3f, current_rsi (7 period) = %.3f\n\n",
priceStr, data.CurrentEMA20, data.CurrentMACD, data.CurrentRSI7))
@@ -645,7 +645,7 @@ func Format(data *Data) string {
data.Symbol))
if data.OpenInterest != nil {
// 使用动态精度格式化 OI 数据
// Format OI data with dynamic precision
oiLatestStr := formatPriceWithDynamicPrecision(data.OpenInterest.Latest)
oiAverageStr := formatPriceWithDynamicPrecision(data.OpenInterest.Average)
sb.WriteString(fmt.Sprintf("Open Interest: Latest: %s Average: %s\n\n",
@@ -705,9 +705,9 @@ func Format(data *Data) string {
}
}
// 多时间周期数据(新增)
// Multi-timeframe data (new)
if len(data.TimeframeData) > 0 {
// 按时间周期排序输出
// Output sorted by timeframe
timeframeOrder := []string{"1m", "3m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "8h", "12h", "1d", "3d", "1w"}
for _, tf := range timeframeOrder {
if tfData, ok := data.TimeframeData[tf]; ok {
@@ -720,7 +720,7 @@ func Format(data *Data) string {
return sb.String()
}
// formatTimeframeData 格式化单个时间周期的数据
// formatTimeframeData formats data for a single timeframe
func formatTimeframeData(sb *strings.Builder, data *TimeframeSeriesData) {
if len(data.MidPrices) > 0 {
sb.WriteString(fmt.Sprintf("Mid prices: %s\n\n", formatFloatSlice(data.MidPrices)))
@@ -753,38 +753,38 @@ func formatTimeframeData(sb *strings.Builder, data *TimeframeSeriesData) {
sb.WriteString(fmt.Sprintf("ATR (14period): %.3f\n\n", data.ATR14))
}
// formatPriceWithDynamicPrecision 根据价格区间动态选择精度
// 这样可以完美支持从超低价 meme coin (< 0.0001) BTC/ETH 的所有币种
// formatPriceWithDynamicPrecision dynamically selects precision based on price range
// This perfectly supports all coins from ultra-low price meme coins (< 0.0001) to BTC/ETH
func formatPriceWithDynamicPrecision(price float64) string {
switch {
case price < 0.0001:
// 超低价 meme coin: 1000SATS, 1000WHY, DOGS
// 0.00002070 → "0.00002070" (8位小数)
// Ultra-low price meme coins: 1000SATS, 1000WHY, DOGS
// 0.00002070 → "0.00002070" (8 decimal places)
return fmt.Sprintf("%.8f", price)
case price < 0.001:
// 低价 meme coin: NEIRO, HMSTR, HOT, NOT
// 0.00015060 → "0.000151" (6位小数)
// Low price meme coins: NEIRO, HMSTR, HOT, NOT
// 0.00015060 → "0.000151" (6 decimal places)
return fmt.Sprintf("%.6f", price)
case price < 0.01:
// 中低价币: PEPE, SHIB, MEME
// 0.00556800 → "0.005568" (6位小数)
// Mid-low price coins: PEPE, SHIB, MEME
// 0.00556800 → "0.005568" (6 decimal places)
return fmt.Sprintf("%.6f", price)
case price < 1.0:
// 低价币: ASTER, DOGE, ADA, TRX
// 0.9954 → "0.9954" (4位小数)
// Low price coins: ASTER, DOGE, ADA, TRX
// 0.9954 → "0.9954" (4 decimal places)
return fmt.Sprintf("%.4f", price)
case price < 100:
// 中价币: SOL, AVAX, LINK, MATIC
// 23.4567 → "23.4567" (4位小数)
// Mid price coins: SOL, AVAX, LINK, MATIC
// 23.4567 → "23.4567" (4 decimal places)
return fmt.Sprintf("%.4f", price)
default:
// 高价币: BTC, ETH (节省 Token)
// 45678.9123 → "45678.91" (2位小数)
// High price coins: BTC, ETH (save tokens)
// 45678.9123 → "45678.91" (2 decimal places)
return fmt.Sprintf("%.2f", price)
}
}
// formatFloatSlice 格式化float64切片为字符串(使用动态精度)
// formatFloatSlice formats float64 slice to string (using dynamic precision)
func formatFloatSlice(values []float64) string {
strValues := make([]string, len(values))
for i, v := range values {
@@ -793,7 +793,7 @@ func formatFloatSlice(values []float64) string {
return "[" + strings.Join(strValues, ", ") + "]"
}
// Normalize 标准化symbol,确保是USDT交易对
// Normalize normalizes symbol, ensures it's a USDT trading pair
func Normalize(symbol string) string {
symbol = strings.ToUpper(symbol)
if strings.HasSuffix(symbol, "USDT") {
@@ -802,7 +802,7 @@ func Normalize(symbol string) string {
return symbol + "USDT"
}
// parseFloat 解析float值
// parseFloat parses float value
func parseFloat(v interface{}) (float64, error) {
switch val := v.(type) {
case string:
@@ -818,7 +818,7 @@ func parseFloat(v interface{}) (float64, error) {
}
}
// BuildDataFromKlines 根据预加载的K线序列构造市场数据快照(用于回测/模拟)。
// BuildDataFromKlines constructs market data snapshot from preloaded K-line series (for backtesting/simulation).
func BuildDataFromKlines(symbol string, primary []Kline, longer []Kline) (*Data, error) {
if len(primary) == 0 {
return nil, fmt.Errorf("primary series is empty")
+40 -40
View File
@@ -5,11 +5,11 @@ import (
"testing"
)
// generateTestKlines 生成测试用的 K线数据
// generateTestKlines generates test K-line data
func generateTestKlines(count int) []Kline {
klines := make([]Kline, count)
for i := 0; i < count; i++ {
// 生成模拟的价格数据,有一定的波动
// Generate simulated price data with some fluctuation
basePrice := 100.0
variance := float64(i%10) * 0.5
open := basePrice + variance
@@ -19,7 +19,7 @@ func generateTestKlines(count int) []Kline {
volume := 1000.0 + float64(i*100)
klines[i] = Kline{
OpenTime: int64(i * 180000), // 3分钟间隔
OpenTime: int64(i * 180000), // 3-minute interval
Open: open,
High: high,
Low: low,
@@ -31,7 +31,7 @@ func generateTestKlines(count int) []Kline {
return klines
}
// TestCalculateIntradaySeries_VolumeCollection 测试 Volume 数据收集
// TestCalculateIntradaySeries_VolumeCollection tests Volume data collection
func TestCalculateIntradaySeries_VolumeCollection(t *testing.T) {
tests := []struct {
name string
@@ -39,24 +39,24 @@ func TestCalculateIntradaySeries_VolumeCollection(t *testing.T) {
expectedVolLen int
}{
{
name: "正常情况 - 20个K线",
name: "Normal case - 20 K-lines",
klineCount: 20,
expectedVolLen: 10, // 应该收集最近10
expectedVolLen: 10, // Should collect latest 10
},
{
name: "刚好10个K线",
name: "Exactly 10 K-lines",
klineCount: 10,
expectedVolLen: 10,
},
{
name: "少于10个K线",
name: "Less than 10 K-lines",
klineCount: 5,
expectedVolLen: 5, // 应该返回所有5个
expectedVolLen: 5, // Should return all 5
},
{
name: "超过10个K线",
name: "More than 10 K-lines",
klineCount: 30,
expectedVolLen: 10, // 应该只返回最近10
expectedVolLen: 10, // Should only return latest 10
},
}
@@ -73,21 +73,21 @@ func TestCalculateIntradaySeries_VolumeCollection(t *testing.T) {
t.Errorf("Volume length = %d, want %d", len(data.Volume), tt.expectedVolLen)
}
// 验证 Volume 数据正确性
// Verify Volume data correctness
if len(data.Volume) > 0 {
// 计算期望的起始索引
// Calculate expected start index
start := tt.klineCount - 10
if start < 0 {
start = 0
}
// 验证第一个 Volume
// Verify first Volume value
expectedFirstVolume := klines[start].Volume
if data.Volume[0] != expectedFirstVolume {
t.Errorf("First volume = %.2f, want %.2f", data.Volume[0], expectedFirstVolume)
}
// 验证最后一个 Volume
// Verify last Volume value
expectedLastVolume := klines[tt.klineCount-1].Volume
lastVolume := data.Volume[len(data.Volume)-1]
if lastVolume != expectedLastVolume {
@@ -98,7 +98,7 @@ func TestCalculateIntradaySeries_VolumeCollection(t *testing.T) {
}
}
// TestCalculateIntradaySeries_VolumeValues 测试 Volume 值的正确性
// TestCalculateIntradaySeries_VolumeValues tests Volume value correctness
func TestCalculateIntradaySeries_VolumeValues(t *testing.T) {
klines := []Kline{
{Close: 100.0, Volume: 1000.0, High: 101.0, Low: 99.0, Open: 100.0},
@@ -128,7 +128,7 @@ func TestCalculateIntradaySeries_VolumeValues(t *testing.T) {
}
}
// TestCalculateIntradaySeries_ATR14 测试 ATR14 计算
// TestCalculateIntradaySeries_ATR14 tests ATR14 calculation
func TestCalculateIntradaySeries_ATR14(t *testing.T) {
tests := []struct {
name string
@@ -137,27 +137,27 @@ func TestCalculateIntradaySeries_ATR14(t *testing.T) {
expectNonZero bool
}{
{
name: "足够数据 - 20个K线",
name: "Sufficient data - 20 K-lines",
klineCount: 20,
expectNonZero: true,
},
{
name: "刚好15个K线(ATR14需要至少15个)",
name: "Exactly 15 K-lines (ATR14 requires at least 15)",
klineCount: 15,
expectNonZero: true,
},
{
name: "数据不足 - 14个K线",
name: "Insufficient data - 14 K-lines",
klineCount: 14,
expectZero: true,
},
{
name: "数据不足 - 10个K线",
name: "Insufficient data - 10 K-lines",
klineCount: 10,
expectZero: true,
},
{
name: "数据不足 - 5个K线",
name: "Insufficient data - 5 K-lines",
klineCount: 5,
expectZero: true,
},
@@ -183,7 +183,7 @@ func TestCalculateIntradaySeries_ATR14(t *testing.T) {
}
}
// TestCalculateATR 测试 ATR 计算函数
// TestCalculateATR tests ATR calculation function
func TestCalculateATR(t *testing.T) {
tests := []struct {
name string
@@ -192,7 +192,7 @@ func TestCalculateATR(t *testing.T) {
expectZero bool
}{
{
name: "正常计算 - 足够数据",
name: "Normal calculation - sufficient data",
klines: []Kline{
{High: 102.0, Low: 100.0, Close: 101.0},
{High: 103.0, Low: 101.0, Close: 102.0},
@@ -214,7 +214,7 @@ func TestCalculateATR(t *testing.T) {
expectZero: false,
},
{
name: "数据不足 - 等于period",
name: "Insufficient data - equal to period",
klines: []Kline{
{High: 102.0, Low: 100.0, Close: 101.0},
{High: 103.0, Low: 101.0, Close: 102.0},
@@ -223,7 +223,7 @@ func TestCalculateATR(t *testing.T) {
expectZero: true,
},
{
name: "数据不足 - 少于period",
name: "Insufficient data - less than period",
klines: []Kline{
{High: 102.0, Low: 100.0, Close: 101.0},
},
@@ -249,9 +249,9 @@ func TestCalculateATR(t *testing.T) {
}
}
// TestCalculateATR_TrueRange 测试 ATR True Range 计算正确性
// TestCalculateATR_TrueRange tests ATR True Range calculation correctness
func TestCalculateATR_TrueRange(t *testing.T) {
// 创建一个简单的测试用例,手动计算期望的 ATR
// Create a simple test case, manually calculate expected ATR
klines := []Kline{
{High: 50.0, Low: 48.0, Close: 49.0}, // TR = 2.0
{High: 51.0, Low: 49.0, Close: 50.0}, // TR = max(2.0, 2.0, 1.0) = 2.0
@@ -262,28 +262,28 @@ func TestCalculateATR_TrueRange(t *testing.T) {
atr := calculateATR(klines, 3)
// 期望的计算:
// Expected calculation:
// TR[1] = max(51-49, |51-49|, |49-49|) = 2.0
// TR[2] = max(52-50, |52-50|, |50-50|) = 2.0
// TR[3] = max(53-51, |53-51|, |51-51|) = 2.0
// 初始 ATR = (2.0 + 2.0 + 2.0) / 3 = 2.0
// Initial ATR = (2.0 + 2.0 + 2.0) / 3 = 2.0
// TR[4] = max(54-52, |54-52|, |52-52|) = 2.0
// 平滑 ATR = (2.0*2 + 2.0) / 3 = 2.0
// Smoothed ATR = (2.0*2 + 2.0) / 3 = 2.0
expectedATR := 2.0
tolerance := 0.01 // 允许小的浮点误差
tolerance := 0.01 // Allow small floating point error
if math.Abs(atr-expectedATR) > tolerance {
t.Errorf("calculateATR() = %.3f, want approximately %.3f", atr, expectedATR)
}
}
// TestCalculateIntradaySeries_ConsistencyWithOtherIndicators 测试 Volume 和其他指标的一致性
// TestCalculateIntradaySeries_ConsistencyWithOtherIndicators tests Volume and other indicators consistency
func TestCalculateIntradaySeries_ConsistencyWithOtherIndicators(t *testing.T) {
klines := generateTestKlines(30)
data := calculateIntradaySeries(klines)
// 所有数组应该存在
// All arrays should exist
if data.MidPrices == nil {
t.Error("MidPrices should not be nil")
}
@@ -291,13 +291,13 @@ func TestCalculateIntradaySeries_ConsistencyWithOtherIndicators(t *testing.T) {
t.Error("Volume should not be nil")
}
// MidPrices Volume 应该有相同的长度(都是最近10个)
// MidPrices and Volume should have the same length (both latest 10)
if len(data.MidPrices) != len(data.Volume) {
t.Errorf("MidPrices length (%d) should equal Volume length (%d)",
len(data.MidPrices), len(data.Volume))
}
// 所有 Volume 值应该大于 0
// All Volume values should be > 0
for i, vol := range data.Volume {
if vol <= 0 {
t.Errorf("Volume[%d] = %.2f, should be > 0", i, vol)
@@ -305,7 +305,7 @@ func TestCalculateIntradaySeries_ConsistencyWithOtherIndicators(t *testing.T) {
}
}
// TestCalculateIntradaySeries_EmptyKlines 测试空 K线数据
// TestCalculateIntradaySeries_EmptyKlines tests empty K-line data
func TestCalculateIntradaySeries_EmptyKlines(t *testing.T) {
klines := []Kline{}
data := calculateIntradaySeries(klines)
@@ -314,7 +314,7 @@ func TestCalculateIntradaySeries_EmptyKlines(t *testing.T) {
t.Fatal("calculateIntradaySeries should not return nil for empty klines")
}
// 所有切片应该为空
// All slices should be empty
if len(data.MidPrices) != 0 {
t.Errorf("MidPrices length = %d, want 0", len(data.MidPrices))
}
@@ -322,13 +322,13 @@ func TestCalculateIntradaySeries_EmptyKlines(t *testing.T) {
t.Errorf("Volume length = %d, want 0", len(data.Volume))
}
// ATR14 应该为 0(数据不足)
// ATR14 should be 0 (insufficient data)
if data.ATR14 != 0 {
t.Errorf("ATR14 = %.3f, want 0", data.ATR14)
}
}
// TestCalculateIntradaySeries_VolumePrecision 测试 Volume 精度保持
// TestCalculateIntradaySeries_VolumePrecision tests Volume precision preservation
func TestCalculateIntradaySeries_VolumePrecision(t *testing.T) {
klines := []Kline{
{Close: 100.0, Volume: 1234.5678, High: 101.0, Low: 99.0},
+2 -2
View File
@@ -13,7 +13,7 @@ const (
binanceMaxKlineLimit = 1500
)
// GetKlinesRange 拉取指定时间范围内的 K 线序列(闭区间),返回按时间升序排列的数据。
// GetKlinesRange fetches K-line series within specified time range (closed interval), returns data sorted by time in ascending order.
func GetKlinesRange(symbol string, timeframe string, start, end time.Time) ([]Kline, error) {
symbol = Normalize(symbol)
normTF, err := NormalizeTimeframe(timeframe)
@@ -94,7 +94,7 @@ func GetKlinesRange(symbol string, timeframe string, start, end time.Time) ([]Kl
last := batch[len(batch)-1]
cursor = last.CloseTime + 1
// 若返回数量少于请求上限,说明已到达末尾,可提前退出。
// If returned quantity is less than request limit, reached the end, can exit early.
if len(batch) < binanceMaxKlineLimit {
break
}
+49 -49
View File
@@ -15,24 +15,24 @@ type WSMonitor struct {
symbols []string
featuresMap sync.Map
alertsChan chan Alert
klineDataMap3m sync.Map // 存储每个交易对的K线历史数据
klineDataMap4h sync.Map // 存储每个交易对的K线历史数据
tickerDataMap sync.Map // 存储每个交易对的ticker数据
klineDataMap3m sync.Map // Store K-line historical data for each trading pair
klineDataMap4h sync.Map // Store K-line historical data for each trading pair
tickerDataMap sync.Map // Store ticker data for each trading pair
batchSize int
filterSymbols sync.Map // 使用sync.Map来存储需要监控的币种和其状态
symbolStats sync.Map // 存储币种统计信息
FilterSymbol []string //经过筛选的币种
filterSymbols sync.Map // Use sync.Map to store monitored coins and their status
symbolStats sync.Map // Store symbol statistics
FilterSymbol []string // Filtered symbols
}
type SymbolStats struct {
LastActiveTime time.Time
AlertCount int
VolumeSpikeCount int
LastAlertTime time.Time
Score float64 // 综合评分
Score float64 // Composite score
}
var WSMonitorCli *WSMonitor
var subKlineTime = []string{"3m", "4h"} // 管理订阅流的K线周期
var subKlineTime = []string{"3m", "4h"} // Manage K-line periods for subscription streams
func NewWSMonitor(batchSize int) *WSMonitor {
WSMonitorCli = &WSMonitor{
@@ -45,16 +45,16 @@ func NewWSMonitor(batchSize int) *WSMonitor {
}
func (m *WSMonitor) Initialize(coins []string) error {
log.Println("初始化WebSocket监控器...")
// 获取交易对信息
log.Println("Initializing WebSocket monitor...")
// Get trading pair information
apiClient := NewAPIClient()
// 如果不指定交易对,则使用market市场的所有交易对币种
// If trading pairs are not specified, use all trading pairs from the market
if len(coins) == 0 {
exchangeInfo, err := apiClient.GetExchangeInfo()
if err != nil {
return err
}
// 筛选永续合约交易对 --仅测试时使用
// Filter perpetual contract trading pairs -- only use for testing
//exchangeInfo.Symbols = exchangeInfo.Symbols[0:2]
for _, symbol := range exchangeInfo.Symbols {
if symbol.Status == "TRADING" && symbol.ContractType == "PERPETUAL" && strings.ToUpper(symbol.Symbol[len(symbol.Symbol)-4:]) == "USDT" {
@@ -66,10 +66,10 @@ func (m *WSMonitor) Initialize(coins []string) error {
m.symbols = coins
}
log.Printf("找到 %d 个交易对", len(m.symbols))
// 初始化历史数据
log.Printf("Found %d trading pairs", len(m.symbols))
// Initialize historical data
if err := m.initializeHistoricalData(); err != nil {
log.Printf("初始化历史数据失败: %v", err)
log.Printf("Failed to initialize historical data: %v", err)
}
return nil
@@ -79,7 +79,7 @@ func (m *WSMonitor) initializeHistoricalData() error {
apiClient := NewAPIClient()
var wg sync.WaitGroup
semaphore := make(chan struct{}, 5) // 限制并发数
semaphore := make(chan struct{}, 5) // Limit concurrency
for _, symbol := range m.symbols {
wg.Add(1)
@@ -89,25 +89,25 @@ func (m *WSMonitor) initializeHistoricalData() error {
defer wg.Done()
defer func() { <-semaphore }()
// 获取历史K线数据
// Get historical K-line data
klines, err := apiClient.GetKlines(s, "3m", 100)
if err != nil {
log.Printf("获取 %s 历史数据失败: %v", s, err)
log.Printf("Failed to get %s historical data: %v", s, err)
return
}
if len(klines) > 0 {
m.klineDataMap3m.Store(s, klines)
log.Printf("已加载 %s 的历史K线数据-3m: %d ", s, len(klines))
log.Printf("Loaded %s historical K-line data-3m: %d entries", s, len(klines))
}
// 获取历史K线数据
// Get historical K-line data
klines4h, err := apiClient.GetKlines(s, "4h", 100)
if err != nil {
log.Printf("获取 %s 历史数据失败: %v", s, err)
log.Printf("Failed to get %s historical data: %v", s, err)
return
}
if len(klines4h) > 0 {
m.klineDataMap4h.Store(s, klines4h)
log.Printf("已加载 %s 的历史K线数据-4h: %d ", s, len(klines4h))
log.Printf("Loaded %s historical K-line data-4h: %d entries", s, len(klines4h))
}
}(symbol)
}
@@ -117,28 +117,28 @@ func (m *WSMonitor) initializeHistoricalData() error {
}
func (m *WSMonitor) Start(coins []string) {
log.Printf("启动WebSocket实时监控...")
// 初始化交易对
log.Printf("Starting WebSocket real-time monitoring...")
// Initialize trading pairs
err := m.Initialize(coins)
if err != nil {
log.Printf("❌ 初始化币种失败: %v", err)
log.Printf("❌ Failed to initialize coins: %v", err)
return
}
err = m.combinedClient.Connect()
if err != nil {
log.Printf("❌ 批量订阅流失败: %v", err)
log.Printf("❌ Failed to batch subscribe to streams: %v", err)
return
}
// 订阅所有交易对
// Subscribe to all trading pairs
err = m.subscribeAll()
if err != nil {
log.Printf("❌ 订阅币种交易对失败: %v", err)
log.Printf("❌ Failed to subscribe to coin trading pairs: %v", err)
return
}
}
// subscribeSymbol 注册监听
// subscribeSymbol registers listener
func (m *WSMonitor) subscribeSymbol(symbol, st string) []string {
var streams []string
stream := fmt.Sprintf("%s@kline_%s", strings.ToLower(symbol), st)
@@ -149,8 +149,8 @@ func (m *WSMonitor) subscribeSymbol(symbol, st string) []string {
return streams
}
func (m *WSMonitor) subscribeAll() error {
// 执行批量订阅
log.Println("开始订阅所有交易对...")
// Execute batch subscription
log.Println("Starting to subscribe to all trading pairs...")
for _, symbol := range m.symbols {
for _, st := range subKlineTime {
m.subscribeSymbol(symbol, st)
@@ -159,11 +159,11 @@ func (m *WSMonitor) subscribeAll() error {
for _, st := range subKlineTime {
err := m.combinedClient.BatchSubscribeKlines(m.symbols, st)
if err != nil {
log.Printf("❌ 订阅 %s K线失败: %v", st, err)
log.Printf("❌ Failed to subscribe to %s K-line: %v", st, err)
return err
}
}
log.Println("所有交易对订阅完成")
log.Println("All trading pair subscriptions completed")
return nil
}
@@ -171,7 +171,7 @@ func (m *WSMonitor) handleKlineData(symbol string, ch <-chan []byte, _time strin
for data := range ch {
var klineData KlineWSData
if err := json.Unmarshal(data, &klineData); err != nil {
log.Printf("解析Kline数据失败: %v", err)
log.Printf("Failed to parse Kline data: %v", err)
continue
}
m.processKlineUpdate(symbol, klineData, _time)
@@ -190,7 +190,7 @@ func (m *WSMonitor) getKlineDataMap(_time string) *sync.Map {
return klineDataMap
}
func (m *WSMonitor) processKlineUpdate(symbol string, wsData KlineWSData, _time string) {
// 转换WebSocket数据为Kline结构
// Convert WebSocket data to Kline structure
kline := Kline{
OpenTime: wsData.Kline.StartTime,
CloseTime: wsData.Kline.CloseTime,
@@ -205,22 +205,22 @@ func (m *WSMonitor) processKlineUpdate(symbol string, wsData KlineWSData, _time
kline.QuoteVolume, _ = parseFloat(wsData.Kline.QuoteVolume)
kline.TakerBuyBaseVolume, _ = parseFloat(wsData.Kline.TakerBuyBaseVolume)
kline.TakerBuyQuoteVolume, _ = parseFloat(wsData.Kline.TakerBuyQuoteVolume)
// 更新K线数据
// Update K-line data
var klineDataMap = m.getKlineDataMap(_time)
value, exists := klineDataMap.Load(symbol)
var klines []Kline
if exists {
klines = value.([]Kline)
// 检查是否是新的K线
// Check if it's a new K-line
if len(klines) > 0 && klines[len(klines)-1].OpenTime == kline.OpenTime {
// 更新当前K线
// Update current K-line
klines[len(klines)-1] = kline
} else {
// 添加新K线
// Add new K-line
klines = append(klines, kline)
// 保持数据长度
// Maintain data length
if len(klines) > 100 {
klines = klines[1:]
}
@@ -233,34 +233,34 @@ func (m *WSMonitor) processKlineUpdate(symbol string, wsData KlineWSData, _time
}
func (m *WSMonitor) GetCurrentKlines(symbol string, duration string) ([]Kline, error) {
// 对每一个进来的symbol检测是否存在内类 是否的话就订阅它
// Check if each incoming symbol exists internally, if not subscribe to it
value, exists := m.getKlineDataMap(duration).Load(symbol)
if !exists {
// 如果Ws数据未初始化完成时,单独使用api获取 - 兼容性代码 (防止在未初始化完成是,已经有交易员运行)
// If WS data is not initialized, use API separately - compatibility code (prevents trader from running when not initialized)
apiClient := NewAPIClient()
klines, err := apiClient.GetKlines(symbol, duration, 100)
if err != nil {
return nil, fmt.Errorf("获取%v分钟K线失败: %v", duration, err)
return nil, fmt.Errorf("Failed to get %v-minute K-line: %v", duration, err)
}
// 动态缓存进缓存
// Dynamically cache into cache
m.getKlineDataMap(duration).Store(strings.ToUpper(symbol), klines)
// 订阅 WebSocket
// Subscribe to WebSocket stream
subStr := m.subscribeSymbol(symbol, duration)
subErr := m.combinedClient.subscribeStreams(subStr)
log.Printf("动态订阅流: %v", subStr)
log.Printf("Dynamic subscription to stream: %v", subStr)
if subErr != nil {
log.Printf("警告: 动态订阅%v分钟K线失败: %v (使用API数据)", duration, subErr)
log.Printf("Warning: Failed to dynamically subscribe to %v-minute K-line: %v (using API data)", duration, subErr)
}
// ✅ FIX: 返回深拷贝而非引用
// ✅ FIX: Return deep copy instead of reference
result := make([]Kline, len(klines))
copy(result, klines)
return result, nil
}
// ✅ FIX: 返回深拷贝而非引用,避免并发竞态条件
// ✅ FIX: Return deep copy instead of reference, avoid concurrent race conditions
klines := value.([]Kline)
result := make([]Kline, len(klines))
copy(result, klines)
+5 -5
View File
@@ -7,7 +7,7 @@ import (
"time"
)
// supportedTimeframes 定义支持的时间周期与其对应的分钟数。
// supportedTimeframes defines supported timeframes and their corresponding durations.
var supportedTimeframes = map[string]time.Duration{
"1m": time.Minute,
"3m": 3 * time.Minute,
@@ -22,7 +22,7 @@ var supportedTimeframes = map[string]time.Duration{
"1d": 24 * time.Hour,
}
// NormalizeTimeframe 规范化传入的时间周期字符串(大小写、不带空格),并校验是否受支持。
// NormalizeTimeframe normalizes the incoming timeframe string (case-insensitive, no spaces), and validates if it's supported.
func NormalizeTimeframe(tf string) (string, error) {
trimmed := strings.TrimSpace(strings.ToLower(tf))
if trimmed == "" {
@@ -34,7 +34,7 @@ func NormalizeTimeframe(tf string) (string, error) {
return trimmed, nil
}
// TFDuration 返回给定周期对应的时间长度。
// TFDuration returns the time duration corresponding to the given timeframe.
func TFDuration(tf string) (time.Duration, error) {
norm, err := NormalizeTimeframe(tf)
if err != nil {
@@ -43,7 +43,7 @@ func TFDuration(tf string) (time.Duration, error) {
return supportedTimeframes[norm], nil
}
// MustNormalizeTimeframe NormalizeTimeframe 类似,但在不受支持时 panic。
// MustNormalizeTimeframe is similar to NormalizeTimeframe, but panics when unsupported.
func MustNormalizeTimeframe(tf string) string {
norm, err := NormalizeTimeframe(tf)
if err != nil {
@@ -52,7 +52,7 @@ func MustNormalizeTimeframe(tf string) string {
return norm
}
// SupportedTimeframes 返回所有受支持的时间周期(已排序的切片)。
// SupportedTimeframes returns all supported timeframes (sorted slice).
func SupportedTimeframes() []string {
keys := make([]string, 0, len(supportedTimeframes))
for k := range supportedTimeframes {
+23 -23
View File
@@ -2,12 +2,12 @@ package market
import "time"
// Data 市场数据结构
// Data market data structure
type Data struct {
Symbol string
CurrentPrice float64
PriceChange1h float64 // 1小时价格变化百分比
PriceChange4h float64 // 4小时价格变化百分比
PriceChange1h float64 // 1-hour price change percentage
PriceChange4h float64 // 4-hour price change percentage
CurrentEMA20 float64
CurrentMACD float64
CurrentRSI7 float64
@@ -15,30 +15,30 @@ type Data struct {
FundingRate float64
IntradaySeries *IntradayData
LongerTermContext *LongerTermData
// 多时间周期数据(新增)
// Multi-timeframe data (new)
TimeframeData map[string]*TimeframeSeriesData `json:"timeframe_data,omitempty"`
}
// TimeframeSeriesData 单个时间周期的序列数据
// TimeframeSeriesData series data for a single timeframe
type TimeframeSeriesData struct {
Timeframe string `json:"timeframe"` // 时间周期标识,如 "5m", "15m", "1h"
MidPrices []float64 `json:"mid_prices"` // 价格序列
EMA20Values []float64 `json:"ema20_values"` // EMA20 序列
EMA50Values []float64 `json:"ema50_values"` // EMA50 序列
MACDValues []float64 `json:"macd_values"` // MACD 序列
RSI7Values []float64 `json:"rsi7_values"` // RSI7 序列
RSI14Values []float64 `json:"rsi14_values"` // RSI14 序列
Volume []float64 `json:"volume"` // 成交量序列
Timeframe string `json:"timeframe"` // Timeframe identifier, e.g. "5m", "15m", "1h"
MidPrices []float64 `json:"mid_prices"` // Price series
EMA20Values []float64 `json:"ema20_values"` // EMA20 series
EMA50Values []float64 `json:"ema50_values"` // EMA50 series
MACDValues []float64 `json:"macd_values"` // MACD series
RSI7Values []float64 `json:"rsi7_values"` // RSI7 series
RSI14Values []float64 `json:"rsi14_values"` // RSI14 series
Volume []float64 `json:"volume"` // Volume series
ATR14 float64 `json:"atr14"` // ATR14
}
// OIData Open Interest数据
// OIData Open Interest data
type OIData struct {
Latest float64
Average float64
}
// IntradayData 日内数据(3分钟间隔)
// IntradayData intraday data (3-minute interval)
type IntradayData struct {
MidPrices []float64
EMA20Values []float64
@@ -49,7 +49,7 @@ type IntradayData struct {
ATR14 float64
}
// LongerTermData 长期数据(4小时时间框架)
// LongerTermData longer-term data (4-hour timeframe)
type LongerTermData struct {
EMA20 float64
EMA50 float64
@@ -61,7 +61,7 @@ type LongerTermData struct {
RSI14Values []float64
}
// Binance API 响应结构
// Binance API response structure
type ExchangeInfo struct {
Symbols []SymbolInfo `json:"symbols"`
}
@@ -105,7 +105,7 @@ type Ticker24hr struct {
QuoteVolume string `json:"quoteVolume"`
}
// 特征数据结构
// SymbolFeatures feature data structure
type SymbolFeatures struct {
Symbol string `json:"symbol"`
Timestamp time.Time `json:"timestamp"`
@@ -126,7 +126,7 @@ type SymbolFeatures struct {
PositionInRange float64 `json:"position_in_range"`
}
// 警报数据结构
// Alert alert data structure
type Alert struct {
Type string `json:"type"`
Symbol string `json:"symbol"`
@@ -150,10 +150,10 @@ type AlertThresholds struct {
RSIOversold float64 `json:"rsi_oversold"`
}
type CleanupConfig struct {
InactiveTimeout time.Duration `json:"inactive_timeout"` // 不活跃超时时间
MinScoreThreshold float64 `json:"min_score_threshold"` // 最低评分阈值
NoAlertTimeout time.Duration `json:"no_alert_timeout"` // 无警报超时时间
CheckInterval time.Duration `json:"check_interval"` // 检查间隔
InactiveTimeout time.Duration `json:"inactive_timeout"` // Inactive timeout duration
MinScoreThreshold float64 `json:"min_score_threshold"` // Minimum score threshold
NoAlertTimeout time.Duration `json:"no_alert_timeout"` // No alert timeout duration
CheckInterval time.Duration `json:"check_interval"` // Check interval
}
var config = Config{
+11 -11
View File
@@ -83,16 +83,16 @@ func (w *WSClient) Connect() error {
conn, _, err := dialer.Dial("wss://ws-fapi.binance.com/ws-fapi/v1", nil)
if err != nil {
return fmt.Errorf("WebSocket连接失败: %v", err)
return fmt.Errorf("WebSocket connection failed: %v", err)
}
w.mu.Lock()
w.conn = conn
w.mu.Unlock()
log.Println("WebSocket连接成功")
log.Println("WebSocket connected successfully")
// 启动消息读取循环
// Start message reading loop
go w.readMessages()
return nil
@@ -124,7 +124,7 @@ func (w *WSClient) subscribe(stream string) error {
defer w.mu.RUnlock()
if w.conn == nil {
return fmt.Errorf("WebSocket未连接")
return fmt.Errorf("WebSocket not connected")
}
err := w.conn.WriteJSON(subscribeMsg)
@@ -132,7 +132,7 @@ func (w *WSClient) subscribe(stream string) error {
return err
}
log.Printf("订阅流: %s", stream)
log.Printf("Subscribing to stream: %s", stream)
return nil
}
@@ -153,7 +153,7 @@ func (w *WSClient) readMessages() {
_, message, err := conn.ReadMessage()
if err != nil {
log.Printf("读取WebSocket消息失败: %v", err)
log.Printf("Failed to read WebSocket message: %v", err)
w.handleReconnect()
return
}
@@ -166,7 +166,7 @@ func (w *WSClient) readMessages() {
func (w *WSClient) handleMessage(message []byte) {
var wsMsg WSMessage
if err := json.Unmarshal(message, &wsMsg); err != nil {
// 可能是其他格式的消息
// Might be a different message format
return
}
@@ -178,7 +178,7 @@ func (w *WSClient) handleMessage(message []byte) {
select {
case ch <- wsMsg.Data:
default:
log.Printf("订阅者通道已满: %s", wsMsg.Stream)
log.Printf("Subscriber channel is full: %s", wsMsg.Stream)
}
}
}
@@ -188,11 +188,11 @@ func (w *WSClient) handleReconnect() {
return
}
log.Println("尝试重新连接...")
log.Println("Attempting to reconnect...")
time.Sleep(3 * time.Second)
if err := w.Connect(); err != nil {
log.Printf("重新连接失败: %v", err)
log.Printf("Reconnection failed: %v", err)
go w.handleReconnect()
}
}
@@ -223,7 +223,7 @@ func (w *WSClient) Close() {
w.conn = nil
}
// 关闭所有订阅者通道
// Close all subscriber channels
for stream, ch := range w.subscribers {
close(ch)
delete(w.subscribers, stream)
+99 -99
View File
@@ -28,65 +28,65 @@ var (
"connection refused",
"temporary failure",
"no such host",
"stream error", // HTTP/2 stream 错误
"INTERNAL_ERROR", // 服务端内部错误
"stream error", // HTTP/2 stream error
"INTERNAL_ERROR", // Server internal error
}
)
// Client AI API配置
// Client AI API configuration
type Client struct {
Provider string
APIKey string
BaseURL string
Model string
UseFullURL bool // 是否使用完整URL(不添加/chat/completions
MaxTokens int // AI响应的最大token数
UseFullURL bool // Whether to use full URL (without appending /chat/completions)
MaxTokens int // Maximum tokens for AI response
httpClient *http.Client
logger Logger // 日志器(可替换)
config *Config // 配置对象(保存所有配置)
logger Logger // Logger (replaceable)
config *Config // Config object (stores all configurations)
// hooks 用于实现动态分派(多态)
// DeepSeekClient 嵌入 Client 时,hooks 指向 DeepSeekClient
// 这样 call() 中调用的方法会自动分派到子类重写的版本
// hooks are used to implement dynamic dispatch (polymorphism)
// When DeepSeekClient embeds Client, hooks point to DeepSeekClient
// This way methods called in call() are automatically dispatched to the overridden version in subclass
hooks clientHooks
}
// New 创建默认客户端(向前兼容)
// New creates default client (backward compatible)
//
// Deprecated: 推荐使用 NewClient(...opts) 以获得更好的灵活性
// Deprecated: Recommend using NewClient(...opts) for better flexibility
func New() AIClient {
return NewClient()
}
// NewClient 创建客户端(支持选项模式)
// NewClient creates client (supports options pattern)
//
// 使用示例:
// // 基础用法(向前兼容)
// Usage examples:
// // Basic usage (backward compatible)
// client := mcp.NewClient()
//
// // 自定义日志
// // Custom logger
// client := mcp.NewClient(mcp.WithLogger(customLogger))
//
// // 自定义超时
// // Custom timeout
// client := mcp.NewClient(mcp.WithTimeout(60*time.Second))
//
// // 组合多个选项
// // Combine multiple options
// client := mcp.NewClient(
// mcp.WithDeepSeekConfig("sk-xxx"),
// mcp.WithLogger(customLogger),
// mcp.WithTimeout(60*time.Second),
// )
func NewClient(opts ...ClientOption) AIClient {
// 1. 创建默认配置
// 1. Create default config
cfg := DefaultConfig()
// 2. 应用用户选项
// 2. Apply user options
for _, opt := range opts {
opt(cfg)
}
// 3. 创建客户端实例
// 3. Create client instance
client := &Client{
Provider: cfg.Provider,
APIKey: cfg.APIKey,
@@ -99,25 +99,25 @@ func NewClient(opts ...ClientOption) AIClient {
config: cfg,
}
// 4. 设置默认 Provider(如果未设置)
// 4. Set default Provider (if not set)
if client.Provider == "" {
client.Provider = ProviderDeepSeek
client.BaseURL = DefaultDeepSeekBaseURL
client.Model = DefaultDeepSeekModel
}
// 5. 设置 hooks 指向自己
// 5. Set hooks to point to self
client.hooks = client
return client
}
// SetCustomAPI 设置自定义OpenAI兼容API
// SetCustomAPI sets custom OpenAI-compatible API
func (client *Client) SetAPIKey(apiKey, apiURL, customModel string) {
client.Provider = ProviderCustom
client.APIKey = apiKey
// 检查URL是否以#结尾,如果是则使用完整URL(不添加/chat/completions
// Check if URL ends with #, if so use full URL (without appending /chat/completions)
if strings.HasSuffix(apiURL, "#") {
client.BaseURL = strings.TrimSuffix(apiURL, "#")
client.UseFullURL = true
@@ -133,45 +133,45 @@ func (client *Client) SetTimeout(timeout time.Duration) {
client.httpClient.Timeout = timeout
}
// CallWithMessages 模板方法 - 固定的重试流程(不可重写)
// CallWithMessages template method - fixed retry flow (cannot be overridden)
func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, error) {
if client.APIKey == "" {
return "", fmt.Errorf("AI API密钥未设置,请先调用 SetAPIKey")
return "", fmt.Errorf("AI API key not set, please call SetAPIKey first")
}
// 固定的重试流程
// Fixed retry flow
var lastErr error
maxRetries := client.config.MaxRetries
for attempt := 1; attempt <= maxRetries; attempt++ {
if attempt > 1 {
client.logger.Warnf("⚠️ AI API调用失败,正在重试 (%d/%d)...", attempt, maxRetries)
client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
}
// 调用固定的单次调用流程
// Call the fixed single-call flow
result, err := client.hooks.call(systemPrompt, userPrompt)
if err == nil {
if attempt > 1 {
client.logger.Infof("✓ AI API重试成功")
client.logger.Infof("✓ AI API retry succeeded")
}
return result, nil
}
lastErr = err
// 通过 hooks 判断是否可重试(支持子类自定义重试策略)
// Check if error is retryable via hooks (supports custom retry strategy in subclass)
if !client.hooks.isRetryableError(err) {
return "", err
}
// 重试前等待
// Wait before retry
if attempt < maxRetries {
waitTime := client.config.RetryWaitBase * time.Duration(attempt)
client.logger.Infof("⏳ 等待%v后重试...", waitTime)
client.logger.Infof("⏳ Waiting %v before retry...", waitTime)
time.Sleep(waitTime)
}
}
return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
return "", fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
}
func (client *Client) setAuthHeader(reqHeader http.Header) {
@@ -179,27 +179,27 @@ func (client *Client) setAuthHeader(reqHeader http.Header) {
}
func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[string]any {
// 构建 messages 数组
// Build messages array
messages := []map[string]string{}
// 如果有 system prompt,添加 system message
// If system prompt exists, add system message
if systemPrompt != "" {
messages = append(messages, map[string]string{
"role": "system",
"content": systemPrompt,
})
}
// 添加 user message
// Add user message
messages = append(messages, map[string]string{
"role": "user",
"content": userPrompt,
})
// 构建请求体
// Build request body
requestBody := map[string]interface{}{
"model": client.Model,
"messages": messages,
"temperature": client.config.Temperature, // 使用配置的 temperature
"temperature": client.config.Temperature, // Use configured temperature
"max_tokens": client.MaxTokens,
}
return requestBody
@@ -209,7 +209,7 @@ func (client *Client) buildMCPRequestBody(systemPrompt, userPrompt string) map[s
func (client *Client) marshalRequestBody(requestBody map[string]any) ([]byte, error) {
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
return nil, fmt.Errorf("failed to serialize request: %w", err)
}
return jsonData, nil
}
@@ -224,11 +224,11 @@ func (client *Client) parseMCPResponse(body []byte) (string, error) {
}
if err := json.Unmarshal(body, &result); err != nil {
return "", fmt.Errorf("解析响应失败: %w", err)
return "", fmt.Errorf("failed to parse response: %w", err)
}
if len(result.Choices) == 0 {
return "", fmt.Errorf("API返回空响应")
return "", fmt.Errorf("API returned empty response")
}
return result.Choices[0].Message.Content, nil
@@ -250,59 +250,59 @@ func (client *Client) buildRequest(url string, jsonData []byte) (*http.Request,
req.Header.Set("Content-Type", "application/json")
// 通过 hooks 设置认证头(支持子类重写)
// Set auth header via hooks (supports overriding in subclass)
client.hooks.setAuthHeader(req.Header)
return req, nil
}
// call 单次调用AI API(固定流程,不可重写)
// call single AI API call (fixed flow, cannot be overridden)
func (client *Client) call(systemPrompt, userPrompt string) (string, error) {
// 打印当前 AI 配置
// Print current AI configuration
client.logger.Infof("📡 [%s] Request AI Server: BaseURL: %s", client.String(), client.BaseURL)
client.logger.Debugf("[%s] UseFullURL: %v", client.String(), client.UseFullURL)
if len(client.APIKey) > 8 {
client.logger.Debugf("[%s] API Key: %s...%s", client.String(), client.APIKey[:4], client.APIKey[len(client.APIKey)-4:])
}
// Step 1: 构建请求体(通过 hooks 实现动态分派)
// Step 1: Build request body (via hooks for dynamic dispatch)
requestBody := client.hooks.buildMCPRequestBody(systemPrompt, userPrompt)
// Step 2: 序列化请求体(通过 hooks 实现动态分派)
// Step 2: Serialize request body (via hooks for dynamic dispatch)
jsonData, err := client.hooks.marshalRequestBody(requestBody)
if err != nil {
return "", err
}
// Step 3: 构建 URL(通过 hooks 实现动态分派)
// Step 3: Build URL (via hooks for dynamic dispatch)
url := client.hooks.buildUrl()
client.logger.Infof("📡 [MCP %s] 请求 URL: %s", client.String(), url)
client.logger.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
// Step 4: 创建 HTTP 请求(固定逻辑)
// Step 4: Create HTTP request (fixed logic)
req, err := client.hooks.buildRequest(url, jsonData)
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
return "", fmt.Errorf("failed to create request: %w", err)
}
// Step 5: 发送 HTTP 请求(固定逻辑)
// Step 5: Send HTTP request (fixed logic)
resp, err := client.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("发送请求失败: %w", err)
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
// Step 6: 读取响应体(固定逻辑)
// Step 6: Read response body (fixed logic)
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取响应失败: %w", err)
return "", fmt.Errorf("failed to read response: %w", err)
}
// Step 7: 检查 HTTP 状态码(固定逻辑)
// Step 7: Check HTTP status code (fixed logic)
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("API返回错误 (status %d): %s", resp.StatusCode, string(body))
return "", fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
}
// Step 8: 解析响应(通过 hooks 实现动态分派)
// Step 8: Parse response (via hooks for dynamic dispatch)
result, err := client.hooks.parseMCPResponse(body)
if err != nil {
return "", fmt.Errorf("fail to parse AI server response: %w", err)
@@ -316,10 +316,10 @@ func (client *Client) String() string {
client.Provider, client.Model)
}
// isRetryableError 判断错误是否可重试(网络错误、超时等)
// isRetryableError determines if error is retryable (network errors, timeouts, etc.)
func (client *Client) isRetryableError(err error) bool {
errStr := err.Error()
// 网络错误、超时、EOF等可以重试
// Network errors, timeouts, EOF, etc. can be retried
for _, retryable := range client.config.RetryableErrors {
if strings.Contains(errStr, retryable) {
return true
@@ -329,18 +329,18 @@ func (client *Client) isRetryableError(err error) bool {
}
// ============================================================
// 构建器模式 API(高级功能)
// Builder Pattern API (Advanced Features)
// ============================================================
// CallWithRequest 使用 Request 对象调用 AI API(支持高级功能)
// CallWithRequest calls AI API using Request object (supports advanced features)
//
// 此方法支持:
// - 多轮对话历史
// - 精细参数控制(temperaturetop_ppenalties 等)
// This method supports:
// - Multi-turn conversation history
// - Fine-grained parameter control (temperature, top_p, penalties, etc.)
// - Function Calling / Tools
// - 流式响应(未来支持)
// - Streaming response (future support)
//
// 使用示例:
// Usage example:
// request := NewRequestBuilder().
// WithSystemPrompt("You are helpful").
// WithUserPrompt("Hello").
@@ -349,93 +349,93 @@ func (client *Client) isRetryableError(err error) bool {
// result, err := client.CallWithRequest(request)
func (client *Client) CallWithRequest(req *Request) (string, error) {
if client.APIKey == "" {
return "", fmt.Errorf("AI API密钥未设置,请先调用 SetAPIKey")
return "", fmt.Errorf("AI API key not set, please call SetAPIKey first")
}
// 如果 Request 中没有设置 Model,使用 Client Model
// If Model is not set in Request, use Client's Model
if req.Model == "" {
req.Model = client.Model
}
// 固定的重试流程
// Fixed retry flow
var lastErr error
maxRetries := client.config.MaxRetries
for attempt := 1; attempt <= maxRetries; attempt++ {
if attempt > 1 {
client.logger.Warnf("⚠️ AI API调用失败,正在重试 (%d/%d)...", attempt, maxRetries)
client.logger.Warnf("⚠️ AI API call failed, retrying (%d/%d)...", attempt, maxRetries)
}
// 调用单次请求
// Call single request
result, err := client.callWithRequest(req)
if err == nil {
if attempt > 1 {
client.logger.Infof("✓ AI API重试成功")
client.logger.Infof("✓ AI API retry succeeded")
}
return result, nil
}
lastErr = err
// 判断是否可重试
// Check if error is retryable
if !client.hooks.isRetryableError(err) {
return "", err
}
// 重试前等待
// Wait before retry
if attempt < maxRetries {
waitTime := client.config.RetryWaitBase * time.Duration(attempt)
client.logger.Infof("⏳ 等待%v后重试...", waitTime)
client.logger.Infof("⏳ Waiting %v before retry...", waitTime)
time.Sleep(waitTime)
}
}
return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
return "", fmt.Errorf("still failed after %d retries: %w", maxRetries, lastErr)
}
// callWithRequest 单次调用 AI API(使用 Request 对象)
// callWithRequest single AI API call (using Request object)
func (client *Client) callWithRequest(req *Request) (string, error) {
// 打印当前 AI 配置
// Print current AI configuration
client.logger.Infof("📡 [%s] Request AI Server with Builder: BaseURL: %s", client.String(), client.BaseURL)
client.logger.Debugf("[%s] Messages count: %d", client.String(), len(req.Messages))
// 构建请求体(从 Request 对象)
// Build request body (from Request object)
requestBody := client.buildRequestBodyFromRequest(req)
// 序列化请求体
// Serialize request body
jsonData, err := client.hooks.marshalRequestBody(requestBody)
if err != nil {
return "", err
}
// 构建 URL
// Build URL
url := client.hooks.buildUrl()
client.logger.Infof("📡 [MCP %s] 请求 URL: %s", client.String(), url)
client.logger.Infof("📡 [MCP %s] Request URL: %s", client.String(), url)
// 创建 HTTP 请求
// Create HTTP request
httpReq, err := client.hooks.buildRequest(url, jsonData)
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
return "", fmt.Errorf("failed to create request: %w", err)
}
// 发送 HTTP 请求
// Send HTTP request
resp, err := client.httpClient.Do(httpReq)
if err != nil {
return "", fmt.Errorf("发送请求失败: %w", err)
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
// 读取响应体
// Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取响应失败: %w", err)
return "", fmt.Errorf("failed to read response: %w", err)
}
// 检查 HTTP 状态码
// Check HTTP status code
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("API返回错误 (status %d): %s", resp.StatusCode, string(body))
return "", fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
}
// 解析响应
// Parse response
result, err := client.hooks.parseMCPResponse(body)
if err != nil {
return "", fmt.Errorf("fail to parse AI server response: %w", err)
@@ -444,9 +444,9 @@ func (client *Client) callWithRequest(req *Request) (string, error) {
return result, nil
}
// buildRequestBodyFromRequest Request 对象构建请求体
// buildRequestBodyFromRequest builds request body from Request object
func (client *Client) buildRequestBodyFromRequest(req *Request) map[string]any {
// 转换 Message API 格式
// Convert Message to API format
messages := make([]map[string]string, 0, len(req.Messages))
for _, msg := range req.Messages {
messages = append(messages, map[string]string{
@@ -455,24 +455,24 @@ func (client *Client) buildRequestBodyFromRequest(req *Request) map[string]any {
})
}
// 构建基础请求体
// Build basic request body
requestBody := map[string]interface{}{
"model": req.Model,
"messages": messages,
}
// 添加可选参数(只添加非 nil 的参数)
// Add optional parameters (only add non-nil parameters)
if req.Temperature != nil {
requestBody["temperature"] = *req.Temperature
} else {
// 如果 Request 中没有设置,使用 Client 的配置
// If not set in Request, use Client's configuration
requestBody["temperature"] = client.config.Temperature
}
if req.MaxTokens != nil {
requestBody["max_tokens"] = *req.MaxTokens
} else {
// 如果 Request 中没有设置,使用 Client MaxTokens
// If not set in Request, use Client's MaxTokens
requestBody["max_tokens"] = client.MaxTokens
}
+21 -21
View File
@@ -8,7 +8,7 @@ import (
)
// ============================================================
// 测试 Client 创建和配置
// Test Client Creation and Configuration
// ============================================================
func TestNewClient_Default(t *testing.T) {
@@ -72,7 +72,7 @@ func TestNewClient_WithOptions(t *testing.T) {
}
// ============================================================
// 测试 CallWithMessages
// Test CallWithMessages
// ============================================================
func TestClient_CallWithMessages_Success(t *testing.T) {
@@ -97,7 +97,7 @@ func TestClient_CallWithMessages_Success(t *testing.T) {
t.Errorf("expected 'AI response content', got '%s'", result)
}
// 验证请求
// Verify request
requests := mockHTTP.GetRequests()
if len(requests) != 1 {
t.Errorf("expected 1 request, got %d", len(requests))
@@ -123,7 +123,7 @@ func TestClient_CallWithMessages_NoAPIKey(t *testing.T) {
t.Error("should error when API key is not set")
}
if err.Error() != "AI API密钥未设置,请先调用 SetAPIKey" {
if err.Error() != "AI API key not set, please call SetAPIKey first" {
t.Errorf("unexpected error message: %v", err)
}
}
@@ -147,14 +147,14 @@ func TestClient_CallWithMessages_HTTPError(t *testing.T) {
}
// ============================================================
// 测试重试逻辑
// Test Retry Logic
// ============================================================
func TestClient_Retry_Success(t *testing.T) {
mockHTTP := NewMockHTTPClient()
mockLogger := NewMockLogger()
// 模拟:第一次失败,第二次成功
// Simulate: first call fails, second call succeeds
callCount := 0
mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) {
callCount++
@@ -174,40 +174,40 @@ func TestClient_Retry_Success(t *testing.T) {
WithMaxRetries(3),
)
// 由于我们的 client 使用 hooks.call,需要特殊处理
// 这里我们测试的是 CallWithMessages 会调用 retry 逻辑
// Since our client uses hooks.call, need special handling
// Here we test that CallWithMessages will invoke retry logic
c := client.(*Client)
// 临时修改重试等待时间为 0 以加速测试
// Temporarily modify retry wait time to 0 to speed up test
oldRetries := MaxRetryTimes
MaxRetryTimes = 3
defer func() { MaxRetryTimes = oldRetries }()
_, err := c.CallWithMessages("system", "user")
// 第一次失败(connection reset),第二次成功,但是响应格式不对,会失败
// 但至少验证了重试逻辑被触发
// First fails (connection reset), second succeeds, but response format is wrong, will fail
// But at least verify retry logic was triggered
if callCount < 2 {
t.Errorf("should retry, got %d calls", callCount)
}
// 检查日志中是否有重试信息
// Check if there's retry information in logs
logs := mockLogger.GetLogsByLevel("WARN")
hasRetryLog := false
for _, log := range logs {
if log.Message == "⚠️ AI API调用失败,正在重试 (2/3)..." {
if log.Message == "⚠️ AI API call failed, retrying (2/3)..." {
hasRetryLog = true
break
}
}
if !hasRetryLog && callCount >= 2 {
// 如果确实重试了,应该有警告日志
// 但由于我们的测试设置,可能不会触发,所以这里只是检查
// If retry was indeed attempted, there should be warning logs
// But due to our test setup, it may not trigger, so just check here
t.Log("Retry was attempted")
}
_ = err // 忽略错误,我们主要测试重试逻辑被触发
_ = err // Ignore error, we mainly test retry logic was triggered
}
func TestClient_Retry_NonRetryableError(t *testing.T) {
@@ -227,7 +227,7 @@ func TestClient_Retry_NonRetryableError(t *testing.T) {
t.Error("should error")
}
// 验证没有重试(因为 400 不是可重试错误)
// Verify no retry (because 400 is not a retryable error)
requests := mockHTTP.GetRequests()
if len(requests) != 1 {
t.Errorf("should not retry for 400 error, got %d requests", len(requests))
@@ -235,7 +235,7 @@ func TestClient_Retry_NonRetryableError(t *testing.T) {
}
// ============================================================
// 测试钩子方法
// Test Hook Methods
// ============================================================
func TestClient_BuildMCPRequestBody(t *testing.T) {
@@ -368,7 +368,7 @@ func TestClient_IsRetryableError(t *testing.T) {
}
// ============================================================
// 测试 SetTimeout
// Test SetTimeout
// ============================================================
func TestClient_SetTimeout(t *testing.T) {
@@ -384,7 +384,7 @@ func TestClient_SetTimeout(t *testing.T) {
}
// ============================================================
// 测试 String 方法
// Test String Method
// ============================================================
func TestClient_String(t *testing.T) {
@@ -404,7 +404,7 @@ func TestClient_String(t *testing.T) {
}
}
// 辅助函数
// Helper function
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr))
}
+11 -11
View File
@@ -9,36 +9,36 @@ import (
"nofx/logger"
)
// Config 客户端配置(集中管理所有配置)
// Config client configuration (centralized management of all configurations)
type Config struct {
// Provider 配置
// Provider configuration
Provider string
APIKey string
BaseURL string
Model string
// 行为配置
// Behavior configuration
MaxTokens int
Temperature float64
UseFullURL bool
// 重试配置
// Retry configuration
MaxRetries int
RetryWaitBase time.Duration
RetryableErrors []string
// 超时配置
// Timeout configuration
Timeout time.Duration
// 依赖注入
// Dependency injection
Logger Logger
HTTPClient *http.Client
}
// DefaultConfig 返回默认配置
// DefaultConfig returns default configuration
func DefaultConfig() *Config {
return &Config{
// 默认值
// Default values
MaxTokens: getEnvInt("AI_MAX_TOKENS", 2000),
Temperature: MCPClientTemperature,
MaxRetries: MaxRetryTimes,
@@ -46,13 +46,13 @@ func DefaultConfig() *Config {
Timeout: DefaultTimeout,
RetryableErrors: retryableErrors,
// 默认依赖(使用全局 logger
// Default dependencies (use global logger)
Logger: logger.NewMCPLogger(),
HTTPClient: &http.Client{Timeout: DefaultTimeout},
}
}
// getEnvInt 从环境变量读取整数,失败则返回默认值
// getEnvInt reads integer from environment variable, returns default value if failed
func getEnvInt(key string, defaultValue int) int {
if val := os.Getenv(key); val != "" {
if parsed, err := strconv.Atoi(val); err == nil && parsed > 0 {
@@ -62,7 +62,7 @@ func getEnvInt(key string, defaultValue int) int {
return defaultValue
}
// getEnvString 从环境变量读取字符串,为空则返回默认值
// getEnvString reads string from environment variable, returns default value if empty
func getEnvString(key string, defaultValue string) string {
if val := os.Getenv(key); val != "" {
return val
+38 -38
View File
@@ -11,49 +11,49 @@ import (
)
// ============================================================
// 测试 Config 字段真正被使用(验证问题2修复)
// Test Config Fields Are Actually Used (Verify Issue 2 Fix)
// ============================================================
func TestConfig_MaxRetries_IsUsed(t *testing.T) {
mockHTTP := NewMockHTTPClient()
mockLogger := NewMockLogger()
// 设置 HTTP 客户端返回错误
// Set HTTP client to return error
callCount := 0
mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) {
callCount++
return nil, errors.New("connection reset")
}
// 创建客户端并设置自定义重试次数为 5
// Create client and set custom retry count to 5
client := NewClient(
WithHTTPClient(mockHTTP.ToHTTPClient()),
WithLogger(mockLogger),
WithAPIKey("sk-test-key"),
WithMaxRetries(5), // ✅ 设置重试5次
WithMaxRetries(5), // Set to retry 5 times
)
// 调用 API(应该失败)
// Call API (should fail)
_, err := client.CallWithMessages("system", "user")
if err == nil {
t.Error("should error")
}
// 验证确实重试了5次(而不是默认的3次)
// Verify indeed retried 5 times (not the default 3 times)
if callCount != 5 {
t.Errorf("expected 5 retry attempts (from WithMaxRetries(5)), got %d", callCount)
}
// 验证日志中显示正确的重试次数
// Verify logs show correct retry count
logs := mockLogger.GetLogsByLevel("WARN")
expectedWarningCount := 4 // 第2、3、4、5次重试时会打印警告
expectedWarningCount := 4 // Warnings will be printed on 2nd, 3rd, 4th, 5th retry
actualWarningCount := 0
for _, log := range logs {
if log.Message == "⚠️ AI API调用失败,正在重试 (2/5)..." ||
log.Message == "⚠️ AI API调用失败,正在重试 (3/5)..." ||
log.Message == "⚠️ AI API调用失败,正在重试 (4/5)..." ||
log.Message == "⚠️ AI API调用失败,正在重试 (5/5)..." {
if log.Message == "⚠️ AI API call failed, retrying (2/5)..." ||
log.Message == "⚠️ AI API call failed, retrying (3/5)..." ||
log.Message == "⚠️ AI API call failed, retrying (4/5)..." ||
log.Message == "⚠️ AI API call failed, retrying (5/5)..." {
actualWarningCount++
}
}
@@ -73,20 +73,20 @@ func TestConfig_Temperature_IsUsed(t *testing.T) {
customTemperature := 0.8
// 创建客户端并设置自定义 temperature
// Create client and set custom temperature
client := NewClient(
WithHTTPClient(mockHTTP.ToHTTPClient()),
WithLogger(mockLogger),
WithAPIKey("sk-test-key"),
WithTemperature(customTemperature), // ✅ 设置自定义 temperature
WithTemperature(customTemperature), // Set custom temperature
)
c := client.(*Client)
// 构建请求体
// Build request body
requestBody := c.buildMCPRequestBody("system", "user")
// 验证 temperature 字段
// Verify temperature field
temp, ok := requestBody["temperature"].(float64)
if !ok {
t.Fatal("temperature should be float64")
@@ -96,26 +96,26 @@ func TestConfig_Temperature_IsUsed(t *testing.T) {
t.Errorf("expected temperature %f (from WithTemperature), got %f", customTemperature, temp)
}
// 也可以通过实际 HTTP 请求验证
// Can also verify through actual HTTP request
_, err := client.CallWithMessages("system", "user")
if err != nil {
t.Fatalf("should not error: %v", err)
}
// 检查发送的请求体
// Check sent request body
requests := mockHTTP.GetRequests()
if len(requests) != 1 {
t.Fatalf("expected 1 request, got %d", len(requests))
}
// 解析请求体
// Parse request body
var body map[string]interface{}
decoder := json.NewDecoder(requests[0].Body)
if err := decoder.Decode(&body); err != nil {
t.Fatalf("failed to decode request body: %v", err)
}
// 验证 temperature
// Verify temperature
if body["temperature"] != customTemperature {
t.Errorf("expected temperature %f in HTTP request, got %v", customTemperature, body["temperature"])
}
@@ -125,18 +125,18 @@ func TestConfig_RetryWaitBase_IsUsed(t *testing.T) {
mockHTTP := NewMockHTTPClient()
mockLogger := NewMockLogger()
// 设置成功响应(在 ResponseFunc 之前)
// Set success response (before ResponseFunc)
mockHTTP.SetSuccessResponse("AI response")
// 设置 HTTP 客户端前2次返回错误,第3次成功
// Set HTTP client to return error first 2 times, success on 3rd time
callCount := 0
successResponse := mockHTTP.Response // 保存成功响应字符串
successResponse := mockHTTP.Response // Save success response string
mockHTTP.ResponseFunc = func(req *http.Request) (*http.Response, error) {
callCount++
if callCount <= 2 {
return nil, errors.New("timeout exceeded")
}
// 第3次返回成功响应
// 3rd time return success response
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewBufferString(successResponse)),
@@ -144,27 +144,27 @@ func TestConfig_RetryWaitBase_IsUsed(t *testing.T) {
}, nil
}
// 设置自定义重试等待基数为 1 秒(而不是默认的 2 秒)
// Set custom retry wait base to 1 second (instead of default 2 seconds)
customWaitBase := 1 * time.Second
client := NewClient(
WithHTTPClient(mockHTTP.ToHTTPClient()),
WithLogger(mockLogger),
WithAPIKey("sk-test-key"),
WithRetryWaitBase(customWaitBase), // ✅ 设置自定义等待时间
WithRetryWaitBase(customWaitBase), // Set custom wait time
WithMaxRetries(3),
)
// 记录开始时间
// Record start time
start := time.Now()
// 调用 API
// Call API
_, err := client.CallWithMessages("system", "user")
// 记录结束时间
// Record end time
elapsed := time.Since(start)
// 第3次成功,但前面失败了2次
// 3rd time succeeds, but failed 2 times before
if err != nil {
t.Fatalf("should succeed on 3rd attempt, got error: %v", err)
}
@@ -173,10 +173,10 @@ func TestConfig_RetryWaitBase_IsUsed(t *testing.T) {
t.Errorf("expected 3 attempts, got %d", callCount)
}
// 验证等待时间
// 第1次失败后等待 1s (customWaitBase * 1)
// 第2次失败后等待 2s (customWaitBase * 2)
// 总等待时间应该约为 3s (允许一些误差)
// Verify wait time
// After 1st failure wait 1s (customWaitBase * 1)
// After 2nd failure wait 2s (customWaitBase * 2)
// Total wait time should be about 3s (allow some error)
expectedWait := 3 * time.Second
tolerance := 200 * time.Millisecond
@@ -189,7 +189,7 @@ func TestConfig_RetryableErrors_IsUsed(t *testing.T) {
mockHTTP := NewMockHTTPClient()
mockLogger := NewMockLogger()
// 自定义可重试错误列表(只包含 "custom error"
// Custom retryable error list (only contains "custom error")
customRetryableErrors := []string{"custom error"}
client := NewClient(
@@ -200,7 +200,7 @@ func TestConfig_RetryableErrors_IsUsed(t *testing.T) {
c := client.(*Client)
// 修改 config RetryableErrors(暂时没有 WithRetryableErrors 选项)
// Modify config's RetryableErrors (no WithRetryableErrors option yet)
c.config.RetryableErrors = customRetryableErrors
tests := []struct {
@@ -236,14 +236,14 @@ func TestConfig_RetryableErrors_IsUsed(t *testing.T) {
}
// ============================================================
// 测试默认值
// Test Default Values
// ============================================================
func TestConfig_DefaultValues(t *testing.T) {
client := NewClient()
c := client.(*Client)
// 验证默认值
// Verify default values
if c.config.MaxRetries != 3 {
t.Errorf("default MaxRetries should be 3, got %d", c.config.MaxRetries)
}
+15 -15
View File
@@ -14,45 +14,45 @@ type DeepSeekClient struct {
*Client
}
// NewDeepSeekClient 创建 DeepSeek 客户端(向前兼容)
// NewDeepSeekClient creates DeepSeek client (backward compatible)
//
// Deprecated: 推荐使用 NewDeepSeekClientWithOptions 以获得更好的灵活性
// Deprecated: Recommend using NewDeepSeekClientWithOptions for better flexibility
func NewDeepSeekClient() AIClient {
return NewDeepSeekClientWithOptions()
}
// NewDeepSeekClientWithOptions 创建 DeepSeek 客户端(支持选项模式)
// NewDeepSeekClientWithOptions creates DeepSeek client (supports options pattern)
//
// 使用示例:
// // 基础用法
// Usage examples:
// // Basic usage
// client := mcp.NewDeepSeekClientWithOptions()
//
// // 自定义配置
// // Custom configuration
// client := mcp.NewDeepSeekClientWithOptions(
// mcp.WithAPIKey("sk-xxx"),
// mcp.WithLogger(customLogger),
// mcp.WithTimeout(60*time.Second),
// )
func NewDeepSeekClientWithOptions(opts ...ClientOption) AIClient {
// 1. 创建 DeepSeek 预设选项
// 1. Create DeepSeek preset options
deepseekOpts := []ClientOption{
WithProvider(ProviderDeepSeek),
WithModel(DefaultDeepSeekModel),
WithBaseURL(DefaultDeepSeekBaseURL),
}
// 2. 合并用户选项(用户选项优先级更高)
// 2. Merge user options (user options have higher priority)
allOpts := append(deepseekOpts, opts...)
// 3. 创建基础客户端
// 3. Create base client
baseClient := NewClient(allOpts...).(*Client)
// 4. 创建 DeepSeek 客户端
// 4. Create DeepSeek client
dsClient := &DeepSeekClient{
Client: baseClient,
}
// 5. 设置 hooks 指向 DeepSeekClient(实现动态分派)
// 5. Set hooks to point to DeepSeekClient (implement dynamic dispatch)
baseClient.hooks = dsClient
return dsClient
@@ -66,15 +66,15 @@ func (dsClient *DeepSeekClient) SetAPIKey(apiKey string, customURL string, custo
}
if customURL != "" {
dsClient.BaseURL = customURL
dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s", customURL)
dsClient.logger.Infof("🔧 [MCP] DeepSeek using custom BaseURL: %s", customURL)
} else {
dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用默认 BaseURL: %s", dsClient.BaseURL)
dsClient.logger.Infof("🔧 [MCP] DeepSeek using default BaseURL: %s", dsClient.BaseURL)
}
if customModel != "" {
dsClient.Model = customModel
dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用自定义 Model: %s", customModel)
dsClient.logger.Infof("🔧 [MCP] DeepSeek using custom Model: %s", customModel)
} else {
dsClient.logger.Infof("🔧 [MCP] DeepSeek 使用默认 Model: %s", dsClient.Model)
dsClient.logger.Infof("🔧 [MCP] DeepSeek using default Model: %s", dsClient.Model)
}
}
+22 -22
View File
@@ -6,7 +6,7 @@ import (
)
// ============================================================
// 测试 DeepSeekClient 创建和配置
// Test DeepSeekClient Creation and Configuration
// ============================================================
func TestNewDeepSeekClient_Default(t *testing.T) {
@@ -16,13 +16,13 @@ func TestNewDeepSeekClient_Default(t *testing.T) {
t.Fatal("client should not be nil")
}
// 类型断言检查
// Type assertion check
dsClient, ok := client.(*DeepSeekClient)
if !ok {
t.Fatal("client should be *DeepSeekClient")
}
// 验证默认值
// Verify default values
if dsClient.Provider != ProviderDeepSeek {
t.Errorf("Provider should be '%s', got '%s'", ProviderDeepSeek, dsClient.Provider)
}
@@ -58,7 +58,7 @@ func TestNewDeepSeekClientWithOptions(t *testing.T) {
dsClient := client.(*DeepSeekClient)
// 验证自定义选项被应用
// Verify custom options are applied
if dsClient.logger != mockLogger {
t.Error("logger should be set from option")
}
@@ -75,7 +75,7 @@ func TestNewDeepSeekClientWithOptions(t *testing.T) {
t.Error("MaxTokens should be 4000")
}
// 验证 DeepSeek 默认值仍然保留
// Verify DeepSeek default values are retained
if dsClient.Provider != ProviderDeepSeek {
t.Errorf("Provider should still be '%s'", ProviderDeepSeek)
}
@@ -86,7 +86,7 @@ func TestNewDeepSeekClientWithOptions(t *testing.T) {
}
// ============================================================
// 测试 SetAPIKey
// Test SetAPIKey
// ============================================================
func TestDeepSeekClient_SetAPIKey(t *testing.T) {
@@ -97,20 +97,20 @@ func TestDeepSeekClient_SetAPIKey(t *testing.T) {
dsClient := client.(*DeepSeekClient)
// 测试设置 API Key(默认 URL Model
// Test setting API Key (default URL and Model)
dsClient.SetAPIKey("sk-test-key-12345678", "", "")
if dsClient.APIKey != "sk-test-key-12345678" {
t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", dsClient.APIKey)
}
// 验证日志记录
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
if len(logs) == 0 {
t.Error("should have logged API key setting")
}
// 验证 BaseURL Model 保持默认
// Verify BaseURL and Model remain default
if dsClient.BaseURL != DefaultDeepSeekBaseURL {
t.Error("BaseURL should remain default")
}
@@ -135,11 +135,11 @@ func TestDeepSeekClient_SetAPIKey_WithCustomURL(t *testing.T) {
t.Errorf("BaseURL should be '%s', got '%s'", customURL, dsClient.BaseURL)
}
// 验证日志记录
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
hasCustomURLLog := false
for _, log := range logs {
if log.Format == "🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s" {
if log.Format == "🔧 [MCP] DeepSeek using custom BaseURL: %s" {
hasCustomURLLog = true
break
}
@@ -165,11 +165,11 @@ func TestDeepSeekClient_SetAPIKey_WithCustomModel(t *testing.T) {
t.Errorf("Model should be '%s', got '%s'", customModel, dsClient.Model)
}
// 验证日志记录
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
hasCustomModelLog := false
for _, log := range logs {
if log.Format == "🔧 [MCP] DeepSeek 使用自定义 Model: %s" {
if log.Format == "🔧 [MCP] DeepSeek using custom Model: %s" {
hasCustomModelLog = true
break
}
@@ -181,7 +181,7 @@ func TestDeepSeekClient_SetAPIKey_WithCustomModel(t *testing.T) {
}
// ============================================================
// 测试集成功能
// Test Integration Features
// ============================================================
func TestDeepSeekClient_CallWithMessages_Success(t *testing.T) {
@@ -205,7 +205,7 @@ func TestDeepSeekClient_CallWithMessages_Success(t *testing.T) {
t.Errorf("expected 'DeepSeek AI response', got '%s'", result)
}
// 验证请求
// Verify request
requests := mockHTTP.GetRequests()
if len(requests) != 1 {
t.Fatalf("expected 1 request, got %d", len(requests))
@@ -213,19 +213,19 @@ func TestDeepSeekClient_CallWithMessages_Success(t *testing.T) {
req := requests[0]
// 验证 URL
// Verify URL
expectedURL := DefaultDeepSeekBaseURL + "/chat/completions"
if req.URL.String() != expectedURL {
t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String())
}
// 验证 Authorization header
// Verify Authorization header
authHeader := req.Header.Get("Authorization")
if authHeader != "Bearer sk-test-key" {
t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader)
}
// 验证 Content-Type
// Verify Content-Type
if req.Header.Get("Content-Type") != "application/json" {
t.Error("Content-Type should be application/json")
}
@@ -242,7 +242,7 @@ func TestDeepSeekClient_Timeout(t *testing.T) {
t.Errorf("expected timeout 30s, got %v", dsClient.httpClient.Timeout)
}
// 测试 SetTimeout
// Test SetTimeout
client.SetTimeout(60 * time.Second)
if dsClient.httpClient.Timeout != 60*time.Second {
@@ -251,19 +251,19 @@ func TestDeepSeekClient_Timeout(t *testing.T) {
}
// ============================================================
// 测试 hooks 机制
// Test hooks Mechanism
// ============================================================
func TestDeepSeekClient_HooksIntegration(t *testing.T) {
client := NewDeepSeekClientWithOptions()
dsClient := client.(*DeepSeekClient)
// 验证 hooks 指向 dsClient 自己(实现多态)
// Verify hooks point to dsClient itself (implements polymorphism)
if dsClient.hooks != dsClient {
t.Error("hooks should point to dsClient for polymorphism")
}
// 验证 buildUrl 使用 DeepSeek 配置
// Verify buildUrl uses DeepSeek configuration
url := dsClient.buildUrl()
expectedURL := DefaultDeepSeekBaseURL + "/chat/completions"
if url != expectedURL {
+44 -44
View File
@@ -9,21 +9,21 @@ import (
)
// ============================================================
// 示例 1: 基础用法(向前兼容)
// Example 1: Basic Usage (Backward Compatible)
// ============================================================
func Example_backward_compatible() {
// ✅ 旧代码继续工作,无需修改
// Old code continues to work without modification
client := mcp.New()
client.SetAPIKey("sk-xxx", "https://api.custom.com", "gpt-4")
// 使用
// Usage
result, _ := client.CallWithMessages("system prompt", "user prompt")
fmt.Println(result)
}
func Example_deepseek_backward_compatible() {
// DeepSeek 旧代码继续工作
// DeepSeek old code continues to work
client := mcp.NewDeepSeekClient()
client.SetAPIKey("sk-xxx", "", "")
@@ -32,19 +32,19 @@ func Example_deepseek_backward_compatible() {
}
// ============================================================
// 示例 2: 新的推荐用法(选项模式)
// Example 2: New Recommended Usage (Options Pattern)
// ============================================================
func Example_new_client_basic() {
// 使用默认配置
// Use default configuration
client := mcp.NewClient()
// 使用 DeepSeek
// Use DeepSeek
client = mcp.NewClient(
mcp.WithDeepSeekConfig("sk-xxx"),
)
// 使用 Qwen
// Use Qwen
client = mcp.NewClient(
mcp.WithQwenConfig("sk-xxx"),
)
@@ -53,7 +53,7 @@ func Example_new_client_basic() {
}
func Example_new_client_with_options() {
// 组合多个选项
// Combine multiple options
client := mcp.NewClient(
mcp.WithDeepSeekConfig("sk-xxx"),
mcp.WithTimeout(60*time.Second),
@@ -67,10 +67,10 @@ func Example_new_client_with_options() {
}
// ============================================================
// 示例 3: 自定义日志器
// Example 3: Custom Logger
// ============================================================
// CustomLogger 自定义日志器示例
// CustomLogger custom logger example
type CustomLogger struct{}
func (l *CustomLogger) Debugf(format string, args ...any) {
@@ -90,7 +90,7 @@ func (l *CustomLogger) Errorf(format string, args ...any) {
}
func Example_custom_logger() {
// 使用自定义日志器
// Use custom logger
customLogger := &CustomLogger{}
client := mcp.NewClient(
@@ -103,7 +103,7 @@ func Example_custom_logger() {
}
func Example_no_logger_for_testing() {
// 测试时禁用日志
// Disable logging during testing
client := mcp.NewClient(
mcp.WithLogger(mcp.NewNoopLogger()),
)
@@ -113,16 +113,16 @@ func Example_no_logger_for_testing() {
}
// ============================================================
// 示例 4: 自定义 HTTP 客户端
// Example 4: Custom HTTP Client
// ============================================================
func Example_custom_http_client() {
// 自定义 HTTP 客户端(添加代理、TLS等)
// Custom HTTP client (add proxy, TLS, etc.)
customHTTP := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
// 自定义 TLS、连接池等
// Custom TLS, connection pool, etc.
},
}
@@ -136,16 +136,16 @@ func Example_custom_http_client() {
}
// ============================================================
// 示例 5: DeepSeek 客户端(新 API
// Example 5: DeepSeek Client (New API)
// ============================================================
func Example_deepseek_new_api() {
// 基础用法
// Basic usage
client := mcp.NewDeepSeekClientWithOptions(
mcp.WithAPIKey("sk-xxx"),
)
// 高级用法
// Advanced usage
client = mcp.NewDeepSeekClientWithOptions(
mcp.WithAPIKey("sk-xxx"),
mcp.WithLogger(&CustomLogger{}),
@@ -158,16 +158,16 @@ func Example_deepseek_new_api() {
}
// ============================================================
// 示例 6: Qwen 客户端(新 API
// Example 6: Qwen Client (New API)
// ============================================================
func Example_qwen_new_api() {
// 基础用法
// Basic usage
client := mcp.NewQwenClientWithOptions(
mcp.WithAPIKey("sk-xxx"),
)
// 高级用法
// Advanced usage
client = mcp.NewQwenClientWithOptions(
mcp.WithAPIKey("sk-xxx"),
mcp.WithLogger(&CustomLogger{}),
@@ -179,18 +179,18 @@ func Example_qwen_new_api() {
}
// ============================================================
// 示例 7: 在 trader/auto_trader.go 中的迁移示例
// Example 7: Migration Example in trader/auto_trader.go
// ============================================================
func Example_trader_migration() {
// === 旧代码(继续工作)===
// Old code (continues to work)
oldStyleClient := func(apiKey, customURL, customModel string) mcp.AIClient {
client := mcp.NewDeepSeekClient()
client.SetAPIKey(apiKey, customURL, customModel)
return client
}
// === 新代码(推荐)===
// New code (recommended)
newStyleClient := func(apiKey, customURL, customModel string) mcp.AIClient {
opts := []mcp.ClientOption{
mcp.WithAPIKey(apiKey),
@@ -207,37 +207,37 @@ func Example_trader_migration() {
return mcp.NewDeepSeekClientWithOptions(opts...)
}
// 两种方式都能工作
// Both approaches work
_ = oldStyleClient("sk-xxx", "", "")
_ = newStyleClient("sk-xxx", "", "")
}
// ============================================================
// 示例 8: 测试场景
// Example 8: Testing Scenarios
// ============================================================
// MockHTTPClient Mock HTTP 客户端
// MockHTTPClient Mock HTTP client
type MockHTTPClient struct {
Response string
}
func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
// 返回预设的响应
// Return preset response
return &http.Response{
StatusCode: 200,
Body: nil, // 实际测试中需要实现
Body: nil, // Need to implement in actual tests
}, nil
}
func Example_testing_with_mock() {
// 测试时使用 Mock
// Use Mock during testing
// mockHTTP := &MockHTTPClient{
// Response: `{"choices":[{"message":{"content":"test response"}}]}`,
// }
client := mcp.NewClient(
// mcp.WithHTTPClient(mockHTTP), // 实际测试中使用 mockHTTP
mcp.WithLogger(mcp.NewNoopLogger()), // 禁用日志
// mcp.WithHTTPClient(mockHTTP), // Use mockHTTP in actual tests
mcp.WithLogger(mcp.NewNoopLogger()), // Disable logging
)
result, _ := client.CallWithMessages("system", "user")
@@ -245,20 +245,20 @@ func Example_testing_with_mock() {
}
// ============================================================
// 示例 9: 环境特定配置
// Example 9: Environment-Specific Configuration
// ============================================================
func Example_environment_specific() {
// 开发环境:详细日志
// Development environment: detailed logging
devClient := mcp.NewClient(
mcp.WithDeepSeekConfig("sk-xxx"),
mcp.WithLogger(&CustomLogger{}), // 详细日志
mcp.WithLogger(&CustomLogger{}), // Detailed logging
)
// 生产环境:结构化日志 + 超时保护
// Production environment: structured logging + timeout protection
prodClient := mcp.NewClient(
mcp.WithDeepSeekConfig("sk-xxx"),
// mcp.WithLogger(&ZapLogger{}), // 生产级日志
// mcp.WithLogger(&ZapLogger{}), // Production-grade logging
mcp.WithTimeout(30*time.Second),
mcp.WithMaxRetries(3),
)
@@ -268,11 +268,11 @@ func Example_environment_specific() {
}
// ============================================================
// 示例 10: 完整实战示例
// Example 10: Complete Real-World Example
// ============================================================
func Example_real_world_usage() {
// 创建带有完整配置的客户端
// Create client with complete configuration
client := mcp.NewDeepSeekClientWithOptions(
mcp.WithAPIKey("sk-xxxxxxxxxx"),
mcp.WithTimeout(60*time.Second),
@@ -282,9 +282,9 @@ func Example_real_world_usage() {
mcp.WithLogger(&CustomLogger{}),
)
// 使用客户端
systemPrompt := "你是一个专业的量化交易顾问"
userPrompt := "分析 BTC 当前走势"
// Use client
systemPrompt := "You are a professional quantitative trading advisor"
userPrompt := "Analyze current BTC trend"
result, err := client.CallWithMessages(systemPrompt, userPrompt)
if err != nil {
@@ -292,5 +292,5 @@ func Example_real_world_usage() {
return
}
fmt.Printf("AI 响应: %s\n", result)
fmt.Printf("AI response: %s\n", result)
}
+5 -5
View File
@@ -5,18 +5,18 @@ import (
"time"
)
// AIClient AI客户端公开接口(给外部使用)
// AIClient public AI client interface (for external use)
type AIClient interface {
SetAPIKey(apiKey string, customURL string, customModel string)
SetTimeout(timeout time.Duration)
CallWithMessages(systemPrompt, userPrompt string) (string, error)
CallWithRequest(req *Request) (string, error) // 构建器模式 API(支持高级功能)
CallWithRequest(req *Request) (string, error) // Builder pattern API (supports advanced features)
}
// clientHooks 内部钩子接口(用于子类重写特定步骤)
// 这些方法只在包内部使用,实现动态分派
// clientHooks internal hook interface (for subclass to override specific steps)
// These methods are only used inside the package to implement dynamic dispatch
type clientHooks interface {
// 可被子类重写的钩子方法
// Hook methods that can be overridden by subclass
call(systemPrompt, userPrompt string) (string, error)
+5 -5
View File
@@ -1,8 +1,8 @@
package mcp
// Logger 日志接口(抽象依赖)
// 使用 Printf 风格的方法名,方便集成 logruszap 等主流日志库
// 默认使用全局 logger 包(见 mcp/config.go
// Logger interface (abstract dependency)
// Uses Printf-style method names for easy integration with mainstream logging libraries like logrus, zap, etc.
// Default uses global logger package (see mcp/config.go)
type Logger interface {
Debugf(format string, args ...any)
Infof(format string, args ...any)
@@ -10,7 +10,7 @@ type Logger interface {
Errorf(format string, args ...any)
}
// noopLogger 空日志实现(测试时使用)
// noopLogger no-op logger implementation (used in tests)
type noopLogger struct{}
func (l *noopLogger) Debugf(format string, args ...any) {}
@@ -18,7 +18,7 @@ func (l *noopLogger) Infof(format string, args ...any) {}
func (l *noopLogger) Warnf(format string, args ...any) {}
func (l *noopLogger) Errorf(format string, args ...any) {}
// NewNoopLogger 创建空日志器(测试使用)
// NewNoopLogger creates no-op logger (for testing)
func NewNoopLogger() Logger {
return &noopLogger{}
}
+28 -28
View File
@@ -13,19 +13,19 @@ import (
// Mock Logger
// ============================================================
// MockLogger Mock 日志器(用于测试)
// MockLogger Mock logger (for testing)
type MockLogger struct {
mu sync.Mutex
Logs []LogEntry
Enabled bool // 是否启用日志记录
Enabled bool // Whether logging is enabled
}
// LogEntry 日志条目
// LogEntry log entry
type LogEntry struct {
Level string
Format string
Args []any
Message string // 格式化后的消息
Message string // Formatted message
}
func NewMockLogger() *MockLogger {
@@ -68,14 +68,14 @@ func (m *MockLogger) log(level, format string, args ...any) {
})
}
// GetLogs 获取所有日志
// GetLogs gets all logs
func (m *MockLogger) GetLogs() []LogEntry {
m.mu.Lock()
defer m.mu.Unlock()
return append([]LogEntry{}, m.Logs...)
}
// GetLogsByLevel 获取指定级别的日志
// GetLogsByLevel gets logs by specified level
func (m *MockLogger) GetLogsByLevel(level string) []LogEntry {
m.mu.Lock()
defer m.mu.Unlock()
@@ -89,14 +89,14 @@ func (m *MockLogger) GetLogsByLevel(level string) []LogEntry {
return result
}
// Clear 清空日志
// Clear clears all logs
func (m *MockLogger) Clear() {
m.mu.Lock()
defer m.mu.Unlock()
m.Logs = make([]LogEntry, 0)
}
// HasLog 检查是否包含指定消息
// HasLog checks if contains specified message
func (m *MockLogger) HasLog(level, message string) bool {
m.mu.Lock()
defer m.mu.Unlock()
@@ -110,20 +110,20 @@ func (m *MockLogger) HasLog(level, message string) bool {
}
// ============================================================
// Mock HTTP Client (实现 http.RoundTripper)
// Mock HTTP Client (implements http.RoundTripper)
// ============================================================
// MockHTTPClient Mock HTTP 客户端(实现 http.RoundTripper
// MockHTTPClient Mock HTTP client (implements http.RoundTripper)
type MockHTTPClient struct {
mu sync.Mutex
// 配置
// Configuration
Response string
StatusCode int
Error error
ResponseFunc func(req *http.Request) (*http.Response, error) // 自定义响应函数
ResponseFunc func(req *http.Request) (*http.Response, error) // Custom response function
// 记录
// Recording
Requests []*http.Request
}
@@ -134,32 +134,32 @@ func NewMockHTTPClient() *MockHTTPClient {
}
}
// ToHTTPClient 转换为 http.Client
// ToHTTPClient converts to http.Client
func (m *MockHTTPClient) ToHTTPClient() *http.Client {
return &http.Client{
Transport: m,
}
}
// RoundTrip 实现 http.RoundTripper 接口
// RoundTrip implements http.RoundTripper interface
func (m *MockHTTPClient) RoundTrip(req *http.Request) (*http.Response, error) {
m.mu.Lock()
defer m.mu.Unlock()
// 记录请求
// Record request
m.Requests = append(m.Requests, req)
// 如果有自定义响应函数,使用它
// If custom response function exists, use it
if m.ResponseFunc != nil {
return m.ResponseFunc(req)
}
// 如果设置了错误,返回错误
// If error is set, return error
if m.Error != nil {
return nil, m.Error
}
// 返回模拟响应
// Return mock response
resp := &http.Response{
StatusCode: m.StatusCode,
Body: io.NopCloser(bytes.NewBufferString(m.Response)),
@@ -169,14 +169,14 @@ func (m *MockHTTPClient) RoundTrip(req *http.Request) (*http.Response, error) {
return resp, nil
}
// GetRequests 获取所有请求
// GetRequests gets all requests
func (m *MockHTTPClient) GetRequests() []*http.Request {
m.mu.Lock()
defer m.mu.Unlock()
return append([]*http.Request{}, m.Requests...)
}
// GetLastRequest 获取最后一次请求
// GetLastRequest gets last request
func (m *MockHTTPClient) GetLastRequest() *http.Request {
m.mu.Lock()
defer m.mu.Unlock()
@@ -187,14 +187,14 @@ func (m *MockHTTPClient) GetLastRequest() *http.Request {
return m.Requests[len(m.Requests)-1]
}
// Reset 重置状态
// Reset resets state
func (m *MockHTTPClient) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
m.Requests = make([]*http.Request, 0)
}
// SetSuccessResponse 设置成功响应
// SetSuccessResponse sets success response
func (m *MockHTTPClient) SetSuccessResponse(content string) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -204,7 +204,7 @@ func (m *MockHTTPClient) SetSuccessResponse(content string) {
m.Error = nil
}
// SetErrorResponse 设置错误响应
// SetErrorResponse sets error response
func (m *MockHTTPClient) SetErrorResponse(statusCode int, message string) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -214,7 +214,7 @@ func (m *MockHTTPClient) SetErrorResponse(statusCode int, message string) {
m.Error = nil
}
// SetNetworkError 设置网络错误
// SetNetworkError sets network error
func (m *MockHTTPClient) SetNetworkError(err error) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -223,10 +223,10 @@ func (m *MockHTTPClient) SetNetworkError(err error) {
}
// ============================================================
// Mock Client Hooks (用于测试钩子机制)
// Mock Client Hooks (for testing hook mechanism)
// ============================================================
// MockClientHooks Mock 客户端钩子
// MockClientHooks Mock client hooks
type MockClientHooks struct {
BuildRequestBodyCalled int
BuildUrlCalled int
@@ -235,7 +235,7 @@ type MockClientHooks struct {
ParseResponseCalled int
IsRetryableErrorCalled int
// 自定义返回值
// Custom return values
BuildUrlFunc func() string
ParseResponseFunc func([]byte) (string, error)
IsRetryableErrorFunc func(error) bool
+29 -29
View File
@@ -5,16 +5,16 @@ import (
"time"
)
// ClientOption 客户端选项函数(Functional Options 模式)
// ClientOption client option function (Functional Options pattern)
type ClientOption func(*Config)
// ============================================================
// 依赖注入选项
// Dependency Injection Options
// ============================================================
// WithLogger 设置自定义日志器
// WithLogger sets custom logger
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithLogger(customLogger))
func WithLogger(logger Logger) ClientOption {
return func(c *Config) {
@@ -22,9 +22,9 @@ func WithLogger(logger Logger) ClientOption {
}
}
// WithHTTPClient 设置自定义 HTTP 客户端
// WithHTTPClient sets custom HTTP client
//
// 使用示例:
// Usage example:
// httpClient := &http.Client{Timeout: 60 * time.Second}
// client := mcp.NewClient(mcp.WithHTTPClient(httpClient))
func WithHTTPClient(client *http.Client) ClientOption {
@@ -34,12 +34,12 @@ func WithHTTPClient(client *http.Client) ClientOption {
}
// ============================================================
// 超时和重试选项
// Timeout and Retry Options
// ============================================================
// WithTimeout 设置请求超时时间
// WithTimeout sets request timeout duration
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithTimeout(60 * time.Second))
func WithTimeout(timeout time.Duration) ClientOption {
return func(c *Config) {
@@ -48,9 +48,9 @@ func WithTimeout(timeout time.Duration) ClientOption {
}
}
// WithMaxRetries 设置最大重试次数
// WithMaxRetries sets maximum retry count
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithMaxRetries(5))
func WithMaxRetries(maxRetries int) ClientOption {
return func(c *Config) {
@@ -58,9 +58,9 @@ func WithMaxRetries(maxRetries int) ClientOption {
}
}
// WithRetryWaitBase 设置重试等待基础时长
// WithRetryWaitBase sets base retry wait duration
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithRetryWaitBase(3 * time.Second))
func WithRetryWaitBase(waitTime time.Duration) ClientOption {
return func(c *Config) {
@@ -69,12 +69,12 @@ func WithRetryWaitBase(waitTime time.Duration) ClientOption {
}
// ============================================================
// AI 参数选项
// AI Parameter Options
// ============================================================
// WithMaxTokens 设置最大 token
// WithMaxTokens sets maximum token count
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithMaxTokens(4000))
func WithMaxTokens(maxTokens int) ClientOption {
return func(c *Config) {
@@ -82,9 +82,9 @@ func WithMaxTokens(maxTokens int) ClientOption {
}
}
// WithTemperature 设置温度参数
// WithTemperature sets temperature parameter
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithTemperature(0.7))
func WithTemperature(temperature float64) ClientOption {
return func(c *Config) {
@@ -93,38 +93,38 @@ func WithTemperature(temperature float64) ClientOption {
}
// ============================================================
// Provider 配置选项
// Provider Configuration Options
// ============================================================
// WithAPIKey 设置 API Key
// WithAPIKey sets API Key
func WithAPIKey(apiKey string) ClientOption {
return func(c *Config) {
c.APIKey = apiKey
}
}
// WithBaseURL 设置基础 URL
// WithBaseURL sets base URL
func WithBaseURL(baseURL string) ClientOption {
return func(c *Config) {
c.BaseURL = baseURL
}
}
// WithModel 设置模型名称
// WithModel sets model name
func WithModel(model string) ClientOption {
return func(c *Config) {
c.Model = model
}
}
// WithProvider 设置提供商
// WithProvider sets provider
func WithProvider(provider string) ClientOption {
return func(c *Config) {
c.Provider = provider
}
}
// WithUseFullURL 设置是否使用完整 URL
// WithUseFullURL sets whether to use full URL
func WithUseFullURL(useFullURL bool) ClientOption {
return func(c *Config) {
c.UseFullURL = useFullURL
@@ -132,12 +132,12 @@ func WithUseFullURL(useFullURL bool) ClientOption {
}
// ============================================================
// 组合选项(便捷方法)
// Combined Options (Convenience Methods)
// ============================================================
// WithDeepSeekConfig 设置 DeepSeek 配置
// WithDeepSeekConfig sets DeepSeek configuration
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithDeepSeekConfig("sk-xxx"))
func WithDeepSeekConfig(apiKey string) ClientOption {
return func(c *Config) {
@@ -148,9 +148,9 @@ func WithDeepSeekConfig(apiKey string) ClientOption {
}
}
// WithQwenConfig 设置 Qwen 配置
// WithQwenConfig sets Qwen configuration
//
// 使用示例:
// Usage example:
// client := mcp.NewClient(mcp.WithQwenConfig("sk-xxx"))
func WithQwenConfig(apiKey string) ClientOption {
return func(c *Config) {
+15 -15
View File
@@ -7,7 +7,7 @@ import (
)
// ============================================================
// 测试基础选项
// Test Basic Options
// ============================================================
func TestWithProvider(t *testing.T) {
@@ -116,7 +116,7 @@ func TestWithHTTPClient(t *testing.T) {
}
// ============================================================
// 测试预设配置选项
// Test Preset Configuration Options
// ============================================================
func TestWithDeepSeekConfig(t *testing.T) {
@@ -162,7 +162,7 @@ func TestWithQwenConfig(t *testing.T) {
}
// ============================================================
// 测试选项组合
// Test Options Combination
// ============================================================
func TestMultipleOptions(t *testing.T) {
@@ -170,7 +170,7 @@ func TestMultipleOptions(t *testing.T) {
cfg := DefaultConfig()
// 应用多个选项
// Apply multiple options
options := []ClientOption{
WithProvider("test-provider"),
WithAPIKey("sk-test-key"),
@@ -186,7 +186,7 @@ func TestMultipleOptions(t *testing.T) {
opt(cfg)
}
// 验证所有选项都被应用
// Verify all options are applied
if cfg.Provider != "test-provider" {
t.Error("Provider should be set")
}
@@ -223,14 +223,14 @@ func TestMultipleOptions(t *testing.T) {
func TestOptionsOverride(t *testing.T) {
cfg := DefaultConfig()
// 先应用 DeepSeek 配置
// First apply DeepSeek configuration
WithDeepSeekConfig("sk-deepseek-key")(cfg)
// 然后覆盖某些选项
// Then override some options
WithModel("custom-model")(cfg)
WithMaxTokens(5000)(cfg)
// 验证覆盖成功
// Verify override succeeded
if cfg.Model != "custom-model" {
t.Errorf("Model should be overridden to 'custom-model', got '%s'", cfg.Model)
}
@@ -239,7 +239,7 @@ func TestOptionsOverride(t *testing.T) {
t.Errorf("MaxTokens should be overridden to 5000, got %d", cfg.MaxTokens)
}
// 验证其他 DeepSeek 配置保持不变
// Verify other DeepSeek configurations remain unchanged
if cfg.Provider != ProviderDeepSeek {
t.Error("Provider should still be DeepSeek")
}
@@ -250,7 +250,7 @@ func TestOptionsOverride(t *testing.T) {
}
// ============================================================
// 测试与客户端集成
// Test Integration with Client
// ============================================================
func TestOptionsWithNewClient(t *testing.T) {
@@ -266,7 +266,7 @@ func TestOptionsWithNewClient(t *testing.T) {
c := client.(*Client)
// 验证选项被正确应用到客户端
// Verify options are correctly applied to client
if c.Provider != "test-provider" {
t.Error("Provider should be set from options")
}
@@ -299,7 +299,7 @@ func TestOptionsWithDeepSeekClient(t *testing.T) {
dsClient := client.(*DeepSeekClient)
// 验证 DeepSeek 默认值
// Verify DeepSeek default values
if dsClient.Provider != ProviderDeepSeek {
t.Error("Provider should be DeepSeek")
}
@@ -312,7 +312,7 @@ func TestOptionsWithDeepSeekClient(t *testing.T) {
t.Error("Model should be DeepSeek default")
}
// 验证自定义选项
// Verify custom options
if dsClient.APIKey != "sk-deepseek-key" {
t.Error("APIKey should be set from options")
}
@@ -337,7 +337,7 @@ func TestOptionsWithQwenClient(t *testing.T) {
qwenClient := client.(*QwenClient)
// 验证 Qwen 默认值
// Verify Qwen default values
if qwenClient.Provider != ProviderQwen {
t.Error("Provider should be Qwen")
}
@@ -350,7 +350,7 @@ func TestOptionsWithQwenClient(t *testing.T) {
t.Error("Model should be Qwen default")
}
// 验证自定义选项
// Verify custom options
if qwenClient.APIKey != "sk-qwen-key" {
t.Error("APIKey should be set from options")
}
+15 -15
View File
@@ -14,45 +14,45 @@ type QwenClient struct {
*Client
}
// NewQwenClient 创建 Qwen 客户端(向前兼容)
// NewQwenClient creates Qwen client (backward compatible)
//
// Deprecated: 推荐使用 NewQwenClientWithOptions 以获得更好的灵活性
// Deprecated: Recommend using NewQwenClientWithOptions for better flexibility
func NewQwenClient() AIClient {
return NewQwenClientWithOptions()
}
// NewQwenClientWithOptions 创建 Qwen 客户端(支持选项模式)
// NewQwenClientWithOptions creates Qwen client (supports options pattern)
//
// 使用示例:
// // 基础用法
// Usage examples:
// // Basic usage
// client := mcp.NewQwenClientWithOptions()
//
// // 自定义配置
// // Custom configuration
// client := mcp.NewQwenClientWithOptions(
// mcp.WithAPIKey("sk-xxx"),
// mcp.WithLogger(customLogger),
// mcp.WithTimeout(60*time.Second),
// )
func NewQwenClientWithOptions(opts ...ClientOption) AIClient {
// 1. 创建 Qwen 预设选项
// 1. Create Qwen preset options
qwenOpts := []ClientOption{
WithProvider(ProviderQwen),
WithModel(DefaultQwenModel),
WithBaseURL(DefaultQwenBaseURL),
}
// 2. 合并用户选项(用户选项优先级更高)
// 2. Merge user options (user options have higher priority)
allOpts := append(qwenOpts, opts...)
// 3. 创建基础客户端
// 3. Create base client
baseClient := NewClient(allOpts...).(*Client)
// 4. 创建 Qwen 客户端
// 4. Create Qwen client
qwenClient := &QwenClient{
Client: baseClient,
}
// 5. 设置 hooks 指向 QwenClient(实现动态分派)
// 5. Set hooks to point to QwenClient (implement dynamic dispatch)
baseClient.hooks = qwenClient
return qwenClient
@@ -66,15 +66,15 @@ func (qwenClient *QwenClient) SetAPIKey(apiKey string, customURL string, customM
}
if customURL != "" {
qwenClient.BaseURL = customURL
qwenClient.logger.Infof("🔧 [MCP] Qwen 使用自定义 BaseURL: %s", customURL)
qwenClient.logger.Infof("🔧 [MCP] Qwen using custom BaseURL: %s", customURL)
} else {
qwenClient.logger.Infof("🔧 [MCP] Qwen 使用默认 BaseURL: %s", qwenClient.BaseURL)
qwenClient.logger.Infof("🔧 [MCP] Qwen using default BaseURL: %s", qwenClient.BaseURL)
}
if customModel != "" {
qwenClient.Model = customModel
qwenClient.logger.Infof("🔧 [MCP] Qwen 使用自定义 Model: %s", customModel)
qwenClient.logger.Infof("🔧 [MCP] Qwen using custom Model: %s", customModel)
} else {
qwenClient.logger.Infof("🔧 [MCP] Qwen 使用默认 Model: %s", qwenClient.Model)
qwenClient.logger.Infof("🔧 [MCP] Qwen using default Model: %s", qwenClient.Model)
}
}
+22 -22
View File
@@ -6,7 +6,7 @@ import (
)
// ============================================================
// 测试 QwenClient 创建和配置
// Test QwenClient Creation and Configuration
// ============================================================
func TestNewQwenClient_Default(t *testing.T) {
@@ -16,13 +16,13 @@ func TestNewQwenClient_Default(t *testing.T) {
t.Fatal("client should not be nil")
}
// 类型断言检查
// Type assertion check
qwenClient, ok := client.(*QwenClient)
if !ok {
t.Fatal("client should be *QwenClient")
}
// 验证默认值
// Verify default values
if qwenClient.Provider != ProviderQwen {
t.Errorf("Provider should be '%s', got '%s'", ProviderQwen, qwenClient.Provider)
}
@@ -58,7 +58,7 @@ func TestNewQwenClientWithOptions(t *testing.T) {
qwenClient := client.(*QwenClient)
// 验证自定义选项被应用
// Verify custom options are applied
if qwenClient.logger != mockLogger {
t.Error("logger should be set from option")
}
@@ -75,7 +75,7 @@ func TestNewQwenClientWithOptions(t *testing.T) {
t.Error("MaxTokens should be 4000")
}
// 验证 Qwen 默认值仍然保留
// Verify Qwen default values are retained
if qwenClient.Provider != ProviderQwen {
t.Errorf("Provider should still be '%s'", ProviderQwen)
}
@@ -86,7 +86,7 @@ func TestNewQwenClientWithOptions(t *testing.T) {
}
// ============================================================
// 测试 SetAPIKey
// Test SetAPIKey
// ============================================================
func TestQwenClient_SetAPIKey(t *testing.T) {
@@ -97,20 +97,20 @@ func TestQwenClient_SetAPIKey(t *testing.T) {
qwenClient := client.(*QwenClient)
// 测试设置 API Key(默认 URL Model
// Test setting API Key (default URL and Model)
qwenClient.SetAPIKey("sk-test-key-12345678", "", "")
if qwenClient.APIKey != "sk-test-key-12345678" {
t.Errorf("APIKey should be 'sk-test-key-12345678', got '%s'", qwenClient.APIKey)
}
// 验证日志记录
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
if len(logs) == 0 {
t.Error("should have logged API key setting")
}
// 验证 BaseURL Model 保持默认
// Verify BaseURL and Model remain default
if qwenClient.BaseURL != DefaultQwenBaseURL {
t.Error("BaseURL should remain default")
}
@@ -135,11 +135,11 @@ func TestQwenClient_SetAPIKey_WithCustomURL(t *testing.T) {
t.Errorf("BaseURL should be '%s', got '%s'", customURL, qwenClient.BaseURL)
}
// 验证日志记录
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
hasCustomURLLog := false
for _, log := range logs {
if log.Format == "🔧 [MCP] Qwen 使用自定义 BaseURL: %s" {
if log.Format == "🔧 [MCP] Qwen using custom BaseURL: %s" {
hasCustomURLLog = true
break
}
@@ -165,11 +165,11 @@ func TestQwenClient_SetAPIKey_WithCustomModel(t *testing.T) {
t.Errorf("Model should be '%s', got '%s'", customModel, qwenClient.Model)
}
// 验证日志记录
// Verify logging
logs := mockLogger.GetLogsByLevel("INFO")
hasCustomModelLog := false
for _, log := range logs {
if log.Format == "🔧 [MCP] Qwen 使用自定义 Model: %s" {
if log.Format == "🔧 [MCP] Qwen using custom Model: %s" {
hasCustomModelLog = true
break
}
@@ -181,7 +181,7 @@ func TestQwenClient_SetAPIKey_WithCustomModel(t *testing.T) {
}
// ============================================================
// 测试集成功能
// Test Integration Features
// ============================================================
func TestQwenClient_CallWithMessages_Success(t *testing.T) {
@@ -205,7 +205,7 @@ func TestQwenClient_CallWithMessages_Success(t *testing.T) {
t.Errorf("expected 'Qwen AI response', got '%s'", result)
}
// 验证请求
// Verify request
requests := mockHTTP.GetRequests()
if len(requests) != 1 {
t.Fatalf("expected 1 request, got %d", len(requests))
@@ -213,19 +213,19 @@ func TestQwenClient_CallWithMessages_Success(t *testing.T) {
req := requests[0]
// 验证 URL
// Verify URL
expectedURL := DefaultQwenBaseURL + "/chat/completions"
if req.URL.String() != expectedURL {
t.Errorf("expected URL '%s', got '%s'", expectedURL, req.URL.String())
}
// 验证 Authorization header
// Verify Authorization header
authHeader := req.Header.Get("Authorization")
if authHeader != "Bearer sk-test-key" {
t.Errorf("expected 'Bearer sk-test-key', got '%s'", authHeader)
}
// 验证 Content-Type
// Verify Content-Type
if req.Header.Get("Content-Type") != "application/json" {
t.Error("Content-Type should be application/json")
}
@@ -242,7 +242,7 @@ func TestQwenClient_Timeout(t *testing.T) {
t.Errorf("expected timeout 30s, got %v", qwenClient.httpClient.Timeout)
}
// 测试 SetTimeout
// Test SetTimeout
client.SetTimeout(60 * time.Second)
if qwenClient.httpClient.Timeout != 60*time.Second {
@@ -251,19 +251,19 @@ func TestQwenClient_Timeout(t *testing.T) {
}
// ============================================================
// 测试 hooks 机制
// Test hooks Mechanism
// ============================================================
func TestQwenClient_HooksIntegration(t *testing.T) {
client := NewQwenClientWithOptions()
qwenClient := client.(*QwenClient)
// 验证 hooks 指向 qwenClient 自己(实现多态)
// Verify hooks point to qwenClient itself (implements polymorphism)
if qwenClient.hooks != qwenClient {
t.Error("hooks should point to qwenClient for polymorphism")
}
// 验证 buildUrl 使用 Qwen 配置
// Verify buildUrl uses Qwen configuration
url := qwenClient.buildUrl()
expectedURL := DefaultQwenBaseURL + "/chat/completions"
if url != expectedURL {
+28 -28
View File
@@ -1,45 +1,45 @@
package mcp
// Message 表示一条对话消息
// Message represents a conversation message
type Message struct {
Role string `json:"role"` // "system", "user", "assistant"
Content string `json:"content"` // 消息内容
Content string `json:"content"` // Message content
}
// Tool 表示 AI 可以调用的工具/函数
// Tool represents a tool/function that AI can call
type Tool struct {
Type string `json:"type"` // 通常为 "function"
Function FunctionDef `json:"function"` // 函数定义
Type string `json:"type"` // Usually "function"
Function FunctionDef `json:"function"` // Function definition
}
// FunctionDef 函数定义
// FunctionDef function definition
type FunctionDef struct {
Name string `json:"name"` // 函数名
Description string `json:"description,omitempty"` // 函数描述
Parameters map[string]any `json:"parameters,omitempty"` // 参数 schema (JSON Schema)
Name string `json:"name"` // Function name
Description string `json:"description,omitempty"` // Function description
Parameters map[string]any `json:"parameters,omitempty"` // Parameter schema (JSON Schema)
}
// Request AI API 请求(支持高级功能)
// Request AI API request (supports advanced features)
type Request struct {
// 基础字段
Model string `json:"model"` // 模型名称
Messages []Message `json:"messages"` // 对话消息列表
Stream bool `json:"stream,omitempty"` // 是否流式响应
// Basic fields
Model string `json:"model"` // Model name
Messages []Message `json:"messages"` // Conversation message list
Stream bool `json:"stream,omitempty"` // Whether to stream response
// 可选参数(用于精细控制)
Temperature *float64 `json:"temperature,omitempty"` // 温度 (0-2),控制随机性
MaxTokens *int `json:"max_tokens,omitempty"` // 最大 token
TopP *float64 `json:"top_p,omitempty"` // 核采样参数 (0-1)
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // 频率惩罚 (-2 to 2)
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // 存在惩罚 (-2 to 2)
Stop []string `json:"stop,omitempty"` // 停止序列
// Optional parameters (for fine-grained control)
Temperature *float64 `json:"temperature,omitempty"` // Temperature (0-2), controls randomness
MaxTokens *int `json:"max_tokens,omitempty"` // Maximum token count
TopP *float64 `json:"top_p,omitempty"` // Nucleus sampling parameter (0-1)
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Frequency penalty (-2 to 2)
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Presence penalty (-2 to 2)
Stop []string `json:"stop,omitempty"` // Stop sequences
// 高级功能
Tools []Tool `json:"tools,omitempty"` // 可用工具列表
ToolChoice string `json:"tool_choice,omitempty"` // 工具选择策略 ("auto", "none", {"type": "function", "function": {"name": "xxx"}})
// Advanced features
Tools []Tool `json:"tools,omitempty"` // Available tools list
ToolChoice string `json:"tool_choice,omitempty"` // Tool choice strategy ("auto", "none", {"type": "function", "function": {"name": "xxx"}})
}
// NewMessage 创建一条消息
// NewMessage creates a message
func NewMessage(role, content string) Message {
return Message{
Role: role,
@@ -47,7 +47,7 @@ func NewMessage(role, content string) Message {
}
}
// NewSystemMessage 创建系统消息
// NewSystemMessage creates a system message
func NewSystemMessage(content string) Message {
return Message{
Role: "system",
@@ -55,7 +55,7 @@ func NewSystemMessage(content string) Message {
}
}
// NewUserMessage 创建用户消息
// NewUserMessage creates a user message
func NewUserMessage(content string) Message {
return Message{
Role: "user",
@@ -63,7 +63,7 @@ func NewUserMessage(content string) Message {
}
}
// NewAssistantMessage 创建助手消息
// NewAssistantMessage creates an assistant message
func NewAssistantMessage(content string) Message {
return Message{
Role: "assistant",
+49 -49
View File
@@ -4,7 +4,7 @@ import (
"errors"
)
// RequestBuilder 请求构建器
// RequestBuilder request builder
type RequestBuilder struct {
model string
messages []Message
@@ -19,9 +19,9 @@ type RequestBuilder struct {
toolChoice string
}
// NewRequestBuilder 创建请求构建器
// NewRequestBuilder creates request builder
//
// 使用示例:
// Usage example:
// request := NewRequestBuilder().
// WithSystemPrompt("You are helpful").
// WithUserPrompt("Hello").
@@ -35,26 +35,26 @@ func NewRequestBuilder() *RequestBuilder {
}
// ============================================================
// 模型和流式配置
// Model and Stream Configuration
// ============================================================
// WithModel 设置模型名称
// WithModel sets model name
func (b *RequestBuilder) WithModel(model string) *RequestBuilder {
b.model = model
return b
}
// WithStream 设置是否使用流式响应
// WithStream sets whether to use streaming response
func (b *RequestBuilder) WithStream(stream bool) *RequestBuilder {
b.stream = stream
return b
}
// ============================================================
// 消息构建方法
// Message Building Methods
// ============================================================
// WithSystemPrompt 添加系统提示词(便捷方法)
// WithSystemPrompt adds system prompt (convenience method)
func (b *RequestBuilder) WithSystemPrompt(prompt string) *RequestBuilder {
if prompt != "" {
b.messages = append(b.messages, NewSystemMessage(prompt))
@@ -62,7 +62,7 @@ func (b *RequestBuilder) WithSystemPrompt(prompt string) *RequestBuilder {
return b
}
// WithUserPrompt 添加用户提示词(便捷方法)
// WithUserPrompt adds user prompt (convenience method)
func (b *RequestBuilder) WithUserPrompt(prompt string) *RequestBuilder {
if prompt != "" {
b.messages = append(b.messages, NewUserMessage(prompt))
@@ -70,17 +70,17 @@ func (b *RequestBuilder) WithUserPrompt(prompt string) *RequestBuilder {
return b
}
// AddSystemMessage 添加系统消息
// AddSystemMessage adds system message
func (b *RequestBuilder) AddSystemMessage(content string) *RequestBuilder {
return b.WithSystemPrompt(content)
}
// AddUserMessage 添加用户消息
// AddUserMessage adds user message
func (b *RequestBuilder) AddUserMessage(content string) *RequestBuilder {
return b.WithUserPrompt(content)
}
// AddAssistantMessage 添加助手消息(用于多轮对话上下文)
// AddAssistantMessage adds assistant message (for multi-turn conversation context)
func (b *RequestBuilder) AddAssistantMessage(content string) *RequestBuilder {
if content != "" {
b.messages = append(b.messages, NewAssistantMessage(content))
@@ -88,7 +88,7 @@ func (b *RequestBuilder) AddAssistantMessage(content string) *RequestBuilder {
return b
}
// AddMessage 添加自定义角色的消息
// AddMessage adds message with custom role
func (b *RequestBuilder) AddMessage(role, content string) *RequestBuilder {
if content != "" {
b.messages = append(b.messages, NewMessage(role, content))
@@ -96,33 +96,33 @@ func (b *RequestBuilder) AddMessage(role, content string) *RequestBuilder {
return b
}
// AddMessages 批量添加消息
// AddMessages adds messages in batch
func (b *RequestBuilder) AddMessages(messages ...Message) *RequestBuilder {
b.messages = append(b.messages, messages...)
return b
}
// AddConversationHistory 添加对话历史
// AddConversationHistory adds conversation history
func (b *RequestBuilder) AddConversationHistory(history []Message) *RequestBuilder {
b.messages = append(b.messages, history...)
return b
}
// ClearMessages 清空所有消息
// ClearMessages clears all messages
func (b *RequestBuilder) ClearMessages() *RequestBuilder {
b.messages = make([]Message, 0)
return b
}
// ============================================================
// 参数控制方法
// Parameter Control Methods
// ============================================================
// WithTemperature 设置温度参数 (0-2)
// 较高的温度(如 1.2)会使输出更随机,较低的温度(如 0.2)会使输出更确定
// WithTemperature sets temperature parameter (0-2)
// Higher temperature (e.g. 1.2) makes output more random, lower temperature (e.g. 0.2) makes output more deterministic
func (b *RequestBuilder) WithTemperature(t float64) *RequestBuilder {
if t < 0 || t > 2 {
// 可以选择 panic 或者静默忽略,这里选择限制范围
// Can choose to panic or silently ignore, here we choose to limit the range
if t < 0 {
t = 0
}
@@ -134,7 +134,7 @@ func (b *RequestBuilder) WithTemperature(t float64) *RequestBuilder {
return b
}
// WithMaxTokens 设置最大 token
// WithMaxTokens sets maximum token count
func (b *RequestBuilder) WithMaxTokens(tokens int) *RequestBuilder {
if tokens > 0 {
b.maxTokens = &tokens
@@ -142,8 +142,8 @@ func (b *RequestBuilder) WithMaxTokens(tokens int) *RequestBuilder {
return b
}
// WithTopP 设置 top-p 核采样参数 (0-1)
// 控制考虑的 token 范围,较小的值(如 0.1)使输出更聚焦
// WithTopP sets top-p nucleus sampling parameter (0-1)
// Controls the range of tokens considered, smaller values (e.g. 0.1) make output more focused
func (b *RequestBuilder) WithTopP(p float64) *RequestBuilder {
if p >= 0 && p <= 1 {
b.topP = &p
@@ -151,8 +151,8 @@ func (b *RequestBuilder) WithTopP(p float64) *RequestBuilder {
return b
}
// WithFrequencyPenalty 设置频率惩罚 (-2 to 2)
// 正值会根据 token 在文本中出现的频率惩罚它们,减少重复
// WithFrequencyPenalty sets frequency penalty (-2 to 2)
// Positive values penalize tokens based on their frequency in the text, reducing repetition
func (b *RequestBuilder) WithFrequencyPenalty(penalty float64) *RequestBuilder {
if penalty >= -2 && penalty <= 2 {
b.frequencyPenalty = &penalty
@@ -160,8 +160,8 @@ func (b *RequestBuilder) WithFrequencyPenalty(penalty float64) *RequestBuilder {
return b
}
// WithPresencePenalty 设置存在惩罚 (-2 to 2)
// 正值会根据 token 是否出现在文本中惩罚它们,增加话题多样性
// WithPresencePenalty sets presence penalty (-2 to 2)
// Positive values penalize tokens based on whether they appear in the text, increasing topic diversity
func (b *RequestBuilder) WithPresencePenalty(penalty float64) *RequestBuilder {
if penalty >= -2 && penalty <= 2 {
b.presencePenalty = &penalty
@@ -169,14 +169,14 @@ func (b *RequestBuilder) WithPresencePenalty(penalty float64) *RequestBuilder {
return b
}
// WithStopSequences 设置停止序列
// 当模型生成这些序列之一时,将停止生成
// WithStopSequences sets stop sequences
// Model will stop generating when it generates one of these sequences
func (b *RequestBuilder) WithStopSequences(sequences []string) *RequestBuilder {
b.stop = sequences
return b
}
// AddStopSequence 添加单个停止序列
// AddStopSequence adds a single stop sequence
func (b *RequestBuilder) AddStopSequence(sequence string) *RequestBuilder {
if sequence != "" {
b.stop = append(b.stop, sequence)
@@ -185,16 +185,16 @@ func (b *RequestBuilder) AddStopSequence(sequence string) *RequestBuilder {
}
// ============================================================
// 工具/函数调用相关
// Tool/Function Calling Related
// ============================================================
// AddTool 添加工具
// AddTool adds a tool
func (b *RequestBuilder) AddTool(tool Tool) *RequestBuilder {
b.tools = append(b.tools, tool)
return b
}
// AddFunction 添加函数(便捷方法)
// AddFunction adds a function (convenience method)
func (b *RequestBuilder) AddFunction(name, description string, parameters map[string]any) *RequestBuilder {
tool := Tool{
Type: "function",
@@ -208,27 +208,27 @@ func (b *RequestBuilder) AddFunction(name, description string, parameters map[st
return b
}
// WithToolChoice 设置工具选择策略
// - "auto": 自动选择是否调用工具
// - "none": 不调用工具
// - 也可以指定特定工具: `{"type": "function", "function": {"name": "my_function"}}`
// WithToolChoice sets tool choice strategy
// - "auto": automatically choose whether to call tools
// - "none": don't call tools
// - Can also specify a specific tool: `{"type": "function", "function": {"name": "my_function"}}`
func (b *RequestBuilder) WithToolChoice(choice string) *RequestBuilder {
b.toolChoice = choice
return b
}
// ============================================================
// 构建方法
// Build Methods
// ============================================================
// Build 构建请求对象
// Build builds request object
func (b *RequestBuilder) Build() (*Request, error) {
// 验证:至少需要一条消息
// Validation: at least one message is required
if len(b.messages) == 0 {
return nil, errors.New("至少需要一条消息")
return nil, errors.New("at least one message is required")
}
// 创建请求
// Create request
req := &Request{
Model: b.model,
Messages: b.messages,
@@ -238,7 +238,7 @@ func (b *RequestBuilder) Build() (*Request, error) {
ToolChoice: b.toolChoice,
}
// 只设置非 nil 的可选参数(避免发送 0 值覆盖服务端默认值)
// Only set non-nil optional parameters (avoid sending 0 values that override server defaults)
if b.temperature != nil {
req.Temperature = b.temperature
}
@@ -258,8 +258,8 @@ func (b *RequestBuilder) Build() (*Request, error) {
return req, nil
}
// MustBuild 构建请求对象,如果失败则 panic
// 适用于构建过程中确定不会出错的场景
// MustBuild builds request object, panics if failed
// Suitable for scenarios where build is guaranteed not to fail
func (b *RequestBuilder) MustBuild() *Request {
req, err := b.Build()
if err != nil {
@@ -269,10 +269,10 @@ func (b *RequestBuilder) MustBuild() *Request {
}
// ============================================================
// 便捷方法:预设场景
// Convenience Methods: Preset Scenarios
// ============================================================
// ForChat 创建用于聊天的构建器(预设合理的参数)
// ForChat creates builder for chat (preset with reasonable parameters)
func ForChat() *RequestBuilder {
temp := 0.7
tokens := 2000
@@ -284,7 +284,7 @@ func ForChat() *RequestBuilder {
}
}
// ForCodeGeneration 创建用于代码生成的构建器(低温度,更确定)
// ForCodeGeneration creates builder for code generation (low temperature, more deterministic)
func ForCodeGeneration() *RequestBuilder {
temp := 0.2
tokens := 2000
@@ -298,7 +298,7 @@ func ForCodeGeneration() *RequestBuilder {
}
}
// ForCreativeWriting 创建用于创意写作的构建器(高温度,更随机)
// ForCreativeWriting creates builder for creative writing (high temperature, more random)
func ForCreativeWriting() *RequestBuilder {
temp := 1.2
tokens := 4000
+17 -17
View File
@@ -6,7 +6,7 @@ import (
)
// ============================================================
// 测试 RequestBuilder 基本功能
// Test RequestBuilder Basic Features
// ============================================================
func TestRequestBuilder_BasicUsage(t *testing.T) {
@@ -39,13 +39,13 @@ func TestRequestBuilder_EmptyMessages(t *testing.T) {
t.Error("Build should error when no messages")
}
if err.Error() != "至少需要一条消息" {
if err.Error() != "at least one message is required" {
t.Errorf("unexpected error: %v", err)
}
}
// ============================================================
// 测试消息构建方法
// Test Message Building Methods
// ============================================================
func TestRequestBuilder_MultipleMessages(t *testing.T) {
@@ -85,7 +85,7 @@ func TestRequestBuilder_AddConversationHistory(t *testing.T) {
}
// ============================================================
// 测试参数控制方法
// Test Parameter Control Methods
// ============================================================
func TestRequestBuilder_WithTemperature(t *testing.T) {
@@ -165,7 +165,7 @@ func TestRequestBuilder_WithStopSequences(t *testing.T) {
}
// ============================================================
// 测试工具/函数调用
// Test Tool/Function Calling
// ============================================================
func TestRequestBuilder_AddTool(t *testing.T) {
@@ -229,7 +229,7 @@ func TestRequestBuilder_AddFunction(t *testing.T) {
}
// ============================================================
// 测试便捷方法
// Test Convenience Methods
// ============================================================
func TestRequestBuilder_ForChat(t *testing.T) {
@@ -287,7 +287,7 @@ func TestRequestBuilder_ForCreativeWriting(t *testing.T) {
}
// ============================================================
// 测试 CallWithRequest 集成
// Test CallWithRequest Integration
// ============================================================
func TestClient_CallWithRequest_Success(t *testing.T) {
@@ -317,25 +317,25 @@ func TestClient_CallWithRequest_Success(t *testing.T) {
t.Errorf("expected 'Builder response', got '%s'", result)
}
// 验证请求体
// Verify request body
requests := mockHTTP.GetRequests()
if len(requests) != 1 {
t.Fatalf("expected 1 request, got %d", len(requests))
}
// 解析请求体验证参数
// Parse request body to verify parameters
var body map[string]interface{}
decoder := json.NewDecoder(requests[0].Body)
if err := decoder.Decode(&body); err != nil {
t.Fatalf("failed to decode request body: %v", err)
}
// 验证 temperature
// Verify temperature
if body["temperature"] != 0.8 {
t.Errorf("expected temperature 0.8, got %v", body["temperature"])
}
// 验证 messages
// Verify messages
messages, ok := body["messages"].([]interface{})
if !ok || len(messages) != 2 {
t.Error("messages not correctly formatted")
@@ -353,7 +353,7 @@ func TestClient_CallWithRequest_MultiRound(t *testing.T) {
WithAPIKey("sk-test-key"),
)
// 构建多轮对话
// Build multi-round conversation
request := NewRequestBuilder().
AddSystemMessage("You are a trading advisor").
AddUserMessage("Analyze BTC").
@@ -372,7 +372,7 @@ func TestClient_CallWithRequest_MultiRound(t *testing.T) {
t.Errorf("expected 'Multi-round response', got '%s'", result)
}
// 验证请求体包含所有消息
// Verify request body contains all messages
requests := mockHTTP.GetRequests()
var body map[string]interface{}
json.NewDecoder(requests[0].Body).Decode(&body)
@@ -411,7 +411,7 @@ func TestClient_CallWithRequest_WithTools(t *testing.T) {
t.Fatalf("should not error: %v", err)
}
// 验证请求体包含 tools
// Verify request body contains tools
requests := mockHTTP.GetRequests()
var body map[string]interface{}
json.NewDecoder(requests[0].Body).Decode(&body)
@@ -440,7 +440,7 @@ func TestClient_CallWithRequest_NoAPIKey(t *testing.T) {
t.Error("should error when API key not set")
}
if err.Error() != "AI API密钥未设置,请先调用 SetAPIKey" {
if err.Error() != "AI API key not set, please call SetAPIKey first" {
t.Errorf("unexpected error: %v", err)
}
}
@@ -456,7 +456,7 @@ func TestClient_CallWithRequest_UsesClientModel(t *testing.T) {
WithAPIKey("sk-test-key"),
)
// Request 不设置 model,应该使用 Client model
// Request does not set model, should use Client's model
request := NewRequestBuilder().
WithUserPrompt("Hello").
MustBuild()
@@ -467,7 +467,7 @@ func TestClient_CallWithRequest_UsesClientModel(t *testing.T) {
client.CallWithRequest(request)
// 验证使用了 DeepSeek model
// Verify DeepSeek's model is used
requests := mockHTTP.GetRequests()
var body map[string]interface{}
json.NewDecoder(requests[0].Body).Decode(&body)
+144 -144
View File
@@ -12,7 +12,7 @@ import (
"time"
)
// defaultMainstreamCoins 默认主流币种池(从配置文件读取)
// defaultMainstreamCoins default mainstream coin pool (read from config file)
var defaultMainstreamCoins = []string{
"BTCUSDT",
"ETHUSDT",
@@ -24,42 +24,42 @@ var defaultMainstreamCoins = []string{
"HYPEUSDT",
}
// CoinPoolConfig 币种池配置
// CoinPoolConfig coin pool configuration
type CoinPoolConfig struct {
APIURL string
Timeout time.Duration
CacheDir string
UseDefaultCoins bool // 是否使用默认主流币种
UseDefaultCoins bool // Whether to use default mainstream coins
}
var coinPoolConfig = CoinPoolConfig{
APIURL: "",
Timeout: 30 * time.Second, // 增加到30秒
Timeout: 30 * time.Second, // Increased to 30 seconds
CacheDir: "coin_pool_cache",
UseDefaultCoins: false, // 默认不使用
UseDefaultCoins: false, // Default is not to use
}
// CoinPoolCache 币种池缓存
// CoinPoolCache coin pool cache
type CoinPoolCache struct {
Coins []CoinInfo `json:"coins"`
FetchedAt time.Time `json:"fetched_at"`
SourceType string `json:"source_type"` // "api" or "cache"
}
// CoinInfo 币种信息
// CoinInfo coin information
type CoinInfo struct {
Pair string `json:"pair"` // 交易对符号(例如:BTCUSDT
Score float64 `json:"score"` // 当前评分
StartTime int64 `json:"start_time"` // 开始时间(Unix时间戳)
StartPrice float64 `json:"start_price"` // 开始价格
LastScore float64 `json:"last_score"` // 最新评分
MaxScore float64 `json:"max_score"` // 最高评分
MaxPrice float64 `json:"max_price"` // 最高价格
IncreasePercent float64 `json:"increase_percent"` // 涨幅百分比
IsAvailable bool `json:"-"` // 是否可交易(内部使用)
Pair string `json:"pair"` // Trading pair symbol (e.g.: BTCUSDT)
Score float64 `json:"score"` // Current score
StartTime int64 `json:"start_time"` // Start time (Unix timestamp)
StartPrice float64 `json:"start_price"` // Start price
LastScore float64 `json:"last_score"` // Latest score
MaxScore float64 `json:"max_score"` // Highest score
MaxPrice float64 `json:"max_price"` // Highest price
IncreasePercent float64 `json:"increase_percent"` // Increase percentage
IsAvailable bool `json:"-"` // Whether tradable (internal use)
}
// CoinPoolAPIResponse API返回的原始数据结构
// CoinPoolAPIResponse raw data structure returned by API
type CoinPoolAPIResponse struct {
Success bool `json:"success"`
Data struct {
@@ -68,85 +68,85 @@ type CoinPoolAPIResponse struct {
} `json:"data"`
}
// SetCoinPoolAPI 设置币种池API
// SetCoinPoolAPI sets coin pool API
func SetCoinPoolAPI(apiURL string) {
coinPoolConfig.APIURL = apiURL
}
// SetOITopAPI 设置OI Top API
// SetOITopAPI sets OI Top API
func SetOITopAPI(apiURL string) {
oiTopConfig.APIURL = apiURL
}
// SetUseDefaultCoins 设置是否使用默认主流币种
// SetUseDefaultCoins sets whether to use default mainstream coins
func SetUseDefaultCoins(useDefault bool) {
coinPoolConfig.UseDefaultCoins = useDefault
}
// SetDefaultCoins 设置默认主流币种列表
// SetDefaultCoins sets default mainstream coin list
func SetDefaultCoins(coins []string) {
if len(coins) > 0 {
defaultMainstreamCoins = coins
log.Printf("✓ 已设置默认币种池(共%d个币种): %v", len(coins), coins)
log.Printf("✓ Default coin pool set (%d coins): %v", len(coins), coins)
}
}
// GetCoinPool 获取币种池列表(带重试和缓存机制)
// GetCoinPool retrieves coin pool list (with retry and cache mechanism)
func GetCoinPool() ([]CoinInfo, error) {
// 优先检查是否启用默认币种列表
// First check if default coin list is enabled
if coinPoolConfig.UseDefaultCoins {
log.Printf("✓ 已启用默认主流币种列表")
log.Printf("✓ Default mainstream coin list enabled")
return convertSymbolsToCoins(defaultMainstreamCoins), nil
}
// 检查API URL是否配置
// Check if API URL is configured
if strings.TrimSpace(coinPoolConfig.APIURL) == "" {
log.Printf("⚠️ 未配置币种池API URL,使用默认主流币种列表")
log.Printf("⚠️ Coin pool API URL not configured, using default mainstream coin list")
return convertSymbolsToCoins(defaultMainstreamCoins), nil
}
maxRetries := 3
var lastErr error
// 尝试从API获取
// Try to fetch from API
for attempt := 1; attempt <= maxRetries; attempt++ {
if attempt > 1 {
log.Printf("⚠️ 第%d次重试获取币种池(共%d次)...", attempt, maxRetries)
time.Sleep(2 * time.Second) // 重试前等待2秒
log.Printf("⚠️ Retry attempt %d of %d to fetch coin pool...", attempt, maxRetries)
time.Sleep(2 * time.Second) // Wait 2 seconds before retry
}
coins, err := fetchCoinPool()
if err == nil {
if attempt > 1 {
log.Printf("✓ 第%d次重试成功", attempt)
log.Printf("✓ Retry attempt %d succeeded", attempt)
}
// 成功获取后保存到缓存
// Save to cache after successful fetch
if err := saveCoinPoolCache(coins); err != nil {
log.Printf("⚠️ 保存币种池缓存失败: %v", err)
log.Printf("⚠️ Failed to save coin pool cache: %v", err)
}
return coins, nil
}
lastErr = err
log.Printf("❌ 第%d次请求失败: %v", attempt, err)
log.Printf("❌ Request attempt %d failed: %v", attempt, err)
}
// API获取失败,尝试使用缓存
log.Printf("⚠️ API请求全部失败,尝试使用历史缓存数据...")
// API fetch failed, try to use cache
log.Printf("⚠️ All API requests failed, trying to use historical cache data...")
cachedCoins, err := loadCoinPoolCache()
if err == nil {
log.Printf("✓ 使用历史缓存数据(共%d个币种)", len(cachedCoins))
log.Printf("✓ Using historical cache data (%d coins)", len(cachedCoins))
return cachedCoins, nil
}
// 缓存也失败,使用默认主流币种
log.Printf("⚠️ 无法加载缓存数据(最后错误: %v),使用默认主流币种列表", lastErr)
// Cache also failed, use default mainstream coins
log.Printf("⚠️ Unable to load cache data (last error: %v), using default mainstream coin list", lastErr)
return convertSymbolsToCoins(defaultMainstreamCoins), nil
}
// fetchCoinPool 实际执行币种池请求
// fetchCoinPool actually executes coin pool request
func fetchCoinPool() ([]CoinInfo, error) {
log.Printf("🔄 正在请求AI500币种池...")
log.Printf("🔄 Requesting AI500 coin pool...")
client := &http.Client{
Timeout: coinPoolConfig.Timeout,
@@ -154,48 +154,48 @@ func fetchCoinPool() ([]CoinInfo, error) {
resp, err := client.Get(coinPoolConfig.APIURL)
if err != nil {
return nil, fmt.Errorf("请求币种池API失败: %w", err)
return nil, fmt.Errorf("failed to request coin pool API: %w", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API返回错误 (status %d): %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("API returned error (status %d): %s", resp.StatusCode, string(body))
}
// 解析API响应
// Parse API response
var response CoinPoolAPIResponse
if err := json.Unmarshal(body, &response); err != nil {
return nil, fmt.Errorf("JSON解析失败: %w", err)
return nil, fmt.Errorf("JSON parsing failed: %w", err)
}
if !response.Success {
return nil, fmt.Errorf("API返回失败状态")
return nil, fmt.Errorf("API returned failure status")
}
if len(response.Data.Coins) == 0 {
return nil, fmt.Errorf("币种列表为空")
return nil, fmt.Errorf("coin list is empty")
}
// 设置IsAvailable标志
// Set IsAvailable flag
coins := response.Data.Coins
for i := range coins {
coins[i].IsAvailable = true
}
log.Printf("✓ 成功获取%d个币种", len(coins))
log.Printf("✓ Successfully fetched %d coins", len(coins))
return coins, nil
}
// saveCoinPoolCache 保存币种池到缓存文件
// saveCoinPoolCache saves coin pool to cache file
func saveCoinPoolCache(coins []CoinInfo) error {
// 确保缓存目录存在
// Ensure cache directory exists
if err := os.MkdirAll(coinPoolConfig.CacheDir, 0755); err != nil {
return fmt.Errorf("创建缓存目录失败: %w", err)
return fmt.Errorf("failed to create cache directory: %w", err)
}
cache := CoinPoolCache{
@@ -206,43 +206,43 @@ func saveCoinPoolCache(coins []CoinInfo) error {
data, err := json.MarshalIndent(cache, "", " ")
if err != nil {
return fmt.Errorf("序列化缓存数据失败: %w", err)
return fmt.Errorf("failed to serialize cache data: %w", err)
}
cachePath := filepath.Join(coinPoolConfig.CacheDir, "latest.json")
if err := ioutil.WriteFile(cachePath, data, 0644); err != nil {
return fmt.Errorf("写入缓存文件失败: %w", err)
return fmt.Errorf("failed to write cache file: %w", err)
}
log.Printf("💾 已保存币种池缓存(%d个币种)", len(coins))
log.Printf("💾 Coin pool cache saved (%d coins)", len(coins))
return nil
}
// loadCoinPoolCache 从缓存文件加载币种池
// loadCoinPoolCache loads coin pool from cache file
func loadCoinPoolCache() ([]CoinInfo, error) {
cachePath := filepath.Join(coinPoolConfig.CacheDir, "latest.json")
// 检查文件是否存在
// Check if file exists
if _, err := os.Stat(cachePath); os.IsNotExist(err) {
return nil, fmt.Errorf("缓存文件不存在")
return nil, fmt.Errorf("cache file does not exist")
}
data, err := ioutil.ReadFile(cachePath)
if err != nil {
return nil, fmt.Errorf("读取缓存文件失败: %w", err)
return nil, fmt.Errorf("failed to read cache file: %w", err)
}
var cache CoinPoolCache
if err := json.Unmarshal(data, &cache); err != nil {
return nil, fmt.Errorf("解析缓存数据失败: %w", err)
return nil, fmt.Errorf("failed to parse cache data: %w", err)
}
// 检查缓存年龄
// Check cache age
cacheAge := time.Since(cache.FetchedAt)
if cacheAge > 24*time.Hour {
log.Printf("⚠️ 缓存数据较旧(%.1f小时前),但仍可使用", cacheAge.Hours())
log.Printf("⚠️ Cache data is old (%.1f hours ago), but still usable", cacheAge.Hours())
} else {
log.Printf("📂 缓存数据时间: %s%.1f分钟前)",
log.Printf("📂 Cache data timestamp: %s (%.1f minutes ago)",
cache.FetchedAt.Format("2006-01-02 15:04:05"),
cacheAge.Minutes())
}
@@ -250,7 +250,7 @@ func loadCoinPoolCache() ([]CoinInfo, error) {
return cache.Coins, nil
}
// GetAvailableCoins 获取可用的币种列表(过滤不可用的)
// GetAvailableCoins retrieves available coin list (filters out unavailable ones)
func GetAvailableCoins() ([]string, error) {
coins, err := GetCoinPool()
if err != nil {
@@ -260,27 +260,27 @@ func GetAvailableCoins() ([]string, error) {
var symbols []string
for _, coin := range coins {
if coin.IsAvailable {
// 确保symbol格式正确(转为大写USDT交易对)
// Ensure symbol format is correct (convert to uppercase USDT pair)
symbol := normalizeSymbol(coin.Pair)
symbols = append(symbols, symbol)
}
}
if len(symbols) == 0 {
return nil, fmt.Errorf("没有可用的币种")
return nil, fmt.Errorf("no available coins")
}
return symbols, nil
}
// GetTopRatedCoins 获取评分最高的N个币种(按评分从大到小排序)
// GetTopRatedCoins retrieves top N coins by score (sorted by score descending)
func GetTopRatedCoins(limit int) ([]string, error) {
coins, err := GetCoinPool()
if err != nil {
return nil, err
}
// 过滤可用的币种
// Filter available coins
var availableCoins []CoinInfo
for _, coin := range coins {
if coin.IsAvailable {
@@ -289,10 +289,10 @@ func GetTopRatedCoins(limit int) ([]string, error) {
}
if len(availableCoins) == 0 {
return nil, fmt.Errorf("没有可用的币种")
return nil, fmt.Errorf("no available coins")
}
// 按Score降序排序(冒泡排序)
// Sort by Score descending (bubble sort)
for i := 0; i < len(availableCoins); i++ {
for j := i + 1; j < len(availableCoins); j++ {
if availableCoins[i].Score < availableCoins[j].Score {
@@ -301,7 +301,7 @@ func GetTopRatedCoins(limit int) ([]string, error) {
}
}
// 取前N个
// Take top N
maxCount := limit
if len(availableCoins) < maxCount {
maxCount = len(availableCoins)
@@ -316,15 +316,15 @@ func GetTopRatedCoins(limit int) ([]string, error) {
return symbols, nil
}
// normalizeSymbol 标准化币种符号
// normalizeSymbol normalizes coin symbol
func normalizeSymbol(symbol string) string {
// 移除空格
// Remove spaces
symbol = trimSpaces(symbol)
// 转为大写
// Convert to uppercase
symbol = toUpper(symbol)
// 确保以USDT结尾
// Ensure ends with USDT
if !endsWith(symbol, "USDT") {
symbol = symbol + "USDT"
}
@@ -332,7 +332,7 @@ func normalizeSymbol(symbol string) string {
return symbol
}
// 辅助函数
// Helper functions
func trimSpaces(s string) string {
result := ""
for i := 0; i < len(s); i++ {
@@ -362,7 +362,7 @@ func endsWith(s, suffix string) bool {
return s[len(s)-len(suffix):] == suffix
}
// convertSymbolsToCoins 将币种符号列表转换为CoinInfo列表
// convertSymbolsToCoins converts symbol list to CoinInfo list
func convertSymbolsToCoins(symbols []string) []CoinInfo {
coins := make([]CoinInfo, 0, len(symbols))
for _, symbol := range symbols {
@@ -375,22 +375,22 @@ func convertSymbolsToCoins(symbols []string) []CoinInfo {
return coins
}
// ========== OI Top(持仓量增长Top20)数据 ==========
// ========== OI Top (Open Interest Growth Top 20) Data ==========
// OIPosition 持仓量数据
// OIPosition open interest data
type OIPosition struct {
Symbol string `json:"symbol"`
Rank int `json:"rank"`
CurrentOI float64 `json:"current_oi"` // 当前持仓量
OIDelta float64 `json:"oi_delta"` // 持仓量变化
OIDeltaPercent float64 `json:"oi_delta_percent"` // 持仓量变化百分比
OIDeltaValue float64 `json:"oi_delta_value"` // 持仓量变化价值
PriceDeltaPercent float64 `json:"price_delta_percent"` // 价格变化百分比
NetLong float64 `json:"net_long"` // 净多仓
NetShort float64 `json:"net_short"` // 净空仓
CurrentOI float64 `json:"current_oi"` // Current open interest
OIDelta float64 `json:"oi_delta"` // Open interest change
OIDeltaPercent float64 `json:"oi_delta_percent"` // Open interest change percentage
OIDeltaValue float64 `json:"oi_delta_value"` // Open interest change value
PriceDeltaPercent float64 `json:"price_delta_percent"` // Price change percentage
NetLong float64 `json:"net_long"` // Net long position
NetShort float64 `json:"net_short"` // Net short position
}
// OITopAPIResponse OI Top API返回的数据结构
// OITopAPIResponse data structure returned by OI Top API
type OITopAPIResponse struct {
Success bool `json:"success"`
Data struct {
@@ -401,7 +401,7 @@ type OITopAPIResponse struct {
} `json:"data"`
}
// OITopCache OI Top 缓存
// OITopCache OI Top cache
type OITopCache struct {
Positions []OIPosition `json:"positions"`
FetchedAt time.Time `json:"fetched_at"`
@@ -418,56 +418,56 @@ var oiTopConfig = struct {
CacheDir: "coin_pool_cache",
}
// GetOITopPositions 获取持仓量增长Top20数据(带重试和缓存)
// GetOITopPositions retrieves OI Top 20 data (with retry and cache)
func GetOITopPositions() ([]OIPosition, error) {
// 检查API URL是否配置
// Check if API URL is configured
if strings.TrimSpace(oiTopConfig.APIURL) == "" {
log.Printf("⚠️ 未配置OI Top API URL,跳过OI Top数据获取")
return []OIPosition{}, nil // 返回空列表,不是错误
log.Printf("⚠️ OI Top API URL not configured, skipping OI Top data fetch")
return []OIPosition{}, nil // Return empty list, not an error
}
maxRetries := 3
var lastErr error
// 尝试从API获取
// Try to fetch from API
for attempt := 1; attempt <= maxRetries; attempt++ {
if attempt > 1 {
log.Printf("⚠️ 第%d次重试获取OI Top数据(共%d次)...", attempt, maxRetries)
log.Printf("⚠️ Retry attempt %d of %d to fetch OI Top data...", attempt, maxRetries)
time.Sleep(2 * time.Second)
}
positions, err := fetchOITop()
if err == nil {
if attempt > 1 {
log.Printf("✓ 第%d次重试成功", attempt)
log.Printf("✓ Retry attempt %d succeeded", attempt)
}
// 成功获取后保存到缓存
// Save to cache after successful fetch
if err := saveOITopCache(positions); err != nil {
log.Printf("⚠️ 保存OI Top缓存失败: %v", err)
log.Printf("⚠️ Failed to save OI Top cache: %v", err)
}
return positions, nil
}
lastErr = err
log.Printf("❌ 第%d次请求OI Top失败: %v", attempt, err)
log.Printf("❌ OI Top request attempt %d failed: %v", attempt, err)
}
// API获取失败,尝试使用缓存
log.Printf("⚠️ OI Top API请求全部失败,尝试使用历史缓存数据...")
// API fetch failed, try to use cache
log.Printf("⚠️ All OI Top API requests failed, trying to use historical cache data...")
cachedPositions, err := loadOITopCache()
if err == nil {
log.Printf("✓ 使用历史OI Top缓存数据(共%d个币种)", len(cachedPositions))
log.Printf("✓ Using historical OI Top cache data (%d coins)", len(cachedPositions))
return cachedPositions, nil
}
// 缓存也失败,返回空列表(OI Top是可选的)
log.Printf("⚠️ 无法加载OI Top缓存数据(最后错误: %v),跳过OI Top数据", lastErr)
// Cache also failed, return empty list (OI Top is optional)
log.Printf("⚠️ Unable to load OI Top cache data (last error: %v), skipping OI Top data", lastErr)
return []OIPosition{}, nil
}
// fetchOITop 实际执行OI Top请求
// fetchOITop actually executes OI Top request
func fetchOITop() ([]OIPosition, error) {
log.Printf("🔄 正在请求OI Top数据...")
log.Printf("🔄 Requesting OI Top data...")
client := &http.Client{
Timeout: oiTopConfig.Timeout,
@@ -475,42 +475,42 @@ func fetchOITop() ([]OIPosition, error) {
resp, err := client.Get(oiTopConfig.APIURL)
if err != nil {
return nil, fmt.Errorf("请求OI Top API失败: %w", err)
return nil, fmt.Errorf("failed to request OI Top API: %w", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取OI Top响应失败: %w", err)
return nil, fmt.Errorf("failed to read OI Top response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("OI Top API返回错误 (status %d): %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("OI Top API returned error (status %d): %s", resp.StatusCode, string(body))
}
// 解析API响应
// Parse API response
var response OITopAPIResponse
if err := json.Unmarshal(body, &response); err != nil {
return nil, fmt.Errorf("OI Top JSON解析失败: %w", err)
return nil, fmt.Errorf("OI Top JSON parsing failed: %w", err)
}
if !response.Success {
return nil, fmt.Errorf("OI Top API返回失败状态")
return nil, fmt.Errorf("OI Top API returned failure status")
}
if len(response.Data.Positions) == 0 {
return nil, fmt.Errorf("OI Top持仓列表为空")
return nil, fmt.Errorf("OI Top position list is empty")
}
log.Printf("✓ 成功获取%dOI Top币种(时间范围: %s",
log.Printf("✓ Successfully fetched %d OI Top coins (time range: %s)",
len(response.Data.Positions), response.Data.TimeRange)
return response.Data.Positions, nil
}
// saveOITopCache 保存OI Top数据到缓存
// saveOITopCache saves OI Top data to cache
func saveOITopCache(positions []OIPosition) error {
if err := os.MkdirAll(oiTopConfig.CacheDir, 0755); err != nil {
return fmt.Errorf("创建缓存目录失败: %w", err)
return fmt.Errorf("failed to create cache directory: %w", err)
}
cache := OITopCache{
@@ -521,41 +521,41 @@ func saveOITopCache(positions []OIPosition) error {
data, err := json.MarshalIndent(cache, "", " ")
if err != nil {
return fmt.Errorf("序列化OI Top缓存数据失败: %w", err)
return fmt.Errorf("failed to serialize OI Top cache data: %w", err)
}
cachePath := filepath.Join(oiTopConfig.CacheDir, "oi_top_latest.json")
if err := ioutil.WriteFile(cachePath, data, 0644); err != nil {
return fmt.Errorf("写入OI Top缓存文件失败: %w", err)
return fmt.Errorf("failed to write OI Top cache file: %w", err)
}
log.Printf("💾 已保存OI Top缓存(%d个币种)", len(positions))
log.Printf("💾 OI Top cache saved (%d coins)", len(positions))
return nil
}
// loadOITopCache 从缓存加载OI Top数据
// loadOITopCache loads OI Top data from cache
func loadOITopCache() ([]OIPosition, error) {
cachePath := filepath.Join(oiTopConfig.CacheDir, "oi_top_latest.json")
if _, err := os.Stat(cachePath); os.IsNotExist(err) {
return nil, fmt.Errorf("OI Top缓存文件不存在")
return nil, fmt.Errorf("OI Top cache file does not exist")
}
data, err := ioutil.ReadFile(cachePath)
if err != nil {
return nil, fmt.Errorf("读取OI Top缓存文件失败: %w", err)
return nil, fmt.Errorf("failed to read OI Top cache file: %w", err)
}
var cache OITopCache
if err := json.Unmarshal(data, &cache); err != nil {
return nil, fmt.Errorf("解析OI Top缓存数据失败: %w", err)
return nil, fmt.Errorf("failed to parse OI Top cache data: %w", err)
}
cacheAge := time.Since(cache.FetchedAt)
if cacheAge > 24*time.Hour {
log.Printf("⚠️ OI Top缓存数据较旧(%.1f小时前),但仍可使用", cacheAge.Hours())
log.Printf("⚠️ OI Top cache data is old (%.1f hours ago), but still usable", cacheAge.Hours())
} else {
log.Printf("📂 OI Top缓存数据时间: %s%.1f分钟前)",
log.Printf("📂 OI Top cache data timestamp: %s (%.1f minutes ago)",
cache.FetchedAt.Format("2006-01-02 15:04:05"),
cacheAge.Minutes())
}
@@ -563,7 +563,7 @@ func loadOITopCache() ([]OIPosition, error) {
return cache.Positions, nil
}
// GetOITopSymbols 获取OI Top的币种符号列表
// GetOITopSymbols retrieves OI Top coin symbol list
func GetOITopSymbols() ([]string, error) {
positions, err := GetOITopPositions()
if err != nil {
@@ -579,41 +579,41 @@ func GetOITopSymbols() ([]string, error) {
return symbols, nil
}
// MergedCoinPool 合并的币种池(AI500 + OI Top
// MergedCoinPool merged coin pool (AI500 + OI Top)
type MergedCoinPool struct {
AI500Coins []CoinInfo // AI500评分币种
OITopCoins []OIPosition // 持仓量增长Top20
AllSymbols []string // 所有不重复的币种符号
SymbolSources map[string][]string // 每个币种的来源("ai500"/"oi_top"
AI500Coins []CoinInfo // AI500 score coins
OITopCoins []OIPosition // Open interest growth Top 20
AllSymbols []string // All unique coin symbols
SymbolSources map[string][]string // Source of each coin ("ai500"/"oi_top")
}
// GetMergedCoinPool 获取合并后的币种池(AI500 + OI Top,去重)
// GetMergedCoinPool retrieves merged coin pool (AI500 + OI Top, deduplicated)
func GetMergedCoinPool(ai500Limit int) (*MergedCoinPool, error) {
// 1. 获取AI500数据
// 1. Get AI500 data
ai500TopSymbols, err := GetTopRatedCoins(ai500Limit)
if err != nil {
log.Printf("⚠️ 获取AI500数据失败: %v", err)
ai500TopSymbols = []string{} // 失败时用空列表
log.Printf("⚠️ Failed to get AI500 data: %v", err)
ai500TopSymbols = []string{} // Use empty list on failure
}
// 2. 获取OI Top数据
// 2. Get OI Top data
oiTopSymbols, err := GetOITopSymbols()
if err != nil {
log.Printf("⚠️ 获取OI Top数据失败: %v", err)
oiTopSymbols = []string{} // 失败时用空列表
log.Printf("⚠️ Failed to get OI Top data: %v", err)
oiTopSymbols = []string{} // Use empty list on failure
}
// 3. 合并并去重
// 3. Merge and deduplicate
symbolSet := make(map[string]bool)
symbolSources := make(map[string][]string)
// 添加AI500币种
// Add AI500 coins
for _, symbol := range ai500TopSymbols {
symbolSet[symbol] = true
symbolSources[symbol] = append(symbolSources[symbol], "ai500")
}
// 添加OI Top币种
// Add OI Top coins
for _, symbol := range oiTopSymbols {
if !symbolSet[symbol] {
symbolSet[symbol] = true
@@ -621,13 +621,13 @@ func GetMergedCoinPool(ai500Limit int) (*MergedCoinPool, error) {
symbolSources[symbol] = append(symbolSources[symbol], "oi_top")
}
// 转换为数组
// Convert to array
var allSymbols []string
for symbol := range symbolSet {
allSymbols = append(allSymbols, symbol)
}
// 获取完整数据
// Get complete data
ai500Coins, _ := GetCoinPool()
oiTopPositions, _ := GetOITopPositions()
@@ -638,7 +638,7 @@ func GetMergedCoinPool(ai500Limit int) (*MergedCoinPool, error) {
SymbolSources: symbolSources,
}
log.Printf("📊 币种池合并完成: AI500=%d, OI_Top=%d, 总计(去重)=%d",
log.Printf("📊 Coin pool merge complete: AI500=%d, OI_Top=%d, Total(deduplicated)=%d",
len(ai500TopSymbols), len(oiTopSymbols), len(allSymbols))
return merged, nil
+36 -36
View File
@@ -12,64 +12,64 @@ import (
)
func main() {
log.Println("🔄 开始迁移数据库到加密格式...")
log.Println("🔄 Starting database migration to encrypted format...")
// 1. 检查数据库文件
// 1. Check database file
dbPath := "data.db"
if len(os.Args) > 1 {
dbPath = os.Args[1]
}
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
log.Fatalf("❌ 数据库文件不存在: %s", dbPath)
log.Fatalf("❌ Database file does not exist: %s", dbPath)
}
// 2. 备份数据库
// 2. Backup database
backupPath := fmt.Sprintf("%s.pre_encryption_backup", dbPath)
log.Printf("📦 备份数据库到: %s", backupPath)
log.Printf("📦 Backing up database to: %s", backupPath)
input, err := os.ReadFile(dbPath)
if err != nil {
log.Fatalf("❌ 读取数据库失败: %v", err)
log.Fatalf("❌ Failed to read database: %v", err)
}
if err := os.WriteFile(backupPath, input, 0600); err != nil {
log.Fatalf("❌ 备份失败: %v", err)
log.Fatalf("❌ Backup failed: %v", err)
}
// 3. 打开数据库
// 3. Open database
db, err := sql.Open("sqlite", dbPath)
if err != nil {
log.Fatalf("❌ 打开数据库失败: %v", err)
log.Fatalf("❌ Failed to open database: %v", err)
}
defer db.Close()
// 4. 初始化 CryptoService(从环境变量加载密钥)
// 4. Initialize CryptoService (load key from environment variables)
cs, err := crypto.NewCryptoService()
if err != nil {
log.Fatalf("❌ 初始化加密服务失败: %v", err)
log.Fatalf("❌ Failed to initialize encryption service: %v", err)
}
// 5. 迁移交易所配置
// 5. Migrate exchange configurations
if err := migrateExchanges(db, cs); err != nil {
log.Fatalf("❌ 迁移交易所配置失败: %v", err)
log.Fatalf("❌ Failed to migrate exchange configurations: %v", err)
}
// 6. 迁移 AI 模型配置
// 6. Migrate AI model configurations
if err := migrateAIModels(db, cs); err != nil {
log.Fatalf("❌ 迁移 AI 模型配置失败: %v", err)
log.Fatalf("❌ Failed to migrate AI model configurations: %v", err)
}
log.Println("✅ 数据迁移完成!")
log.Printf("📝 原始数据备份位于: %s", backupPath)
log.Println("⚠️ 请验证系统功能正常后,手动删除备份文件")
log.Println("✅ Data migration completed!")
log.Printf("📝 Original data backed up at: %s", backupPath)
log.Println("⚠️ Please verify system functionality before manually deleting backup file")
}
// migrateExchanges 迁移交易所配置
// migrateExchanges migrates exchange configurations
func migrateExchanges(db *sql.DB, cs *crypto.CryptoService) error {
log.Println("🔄 迁移交易所配置...")
log.Println("🔄 Migrating exchange configurations...")
// 查询所有未加密的记录(加密数据以 ENC:v1: 开头)
// Query all unencrypted records (encrypted data starts with ENC:v1:)
rows, err := db.Query(`
SELECT user_id, id, api_key, secret_key,
COALESCE(hyperliquid_private_key, ''),
@@ -96,22 +96,22 @@ func migrateExchanges(db *sql.DB, cs *crypto.CryptoService) error {
return err
}
// 加密每个字段
// Encrypt each field
encAPIKey, err := cs.EncryptForStorage(apiKey)
if err != nil {
return fmt.Errorf("加密 API Key 失败: %w", err)
return fmt.Errorf("failed to encrypt API Key: %w", err)
}
encSecretKey, err := cs.EncryptForStorage(secretKey)
if err != nil {
return fmt.Errorf("加密 Secret Key 失败: %w", err)
return fmt.Errorf("failed to encrypt Secret Key: %w", err)
}
encHLPrivateKey := ""
if hlPrivateKey != "" {
encHLPrivateKey, err = cs.EncryptForStorage(hlPrivateKey)
if err != nil {
return fmt.Errorf("加密 Hyperliquid Private Key 失败: %w", err)
return fmt.Errorf("failed to encrypt Hyperliquid Private Key: %w", err)
}
}
@@ -119,11 +119,11 @@ func migrateExchanges(db *sql.DB, cs *crypto.CryptoService) error {
if asterPrivateKey != "" {
encAsterPrivateKey, err = cs.EncryptForStorage(asterPrivateKey)
if err != nil {
return fmt.Errorf("加密 Aster Private Key 失败: %w", err)
return fmt.Errorf("failed to encrypt Aster Private Key: %w", err)
}
}
// 更新数据库
// Update database
_, err = tx.Exec(`
UPDATE exchanges
SET api_key = ?, secret_key = ?,
@@ -132,10 +132,10 @@ func migrateExchanges(db *sql.DB, cs *crypto.CryptoService) error {
`, encAPIKey, encSecretKey, encHLPrivateKey, encAsterPrivateKey, userID, exchangeID)
if err != nil {
return fmt.Errorf("更新数据库失败: %w", err)
return fmt.Errorf("failed to update database: %w", err)
}
log.Printf(" ✓ 已加密: [%s] %s", userID, exchangeID)
log.Printf(" ✓ Encrypted: [%s] %s", userID, exchangeID)
count++
}
@@ -143,13 +143,13 @@ func migrateExchanges(db *sql.DB, cs *crypto.CryptoService) error {
return err
}
log.Printf("✅ 已迁移 %d 个交易所配置", count)
log.Printf("✅ Migrated %d exchange configurations", count)
return nil
}
// migrateAIModels 迁移 AI 模型配置
// migrateAIModels migrates AI model configurations
func migrateAIModels(db *sql.DB, cs *crypto.CryptoService) error {
log.Println("🔄 迁移 AI 模型配置...")
log.Println("🔄 Migrating AI model configurations...")
rows, err := db.Query(`
SELECT user_id, id, api_key
@@ -176,7 +176,7 @@ func migrateAIModels(db *sql.DB, cs *crypto.CryptoService) error {
encAPIKey, err := cs.EncryptForStorage(apiKey)
if err != nil {
return fmt.Errorf("加密 API Key 失败: %w", err)
return fmt.Errorf("failed to encrypt API Key: %w", err)
}
_, err = tx.Exec(`
@@ -184,10 +184,10 @@ func migrateAIModels(db *sql.DB, cs *crypto.CryptoService) error {
`, encAPIKey, userID, modelID)
if err != nil {
return fmt.Errorf("更新数据库失败: %w", err)
return fmt.Errorf("failed to update database: %w", err)
}
log.Printf(" ✓ 已加密: [%s] %s", userID, modelID)
log.Printf(" ✓ Encrypted: [%s] %s", userID, modelID)
count++
}
@@ -195,6 +195,6 @@ func migrateAIModels(db *sql.DB, cs *crypto.CryptoService) error {
return err
}
log.Printf("✅ 已迁移 %d AI 模型配置", count)
log.Printf("✅ Migrated %d AI model configurations", count)
return nil
}
+17 -17
View File
@@ -9,14 +9,14 @@ import (
"time"
)
// AIModelStore AI模型存储
// AIModelStore AI model storage
type AIModelStore struct {
db *sql.DB
encryptFunc func(string) string
decryptFunc func(string) string
}
// AIModel AI模型配置
// AIModel AI model configuration
type AIModel struct {
ID string `json:"id"`
UserID string `json:"user_id"`
@@ -49,7 +49,7 @@ func (s *AIModelStore) initTables() error {
return err
}
// 触发器
// Trigger
_, err = s.db.Exec(`
CREATE TRIGGER IF NOT EXISTS update_ai_models_updated_at
AFTER UPDATE ON ai_models
@@ -61,7 +61,7 @@ func (s *AIModelStore) initTables() error {
return err
}
// 向后兼容:添加可能缺失的列
// Backward compatibility: add potentially missing columns
s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_api_url TEXT DEFAULT ''`)
s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_model_name TEXT DEFAULT ''`)
@@ -82,7 +82,7 @@ func (s *AIModelStore) initDefaultData() error {
VALUES (?, 'default', ?, ?, 0)
`, model.id, model.name, model.provider)
if err != nil {
return fmt.Errorf("初始化AI模型失败: %w", err)
return fmt.Errorf("failed to initialize AI model: %w", err)
}
}
return nil
@@ -102,7 +102,7 @@ func (s *AIModelStore) decrypt(encrypted string) string {
return encrypted
}
// List 获取用户的AI模型列表
// List retrieves user's AI model list
func (s *AIModelStore) List(userID string) ([]*AIModel, error) {
rows, err := s.db.Query(`
SELECT id, user_id, name, provider, enabled, api_key,
@@ -136,10 +136,10 @@ func (s *AIModelStore) List(userID string) ([]*AIModel, error) {
return models, nil
}
// Get 获取单个AI模型
// Get retrieves a single AI model
func (s *AIModelStore) Get(userID, modelID string) (*AIModel, error) {
if modelID == "" {
return nil, fmt.Errorf("模型ID不能为空")
return nil, fmt.Errorf("model ID cannot be empty")
}
candidates := []string{}
@@ -178,7 +178,7 @@ func (s *AIModelStore) Get(userID, modelID string) (*AIModel, error) {
return nil, sql.ErrNoRows
}
// GetDefault 获取默认启用的AI模型
// GetDefault retrieves the default enabled AI model
func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) {
if userID == "" {
userID = "default"
@@ -193,7 +193,7 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) {
if userID != "default" {
return s.firstEnabled("default")
}
return nil, fmt.Errorf("请先在系统中配置可用的AI模型")
return nil, fmt.Errorf("please configure an available AI model in the system first")
}
func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) {
@@ -218,9 +218,9 @@ func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) {
return &model, nil
}
// Update 更新AI模型,不存在则创建
// Update updates AI model, creates if not exists
func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error {
// 先尝试精确匹配ID
// Try exact ID match first
var existingID string
err := s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1`, userID, id).Scan(&existingID)
if err == nil {
@@ -232,11 +232,11 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI
return err
}
// 尝试兼容旧逻辑:将id作为provider查找
// Try legacy logic compatibility: use id as provider to search
provider := id
err = s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND provider = ? LIMIT 1`, userID, provider).Scan(&existingID)
if err == nil {
logger.Warnf("⚠️ 使用旧版 provider 匹配更新模型: %s -> %s", provider, existingID)
logger.Warnf("⚠️ Using legacy provider matching to update model: %s -> %s", provider, existingID)
encryptedAPIKey := s.encrypt(apiKey)
_, err = s.db.Exec(`
UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
@@ -245,7 +245,7 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI
return err
}
// 创建新记录
// Create new record
if provider == id && (provider == "deepseek" || provider == "qwen") {
provider = id
} else {
@@ -274,7 +274,7 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI
newModelID = fmt.Sprintf("%s_%s", userID, provider)
}
logger.Infof("✓ 创建新的 AI 模型配置: ID=%s, Provider=%s, Name=%s", newModelID, provider, name)
logger.Infof("✓ Creating new AI model configuration: ID=%s, Provider=%s, Name=%s", newModelID, provider, name)
encryptedAPIKey := s.encrypt(apiKey)
_, err = s.db.Exec(`
INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url, custom_model_name, created_at, updated_at)
@@ -283,7 +283,7 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI
return err
}
// Create 创建AI模型
// Create creates an AI model
func (s *AIModelStore) Create(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error {
_, err := s.db.Exec(`
INSERT OR IGNORE INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url)
+38 -38
View File
@@ -7,12 +7,12 @@ import (
"time"
)
// BacktestStore 回测数据存储
// BacktestStore backtest data storage
type BacktestStore struct {
db *sql.DB
}
// RunState 回测状态
// RunState backtest state
type RunState string
const (
@@ -23,7 +23,7 @@ const (
RunStateFailed RunState = "failed"
)
// RunMetadata 回测元数据
// RunMetadata backtest metadata
type RunMetadata struct {
RunID string `json:"run_id"`
UserID string `json:"user_id"`
@@ -36,7 +36,7 @@ type RunMetadata struct {
UpdatedAt time.Time `json:"updated_at"`
}
// RunSummary 回测摘要
// RunSummary backtest summary
type RunSummary struct {
SymbolCount int `json:"symbol_count"`
DecisionTF string `json:"decision_tf"`
@@ -48,7 +48,7 @@ type RunSummary struct {
LiquidationNote string `json:"liquidation_note"`
}
// EquityPoint 权益点
// EquityPoint equity point
type EquityPoint struct {
Timestamp int64 `json:"timestamp"`
Equity float64 `json:"equity"`
@@ -59,7 +59,7 @@ type EquityPoint struct {
Cycle int `json:"cycle"`
}
// TradeEvent 交易事件
// TradeEvent trade event
type TradeEvent struct {
Timestamp int64 `json:"timestamp"`
Symbol string `json:"symbol"`
@@ -78,7 +78,7 @@ type TradeEvent struct {
Note string `json:"note"`
}
// RunIndexEntry 回测索引条目
// RunIndexEntry backtest index entry
type RunIndexEntry struct {
RunID string `json:"run_id"`
State string `json:"state"`
@@ -92,10 +92,10 @@ type RunIndexEntry struct {
UpdatedAtISO string `json:"updated_at"`
}
// initTables 初始化回测相关表
// initTables initializes backtest related tables
func (s *BacktestStore) initTables() error {
queries := []string{
// 回测运行主表
// Backtest runs main table
`CREATE TABLE IF NOT EXISTS backtest_runs (
run_id TEXT PRIMARY KEY,
user_id TEXT NOT NULL DEFAULT '',
@@ -120,7 +120,7 @@ func (s *BacktestStore) initTables() error {
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
// 回测检查点
// Backtest checkpoints
`CREATE TABLE IF NOT EXISTS backtest_checkpoints (
run_id TEXT PRIMARY KEY,
payload BLOB NOT NULL,
@@ -128,7 +128,7 @@ func (s *BacktestStore) initTables() error {
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
)`,
// 回测权益曲线
// Backtest equity curve
`CREATE TABLE IF NOT EXISTS backtest_equity (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id TEXT NOT NULL,
@@ -142,7 +142,7 @@ func (s *BacktestStore) initTables() error {
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
)`,
// 回测交易记录
// Backtest trade records
`CREATE TABLE IF NOT EXISTS backtest_trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id TEXT NOT NULL,
@@ -164,7 +164,7 @@ func (s *BacktestStore) initTables() error {
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
)`,
// 回测指标
// Backtest metrics
`CREATE TABLE IF NOT EXISTS backtest_metrics (
run_id TEXT PRIMARY KEY,
payload BLOB NOT NULL,
@@ -172,7 +172,7 @@ func (s *BacktestStore) initTables() error {
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
)`,
// 回测决策日志
// Backtest decision logs
`CREATE TABLE IF NOT EXISTS backtest_decisions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id TEXT NOT NULL,
@@ -182,7 +182,7 @@ func (s *BacktestStore) initTables() error {
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
)`,
// 索引
// Indexes
`CREATE INDEX IF NOT EXISTS idx_backtest_runs_state ON backtest_runs(state, updated_at)`,
`CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`,
`CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`,
@@ -191,11 +191,11 @@ func (s *BacktestStore) initTables() error {
for _, query := range queries {
if _, err := s.db.Exec(query); err != nil {
return fmt.Errorf("执行SQL失败: %w", err)
return fmt.Errorf("failed to execute SQL: %w", err)
}
}
// 添加可能缺失的列(向后兼容)
// Add potentially missing columns (backward compatibility)
s.addColumnIfNotExists("backtest_runs", "label", "TEXT DEFAULT ''")
s.addColumnIfNotExists("backtest_runs", "last_error", "TEXT DEFAULT ''")
s.addColumnIfNotExists("backtest_trades", "leverage", "INTEGER DEFAULT 0")
@@ -219,14 +219,14 @@ func (s *BacktestStore) addColumnIfNotExists(table, column, definition string) {
continue
}
if name == column {
return // 列已存在
return // Column already exists
}
}
s.db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, definition))
}
// SaveCheckpoint 保存检查点
// SaveCheckpoint saves checkpoint
func (s *BacktestStore) SaveCheckpoint(runID string, payload []byte) error {
_, err := s.db.Exec(`
INSERT INTO backtest_checkpoints (run_id, payload, updated_at)
@@ -236,14 +236,14 @@ func (s *BacktestStore) SaveCheckpoint(runID string, payload []byte) error {
return err
}
// LoadCheckpoint 加载检查点
// LoadCheckpoint loads checkpoint
func (s *BacktestStore) LoadCheckpoint(runID string) ([]byte, error) {
var payload []byte
err := s.db.QueryRow(`SELECT payload FROM backtest_checkpoints WHERE run_id = ?`, runID).Scan(&payload)
return payload, err
}
// SaveRunMetadata 保存运行元数据
// SaveRunMetadata saves run metadata
func (s *BacktestStore) SaveRunMetadata(meta *RunMetadata) error {
created := meta.CreatedAt.UTC().Format(time.RFC3339)
updated := meta.UpdatedAt.UTC().Format(time.RFC3339)
@@ -270,7 +270,7 @@ func (s *BacktestStore) SaveRunMetadata(meta *RunMetadata) error {
return err
}
// LoadRunMetadata 加载运行元数据
// LoadRunMetadata loads run metadata
func (s *BacktestStore) LoadRunMetadata(runID string) (*RunMetadata, error) {
var (
userID string
@@ -326,7 +326,7 @@ func (s *BacktestStore) LoadRunMetadata(runID string) (*RunMetadata, error) {
return meta, nil
}
// ListRunIDs 列出所有运行ID
// ListRunIDs lists all run IDs
func (s *BacktestStore) ListRunIDs() ([]string, error) {
rows, err := s.db.Query(`SELECT run_id FROM backtest_runs ORDER BY datetime(updated_at) DESC`)
if err != nil {
@@ -345,7 +345,7 @@ func (s *BacktestStore) ListRunIDs() ([]string, error) {
return ids, rows.Err()
}
// AppendEquityPoint 添加权益点
// AppendEquityPoint appends equity point
func (s *BacktestStore) AppendEquityPoint(runID string, point EquityPoint) error {
_, err := s.db.Exec(`
INSERT INTO backtest_equity (run_id, ts, equity, available, pnl, pnl_pct, dd_pct, cycle)
@@ -355,7 +355,7 @@ func (s *BacktestStore) AppendEquityPoint(runID string, point EquityPoint) error
return err
}
// LoadEquityPoints 加载权益点
// LoadEquityPoints loads equity points
func (s *BacktestStore) LoadEquityPoints(runID string) ([]EquityPoint, error) {
rows, err := s.db.Query(`
SELECT ts, equity, available, pnl, pnl_pct, dd_pct, cycle
@@ -378,7 +378,7 @@ func (s *BacktestStore) LoadEquityPoints(runID string) ([]EquityPoint, error) {
return points, rows.Err()
}
// AppendTradeEvent 添加交易事件
// AppendTradeEvent appends trade event
func (s *BacktestStore) AppendTradeEvent(runID string, event TradeEvent) error {
_, err := s.db.Exec(`
INSERT INTO backtest_trades (run_id, ts, symbol, action, side, qty, price, fee,
@@ -391,7 +391,7 @@ func (s *BacktestStore) AppendTradeEvent(runID string, event TradeEvent) error {
return err
}
// LoadTradeEvents 加载交易事件
// LoadTradeEvents loads trade events
func (s *BacktestStore) LoadTradeEvents(runID string) ([]TradeEvent, error) {
rows, err := s.db.Query(`
SELECT ts, symbol, action, side, qty, price, fee, slippage, order_value,
@@ -417,7 +417,7 @@ func (s *BacktestStore) LoadTradeEvents(runID string) ([]TradeEvent, error) {
return events, rows.Err()
}
// SaveMetrics 保存指标
// SaveMetrics saves metrics
func (s *BacktestStore) SaveMetrics(runID string, payload []byte) error {
_, err := s.db.Exec(`
INSERT INTO backtest_metrics (run_id, payload, updated_at)
@@ -427,14 +427,14 @@ func (s *BacktestStore) SaveMetrics(runID string, payload []byte) error {
return err
}
// LoadMetrics 加载指标
// LoadMetrics loads metrics
func (s *BacktestStore) LoadMetrics(runID string) ([]byte, error) {
var payload []byte
err := s.db.QueryRow(`SELECT payload FROM backtest_metrics WHERE run_id = ?`, runID).Scan(&payload)
return payload, err
}
// SaveDecisionRecord 保存决策记录
// SaveDecisionRecord saves decision record
func (s *BacktestStore) SaveDecisionRecord(runID string, cycle int, payload []byte) error {
_, err := s.db.Exec(`
INSERT INTO backtest_decisions (run_id, cycle, payload)
@@ -443,7 +443,7 @@ func (s *BacktestStore) SaveDecisionRecord(runID string, cycle int, payload []by
return err
}
// LoadDecisionRecords 加载决策记录
// LoadDecisionRecords loads decision records
func (s *BacktestStore) LoadDecisionRecords(runID string, limit, offset int) ([]json.RawMessage, error) {
rows, err := s.db.Query(`
SELECT payload FROM backtest_decisions
@@ -467,7 +467,7 @@ func (s *BacktestStore) LoadDecisionRecords(runID string, limit, offset int) ([]
return records, rows.Err()
}
// LoadLatestDecision 加载最新决策
// LoadLatestDecision loads latest decision
func (s *BacktestStore) LoadLatestDecision(runID string, cycle int) ([]byte, error) {
var query string
var args []interface{}
@@ -485,7 +485,7 @@ func (s *BacktestStore) LoadLatestDecision(runID string, cycle int) ([]byte, err
return payload, err
}
// UpdateProgress 更新进度
// UpdateProgress updates progress
func (s *BacktestStore) UpdateProgress(runID string, progressPct, equity float64, barIndex int, liquidated bool) error {
_, err := s.db.Exec(`
UPDATE backtest_runs
@@ -495,7 +495,7 @@ func (s *BacktestStore) UpdateProgress(runID string, progressPct, equity float64
return err
}
// ListIndexEntries 列出索引条目
// ListIndexEntries lists index entries
func (s *BacktestStore) ListIndexEntries() ([]RunIndexEntry, error) {
rows, err := s.db.Query(`
SELECT run_id, state, symbol_count, decision_tf, equity_last, max_drawdown_pct,
@@ -524,7 +524,7 @@ func (s *BacktestStore) ListIndexEntries() ([]RunIndexEntry, error) {
entry.UpdatedAtISO = updatedISO
entry.Symbols = make([]string, 0, symbolCnt)
// 尝试从配置中提取更多信息
// Try to extract more information from config
if len(cfgJSON) > 0 {
var cfg struct {
Symbols []string `json:"symbols"`
@@ -543,13 +543,13 @@ func (s *BacktestStore) ListIndexEntries() ([]RunIndexEntry, error) {
return entries, rows.Err()
}
// DeleteRun 删除运行
// DeleteRun deletes run
func (s *BacktestStore) DeleteRun(runID string) error {
_, err := s.db.Exec(`DELETE FROM backtest_runs WHERE run_id = ?`, runID)
return err
}
// SaveConfig 保存配置
// SaveConfig saves config
func (s *BacktestStore) SaveConfig(runID, userID, template, customPrompt, provider, model string, override bool, configJSON []byte) error {
now := time.Now().UTC().Format(time.RFC3339)
if userID == "" {
@@ -575,7 +575,7 @@ func (s *BacktestStore) SaveConfig(runID, userID, template, customPrompt, provid
return err
}
// LoadConfig 加载配置
// LoadConfig loads config
func (s *BacktestStore) LoadConfig(runID string) ([]byte, error) {
var payload []byte
err := s.db.QueryRow(`SELECT config_json FROM backtest_runs WHERE run_id = ?`, runID).Scan(&payload)
+43 -43
View File
@@ -7,12 +7,12 @@ import (
"time"
)
// DecisionStore 决策日志存储
// DecisionStore decision log storage
type DecisionStore struct {
db *sql.DB
}
// DecisionRecord 决策记录
// DecisionRecord decision record
type DecisionRecord struct {
ID int64 `json:"id"`
TraderID string `json:"trader_id"`
@@ -32,7 +32,7 @@ type DecisionRecord struct {
Decisions []DecisionAction `json:"decisions"`
}
// AccountSnapshot 账户状态快照
// AccountSnapshot account state snapshot
type AccountSnapshot struct {
TotalBalance float64 `json:"total_balance"`
AvailableBalance float64 `json:"available_balance"`
@@ -42,7 +42,7 @@ type AccountSnapshot struct {
InitialBalance float64 `json:"initial_balance"`
}
// PositionSnapshot 持仓快照
// PositionSnapshot position snapshot
type PositionSnapshot struct {
Symbol string `json:"symbol"`
Side string `json:"side"`
@@ -54,8 +54,8 @@ type PositionSnapshot struct {
LiquidationPrice float64 `json:"liquidation_price"`
}
// DecisionAction 决策动作
type DecisionAction struct {
// DecisionAction decision action
type DecisionAction struct{
Action string `json:"action"`
Symbol string `json:"symbol"`
Quantity float64 `json:"quantity"`
@@ -67,7 +67,7 @@ type DecisionAction struct {
Error string `json:"error"`
}
// Statistics 统计信息
// Statistics statistics information
type Statistics struct {
TotalCycles int `json:"total_cycles"`
SuccessfulCycles int `json:"successful_cycles"`
@@ -76,11 +76,11 @@ type Statistics struct {
TotalClosePositions int `json:"total_close_positions"`
}
// initTables 初始化 AI 决策日志表
// 注意:账户净值曲线数据已迁移到 trader_equity_snapshots 表(由 EquityStore 管理)
// initTables initializes AI decision log tables
// Note: Account equity curve data has been migrated to trader_equity_snapshots table (managed by EquityStore)
func (s *DecisionStore) initTables() error {
queries := []string{
// AI 决策日志表(记录 AI 的输入输出、思维链等)
// AI decision log table (records AI input/output, chain of thought, etc.)
`CREATE TABLE IF NOT EXISTS decision_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trader_id TEXT NOT NULL,
@@ -97,21 +97,21 @@ func (s *DecisionStore) initTables() error {
ai_request_duration_ms INTEGER DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
// 索引
// Indexes
`CREATE INDEX IF NOT EXISTS idx_decision_records_trader_time ON decision_records(trader_id, timestamp DESC)`,
`CREATE INDEX IF NOT EXISTS idx_decision_records_timestamp ON decision_records(timestamp DESC)`,
}
for _, query := range queries {
if _, err := s.db.Exec(query); err != nil {
return fmt.Errorf("执行SQL失败: %w", err)
return fmt.Errorf("failed to execute SQL: %w", err)
}
}
return nil
}
// LogDecision 记录决策(仅保存 AI 决策日志,净值曲线已迁移到 equity 表)
// LogDecision logs decision (only saves AI decision log, equity curve has been migrated to equity table)
func (s *DecisionStore) LogDecision(record *DecisionRecord) error {
if record.Timestamp.IsZero() {
record.Timestamp = time.Now().UTC()
@@ -119,11 +119,11 @@ func (s *DecisionStore) LogDecision(record *DecisionRecord) error {
record.Timestamp = record.Timestamp.UTC()
}
// 序列化候选币种和执行日志为 JSON
// Serialize candidate coins and execution log to JSON
candidateCoinsJSON, _ := json.Marshal(record.CandidateCoins)
executionLogJSON, _ := json.Marshal(record.ExecutionLog)
// 插入决策记录主表(仅保存 AI 决策相关内容)
// Insert decision record main table (only save AI decision related content)
result, err := s.db.Exec(`
INSERT INTO decision_records (
trader_id, cycle_number, timestamp, system_prompt, input_prompt,
@@ -137,19 +137,19 @@ func (s *DecisionStore) LogDecision(record *DecisionRecord) error {
record.Success, record.ErrorMessage, record.AIRequestDurationMs,
)
if err != nil {
return fmt.Errorf("插入决策记录失败: %w", err)
return fmt.Errorf("failed to insert decision record: %w", err)
}
decisionID, err := result.LastInsertId()
if err != nil {
return fmt.Errorf("获取决策ID失败: %w", err)
return fmt.Errorf("failed to get decision ID: %w", err)
}
record.ID = decisionID
return nil
}
// GetLatestRecords 获取指定交易员最近N条记录(按时间正序:从旧到新)
// GetLatestRecords gets the latest N records for specified trader (sorted by time in ascending order: old to new)
func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRecord, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt,
@@ -161,7 +161,7 @@ func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRec
LIMIT ?
`, traderID, n)
if err != nil {
return nil, fmt.Errorf("查询决策记录失败: %w", err)
return nil, fmt.Errorf("failed to query decision records: %w", err)
}
defer rows.Close()
@@ -174,12 +174,12 @@ func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRec
records = append(records, record)
}
// 填充关联数据
// Fill associated data
for _, record := range records {
s.fillRecordDetails(record)
}
// 反转数组,让时间从旧到新排列
// Reverse array to sort time from old to new
for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 {
records[i], records[j] = records[j], records[i]
}
@@ -187,7 +187,7 @@ func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRec
return records, nil
}
// GetAllLatestRecords 获取所有交易员最近N条记录
// GetAllLatestRecords gets the latest N records for all traders
func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt,
@@ -198,7 +198,7 @@ func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) {
LIMIT ?
`, n)
if err != nil {
return nil, fmt.Errorf("查询决策记录失败: %w", err)
return nil, fmt.Errorf("failed to query decision records: %w", err)
}
defer rows.Close()
@@ -211,7 +211,7 @@ func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) {
records = append(records, record)
}
// 反转数组
// Reverse array
for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 {
records[i], records[j] = records[j], records[i]
}
@@ -219,7 +219,7 @@ func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) {
return records, nil
}
// GetRecordsByDate 获取指定交易员指定日期的所有记录
// GetRecordsByDate gets all records for a specified trader on a specified date
func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*DecisionRecord, error) {
dateStr := date.Format("2006-01-02")
@@ -232,7 +232,7 @@ func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*De
ORDER BY timestamp ASC
`, traderID, dateStr)
if err != nil {
return nil, fmt.Errorf("查询决策记录失败: %w", err)
return nil, fmt.Errorf("failed to query decision records: %w", err)
}
defer rows.Close()
@@ -248,7 +248,7 @@ func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*De
return records, nil
}
// CleanOldRecords 清理N天前的旧记录
// CleanOldRecords cleans old records from N days ago
func (s *DecisionStore) CleanOldRecords(traderID string, days int) (int64, error) {
cutoffTime := time.Now().AddDate(0, 0, -days).Format(time.RFC3339)
@@ -257,13 +257,13 @@ func (s *DecisionStore) CleanOldRecords(traderID string, days int) (int64, error
WHERE trader_id = ? AND timestamp < ?
`, traderID, cutoffTime)
if err != nil {
return 0, fmt.Errorf("清理旧记录失败: %w", err)
return 0, fmt.Errorf("failed to clean old records: %w", err)
}
return result.RowsAffected()
}
// GetStatistics 获取指定交易员的统计信息
// GetStatistics gets statistics information for specified trader
func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) {
stats := &Statistics{}
@@ -271,24 +271,24 @@ func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) {
SELECT COUNT(*) FROM decision_records WHERE trader_id = ?
`, traderID).Scan(&stats.TotalCycles)
if err != nil {
return nil, fmt.Errorf("查询总周期数失败: %w", err)
return nil, fmt.Errorf("failed to query total cycles: %w", err)
}
err = s.db.QueryRow(`
SELECT COUNT(*) FROM decision_records WHERE trader_id = ? AND success = 1
`, traderID).Scan(&stats.SuccessfulCycles)
if err != nil {
return nil, fmt.Errorf("查询成功周期数失败: %w", err)
return nil, fmt.Errorf("failed to query successful cycles: %w", err)
}
stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles
// trader_orders 表统计开仓次数
// Count open positions from trader_orders table
s.db.QueryRow(`
SELECT COUNT(*) FROM trader_orders
WHERE trader_id = ? AND status = 'FILLED' AND action IN ('open_long', 'open_short')
`, traderID).Scan(&stats.TotalOpenPositions)
// trader_orders 表统计平仓次数
// Count close positions from trader_orders table
s.db.QueryRow(`
SELECT COUNT(*) FROM trader_orders
WHERE trader_id = ? AND status = 'FILLED' AND action IN ('close_long', 'close_short', 'auto_close_long', 'auto_close_short')
@@ -297,7 +297,7 @@ func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) {
return stats, nil
}
// GetAllStatistics 获取所有交易员的统计信息
// GetAllStatistics gets statistics information for all traders
func (s *DecisionStore) GetAllStatistics() (*Statistics, error) {
stats := &Statistics{}
@@ -305,7 +305,7 @@ func (s *DecisionStore) GetAllStatistics() (*Statistics, error) {
s.db.QueryRow(`SELECT COUNT(*) FROM decision_records WHERE success = 1`).Scan(&stats.SuccessfulCycles)
stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles
// trader_orders 表统计
// Count from trader_orders table
s.db.QueryRow(`
SELECT COUNT(*) FROM trader_orders
WHERE status = 'FILLED' AND action IN ('open_long', 'open_short')
@@ -319,7 +319,7 @@ func (s *DecisionStore) GetAllStatistics() (*Statistics, error) {
return stats, nil
}
// GetLastCycleNumber 获取指定交易员的最后周期编号
// GetLastCycleNumber gets the last cycle number for specified trader
func (s *DecisionStore) GetLastCycleNumber(traderID string) (int, error) {
var cycleNumber int
err := s.db.QueryRow(`
@@ -331,7 +331,7 @@ func (s *DecisionStore) GetLastCycleNumber(traderID string) (int, error) {
return cycleNumber, nil
}
// scanDecisionRecord 从行中扫描决策记录
// scanDecisionRecord scans decision record from row
func (s *DecisionStore) scanDecisionRecord(rows *sql.Rows) (*DecisionRecord, error) {
var record DecisionRecord
var timestampStr string
@@ -354,11 +354,11 @@ func (s *DecisionStore) scanDecisionRecord(rows *sql.Rows) (*DecisionRecord, err
return &record, nil
}
// fillRecordDetails 填充决策记录的关联数据(旧的关联表已删除,此函数保留用于兼容性)
// 注意:账户快照、持仓快照、决策动作等数据已不再存储在 decision 相关表中
// - 净值数据请使用 EquityStore.GetLatest()
// - 订单数据请使用 OrderStore
// fillRecordDetails fills associated data for decision record (old associated tables removed, this function kept for compatibility)
// Note: Account snapshot, position snapshot, decision action data are no longer stored in decision related tables
// - For equity data use EquityStore.GetLatest()
// - For order data use OrderStore
func (s *DecisionStore) fillRecordDetails(record *DecisionRecord) {
// 旧的关联表已删除,不再需要填充
// AccountState, Positions, Decisions 字段将保持为零值
// Old associated tables removed, no longer need to fill
// AccountState, Positions, Decisions fields will remain at zero values
}
+30 -30
View File
@@ -6,27 +6,27 @@ import (
"time"
)
// EquityStore 账户净值存储(用于绘制收益率曲线)
// EquityStore account equity storage (for plotting return curves)
type EquityStore struct {
db *sql.DB
}
// EquitySnapshot 净值快照
// EquitySnapshot equity snapshot
type EquitySnapshot struct {
ID int64 `json:"id"`
TraderID string `json:"trader_id"`
Timestamp time.Time `json:"timestamp"`
TotalEquity float64 `json:"total_equity"` // 账户净值 (余额 + 未实现盈亏)
Balance float64 `json:"balance"` // 账户余额
UnrealizedPnL float64 `json:"unrealized_pnl"` // 未实现盈亏
PositionCount int `json:"position_count"` // 持仓数量
MarginUsedPct float64 `json:"margin_used_pct"` // 保证金使用率
TotalEquity float64 `json:"total_equity"` // Account equity (balance + unrealized PnL)
Balance float64 `json:"balance"` // Account balance
UnrealizedPnL float64 `json:"unrealized_pnl"` // Unrealized profit and loss
PositionCount int `json:"position_count"` // Position count
MarginUsedPct float64 `json:"margin_used_pct"` // Margin usage percentage
}
// initTables 初始化净值表
// initTables initializes equity tables
func (s *EquityStore) initTables() error {
queries := []string{
// 净值快照表 - 专门用于收益率曲线
// Equity snapshot table - specifically for return curves
`CREATE TABLE IF NOT EXISTS trader_equity_snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trader_id TEXT NOT NULL,
@@ -38,21 +38,21 @@ func (s *EquityStore) initTables() error {
margin_used_pct REAL DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
// 索引
// Indexes
`CREATE INDEX IF NOT EXISTS idx_equity_trader_time ON trader_equity_snapshots(trader_id, timestamp DESC)`,
`CREATE INDEX IF NOT EXISTS idx_equity_timestamp ON trader_equity_snapshots(timestamp DESC)`,
}
for _, query := range queries {
if _, err := s.db.Exec(query); err != nil {
return fmt.Errorf("执行SQL失败: %w", err)
return fmt.Errorf("failed to execute SQL: %w", err)
}
}
return nil
}
// Save 保存净值快照
// Save saves equity snapshot
func (s *EquityStore) Save(snapshot *EquitySnapshot) error {
if snapshot.Timestamp.IsZero() {
snapshot.Timestamp = time.Now().UTC()
@@ -75,7 +75,7 @@ func (s *EquityStore) Save(snapshot *EquitySnapshot) error {
snapshot.MarginUsedPct,
)
if err != nil {
return fmt.Errorf("保存净值快照失败: %w", err)
return fmt.Errorf("failed to save equity snapshot: %w", err)
}
id, _ := result.LastInsertId()
@@ -83,7 +83,7 @@ func (s *EquityStore) Save(snapshot *EquitySnapshot) error {
return nil
}
// GetLatest 获取指定交易员最近N条净值记录(按时间正序:从旧到新)
// GetLatest gets the latest N equity records for specified trader (sorted in ascending chronological order: old to new)
func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, timestamp, total_equity, balance,
@@ -94,7 +94,7 @@ func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot,
LIMIT ?
`, traderID, limit)
if err != nil {
return nil, fmt.Errorf("查询净值记录失败: %w", err)
return nil, fmt.Errorf("failed to query equity records: %w", err)
}
defer rows.Close()
@@ -113,7 +113,7 @@ func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot,
snapshots = append(snapshots, snap)
}
// 反转数组,让时间从旧到新排列(适合绘制曲线)
// Reverse the array to sort time from old to new (suitable for plotting curves)
for i, j := 0, len(snapshots)-1; i < j; i, j = i+1, j-1 {
snapshots[i], snapshots[j] = snapshots[j], snapshots[i]
}
@@ -121,7 +121,7 @@ func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot,
return snapshots, nil
}
// GetByTimeRange 获取指定时间范围内的净值记录
// GetByTimeRange gets equity records within specified time range
func (s *EquityStore) GetByTimeRange(traderID string, start, end time.Time) ([]*EquitySnapshot, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, timestamp, total_equity, balance,
@@ -131,7 +131,7 @@ func (s *EquityStore) GetByTimeRange(traderID string, start, end time.Time) ([]*
ORDER BY timestamp ASC
`, traderID, start.Format(time.RFC3339), end.Format(time.RFC3339))
if err != nil {
return nil, fmt.Errorf("查询净值记录失败: %w", err)
return nil, fmt.Errorf("failed to query equity records: %w", err)
}
defer rows.Close()
@@ -153,7 +153,7 @@ func (s *EquityStore) GetByTimeRange(traderID string, start, end time.Time) ([]*
return snapshots, nil
}
// GetAllTradersLatest 获取所有交易员的最新净值(用于排行榜)
// GetAllTradersLatest gets latest equity for all traders (for leaderboards)
func (s *EquityStore) GetAllTradersLatest() (map[string]*EquitySnapshot, error) {
rows, err := s.db.Query(`
SELECT e.id, e.trader_id, e.timestamp, e.total_equity, e.balance,
@@ -166,7 +166,7 @@ func (s *EquityStore) GetAllTradersLatest() (map[string]*EquitySnapshot, error)
) latest ON e.trader_id = latest.trader_id AND e.timestamp = latest.max_ts
`)
if err != nil {
return nil, fmt.Errorf("查询最新净值失败: %w", err)
return nil, fmt.Errorf("failed to query latest equity: %w", err)
}
defer rows.Close()
@@ -188,7 +188,7 @@ func (s *EquityStore) GetAllTradersLatest() (map[string]*EquitySnapshot, error)
return result, nil
}
// CleanOldRecords 清理N天前的旧记录
// CleanOldRecords cleans old records from N days ago
func (s *EquityStore) CleanOldRecords(traderID string, days int) (int64, error) {
cutoffTime := time.Now().AddDate(0, 0, -days).Format(time.RFC3339)
@@ -197,13 +197,13 @@ func (s *EquityStore) CleanOldRecords(traderID string, days int) (int64, error)
WHERE trader_id = ? AND timestamp < ?
`, traderID, cutoffTime)
if err != nil {
return 0, fmt.Errorf("清理旧记录失败: %w", err)
return 0, fmt.Errorf("failed to clean old records: %w", err)
}
return result.RowsAffected()
}
// GetCount 获取指定交易员的记录数
// GetCount gets record count for specified trader
func (s *EquityStore) GetCount(traderID string) (int, error) {
var count int
err := s.db.QueryRow(`
@@ -212,26 +212,26 @@ func (s *EquityStore) GetCount(traderID string) (int, error) {
return count, err
}
// MigrateFromDecision 从旧的 decision_account_snapshots 迁移数据
// MigrateFromDecision migrates data from old decision_account_snapshots table
func (s *EquityStore) MigrateFromDecision() (int64, error) {
// 检查是否需要迁移(新表是否为空)
// Check if migration is needed (whether new table is empty)
var count int
s.db.QueryRow(`SELECT COUNT(*) FROM trader_equity_snapshots`).Scan(&count)
if count > 0 {
return 0, nil // 已有数据,跳过迁移
return 0, nil // Already has data, skip migration
}
// 检查旧表是否存在
// Check if old table exists
var tableName string
err := s.db.QueryRow(`
SELECT name FROM sqlite_master
WHERE type='table' AND name='decision_account_snapshots'
`).Scan(&tableName)
if err != nil {
return 0, nil // 旧表不存在,跳过
return 0, nil // Old table doesn't exist, skip
}
// 迁移数据:从 decision_records + decision_account_snapshots 联合查询
// Migrate data: join query from decision_records + decision_account_snapshots
result, err := s.db.Exec(`
INSERT INTO trader_equity_snapshots (
trader_id, timestamp, total_equity, balance,
@@ -250,7 +250,7 @@ func (s *EquityStore) MigrateFromDecision() (int64, error) {
ORDER BY dr.timestamp ASC
`)
if err != nil {
return 0, fmt.Errorf("迁移数据失败: %w", err)
return 0, fmt.Errorf("failed to migrate data: %w", err)
}
return result.RowsAffected()
+14 -14
View File
@@ -8,14 +8,14 @@ import (
"time"
)
// ExchangeStore 交易所存储
// ExchangeStore exchange storage
type ExchangeStore struct {
db *sql.DB
encryptFunc func(string) string
decryptFunc func(string) string
}
// Exchange 交易所配置
// Exchange exchange configuration
type Exchange struct {
ID string `json:"id"`
UserID string `json:"user_id"`
@@ -24,7 +24,7 @@ type Exchange struct {
Enabled bool `json:"enabled"`
APIKey string `json:"apiKey"`
SecretKey string `json:"secretKey"`
Passphrase string `json:"passphrase"` // OKX专用
Passphrase string `json:"passphrase"` // OKX-specific
Testnet bool `json:"testnet"`
HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"`
AsterUser string `json:"asterUser"`
@@ -65,10 +65,10 @@ func (s *ExchangeStore) initTables() error {
return err
}
// 迁移:添加 passphrase 列(如果不存在)
// Migration: add passphrase column (if not exists)
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN passphrase TEXT DEFAULT ''`)
// 触发器
// Trigger
_, err = s.db.Exec(`
CREATE TRIGGER IF NOT EXISTS update_exchanges_updated_at
AFTER UPDATE ON exchanges
@@ -97,7 +97,7 @@ func (s *ExchangeStore) initDefaultData() error {
VALUES (?, 'default', ?, ?, 0)
`, exchange.id, exchange.name, exchange.typ)
if err != nil {
return fmt.Errorf("初始化交易所失败: %w", err)
return fmt.Errorf("failed to initialize exchange: %w", err)
}
}
return nil
@@ -117,7 +117,7 @@ func (s *ExchangeStore) decrypt(encrypted string) string {
return encrypted
}
// EnsureUserExchanges 确保用户有所有支持的交易所记录
// EnsureUserExchanges ensures user has records for all supported exchanges
func (s *ExchangeStore) EnsureUserExchanges(userID string) error {
exchanges := []struct {
id, name, typ string
@@ -136,17 +136,17 @@ func (s *ExchangeStore) EnsureUserExchanges(userID string) error {
VALUES (?, ?, ?, ?, 0)
`, exchange.id, userID, exchange.name, exchange.typ)
if err != nil {
return fmt.Errorf("确保用户交易所失败: %w", err)
return fmt.Errorf("failed to ensure user exchanges: %w", err)
}
}
return nil
}
// List 获取用户的交易所列表
// List gets user's exchange list
func (s *ExchangeStore) List(userID string) ([]*Exchange, error) {
// 确保用户有所有支持的交易所记录
// Ensure user has records for all supported exchanges
if err := s.EnsureUserExchanges(userID); err != nil {
logger.Debugf("⚠️ 确保用户交易所记录失败: %v", err)
logger.Debugf("Warning: failed to ensure user exchange records: %v", err)
}
rows, err := s.db.Query(`
@@ -194,7 +194,7 @@ func (s *ExchangeStore) List(userID string) ([]*Exchange, error) {
return exchanges, nil
}
// Update 更新交易所配置
// Update updates exchange configuration
func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKey, passphrase string, testnet bool,
hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, lighterWalletAddr, lighterPrivateKey, lighterApiKeyPrivateKey string) error {
@@ -246,7 +246,7 @@ func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKe
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
// 创建新记录,type 使用交易所 ID 以便后续正确识别
// Create new record, use exchange ID as type for correct identification
var name, typ string
switch id {
case "binance":
@@ -278,7 +278,7 @@ func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKe
return nil
}
// Create 创建交易所配置
// Create creates exchange configuration
func (s *ExchangeStore) Create(userID, id, name, typ string, enabled bool, apiKey, secretKey string, testnet bool,
hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error {
_, err := s.db.Exec(`
+73 -73
View File
@@ -7,73 +7,73 @@ import (
"time"
)
// TraderOrder 交易员订单记录
// TraderOrder trader order record
type TraderOrder struct {
ID int64 `json:"id"`
TraderID string `json:"trader_id"` // 交易员ID
OrderID string `json:"order_id"` // 交易所订单ID
ClientOrderID string `json:"client_order_id"` // 客户端订单ID
Symbol string `json:"symbol"` // 交易对
TraderID string `json:"trader_id"` // Trader ID
OrderID string `json:"order_id"` // Exchange order ID
ClientOrderID string `json:"client_order_id"` // Client order ID
Symbol string `json:"symbol"` // Trading pair
Side string `json:"side"` // BUY/SELL
PositionSide string `json:"position_side"` // LONG/SHORT/BOTH
Action string `json:"action"` // open_long/close_long/open_short/close_short
OrderType string `json:"order_type"` // MARKET/LIMIT
Quantity float64 `json:"quantity"` // 订单数量
Price float64 `json:"price"` // 订单价格
AvgPrice float64 `json:"avg_price"` // 实际成交均价
ExecutedQty float64 `json:"executed_qty"` // 已成交数量
Leverage int `json:"leverage"` // 杠杆倍数
Quantity float64 `json:"quantity"` // Order quantity
Price float64 `json:"price"` // Order price
AvgPrice float64 `json:"avg_price"` // Actual average execution price
ExecutedQty float64 `json:"executed_qty"` // Executed quantity
Leverage int `json:"leverage"` // Leverage multiplier
Status string `json:"status"` // NEW/FILLED/CANCELED/EXPIRED
Fee float64 `json:"fee"` // 手续费
FeeAsset string `json:"fee_asset"` // 手续费资产
RealizedPnL float64 `json:"realized_pnl"` // 已实现盈亏(平仓时)
EntryPrice float64 `json:"entry_price"` // 开仓价(平仓时记录)
Fee float64 `json:"fee"` // Fee
FeeAsset string `json:"fee_asset"` // Fee asset
RealizedPnL float64 `json:"realized_pnl"` // Realized PnL (when closing)
EntryPrice float64 `json:"entry_price"` // Entry price (recorded when closing)
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
FilledAt time.Time `json:"filled_at"` // 成交时间
FilledAt time.Time `json:"filled_at"` // Filled time
}
// TraderStats 交易统计指标
// TraderStats trading statistics metrics
type TraderStats struct {
TotalTrades int `json:"total_trades"` // 总交易数(已平仓)
WinTrades int `json:"win_trades"` // 盈利交易数
LossTrades int `json:"loss_trades"` // 亏损交易数
WinRate float64 `json:"win_rate"` // 胜率 (%)
ProfitFactor float64 `json:"profit_factor"` // 盈亏比
SharpeRatio float64 `json:"sharpe_ratio"` // 夏普比
TotalPnL float64 `json:"total_pnl"` // 总盈亏
TotalFee float64 `json:"total_fee"` // 总手续费
AvgWin float64 `json:"avg_win"` // 平均盈利
AvgLoss float64 `json:"avg_loss"` // 平均亏损
MaxDrawdownPct float64 `json:"max_drawdown_pct"` // 最大回撤 (%)
TotalTrades int `json:"total_trades"` // Total trades (closed)
WinTrades int `json:"win_trades"` // Winning trades
LossTrades int `json:"loss_trades"` // Losing trades
WinRate float64 `json:"win_rate"` // Win rate (%)
ProfitFactor float64 `json:"profit_factor"` // Profit factor
SharpeRatio float64 `json:"sharpe_ratio"` // Sharpe ratio
TotalPnL float64 `json:"total_pnl"` // Total PnL
TotalFee float64 `json:"total_fee"` // Total fees
AvgWin float64 `json:"avg_win"` // Average win
AvgLoss float64 `json:"avg_loss"` // Average loss
MaxDrawdownPct float64 `json:"max_drawdown_pct"` // Max drawdown (%)
}
// CompletedOrder 已完成订单(用于AI输入)
// CompletedOrder completed order (for AI input)
type CompletedOrder struct {
Symbol string `json:"symbol"` // 交易对
Symbol string `json:"symbol"` // Trading pair
Action string `json:"action"` // close_long/close_short
Side string `json:"side"` // long/short
Quantity float64 `json:"quantity"` // 数量
EntryPrice float64 `json:"entry_price"` // 开仓价
ExitPrice float64 `json:"exit_price"` // 平仓价
RealizedPnL float64 `json:"realized_pnl"` // 已实现盈亏
PnLPct float64 `json:"pnl_pct"` // 盈亏百分比
Fee float64 `json:"fee"` // 手续费
Leverage int `json:"leverage"` // 杠杆
FilledAt time.Time `json:"filled_at"` // 成交时间
Quantity float64 `json:"quantity"` // Quantity
EntryPrice float64 `json:"entry_price"` // Entry price
ExitPrice float64 `json:"exit_price"` // Exit price
RealizedPnL float64 `json:"realized_pnl"` // Realized PnL
PnLPct float64 `json:"pnl_pct"` // PnL percentage
Fee float64 `json:"fee"` // Fee
Leverage int `json:"leverage"` // Leverage
FilledAt time.Time `json:"filled_at"` // Filled time
}
// OrderStore 订单存储
// OrderStore order storage
type OrderStore struct {
db *sql.DB
}
// NewOrderStore 创建订单存储实例
// NewOrderStore creates order storage instance
func NewOrderStore(db *sql.DB) *OrderStore {
return &OrderStore{db: db}
}
// InitTables 初始化订单表
// InitTables initializes order tables
func (s *OrderStore) InitTables() error {
_, err := s.db.Exec(`
CREATE TABLE IF NOT EXISTS trader_orders (
@@ -103,10 +103,10 @@ func (s *OrderStore) InitTables() error {
)
`)
if err != nil {
return fmt.Errorf("创建trader_orders表失败: %w", err)
return fmt.Errorf("failed to create trader_orders table: %w", err)
}
// 创建索引
// Create indexes
indices := []string{
`CREATE INDEX IF NOT EXISTS idx_trader_orders_trader ON trader_orders(trader_id)`,
`CREATE INDEX IF NOT EXISTS idx_trader_orders_status ON trader_orders(trader_id, status)`,
@@ -115,14 +115,14 @@ func (s *OrderStore) InitTables() error {
}
for _, idx := range indices {
if _, err := s.db.Exec(idx); err != nil {
return fmt.Errorf("创建索引失败: %w", err)
return fmt.Errorf("failed to create index: %w", err)
}
}
return nil
}
// Create 创建订单记录
// Create creates order record
func (s *OrderStore) Create(order *TraderOrder) error {
now := time.Now().Format(time.RFC3339)
result, err := s.db.Exec(`
@@ -140,7 +140,7 @@ func (s *OrderStore) Create(order *TraderOrder) error {
order.RealizedPnL, order.EntryPrice, now, now,
)
if err != nil {
return fmt.Errorf("创建订单记录失败: %w", err)
return fmt.Errorf("failed to create order record: %w", err)
}
id, _ := result.LastInsertId()
@@ -148,7 +148,7 @@ func (s *OrderStore) Create(order *TraderOrder) error {
return nil
}
// Update 更新订单记录
// Update updates order record
func (s *OrderStore) Update(order *TraderOrder) error {
now := time.Now().Format(time.RFC3339)
filledAt := ""
@@ -167,12 +167,12 @@ func (s *OrderStore) Update(order *TraderOrder) error {
order.TraderID, order.OrderID,
)
if err != nil {
return fmt.Errorf("更新订单记录失败: %w", err)
return fmt.Errorf("failed to update order record: %w", err)
}
return nil
}
// GetByOrderID 根据订单ID获取订单
// GetByOrderID gets order by order ID
func (s *OrderStore) GetByOrderID(traderID, orderID string) (*TraderOrder, error) {
var order TraderOrder
var createdAt, updatedAt, filledAt sql.NullString
@@ -208,9 +208,9 @@ func (s *OrderStore) GetByOrderID(traderID, orderID string) (*TraderOrder, error
return &order, nil
}
// GetLatestOpenOrder 获取某币种最近的开仓订单(用于计算平仓盈亏)
// GetLatestOpenOrder gets the latest open order for a symbol (for calculating close PnL)
func (s *OrderStore) GetLatestOpenOrder(traderID, symbol, side string) (*TraderOrder, error) {
// side: long -> open_long, short -> open_short
// side: long -> find open_long, short -> find open_short
action := "open_long"
if side == "short" {
action = "open_short"
@@ -252,7 +252,7 @@ func (s *OrderStore) GetLatestOpenOrder(traderID, symbol, side string) (*TraderO
return &order, nil
}
// GetRecentCompletedOrders 获取最近已完成的平仓订单
// GetRecentCompletedOrders gets recent completed close orders
func (s *OrderStore) GetRecentCompletedOrders(traderID string, limit int) ([]CompletedOrder, error) {
rows, err := s.db.Query(`
SELECT symbol, action, side, executed_qty, entry_price, avg_price,
@@ -264,7 +264,7 @@ func (s *OrderStore) GetRecentCompletedOrders(traderID string, limit int) ([]Com
LIMIT ?
`, traderID, limit)
if err != nil {
return nil, fmt.Errorf("查询已完成订单失败: %w", err)
return nil, fmt.Errorf("failed to query completed orders: %w", err)
}
defer rows.Close()
@@ -282,7 +282,7 @@ func (s *OrderStore) GetRecentCompletedOrders(traderID string, limit int) ([]Com
continue
}
// 根据action推断side
// Infer side from action
if o.Action == "close_long" {
o.Side = "long"
} else if o.Action == "close_short" {
@@ -291,7 +291,7 @@ func (s *OrderStore) GetRecentCompletedOrders(traderID string, limit int) ([]Com
o.Side = side.String
}
// 计算盈亏百分比
// Calculate PnL percentage
if o.EntryPrice > 0 {
if o.Side == "long" {
o.PnLPct = (o.ExitPrice - o.EntryPrice) / o.EntryPrice * 100 * float64(o.Leverage)
@@ -310,11 +310,11 @@ func (s *OrderStore) GetRecentCompletedOrders(traderID string, limit int) ([]Com
return orders, nil
}
// GetTraderStats 获取交易统计指标
// GetTraderStats gets trading statistics metrics
func (s *OrderStore) GetTraderStats(traderID string) (*TraderStats, error) {
stats := &TraderStats{}
// 查询所有已完成的平仓订单
// Query all completed close orders
rows, err := s.db.Query(`
SELECT realized_pnl, fee, filled_at
FROM trader_orders
@@ -323,7 +323,7 @@ func (s *OrderStore) GetTraderStats(traderID string) (*TraderStats, error) {
ORDER BY filled_at ASC
`, traderID)
if err != nil {
return nil, fmt.Errorf("查询订单统计失败: %w", err)
return nil, fmt.Errorf("failed to query order statistics: %w", err)
}
defer rows.Close()
@@ -351,17 +351,17 @@ func (s *OrderStore) GetTraderStats(traderID string) (*TraderStats, error) {
}
}
// 计算胜率
// Calculate win rate
if stats.TotalTrades > 0 {
stats.WinRate = float64(stats.WinTrades) / float64(stats.TotalTrades) * 100
}
// 计算盈亏比
// Calculate profit factor
if totalLoss > 0 {
stats.ProfitFactor = totalWin / totalLoss
}
// 计算平均盈亏
// Calculate average win/loss
if stats.WinTrades > 0 {
stats.AvgWin = totalWin / float64(stats.WinTrades)
}
@@ -369,12 +369,12 @@ func (s *OrderStore) GetTraderStats(traderID string) (*TraderStats, error) {
stats.AvgLoss = totalLoss / float64(stats.LossTrades)
}
// 计算夏普比(使用盈亏序列)
// Calculate Sharpe ratio (using PnL sequence)
if len(pnls) > 1 {
stats.SharpeRatio = calculateSharpeRatio(pnls)
}
// 计算最大回撤
// Calculate max drawdown
if len(pnls) > 0 {
stats.MaxDrawdownPct = calculateMaxDrawdown(pnls)
}
@@ -382,20 +382,20 @@ func (s *OrderStore) GetTraderStats(traderID string) (*TraderStats, error) {
return stats, nil
}
// calculateSharpeRatio 计算夏普比
// calculateSharpeRatio calculates Sharpe ratio
func calculateSharpeRatio(pnls []float64) float64 {
if len(pnls) < 2 {
return 0
}
// 计算平均收益
// Calculate average return
var sum float64
for _, pnl := range pnls {
sum += pnl
}
mean := sum / float64(len(pnls))
// 计算标准差
// Calculate standard deviation
var variance float64
for _, pnl := range pnls {
variance += (pnl - mean) * (pnl - mean)
@@ -406,17 +406,17 @@ func calculateSharpeRatio(pnls []float64) float64 {
return 0
}
// 夏普比 = 平均收益 / 标准差
// Sharpe ratio = average return / standard deviation
return mean / stdDev
}
// calculateMaxDrawdown 计算最大回撤
// calculateMaxDrawdown calculates max drawdown
func calculateMaxDrawdown(pnls []float64) float64 {
if len(pnls) == 0 {
return 0
}
// 计算累计权益曲线
// Calculate cumulative equity curve
var cumulative float64
var peak float64
var maxDD float64
@@ -437,7 +437,7 @@ func calculateMaxDrawdown(pnls []float64) float64 {
return maxDD
}
// GetPendingOrders 获取未成交的订单(用于轮询)
// GetPendingOrders gets pending orders (for polling)
func (s *OrderStore) GetPendingOrders(traderID string) ([]*TraderOrder, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, order_id, client_order_id, symbol, side, position_side,
@@ -449,14 +449,14 @@ func (s *OrderStore) GetPendingOrders(traderID string) ([]*TraderOrder, error) {
ORDER BY created_at ASC
`, traderID)
if err != nil {
return nil, fmt.Errorf("查询未成交订单失败: %w", err)
return nil, fmt.Errorf("failed to query pending orders: %w", err)
}
defer rows.Close()
return s.scanOrders(rows)
}
// GetAllPendingOrders 获取所有未成交的订单(用于全局同步)
// GetAllPendingOrders gets all pending orders (for global sync)
func (s *OrderStore) GetAllPendingOrders() ([]*TraderOrder, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, order_id, client_order_id, symbol, side, position_side,
@@ -468,14 +468,14 @@ func (s *OrderStore) GetAllPendingOrders() ([]*TraderOrder, error) {
ORDER BY trader_id, created_at ASC
`)
if err != nil {
return nil, fmt.Errorf("查询未成交订单失败: %w", err)
return nil, fmt.Errorf("failed to query pending orders: %w", err)
}
defer rows.Close()
return s.scanOrders(rows)
}
// scanOrders 扫描订单行到结构体
// scanOrders scans order rows to structs
func (s *OrderStore) scanOrders(rows *sql.Rows) ([]*TraderOrder, error) {
var orders []*TraderOrder
for rows.Next() {
+53 -53
View File
@@ -7,40 +7,40 @@ import (
"time"
)
// TraderPosition 仓位记录(完整的开平仓追踪)
// TraderPosition position record (complete open/close position tracking)
type TraderPosition struct {
ID int64 `json:"id"`
TraderID string `json:"trader_id"`
ExchangeID string `json:"exchange_id"` // 交易所ID: binance/bybit/hyperliquid/aster/lighter
ExchangeID string `json:"exchange_id"` // Exchange ID: binance/bybit/hyperliquid/aster/lighter
Symbol string `json:"symbol"`
Side string `json:"side"` // LONG/SHORT
Quantity float64 `json:"quantity"` // 开仓数量
EntryPrice float64 `json:"entry_price"` // 开仓均价
EntryOrderID string `json:"entry_order_id"` // 开仓订单ID
EntryTime time.Time `json:"entry_time"` // 开仓时间
ExitPrice float64 `json:"exit_price"` // 平仓均价
ExitOrderID string `json:"exit_order_id"` // 平仓订单ID
ExitTime *time.Time `json:"exit_time"` // 平仓时间
RealizedPnL float64 `json:"realized_pnl"` // 已实现盈亏
Fee float64 `json:"fee"` // 手续费
Leverage int `json:"leverage"` // 杠杆倍数
Quantity float64 `json:"quantity"` // Opening quantity
EntryPrice float64 `json:"entry_price"` // Entry price
EntryOrderID string `json:"entry_order_id"` // Entry order ID
EntryTime time.Time `json:"entry_time"` // Entry time
ExitPrice float64 `json:"exit_price"` // Exit price
ExitOrderID string `json:"exit_order_id"` // Exit order ID
ExitTime *time.Time `json:"exit_time"` // Exit time
RealizedPnL float64 `json:"realized_pnl"` // Realized profit and loss
Fee float64 `json:"fee"` // Fee
Leverage int `json:"leverage"` // Leverage multiplier
Status string `json:"status"` // OPEN/CLOSED
CloseReason string `json:"close_reason"` // 平仓原因: ai_decision/manual/stop_loss/take_profit
CloseReason string `json:"close_reason"` // Close reason: ai_decision/manual/stop_loss/take_profit
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// PositionStore 仓位存储
// PositionStore position storage
type PositionStore struct {
db *sql.DB
}
// NewPositionStore 创建仓位存储实例
// NewPositionStore creates position storage instance
func NewPositionStore(db *sql.DB) *PositionStore {
return &PositionStore{db: db}
}
// InitTables 初始化仓位表
// InitTables initializes position tables
func (s *PositionStore) InitTables() error {
_, err := s.db.Exec(`
CREATE TABLE IF NOT EXISTS trader_positions (
@@ -66,14 +66,14 @@ func (s *PositionStore) InitTables() error {
)
`)
if err != nil {
return fmt.Errorf("创建trader_positions表失败: %w", err)
return fmt.Errorf("failed to create trader_positions table: %w", err)
}
// 迁移:为现有表添加 exchange_id 列(如果不存在)
// 必须在创建索引之前执行!
// Migration: add exchange_id column to existing table (if not exists)
// Must be executed before creating indexes!
s.db.Exec(`ALTER TABLE trader_positions ADD COLUMN exchange_id TEXT NOT NULL DEFAULT ''`)
// 创建索引(在迁移之后)
// Create indexes (after migration)
indices := []string{
`CREATE INDEX IF NOT EXISTS idx_positions_trader ON trader_positions(trader_id)`,
`CREATE INDEX IF NOT EXISTS idx_positions_exchange ON trader_positions(exchange_id)`,
@@ -84,14 +84,14 @@ func (s *PositionStore) InitTables() error {
}
for _, idx := range indices {
if _, err := s.db.Exec(idx); err != nil {
return fmt.Errorf("创建索引失败: %w", err)
return fmt.Errorf("failed to create index: %w", err)
}
}
return nil
}
// Create 创建仓位记录(开仓时调用)
// Create creates position record (called when opening position)
func (s *PositionStore) Create(pos *TraderPosition) error {
now := time.Now()
pos.CreatedAt = now
@@ -109,7 +109,7 @@ func (s *PositionStore) Create(pos *TraderPosition) error {
pos.Status, now.Format(time.RFC3339), now.Format(time.RFC3339),
)
if err != nil {
return fmt.Errorf("创建仓位记录失败: %w", err)
return fmt.Errorf("failed to create position record: %w", err)
}
id, _ := result.LastInsertId()
@@ -117,7 +117,7 @@ func (s *PositionStore) Create(pos *TraderPosition) error {
return nil
}
// ClosePosition 平仓(更新仓位记录)
// ClosePosition closes position (updates position record)
func (s *PositionStore) ClosePosition(id int64, exitPrice float64, exitOrderID string, realizedPnL float64, fee float64, closeReason string) error {
now := time.Now()
_, err := s.db.Exec(`
@@ -131,12 +131,12 @@ func (s *PositionStore) ClosePosition(id int64, exitPrice float64, exitOrderID s
realizedPnL, fee, closeReason, now.Format(time.RFC3339), id,
)
if err != nil {
return fmt.Errorf("更新仓位记录失败: %w", err)
return fmt.Errorf("failed to update position record: %w", err)
}
return nil
}
// GetOpenPositions 获取所有未平仓位
// GetOpenPositions gets all open positions
func (s *PositionStore) GetOpenPositions(traderID string) ([]*TraderPosition, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, exchange_id, symbol, side, quantity, entry_price, entry_order_id,
@@ -147,14 +147,14 @@ func (s *PositionStore) GetOpenPositions(traderID string) ([]*TraderPosition, er
ORDER BY entry_time DESC
`, traderID)
if err != nil {
return nil, fmt.Errorf("查询未平仓位失败: %w", err)
return nil, fmt.Errorf("failed to query open positions: %w", err)
}
defer rows.Close()
return s.scanPositions(rows)
}
// GetOpenPositionBySymbol 获取指定币种方向的未平仓位
// GetOpenPositionBySymbol gets open position for specified symbol and direction
func (s *PositionStore) GetOpenPositionBySymbol(traderID, symbol, side string) (*TraderPosition, error) {
var pos TraderPosition
var entryTime, exitTime, createdAt, updatedAt sql.NullString
@@ -183,7 +183,7 @@ func (s *PositionStore) GetOpenPositionBySymbol(traderID, symbol, side string) (
return &pos, nil
}
// GetClosedPositions 获取已平仓位(历史记录)
// GetClosedPositions gets closed positions (historical records)
func (s *PositionStore) GetClosedPositions(traderID string, limit int) ([]*TraderPosition, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, exchange_id, symbol, side, quantity, entry_price, entry_order_id,
@@ -195,14 +195,14 @@ func (s *PositionStore) GetClosedPositions(traderID string, limit int) ([]*Trade
LIMIT ?
`, traderID, limit)
if err != nil {
return nil, fmt.Errorf("查询已平仓位失败: %w", err)
return nil, fmt.Errorf("failed to query closed positions: %w", err)
}
defer rows.Close()
return s.scanPositions(rows)
}
// GetAllOpenPositions 获取所有trader的未平仓位(用于全局同步)
// GetAllOpenPositions gets all traders' open positions (for global sync)
func (s *PositionStore) GetAllOpenPositions() ([]*TraderPosition, error) {
rows, err := s.db.Query(`
SELECT id, trader_id, exchange_id, symbol, side, quantity, entry_price, entry_order_id,
@@ -213,18 +213,18 @@ func (s *PositionStore) GetAllOpenPositions() ([]*TraderPosition, error) {
ORDER BY trader_id, entry_time DESC
`)
if err != nil {
return nil, fmt.Errorf("查询所有未平仓位失败: %w", err)
return nil, fmt.Errorf("failed to query all open positions: %w", err)
}
defer rows.Close()
return s.scanPositions(rows)
}
// GetPositionStats 获取仓位统计(简单版)
// GetPositionStats gets position statistics (simplified version)
func (s *PositionStore) GetPositionStats(traderID string) (map[string]interface{}, error) {
stats := make(map[string]interface{})
// 总交易数
// Total trades
var totalTrades, winTrades int
var totalPnL, totalFee float64
@@ -254,11 +254,11 @@ func (s *PositionStore) GetPositionStats(traderID string) (map[string]interface{
return stats, nil
}
// GetFullStats 获取完整的交易统计(与 TraderStats 兼容)
// GetFullStats gets complete trading statistics (compatible with TraderStats)
func (s *PositionStore) GetFullStats(traderID string) (*TraderStats, error) {
stats := &TraderStats{}
// 查询所有已平仓位
// Query all closed positions
rows, err := s.db.Query(`
SELECT realized_pnl, fee, exit_time
FROM trader_positions
@@ -266,7 +266,7 @@ func (s *PositionStore) GetFullStats(traderID string) (*TraderStats, error) {
ORDER BY exit_time ASC
`, traderID)
if err != nil {
return nil, fmt.Errorf("查询仓位统计失败: %w", err)
return nil, fmt.Errorf("failed to query position statistics: %w", err)
}
defer rows.Close()
@@ -290,21 +290,21 @@ func (s *PositionStore) GetFullStats(traderID string) (*TraderStats, error) {
totalWin += pnl
} else if pnl < 0 {
stats.LossTrades++
totalLoss += -pnl // 转为正数
totalLoss += -pnl // Convert to positive
}
}
// 计算胜率
// Calculate win rate
if stats.TotalTrades > 0 {
stats.WinRate = float64(stats.WinTrades) / float64(stats.TotalTrades) * 100
}
// 计算盈亏比
// Calculate profit factor
if totalLoss > 0 {
stats.ProfitFactor = totalWin / totalLoss
}
// 计算平均盈亏
// Calculate average profit/loss
if stats.WinTrades > 0 {
stats.AvgWin = totalWin / float64(stats.WinTrades)
}
@@ -312,12 +312,12 @@ func (s *PositionStore) GetFullStats(traderID string) (*TraderStats, error) {
stats.AvgLoss = totalLoss / float64(stats.LossTrades)
}
// 计算夏普比
// Calculate Sharpe ratio
if len(pnls) > 1 {
stats.SharpeRatio = calculateSharpeRatioFromPnls(pnls)
}
// 计算最大回撤
// Calculate maximum drawdown
if len(pnls) > 0 {
stats.MaxDrawdownPct = calculateMaxDrawdownFromPnls(pnls)
}
@@ -325,7 +325,7 @@ func (s *PositionStore) GetFullStats(traderID string) (*TraderStats, error) {
return stats, nil
}
// RecentTrade 最近的交易记录(用于AI输入)
// RecentTrade recent trade record (for AI input)
type RecentTrade struct {
Symbol string `json:"symbol"`
Side string `json:"side"` // long/short
@@ -336,7 +336,7 @@ type RecentTrade struct {
ExitTime string `json:"exit_time"`
}
// GetRecentTrades 获取最近的已平仓交易
// GetRecentTrades gets recent closed trades
func (s *PositionStore) GetRecentTrades(traderID string, limit int) ([]RecentTrade, error) {
rows, err := s.db.Query(`
SELECT symbol, side, entry_price, exit_price, realized_pnl, leverage, exit_time
@@ -346,7 +346,7 @@ func (s *PositionStore) GetRecentTrades(traderID string, limit int) ([]RecentTra
LIMIT ?
`, traderID, limit)
if err != nil {
return nil, fmt.Errorf("查询最近交易失败: %w", err)
return nil, fmt.Errorf("failed to query recent trades: %w", err)
}
defer rows.Close()
@@ -361,14 +361,14 @@ func (s *PositionStore) GetRecentTrades(traderID string, limit int) ([]RecentTra
continue
}
// 转换 side 格式
// Convert side format
if t.Side == "LONG" {
t.Side = "long"
} else if t.Side == "SHORT" {
t.Side = "short"
}
// 计算盈亏百分比
// Calculate profit/loss percentage
if t.EntryPrice > 0 {
if t.Side == "long" {
t.PnLPct = (t.ExitPrice - t.EntryPrice) / t.EntryPrice * 100 * float64(leverage)
@@ -377,7 +377,7 @@ func (s *PositionStore) GetRecentTrades(traderID string, limit int) ([]RecentTra
}
}
// 格式化时间
// Format time
if exitTime.Valid {
if parsed, err := time.Parse(time.RFC3339, exitTime.String); err == nil {
t.ExitTime = parsed.Format("01-02 15:04")
@@ -390,7 +390,7 @@ func (s *PositionStore) GetRecentTrades(traderID string, limit int) ([]RecentTra
return trades, nil
}
// calculateSharpeRatioFromPnls 计算夏普比
// calculateSharpeRatioFromPnls calculates Sharpe ratio
func calculateSharpeRatioFromPnls(pnls []float64) float64 {
if len(pnls) < 2 {
return 0
@@ -415,7 +415,7 @@ func calculateSharpeRatioFromPnls(pnls []float64) float64 {
return mean / stdDev
}
// calculateMaxDrawdownFromPnls 计算最大回撤
// calculateMaxDrawdownFromPnls calculates maximum drawdown
func calculateMaxDrawdownFromPnls(pnls []float64) float64 {
if len(pnls) == 0 {
return 0
@@ -438,7 +438,7 @@ func calculateMaxDrawdownFromPnls(pnls []float64) float64 {
return maxDD
}
// scanPositions 扫描仓位行到结构体
// scanPositions scans position rows into structs
func (s *PositionStore) scanPositions(rows *sql.Rows) ([]*TraderPosition, error) {
var positions []*TraderPosition
for rows.Next() {
@@ -462,7 +462,7 @@ func (s *PositionStore) scanPositions(rows *sql.Rows) ([]*TraderPosition, error)
return positions, nil
}
// parsePositionTimes 解析时间字段
// parsePositionTimes parses time fields
func (s *PositionStore) parsePositionTimes(pos *TraderPosition, entryTime, exitTime, createdAt, updatedAt sql.NullString) {
if entryTime.Valid {
pos.EntryTime, _ = time.Parse(time.RFC3339, entryTime.String)
+57 -57
View File
@@ -1,5 +1,5 @@
// Package store 提供统一的数据库存储层
// 所有数据库操作都应该通过这个包进行
// Package store provides unified database storage layer
// All database operations should go through this package
package store
import (
@@ -11,11 +11,11 @@ import (
_ "modernc.org/sqlite"
)
// Store 统一的数据存储接口
// Store unified data storage interface
type Store struct {
db *sql.DB
// 子存储(延迟初始化)
// Sub-stores (lazy initialization)
user *UserStore
aiModel *AIModelStore
exchange *ExchangeStore
@@ -27,80 +27,80 @@ type Store struct {
strategy *StrategyStore
equity *EquityStore
// 加密函数
// Encryption functions
encryptFunc func(string) string
decryptFunc func(string) string
mu sync.RWMutex
}
// New 创建新的 Store 实例
// New creates new Store instance
func New(dbPath string) (*Store, error) {
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, fmt.Errorf("打开数据库失败: %w", err)
return nil, fmt.Errorf("failed to open database: %w", err)
}
// SQLite 配置
// SQLite configuration
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
// 启用外键约束
// Enable foreign key constraints
if _, err := db.Exec(`PRAGMA foreign_keys = ON`); err != nil {
db.Close()
return nil, fmt.Errorf("启用外键失败: %w", err)
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
}
// 使用 DELETE 模式(传统模式)以确保 Docker bind mount 兼容性
// 注意:WAL 模式在 macOS Docker 下会导致数据同步问题
// Use DELETE mode (traditional mode) to ensure Docker bind mount compatibility
// Note: WAL mode causes data sync issues on macOS Docker
if _, err := db.Exec("PRAGMA journal_mode=DELETE"); err != nil {
db.Close()
return nil, fmt.Errorf("设置journal_mode失败: %w", err)
return nil, fmt.Errorf("failed to set journal_mode: %w", err)
}
// 设置 synchronous=FULL
// Set synchronous=FULL
if _, err := db.Exec("PRAGMA synchronous=FULL"); err != nil {
db.Close()
return nil, fmt.Errorf("设置synchronous失败: %w", err)
return nil, fmt.Errorf("failed to set synchronous: %w", err)
}
// 设置 busy_timeout
// Set busy_timeout
if _, err := db.Exec("PRAGMA busy_timeout = 5000"); err != nil {
db.Close()
return nil, fmt.Errorf("设置busy_timeout失败: %w", err)
return nil, fmt.Errorf("failed to set busy_timeout: %w", err)
}
s := &Store{db: db}
// 初始化所有表结构
// Initialize all table structures
if err := s.initTables(); err != nil {
db.Close()
return nil, fmt.Errorf("初始化表结构失败: %w", err)
return nil, fmt.Errorf("failed to initialize table structure: %w", err)
}
// 初始化默认数据
// Initialize default data
if err := s.initDefaultData(); err != nil {
db.Close()
return nil, fmt.Errorf("初始化默认数据失败: %w", err)
return nil, fmt.Errorf("failed to initialize default data: %w", err)
}
logger.Info("✅ 数据库已启用 DELETE 模式和 FULL 同步")
logger.Info("✅ Database enabled DELETE mode and FULL sync")
return s, nil
}
// NewFromDB 从现有数据库连接创建 Store
// NewFromDB creates Store from existing database connection
func NewFromDB(db *sql.DB) *Store {
return &Store{db: db}
}
// SetCryptoFuncs 设置加密解密函数
// SetCryptoFuncs sets encryption/decryption functions
func (s *Store) SetCryptoFuncs(encrypt, decrypt func(string) string) {
s.mu.Lock()
defer s.mu.Unlock()
s.encryptFunc = encrypt
s.decryptFunc = decrypt
// 更新已初始化的子存储
// Update already initialized sub-stores
if s.aiModel != nil {
s.aiModel.encryptFunc = encrypt
s.aiModel.decryptFunc = decrypt
@@ -114,43 +114,43 @@ func (s *Store) SetCryptoFuncs(encrypt, decrypt func(string) string) {
}
}
// initTables 初始化所有数据库表
// initTables initializes all database tables
func (s *Store) initTables() error {
// 按依赖顺序初始化
// Initialize in dependency order
if err := s.User().initTables(); err != nil {
return fmt.Errorf("初始化用户表失败: %w", err)
return fmt.Errorf("failed to initialize user tables: %w", err)
}
if err := s.AIModel().initTables(); err != nil {
return fmt.Errorf("初始化AI模型表失败: %w", err)
return fmt.Errorf("failed to initialize AI model tables: %w", err)
}
if err := s.Exchange().initTables(); err != nil {
return fmt.Errorf("初始化交易所表失败: %w", err)
return fmt.Errorf("failed to initialize exchange tables: %w", err)
}
if err := s.Trader().initTables(); err != nil {
return fmt.Errorf("初始化交易员表失败: %w", err)
return fmt.Errorf("failed to initialize trader tables: %w", err)
}
if err := s.Decision().initTables(); err != nil {
return fmt.Errorf("初始化决策日志表失败: %w", err)
return fmt.Errorf("failed to initialize decision log tables: %w", err)
}
if err := s.Backtest().initTables(); err != nil {
return fmt.Errorf("初始化回测表失败: %w", err)
return fmt.Errorf("failed to initialize backtest tables: %w", err)
}
if err := s.Order().InitTables(); err != nil {
return fmt.Errorf("初始化订单表失败: %w", err)
return fmt.Errorf("failed to initialize order tables: %w", err)
}
if err := s.Position().InitTables(); err != nil {
return fmt.Errorf("初始化仓位表失败: %w", err)
return fmt.Errorf("failed to initialize position tables: %w", err)
}
if err := s.Strategy().initTables(); err != nil {
return fmt.Errorf("初始化策略表失败: %w", err)
return fmt.Errorf("failed to initialize strategy tables: %w", err)
}
if err := s.Equity().initTables(); err != nil {
return fmt.Errorf("初始化净值表失败: %w", err)
return fmt.Errorf("failed to initialize equity tables: %w", err)
}
return nil
}
// initDefaultData 初始化默认数据
// initDefaultData initializes default data
func (s *Store) initDefaultData() error {
if err := s.AIModel().initDefaultData(); err != nil {
return err
@@ -161,16 +161,16 @@ func (s *Store) initDefaultData() error {
if err := s.Strategy().initDefaultData(); err != nil {
return err
}
// 迁移旧的 decision_account_snapshots 数据到新的 trader_equity_snapshots
// Migrate old decision_account_snapshots data to new trader_equity_snapshots table
if migrated, err := s.Equity().MigrateFromDecision(); err != nil {
logger.Warnf("迁移净值数据失败: %v", err)
logger.Warnf("failed to migrate equity data: %v", err)
} else if migrated > 0 {
logger.Infof("✅ 已迁移 %d 条净值数据到新表", migrated)
logger.Infof("✅ Migrated %d equity records to new table", migrated)
}
return nil
}
// User 获取用户存储
// User gets user storage
func (s *Store) User() *UserStore {
s.mu.Lock()
defer s.mu.Unlock()
@@ -180,7 +180,7 @@ func (s *Store) User() *UserStore {
return s.user
}
// AIModel 获取AI模型存储
// AIModel gets AI model storage
func (s *Store) AIModel() *AIModelStore {
s.mu.Lock()
defer s.mu.Unlock()
@@ -194,7 +194,7 @@ func (s *Store) AIModel() *AIModelStore {
return s.aiModel
}
// Exchange 获取交易所存储
// Exchange gets exchange storage
func (s *Store) Exchange() *ExchangeStore {
s.mu.Lock()
defer s.mu.Unlock()
@@ -208,7 +208,7 @@ func (s *Store) Exchange() *ExchangeStore {
return s.exchange
}
// Trader 获取交易员存储
// Trader gets trader storage
func (s *Store) Trader() *TraderStore {
s.mu.Lock()
defer s.mu.Unlock()
@@ -221,7 +221,7 @@ func (s *Store) Trader() *TraderStore {
return s.trader
}
// Decision 获取决策日志存储
// Decision gets decision log storage
func (s *Store) Decision() *DecisionStore {
s.mu.Lock()
defer s.mu.Unlock()
@@ -231,7 +231,7 @@ func (s *Store) Decision() *DecisionStore {
return s.decision
}
// Backtest 获取回测数据存储
// Backtest gets backtest data storage
func (s *Store) Backtest() *BacktestStore {
s.mu.Lock()
defer s.mu.Unlock()
@@ -241,7 +241,7 @@ func (s *Store) Backtest() *BacktestStore {
return s.backtest
}
// Order 获取订单存储
// Order gets order storage
func (s *Store) Order() *OrderStore {
s.mu.Lock()
defer s.mu.Unlock()
@@ -251,7 +251,7 @@ func (s *Store) Order() *OrderStore {
return s.order
}
// Position 获取仓位存储
// Position gets position storage
func (s *Store) Position() *PositionStore {
s.mu.Lock()
defer s.mu.Unlock()
@@ -261,7 +261,7 @@ func (s *Store) Position() *PositionStore {
return s.position
}
// Strategy 获取策略存储
// Strategy gets strategy storage
func (s *Store) Strategy() *StrategyStore {
s.mu.Lock()
defer s.mu.Unlock()
@@ -271,7 +271,7 @@ func (s *Store) Strategy() *StrategyStore {
return s.strategy
}
// Equity 获取净值存储
// Equity gets equity storage
func (s *Store) Equity() *EquityStore {
s.mu.Lock()
defer s.mu.Unlock()
@@ -281,22 +281,22 @@ func (s *Store) Equity() *EquityStore {
return s.equity
}
// Close 关闭数据库连接
// Close closes database connection
func (s *Store) Close() error {
return s.db.Close()
}
// DB 获取底层数据库连接(仅用于兼容旧代码,逐步废弃)
// Deprecated: 使用 Store 的方法代替
// DB gets underlying database connection (for legacy code compatibility, gradually deprecated)
// Deprecated: use Store methods instead
func (s *Store) DB() *sql.DB {
return s.db
}
// Transaction 执行事务
// Transaction executes transaction
func (s *Store) Transaction(fn func(tx *sql.Tx) error) error {
tx, err := s.db.Begin()
if err != nil {
return fmt.Errorf("开始事务失败: %w", err)
return fmt.Errorf("failed to begin transaction: %w", err)
}
if err := fn(tx); err != nil {
@@ -305,7 +305,7 @@ func (s *Store) Transaction(fn func(tx *sql.Tx) error) error {
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("提交事务失败: %w", err)
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
+104 -104
View File
@@ -7,139 +7,139 @@ import (
"time"
)
// StrategyStore 策略存储
// StrategyStore strategy storage
type StrategyStore struct {
db *sql.DB
}
// Strategy 策略配置
// Strategy strategy configuration
type Strategy struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Name string `json:"name"`
Description string `json:"description"`
IsActive bool `json:"is_active"` // 是否激活(一个用户只能有一个激活的策略)
IsDefault bool `json:"is_default"` // 是否为系统默认策略
Config string `json:"config"` // JSON 格式的策略配置
IsActive bool `json:"is_active"` // whether it is active (a user can only have one active strategy)
IsDefault bool `json:"is_default"` // whether it is a system default strategy
Config string `json:"config"` // strategy configuration in JSON format
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// StrategyConfig 策略配置详情(JSON 结构)
// StrategyConfig strategy configuration details (JSON structure)
type StrategyConfig struct {
// 币种来源配置
// coin source configuration
CoinSource CoinSourceConfig `json:"coin_source"`
// 量化数据配置
// quantitative data configuration
Indicators IndicatorConfig `json:"indicators"`
// 自定义 Prompt(附加在最后)
// custom prompt (appended at the end)
CustomPrompt string `json:"custom_prompt,omitempty"`
// 风险控制配置
// risk control configuration
RiskControl RiskControlConfig `json:"risk_control"`
// System Prompt 可编辑部分
// editable sections of System Prompt
PromptSections PromptSectionsConfig `json:"prompt_sections,omitempty"`
}
// PromptSectionsConfig System Prompt 可编辑部分
// PromptSectionsConfig editable sections of System Prompt
type PromptSectionsConfig struct {
// 角色定义(标题+描述)
// role definition (title + description)
RoleDefinition string `json:"role_definition,omitempty"`
// 交易频率认知
// trading frequency awareness
TradingFrequency string `json:"trading_frequency,omitempty"`
// 开仓标准
// entry standards
EntryStandards string `json:"entry_standards,omitempty"`
// 决策流程
// decision process
DecisionProcess string `json:"decision_process,omitempty"`
}
// CoinSourceConfig 币种来源配置
// CoinSourceConfig coin source configuration
type CoinSourceConfig struct {
// 来源类型: "static" | "coinpool" | "oi_top" | "mixed"
// source type: "static" | "coinpool" | "oi_top" | "mixed"
SourceType string `json:"source_type"`
// 静态币种列表(当 source_type = "static" 时使用)
// static coin list (used when source_type = "static")
StaticCoins []string `json:"static_coins,omitempty"`
// 是否使用 AI500 币种池
// whether to use AI500 coin pool
UseCoinPool bool `json:"use_coin_pool"`
// AI500 币种池最大数量
// AI500 coin pool maximum count
CoinPoolLimit int `json:"coin_pool_limit,omitempty"`
// AI500 币种池 API URL(策略级别配置)
// AI500 coin pool API URL (strategy-level configuration)
CoinPoolAPIURL string `json:"coin_pool_api_url,omitempty"`
// 是否使用 OI Top
// whether to use OI Top
UseOITop bool `json:"use_oi_top"`
// OI Top 最大数量
// OI Top maximum count
OITopLimit int `json:"oi_top_limit,omitempty"`
// OI Top API URL(策略级别配置)
// OI Top API URL (strategy-level configuration)
OITopAPIURL string `json:"oi_top_api_url,omitempty"`
}
// IndicatorConfig 指标配置
// IndicatorConfig indicator configuration
type IndicatorConfig struct {
// K线配置
// K-line configuration
Klines KlineConfig `json:"klines"`
// 技术指标开关
// technical indicator switches
EnableEMA bool `json:"enable_ema"`
EnableMACD bool `json:"enable_macd"`
EnableRSI bool `json:"enable_rsi"`
EnableATR bool `json:"enable_atr"`
EnableVolume bool `json:"enable_volume"`
EnableOI bool `json:"enable_oi"` // 持仓量
EnableFundingRate bool `json:"enable_funding_rate"` // 资金费率
// EMA 周期配置
EMAPeriods []int `json:"ema_periods,omitempty"` // 默认 [20, 50]
// RSI 周期配置
RSIPeriods []int `json:"rsi_periods,omitempty"` // 默认 [7, 14]
// ATR 周期配置
ATRPeriods []int `json:"atr_periods,omitempty"` // 默认 [14]
// 外部数据源
EnableOI bool `json:"enable_oi"` // open interest
EnableFundingRate bool `json:"enable_funding_rate"` // funding rate
// EMA period configuration
EMAPeriods []int `json:"ema_periods,omitempty"` // default [20, 50]
// RSI period configuration
RSIPeriods []int `json:"rsi_periods,omitempty"` // default [7, 14]
// ATR period configuration
ATRPeriods []int `json:"atr_periods,omitempty"` // default [14]
// external data sources
ExternalDataSources []ExternalDataSource `json:"external_data_sources,omitempty"`
// 量化数据源(资金流向、持仓变化、价格变化)
EnableQuantData bool `json:"enable_quant_data"` // 是否启用量化数据
QuantDataAPIURL string `json:"quant_data_api_url,omitempty"` // 量化数据 API 地址
// quantitative data sources (capital flow, position changes, price changes)
EnableQuantData bool `json:"enable_quant_data"` // whether to enable quantitative data
QuantDataAPIURL string `json:"quant_data_api_url,omitempty"` // quantitative data API address
}
// KlineConfig K线配置
// KlineConfig K-line configuration
type KlineConfig struct {
// 主时间周期: "1m", "3m", "5m", "15m", "1h", "4h"
// primary timeframe: "1m", "3m", "5m", "15m", "1h", "4h"
PrimaryTimeframe string `json:"primary_timeframe"`
// 主时间周期 K 线数量
// primary timeframe K-line count
PrimaryCount int `json:"primary_count"`
// 长周期时间框架
// longer timeframe
LongerTimeframe string `json:"longer_timeframe,omitempty"`
// 长周期 K 线数量
// longer timeframe K-line count
LongerCount int `json:"longer_count,omitempty"`
// 是否启用多时间框架分析
// whether to enable multi-timeframe analysis
EnableMultiTimeframe bool `json:"enable_multi_timeframe"`
// 选中的时间周期列表(新增:支持多周期选择)
// selected timeframe list (new: supports multi-timeframe selection)
SelectedTimeframes []string `json:"selected_timeframes,omitempty"`
}
// ExternalDataSource 外部数据源配置
// ExternalDataSource external data source configuration
type ExternalDataSource struct {
Name string `json:"name"` // 数据源名称
Type string `json:"type"` // 类型: "api" | "webhook"
Name string `json:"name"` // data source name
Type string `json:"type"` // type: "api" | "webhook"
URL string `json:"url"` // API URL
Method string `json:"method"` // HTTP 方法
Method string `json:"method"` // HTTP method
Headers map[string]string `json:"headers,omitempty"`
DataPath string `json:"data_path,omitempty"` // JSON 数据路径
RefreshSecs int `json:"refresh_secs,omitempty"` // 刷新间隔(秒)
DataPath string `json:"data_path,omitempty"` // JSON data path
RefreshSecs int `json:"refresh_secs,omitempty"` // refresh interval (seconds)
}
// RiskControlConfig 风险控制配置
// RiskControlConfig risk control configuration
type RiskControlConfig struct {
// 最大持仓数量
// maximum number of positions
MaxPositions int `json:"max_positions"`
// BTC/ETH 最大杠杆
// BTC/ETH maximum leverage
BTCETHMaxLeverage int `json:"btc_eth_max_leverage"`
// 山寨币最大杠杆
// altcoin maximum leverage
AltcoinMaxLeverage int `json:"altcoin_max_leverage"`
// 最小风险回报比
// minimum risk-reward ratio
MinRiskRewardRatio float64 `json:"min_risk_reward_ratio"`
// 最大保证金使用率
// maximum margin usage
MaxMarginUsage float64 `json:"max_margin_usage"`
// 单币种最大仓位比例(相对账户净值)
// maximum position ratio per coin (relative to account equity)
MaxPositionRatio float64 `json:"max_position_ratio"`
// 最小开仓金额(USDT
// minimum position size (USDT)
MinPositionSize float64 `json:"min_position_size"`
// 最小信心度
// minimum confidence level
MinConfidence int `json:"min_confidence"`
}
@@ -161,11 +161,11 @@ func (s *StrategyStore) initTables() error {
return err
}
// 创建索引
// create indexes
_, _ = s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_strategies_user_id ON strategies(user_id)`)
_, _ = s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_strategies_is_active ON strategies(is_active)`)
// 触发器:更新时自动更新 updated_at
// trigger: automatically update updated_at on update
_, err = s.db.Exec(`
CREATE TRIGGER IF NOT EXISTS update_strategies_updated_at
AFTER UPDATE ON strategies
@@ -178,14 +178,14 @@ func (s *StrategyStore) initTables() error {
}
func (s *StrategyStore) initDefaultData() error {
// 检查是否已有默认策略
// check if default strategy already exists
var count int
s.db.QueryRow(`SELECT COUNT(*) FROM strategies WHERE is_default = 1`).Scan(&count)
if count > 0 {
return nil
}
// 创建系统默认策略
// create system default strategy
defaultConfig := StrategyConfig{
CoinSource: CoinSourceConfig{
SourceType: "coinpool",
@@ -228,23 +228,23 @@ func (s *StrategyStore) initDefaultData() error {
MinConfidence: 75,
},
PromptSections: PromptSectionsConfig{
RoleDefinition: `# 你是专业的加密货币交易AI
RoleDefinition: `# You are a professional cryptocurrency trading AI
你的任务是根据提供的市场数据做出交易决策你是一位经验丰富的量化交易员擅长技术分析和风险管理`,
TradingFrequency: `# 交易频率认知
Your task is to make trading decisions based on the provided market data. You are an experienced quantitative trader skilled in technical analysis and risk management.`,
TradingFrequency: `# Trading Frequency Awareness
- 优秀交易员每天2-4 每小时0.1-0.2
- 每小时>2 = 过度交易
- 单笔持仓时间30-60分钟
如果你发现自己每个周期都在交易 标准过低若持仓<30分钟就平仓 过于急躁`,
EntryStandards: `# 🎯 开仓标准严格
- Excellent trader: 2-4 trades per day 0.1-0.2 trades per hour
- >2 trades per hour = overtrading
- Single position holding time 30-60 minutes
If you find yourself trading every cycle standards are too low; if closing positions in <30 minutes too impulsive.`,
EntryStandards: `# 🎯 Entry Standards (Strict)
只在多重信号共振时开仓自由运用任何有效的分析方法避免单一指标信号矛盾横盘震荡刚平仓即重启等低质量行为`,
DecisionProcess: `# 📋 决策流程
Only enter positions when multiple signals resonate. Freely use any effective analysis methods, avoid low-quality behaviors such as single indicators, contradictory signals, sideways oscillation, or immediately restarting after closing positions.`,
DecisionProcess: `# 📋 Decision Process
1. 检查持仓 是否该止盈/止损
2. 扫描候选币 + 多时间框 是否存在强信号
3. 先写思维链再输出结构化JSON`,
1. Check positions whether to take profit/stop loss
2. Scan candidate coins + multi-timeframe whether strong signals exist
3. Write chain of thought first, then output structured JSON`,
},
}
@@ -252,13 +252,13 @@ func (s *StrategyStore) initDefaultData() error {
_, err := s.db.Exec(`
INSERT INTO strategies (id, user_id, name, description, is_active, is_default, config)
VALUES ('default', 'system', '默认山寨策略', '系统默认的山寨币交易策略使用 AI500 币种池包含完整的技术指标', 0, 1, ?)
VALUES ('default', 'system', 'Default Altcoin Strategy', 'System default altcoin trading strategy, uses AI500 coin pool, includes complete technical indicators', 0, 1, ?)
`, string(configJSON))
return err
}
// Create 创建策略
// Create create a strategy
func (s *StrategyStore) Create(strategy *Strategy) error {
_, err := s.db.Exec(`
INSERT INTO strategies (id, user_id, name, description, is_active, is_default, config)
@@ -267,7 +267,7 @@ func (s *StrategyStore) Create(strategy *Strategy) error {
return err
}
// Update 更新策略
// Update update a strategy
func (s *StrategyStore) Update(strategy *Strategy) error {
_, err := s.db.Exec(`
UPDATE strategies SET
@@ -277,22 +277,22 @@ func (s *StrategyStore) Update(strategy *Strategy) error {
return err
}
// Delete 删除策略
// Delete delete a strategy
func (s *StrategyStore) Delete(userID, id string) error {
// 不允许删除系统默认策略
// do not allow deleting system default strategy
var isDefault bool
s.db.QueryRow(`SELECT is_default FROM strategies WHERE id = ?`, id).Scan(&isDefault)
if isDefault {
return fmt.Errorf("不能删除系统默认策略")
return fmt.Errorf("cannot delete system default strategy")
}
_, err := s.db.Exec(`DELETE FROM strategies WHERE id = ? AND user_id = ?`, id, userID)
return err
}
// List 获取用户的策略列表
// List get user's strategy list
func (s *StrategyStore) List(userID string) ([]*Strategy, error) {
// 获取用户自己的策略 + 系统默认策略
// get user's own strategies + system default strategy
rows, err := s.db.Query(`
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies
@@ -323,7 +323,7 @@ func (s *StrategyStore) List(userID string) ([]*Strategy, error) {
return strategies, nil
}
// Get 获取单个策略
// Get get a single strategy
func (s *StrategyStore) Get(userID, id string) (*Strategy, error) {
var st Strategy
var createdAt, updatedAt string
@@ -344,7 +344,7 @@ func (s *StrategyStore) Get(userID, id string) (*Strategy, error) {
return &st, nil
}
// GetActive 获取用户当前激活的策略
// GetActive get user's currently active strategy
func (s *StrategyStore) GetActive(userID string) (*Strategy, error) {
var st Strategy
var createdAt, updatedAt string
@@ -358,7 +358,7 @@ func (s *StrategyStore) GetActive(userID string) (*Strategy, error) {
&createdAt, &updatedAt,
)
if err == sql.ErrNoRows {
// 没有激活的策略,返回系统默认策略
// no active strategy, return system default strategy
return s.GetDefault()
}
if err != nil {
@@ -369,7 +369,7 @@ func (s *StrategyStore) GetActive(userID string) (*Strategy, error) {
return &st, nil
}
// GetDefault 获取系统默认策略
// GetDefault get system default strategy
func (s *StrategyStore) GetDefault() (*Strategy, error) {
var st Strategy
var createdAt, updatedAt string
@@ -391,22 +391,22 @@ func (s *StrategyStore) GetDefault() (*Strategy, error) {
return &st, nil
}
// SetActive 设置激活策略(会先取消其他策略的激活状态)
// SetActive set active strategy (will first deactivate other strategies)
func (s *StrategyStore) SetActive(userID, strategyID string) error {
// 开启事务
// begin transaction
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// 先取消该用户所有策略的激活状态
// first deactivate all strategies for the user
_, err = tx.Exec(`UPDATE strategies SET is_active = 0 WHERE user_id = ?`, userID)
if err != nil {
return err
}
// 激活指定策略
// activate specified strategy
_, err = tx.Exec(`UPDATE strategies SET is_active = 1 WHERE id = ? AND (user_id = ? OR is_default = 1)`, strategyID, userID)
if err != nil {
return err
@@ -415,20 +415,20 @@ func (s *StrategyStore) SetActive(userID, strategyID string) error {
return tx.Commit()
}
// Duplicate 复制策略(用于基于默认策略创建自定义策略)
// Duplicate duplicate a strategy (used to create custom strategy based on default strategy)
func (s *StrategyStore) Duplicate(userID, sourceID, newID, newName string) error {
// 获取源策略
// get source strategy
source, err := s.Get(userID, sourceID)
if err != nil {
return fmt.Errorf("获取源策略失败: %w", err)
return fmt.Errorf("failed to get source strategy: %w", err)
}
// 创建新策略
// create new strategy
newStrategy := &Strategy{
ID: newID,
UserID: userID,
Name: newName,
Description: "基于 [" + source.Name + "] 创建",
Description: "Created based on [" + source.Name + "]",
IsActive: false,
IsDefault: false,
Config: source.Config,
@@ -437,20 +437,20 @@ func (s *StrategyStore) Duplicate(userID, sourceID, newID, newName string) error
return s.Create(newStrategy)
}
// ParseConfig 解析策略配置 JSON
// ParseConfig parse strategy configuration JSON
func (s *Strategy) ParseConfig() (*StrategyConfig, error) {
var config StrategyConfig
if err := json.Unmarshal([]byte(s.Config), &config); err != nil {
return nil, fmt.Errorf("解析策略配置失败: %w", err)
return nil, fmt.Errorf("failed to parse strategy configuration: %w", err)
}
return &config, nil
}
// SetConfig 设置策略配置
// SetConfig set strategy configuration
func (s *Strategy) SetConfig(config *StrategyConfig) error {
data, err := json.Marshal(config)
if err != nil {
return fmt.Errorf("序列化策略配置失败: %w", err)
return fmt.Errorf("failed to serialize strategy configuration: %w", err)
}
s.Config = string(data)
return nil
+24 -24
View File
@@ -5,20 +5,20 @@ import (
"time"
)
// TraderStore 交易员存储
// TraderStore trader storage
type TraderStore struct {
db *sql.DB
decryptFunc func(string) string
}
// Trader 交易员配置
// Trader trader configuration
type Trader struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Name string `json:"name"`
AIModelID string `json:"ai_model_id"`
ExchangeID string `json:"exchange_id"`
StrategyID string `json:"strategy_id"` // 关联策略ID
StrategyID string `json:"strategy_id"` // Associated strategy ID
InitialBalance float64 `json:"initial_balance"`
ScanIntervalMinutes int `json:"scan_interval_minutes"`
IsRunning bool `json:"is_running"`
@@ -26,7 +26,7 @@ type Trader struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
// 以下字段已废弃,保留用于向后兼容,新交易员应使用 StrategyID
// Following fields are deprecated, kept for backward compatibility, new traders should use StrategyID
BTCETHLeverage int `json:"btc_eth_leverage,omitempty"`
AltcoinLeverage int `json:"altcoin_leverage,omitempty"`
TradingSymbols string `json:"trading_symbols,omitempty"`
@@ -37,12 +37,12 @@ type Trader struct {
SystemPromptTemplate string `json:"system_prompt_template,omitempty"`
}
// TraderFullConfig 交易员完整配置(包含AI模型、交易所和策略)
// TraderFullConfig trader full configuration (includes AI model, exchange and strategy)
type TraderFullConfig struct {
Trader *Trader
AIModel *AIModel
Exchange *Exchange
Strategy *Strategy // 关联的策略配置
Strategy *Strategy // Associated strategy configuration
}
func (s *TraderStore) initTables() error {
@@ -74,7 +74,7 @@ func (s *TraderStore) initTables() error {
return err
}
// 触发器
// Trigger
_, err = s.db.Exec(`
CREATE TRIGGER IF NOT EXISTS update_traders_updated_at
AFTER UPDATE ON traders
@@ -86,7 +86,7 @@ func (s *TraderStore) initTables() error {
return err
}
// 向后兼容
// Backward compatibility
alterQueries := []string{
`ALTER TABLE traders ADD COLUMN custom_prompt TEXT DEFAULT ''`,
`ALTER TABLE traders ADD COLUMN override_base_prompt BOOLEAN DEFAULT 0`,
@@ -113,7 +113,7 @@ func (s *TraderStore) decrypt(encrypted string) string {
return encrypted
}
// Create 创建交易员
// Create creates trader
func (s *TraderStore) Create(trader *Trader) error {
_, err := s.db.Exec(`
INSERT INTO traders (id, user_id, name, ai_model_id, exchange_id, strategy_id, initial_balance,
@@ -128,7 +128,7 @@ func (s *TraderStore) Create(trader *Trader) error {
return err
}
// List 获取用户的交易员列表
// List gets user's trader list
func (s *TraderStore) List(userID string) ([]*Trader, error) {
rows, err := s.db.Query(`
SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''),
@@ -165,13 +165,13 @@ func (s *TraderStore) List(userID string) ([]*Trader, error) {
return traders, nil
}
// UpdateStatus 更新交易员运行状态
// UpdateStatus updates trader running status
func (s *TraderStore) UpdateStatus(userID, id string, isRunning bool) error {
_, err := s.db.Exec(`UPDATE traders SET is_running = ? WHERE id = ? AND user_id = ?`, isRunning, id, userID)
return err
}
// Update 更新交易员配置
// Update updates trader configuration
func (s *TraderStore) Update(trader *Trader) error {
_, err := s.db.Exec(`
UPDATE traders SET
@@ -184,26 +184,26 @@ func (s *TraderStore) Update(trader *Trader) error {
return err
}
// UpdateInitialBalance 更新初始余额
// UpdateInitialBalance updates initial balance
func (s *TraderStore) UpdateInitialBalance(userID, id string, newBalance float64) error {
_, err := s.db.Exec(`UPDATE traders SET initial_balance = ? WHERE id = ? AND user_id = ?`, newBalance, id, userID)
return err
}
// UpdateCustomPrompt 更新自定义提示词
// UpdateCustomPrompt updates custom prompt
func (s *TraderStore) UpdateCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error {
_, err := s.db.Exec(`UPDATE traders SET custom_prompt = ?, override_base_prompt = ? WHERE id = ? AND user_id = ?`,
customPrompt, overrideBase, id, userID)
return err
}
// Delete 删除交易员
// Delete deletes trader
func (s *TraderStore) Delete(userID, id string) error {
_, err := s.db.Exec(`DELETE FROM traders WHERE id = ? AND user_id = ?`, id, userID)
return err
}
// GetFullConfig 获取交易员完整配置
// GetFullConfig gets trader full configuration
func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig, error) {
var trader Trader
var aiModel AIModel
@@ -255,7 +255,7 @@ func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig,
exchange.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeCreatedAt)
exchange.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeUpdatedAt)
// 解密
// Decrypt
aiModel.APIKey = s.decrypt(aiModel.APIKey)
exchange.APIKey = s.decrypt(exchange.APIKey)
exchange.SecretKey = s.decrypt(exchange.SecretKey)
@@ -264,12 +264,12 @@ func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig,
exchange.LighterPrivateKey = s.decrypt(exchange.LighterPrivateKey)
exchange.LighterAPIKeyPrivateKey = s.decrypt(exchange.LighterAPIKeyPrivateKey)
// 加载关联的策略
// Load associated strategy
var strategy *Strategy
if trader.StrategyID != "" {
strategy, _ = s.getStrategyByID(userID, trader.StrategyID)
}
// 如果没有关联策略,获取用户的激活策略或默认策略
// If no associated strategy, get user's active strategy or default strategy
if strategy == nil {
strategy, _ = s.getActiveOrDefaultStrategy(userID)
}
@@ -282,7 +282,7 @@ func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig,
}, nil
}
// getStrategyByID 内部方法:根据ID获取策略
// getStrategyByID internal method: gets strategy by ID
func (s *TraderStore) getStrategyByID(userID, strategyID string) (*Strategy, error) {
var strategy Strategy
var createdAt, updatedAt string
@@ -301,12 +301,12 @@ func (s *TraderStore) getStrategyByID(userID, strategyID string) (*Strategy, err
return &strategy, nil
}
// getActiveOrDefaultStrategy 内部方法:获取用户激活的策略或系统默认策略
// getActiveOrDefaultStrategy internal method: gets user's active strategy or system default strategy
func (s *TraderStore) getActiveOrDefaultStrategy(userID string) (*Strategy, error) {
var strategy Strategy
var createdAt, updatedAt string
// 先尝试获取用户激活的策略
// First try to get user's active strategy
err := s.db.QueryRow(`
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies WHERE user_id = ? AND is_active = 1
@@ -320,7 +320,7 @@ func (s *TraderStore) getActiveOrDefaultStrategy(userID string) (*Strategy, erro
return &strategy, nil
}
// 回退到系统默认策略
// Fallback to system default strategy
err = s.db.QueryRow(`
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
FROM strategies WHERE is_default = 1 LIMIT 1
@@ -336,7 +336,7 @@ func (s *TraderStore) getActiveOrDefaultStrategy(userID string) (*Strategy, erro
return &strategy, nil
}
// ListAll 获取所有用户的交易员列表
// ListAll gets all users' trader list
func (s *TraderStore) ListAll() ([]*Trader, error) {
rows, err := s.db.Query(`
SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''),
+11 -11
View File
@@ -7,12 +7,12 @@ import (
"time"
)
// UserStore 用户存储
// UserStore user storage
type UserStore struct {
db *sql.DB
}
// User 用户
// User user
type User struct {
ID string `json:"id"`
Email string `json:"email"`
@@ -23,7 +23,7 @@ type User struct {
UpdatedAt time.Time `json:"updated_at"`
}
// GenerateOTPSecret 生成OTP密钥
// GenerateOTPSecret generates OTP secret
func GenerateOTPSecret() (string, error) {
secret := make([]byte, 20)
_, err := rand.Read(secret)
@@ -49,7 +49,7 @@ func (s *UserStore) initTables() error {
return err
}
// 触发器
// Trigger
_, err = s.db.Exec(`
CREATE TRIGGER IF NOT EXISTS update_users_updated_at
AFTER UPDATE ON users
@@ -64,7 +64,7 @@ func (s *UserStore) initTables() error {
return nil
}
// Create 创建用户
// Create creates user
func (s *UserStore) Create(user *User) error {
_, err := s.db.Exec(`
INSERT INTO users (id, email, password_hash, otp_secret, otp_verified)
@@ -73,7 +73,7 @@ func (s *UserStore) Create(user *User) error {
return err
}
// GetByEmail 通过邮箱获取用户
// GetByEmail gets user by email
func (s *UserStore) GetByEmail(email string) (*User, error) {
var user User
var createdAt, updatedAt string
@@ -92,7 +92,7 @@ func (s *UserStore) GetByEmail(email string) (*User, error) {
return &user, nil
}
// GetByID 通过ID获取用户
// GetByID gets user by ID
func (s *UserStore) GetByID(userID string) (*User, error) {
var user User
var createdAt, updatedAt string
@@ -111,7 +111,7 @@ func (s *UserStore) GetByID(userID string) (*User, error) {
return &user, nil
}
// GetAllIDs 获取所有用户ID
// GetAllIDs gets all user IDs
func (s *UserStore) GetAllIDs() ([]string, error) {
rows, err := s.db.Query(`SELECT id FROM users ORDER BY id`)
if err != nil {
@@ -130,13 +130,13 @@ func (s *UserStore) GetAllIDs() ([]string, error) {
return userIDs, nil
}
// UpdateOTPVerified 更新OTP验证状态
// UpdateOTPVerified updates OTP verification status
func (s *UserStore) UpdateOTPVerified(userID string, verified bool) error {
_, err := s.db.Exec(`UPDATE users SET otp_verified = ? WHERE id = ?`, verified, userID)
return err
}
// UpdatePassword 更新密码
// UpdatePassword updates password
func (s *UserStore) UpdatePassword(userID, passwordHash string) error {
_, err := s.db.Exec(`
UPDATE users SET password_hash = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?
@@ -144,7 +144,7 @@ func (s *UserStore) UpdatePassword(userID, passwordHash string) error {
return err
}
// EnsureAdmin 确保admin用户存在
// EnsureAdmin ensures admin user exists
func (s *UserStore) EnsureAdmin() error {
var count int
err := s.db.QueryRow(`SELECT COUNT(*) FROM users WHERE id = 'admin'`).Scan(&count)
+195 -195
View File
File diff suppressed because it is too large Load Diff
+30 -30
View File
@@ -13,27 +13,27 @@ import (
)
// ============================================================
// 一、AsterTraderTestSuite - 继承 base test suite
// 1. AsterTraderTestSuite - inherits base test suite
// ============================================================
// AsterTraderTestSuite Aster交易器测试套件
// 继承 TraderTestSuite 并添加 Aster 特定的 mock 逻辑
// AsterTraderTestSuite Aster trader test suite
// Inherits TraderTestSuite and adds Aster specific mock logic
type AsterTraderTestSuite struct {
*TraderTestSuite // 嵌入基础测试套件
*TraderTestSuite // Embeds base test suite
mockServer *httptest.Server
}
// NewAsterTraderTestSuite 创建 Aster 测试套件
// NewAsterTraderTestSuite creates Aster test suite
func NewAsterTraderTestSuite(t *testing.T) *AsterTraderTestSuite {
// 创建 mock HTTP 服务器
// Create mock HTTP server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 根据不同的 URL 路径返回不同的 mock 响应
// Return different mock responses based on URL path
path := r.URL.Path
var respBody interface{}
switch {
// Mock GetBalance - /fapi/v3/balance (返回数组)
// Mock GetBalance - /fapi/v3/balance (returns array)
case path == "/fapi/v3/balance":
respBody = []map[string]interface{}{
{
@@ -65,19 +65,19 @@ func NewAsterTraderTestSuite(t *testing.T) *AsterTraderTestSuite {
},
}
// Mock GetMarketPrice - /fapi/v3/ticker/price (返回单个对象)
// Mock GetMarketPrice - /fapi/v3/ticker/price (returns single object)
case path == "/fapi/v3/ticker/price":
// 从查询参数获取symbol
// Get symbol from query parameters
symbol := r.URL.Query().Get("symbol")
if symbol == "" {
symbol = "BTCUSDT"
}
// 根据symbol返回不同价格
// Return different price based on symbol
price := "50000.00"
if symbol == "ETHUSDT" {
price = "3000.00"
} else if symbol == "INVALIDUSDT" {
// 返回错误响应
// Return error response
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]interface{}{
"code": -1121,
@@ -133,7 +133,7 @@ func NewAsterTraderTestSuite(t *testing.T) *AsterTraderTestSuite {
// Mock CreateOrder - /fapi/v1/order and /fapi/v3/order
case (path == "/fapi/v1/order" || path == "/fapi/v3/order") && r.Method == "POST":
// 从请求中解析参数以确定symbol
// Parse parameters from request to determine symbol
bodyBytes, _ := io.ReadAll(r.Body)
var orderParams map[string]interface{}
json.Unmarshal(bodyBytes, &orderParams)
@@ -182,26 +182,26 @@ func NewAsterTraderTestSuite(t *testing.T) *AsterTraderTestSuite {
respBody = map[string]interface{}{}
}
// 序列化响应
// Serialize response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(respBody)
}))
// 生成一个测试用的私钥
// Generate a private key for testing
privateKey, _ := crypto.GenerateKey()
// 创建 mock trader,使用 mock server URL
// Create mock trader using mock server's URL
trader := &AsterTrader{
ctx: context.Background(),
user: "0x1234567890123456789012345678901234567890",
signer: "0xabcdefabcdefabcdefabcdefabcdefabcdefabcd",
privateKey: privateKey,
client: mockServer.Client(),
baseURL: mockServer.URL, // 使用 mock server URL
baseURL: mockServer.URL, // Use mock server's URL
symbolPrecision: make(map[string]SymbolPrecision),
}
// 创建基础套件
// Create base suite
baseSuite := NewTraderTestSuite(t, trader)
return &AsterTraderTestSuite{
@@ -210,7 +210,7 @@ func NewAsterTraderTestSuite(t *testing.T) *AsterTraderTestSuite {
}
}
// Cleanup 清理资源
// Cleanup cleans up resources
func (s *AsterTraderTestSuite) Cleanup() {
if s.mockServer != nil {
s.mockServer.Close()
@@ -219,29 +219,29 @@ func (s *AsterTraderTestSuite) Cleanup() {
}
// ============================================================
// 二、使用 AsterTraderTestSuite 运行通用测试
// 2. Run common tests using AsterTraderTestSuite
// ============================================================
// TestAsterTrader_InterfaceCompliance 测试接口兼容性
// TestAsterTrader_InterfaceCompliance tests interface compliance
func TestAsterTrader_InterfaceCompliance(t *testing.T) {
var _ Trader = (*AsterTrader)(nil)
}
// TestAsterTrader_CommonInterface 使用测试套件运行所有通用接口测试
// TestAsterTrader_CommonInterface runs all common interface tests using test suite
func TestAsterTrader_CommonInterface(t *testing.T) {
// 创建测试套件
// Create test suite
suite := NewAsterTraderTestSuite(t)
defer suite.Cleanup()
// 运行所有通用接口测试
// Run all common interface tests
suite.RunAllTests()
}
// ============================================================
// 三、Aster 特定功能的单元测试
// 3. Aster specific unit tests
// ============================================================
// TestNewAsterTrader 测试创建 Aster 交易器
// TestNewAsterTrader tests creating Aster trader
func TestNewAsterTrader(t *testing.T) {
tests := []struct {
name string
@@ -252,22 +252,22 @@ func TestNewAsterTrader(t *testing.T) {
errorContains string
}{
{
name: "成功创建",
name: "successful creation",
user: "0x1234567890123456789012345678901234567890",
signer: "0xabcdefabcdefabcdefabcdefabcdefabcdefabcd",
privateKeyHex: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
wantError: false,
},
{
name: "无效私钥格式",
name: "invalid private key format",
user: "0x1234567890123456789012345678901234567890",
signer: "0xabcdefabcdefabcdefabcdefabcdefabcdefabcd",
privateKeyHex: "invalid_key",
wantError: true,
errorContains: "解析私钥失败",
errorContains: "failed to parse private key",
},
{
name: "带0x前缀的私钥",
name: "private key with 0x prefix",
user: "0x1234567890123456789012345678901234567890",
signer: "0xabcdefabcdefabcdefabcdefabcdefabcdefabcd",
privateKeyHex: "0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
+358 -358
View File
File diff suppressed because it is too large Load Diff
+129 -129
View File
@@ -17,44 +17,44 @@ import (
)
// ============================================================
// AutoTraderTestSuite - 使用 testify/suite 进行结构化测试
// AutoTraderTestSuite - Structured testing using testify/suite
// ============================================================
// AutoTraderTestSuite AutoTrader 的测试套件
// 使用 testify/suite 来组织测试,提供统一的 setup/teardown mock 管理
// AutoTraderTestSuite Test suite for AutoTrader
// Uses testify/suite to organize tests, providing unified setup/teardown and mock management
type AutoTraderTestSuite struct {
suite.Suite
// 测试对象
// Test subject
autoTrader *AutoTrader
// Mock 依赖
// Mock dependencies
mockTrader *MockTrader
mockStore *store.Store
// gomonkey patches
patches *gomonkey.Patches
// 测试配置
// Test configuration
config AutoTraderConfig
}
// SetupSuite 在整个测试套件开始前执行一次
// SetupSuite Executed once before the entire test suite starts
func (s *AutoTraderTestSuite) SetupSuite() {
// 可以在这里初始化一些全局资源
// Can initialize some global resources here
}
// TearDownSuite 在整个测试套件结束后执行一次
// TearDownSuite Executed once after the entire test suite ends
func (s *AutoTraderTestSuite) TearDownSuite() {
// 清理全局资源
// Clean up global resources
}
// SetupTest 在每个测试用例开始前执行
// SetupTest Executed before each test case starts
func (s *AutoTraderTestSuite) SetupTest() {
// 初始化 patches
// Initialize patches
s.patches = gomonkey.NewPatches()
// 创建 mock 对象
// Create mock objects
s.mockTrader = &MockTrader{
balance: map[string]interface{}{
"totalWalletBalance": 10000.0,
@@ -65,10 +65,10 @@ func (s *AutoTraderTestSuite) SetupTest() {
}
// 创建临时store(使用nil表示测试中不需要实际的store)
// Create temporary store (using nil means no actual store needed in test)
s.mockStore = nil
// 设置默认配置
// Set default configuration
s.config = AutoTraderConfig{
ID: "test_trader",
Name: "Test Trader",
@@ -82,7 +82,7 @@ func (s *AutoTraderTestSuite) SetupTest() {
IsCrossMargin: true,
}
// 创建 AutoTrader 实例(直接构造,不调用 NewAutoTrader 以避免外部依赖)
// Create AutoTrader instance (direct construction, don't call NewAutoTrader to avoid external dependencies)
s.autoTrader = &AutoTrader{
id: s.config.ID,
name: s.config.Name,
@@ -90,7 +90,7 @@ func (s *AutoTraderTestSuite) SetupTest() {
exchange: s.config.Exchange,
config: s.config,
trader: s.mockTrader,
mcpClient: nil, // 测试中不需要实际的 MCP Client
mcpClient: nil, // No actual MCP Client needed in tests
store: s.mockStore,
initialBalance: s.config.InitialBalance,
systemPromptTemplate: s.config.SystemPromptTemplate,
@@ -108,16 +108,16 @@ func (s *AutoTraderTestSuite) SetupTest() {
}
}
// TearDownTest 在每个测试用例结束后执行
// TearDownTest Executed after each test case ends
func (s *AutoTraderTestSuite) TearDownTest() {
// 重置 gomonkey patches
// Reset gomonkey patches
if s.patches != nil {
s.patches.Reset()
}
}
// ============================================================
// 层次 1: 工具函数测试
// Level 1: Utility function tests
// ============================================================
func (s *AutoTraderTestSuite) TestSortDecisionsByPriority() {
@@ -126,7 +126,7 @@ func (s *AutoTraderTestSuite) TestSortDecisionsByPriority() {
input []decision.Decision
}{
{
name: "混合决策_验证优先级排序",
name: "Mixed decisions - verify priority sorting",
input: []decision.Decision{
{Action: "open_long", Symbol: "BTCUSDT"},
{Action: "close_short", Symbol: "ETHUSDT"},
@@ -141,9 +141,9 @@ func (s *AutoTraderTestSuite) TestSortDecisionsByPriority() {
s.Run(tt.name, func() {
result := sortDecisionsByPriority(tt.input)
s.Equal(len(tt.input), len(result), "结果长度应该相同")
s.Equal(len(tt.input), len(result), "Result length should be the same")
// 验证优先级是否递增
// Verify priority is increasing
getActionPriority := func(action string) int {
switch action {
case "close_long", "close_short":
@@ -160,7 +160,7 @@ func (s *AutoTraderTestSuite) TestSortDecisionsByPriority() {
for i := 0; i < len(result)-1; i++ {
currentPriority := getActionPriority(result[i].Action)
nextPriority := getActionPriority(result[i+1].Action)
s.LessOrEqual(currentPriority, nextPriority, "优先级应该递增")
s.LessOrEqual(currentPriority, nextPriority, "Priority should be increasing")
}
})
}
@@ -172,10 +172,10 @@ func (s *AutoTraderTestSuite) TestNormalizeSymbol() {
input string
expected string
}{
{"已经是标准格式", "BTCUSDT", "BTCUSDT"},
{"小写转大写", "btcusdt", "BTCUSDT"},
{"只有币种名称_添加USDT", "BTC", "BTCUSDT"},
{"带空格_去除空格", " BTC ", "BTCUSDT"},
{"Already standard format", "BTCUSDT", "BTCUSDT"},
{"Lowercase to uppercase", "btcusdt", "BTCUSDT"},
{"Coin name only - add USDT", "BTC", "BTCUSDT"},
{"With spaces - remove spaces", " BTC ", "BTCUSDT"},
}
for _, tt := range tests {
@@ -187,7 +187,7 @@ func (s *AutoTraderTestSuite) TestNormalizeSymbol() {
}
// ============================================================
// 层次 2: Getter/Setter 测试
// Level 2: Getter/Setter tests
// ============================================================
func (s *AutoTraderTestSuite) TestGettersAndSetters() {
@@ -211,38 +211,38 @@ func (s *AutoTraderTestSuite) TestGettersAndSetters() {
}
// ============================================================
// 层次 3: PeakPnL 缓存测试
// Level 3: PeakPnL cache tests
// ============================================================
func (s *AutoTraderTestSuite) TestPeakPnLCache() {
s.Run("UpdatePeakPnL_首次记录", func() {
s.Run("UpdatePeakPnL_first record", func() {
s.autoTrader.UpdatePeakPnL("BTCUSDT", "long", 10.5)
cache := s.autoTrader.GetPeakPnLCache()
s.Equal(10.5, cache["BTCUSDT_long"])
})
s.Run("UpdatePeakPnL_更新为更高值", func() {
s.Run("UpdatePeakPnL_update to higher value", func() {
s.autoTrader.UpdatePeakPnL("BTCUSDT", "long", 15.0)
cache := s.autoTrader.GetPeakPnLCache()
s.Equal(15.0, cache["BTCUSDT_long"])
})
s.Run("UpdatePeakPnL_不更新为更低值", func() {
s.Run("UpdatePeakPnL_do not update to lower value", func() {
s.autoTrader.UpdatePeakPnL("BTCUSDT", "long", 12.0)
cache := s.autoTrader.GetPeakPnLCache()
s.Equal(15.0, cache["BTCUSDT_long"], "峰值应保持不变")
s.Equal(15.0, cache["BTCUSDT_long"], "Peak value should remain unchanged")
})
s.Run("ClearPeakPnLCache", func() {
s.autoTrader.ClearPeakPnLCache("BTCUSDT", "long")
cache := s.autoTrader.GetPeakPnLCache()
_, exists := cache["BTCUSDT_long"]
s.False(exists, "应该被清除")
s.False(exists, "Should be cleared")
})
}
// ============================================================
// 层次 4: GetStatus 测试
// Level 4: GetStatus tests
// ============================================================
func (s *AutoTraderTestSuite) TestGetStatus() {
@@ -261,7 +261,7 @@ func (s *AutoTraderTestSuite) TestGetStatus() {
}
// ============================================================
// 层次 5: GetAccountInfo 测试
// Level 5: GetAccountInfo tests
// ============================================================
func (s *AutoTraderTestSuite) TestGetAccountInfo() {
@@ -270,29 +270,29 @@ func (s *AutoTraderTestSuite) TestGetAccountInfo() {
s.NoError(err)
s.NotNil(accountInfo)
// 验证核心字段和数值
// Verify core fields and values
s.Equal(10100.0, accountInfo["total_equity"]) // 10000 + 100
s.Equal(8000.0, accountInfo["available_balance"])
s.Equal(100.0, accountInfo["total_pnl"]) // 10100 - 10000
}
// ============================================================
// 层次 6: GetPositions 测试
// Level 6: GetPositions tests
// ============================================================
func (s *AutoTraderTestSuite) TestGetPositions() {
s.Run("空持仓", func() {
s.Run("No positions", func() {
positions, err := s.autoTrader.GetPositions()
s.NoError(err)
// positions 可能是 nil 或空数组,两者都是有效的
// positions may be nil or empty array, both are valid
if positions != nil {
s.Equal(0, len(positions))
}
})
s.Run("有持仓", func() {
// 设置 mock 持仓
s.Run("Has positions", func() {
// Set mock positions
s.mockTrader.positions = []map[string]interface{}{
{
"symbol": "BTCUSDT",
@@ -320,13 +320,13 @@ func (s *AutoTraderTestSuite) TestGetPositions() {
}
// ============================================================
// 层次 7: getCandidateCoins 测试
// Level 7: getCandidateCoins tests
// ============================================================
func (s *AutoTraderTestSuite) TestGetCandidateCoins() {
s.Run("使用数据库默认币种", func() {
s.Run("Use database default coins", func() {
s.autoTrader.defaultCoins = []string{"BTC", "ETH", "BNB"}
s.autoTrader.tradingCoins = []string{} // 空的自定义币种
s.autoTrader.tradingCoins = []string{} // Empty custom coins
coins, err := s.autoTrader.getCandidateCoins()
@@ -338,7 +338,7 @@ func (s *AutoTraderTestSuite) TestGetCandidateCoins() {
s.Contains(coins[0].Sources, "default")
})
s.Run("使用自定义币种", func() {
s.Run("Use custom coins", func() {
s.autoTrader.tradingCoins = []string{"SOL", "AVAX"}
coins, err := s.autoTrader.getCandidateCoins()
@@ -350,9 +350,9 @@ func (s *AutoTraderTestSuite) TestGetCandidateCoins() {
s.Contains(coins[0].Sources, "custom")
})
s.Run("使用AI500+OI作为fallback", func() {
s.autoTrader.defaultCoins = []string{} // 空的默认币种
s.autoTrader.tradingCoins = []string{} // 空的自定义币种
s.Run("Use AI500+OI as fallback", func() {
s.autoTrader.defaultCoins = []string{} // Empty default coins
s.autoTrader.tradingCoins = []string{} // Empty custom coins
// Mock pool.GetMergedCoinPool
s.patches.ApplyFunc(pool.GetMergedCoinPool, func(ai500Limit int) (*pool.MergedCoinPool, error) {
@@ -373,7 +373,7 @@ func (s *AutoTraderTestSuite) TestGetCandidateCoins() {
}
// ============================================================
// 层次 8: buildTradingContext 测试
// Level 8: buildTradingContext tests
// ============================================================
func (s *AutoTraderTestSuite) TestBuildTradingContext() {
@@ -387,7 +387,7 @@ func (s *AutoTraderTestSuite) TestBuildTradingContext() {
s.NoError(err)
s.NotNil(ctx)
// 验证核心字段
// Verify core fields
s.Equal(10100.0, ctx.Account.TotalEquity) // 10000 + 100
s.Equal(8000.0, ctx.Account.AvailableBalance)
s.Equal(10, ctx.BTCETHLeverage)
@@ -395,10 +395,10 @@ func (s *AutoTraderTestSuite) TestBuildTradingContext() {
}
// ============================================================
// 层次 9: 交易执行测试
// Level 9: Trade execution tests
// ============================================================
// TestExecuteOpenPosition 测试开仓操作(多空通用)
// TestExecuteOpenPosition Test open position operation (common for long and short)
func (s *AutoTraderTestSuite) TestExecuteOpenPosition() {
tests := []struct {
name string
@@ -410,7 +410,7 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() {
executeFn func(*decision.Decision, *store.DecisionAction) error
}{
{
name: "成功开多仓",
name: "Successfully open long",
action: "open_long",
expectedOrder: 123456,
availBalance: 8000.0,
@@ -419,7 +419,7 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() {
},
},
{
name: "成功开空仓",
name: "Successfully open short",
action: "open_short",
expectedOrder: 123457,
availBalance: 8000.0,
@@ -428,39 +428,39 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() {
},
},
{
name: "多仓_保证金不足",
name: "Long - insufficient margin",
action: "open_long",
availBalance: 0.0,
expectedErr: "保证金不足",
expectedErr: "Insufficient margin",
executeFn: func(d *decision.Decision, a *store.DecisionAction) error {
return s.autoTrader.executeOpenLongWithRecord(d, a)
},
},
{
name: "空仓_保证金不足",
name: "Short - insufficient margin",
action: "open_short",
availBalance: 0.0,
expectedErr: "保证金不足",
expectedErr: "Insufficient margin",
executeFn: func(d *decision.Decision, a *store.DecisionAction) error {
return s.autoTrader.executeOpenShortWithRecord(d, a)
},
},
{
name: "多仓_已有同方向持仓",
name: "Long - already has same side position",
action: "open_long",
existingSide: "long",
availBalance: 8000.0,
expectedErr: "已有多仓",
expectedErr: "Already has long position",
executeFn: func(d *decision.Decision, a *store.DecisionAction) error {
return s.autoTrader.executeOpenLongWithRecord(d, a)
},
},
{
name: "空仓_已有同方向持仓",
name: "Short - already has same side position",
action: "open_short",
existingSide: "short",
availBalance: 8000.0,
expectedErr: "已有空仓",
expectedErr: "Already has short position",
executeFn: func(d *decision.Decision, a *store.DecisionAction) error {
return s.autoTrader.executeOpenShortWithRecord(d, a)
},
@@ -496,14 +496,14 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() {
s.Equal(50000.0, actionRecord.Price)
}
// 恢复默认状态
// Restore default state
s.mockTrader.balance["availableBalance"] = 8000.0
s.mockTrader.positions = []map[string]interface{}{}
})
}
}
// TestExecuteClosePosition 测试平仓操作(多空通用)
// TestExecuteClosePosition Test close position operation (common for long and short)
func (s *AutoTraderTestSuite) TestExecuteClosePosition() {
tests := []struct {
name string
@@ -513,7 +513,7 @@ func (s *AutoTraderTestSuite) TestExecuteClosePosition() {
executeFn func(*decision.Decision, *store.DecisionAction) error
}{
{
name: "成功平多仓",
name: "Successfully close long",
action: "close_long",
currentPrice: 51000.0,
expectedOrder: 123458,
@@ -522,7 +522,7 @@ func (s *AutoTraderTestSuite) TestExecuteClosePosition() {
},
},
{
name: "成功平空仓",
name: "Successfully close short",
action: "close_short",
currentPrice: 49000.0,
expectedOrder: 123459,
@@ -552,7 +552,7 @@ func (s *AutoTraderTestSuite) TestExecuteClosePosition() {
}
// ============================================================
// 层次 10: executeDecisionWithRecord 路由测试
// Level 10: executeDecisionWithRecord routing tests
// ============================================================
func (s *AutoTraderTestSuite) TestExecuteDecisionWithRecord() {
@@ -564,7 +564,7 @@ func (s *AutoTraderTestSuite) TestExecuteDecisionWithRecord() {
}, nil
})
s.Run("路由到open_long", func() {
s.Run("Route to open_long", func() {
decision := &decision.Decision{
Action: "open_long",
Symbol: "BTCUSDT",
@@ -577,7 +577,7 @@ func (s *AutoTraderTestSuite) TestExecuteDecisionWithRecord() {
s.NoError(err)
})
s.Run("路由到close_long", func() {
s.Run("Route to close_long", func() {
decision := &decision.Decision{
Action: "close_long",
Symbol: "BTCUSDT",
@@ -588,7 +588,7 @@ func (s *AutoTraderTestSuite) TestExecuteDecisionWithRecord() {
s.NoError(err)
})
s.Run("路由到hold_不执行", func() {
s.Run("Route to hold - no execution", func() {
decision := &decision.Decision{
Action: "hold",
Symbol: "BTCUSDT",
@@ -599,7 +599,7 @@ func (s *AutoTraderTestSuite) TestExecuteDecisionWithRecord() {
s.NoError(err)
})
s.Run("未知action返回错误", func() {
s.Run("Unknown action returns error", func() {
decision := &decision.Decision{
Action: "unknown_action",
Symbol: "BTCUSDT",
@@ -608,7 +608,7 @@ func (s *AutoTraderTestSuite) TestExecuteDecisionWithRecord() {
err := s.autoTrader.executeDecisionWithRecord(decision, actionRecord)
s.Error(err)
s.Contains(err.Error(), "未知的action")
s.Contains(err.Error(), "Unknown action")
})
}
@@ -624,18 +624,18 @@ func (s *AutoTraderTestSuite) TestCheckPositionDrawdown() {
skipCacheCheck bool
}{
{
name: "获取持仓失败_不panic",
name: "Get positions failed - no panic",
setupFailures: func() { s.mockTrader.shouldFailPositions = true },
cleanupFailures: func() { s.mockTrader.shouldFailPositions = false },
skipCacheCheck: true,
},
{
name: "无持仓_不panic",
name: "No positions - no panic",
setupPositions: func() { s.mockTrader.positions = []map[string]interface{}{} },
skipCacheCheck: true,
},
{
name: "收益不足5%_不触发平仓",
name: "Profit less than 5% - no close",
setupPositions: func() {
s.mockTrader.positions = []map[string]interface{}{
{"symbol": "BTCUSDT", "side": "long", "positionAmt": 0.1, "entryPrice": 50000.0, "markPrice": 50150.0, "leverage": 10.0},
@@ -645,7 +645,7 @@ func (s *AutoTraderTestSuite) TestCheckPositionDrawdown() {
skipCacheCheck: true,
},
{
name: "回撤不足40%_不触发平仓",
name: "Drawdown less than 40% - no close",
setupPositions: func() {
s.mockTrader.positions = []map[string]interface{}{
{"symbol": "BTCUSDT", "side": "long", "positionAmt": 0.1, "entryPrice": 50000.0, "markPrice": 50400.0, "leverage": 10.0},
@@ -655,7 +655,7 @@ func (s *AutoTraderTestSuite) TestCheckPositionDrawdown() {
skipCacheCheck: true,
},
{
name: "多头_触发回撤平仓",
name: "Long - trigger drawdown close",
setupPositions: func() {
s.mockTrader.positions = []map[string]interface{}{
{"symbol": "BTCUSDT", "side": "long", "positionAmt": 0.1, "entryPrice": 50000.0, "markPrice": 50300.0, "leverage": 10.0},
@@ -666,7 +666,7 @@ func (s *AutoTraderTestSuite) TestCheckPositionDrawdown() {
shouldClearCache: true,
},
{
name: "空头_触发回撤平仓",
name: "Short - trigger drawdown close",
setupPositions: func() {
s.mockTrader.positions = []map[string]interface{}{
{"symbol": "ETHUSDT", "side": "short", "positionAmt": -0.5, "entryPrice": 3000.0, "markPrice": 2982.0, "leverage": 10.0},
@@ -677,7 +677,7 @@ func (s *AutoTraderTestSuite) TestCheckPositionDrawdown() {
shouldClearCache: true,
},
{
name: "多头_平仓失败_保留缓存",
name: "Long - close failed - keep cache",
setupPositions: func() {
s.mockTrader.positions = []map[string]interface{}{
{"symbol": "BTCUSDT", "side": "long", "positionAmt": 0.1, "entryPrice": 50000.0, "markPrice": 50300.0, "leverage": 10.0},
@@ -690,7 +690,7 @@ func (s *AutoTraderTestSuite) TestCheckPositionDrawdown() {
shouldClearCache: false,
},
{
name: "空头_平仓失败_保留缓存",
name: "Short - close failed - keep cache",
setupPositions: func() {
s.mockTrader.positions = []map[string]interface{}{
{"symbol": "ETHUSDT", "side": "short", "positionAmt": -0.5, "entryPrice": 3000.0, "markPrice": 2982.0, "leverage": 10.0},
@@ -725,23 +725,23 @@ func (s *AutoTraderTestSuite) TestCheckPositionDrawdown() {
cache := s.autoTrader.GetPeakPnLCache()
_, exists := cache[tt.expectedCacheKey]
if tt.shouldClearCache {
s.False(exists, "峰值缓存应该被清理")
s.False(exists, "Peak PnL cache should be cleared")
} else {
s.True(exists, "峰值缓存不应该被清理")
s.True(exists, "Peak PnL cache should not be cleared")
}
}
// 清理状态
// Clean up state
s.mockTrader.positions = []map[string]interface{}{}
})
}
}
// ============================================================
// Mock 实现
// Mock implementations
// ============================================================
// MockDatabase 模拟数据库
// MockDatabase Mock database
type MockDatabase struct {
shouldFail bool
}
@@ -753,7 +753,7 @@ func (m *MockDatabase) UpdateTraderInitialBalance(userID, traderID string, newBa
return nil
}
// MockTrader 增强版(添加错误控制)
// MockTrader Enhanced version (with error control)
type MockTrader struct {
balance map[string]interface{}
positions []map[string]interface{}
@@ -866,16 +866,16 @@ func (m *MockTrader) FormatQuantity(symbol string, quantity float64) (string, er
}
// ============================================================
// 测试套件入口
// Test suite entry point
// ============================================================
// TestAutoTraderTestSuite 运行 AutoTrader 测试套件
// TestAutoTraderTestSuite Run AutoTrader test suite
func TestAutoTraderTestSuite(t *testing.T) {
suite.Run(t, new(AutoTraderTestSuite))
}
// ============================================================
// 独立的单元测试 - calculatePnLPercentage 函数测试
// Independent unit tests - calculatePnLPercentage function tests
// ============================================================
func TestCalculatePnLPercentage(t *testing.T) {
@@ -886,58 +886,58 @@ func TestCalculatePnLPercentage(t *testing.T) {
expected float64
}{
{
name: "正常盈利 - 10倍杠杆",
unrealizedPnl: 100.0, // 盈利 100 USDT
marginUsed: 1000.0, // 保证金 1000 USDT
expected: 10.0, // 10% 收益率
name: "Normal profit - 10x leverage",
unrealizedPnl: 100.0, // 100 USDT profit
marginUsed: 1000.0, // 1000 USDT margin
expected: 10.0, // 10% return
},
{
name: "正常亏损 - 10倍杠杆",
unrealizedPnl: -50.0, // 亏损 50 USDT
marginUsed: 1000.0, // 保证金 1000 USDT
expected: -5.0, // -5% 收益率
name: "Normal loss - 10x leverage",
unrealizedPnl: -50.0, // 50 USDT loss
marginUsed: 1000.0, // 1000 USDT margin
expected: -5.0, // -5% return
},
{
name: "高杠杆盈利 - 价格上涨1%,20倍杠杆",
unrealizedPnl: 200.0, // 盈利 200 USDT
marginUsed: 1000.0, // 保证金 1000 USDT
expected: 20.0, // 20% 收益率
name: "High leverage profit - 1% price increase, 20x leverage",
unrealizedPnl: 200.0, // 200 USDT profit
marginUsed: 1000.0, // 1000 USDT margin
expected: 20.0, // 20% return
},
{
name: "保证金为0 - 边界情况",
name: "Zero margin - edge case",
unrealizedPnl: 100.0,
marginUsed: 0.0,
expected: 0.0, // 应该返回 0 而不是除以零错误
expected: 0.0, // Should return 0 instead of division by zero error
},
{
name: "负保证金 - 边界情况",
name: "Negative margin - edge case",
unrealizedPnl: 100.0,
marginUsed: -1000.0,
expected: 0.0, // 应该返回 0(异常情况)
expected: 0.0, // Should return 0 (abnormal case)
},
{
name: "盈亏为0",
name: "Zero PnL",
unrealizedPnl: 0.0,
marginUsed: 1000.0,
expected: 0.0,
},
{
name: "小额交易",
name: "Small trade",
unrealizedPnl: 0.5,
marginUsed: 10.0,
expected: 5.0,
},
{
name: "大额盈利",
name: "Large profit",
unrealizedPnl: 5000.0,
marginUsed: 10000.0,
expected: 50.0,
},
{
name: "极小保证金",
name: "Tiny margin",
unrealizedPnl: 1.0,
marginUsed: 0.01,
expected: 10000.0, // 100倍收益率
expected: 10000.0, // 100x return
},
}
@@ -945,7 +945,7 @@ func TestCalculatePnLPercentage(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
result := calculatePnLPercentage(tt.unrealizedPnl, tt.marginUsed)
// 使用精度比较,避免浮点数误差
// Use precision comparison to avoid floating point errors
if math.Abs(result-tt.expected) > 0.0001 {
t.Errorf("calculatePnLPercentage(%v, %v) = %v, want %v",
tt.unrealizedPnl, tt.marginUsed, result, tt.expected)
@@ -954,38 +954,38 @@ func TestCalculatePnLPercentage(t *testing.T) {
}
}
// TestCalculatePnLPercentage_RealWorldScenarios 真实场景测试
// TestCalculatePnLPercentage_RealWorldScenarios Real world scenario tests
func TestCalculatePnLPercentage_RealWorldScenarios(t *testing.T) {
t.Run("BTC 10倍杠杆,价格上涨2%", func(t *testing.T) {
// 开仓:1000 USDT 保证金,10倍杠杆 = 10000 USDT 仓位
// 价格上涨 2% = 200 USDT 盈利
// 收益率 = 200 / 1000 = 20%
t.Run("BTC 10x leverage, 2% price increase", func(t *testing.T) {
// Open: 1000 USDT margin, 10x leverage = 10000 USDT position
// 2% price increase = 200 USDT profit
// Return = 200 / 1000 = 20%
result := calculatePnLPercentage(200.0, 1000.0)
expected := 20.0
if math.Abs(result-expected) > 0.0001 {
t.Errorf("BTC场景: got %v, want %v", result, expected)
t.Errorf("BTC scenario: got %v, want %v", result, expected)
}
})
t.Run("ETH 5倍杠杆,价格下跌3%", func(t *testing.T) {
// 开仓:2000 USDT 保证金,5倍杠杆 = 10000 USDT 仓位
// 价格下跌 3% = -300 USDT 亏损
// 收益率 = -300 / 2000 = -15%
t.Run("ETH 5x leverage, 3% price decrease", func(t *testing.T) {
// Open: 2000 USDT margin, 5x leverage = 10000 USDT position
// 3% price decrease = -300 USDT loss
// Return = -300 / 2000 = -15%
result := calculatePnLPercentage(-300.0, 2000.0)
expected := -15.0
if math.Abs(result-expected) > 0.0001 {
t.Errorf("ETH场景: got %v, want %v", result, expected)
t.Errorf("ETH scenario: got %v, want %v", result, expected)
}
})
t.Run("SOL 20倍杠杆,价格上涨0.5%", func(t *testing.T) {
// 开仓:500 USDT 保证金,20倍杠杆 = 10000 USDT 仓位
// 价格上涨 0.5% = 50 USDT 盈利
// 收益率 = 50 / 500 = 10%
t.Run("SOL 20x leverage, 0.5% price increase", func(t *testing.T) {
// Open: 500 USDT margin, 20x leverage = 10000 USDT position
// 0.5% price increase = 50 USDT profit
// Return = 50 / 500 = 10%
result := calculatePnLPercentage(50.0, 500.0)
expected := 10.0
if math.Abs(result-expected) > 0.0001 {
t.Errorf("SOL场景: got %v, want %v", result, expected)
t.Errorf("SOL scenario: got %v, want %v", result, expected)
}
})
}
+202 -202
View File
File diff suppressed because it is too large Load Diff
+42 -42
View File
@@ -14,21 +14,21 @@ import (
)
// ============================================================
// 一、BinanceFuturesTestSuite - 继承 base test suite
// 1. BinanceFuturesTestSuite - Inherits base test suite
// ============================================================
// BinanceFuturesTestSuite 币安合约交易器测试套件
// 继承 TraderTestSuite 并添加 Binance Futures 特定的 mock 逻辑
// BinanceFuturesTestSuite Binance Futures trader test suite
// Inherits TraderTestSuite and adds Binance Futures specific mock logic
type BinanceFuturesTestSuite struct {
*TraderTestSuite // 嵌入基础测试套件
*TraderTestSuite // Embeds base test suite
mockServer *httptest.Server
}
// NewBinanceFuturesTestSuite 创建币安合约测试套件
// NewBinanceFuturesTestSuite Creates Binance Futures test suite
func NewBinanceFuturesTestSuite(t *testing.T) *BinanceFuturesTestSuite {
// 创建 mock HTTP 服务器
// Create mock HTTP server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 根据不同的 URL 路径返回不同的 mock 响应
// Return different mock responses based on URL path
path := r.URL.Path
var respBody interface{}
@@ -91,13 +91,13 @@ func NewBinanceFuturesTestSuite(t *testing.T) *BinanceFuturesTestSuite {
case path == "/fapi/v1/ticker/price" || path == "/fapi/v2/ticker/price":
symbol := r.URL.Query().Get("symbol")
if symbol == "" {
// 返回所有价格
// Return all prices
respBody = []map[string]interface{}{
{"Symbol": "BTCUSDT", "Price": "50000.00", "Time": 1234567890},
{"Symbol": "ETHUSDT", "Price": "3000.00", "Time": 1234567890},
}
} else if symbol == "INVALIDUSDT" {
// 返回错误
// Return error
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]interface{}{
"code": -1121,
@@ -105,7 +105,7 @@ func NewBinanceFuturesTestSuite(t *testing.T) *BinanceFuturesTestSuite {
})
return
} else {
// 返回单个价格(注意:即使有 symbol 参数,也要返回数组)
// Return single price (note: even with symbol parameter, return array)
price := "50000.00"
if symbol == "ETHUSDT" {
price = "3000.00"
@@ -221,11 +221,11 @@ func NewBinanceFuturesTestSuite(t *testing.T) *BinanceFuturesTestSuite {
// Mock SetLeverage - /fapi/v1/leverage
case path == "/fapi/v1/leverage":
// 将字符串转换为整数
// Convert string to integer
leverageStr := r.FormValue("leverage")
leverage := 10 // 默认值
leverage := 10 // default value
if leverageStr != "" {
// 注意:这里我们直接返回整数,而不是字符串
// Note: here we return an integer directly, not a string
fmt.Sscanf(leverageStr, "%d", &leverage)
}
respBody = map[string]interface{}{
@@ -259,23 +259,23 @@ func NewBinanceFuturesTestSuite(t *testing.T) *BinanceFuturesTestSuite {
respBody = map[string]interface{}{}
}
// 序列化响应
// Serialize response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(respBody)
}))
// 创建 futures.Client 并设置为使用 mock 服务器
// Create futures.Client and configure to use mock server
client := futures.NewClient("test_api_key", "test_secret_key")
client.BaseURL = mockServer.URL
client.HTTPClient = mockServer.Client()
// 创建 FuturesTrader
// Create FuturesTrader
trader := &FuturesTrader{
client: client,
cacheDuration: 0, // 禁用缓存以便测试
cacheDuration: 0, // disable cache for testing
}
// 创建基础套件
// Create base suite
baseSuite := NewTraderTestSuite(t, trader)
return &BinanceFuturesTestSuite{
@@ -284,7 +284,7 @@ func NewBinanceFuturesTestSuite(t *testing.T) *BinanceFuturesTestSuite {
}
}
// Cleanup 清理资源
// Cleanup cleans up resources
func (s *BinanceFuturesTestSuite) Cleanup() {
if s.mockServer != nil {
s.mockServer.Close()
@@ -293,31 +293,31 @@ func (s *BinanceFuturesTestSuite) Cleanup() {
}
// ============================================================
// 二、使用 BinanceFuturesTestSuite 运行通用测试
// 2. Run common tests using BinanceFuturesTestSuite
// ============================================================
// TestFuturesTrader_InterfaceCompliance 测试接口兼容性
// TestFuturesTrader_InterfaceCompliance tests interface compliance
func TestFuturesTrader_InterfaceCompliance(t *testing.T) {
var _ Trader = (*FuturesTrader)(nil)
}
// TestFuturesTrader_CommonInterface 使用测试套件运行所有通用接口测试
// TestFuturesTrader_CommonInterface runs all common interface tests using test suite
func TestFuturesTrader_CommonInterface(t *testing.T) {
// 创建测试套件
// Create test suite
suite := NewBinanceFuturesTestSuite(t)
defer suite.Cleanup()
// 运行所有通用接口测试
// Run all common interface tests
suite.RunAllTests()
}
// ============================================================
// 三、币安合约特定功能的单元测试
// 3. Binance Futures specific unit tests
// ============================================================
// TestNewFuturesTrader 测试创建币安合约交易器
// TestNewFuturesTrader tests creating Binance Futures trader
func TestNewFuturesTrader(t *testing.T) {
// 创建 mock HTTP 服务器
// Create mock HTTP server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
@@ -342,10 +342,10 @@ func TestNewFuturesTrader(t *testing.T) {
}))
defer mockServer.Close()
// 测试成功创建
// Test successful creation
trader := NewFuturesTrader("test_api_key", "test_secret_key", "test_user")
// 修改 client 使用 mock server
// Modify client to use mock server
trader.client.BaseURL = mockServer.URL
trader.client.HTTPClient = mockServer.Client()
@@ -354,7 +354,7 @@ func TestNewFuturesTrader(t *testing.T) {
assert.Equal(t, 15*time.Second, trader.cacheDuration)
}
// TestCalculatePositionSize 测试仓位计算
// TestCalculatePositionSize tests position size calculation
func TestCalculatePositionSize(t *testing.T) {
trader := &FuturesTrader{}
@@ -367,7 +367,7 @@ func TestCalculatePositionSize(t *testing.T) {
wantQuantity float64
}{
{
name: "正常计算",
name: "normal calculation",
balance: 10000,
riskPercent: 2,
price: 50000,
@@ -375,7 +375,7 @@ func TestCalculatePositionSize(t *testing.T) {
wantQuantity: 0.04, // (10000 * 0.02 * 10) / 50000 = 0.04
},
{
name: "高杠杆",
name: "high leverage",
balance: 10000,
riskPercent: 1,
price: 3000,
@@ -383,7 +383,7 @@ func TestCalculatePositionSize(t *testing.T) {
wantQuantity: 0.6667, // (10000 * 0.01 * 20) / 3000 = 0.6667
},
{
name: "低风险",
name: "low risk",
balance: 5000,
riskPercent: 0.5,
price: 50000,
@@ -395,26 +395,26 @@ func TestCalculatePositionSize(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
quantity := trader.CalculatePositionSize(tt.balance, tt.riskPercent, tt.price, tt.leverage)
assert.InDelta(t, tt.wantQuantity, quantity, 0.0001, "计算的仓位数量不正确")
assert.InDelta(t, tt.wantQuantity, quantity, 0.0001, "calculated position size is incorrect")
})
}
}
// TestGetBrOrderID 测试订单ID生成
// TestGetBrOrderID tests order ID generation
func TestGetBrOrderID(t *testing.T) {
// 测试3次,确保每次生成的ID都不同
// Test 3 times to ensure each generated ID is unique
ids := make(map[string]bool)
for i := 0; i < 3; i++ {
id := getBrOrderID()
// 检查格式
assert.True(t, strings.HasPrefix(id, "x-KzrpZaP9"), "订单ID应以x-KzrpZaP9开头")
// Check format
assert.True(t, strings.HasPrefix(id, "x-KzrpZaP9"), "order ID should start with x-KzrpZaP9")
// 检查长度(应该 <= 32
assert.LessOrEqual(t, len(id), 32, "订单ID长度不应超过32字符")
// Check length (should be <= 32)
assert.LessOrEqual(t, len(id), 32, "order ID length should not exceed 32 characters")
// 检查唯一性
assert.False(t, ids[id], "订单ID应该唯一")
// Check uniqueness
assert.False(t, ids[id], "order ID should be unique")
ids[id] = true
}
}
+130 -130
View File
@@ -16,35 +16,35 @@ import (
bybit "github.com/bybit-exchange/bybit.go.api"
)
// BybitTrader Bybit USDT 永續合約交易器
// BybitTrader Bybit USDT Perpetual Futures Trader
type BybitTrader struct {
client *bybit.Client
// 余额缓存
// Balance cache
cachedBalance map[string]interface{}
balanceCacheTime time.Time
balanceCacheMutex sync.RWMutex
// 持仓缓存
// Position cache
cachedPositions []map[string]interface{}
positionsCacheTime time.Time
positionsCacheMutex sync.RWMutex
// 交易对精度缓存 (symbol -> qtyStep)
// Trading pair precision cache (symbol -> qtyStep)
qtyStepCache map[string]float64
qtyStepCacheMutex sync.RWMutex
// 缓存有效期(15秒)
// Cache duration (15 seconds)
cacheDuration time.Duration
}
// NewBybitTrader 创建 Bybit 交易器
// NewBybitTrader creates a Bybit trader
func NewBybitTrader(apiKey, secretKey string) *BybitTrader {
const src = "Up000938"
client := bybit.NewBybitHttpClient(apiKey, secretKey, bybit.WithBaseURL(bybit.MAINNET))
// 设置 HTTP 传输
// Set HTTP transport
if client != nil && client.HTTPClient != nil {
defaultTransport := client.HTTPClient.Transport
if defaultTransport == nil {
@@ -63,12 +63,12 @@ func NewBybitTrader(apiKey, secretKey string) *BybitTrader {
qtyStepCache: make(map[string]float64),
}
logger.Infof("🔵 [Bybit] 交易器已初始化")
logger.Infof("🔵 [Bybit] Trader initialized")
return trader
}
// headerRoundTripper 用于添加自定义 header 的 HTTP RoundTripper
// headerRoundTripper HTTP RoundTripper for adding custom headers
type headerRoundTripper struct {
base http.RoundTripper
refererID string
@@ -79,9 +79,9 @@ func (h *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
return h.base.RoundTrip(req)
}
// GetBalance 获取账户余额
// GetBalance retrieves account balance
func (t *BybitTrader) GetBalance() (map[string]interface{}, error) {
// 检查缓存
// Check cache
t.balanceCacheMutex.RLock()
if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration {
balance := t.cachedBalance
@@ -90,24 +90,24 @@ func (t *BybitTrader) GetBalance() (map[string]interface{}, error) {
}
t.balanceCacheMutex.RUnlock()
// 调用 API
// Call API
params := map[string]interface{}{
"accountType": "UNIFIED",
}
result, err := t.client.NewUtaBybitServiceWithParams(params).GetAccountWallet(context.Background())
if err != nil {
return nil, fmt.Errorf("获取 Bybit 余额失败: %w", err)
return nil, fmt.Errorf("failed to get Bybit balance: %w", err)
}
if result.RetCode != 0 {
return nil, fmt.Errorf("Bybit API 错误: %s", result.RetMsg)
return nil, fmt.Errorf("Bybit API error: %s", result.RetMsg)
}
// 提取余额信息
// Extract balance information
resultData, ok := result.Result.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("Bybit 余额返回格式错误")
return nil, fmt.Errorf("Bybit balance return format error")
}
list, _ := resultData["list"].([]interface{})
@@ -122,17 +122,17 @@ func (t *BybitTrader) GetBalance() (map[string]interface{}, error) {
if availStr, ok := account["totalAvailableBalance"].(string); ok {
availableBalance, _ = strconv.ParseFloat(availStr, 64)
}
// Bybit UNIFIED 账户的钱包余额字段
// Bybit UNIFIED account wallet balance field
if walletStr, ok := account["totalWalletBalance"].(string); ok {
totalWalletBalance, _ = strconv.ParseFloat(walletStr, 64)
}
// Bybit 永续合约未实现盈亏
// Bybit perpetual contract unrealized PnL
if uplStr, ok := account["totalPerpUPL"].(string); ok {
totalPerpUPL, _ = strconv.ParseFloat(uplStr, 64)
}
}
// 如果没有 totalWalletBalance,使用 totalEquity
// If no totalWalletBalance, use totalEquity
if totalWalletBalance == 0 {
totalWalletBalance = totalEquity
}
@@ -142,10 +142,10 @@ func (t *BybitTrader) GetBalance() (map[string]interface{}, error) {
"totalWalletBalance": totalWalletBalance,
"availableBalance": availableBalance,
"totalUnrealizedProfit": totalPerpUPL,
"balance": totalEquity, // 兼容其他交易所格式
"balance": totalEquity, // Compatible with other exchange formats
}
// 更新缓存
// Update cache
t.balanceCacheMutex.Lock()
t.cachedBalance = balance
t.balanceCacheTime = time.Now()
@@ -154,9 +154,9 @@ func (t *BybitTrader) GetBalance() (map[string]interface{}, error) {
return balance, nil
}
// GetPositions 获取所有持仓
// GetPositions retrieves all positions
func (t *BybitTrader) GetPositions() ([]map[string]interface{}, error) {
// 检查缓存
// Check cache
t.positionsCacheMutex.RLock()
if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration {
positions := t.cachedPositions
@@ -165,7 +165,7 @@ func (t *BybitTrader) GetPositions() ([]map[string]interface{}, error) {
}
t.positionsCacheMutex.RUnlock()
// 调用 API
// Call API
params := map[string]interface{}{
"category": "linear",
"settleCoin": "USDT",
@@ -173,16 +173,16 @@ func (t *BybitTrader) GetPositions() ([]map[string]interface{}, error) {
result, err := t.client.NewUtaBybitServiceWithParams(params).GetPositionList(context.Background())
if err != nil {
return nil, fmt.Errorf("获取 Bybit 持仓失败: %w", err)
return nil, fmt.Errorf("failed to get Bybit positions: %w", err)
}
if result.RetCode != 0 {
return nil, fmt.Errorf("Bybit API 错误: %s", result.RetMsg)
return nil, fmt.Errorf("Bybit API error: %s", result.RetMsg)
}
resultData, ok := result.Result.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("Bybit 持仓返回格式错误")
return nil, fmt.Errorf("Bybit positions return format error")
}
list, _ := resultData["list"].([]interface{})
@@ -198,7 +198,7 @@ func (t *BybitTrader) GetPositions() ([]map[string]interface{}, error) {
sizeStr, _ := pos["size"].(string)
size, _ := strconv.ParseFloat(sizeStr, 64)
// 跳过空仓位
// Skip empty positions
if size == 0 {
continue
}
@@ -212,17 +212,17 @@ func (t *BybitTrader) GetPositions() ([]map[string]interface{}, error) {
leverageStr, _ := pos["leverage"].(string)
leverage, _ := strconv.ParseFloat(leverageStr, 64)
// 标记价格
// Mark price
markPriceStr, _ := pos["markPrice"].(string)
markPrice, _ := strconv.ParseFloat(markPriceStr, 64)
// 强平价格
// Liquidation price
liqPriceStr, _ := pos["liqPrice"].(string)
liqPrice, _ := strconv.ParseFloat(liqPriceStr, 64)
positionSide, _ := pos["side"].(string) // Buy = LONG, Sell = SHORT
// 转换为统一格式
// Convert to unified format
side := "LONG"
positionAmt := size
if positionSide == "Sell" {
@@ -245,7 +245,7 @@ func (t *BybitTrader) GetPositions() ([]map[string]interface{}, error) {
positions = append(positions, position)
}
// 更新缓存
// Update cache
t.positionsCacheMutex.Lock()
t.cachedPositions = positions
t.positionsCacheTime = time.Now()
@@ -254,14 +254,14 @@ func (t *BybitTrader) GetPositions() ([]map[string]interface{}, error) {
return positions, nil
}
// OpenLong 开多仓
// OpenLong opens a long position
func (t *BybitTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
// 先设置杠杆
// Set leverage first
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Infof("⚠️ [Bybit] 设置杠杆失败: %v", err)
logger.Infof("⚠️ [Bybit] Failed to set leverage: %v", err)
}
// 使用 FormatQuantity 格式化数量
// Use FormatQuantity to format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
params := map[string]interface{}{
@@ -270,28 +270,28 @@ func (t *BybitTrader) OpenLong(symbol string, quantity float64, leverage int) (m
"side": "Buy",
"orderType": "Market",
"qty": qtyStr,
"positionIdx": 0, // 单向持仓模式
"positionIdx": 0, // One-way position mode
}
result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background())
if err != nil {
return nil, fmt.Errorf("Bybit 开多失败: %w", err)
return nil, fmt.Errorf("Bybit open long failed: %w", err)
}
// 清除缓存
// Clear cache
t.clearCache()
return t.parseOrderResult(result)
}
// OpenShort 开空仓
// OpenShort opens a short position
func (t *BybitTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
// 先设置杠杆
// Set leverage first
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Infof("⚠️ [Bybit] 设置杠杆失败: %v", err)
logger.Infof("⚠️ [Bybit] Failed to set leverage: %v", err)
}
// 使用 FormatQuantity 格式化数量
// Use FormatQuantity to format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
params := map[string]interface{}{
@@ -300,23 +300,23 @@ func (t *BybitTrader) OpenShort(symbol string, quantity float64, leverage int) (
"side": "Sell",
"orderType": "Market",
"qty": qtyStr,
"positionIdx": 0, // 单向持仓模式
"positionIdx": 0, // One-way position mode
}
result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background())
if err != nil {
return nil, fmt.Errorf("Bybit 开空失败: %w", err)
return nil, fmt.Errorf("Bybit open short failed: %w", err)
}
// 清除缓存
// Clear cache
t.clearCache()
return t.parseOrderResult(result)
}
// CloseLong 平多仓
// CloseLong closes a long position
func (t *BybitTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) {
// 如果 quantity = 0,获取当前持仓数量
// If quantity = 0, get current position quantity
if quantity == 0 {
positions, err := t.GetPositions()
if err != nil {
@@ -331,16 +331,16 @@ func (t *BybitTrader) CloseLong(symbol string, quantity float64) (map[string]int
}
if quantity <= 0 {
return nil, fmt.Errorf("没有多仓可平")
return nil, fmt.Errorf("no long position to close")
}
// 使用 FormatQuantity 格式化数量
// Use FormatQuantity to format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"side": "Sell", // 平多用 Sell
"side": "Sell", // Close long with Sell
"orderType": "Market",
"qty": qtyStr,
"positionIdx": 0,
@@ -349,18 +349,18 @@ func (t *BybitTrader) CloseLong(symbol string, quantity float64) (map[string]int
result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background())
if err != nil {
return nil, fmt.Errorf("Bybit 平多失败: %w", err)
return nil, fmt.Errorf("Bybit close long failed: %w", err)
}
// 清除缓存
// Clear cache
t.clearCache()
return t.parseOrderResult(result)
}
// CloseShort 平空仓
// CloseShort closes a short position
func (t *BybitTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) {
// 如果 quantity = 0,获取当前持仓数量
// If quantity = 0, get current position quantity
if quantity == 0 {
positions, err := t.GetPositions()
if err != nil {
@@ -368,23 +368,23 @@ func (t *BybitTrader) CloseShort(symbol string, quantity float64) (map[string]in
}
for _, pos := range positions {
if pos["symbol"] == symbol && pos["side"] == "SHORT" {
quantity = -pos["positionAmt"].(float64) // 空仓是负数
quantity = -pos["positionAmt"].(float64) // Short position is negative
break
}
}
}
if quantity <= 0 {
return nil, fmt.Errorf("没有空仓可平")
return nil, fmt.Errorf("no short position to close")
}
// 使用 FormatQuantity 格式化数量
// Use FormatQuantity to format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"side": "Buy", // 平空用 Buy
"side": "Buy", // Close short with Buy
"orderType": "Market",
"qty": qtyStr,
"positionIdx": 0,
@@ -393,16 +393,16 @@ func (t *BybitTrader) CloseShort(symbol string, quantity float64) (map[string]in
result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background())
if err != nil {
return nil, fmt.Errorf("Bybit 平空失败: %w", err)
return nil, fmt.Errorf("Bybit close short failed: %w", err)
}
// 清除缓存
// Clear cache
t.clearCache()
return t.parseOrderResult(result)
}
// SetLeverage 设置杠杆
// SetLeverage sets leverage
func (t *BybitTrader) SetLeverage(symbol string, leverage int) error {
params := map[string]interface{}{
"category": "linear",
@@ -413,25 +413,25 @@ func (t *BybitTrader) SetLeverage(symbol string, leverage int) error {
result, err := t.client.NewUtaBybitServiceWithParams(params).SetPositionLeverage(context.Background())
if err != nil {
// 如果杠杆已经是目标值,Bybit 会返回错误,忽略这种情况
// If leverage is already at target value, Bybit will return an error, ignore this case
if strings.Contains(err.Error(), "leverage not modified") {
return nil
}
return fmt.Errorf("设置杠杆失败: %w", err)
return fmt.Errorf("failed to set leverage: %w", err)
}
if result.RetCode != 0 && result.RetCode != 110043 { // 110043 = leverage not modified
return fmt.Errorf("设置杠杆失败: %s", result.RetMsg)
return fmt.Errorf("failed to set leverage: %s", result.RetMsg)
}
return nil
}
// SetMarginMode 设置仓位模式
// SetMarginMode sets position margin mode
func (t *BybitTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
tradeMode := 1 // 逐仓
tradeMode := 1 // Isolated margin
if isCrossMargin {
tradeMode = 0 // 全仓
tradeMode = 0 // Cross margin
}
params := map[string]interface{}{
@@ -445,17 +445,17 @@ func (t *BybitTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
if strings.Contains(err.Error(), "Cross/isolated margin mode is not modified") {
return nil
}
return fmt.Errorf("设置保证金模式失败: %w", err)
return fmt.Errorf("failed to set margin mode: %w", err)
}
if result.RetCode != 0 && result.RetCode != 110026 { // already in target mode
return fmt.Errorf("设置保证金模式失败: %s", result.RetMsg)
return fmt.Errorf("failed to set margin mode: %s", result.RetMsg)
}
return nil
}
// GetMarketPrice 获取市场价格
// GetMarketPrice retrieves market price
func (t *BybitTrader) GetMarketPrice(symbol string) (float64, error) {
params := map[string]interface{}{
"category": "linear",
@@ -464,53 +464,53 @@ func (t *BybitTrader) GetMarketPrice(symbol string) (float64, error) {
result, err := t.client.NewUtaBybitServiceWithParams(params).GetMarketTickers(context.Background())
if err != nil {
return 0, fmt.Errorf("获取市场价格失败: %w", err)
return 0, fmt.Errorf("failed to get market price: %w", err)
}
if result.RetCode != 0 {
return 0, fmt.Errorf("API 错误: %s", result.RetMsg)
return 0, fmt.Errorf("API error: %s", result.RetMsg)
}
resultData, ok := result.Result.(map[string]interface{})
if !ok {
return 0, fmt.Errorf("返回格式错误")
return 0, fmt.Errorf("return format error")
}
list, _ := resultData["list"].([]interface{})
if len(list) == 0 {
return 0, fmt.Errorf("未找到 %s 的价格数据", symbol)
return 0, fmt.Errorf("price data not found for %s", symbol)
}
ticker, _ := list[0].(map[string]interface{})
lastPriceStr, _ := ticker["lastPrice"].(string)
lastPrice, err := strconv.ParseFloat(lastPriceStr, 64)
if err != nil {
return 0, fmt.Errorf("解析价格失败: %w", err)
return 0, fmt.Errorf("failed to parse price: %w", err)
}
return lastPrice, nil
}
// SetStopLoss 设置止损单
// SetStopLoss sets stop loss order
func (t *BybitTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error {
side := "Sell" // LONG 止损用 Sell
side := "Sell" // LONG stop loss uses Sell
if positionSide == "SHORT" {
side = "Buy" // SHORT 止损用 Buy
side = "Buy" // SHORT stop loss uses Buy
}
// 获取当前价格来确定 triggerDirection
// Get current price to determine triggerDirection
currentPrice, err := t.GetMarketPrice(symbol)
if err != nil {
return err
}
triggerDirection := 2 // 价格下跌触发(默认多单止损)
triggerDirection := 2 // Price fall trigger (default long stop loss)
if stopPrice > currentPrice {
triggerDirection = 1 // 价格上涨触发(空单止损)
triggerDirection = 1 // Price rise trigger (short stop loss)
}
// 使用 FormatQuantity 格式化数量
// Use FormatQuantity to format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
params := map[string]interface{}{
@@ -527,36 +527,36 @@ func (t *BybitTrader) SetStopLoss(symbol string, positionSide string, quantity,
result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background())
if err != nil {
return fmt.Errorf("设置止损失败: %w", err)
return fmt.Errorf("failed to set stop loss: %w", err)
}
if result.RetCode != 0 {
return fmt.Errorf("设置止损失败: %s", result.RetMsg)
return fmt.Errorf("failed to set stop loss: %s", result.RetMsg)
}
logger.Infof(" ✓ [Bybit] 止损单已设置: %s @ %.2f", symbol, stopPrice)
logger.Infof(" ✓ [Bybit] Stop loss order set: %s @ %.2f", symbol, stopPrice)
return nil
}
// SetTakeProfit 设置止盈单
// SetTakeProfit sets take profit order
func (t *BybitTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error {
side := "Sell" // LONG 止盈用 Sell
side := "Sell" // LONG take profit uses Sell
if positionSide == "SHORT" {
side = "Buy" // SHORT 止盈用 Buy
side = "Buy" // SHORT take profit uses Buy
}
// 获取当前价格来确定 triggerDirection
// Get current price to determine triggerDirection
currentPrice, err := t.GetMarketPrice(symbol)
if err != nil {
return err
}
triggerDirection := 1 // 价格上涨触发(默认多单止盈)
triggerDirection := 1 // Price rise trigger (default long take profit)
if takeProfitPrice < currentPrice {
triggerDirection = 2 // 价格下跌触发(空单止盈)
triggerDirection = 2 // Price fall trigger (short take profit)
}
// 使用 FormatQuantity 格式化数量
// Use FormatQuantity to format quantity
qtyStr, _ := t.FormatQuantity(symbol, quantity)
params := map[string]interface{}{
@@ -573,28 +573,28 @@ func (t *BybitTrader) SetTakeProfit(symbol string, positionSide string, quantity
result, err := t.client.NewUtaBybitServiceWithParams(params).PlaceOrder(context.Background())
if err != nil {
return fmt.Errorf("设置止盈失败: %w", err)
return fmt.Errorf("failed to set take profit: %w", err)
}
if result.RetCode != 0 {
return fmt.Errorf("设置止盈失败: %s", result.RetMsg)
return fmt.Errorf("failed to set take profit: %s", result.RetMsg)
}
logger.Infof(" ✓ [Bybit] 止盈单已设置: %s @ %.2f", symbol, takeProfitPrice)
logger.Infof(" ✓ [Bybit] Take profit order set: %s @ %.2f", symbol, takeProfitPrice)
return nil
}
// CancelStopLossOrders 取消止损单
// CancelStopLossOrders cancels stop loss orders
func (t *BybitTrader) CancelStopLossOrders(symbol string) error {
return t.cancelConditionalOrders(symbol, "StopLoss")
}
// CancelTakeProfitOrders 取消止盈单
// CancelTakeProfitOrders cancels take profit orders
func (t *BybitTrader) CancelTakeProfitOrders(symbol string) error {
return t.cancelConditionalOrders(symbol, "TakeProfit")
}
// CancelAllOrders 取消所有挂单
// CancelAllOrders cancels all pending orders
func (t *BybitTrader) CancelAllOrders(symbol string) error {
params := map[string]interface{}{
"category": "linear",
@@ -603,26 +603,26 @@ func (t *BybitTrader) CancelAllOrders(symbol string) error {
_, err := t.client.NewUtaBybitServiceWithParams(params).CancelAllOrders(context.Background())
if err != nil {
return fmt.Errorf("取消所有订单失败: %w", err)
return fmt.Errorf("failed to cancel all orders: %w", err)
}
return nil
}
// CancelStopOrders 取消所有止盈止损单
// CancelStopOrders cancels all stop loss and take profit orders
func (t *BybitTrader) CancelStopOrders(symbol string) error {
if err := t.CancelStopLossOrders(symbol); err != nil {
logger.Infof("⚠️ [Bybit] 取消止损单失败: %v", err)
logger.Infof("⚠️ [Bybit] Failed to cancel stop loss orders: %v", err)
}
if err := t.CancelTakeProfitOrders(symbol); err != nil {
logger.Infof("⚠️ [Bybit] 取消止盈单失败: %v", err)
logger.Infof("⚠️ [Bybit] Failed to cancel take profit orders: %v", err)
}
return nil
}
// getQtyStep 获取交易对的数量步长
// getQtyStep retrieves the quantity step for a trading pair
func (t *BybitTrader) getQtyStep(symbol string) float64 {
// 先检查缓存
// Check cache first
t.qtyStepCacheMutex.RLock()
if step, ok := t.qtyStepCache[symbol]; ok {
t.qtyStepCacheMutex.RUnlock()
@@ -630,12 +630,12 @@ func (t *BybitTrader) getQtyStep(symbol string) float64 {
}
t.qtyStepCacheMutex.RUnlock()
// 直接调用公开 API 获取合约信息
// Call public API directly to get contract information
url := fmt.Sprintf("https://api.bybit.com/v5/market/instruments-info?category=linear&symbol=%s", symbol)
resp, err := http.Get(url)
if err != nil {
logger.Infof("⚠️ [Bybit] 获取 %s 精度信息失败: %v", symbol, err)
return 1 // 默认整数
logger.Infof("⚠️ [Bybit] Failed to get precision info for %s: %v", symbol, err)
return 1 // Default to integer
}
defer resp.Body.Close()
@@ -668,7 +668,7 @@ func (t *BybitTrader) getQtyStep(symbol string) float64 {
qtyStep = 1
}
// 缓存结果
// Cache result
t.qtyStepCacheMutex.Lock()
t.qtyStepCache[symbol] = qtyStep
t.qtyStepCacheMutex.Unlock()
@@ -678,15 +678,15 @@ func (t *BybitTrader) getQtyStep(symbol string) float64 {
return qtyStep
}
// FormatQuantity 格式化数量
// FormatQuantity formats quantity
func (t *BybitTrader) FormatQuantity(symbol string, quantity float64) (string, error) {
// 获取该币种的 qtyStep
// Get qtyStep for this symbol
qtyStep := t.getQtyStep(symbol)
// 根据 qtyStep 对齐数量(向下取整到最近的 step
// Align quantity according to qtyStep (round down to nearest step)
alignedQty := math.Floor(quantity/qtyStep) * qtyStep
// 计算需要的小数位数
// Calculate required decimal places
decimals := 0
if qtyStep < 1 {
stepStr := strconv.FormatFloat(qtyStep, 'f', -1, 64)
@@ -695,14 +695,14 @@ func (t *BybitTrader) FormatQuantity(symbol string, quantity float64) (string, e
}
}
// 格式化
// Format
format := fmt.Sprintf("%%.%df", decimals)
formatted := fmt.Sprintf(format, alignedQty)
return formatted, nil
}
// 辅助方法
// Helper methods
func (t *BybitTrader) clearCache() {
t.balanceCacheMutex.Lock()
@@ -716,12 +716,12 @@ func (t *BybitTrader) clearCache() {
func (t *BybitTrader) parseOrderResult(result *bybit.ServerResponse) (map[string]interface{}, error) {
if result.RetCode != 0 {
return nil, fmt.Errorf("下单失败: %s", result.RetMsg)
return nil, fmt.Errorf("order placement failed: %s", result.RetMsg)
}
resultData, ok := result.Result.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("返回格式错误")
return nil, fmt.Errorf("return format error")
}
orderId, _ := resultData["orderId"].(string)
@@ -732,7 +732,7 @@ func (t *BybitTrader) parseOrderResult(result *bybit.ServerResponse) (map[string
}, nil
}
// GetOrderStatus 获取订单状态
// GetOrderStatus retrieves order status
func (t *BybitTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
params := map[string]interface{}{
"category": "linear",
@@ -742,26 +742,26 @@ func (t *BybitTrader) GetOrderStatus(symbol string, orderID string) (map[string]
result, err := t.client.NewUtaBybitServiceWithParams(params).GetOrderHistory(context.Background())
if err != nil {
return nil, fmt.Errorf("获取订单状态失败: %w", err)
return nil, fmt.Errorf("failed to get order status: %w", err)
}
if result.RetCode != 0 {
return nil, fmt.Errorf("API 错误: %s", result.RetMsg)
return nil, fmt.Errorf("API error: %s", result.RetMsg)
}
resultData, ok := result.Result.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("返回格式错误")
return nil, fmt.Errorf("return format error")
}
list, _ := resultData["list"].([]interface{})
if len(list) == 0 {
return nil, fmt.Errorf("未找到订单 %s", orderID)
return nil, fmt.Errorf("order %s not found", orderID)
}
order, _ := list[0].(map[string]interface{})
// 解析订单数据
// Parse order data
status, _ := order["orderStatus"].(string)
avgPriceStr, _ := order["avgPrice"].(string)
cumExecQtyStr, _ := order["cumExecQty"].(string)
@@ -771,7 +771,7 @@ func (t *BybitTrader) GetOrderStatus(symbol string, orderID string) (map[string]
executedQty, _ := strconv.ParseFloat(cumExecQtyStr, 64)
commission, _ := strconv.ParseFloat(cumExecFeeStr, 64)
// 转换状态为统一格式
// Convert status to unified format
unifiedStatus := status
switch status {
case "Filled":
@@ -794,20 +794,20 @@ func (t *BybitTrader) GetOrderStatus(symbol string, orderID string) (map[string]
}
func (t *BybitTrader) cancelConditionalOrders(symbol string, orderType string) error {
// 先获取所有条件单
// First get all conditional orders
params := map[string]interface{}{
"category": "linear",
"symbol": symbol,
"orderFilter": "StopOrder", // 条件单
"orderFilter": "StopOrder", // Conditional orders
}
result, err := t.client.NewUtaBybitServiceWithParams(params).GetOpenOrders(context.Background())
if err != nil {
return fmt.Errorf("获取条件单失败: %w", err)
return fmt.Errorf("failed to get conditional orders: %w", err)
}
if result.RetCode != 0 {
return nil // 没有订单
return nil // No orders
}
resultData, ok := result.Result.(map[string]interface{})
@@ -817,7 +817,7 @@ func (t *BybitTrader) cancelConditionalOrders(symbol string, orderType string) e
list, _ := resultData["list"].([]interface{})
// 取消匹配的订单
// Cancel matching orders
for _, item := range list {
order, ok := item.(map[string]interface{})
if !ok {
@@ -827,7 +827,7 @@ func (t *BybitTrader) cancelConditionalOrders(symbol string, orderType string) e
orderId, _ := order["orderId"].(string)
stopOrderType, _ := order["stopOrderType"].(string)
// 根据类型筛选
// Filter by type
shouldCancel := false
if orderType == "StopLoss" && (stopOrderType == "StopLoss" || stopOrderType == "Stop") {
shouldCancel = true
+52 -52
View File
@@ -12,21 +12,21 @@ import (
)
// ============================================================
// 一、BybitTraderTestSuite - 继承 base test suite
// Part 1: BybitTraderTestSuite - Inherits base test suite
// ============================================================
// BybitTraderTestSuite Bybit交易器测试套件
// 继承 TraderTestSuite 并添加 Bybit 特定的 mock 逻辑
// BybitTraderTestSuite Bybit trader test suite
// Inherits TraderTestSuite and adds Bybit-specific mock logic
type BybitTraderTestSuite struct {
*TraderTestSuite // 嵌入基础测试套件
*TraderTestSuite // Embeds base test suite
mockServer *httptest.Server
}
// NewBybitTraderTestSuite 创建 Bybit 测试套件
// 注意:由于 Bybit SDK 封装设计,无法轻松注入 mock HTTP client
// 因此这里的测试套件主要用于接口合规性验证,而非 API 调用测试
// NewBybitTraderTestSuite Create Bybit test suite
// Note: Due to Bybit SDK encapsulation design, cannot easily inject mock HTTP client
// Therefore this test suite is mainly used for interface compliance verification, not API call testing
func NewBybitTraderTestSuite(t *testing.T) *BybitTraderTestSuite {
// 创建 mock HTTP 服务器(用于验证响应格式)
// Create mock HTTP server (for response format verification)
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
var respBody interface{}
@@ -65,10 +65,10 @@ func NewBybitTraderTestSuite(t *testing.T) *BybitTraderTestSuite {
json.NewEncoder(w).Encode(respBody)
}))
// 创建真实的 Bybit trader(用于接口合规性测试)
// Create real Bybit trader (for interface compliance testing)
trader := NewBybitTrader("test_api_key", "test_secret_key")
// 创建基础套件
// Create base suite
baseSuite := NewTraderTestSuite(t, trader)
return &BybitTraderTestSuite{
@@ -77,7 +77,7 @@ func NewBybitTraderTestSuite(t *testing.T) *BybitTraderTestSuite {
}
}
// Cleanup 清理资源
// Cleanup Clean up resources
func (s *BybitTraderTestSuite) Cleanup() {
if s.mockServer != nil {
s.mockServer.Close()
@@ -86,19 +86,19 @@ func (s *BybitTraderTestSuite) Cleanup() {
}
// ============================================================
// 二、接口兼容性测试
// Part 2: Interface compliance tests
// ============================================================
// TestBybitTrader_InterfaceCompliance 测试接口兼容性
// TestBybitTrader_InterfaceCompliance Test interface compliance
func TestBybitTrader_InterfaceCompliance(t *testing.T) {
var _ Trader = (*BybitTrader)(nil)
}
// ============================================================
// 三、Bybit 特定功能的单元测试
// Part 3: Bybit-specific feature unit tests
// ============================================================
// TestNewBybitTrader 测试创建 Bybit 交易器
// TestNewBybitTrader Test creating Bybit trader
func TestNewBybitTrader(t *testing.T) {
tests := []struct {
name string
@@ -107,19 +107,19 @@ func TestNewBybitTrader(t *testing.T) {
wantNil bool
}{
{
name: "成功创建",
name: "Successfully create",
apiKey: "test_api_key",
secretKey: "test_secret_key",
wantNil: false,
},
{
name: "API Key仍可创建",
name: "Empty API Key can still create",
apiKey: "",
secretKey: "test_secret_key",
wantNil: false,
},
{
name: "Secret Key仍可创建",
name: "Empty Secret Key can still create",
apiKey: "test_api_key",
secretKey: "",
wantNil: false,
@@ -140,26 +140,26 @@ func TestNewBybitTrader(t *testing.T) {
}
}
// TestBybitTrader_SymbolFormat 测试符号格式
// TestBybitTrader_SymbolFormat Test symbol format
func TestBybitTrader_SymbolFormat(t *testing.T) {
// Bybit 使用大写符号格式(如 BTCUSDT
// Bybit uses uppercase symbol format (e.g. BTCUSDT)
tests := []struct {
name string
symbol string
isValid bool
}{
{
name: "标准USDT合约",
name: "Standard USDT contract",
symbol: "BTCUSDT",
isValid: true,
},
{
name: "ETH合约",
name: "ETH contract",
symbol: "ETHUSDT",
isValid: true,
},
{
name: "SOL合约",
name: "SOL contract",
symbol: "SOLUSDT",
isValid: true,
},
@@ -167,14 +167,14 @@ func TestBybitTrader_SymbolFormat(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 验证符号格式正确(全大写,以USDT结尾)
// Verify symbol format is correct (all uppercase, ends with USDT)
assert.True(t, tt.symbol == strings.ToUpper(tt.symbol))
assert.True(t, strings.HasSuffix(tt.symbol, "USDT"))
})
}
}
// TestBybitTrader_FormatQuantity 测试数量格式化
// TestBybitTrader_FormatQuantity Test quantity formatting
func TestBybitTrader_FormatQuantity(t *testing.T) {
trader := NewBybitTrader("test", "test")
@@ -186,21 +186,21 @@ func TestBybitTrader_FormatQuantity(t *testing.T) {
hasError bool
}{
{
name: "BTC数量格式化",
name: "BTC quantity formatting",
symbol: "BTCUSDT",
quantity: 0.12345,
expected: "0.123", // Bybit 默认使用 3 位小数
expected: "0.123", // Bybit defaults to 3 decimal places
hasError: false,
},
{
name: "ETH数量格式化",
name: "ETH quantity formatting",
symbol: "ETHUSDT",
quantity: 1.2345,
expected: "1.234",
hasError: false,
},
{
name: "整数数量",
name: "Integer quantity",
symbol: "SOLUSDT",
quantity: 10.0,
expected: "10.000",
@@ -221,7 +221,7 @@ func TestBybitTrader_FormatQuantity(t *testing.T) {
}
}
// TestBybitTrader_ParseResponse 测试响应解析
// TestBybitTrader_ParseResponse Test response parsing
func TestBybitTrader_ParseResponse(t *testing.T) {
tests := []struct {
name string
@@ -231,20 +231,20 @@ func TestBybitTrader_ParseResponse(t *testing.T) {
errContain string
}{
{
name: "成功响应",
name: "Success response",
retCode: 0,
retMsg: "OK",
expectErr: false,
},
{
name: "API错误",
name: "API error",
retCode: 10001,
retMsg: "Invalid symbol",
expectErr: true,
errContain: "Invalid symbol",
},
{
name: "权限错误",
name: "Permission error",
retCode: 10003,
retMsg: "Invalid API key",
expectErr: true,
@@ -267,7 +267,7 @@ func TestBybitTrader_ParseResponse(t *testing.T) {
}
}
// checkBybitResponse 检查 Bybit API 响应是否有错误
// checkBybitResponse Check if Bybit API response has errors
func checkBybitResponse(retCode int, retMsg string) error {
if retCode != 0 {
return &BybitAPIError{
@@ -278,7 +278,7 @@ func checkBybitResponse(retCode int, retMsg string) error {
return nil
}
// BybitAPIError Bybit API 错误类型
// BybitAPIError Bybit API error type
type BybitAPIError struct {
Code int
Message string
@@ -288,7 +288,7 @@ func (e *BybitAPIError) Error() string {
return e.Message
}
// TestBybitTrader_PositionSideConversion 测试仓位方向转换
// TestBybitTrader_PositionSideConversion Test position side conversion
func TestBybitTrader_PositionSideConversion(t *testing.T) {
tests := []struct {
name string
@@ -296,17 +296,17 @@ func TestBybitTrader_PositionSideConversion(t *testing.T) {
expected string
}{
{
name: "BuyLong",
name: "Buy to Long",
side: "Buy",
expected: "long",
},
{
name: "SellShort",
name: "Sell to Short",
side: "Sell",
expected: "short",
},
{
name: "其他值保持不变",
name: "Other values remain unchanged",
side: "Unknown",
expected: "unknown",
},
@@ -320,7 +320,7 @@ func TestBybitTrader_PositionSideConversion(t *testing.T) {
}
}
// convertBybitSide 转换 Bybit 仓位方向
// convertBybitSide Convert Bybit position side
func convertBybitSide(side string) string {
switch side {
case "Buy":
@@ -332,29 +332,29 @@ func convertBybitSide(side string) string {
}
}
// TestBybitTrader_CategoryLinear 测试只使用 linear 类别
// TestBybitTrader_CategoryLinear Test using only linear category
func TestBybitTrader_CategoryLinear(t *testing.T) {
// Bybit trader 应该只使用 linear 类别(USDT永续合约)
// Bybit trader should only use linear category (USDT perpetual contracts)
trader := NewBybitTrader("test", "test")
assert.NotNil(t, trader)
// 验证默认配置
// Verify default configuration
assert.NotNil(t, trader.client)
}
// TestBybitTrader_CacheDuration 测试缓存持续时间
// TestBybitTrader_CacheDuration Test cache duration
func TestBybitTrader_CacheDuration(t *testing.T) {
trader := NewBybitTrader("test", "test")
// 验证默认缓存时间为15秒
// Verify default cache time is 15 seconds
assert.Equal(t, 15*time.Second, trader.cacheDuration)
}
// ============================================================
// 四、Mock 服务器集成测试
// Part 4: Mock server integration tests
// ============================================================
// TestBybitTrader_MockServerGetBalance 测试通过 Mock 服务器获取余额
// TestBybitTrader_MockServerGetBalance Test getting balance through Mock server
func TestBybitTrader_MockServerGetBalance(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v5/account/wallet-balance" {
@@ -386,12 +386,12 @@ func TestBybitTrader_MockServerGetBalance(t *testing.T) {
}))
defer mockServer.Close()
// 由于 Bybit SDK 封装,无法直接注入 mock URL
// 这个测试验证 mock 服务器响应格式正确
// Due to Bybit SDK encapsulation, cannot directly inject mock URL
// This test verifies mock server response format is correct
assert.NotNil(t, mockServer)
}
// TestBybitTrader_MockServerGetPositions 测试通过 Mock 服务器获取持仓
// TestBybitTrader_MockServerGetPositions Test getting positions through Mock server
func TestBybitTrader_MockServerGetPositions(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v5/position/list" {
@@ -425,7 +425,7 @@ func TestBybitTrader_MockServerGetPositions(t *testing.T) {
assert.NotNil(t, mockServer)
}
// TestBybitTrader_MockServerPlaceOrder 测试通过 Mock 服务器下单
// TestBybitTrader_MockServerPlaceOrder Test placing order through Mock server
func TestBybitTrader_MockServerPlaceOrder(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v5/order/create" && r.Method == "POST" {
@@ -448,7 +448,7 @@ func TestBybitTrader_MockServerPlaceOrder(t *testing.T) {
assert.NotNil(t, mockServer)
}
// TestBybitTrader_MockServerSetLeverage 测试通过 Mock 服务器设置杠杆
// TestBybitTrader_MockServerSetLeverage Test setting leverage through Mock server
func TestBybitTrader_MockServerSetLeverage(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v5/position/set-leverage" && r.Method == "POST" {
+4 -4
View File
@@ -5,7 +5,7 @@ import (
"strconv"
)
// SafeFloat64 从map中安全提取float64值
// SafeFloat64 Safely extract float64 value from map
func SafeFloat64(data map[string]interface{}, key string) (float64, error) {
value, ok := data[key]
if !ok {
@@ -22,7 +22,7 @@ func SafeFloat64(data map[string]interface{}, key string) (float64, error) {
case int64:
return float64(v), nil
case string:
// 尝试解析字符串为float64
// Try to parse string as float64
parsed, err := strconv.ParseFloat(v, 64)
if err != nil {
return 0, fmt.Errorf("cannot parse string '%s' as float64: %w", v, err)
@@ -33,7 +33,7 @@ func SafeFloat64(data map[string]interface{}, key string) (float64, error) {
}
}
// SafeString 从map中安全提取字符串值
// SafeString Safely extract string value from map
func SafeString(data map[string]interface{}, key string) (string, error) {
value, ok := data[key]
if !ok {
@@ -50,7 +50,7 @@ func SafeString(data map[string]interface{}, key string) (string, error) {
}
}
// SafeInt 从map中安全提取int值
// SafeInt Safely extract int value from map
func SafeInt(data map[string]interface{}, key string) (int, error) {
value, ok := data[key]
if !ok {
File diff suppressed because it is too large Load Diff
+78 -78
View File
@@ -14,32 +14,32 @@ import (
)
// ============================================================
// 一、HyperliquidTestSuite - 继承 base test suite
// Part 1: HyperliquidTestSuite - Inherits base test suite
// ============================================================
// HyperliquidTestSuite Hyperliquid 交易器测试套件
// 继承 TraderTestSuite 并添加 Hyperliquid 特定的 mock 逻辑
// HyperliquidTestSuite Hyperliquid trader test suite
// Inherits TraderTestSuite and adds Hyperliquid-specific mock logic
type HyperliquidTestSuite struct {
*TraderTestSuite // 嵌入基础测试套件
*TraderTestSuite // Embeds base test suite
mockServer *httptest.Server
privateKey *ecdsa.PrivateKey
}
// NewHyperliquidTestSuite 创建 Hyperliquid 测试套件
// NewHyperliquidTestSuite Create Hyperliquid test suite
func NewHyperliquidTestSuite(t *testing.T) *HyperliquidTestSuite {
// 创建测试用私钥
// Create test private key
privateKey, err := crypto.HexToECDSA("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")
if err != nil {
t.Fatalf("创建测试私钥失败: %v", err)
t.Fatalf("Failed to create test private key: %v", err)
}
// 创建 mock HTTP 服务器
// Create mock HTTP server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 根据不同的请求路径返回不同的 mock 响应
// Return different mock responses based on request path
var respBody interface{}
// Hyperliquid API 使用 POST 请求,请求体是 JSON
// 我们需要根据请求体中的 "type" 字段来区分不同的请求
// Hyperliquid API uses POST requests with JSON body
// We need to distinguish different requests by the "type" field in request body
var reqBody map[string]interface{}
if r.Method == "POST" {
json.NewDecoder(r.Body).Decode(&reqBody)
@@ -54,7 +54,7 @@ func NewHyperliquidTestSuite(t *testing.T) *HyperliquidTestSuite {
}
switch reqType {
// Mock Meta - 获取市场元数据
// Mock Meta - Get market metadata
case "meta":
respBody = map[string]interface{}{
"universe": []map[string]interface{}{
@@ -78,14 +78,14 @@ func NewHyperliquidTestSuite(t *testing.T) *HyperliquidTestSuite {
"marginTables": []interface{}{},
}
// Mock UserState - 获取用户账户状态(用于 GetBalance GetPositions
// Mock UserState - Get user account state (for GetBalance and GetPositions)
case "clearinghouseState":
user, _ := reqBody["user"].(string)
// 检查是否是查询 Agent 钱包余额(用于安全检查)
// Check if querying Agent wallet balance (for security check)
agentAddr := crypto.PubkeyToAddress(privateKey.PublicKey).Hex()
if user == agentAddr {
// Agent 钱包余额应该很低
// Agent wallet balance should be low
respBody = map[string]interface{}{
"crossMarginSummary": map[string]interface{}{
"accountValue": "5.00",
@@ -95,7 +95,7 @@ func NewHyperliquidTestSuite(t *testing.T) *HyperliquidTestSuite {
"assetPositions": []interface{}{},
}
} else {
// 主钱包账户状态
// Main wallet account state
respBody = map[string]interface{}{
"crossMarginSummary": map[string]interface{}{
"accountValue": "10000.00",
@@ -121,7 +121,7 @@ func NewHyperliquidTestSuite(t *testing.T) *HyperliquidTestSuite {
}
}
// Mock SpotUserState - 获取现货账户状态
// Mock SpotUserState - Get spot account state
case "spotClearinghouseState":
respBody = map[string]interface{}{
"balances": []map[string]interface{}{
@@ -132,25 +132,25 @@ func NewHyperliquidTestSuite(t *testing.T) *HyperliquidTestSuite {
},
}
// Mock SpotMeta - 获取现货市场元数据
// Mock SpotMeta - Get spot market metadata
case "spotMeta":
respBody = map[string]interface{}{
"universe": []map[string]interface{}{},
"tokens": []map[string]interface{}{},
}
// Mock AllMids - 获取所有市场价格
// Mock AllMids - Get all market prices
case "allMids":
respBody = map[string]string{
"BTC": "50000.00",
"ETH": "3000.00",
}
// Mock OpenOrders - 获取挂单列表
// Mock OpenOrders - Get open orders list
case "openOrders":
respBody = []interface{}{}
// Mock Order - 创建订单(开仓、平仓、止损、止盈)
// Mock Order - Create order (open, close, stop-loss, take-profit)
case "order":
respBody = map[string]interface{}{
"status": "ok",
@@ -169,46 +169,46 @@ func NewHyperliquidTestSuite(t *testing.T) *HyperliquidTestSuite {
},
}
// Mock UpdateLeverage - 设置杠杆
// Mock UpdateLeverage - Set leverage
case "updateLeverage":
respBody = map[string]interface{}{
"status": "ok",
}
// Mock Cancel - 取消订单
// Mock Cancel - Cancel order
case "cancel":
respBody = map[string]interface{}{
"status": "ok",
}
default:
// 默认返回成功响应
// Default return success response
respBody = map[string]interface{}{
"status": "ok",
}
}
// 序列化响应
// Serialize response
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(respBody)
}))
// 创建 HyperliquidTrader,使用 mock 服务器 URL
// Create HyperliquidTrader, using mock server URL
walletAddr := "0x9999999999999999999999999999999999999999"
ctx := context.Background()
// 创建 Exchange 客户端,指向 mock 服务器
// Create Exchange client, pointing to mock server
exchange := hyperliquid.NewExchange(
ctx,
privateKey,
mockServer.URL, // 使用 mock 服务器 URL
mockServer.URL, // Use mock server URL
nil,
"",
walletAddr,
nil,
)
// 创建 meta(模拟获取成功)
// Create meta (simulate successful fetch)
meta := &hyperliquid.Meta{
Universe: []hyperliquid.AssetInfo{
{Name: "BTC", SzDecimals: 4},
@@ -224,7 +224,7 @@ func NewHyperliquidTestSuite(t *testing.T) *HyperliquidTestSuite {
isCrossMargin: true,
}
// 创建基础套件
// Create base suite
baseSuite := NewTraderTestSuite(t, trader)
return &HyperliquidTestSuite{
@@ -234,7 +234,7 @@ func NewHyperliquidTestSuite(t *testing.T) *HyperliquidTestSuite {
}
}
// Cleanup 清理资源
// Cleanup Clean up resources
func (s *HyperliquidTestSuite) Cleanup() {
if s.mockServer != nil {
s.mockServer.Close()
@@ -243,29 +243,29 @@ func (s *HyperliquidTestSuite) Cleanup() {
}
// ============================================================
// 二、使用 HyperliquidTestSuite 运行通用测试
// Part 2: Run common tests using HyperliquidTestSuite
// ============================================================
// TestHyperliquidTrader_InterfaceCompliance 测试接口兼容性
// TestHyperliquidTrader_InterfaceCompliance Test interface compliance
func TestHyperliquidTrader_InterfaceCompliance(t *testing.T) {
var _ Trader = (*HyperliquidTrader)(nil)
}
// TestHyperliquidTrader_CommonInterface 使用测试套件运行所有通用接口测试
// TestHyperliquidTrader_CommonInterface Run all common interface tests using test suite
func TestHyperliquidTrader_CommonInterface(t *testing.T) {
// 创建测试套件
// Create test suite
suite := NewHyperliquidTestSuite(t)
defer suite.Cleanup()
// 运行所有通用接口测试
// Run all common interface tests
suite.RunAllTests()
}
// ============================================================
// 三、Hyperliquid 特定功能的单元测试
// Part 3: Hyperliquid-specific feature unit tests
// ============================================================
// TestNewHyperliquidTrader 测试创建 Hyperliquid 交易器
// TestNewHyperliquidTrader Test creating Hyperliquid trader
func TestNewHyperliquidTrader(t *testing.T) {
tests := []struct {
name string
@@ -276,15 +276,15 @@ func TestNewHyperliquidTrader(t *testing.T) {
errorContains string
}{
{
name: "无效私钥格式",
name: "Invalid private key format",
privateKeyHex: "invalid_key",
walletAddr: "0x1234567890123456789012345678901234567890",
testnet: true,
wantError: true,
errorContains: "解析私钥失败",
errorContains: "Failed to parse private key",
},
{
name: "钱包地址为空",
name: "Empty wallet address",
privateKeyHex: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
walletAddr: "",
testnet: true,
@@ -315,13 +315,13 @@ func TestNewHyperliquidTrader(t *testing.T) {
}
}
// TestNewHyperliquidTrader_Success 测试成功创建交易器(需要 mock HTTP
// TestNewHyperliquidTrader_Success Test successfully creating trader (requires mock HTTP)
func TestNewHyperliquidTrader_Success(t *testing.T) {
// 创建测试用私钥
// Create test private key
privateKey, _ := crypto.HexToECDSA("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")
agentAddr := crypto.PubkeyToAddress(privateKey.PublicKey).Hex()
// 创建 mock HTTP 服务器
// Create mock HTTP server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
@@ -346,7 +346,7 @@ func TestNewHyperliquidTrader_Success(t *testing.T) {
case "clearinghouseState":
user, _ := reqBody["user"].(string)
if user == agentAddr {
// Agent 钱包余额低
// Agent wallet low balance
respBody = map[string]interface{}{
"crossMarginSummary": map[string]interface{}{
"accountValue": "5.00",
@@ -354,7 +354,7 @@ func TestNewHyperliquidTrader_Success(t *testing.T) {
"assetPositions": []interface{}{},
}
} else {
// 主钱包
// Main wallet
respBody = map[string]interface{}{
"crossMarginSummary": map[string]interface{}{
"accountValue": "10000.00",
@@ -371,17 +371,17 @@ func TestNewHyperliquidTrader_Success(t *testing.T) {
}))
defer mockServer.Close()
// 注意:这个测试会真正调用 NewHyperliquidTrader,但会失败
// 因为 hyperliquid SDK 不允许我们在构造函数中注入自定义 URL
// 所以这个测试仅用于验证参数处理逻辑
t.Skip("跳过此测试:hyperliquid SDK 在构造时会调用真实 API,无法注入 mock URL")
// Note: This test would actually call NewHyperliquidTrader, but will fail
// Because hyperliquid SDK doesn't allow us to inject custom URL in constructor
// So this test is only for verifying parameter handling logic
t.Skip("Skip this test: hyperliquid SDK calls real API during construction, cannot inject mock URL")
}
// ============================================================
// 四、工具函数单元测试(Hyperliquid 特有)
// Part 4: Utility function unit tests (Hyperliquid-specific)
// ============================================================
// TestConvertSymbolToHyperliquid 测试 symbol 转换函数
// TestConvertSymbolToHyperliquid Test symbol conversion function
func TestConvertSymbolToHyperliquid(t *testing.T) {
tests := []struct {
name string
@@ -389,17 +389,17 @@ func TestConvertSymbolToHyperliquid(t *testing.T) {
expected string
}{
{
name: "BTCUSDT转换",
name: "BTCUSDT conversion",
symbol: "BTCUSDT",
expected: "BTC",
},
{
name: "ETHUSDT转换",
name: "ETHUSDT conversion",
symbol: "ETHUSDT",
expected: "ETH",
},
{
name: "USDT后缀",
name: "No USDT suffix",
symbol: "BTC",
expected: "BTC",
},
@@ -413,7 +413,7 @@ func TestConvertSymbolToHyperliquid(t *testing.T) {
}
}
// TestAbsFloat 测试绝对值函数
// TestAbsFloat Test absolute value function
func TestAbsFloat(t *testing.T) {
tests := []struct {
name string
@@ -421,17 +421,17 @@ func TestAbsFloat(t *testing.T) {
expected float64
}{
{
name: "正数",
name: "Positive number",
input: 10.5,
expected: 10.5,
},
{
name: "负数",
name: "Negative number",
input: -10.5,
expected: 10.5,
},
{
name: "",
name: "Zero",
input: 0,
expected: 0,
},
@@ -445,7 +445,7 @@ func TestAbsFloat(t *testing.T) {
}
}
// TestHyperliquidTrader_RoundToSzDecimals 测试数量精度处理
// TestHyperliquidTrader_RoundToSzDecimals Test quantity precision handling
func TestHyperliquidTrader_RoundToSzDecimals(t *testing.T) {
trader := &HyperliquidTrader{
meta: &hyperliquid.Meta{
@@ -463,19 +463,19 @@ func TestHyperliquidTrader_RoundToSzDecimals(t *testing.T) {
expected float64
}{
{
name: "BTC_四舍五入到4位",
name: "BTC - round to 4 decimals",
coin: "BTC",
quantity: 1.23456789,
expected: 1.2346,
},
{
name: "ETH_四舍五入到3位",
name: "ETH - round to 3 decimals",
coin: "ETH",
quantity: 10.12345,
expected: 10.123,
},
{
name: "未知币种_使用默认精度4位",
name: "Unknown coin - use default 4 decimals",
coin: "UNKNOWN",
quantity: 1.23456789,
expected: 1.2346,
@@ -490,7 +490,7 @@ func TestHyperliquidTrader_RoundToSzDecimals(t *testing.T) {
}
}
// TestHyperliquidTrader_RoundPriceToSigfigs 测试价格有效数字处理
// TestHyperliquidTrader_RoundPriceToSigfigs Test price significant figures handling
func TestHyperliquidTrader_RoundPriceToSigfigs(t *testing.T) {
trader := &HyperliquidTrader{}
@@ -500,17 +500,17 @@ func TestHyperliquidTrader_RoundPriceToSigfigs(t *testing.T) {
expected float64
}{
{
name: "BTC价格_5位有效数字",
name: "BTC price - 5 significant figures",
price: 50123.456789,
expected: 50123.0,
},
{
name: "小数价格_5位有效数字",
name: "Decimal price - 5 significant figures",
price: 0.0012345678,
expected: 0.0012346,
},
{
name: "零价格",
name: "Zero price",
price: 0,
expected: 0,
},
@@ -524,7 +524,7 @@ func TestHyperliquidTrader_RoundPriceToSigfigs(t *testing.T) {
}
}
// TestHyperliquidTrader_GetSzDecimals 测试获取精度
// TestHyperliquidTrader_GetSzDecimals Test getting precision
func TestHyperliquidTrader_GetSzDecimals(t *testing.T) {
tests := []struct {
name string
@@ -533,13 +533,13 @@ func TestHyperliquidTrader_GetSzDecimals(t *testing.T) {
expected int
}{
{
name: "meta为nil_返回默认精度",
name: "meta is nil - return default precision",
meta: nil,
coin: "BTC",
expected: 4,
},
{
name: "找到BTC_返回正确精度",
name: "Found BTC - return correct precision",
meta: &hyperliquid.Meta{
Universe: []hyperliquid.AssetInfo{
{Name: "BTC", SzDecimals: 5},
@@ -549,7 +549,7 @@ func TestHyperliquidTrader_GetSzDecimals(t *testing.T) {
expected: 5,
},
{
name: "未找到币种_返回默认精度",
name: "Coin not found - return default precision",
meta: &hyperliquid.Meta{
Universe: []hyperliquid.AssetInfo{
{Name: "ETH", SzDecimals: 3},
@@ -569,7 +569,7 @@ func TestHyperliquidTrader_GetSzDecimals(t *testing.T) {
}
}
// TestHyperliquidTrader_SetMarginMode 测试设置保证金模式
// TestHyperliquidTrader_SetMarginMode Test setting margin mode
func TestHyperliquidTrader_SetMarginMode(t *testing.T) {
trader := &HyperliquidTrader{
ctx: context.Background(),
@@ -583,13 +583,13 @@ func TestHyperliquidTrader_SetMarginMode(t *testing.T) {
wantError bool
}{
{
name: "设置为全仓模式",
name: "Set to cross margin mode",
symbol: "BTCUSDT",
isCrossMargin: true,
wantError: false,
},
{
name: "设置为逐仓模式",
name: "Set to isolated margin mode",
symbol: "ETHUSDT",
isCrossMargin: false,
wantError: false,
@@ -610,7 +610,7 @@ func TestHyperliquidTrader_SetMarginMode(t *testing.T) {
}
}
// TestNewHyperliquidTrader_PrivateKeyProcessing 测试私钥处理
// TestNewHyperliquidTrader_PrivateKeyProcessing Test private key processing
func TestNewHyperliquidTrader_PrivateKeyProcessing(t *testing.T) {
tests := []struct {
name string
@@ -619,13 +619,13 @@ func TestNewHyperliquidTrader_PrivateKeyProcessing(t *testing.T) {
expectedLength int
}{
{
name: "带0x前缀的私钥",
name: "Private key with 0x prefix",
privateKeyHex: "0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
shouldStripOx: true,
expectedLength: 64,
},
{
name: "无前缀的私钥",
name: "Private key without prefix",
privateKeyHex: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
shouldStripOx: false,
expectedLength: 64,
@@ -634,7 +634,7 @@ func TestNewHyperliquidTrader_PrivateKeyProcessing(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 测试私钥前缀处理逻辑(不实际创建 trader
// Test private key prefix handling logic (without actually creating trader)
processed := tt.privateKeyHex
if len(processed) > 2 && (processed[:2] == "0x" || processed[:2] == "0X") {
processed = processed[2:]
+20 -20
View File
@@ -1,57 +1,57 @@
package trader
// Trader 交易器统一接口
// 支持多个交易平台(币安、Hyperliquid等)
// Trader Unified trader interface
// Supports multiple trading platforms (Binance, Hyperliquid, etc.)
type Trader interface {
// GetBalance 获取账户余额
// GetBalance Get account balance
GetBalance() (map[string]interface{}, error)
// GetPositions 获取所有持仓
// GetPositions Get all positions
GetPositions() ([]map[string]interface{}, error)
// OpenLong 开多仓
// OpenLong Open long position
OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error)
// OpenShort 开空仓
// OpenShort Open short position
OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error)
// CloseLong 平多仓(quantity=0表示全部平仓)
// CloseLong Close long position (quantity=0 means close all)
CloseLong(symbol string, quantity float64) (map[string]interface{}, error)
// CloseShort 平空仓(quantity=0表示全部平仓)
// CloseShort Close short position (quantity=0 means close all)
CloseShort(symbol string, quantity float64) (map[string]interface{}, error)
// SetLeverage 设置杠杆
// SetLeverage Set leverage
SetLeverage(symbol string, leverage int) error
// SetMarginMode 设置仓位模式 (true=全仓, false=逐仓)
// SetMarginMode Set position mode (true=cross margin, false=isolated margin)
SetMarginMode(symbol string, isCrossMargin bool) error
// GetMarketPrice 获取市场价格
// GetMarketPrice Get market price
GetMarketPrice(symbol string) (float64, error)
// SetStopLoss 设置止损单
// SetStopLoss Set stop-loss order
SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error
// SetTakeProfit 设置止盈单
// SetTakeProfit Set take-profit order
SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error
// CancelStopLossOrders 仅取消止损单(修复 BUG:调整止损时不删除止盈)
// CancelStopLossOrders Cancel only stop-loss orders (BUG fix: don't delete take-profit when adjusting stop-loss)
CancelStopLossOrders(symbol string) error
// CancelTakeProfitOrders 仅取消止盈单(修复 BUG:调整止盈时不删除止损)
// CancelTakeProfitOrders Cancel only take-profit orders (BUG fix: don't delete stop-loss when adjusting take-profit)
CancelTakeProfitOrders(symbol string) error
// CancelAllOrders 取消该币种的所有挂单
// CancelAllOrders Cancel all pending orders for this symbol
CancelAllOrders(symbol string) error
// CancelStopOrders 取消该币种的止盈/止损单(用于调整止盈止损位置)
// CancelStopOrders Cancel stop-loss/take-profit orders for this symbol (for adjusting stop-loss/take-profit positions)
CancelStopOrders(symbol string) error
// FormatQuantity 格式化数量到正确的精度
// FormatQuantity Format quantity to correct precision
FormatQuantity(symbol string, quantity float64) (string, error)
// GetOrderStatus 获取订单状态
// 返回: status(FILLED/NEW/CANCELED), avgPrice, executedQty, commission
// GetOrderStatus Get order status
// Returns: status(FILLED/NEW/CANCELED), avgPrice, executedQty, commission
GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error)
}
+47 -47
View File
@@ -7,29 +7,29 @@ import (
"net/http"
)
// AccountBalance 账户余额信息
// AccountBalance Account balance information
type AccountBalance struct {
TotalEquity float64 `json:"total_equity"` // 总权益
AvailableBalance float64 `json:"available_balance"` // 可用余额
MarginUsed float64 `json:"margin_used"` // 已用保证金
UnrealizedPnL float64 `json:"unrealized_pnl"` // 未实现盈亏
MaintenanceMargin float64 `json:"maintenance_margin"` // 维持保证金
TotalEquity float64 `json:"total_equity"` // Total equity
AvailableBalance float64 `json:"available_balance"` // Available balance
MarginUsed float64 `json:"margin_used"` // Used margin
UnrealizedPnL float64 `json:"unrealized_pnl"` // Unrealized PnL
MaintenanceMargin float64 `json:"maintenance_margin"` // Maintenance margin
}
// Position 持仓信息
// Position Position information
type Position struct {
Symbol string `json:"symbol"` // 交易对
Side string `json:"side"` // "long" "short"
Size float64 `json:"size"` // 持仓大小
EntryPrice float64 `json:"entry_price"` // 开仓均价
MarkPrice float64 `json:"mark_price"` // 标记价格
LiquidationPrice float64 `json:"liquidation_price"` // 强平价格
UnrealizedPnL float64 `json:"unrealized_pnl"` // 未实现盈亏
Leverage float64 `json:"leverage"` // 杠杆倍数
MarginUsed float64 `json:"margin_used"` // 已用保证金
Symbol string `json:"symbol"` // Trading pair
Side string `json:"side"` // "long" or "short"
Size float64 `json:"size"` // Position size
EntryPrice float64 `json:"entry_price"` // Average entry price
MarkPrice float64 `json:"mark_price"` // Mark price
LiquidationPrice float64 `json:"liquidation_price"` // Liquidation price
UnrealizedPnL float64 `json:"unrealized_pnl"` // Unrealized PnL
Leverage float64 `json:"leverage"` // Leverage multiplier
MarginUsed float64 `json:"margin_used"` // Used margin
}
// GetBalance 获取账户余额(实现 Trader 接口)
// GetBalance Get account balance (implements Trader interface)
func (t *LighterTrader) GetBalance() (map[string]interface{}, error) {
balance, err := t.GetAccountBalance()
if err != nil {
@@ -45,10 +45,10 @@ func (t *LighterTrader) GetBalance() (map[string]interface{}, error) {
}, nil
}
// GetAccountBalance 获取账户详细余额信息
// GetAccountBalance Get detailed account balance information
func (t *LighterTrader) GetAccountBalance() (*AccountBalance, error) {
if err := t.ensureAuthToken(); err != nil {
return nil, fmt.Errorf("认证令牌无效: %w", err)
return nil, fmt.Errorf("invalid auth token: %w", err)
}
t.accountMutex.RLock()
@@ -62,7 +62,7 @@ func (t *LighterTrader) GetAccountBalance() (*AccountBalance, error) {
return nil, err
}
// 添加认证头
// Add auth header
t.accountMutex.RLock()
req.Header.Set("Authorization", t.authToken)
t.accountMutex.RUnlock()
@@ -79,21 +79,21 @@ func (t *LighterTrader) GetAccountBalance() (*AccountBalance, error) {
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("获取余额失败 (status %d): %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("failed to get balance (status %d): %s", resp.StatusCode, string(body))
}
var balance AccountBalance
if err := json.Unmarshal(body, &balance); err != nil {
return nil, fmt.Errorf("解析余额响应失败: %w", err)
return nil, fmt.Errorf("failed to parse balance response: %w", err)
}
return &balance, nil
}
// GetPositionsRaw 获取所有持仓(返回原始类型)
// GetPositionsRaw Get all positions (returns raw type)
func (t *LighterTrader) GetPositionsRaw(symbol string) ([]Position, error) {
if err := t.ensureAuthToken(); err != nil {
return nil, fmt.Errorf("认证令牌无效: %w", err)
return nil, fmt.Errorf("invalid auth token: %w", err)
}
t.accountMutex.RLock()
@@ -110,7 +110,7 @@ func (t *LighterTrader) GetPositionsRaw(symbol string) ([]Position, error) {
return nil, err
}
// 添加认证头
// Add auth header
t.accountMutex.RLock()
req.Header.Set("Authorization", t.authToken)
t.accountMutex.RUnlock()
@@ -127,18 +127,18 @@ func (t *LighterTrader) GetPositionsRaw(symbol string) ([]Position, error) {
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("获取持仓失败 (status %d): %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("failed to get positions (status %d): %s", resp.StatusCode, string(body))
}
var positions []Position
if err := json.Unmarshal(body, &positions); err != nil {
return nil, fmt.Errorf("解析持仓响应失败: %w", err)
return nil, fmt.Errorf("failed to parse positions response: %w", err)
}
return positions, nil
}
// GetPositions 获取所有持仓(实现 Trader 接口)
// GetPositions Get all positions (implements Trader interface)
func (t *LighterTrader) GetPositions() ([]map[string]interface{}, error) {
positions, err := t.GetPositionsRaw("")
if err != nil {
@@ -163,25 +163,25 @@ func (t *LighterTrader) GetPositions() ([]map[string]interface{}, error) {
return result, nil
}
// GetPosition 获取指定币种的持仓
// GetPosition Get position for specified symbol
func (t *LighterTrader) GetPosition(symbol string) (*Position, error) {
positions, err := t.GetPositionsRaw(symbol)
if err != nil {
return nil, err
}
// 找到指定币种的持仓
// Find position for specified symbol
for _, pos := range positions {
if pos.Symbol == symbol && pos.Size > 0 {
return &pos, nil
}
}
// 无持仓
// No position
return nil, nil
}
// GetMarketPrice 获取市场价格
// GetMarketPrice Get market price
func (t *LighterTrader) GetMarketPrice(symbol string) (float64, error) {
endpoint := fmt.Sprintf("%s/api/v1/market/ticker?symbol=%s", t.baseURL, symbol)
@@ -202,24 +202,24 @@ func (t *LighterTrader) GetMarketPrice(symbol string) (float64, error) {
}
if resp.StatusCode != http.StatusOK {
return 0, fmt.Errorf("获取市场价格失败 (status %d): %s", resp.StatusCode, string(body))
return 0, fmt.Errorf("failed to get market price (status %d): %s", resp.StatusCode, string(body))
}
var ticker map[string]interface{}
if err := json.Unmarshal(body, &ticker); err != nil {
return 0, fmt.Errorf("解析价格响应失败: %w", err)
return 0, fmt.Errorf("failed to parse price response: %w", err)
}
// 提取最新价格
// Extract latest price
price, err := SafeFloat64(ticker, "last_price")
if err != nil {
return 0, fmt.Errorf("无法获取价格: %w", err)
return 0, fmt.Errorf("failed to get price: %w", err)
}
return price, nil
}
// GetAccountInfo 获取账户完整信息(用于AutoTrader
// GetAccountInfo Get complete account information (for AutoTrader)
func (t *LighterTrader) GetAccountInfo() (map[string]interface{}, error) {
balance, err := t.GetAccountBalance()
if err != nil {
@@ -231,7 +231,7 @@ func (t *LighterTrader) GetAccountInfo() (map[string]interface{}, error) {
return nil, err
}
// 构建返回信息
// Build return information
info := map[string]interface{}{
"total_equity": balance.TotalEquity,
"available_balance": balance.AvailableBalance,
@@ -245,27 +245,27 @@ func (t *LighterTrader) GetAccountInfo() (map[string]interface{}, error) {
return info, nil
}
// SetLeverage 设置杠杆倍数
// SetLeverage Set leverage multiplier
func (t *LighterTrader) SetLeverage(symbol string, leverage int) error {
if err := t.ensureAuthToken(); err != nil {
return fmt.Errorf("认证令牌无效: %w", err)
return fmt.Errorf("invalid auth token: %w", err)
}
// TODO: 实现设置杠杆的API调用
// LIGHTER可能需要签名交易来设置杠杆
// TODO: Implement set leverage API call
// LIGHTER may require signed transaction to set leverage
return fmt.Errorf("SetLeverage未实现")
return fmt.Errorf("SetLeverage not implemented")
}
// GetMaxLeverage 获取最大杠杆倍数
// GetMaxLeverage Get maximum leverage multiplier
func (t *LighterTrader) GetMaxLeverage(symbol string) (int, error) {
// LIGHTER支持BTC/ETH最高50x杠杆
// TODO: 从API获取实际限制
// LIGHTER supports up to 50x leverage for BTC/ETH
// TODO: Get actual limits from API
if symbol == "BTC-PERP" || symbol == "ETH-PERP" {
return 50, nil
}
// 其他币种默认20x
// Default 20x for other symbols
return 20, nil
}
+59 -59
View File
@@ -9,19 +9,19 @@ import (
"net/http"
)
// CreateOrderRequest 创建订单请求
// CreateOrderRequest Create order request
type CreateOrderRequest struct {
Symbol string `json:"symbol"` // 交易对,如 "BTC-PERP"
Side string `json:"side"` // "buy" "sell"
OrderType string `json:"order_type"` // "market" "limit"
Quantity float64 `json:"quantity"` // 数量
Price float64 `json:"price"` // 价格(限价单必填)
ReduceOnly bool `json:"reduce_only"` // 是否只减仓
Symbol string `json:"symbol"` // Trading pair, e.g. "BTC-PERP"
Side string `json:"side"` // "buy" or "sell"
OrderType string `json:"order_type"` // "market" or "limit"
Quantity float64 `json:"quantity"` // Quantity
Price float64 `json:"price"` // Price (required for limit orders)
ReduceOnly bool `json:"reduce_only"` // Reduce-only flag
TimeInForce string `json:"time_in_force"` // "GTC", "IOC", "FOK"
PostOnly bool `json:"post_only"` // 是否只做Maker
PostOnly bool `json:"post_only"` // Post-only (maker only)
}
// OrderResponse 订单响应
// OrderResponse Order response
type OrderResponse struct {
OrderID string `json:"order_id"`
Symbol string `json:"symbol"`
@@ -35,13 +35,13 @@ type OrderResponse struct {
CreateTime int64 `json:"create_time"`
}
// CreateOrder 创建订单(市价或限价)
// CreateOrder Create order (market or limit)
func (t *LighterTrader) CreateOrder(symbol, side string, quantity, price float64, orderType string) (string, error) {
if err := t.ensureAuthToken(); err != nil {
return "", fmt.Errorf("认证令牌无效: %w", err)
return "", fmt.Errorf("invalid auth token: %w", err)
}
// 构建订单请求
// Build order request
req := CreateOrderRequest{
Symbol: symbol,
Side: side,
@@ -56,41 +56,41 @@ func (t *LighterTrader) CreateOrder(symbol, side string, quantity, price float64
req.Price = price
}
// 发送订单
// Send order
orderResp, err := t.sendOrder(req)
if err != nil {
return "", err
}
logger.Infof("✓ LIGHTER订单已创建 - ID: %s, Symbol: %s, Side: %s, Qty: %.4f",
logger.Infof("✓ LIGHTER order created - ID: %s, Symbol: %s, Side: %s, Qty: %.4f",
orderResp.OrderID, symbol, side, quantity)
return orderResp.OrderID, nil
}
// sendOrder 发送订单到LIGHTER API
// sendOrder Send order to LIGHTER API
func (t *LighterTrader) sendOrder(orderReq CreateOrderRequest) (*OrderResponse, error) {
endpoint := fmt.Sprintf("%s/api/v1/order", t.baseURL)
// 序列化请求
// Serialize request
jsonData, err := json.Marshal(orderReq)
if err != nil {
return nil, err
}
// 创建HTTP请求
// Create HTTP request
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
// 添加请求头
// Add request headers
req.Header.Set("Content-Type", "application/json")
t.accountMutex.RLock()
req.Header.Set("Authorization", t.authToken)
t.accountMutex.RUnlock()
// 发送请求
// Send request
resp, err := t.client.Do(req)
if err != nil {
return nil, err
@@ -103,21 +103,21 @@ func (t *LighterTrader) sendOrder(orderReq CreateOrderRequest) (*OrderResponse,
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("创建订单失败 (status %d): %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("failed to create order (status %d): %s", resp.StatusCode, string(body))
}
var orderResp OrderResponse
if err := json.Unmarshal(body, &orderResp); err != nil {
return nil, fmt.Errorf("解析订单响应失败: %w", err)
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
return &orderResp, nil
}
// CancelOrder 取消订单
// CancelOrder Cancel order
func (t *LighterTrader) CancelOrder(symbol, orderID string) error {
if err := t.ensureAuthToken(); err != nil {
return fmt.Errorf("认证令牌无效: %w", err)
return fmt.Errorf("invalid auth token: %w", err)
}
endpoint := fmt.Sprintf("%s/api/v1/order/%s", t.baseURL, orderID)
@@ -127,7 +127,7 @@ func (t *LighterTrader) CancelOrder(symbol, orderID string) error {
return err
}
// 添加认证头
// Add auth header
t.accountMutex.RLock()
req.Header.Set("Authorization", t.authToken)
t.accountMutex.RUnlock()
@@ -140,45 +140,45 @@ func (t *LighterTrader) CancelOrder(symbol, orderID string) error {
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("取消订单失败 (status %d): %s", resp.StatusCode, string(body))
return fmt.Errorf("failed to cancel order (status %d): %s", resp.StatusCode, string(body))
}
logger.Infof("✓ LIGHTER订单已取消 - ID: %s", orderID)
logger.Infof("✓ LIGHTER order cancelled - ID: %s", orderID)
return nil
}
// CancelAllOrders 取消所有订单
// CancelAllOrders Cancel all orders
func (t *LighterTrader) CancelAllOrders(symbol string) error {
if err := t.ensureAuthToken(); err != nil {
return fmt.Errorf("认证令牌无效: %w", err)
return fmt.Errorf("invalid auth token: %w", err)
}
// 获取所有活跃订单
// Get all active orders
orders, err := t.GetActiveOrders(symbol)
if err != nil {
return fmt.Errorf("获取活跃订单失败: %w", err)
return fmt.Errorf("failed to get active orders: %w", err)
}
if len(orders) == 0 {
logger.Infof("✓ LIGHTER - 无需取消订单(无活跃订单)")
logger.Infof("✓ LIGHTER - no orders to cancel (no active orders)")
return nil
}
// 批量取消
// Cancel in batch
for _, order := range orders {
if err := t.CancelOrder(symbol, order.OrderID); err != nil {
logger.Infof("⚠️ 取消订单失败 (ID: %s): %v", order.OrderID, err)
logger.Infof("⚠️ Failed to cancel order (ID: %s): %v", order.OrderID, err)
}
}
logger.Infof("✓ LIGHTER - 已取消 %d 个订单", len(orders))
logger.Infof("✓ LIGHTER - cancelled %d orders", len(orders))
return nil
}
// GetActiveOrders 获取活跃订单
// GetActiveOrders Get active orders
func (t *LighterTrader) GetActiveOrders(symbol string) ([]OrderResponse, error) {
if err := t.ensureAuthToken(); err != nil {
return nil, fmt.Errorf("认证令牌无效: %w", err)
return nil, fmt.Errorf("invalid auth token: %w", err)
}
t.accountMutex.RLock()
@@ -195,7 +195,7 @@ func (t *LighterTrader) GetActiveOrders(symbol string) ([]OrderResponse, error)
return nil, err
}
// 添加认证头
// Add auth header
t.accountMutex.RLock()
req.Header.Set("Authorization", t.authToken)
t.accountMutex.RUnlock()
@@ -212,21 +212,21 @@ func (t *LighterTrader) GetActiveOrders(symbol string) ([]OrderResponse, error)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("获取活跃订单失败 (status %d): %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("failed to get active orders (status %d): %s", resp.StatusCode, string(body))
}
var orders []OrderResponse
if err := json.Unmarshal(body, &orders); err != nil {
return nil, fmt.Errorf("解析订单列表失败: %w", err)
return nil, fmt.Errorf("failed to parse order list: %w", err)
}
return orders, nil
}
// GetOrderStatus 获取订单状态(实现 Trader 接口)
// GetOrderStatus Get order status (implements Trader interface)
func (t *LighterTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
if err := t.ensureAuthToken(); err != nil {
return nil, fmt.Errorf("认证令牌无效: %w", err)
return nil, fmt.Errorf("invalid auth token: %w", err)
}
endpoint := fmt.Sprintf("%s/api/v1/order/%s", t.baseURL, orderID)
@@ -236,7 +236,7 @@ func (t *LighterTrader) GetOrderStatus(symbol string, orderID string) (map[strin
return nil, err
}
// 添加认证头
// Add auth header
t.accountMutex.RLock()
req.Header.Set("Authorization", t.authToken)
t.accountMutex.RUnlock()
@@ -253,15 +253,15 @@ func (t *LighterTrader) GetOrderStatus(symbol string, orderID string) (map[strin
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("获取订单状态失败 (status %d): %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("failed to get order status (status %d): %s", resp.StatusCode, string(body))
}
var order OrderResponse
if err := json.Unmarshal(body, &order); err != nil {
return nil, fmt.Errorf("解析订单响应失败: %w", err)
return nil, fmt.Errorf("failed to parse order response: %w", err)
}
// 转换状态为统一格式
// Convert status to unified format
unifiedStatus := order.Status
switch order.Status {
case "filled":
@@ -281,43 +281,43 @@ func (t *LighterTrader) GetOrderStatus(symbol string, orderID string) (map[strin
}, nil
}
// CancelStopLossOrders 仅取消止损单(LIGHTER 暂无法区分,取消所有止盈止损单)
// CancelStopLossOrders Cancel stop-loss orders only (LIGHTER cannot distinguish, cancels all TP/SL orders)
func (t *LighterTrader) CancelStopLossOrders(symbol string) error {
// LIGHTER 暂时无法区分止损和止盈单,取消所有止盈止损单
logger.Infof(" ⚠️ LIGHTER 无法区分止损/止盈单,将取消所有止盈止损单")
// LIGHTER currently cannot distinguish between stop-loss and take-profit orders, cancel all TP/SL orders
logger.Infof(" ⚠️ LIGHTER cannot distinguish SL/TP orders, will cancel all TP/SL orders")
return t.CancelStopOrders(symbol)
}
// CancelTakeProfitOrders 仅取消止盈单(LIGHTER 暂无法区分,取消所有止盈止损单)
// CancelTakeProfitOrders Cancel take-profit orders only (LIGHTER cannot distinguish, cancels all TP/SL orders)
func (t *LighterTrader) CancelTakeProfitOrders(symbol string) error {
// LIGHTER 暂时无法区分止损和止盈单,取消所有止盈止损单
logger.Infof(" ⚠️ LIGHTER 无法区分止损/止盈单,将取消所有止盈止损单")
// LIGHTER currently cannot distinguish between stop-loss and take-profit orders, cancel all TP/SL orders
logger.Infof(" ⚠️ LIGHTER cannot distinguish SL/TP orders, will cancel all TP/SL orders")
return t.CancelStopOrders(symbol)
}
// CancelStopOrders 取消该币种的止盈/止损单
// CancelStopOrders Cancel take-profit/stop-loss orders for this symbol
func (t *LighterTrader) CancelStopOrders(symbol string) error {
if err := t.ensureAuthToken(); err != nil {
return fmt.Errorf("认证令牌无效: %w", err)
return fmt.Errorf("invalid auth token: %w", err)
}
// 获取活跃订单
// Get active orders
orders, err := t.GetActiveOrders(symbol)
if err != nil {
return fmt.Errorf("获取活跃订单失败: %w", err)
return fmt.Errorf("failed to get active orders: %w", err)
}
canceledCount := 0
for _, order := range orders {
// TODO: 需要检查订单类型,只取消止盈止损单
// 暂时取消所有订单
// TODO: Need to check order type, only cancel TP/SL orders
// Currently cancelling all orders
if err := t.CancelOrder(symbol, order.OrderID); err != nil {
logger.Infof("⚠️ 取消订单失败 (ID: %s): %v", order.OrderID, err)
logger.Infof("⚠️ Failed to cancel order (ID: %s): %v", order.OrderID, err)
} else {
canceledCount++
}
}
logger.Infof("✓ LIGHTER - 已取消 %d 个止盈止损单", canceledCount)
logger.Infof("✓ LIGHTER - cancelled %d TP/SL orders", canceledCount)
return nil
}
+48 -48
View File
@@ -16,56 +16,56 @@ import (
"github.com/ethereum/go-ethereum/crypto"
)
// LighterTrader LIGHTER DEX交易器
// LIGHTER是基于Ethereum L2的永续合约DEX,使用zk-rollup技术
// LighterTrader LIGHTER DEX trader
// LIGHTER is an Ethereum L2-based perpetual contract DEX using zk-rollup technology
type LighterTrader struct {
ctx context.Context
privateKey *ecdsa.PrivateKey
walletAddr string // Ethereum钱包地址
walletAddr string // Ethereum wallet address
client *http.Client
baseURL string
testnet bool
// 账户信息缓存
accountIndex int // LIGHTER账户索引
apiKey string // API密钥(从私钥派生)
authToken string // 认证令牌(8小时有效期)
// Account information cache
accountIndex int // LIGHTER account index
apiKey string // API key (derived from private key)
authToken string // Authentication token (8-hour validity)
tokenExpiry time.Time
accountMutex sync.RWMutex
// 市场信息缓存
// Market information cache
symbolPrecision map[string]SymbolPrecision
precisionMutex sync.RWMutex
}
// LighterConfig LIGHTER配置
// LighterConfig LIGHTER configuration
type LighterConfig struct {
PrivateKeyHex string
WalletAddr string
Testnet bool
}
// NewLighterTrader 创建LIGHTER交易器
// NewLighterTrader Create LIGHTER trader
func NewLighterTrader(privateKeyHex string, walletAddr string, testnet bool) (*LighterTrader, error) {
// 去掉私钥的 0x 前缀(如果有)
// Remove 0x prefix from private key (if present)
privateKeyHex = strings.TrimPrefix(strings.ToLower(privateKeyHex), "0x")
// 解析私钥
// Parse private key
privateKey, err := crypto.HexToECDSA(privateKeyHex)
if err != nil {
return nil, fmt.Errorf("解析私钥失败: %w", err)
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
// 从私钥派生钱包地址(如果未提供)
// Derive wallet address from private key (if not provided)
if walletAddr == "" {
walletAddr = crypto.PubkeyToAddress(*privateKey.Public().(*ecdsa.PublicKey)).Hex()
logger.Infof("✓ 从私钥派生钱包地址: %s", walletAddr)
logger.Infof("✓ Derived wallet address from private key: %s", walletAddr)
}
// 选择API URL
// Select API URL
baseURL := "https://mainnet.zklighter.elliot.ai"
if testnet {
baseURL = "https://testnet.zklighter.elliot.ai" // TODO: 确认testnet URL
baseURL = "https://testnet.zklighter.elliot.ai" // TODO: Confirm testnet URL
}
trader := &LighterTrader{
@@ -78,39 +78,39 @@ func NewLighterTrader(privateKeyHex string, walletAddr string, testnet bool) (*L
symbolPrecision: make(map[string]SymbolPrecision),
}
logger.Infof("✓ LIGHTER交易器初始化成功 (testnet=%v, wallet=%s)", testnet, walletAddr)
logger.Infof("✓ LIGHTER trader initialized successfully (testnet=%v, wallet=%s)", testnet, walletAddr)
// 初始化账户信息(获取账户索引和API密钥)
// Initialize account information (get account index and API key)
if err := trader.initializeAccount(); err != nil {
return nil, fmt.Errorf("初始化账户失败: %w", err)
return nil, fmt.Errorf("failed to initialize account: %w", err)
}
return trader, nil
}
// initializeAccount 初始化账户信息
// initializeAccount Initialize account information
func (t *LighterTrader) initializeAccount() error {
// 1. 获取账户信息(通过L1地址)
// 1. Get account information (by L1 address)
accountInfo, err := t.getAccountByL1Address()
if err != nil {
return fmt.Errorf("获取账户信息失败: %w", err)
return fmt.Errorf("failed to get account information: %w", err)
}
t.accountMutex.Lock()
t.accountIndex = accountInfo["index"].(int)
t.accountMutex.Unlock()
logger.Infof("✓ LIGHTER账户索引: %d", t.accountIndex)
logger.Infof("✓ LIGHTER account index: %d", t.accountIndex)
// 2. 生成认证令牌(有效期8小时)
// 2. Generate authentication token (8-hour validity)
if err := t.refreshAuthToken(); err != nil {
return fmt.Errorf("生成认证令牌失败: %w", err)
return fmt.Errorf("failed to generate auth token: %w", err)
}
return nil
}
// getAccountByL1Address 通过Ethereum地址获取LIGHTER账户信息
// getAccountByL1Address Get LIGHTER account information by Ethereum address
func (t *LighterTrader) getAccountByL1Address() (map[string]interface{}, error) {
endpoint := fmt.Sprintf("%s/api/v1/account/by/l1/%s", t.baseURL, t.walletAddr)
@@ -131,50 +131,50 @@ func (t *LighterTrader) getAccountByL1Address() (map[string]interface{}, error)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API错误 (status %d): %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return result, nil
}
// refreshAuthToken 刷新认证令牌
// refreshAuthToken Refresh authentication token
func (t *LighterTrader) refreshAuthToken() error {
// TODO: 实现认证令牌生成逻辑
// 参考 lighter-python SDK 的实现
// 需要签名特定消息并提交到API
// TODO: Implement authentication token generation logic
// Reference lighter-python SDK implementation
// Need to sign specific message and submit to API
t.accountMutex.Lock()
defer t.accountMutex.Unlock()
// 临时实现:设置过期时间为8小时后
// Temporary implementation: set expiry time to 8 hours from now
t.tokenExpiry = time.Now().Add(8 * time.Hour)
logger.Infof("✓ 认证令牌已生成(有效期至: %s", t.tokenExpiry.Format(time.RFC3339))
logger.Infof("✓ Auth token generated (valid until: %s)", t.tokenExpiry.Format(time.RFC3339))
return nil
}
// ensureAuthToken 确保认证令牌有效
// ensureAuthToken Ensure authentication token is valid
func (t *LighterTrader) ensureAuthToken() error {
t.accountMutex.RLock()
expired := time.Now().After(t.tokenExpiry.Add(-30 * time.Minute)) // 提前30分钟刷新
expired := time.Now().After(t.tokenExpiry.Add(-30 * time.Minute)) // Refresh 30 minutes early
t.accountMutex.RUnlock()
if expired {
logger.Info("🔄 认证令牌即将过期,刷新中...")
logger.Info("🔄 Auth token expiring soon, refreshing...")
return t.refreshAuthToken()
}
return nil
}
// signMessage 签名消息(Ethereum签名)
// signMessage Sign message (Ethereum signature)
func (t *LighterTrader) signMessage(message []byte) (string, error) {
// 使用Ethereum个人签名格式
// Use Ethereum personal sign format
prefix := fmt.Sprintf("\x19Ethereum Signed Message:\n%d", len(message))
prefixedMessage := append([]byte(prefix), message...)
@@ -184,7 +184,7 @@ func (t *LighterTrader) signMessage(message []byte) (string, error) {
return "", err
}
// 调整v值(Ethereum格式)
// Adjust v value (Ethereum format)
if signature[64] < 27 {
signature[64] += 27
}
@@ -192,24 +192,24 @@ func (t *LighterTrader) signMessage(message []byte) (string, error) {
return "0x" + hex.EncodeToString(signature), nil
}
// GetName 获取交易器名称
// GetName Get trader name
func (t *LighterTrader) GetName() string {
return "LIGHTER"
}
// GetExchangeType 获取交易所类型
// GetExchangeType Get exchange type
func (t *LighterTrader) GetExchangeType() string {
return "lighter"
}
// Close 关闭交易器
// Close Close trader
func (t *LighterTrader) Close() error {
logger.Info("✓ LIGHTER交易器已关闭")
logger.Info("✓ LIGHTER trader closed")
return nil
}
// Run 运行交易器(实现Trader接口)
// Run Run trader (implements Trader interface)
func (t *LighterTrader) Run() error {
logger.Info("⚠️ LIGHTER交易器的Run方法应由AutoTrader调用")
return fmt.Errorf("请使用AutoTrader管理交易器生命周期")
logger.Info("⚠️ LIGHTER trader's Run method should be called by AutoTrader")
return fmt.Errorf("please use AutoTrader to manage trader lifecycle")
}
+20 -20
View File
@@ -12,20 +12,20 @@ import (
)
// ============================================================
// LIGHTER V1 测试套件
// LIGHTER V1 Test Suite
// ============================================================
// TestLighterTrader_NewTrader 测试创建LIGHTER交易器
// TestLighterTrader_NewTrader Test creating LIGHTER trader
func TestLighterTrader_NewTrader(t *testing.T) {
t.Run("无效私钥", func(t *testing.T) {
t.Run("Invalid private key", func(t *testing.T) {
trader, err := NewLighterTrader("invalid_key", "", true)
assert.Error(t, err)
assert.Nil(t, trader)
t.Logf("✅ Invalid private key correctly rejected")
})
t.Run("有效私钥格式验证", func(t *testing.T) {
// 只验证私钥解析,不调用真实 API
t.Run("Valid private key format verification", func(t *testing.T) {
// Only verify private key parsing, don't call real API
testL1Key := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
privateKey, err := crypto.HexToECDSA(testL1Key)
assert.NoError(t, err)
@@ -37,7 +37,7 @@ func TestLighterTrader_NewTrader(t *testing.T) {
})
}
// createMockLighterServer 创建 mock LIGHTER API 服务器
// createMockLighterServer Create mock LIGHTER API server
func createMockLighterServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
@@ -106,7 +106,7 @@ func createMockLighterServer() *httptest.Server {
}))
}
// createMockLighterTrader 创建带 mock server 的 LIGHTER trader
// createMockLighterTrader Create LIGHTER trader with mock server
func createMockLighterTrader(t *testing.T, mockServer *httptest.Server) *LighterTrader {
testL1Key := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
privateKey, err := crypto.HexToECDSA(testL1Key)
@@ -125,7 +125,7 @@ func createMockLighterTrader(t *testing.T, mockServer *httptest.Server) *Lighter
return trader
}
// TestLighterTrader_GetBalance 测试获取余额
// TestLighterTrader_GetBalance Test getting balance
func TestLighterTrader_GetBalance(t *testing.T) {
t.Skip("Skipping Lighter tests until mock server endpoints are completed")
mockServer := createMockLighterServer()
@@ -140,7 +140,7 @@ func TestLighterTrader_GetBalance(t *testing.T) {
t.Logf("✅ GetBalance: %+v", balance)
}
// TestLighterTrader_GetPositions 测试获取持仓
// TestLighterTrader_GetPositions Test getting positions
func TestLighterTrader_GetPositions(t *testing.T) {
t.Skip("Skipping Lighter tests until mock server endpoints are completed")
mockServer := createMockLighterServer()
@@ -155,7 +155,7 @@ func TestLighterTrader_GetPositions(t *testing.T) {
t.Logf("✅ GetPositions: found %d positions", len(positions))
}
// TestLighterTrader_GetMarketPrice 测试获取市场价格
// TestLighterTrader_GetMarketPrice Test getting market price
func TestLighterTrader_GetMarketPrice(t *testing.T) {
t.Skip("Skipping Lighter tests until mock server endpoints are completed")
mockServer := createMockLighterServer()
@@ -170,7 +170,7 @@ func TestLighterTrader_GetMarketPrice(t *testing.T) {
t.Logf("✅ GetMarketPrice(BTC): %.2f", price)
}
// TestLighterTrader_FormatQuantity 测试格式化数量
// TestLighterTrader_FormatQuantity Test formatting quantity
func TestLighterTrader_FormatQuantity(t *testing.T) {
mockServer := createMockLighterServer()
defer mockServer.Close()
@@ -184,7 +184,7 @@ func TestLighterTrader_FormatQuantity(t *testing.T) {
t.Logf("✅ FormatQuantity: %s", result)
}
// TestLighterTrader_GetExchangeType 测试获取交易所类型
// TestLighterTrader_GetExchangeType Test getting exchange type
func TestLighterTrader_GetExchangeType(t *testing.T) {
mockServer := createMockLighterServer()
defer mockServer.Close()
@@ -197,45 +197,45 @@ func TestLighterTrader_GetExchangeType(t *testing.T) {
t.Logf("✅ GetExchangeType: %s", exchangeType)
}
// TestLighterTrader_InvalidQuantity 测试无效数量验证
// TestLighterTrader_InvalidQuantity Test invalid quantity validation
func TestLighterTrader_InvalidQuantity(t *testing.T) {
mockServer := createMockLighterServer()
defer mockServer.Close()
trader := createMockLighterTrader(t, mockServer)
// 测试零数量
// Test zero quantity
_, err := trader.OpenLong("BTC", 0, 10)
assert.Error(t, err)
// 测试负数量
// Test negative quantity
_, err = trader.OpenLong("BTC", -0.1, 10)
assert.Error(t, err)
t.Logf("✅ Invalid quantity validation working")
}
// TestLighterTrader_InvalidLeverage 测试无效杠杆验证
// TestLighterTrader_InvalidLeverage Test invalid leverage validation
func TestLighterTrader_InvalidLeverage(t *testing.T) {
mockServer := createMockLighterServer()
defer mockServer.Close()
trader := createMockLighterTrader(t, mockServer)
// 测试零杠杆
// Test zero leverage
_, err := trader.OpenLong("BTC", 0.1, 0)
assert.Error(t, err)
// 测试负杠杆
// Test negative leverage
_, err = trader.OpenLong("BTC", 0.1, -10)
assert.Error(t, err)
t.Logf("✅ Invalid leverage validation working")
}
// TestLighterTrader_HelperFunctions 测试辅助函数
// TestLighterTrader_HelperFunctions Test helper functions
func TestLighterTrader_HelperFunctions(t *testing.T) {
// 测试 SafeFloat64
// Test SafeFloat64
data := map[string]interface{}{
"float_val": 123.45,
"string_val": "678.90",
+71 -71
View File
@@ -18,68 +18,68 @@ import (
"github.com/ethereum/go-ethereum/crypto"
)
// AccountInfo LIGHTER 賬戶信息
// AccountInfo LIGHTER account information
type AccountInfo struct {
AccountIndex int64 `json:"account_index"`
L1Address string `json:"l1_address"`
// 其他字段可以根據實際 API 響應添加
// Other fields can be added based on actual API response
}
// LighterTraderV2 使用官方 lighter-go SDK 的新實現
// LighterTraderV2 New implementation using official lighter-go SDK
type LighterTraderV2 struct {
ctx context.Context
privateKey *ecdsa.PrivateKey // L1 錢包私鑰(用於識別賬戶)
walletAddr string // Ethereum 錢包地址
privateKey *ecdsa.PrivateKey // L1 wallet private key (for account identification)
walletAddr string // Ethereum wallet address
client *http.Client
baseURL string
testnet bool
chainID uint32
// SDK 客戶端
// SDK clients
httpClient lighterClient.MinimalHTTPClient
txClient *lighterClient.TxClient
// API Key 管理
apiKeyPrivateKey string // 40字節的 API Key 私鑰(用於簽名交易)
apiKeyIndex uint8 // API Key 索引(默認 0
accountIndex int64 // 賬戶索引
// API Key management
apiKeyPrivateKey string // 40-byte API Key private key (for signing transactions)
apiKeyIndex uint8 // API Key index (default 0)
accountIndex int64 // Account index
// 認證令牌
// Authentication token
authToken string
tokenExpiry time.Time
accountMutex sync.RWMutex
// 市場信息緩存
// Market info cache
symbolPrecision map[string]SymbolPrecision
precisionMutex sync.RWMutex
// 市場索引緩存
// Market index cache
marketIndexMap map[string]uint8 // symbol -> market_id
marketMutex sync.RWMutex
}
// NewLighterTraderV2 創建新的 LIGHTER 交易器(使用官方 SDK
// 參數說明:
// - l1PrivateKeyHex: L1 錢包私鑰(32字節,標準以太坊私鑰)
// - walletAddr: 以太坊錢包地址(可選,會從私鑰自動派生)
// - apiKeyPrivateKeyHex: API Key 私鑰(40字節,用於簽名交易)如果為空則需要生成
// - testnet: 是否使用測試網
// NewLighterTraderV2 Create new LIGHTER trader (using official SDK)
// Parameters:
// - l1PrivateKeyHex: L1 wallet private key (32 bytes, standard Ethereum private key)
// - walletAddr: Ethereum wallet address (optional, will be derived from private key if empty)
// - apiKeyPrivateKeyHex: API Key private key (40 bytes, for signing transactions) - needs generation if empty
// - testnet: Whether to use testnet
func NewLighterTraderV2(l1PrivateKeyHex, walletAddr, apiKeyPrivateKeyHex string, testnet bool) (*LighterTraderV2, error) {
// 1. 解析 L1 私鑰
// 1. Parse L1 private key
l1PrivateKeyHex = strings.TrimPrefix(strings.ToLower(l1PrivateKeyHex), "0x")
l1PrivateKey, err := crypto.HexToECDSA(l1PrivateKeyHex)
if err != nil {
return nil, fmt.Errorf("無效的 L1 私鑰: %w", err)
return nil, fmt.Errorf("invalid L1 private key: %w", err)
}
// 2. 如果沒有提供錢包地址,從私鑰派生
// 2. If wallet address not provided, derive from private key
if walletAddr == "" {
walletAddr = crypto.PubkeyToAddress(*l1PrivateKey.Public().(*ecdsa.PublicKey)).Hex()
logger.Infof("✓ 從私鑰派生錢包地址: %s", walletAddr)
logger.Infof("✓ Derived wallet address from private key: %s", walletAddr)
}
// 3. 確定 API URL Chain ID
// 3. Determine API URL and Chain ID
baseURL := "https://mainnet.zklighter.elliot.ai"
chainID := uint32(42766) // Mainnet Chain ID
if testnet {
@@ -87,7 +87,7 @@ func NewLighterTraderV2(l1PrivateKeyHex, walletAddr, apiKeyPrivateKeyHex string,
chainID = uint32(42069) // Testnet Chain ID
}
// 4. 創建 HTTP 客戶端
// 4. Create HTTP client
httpClient := lighterHTTP.NewClient(baseURL)
trader := &LighterTraderV2{
@@ -100,24 +100,24 @@ func NewLighterTraderV2(l1PrivateKeyHex, walletAddr, apiKeyPrivateKeyHex string,
chainID: chainID,
httpClient: httpClient,
apiKeyPrivateKey: apiKeyPrivateKeyHex,
apiKeyIndex: 0, // 默認使用索引 0
apiKeyIndex: 0, // Default to index 0
symbolPrecision: make(map[string]SymbolPrecision),
marketIndexMap: make(map[string]uint8),
}
// 5. 初始化賬戶(獲取賬戶索引)
// 5. Initialize account (get account index)
if err := trader.initializeAccount(); err != nil {
return nil, fmt.Errorf("初始化賬戶失敗: %w", err)
return nil, fmt.Errorf("failed to initialize account: %w", err)
}
// 6. 如果沒有 API Key,提示用戶需要生成
// 6. If no API Key, prompt user to generate one
if apiKeyPrivateKeyHex == "" {
logger.Infof("⚠️ 未提供 API Key 私鑰,請調用 GenerateAndRegisterAPIKey() 生成")
logger.Infof(" 或者從 LIGHTER 官網獲取現有的 API Key")
logger.Infof("⚠️ No API Key private key provided, please call GenerateAndRegisterAPIKey() to generate")
logger.Infof(" Or get an existing API Key from LIGHTER website")
return trader, nil
}
// 7. 創建 TxClient(用於簽名交易)
// 7. Create TxClient (for signing transactions)
txClient, err := lighterClient.NewTxClient(
httpClient,
apiKeyPrivateKeyHex,
@@ -126,41 +126,41 @@ func NewLighterTraderV2(l1PrivateKeyHex, walletAddr, apiKeyPrivateKeyHex string,
trader.chainID,
)
if err != nil {
return nil, fmt.Errorf("創建 TxClient 失敗: %w", err)
return nil, fmt.Errorf("failed to create TxClient: %w", err)
}
trader.txClient = txClient
// 8. 驗證 API Key 是否正確
// 8. Verify API Key is correct
if err := trader.checkClient(); err != nil {
logger.Infof("⚠️ API Key 驗證失敗: %v", err)
logger.Infof(" 您可能需要重新生成 API Key 或檢查配置")
logger.Infof("⚠️ API Key verification failed: %v", err)
logger.Infof(" You may need to regenerate API Key or check configuration")
return trader, err
}
logger.Infof("✓ LIGHTER 交易器初始化成功 (account=%d, apiKey=%d, testnet=%v)",
logger.Infof("✓ LIGHTER trader initialized successfully (account=%d, apiKey=%d, testnet=%v)",
trader.accountIndex, trader.apiKeyIndex, testnet)
return trader, nil
}
// initializeAccount 初始化賬戶信息(獲取賬戶索引)
// initializeAccount Initialize account information (get account index)
func (t *LighterTraderV2) initializeAccount() error {
// 通過 L1 地址獲取賬戶信息
// Get account info by L1 address
accountInfo, err := t.getAccountByL1Address()
if err != nil {
return fmt.Errorf("獲取賬戶信息失敗: %w", err)
return fmt.Errorf("failed to get account info: %w", err)
}
t.accountMutex.Lock()
t.accountIndex = accountInfo.AccountIndex
t.accountMutex.Unlock()
logger.Infof("✓ 賬戶索引: %d", t.accountIndex)
logger.Infof("✓ Account index: %d", t.accountIndex)
return nil
}
// getAccountByL1Address 通過 L1 錢包地址獲取 LIGHTER 賬戶信息
// getAccountByL1Address Get LIGHTER account info by L1 wallet address
func (t *LighterTraderV2) getAccountByL1Address() (*AccountInfo, error) {
endpoint := fmt.Sprintf("%s/api/v1/account?by=address&value=%s", t.baseURL, t.walletAddr)
@@ -181,67 +181,67 @@ func (t *LighterTraderV2) getAccountByL1Address() (*AccountInfo, error) {
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("獲取賬戶失敗 (status %d): %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("failed to get account (status %d): %s", resp.StatusCode, string(body))
}
var accountInfo AccountInfo
if err := json.Unmarshal(body, &accountInfo); err != nil {
return nil, fmt.Errorf("解析賬戶響應失敗: %w", err)
return nil, fmt.Errorf("failed to parse account response: %w", err)
}
return &accountInfo, nil
}
// checkClient 驗證 API Key 是否正確
// checkClient Verify if API Key is correct
func (t *LighterTraderV2) checkClient() error {
if t.txClient == nil {
return fmt.Errorf("TxClient 未初始化")
return fmt.Errorf("TxClient not initialized")
}
// 獲取服務器上註冊的 API Key 公鑰
// Get API Key public key registered on server
publicKey, err := t.httpClient.GetApiKey(t.accountIndex, t.apiKeyIndex)
if err != nil {
return fmt.Errorf("獲取 API Key 失敗: %w", err)
return fmt.Errorf("failed to get API Key: %w", err)
}
// 獲取本地 API Key 的公鑰
// Get local API Key public key
pubKeyBytes := t.txClient.GetKeyManager().PubKeyBytes()
localPubKey := hexutil.Encode(pubKeyBytes[:])
localPubKey = strings.Replace(localPubKey, "0x", "", 1)
// 比對公鑰
// Compare public keys
if publicKey != localPubKey {
return fmt.Errorf("API Key 不匹配:本地=%s, 服務器=%s", localPubKey, publicKey)
return fmt.Errorf("API Key mismatch: local=%s, server=%s", localPubKey, publicKey)
}
logger.Infof("✓ API Key 驗證通過")
logger.Infof("✓ API Key verification passed")
return nil
}
// GenerateAndRegisterAPIKey 生成新的 API Key 並註冊到 LIGHTER
// 注意:這需要 L1 私鑰簽名,所以必須在有 L1 私鑰的情況下調用
// GenerateAndRegisterAPIKey Generate new API Key and register to LIGHTER
// Note: This requires L1 private key signature, so must be called with L1 private key available
func (t *LighterTraderV2) GenerateAndRegisterAPIKey(seed string) (privateKey, publicKey string, err error) {
// 這個功能需要調用官方 SDK GenerateAPIKey 函數
// 但這是在 sharedlib 中的 CGO 函數,無法直接在純 Go 代碼中調用
// This function needs to call the official SDK's GenerateAPIKey function
// But this is a CGO function in sharedlib, cannot be called directly in pure Go code
//
// 解決方案:
// 1. 讓用戶從 LIGHTER 官網生成 API Key
// 2. 或者我們可以實現一個簡單的 API Key 生成包裝器
// Solutions:
// 1. Let users generate API Key from LIGHTER website
// 2. Or we can implement a simple API Key generation wrapper
return "", "", fmt.Errorf("GenerateAndRegisterAPIKey 功能待實現,請從 LIGHTER 官網生成 API Key")
return "", "", fmt.Errorf("GenerateAndRegisterAPIKey feature not implemented yet, please generate API Key from LIGHTER website")
}
// refreshAuthToken 刷新認證令牌(使用官方 SDK
// refreshAuthToken Refresh authentication token (using official SDK)
func (t *LighterTraderV2) refreshAuthToken() error {
if t.txClient == nil {
return fmt.Errorf("TxClient 未初始化,請先設置 API Key")
return fmt.Errorf("TxClient not initialized, please set API Key first")
}
// 使用官方 SDK 生成認證令牌(有效期 7 小時)
// Generate auth token using official SDK (valid for 7 hours)
deadline := time.Now().Add(7 * time.Hour)
authToken, err := t.txClient.GetAuthToken(deadline)
if err != nil {
return fmt.Errorf("生成認證令牌失敗: %w", err)
return fmt.Errorf("failed to generate auth token: %w", err)
}
t.accountMutex.Lock()
@@ -249,31 +249,31 @@ func (t *LighterTraderV2) refreshAuthToken() error {
t.tokenExpiry = deadline
t.accountMutex.Unlock()
logger.Infof("✓ 認證令牌已生成(有效期至: %s", t.tokenExpiry.Format(time.RFC3339))
logger.Infof("✓ Auth token generated (valid until: %s)", t.tokenExpiry.Format(time.RFC3339))
return nil
}
// ensureAuthToken 確保認證令牌有效
// ensureAuthToken Ensure authentication token is valid
func (t *LighterTraderV2) ensureAuthToken() error {
t.accountMutex.RLock()
expired := time.Now().After(t.tokenExpiry.Add(-30 * time.Minute)) // 提前 30 分鐘刷新
expired := time.Now().After(t.tokenExpiry.Add(-30 * time.Minute)) // Refresh 30 minutes early
t.accountMutex.RUnlock()
if expired {
logger.Info("🔄 認證令牌即將過期,刷新中...")
logger.Info("🔄 Auth token about to expire, refreshing...")
return t.refreshAuthToken()
}
return nil
}
// GetExchangeType 獲取交易所類型
// GetExchangeType Get exchange type
func (t *LighterTraderV2) GetExchangeType() string {
return "lighter"
}
// Cleanup 清理資源
// Cleanup Clean up resources
func (t *LighterTraderV2) Cleanup() error {
logger.Info("⏹ LIGHTER 交易器清理完成")
logger.Info("⏹ LIGHTER trader cleanup completed")
return nil
}
+20 -20
View File
@@ -7,7 +7,7 @@ import (
"net/http"
)
// GetBalance 獲取賬戶余額(實現 Trader 接口)
// GetBalance Get account balance (implements Trader interface)
func (t *LighterTraderV2) GetBalance() (map[string]interface{}, error) {
balance, err := t.GetAccountBalance()
if err != nil {
@@ -23,10 +23,10 @@ func (t *LighterTraderV2) GetBalance() (map[string]interface{}, error) {
}, nil
}
// GetAccountBalance 獲取賬戶詳細余額信息
// GetAccountBalance Get detailed account balance information
func (t *LighterTraderV2) GetAccountBalance() (*AccountBalance, error) {
if err := t.ensureAuthToken(); err != nil {
return nil, fmt.Errorf("認證令牌無效: %w", err)
return nil, fmt.Errorf("invalid auth token: %w", err)
}
t.accountMutex.RLock()
@@ -41,7 +41,7 @@ func (t *LighterTraderV2) GetAccountBalance() (*AccountBalance, error) {
return nil, err
}
// 添加認證頭
// Add authentication header
req.Header.Set("Authorization", authToken)
resp, err := t.client.Do(req)
@@ -56,18 +56,18 @@ func (t *LighterTraderV2) GetAccountBalance() (*AccountBalance, error) {
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("獲取余額失敗 (status %d): %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("failed to get balance (status %d): %s", resp.StatusCode, string(body))
}
var balance AccountBalance
if err := json.Unmarshal(body, &balance); err != nil {
return nil, fmt.Errorf("解析余額響應失敗: %w", err)
return nil, fmt.Errorf("failed to parse balance response: %w", err)
}
return &balance, nil
}
// GetPositions 獲取所有持倉(實現 Trader 接口)
// GetPositions Get all positions (implements Trader interface)
func (t *LighterTraderV2) GetPositions() ([]map[string]interface{}, error) {
positions, err := t.GetPositionsRaw("")
if err != nil {
@@ -92,10 +92,10 @@ func (t *LighterTraderV2) GetPositions() ([]map[string]interface{}, error) {
return result, nil
}
// GetPositionsRaw 獲取所有持倉(返回原始類型)
// GetPositionsRaw Get all positions (returns raw type)
func (t *LighterTraderV2) GetPositionsRaw(symbol string) ([]Position, error) {
if err := t.ensureAuthToken(); err != nil {
return nil, fmt.Errorf("認證令牌無效: %w", err)
return nil, fmt.Errorf("invalid auth token: %w", err)
}
t.accountMutex.RLock()
@@ -127,18 +127,18 @@ func (t *LighterTraderV2) GetPositionsRaw(symbol string) ([]Position, error) {
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("獲取持倉失敗 (status %d): %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("failed to get positions (status %d): %s", resp.StatusCode, string(body))
}
var positions []Position
if err := json.Unmarshal(body, &positions); err != nil {
return nil, fmt.Errorf("解析持倉響應失敗: %w", err)
return nil, fmt.Errorf("failed to parse positions response: %w", err)
}
return positions, nil
}
// GetPosition 獲取指定幣種的持倉
// GetPosition Get position for specified symbol
func (t *LighterTraderV2) GetPosition(symbol string) (*Position, error) {
positions, err := t.GetPositionsRaw(symbol)
if err != nil {
@@ -151,10 +151,10 @@ func (t *LighterTraderV2) GetPosition(symbol string) (*Position, error) {
}
}
return nil, nil // 無持倉
return nil, nil // No position
}
// GetMarketPrice 獲取市場價格(實現 Trader 接口)
// GetMarketPrice Get market price (implements Trader interface)
func (t *LighterTraderV2) GetMarketPrice(symbol string) (float64, error) {
endpoint := fmt.Sprintf("%s/api/v1/market/ticker?symbol=%s", t.baseURL, symbol)
@@ -175,25 +175,25 @@ func (t *LighterTraderV2) GetMarketPrice(symbol string) (float64, error) {
}
if resp.StatusCode != http.StatusOK {
return 0, fmt.Errorf("獲取市場價格失敗 (status %d): %s", resp.StatusCode, string(body))
return 0, fmt.Errorf("failed to get market price (status %d): %s", resp.StatusCode, string(body))
}
var ticker map[string]interface{}
if err := json.Unmarshal(body, &ticker); err != nil {
return 0, fmt.Errorf("解析價格響應失敗: %w", err)
return 0, fmt.Errorf("failed to parse price response: %w", err)
}
price, err := SafeFloat64(ticker, "last_price")
if err != nil {
return 0, fmt.Errorf("無法獲取價格: %w", err)
return 0, fmt.Errorf("failed to get price: %w", err)
}
return price, nil
}
// FormatQuantity 格式化數量到正確的精度(實現 Trader 接口)
// FormatQuantity Format quantity to correct precision (implements Trader interface)
func (t *LighterTraderV2) FormatQuantity(symbol string, quantity float64) (string, error) {
// TODO: 從 API 獲取幣種精度
// 暫時使用默認精度
// TODO: Get symbol precision from API
// Using default precision for now
return fmt.Sprintf("%.4f", quantity), nil
}
+84 -84
View File
@@ -12,92 +12,92 @@ import (
"github.com/elliottech/lighter-go/types"
)
// SetStopLoss 設置止損單(實現 Trader 接口)
// SetStopLoss Set stop-loss order (implements Trader interface)
func (t *LighterTraderV2) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error {
if t.txClient == nil {
return fmt.Errorf("TxClient 未初始化")
return fmt.Errorf("TxClient not initialized")
}
logger.Infof("🛑 LIGHTER 設置止損: %s %s qty=%.4f, stop=%.2f", symbol, positionSide, quantity, stopPrice)
logger.Infof("🛑 LIGHTER Setting stop-loss: %s %s qty=%.4f, stop=%.2f", symbol, positionSide, quantity, stopPrice)
// 確定訂單方向(做空止損用買單,做多止損用賣單)
// Determine order direction (short position uses buy order, long position uses sell order)
isAsk := (positionSide == "LONG" || positionSide == "long")
// 創建限價止損單
// Create limit stop-loss order
_, err := t.CreateOrder(symbol, isAsk, quantity, stopPrice, "limit")
if err != nil {
return fmt.Errorf("設置止損失敗: %w", err)
return fmt.Errorf("failed to set stop-loss: %w", err)
}
logger.Infof("✓ LIGHTER 止損已設置: %.2f", stopPrice)
logger.Infof("✓ LIGHTER stop-loss set: %.2f", stopPrice)
return nil
}
// SetTakeProfit 設置止盈單(實現 Trader 接口)
// SetTakeProfit Set take-profit order (implements Trader interface)
func (t *LighterTraderV2) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error {
if t.txClient == nil {
return fmt.Errorf("TxClient 未初始化")
return fmt.Errorf("TxClient not initialized")
}
logger.Infof("🎯 LIGHTER 設置止盈: %s %s qty=%.4f, tp=%.2f", symbol, positionSide, quantity, takeProfitPrice)
logger.Infof("🎯 LIGHTER Setting take-profit: %s %s qty=%.4f, tp=%.2f", symbol, positionSide, quantity, takeProfitPrice)
// 確定訂單方向(做空止盈用買單,做多止盈用賣單)
// Determine order direction (short position uses buy order, long position uses sell order)
isAsk := (positionSide == "LONG" || positionSide == "long")
// 創建限價止盈單
// Create limit take-profit order
_, err := t.CreateOrder(symbol, isAsk, quantity, takeProfitPrice, "limit")
if err != nil {
return fmt.Errorf("設置止盈失敗: %w", err)
return fmt.Errorf("failed to set take-profit: %w", err)
}
logger.Infof("✓ LIGHTER 止盈已設置: %.2f", takeProfitPrice)
logger.Infof("✓ LIGHTER take-profit set: %.2f", takeProfitPrice)
return nil
}
// CancelAllOrders 取消所有訂單(實現 Trader 接口)
// CancelAllOrders Cancel all orders (implements Trader interface)
func (t *LighterTraderV2) CancelAllOrders(symbol string) error {
if t.txClient == nil {
return fmt.Errorf("TxClient 未初始化")
return fmt.Errorf("TxClient not initialized")
}
if err := t.ensureAuthToken(); err != nil {
return fmt.Errorf("認證令牌無效: %w", err)
return fmt.Errorf("invalid auth token: %w", err)
}
// 獲取所有活躍訂單
// Get all active orders
orders, err := t.GetActiveOrders(symbol)
if err != nil {
return fmt.Errorf("獲取活躍訂單失敗: %w", err)
return fmt.Errorf("failed to get active orders: %w", err)
}
if len(orders) == 0 {
logger.Infof("✓ LIGHTER - 無需取消訂單(無活躍訂單)")
logger.Infof("✓ LIGHTER - No orders to cancel (no active orders)")
return nil
}
// 批量取消
// Batch cancel
canceledCount := 0
for _, order := range orders {
if err := t.CancelOrder(symbol, order.OrderID); err != nil {
logger.Infof("⚠️ 取消訂單失敗 (ID: %s): %v", order.OrderID, err)
logger.Infof("⚠️ Failed to cancel order (ID: %s): %v", order.OrderID, err)
} else {
canceledCount++
}
}
logger.Infof("✓ LIGHTER - 已取消 %d 個訂單", canceledCount)
logger.Infof("✓ LIGHTER - Canceled %d orders", canceledCount)
return nil
}
// GetOrderStatus 獲取訂單狀態(實現 Trader 接口)
// GetOrderStatus Get order status (implements Trader interface)
func (t *LighterTraderV2) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
// LIGHTER 使用市價單通常立即成交
// 嘗試查詢訂單狀態
// LIGHTER market orders are usually filled immediately
// Try to query order status
if err := t.ensureAuthToken(); err != nil {
return nil, fmt.Errorf("認證令牌無效: %w", err)
return nil, fmt.Errorf("invalid auth token: %w", err)
}
// 構建請求 URL
// Build request URL
endpoint := fmt.Sprintf("%s/api/v1/order/%s", t.baseURL, orderID)
req, err := http.NewRequest("GET", endpoint, nil)
@@ -110,7 +110,7 @@ func (t *LighterTraderV2) GetOrderStatus(symbol string, orderID string) (map[str
resp, err := t.client.Do(req)
if err != nil {
// 如果查詢失敗,假設訂單已完成
// If query fails, assume order is filled
return map[string]interface{}{
"orderId": orderID,
"status": "FILLED",
@@ -143,7 +143,7 @@ func (t *LighterTraderV2) GetOrderStatus(symbol string, orderID string) (map[str
}, nil
}
// 轉換狀態為統一格式
// Convert status to unified format
unifiedStatus := order.Status
switch order.Status {
case "filled":
@@ -163,89 +163,89 @@ func (t *LighterTraderV2) GetOrderStatus(symbol string, orderID string) (map[str
}, nil
}
// CancelStopLossOrders 僅取消止損單(實現 Trader 接口)
// CancelStopLossOrders Cancel only stop-loss orders (implements Trader interface)
func (t *LighterTraderV2) CancelStopLossOrders(symbol string) error {
// LIGHTER 暫時無法區分止損和止盈單,取消所有止盈止損單
logger.Infof("⚠️ LIGHTER 無法區分止損/止盈單,將取消所有止盈止損單")
// LIGHTER cannot distinguish between stop-loss and take-profit orders yet, will cancel all stop orders
logger.Infof("⚠️ LIGHTER cannot distinguish stop-loss/take-profit orders, will cancel all stop orders")
return t.CancelStopOrders(symbol)
}
// CancelTakeProfitOrders 僅取消止盈單(實現 Trader 接口)
// CancelTakeProfitOrders Cancel only take-profit orders (implements Trader interface)
func (t *LighterTraderV2) CancelTakeProfitOrders(symbol string) error {
// LIGHTER 暫時無法區分止損和止盈單,取消所有止盈止損單
logger.Infof("⚠️ LIGHTER 無法區分止損/止盈單,將取消所有止盈止損單")
// LIGHTER cannot distinguish between stop-loss and take-profit orders yet, will cancel all stop orders
logger.Infof("⚠️ LIGHTER cannot distinguish stop-loss/take-profit orders, will cancel all stop orders")
return t.CancelStopOrders(symbol)
}
// CancelStopOrders 取消該幣種的止盈/止損單(實現 Trader 接口)
// CancelStopOrders Cancel stop-loss/take-profit orders for this symbol (implements Trader interface)
func (t *LighterTraderV2) CancelStopOrders(symbol string) error {
if t.txClient == nil {
return fmt.Errorf("TxClient 未初始化")
return fmt.Errorf("TxClient not initialized")
}
if err := t.ensureAuthToken(); err != nil {
return fmt.Errorf("認證令牌無效: %w", err)
return fmt.Errorf("invalid auth token: %w", err)
}
// 獲取活躍訂單
// Get active orders
orders, err := t.GetActiveOrders(symbol)
if err != nil {
return fmt.Errorf("獲取活躍訂單失敗: %w", err)
return fmt.Errorf("failed to get active orders: %w", err)
}
canceledCount := 0
for _, order := range orders {
// TODO: 檢查訂單類型,只取消止盈止損單
// 暫時取消所有訂單
// TODO: Check order type, only cancel stop orders
// For now, cancel all orders
if err := t.CancelOrder(symbol, order.OrderID); err != nil {
logger.Infof("⚠️ 取消訂單失敗 (ID: %s): %v", order.OrderID, err)
logger.Infof("⚠️ Failed to cancel order (ID: %s): %v", order.OrderID, err)
} else {
canceledCount++
}
}
logger.Infof("✓ LIGHTER - 已取消 %d 個止盈止損單", canceledCount)
logger.Infof("✓ LIGHTER - Canceled %d stop orders", canceledCount)
return nil
}
// GetActiveOrders 獲取活躍訂單
// GetActiveOrders Get active orders
func (t *LighterTraderV2) GetActiveOrders(symbol string) ([]OrderResponse, error) {
if err := t.ensureAuthToken(); err != nil {
return nil, fmt.Errorf("認證令牌無效: %w", err)
return nil, fmt.Errorf("invalid auth token: %w", err)
}
// 獲取市場索引
// Get market index
marketIndex, err := t.getMarketIndex(symbol)
if err != nil {
return nil, fmt.Errorf("獲取市場索引失敗: %w", err)
return nil, fmt.Errorf("failed to get market index: %w", err)
}
// 構建請求 URL
// Build request URL
endpoint := fmt.Sprintf("%s/api/v1/accountActiveOrders?account_index=%d&market_id=%d",
t.baseURL, t.accountIndex, marketIndex)
// 發送 GET 請求
// Send GET request
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("創建請求失敗: %w", err)
return nil, fmt.Errorf("failed to create request: %w", err)
}
// 添加認證頭
// Add authentication header
req.Header.Set("Authorization", t.authToken)
req.Header.Set("Content-Type", "application/json")
resp, err := t.client.Do(req)
if err != nil {
return nil, fmt.Errorf("請求失敗: %w", err)
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("讀取響應失敗: %w", err)
return nil, fmt.Errorf("failed to read response: %w", err)
}
// 解析響應
// Parse response
var apiResp struct {
Code int `json:"code"`
Message string `json:"message"`
@@ -253,83 +253,83 @@ func (t *LighterTraderV2) GetActiveOrders(symbol string) ([]OrderResponse, error
}
if err := json.Unmarshal(body, &apiResp); err != nil {
return nil, fmt.Errorf("解析響應失敗: %w, body: %s", err, string(body))
return nil, fmt.Errorf("failed to parse response: %w, body: %s", err, string(body))
}
if apiResp.Code != 200 {
return nil, fmt.Errorf("獲取活躍訂單失敗 (code %d): %s", apiResp.Code, apiResp.Message)
return nil, fmt.Errorf("failed to get active orders (code %d): %s", apiResp.Code, apiResp.Message)
}
logger.Infof("✓ LIGHTER - 獲取到 %d 個活躍訂單", len(apiResp.Data))
logger.Infof("✓ LIGHTER - Retrieved %d active orders", len(apiResp.Data))
return apiResp.Data, nil
}
// CancelOrder 取消單個訂單
// CancelOrder Cancel a single order
func (t *LighterTraderV2) CancelOrder(symbol, orderID string) error {
if t.txClient == nil {
return fmt.Errorf("TxClient 未初始化")
return fmt.Errorf("TxClient not initialized")
}
// 獲取市場索引
// Get market index
marketIndex, err := t.getMarketIndex(symbol)
if err != nil {
return fmt.Errorf("獲取市場索引失敗: %w", err)
return fmt.Errorf("failed to get market index: %w", err)
}
// orderID 轉換為 int64
// Convert orderID to int64
orderIndex, err := strconv.ParseInt(orderID, 10, 64)
if err != nil {
return fmt.Errorf("無效的訂單ID: %w", err)
return fmt.Errorf("invalid order ID: %w", err)
}
// 構建取消訂單請求
// Build cancel order request
txReq := &types.CancelOrderTxReq{
MarketIndex: marketIndex,
Index: orderIndex,
}
// 使用 SDK 簽名交易
nonce := int64(-1) // -1 表示自動獲取
// Sign transaction using SDK
nonce := int64(-1) // -1 means auto-fetch
tx, err := t.txClient.GetCancelOrderTransaction(txReq, &types.TransactOpts{
Nonce: &nonce,
})
if err != nil {
return fmt.Errorf("簽名取消訂單失敗: %w", err)
return fmt.Errorf("failed to sign cancel order: %w", err)
}
// 序列化交易
// Serialize transaction
txBytes, err := json.Marshal(tx)
if err != nil {
return fmt.Errorf("序列化交易失敗: %w", err)
return fmt.Errorf("failed to serialize transaction: %w", err)
}
// 提交取消訂單到 LIGHTER API
// Submit cancel order to LIGHTER API
_, err = t.submitCancelOrder(txBytes)
if err != nil {
return fmt.Errorf("提交取消訂單失敗: %w", err)
return fmt.Errorf("failed to submit cancel order: %w", err)
}
logger.Infof("✓ LIGHTER訂單已取消 - ID: %s", orderID)
logger.Infof("✓ LIGHTER order canceled - ID: %s", orderID)
return nil
}
// submitCancelOrder 提交已簽名的取消訂單到 LIGHTER API
// submitCancelOrder Submit signed cancel order to LIGHTER API
func (t *LighterTraderV2) submitCancelOrder(signedTx []byte) (map[string]interface{}, error) {
const TX_TYPE_CANCEL_ORDER = 15
// 構建請求
// Build request
req := SendTxRequest{
TxType: TX_TYPE_CANCEL_ORDER,
TxInfo: string(signedTx),
PriceProtection: false, // 取消訂單不需要價格保護
PriceProtection: false, // Cancel order doesn't need price protection
}
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("序列化請求失敗: %w", err)
return nil, fmt.Errorf("failed to serialize request: %w", err)
}
// 發送 POST 請求到 /api/v1/sendTx
// Send POST request to /api/v1/sendTx
endpoint := fmt.Sprintf("%s/api/v1/sendTx", t.baseURL)
httpReq, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(reqBody))
if err != nil {
@@ -349,15 +349,15 @@ func (t *LighterTraderV2) submitCancelOrder(signedTx []byte) (map[string]interfa
return nil, err
}
// 解析響應
// Parse response
var sendResp SendTxResponse
if err := json.Unmarshal(body, &sendResp); err != nil {
return nil, fmt.Errorf("解析響應失敗: %w, body: %s", err, string(body))
return nil, fmt.Errorf("failed to parse response: %w, body: %s", err, string(body))
}
// 檢查響應碼
// Check response code
if sendResp.Code != 200 {
return nil, fmt.Errorf("提交取消訂單失敗 (code %d): %s", sendResp.Code, sendResp.Message)
return nil, fmt.Errorf("failed to submit cancel order (code %d): %s", sendResp.Code, sendResp.Message)
}
result := map[string]interface{}{
@@ -365,6 +365,6 @@ func (t *LighterTraderV2) submitCancelOrder(signedTx []byte) (map[string]interfa
"status": "cancelled",
}
logger.Infof("✓ 取消訂單已提交到 LIGHTER - tx_hash: %v", sendResp.Data["tx_hash"])
logger.Infof("✓ Cancel order submitted to LIGHTER - tx_hash: %v", sendResp.Data["tx_hash"])
return result, nil
}
+106 -106
View File
@@ -12,32 +12,32 @@ import (
"github.com/elliottech/lighter-go/types"
)
// OpenLong 開多倉(實現 Trader 接口)
// OpenLong Open long position (implements Trader interface)
func (t *LighterTraderV2) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
if t.txClient == nil {
return nil, fmt.Errorf("TxClient 未初始化,請先設置 API Key")
return nil, fmt.Errorf("TxClient not initialized, please set API Key first")
}
logger.Infof("📈 LIGHTER 開多倉: %s, qty=%.4f, leverage=%dx", symbol, quantity, leverage)
logger.Infof("📈 LIGHTER opening long: %s, qty=%.4f, leverage=%dx", symbol, quantity, leverage)
// 1. 設置杠杆(如果需要)
// 1. Set leverage (if needed)
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Infof("⚠️ 設置杠杆失敗: %v", err)
logger.Infof("⚠️ Failed to set leverage: %v", err)
}
// 2. 獲取市場價格
// 2. Get market price
marketPrice, err := t.GetMarketPrice(symbol)
if err != nil {
return nil, fmt.Errorf("獲取市場價格失敗: %w", err)
return nil, fmt.Errorf("failed to get market price: %w", err)
}
// 3. 創建市價買入單(開多)
// 3. Create market buy order (open long)
orderResult, err := t.CreateOrder(symbol, false, quantity, 0, "market")
if err != nil {
return nil, fmt.Errorf("開多倉失敗: %w", err)
return nil, fmt.Errorf("failed to open long: %w", err)
}
logger.Infof("✓ LIGHTER 開多倉成功: %s @ %.2f", symbol, marketPrice)
logger.Infof("✓ LIGHTER opened long successfully: %s @ %.2f", symbol, marketPrice)
return map[string]interface{}{
"orderId": orderResult["orderId"],
@@ -48,32 +48,32 @@ func (t *LighterTraderV2) OpenLong(symbol string, quantity float64, leverage int
}, nil
}
// OpenShort 開空倉(實現 Trader 接口)
// OpenShort Open short position (implements Trader interface)
func (t *LighterTraderV2) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
if t.txClient == nil {
return nil, fmt.Errorf("TxClient 未初始化,請先設置 API Key")
return nil, fmt.Errorf("TxClient not initialized, please set API Key first")
}
logger.Infof("📉 LIGHTER 開空倉: %s, qty=%.4f, leverage=%dx", symbol, quantity, leverage)
logger.Infof("📉 LIGHTER opening short: %s, qty=%.4f, leverage=%dx", symbol, quantity, leverage)
// 1. 設置杠杆
// 1. Set leverage
if err := t.SetLeverage(symbol, leverage); err != nil {
logger.Infof("⚠️ 設置杠杆失敗: %v", err)
logger.Infof("⚠️ Failed to set leverage: %v", err)
}
// 2. 獲取市場價格
// 2. Get market price
marketPrice, err := t.GetMarketPrice(symbol)
if err != nil {
return nil, fmt.Errorf("獲取市場價格失敗: %w", err)
return nil, fmt.Errorf("failed to get market price: %w", err)
}
// 3. 創建市價賣出單(開空)
// 3. Create market sell order (open short)
orderResult, err := t.CreateOrder(symbol, true, quantity, 0, "market")
if err != nil {
return nil, fmt.Errorf("開空倉失敗: %w", err)
return nil, fmt.Errorf("failed to open short: %w", err)
}
logger.Infof("✓ LIGHTER 開空倉成功: %s @ %.2f", symbol, marketPrice)
logger.Infof("✓ LIGHTER opened short successfully: %s @ %.2f", symbol, marketPrice)
return map[string]interface{}{
"orderId": orderResult["orderId"],
@@ -84,17 +84,17 @@ func (t *LighterTraderV2) OpenShort(symbol string, quantity float64, leverage in
}, nil
}
// CloseLong 平多倉(實現 Trader 接口)
// CloseLong Close long position (implements Trader interface)
func (t *LighterTraderV2) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) {
if t.txClient == nil {
return nil, fmt.Errorf("TxClient 未初始化")
return nil, fmt.Errorf("TxClient not initialized")
}
// 如果 quantity=0,獲取當前持倉數量
// If quantity=0, get current position quantity
if quantity == 0 {
pos, err := t.GetPosition(symbol)
if err != nil {
return nil, fmt.Errorf("獲取持倉失敗: %w", err)
return nil, fmt.Errorf("failed to get position: %w", err)
}
if pos == nil || pos.Size == 0 {
return map[string]interface{}{
@@ -105,20 +105,20 @@ func (t *LighterTraderV2) CloseLong(symbol string, quantity float64) (map[string
quantity = pos.Size
}
logger.Infof("🔻 LIGHTER 平多倉: %s, qty=%.4f", symbol, quantity)
logger.Infof("🔻 LIGHTER closing long: %s, qty=%.4f", symbol, quantity)
// 創建市價賣出單平倉(reduceOnly=true
// Create market sell order to close (reduceOnly=true)
orderResult, err := t.CreateOrder(symbol, true, quantity, 0, "market")
if err != nil {
return nil, fmt.Errorf("平多倉失敗: %w", err)
return nil, fmt.Errorf("failed to close long: %w", err)
}
// 平倉後取消所有掛單
// Cancel all open orders after closing position
if err := t.CancelAllOrders(symbol); err != nil {
logger.Infof("⚠️ 取消掛單失敗: %v", err)
logger.Infof("⚠️ Failed to cancel orders: %v", err)
}
logger.Infof("✓ LIGHTER 平多倉成功: %s", symbol)
logger.Infof("✓ LIGHTER closed long successfully: %s", symbol)
return map[string]interface{}{
"orderId": orderResult["orderId"],
@@ -127,17 +127,17 @@ func (t *LighterTraderV2) CloseLong(symbol string, quantity float64) (map[string
}, nil
}
// CloseShort 平空倉(實現 Trader 接口)
// CloseShort Close short position (implements Trader interface)
func (t *LighterTraderV2) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) {
if t.txClient == nil {
return nil, fmt.Errorf("TxClient 未初始化")
return nil, fmt.Errorf("TxClient not initialized")
}
// 如果 quantity=0,獲取當前持倉數量
// If quantity=0, get current position quantity
if quantity == 0 {
pos, err := t.GetPosition(symbol)
if err != nil {
return nil, fmt.Errorf("獲取持倉失敗: %w", err)
return nil, fmt.Errorf("failed to get position: %w", err)
}
if pos == nil || pos.Size == 0 {
return map[string]interface{}{
@@ -148,20 +148,20 @@ func (t *LighterTraderV2) CloseShort(symbol string, quantity float64) (map[strin
quantity = pos.Size
}
logger.Infof("🔺 LIGHTER 平空倉: %s, qty=%.4f", symbol, quantity)
logger.Infof("🔺 LIGHTER closing short: %s, qty=%.4f", symbol, quantity)
// 創建市價買入單平倉(reduceOnly=true
// Create market buy order to close (reduceOnly=true)
orderResult, err := t.CreateOrder(symbol, false, quantity, 0, "market")
if err != nil {
return nil, fmt.Errorf("平空倉失敗: %w", err)
return nil, fmt.Errorf("failed to close short: %w", err)
}
// 平倉後取消所有掛單
// Cancel all open orders after closing position
if err := t.CancelAllOrders(symbol); err != nil {
logger.Infof("⚠️ 取消掛單失敗: %v", err)
logger.Infof("⚠️ Failed to cancel orders: %v", err)
}
logger.Infof("✓ LIGHTER 平空倉成功: %s", symbol)
logger.Infof("✓ LIGHTER closed short successfully: %s", symbol)
return map[string]interface{}{
"orderId": orderResult["orderId"],
@@ -170,31 +170,31 @@ func (t *LighterTraderV2) CloseShort(symbol string, quantity float64) (map[strin
}, nil
}
// CreateOrder 創建訂單(市價或限價)- 使用官方 SDK 簽名
// CreateOrder Create order (market or limit) - uses official SDK for signing
func (t *LighterTraderV2) CreateOrder(symbol string, isAsk bool, quantity float64, price float64, orderType string) (map[string]interface{}, error) {
if t.txClient == nil {
return nil, fmt.Errorf("TxClient 未初始化")
return nil, fmt.Errorf("TxClient not initialized")
}
// 獲取市場索引(需要從 symbol 轉換)
// Get market index (convert from symbol)
marketIndex, err := t.getMarketIndex(symbol)
if err != nil {
return nil, fmt.Errorf("獲取市場索引失敗: %w", err)
return nil, fmt.Errorf("failed to get market index: %w", err)
}
// 構建訂單請求
clientOrderIndex := time.Now().UnixNano() // 使用時間戳作為客戶端訂單ID
// Build order request
clientOrderIndex := time.Now().UnixNano() // Use timestamp as client order ID
var orderTypeValue uint8 = 0 // 0=limit, 1=market
if orderType == "market" {
orderTypeValue = 1
}
// 將數量和價格轉換為LIGHTER格式(需要乘以精度)
baseAmount := int64(quantity * 1e8) // 8位小數精度
// Convert quantity and price to LIGHTER format (multiply by precision)
baseAmount := int64(quantity * 1e8) // 8 decimal precision
priceValue := uint32(0)
if orderType == "limit" {
priceValue = uint32(price * 1e2) // 價格精度
priceValue = uint32(price * 1e2) // Price precision
}
txReq := &types.CreateOrderTxReq{
@@ -205,60 +205,60 @@ func (t *LighterTraderV2) CreateOrder(symbol string, isAsk bool, quantity float6
IsAsk: boolToUint8(isAsk),
Type: orderTypeValue,
TimeInForce: 0, // GTC
ReduceOnly: 0, // 不只減倉
ReduceOnly: 0, // Not reduce-only
TriggerPrice: 0,
OrderExpiry: time.Now().Add(24 * 28 * time.Hour).UnixMilli(), // 28天後過期
OrderExpiry: time.Now().Add(24 * 28 * time.Hour).UnixMilli(), // Expires in 28 days
}
// 使用SDK簽名交易(nonce會自動獲取)
nonce := int64(-1) // -1表示自動獲取
// Sign transaction using SDK (nonce will be auto-fetched)
nonce := int64(-1) // -1 means auto-fetch
tx, err := t.txClient.GetCreateOrderTransaction(txReq, &types.TransactOpts{
Nonce: &nonce,
})
if err != nil {
return nil, fmt.Errorf("簽名訂單失敗: %w", err)
return nil, fmt.Errorf("failed to sign order: %w", err)
}
// 序列化交易
// Serialize transaction
txBytes, err := json.Marshal(tx)
if err != nil {
return nil, fmt.Errorf("序列化交易失敗: %w", err)
return nil, fmt.Errorf("failed to serialize transaction: %w", err)
}
// 提交訂單到LIGHTER API
// Submit order to LIGHTER API
orderResp, err := t.submitOrder(txBytes)
if err != nil {
return nil, fmt.Errorf("提交訂單失敗: %w", err)
return nil, fmt.Errorf("failed to submit order: %w", err)
}
side := "buy"
if isAsk {
side = "sell"
}
logger.Infof("✓ LIGHTER訂單已創建: %s %s qty=%.4f", symbol, side, quantity)
logger.Infof("✓ LIGHTER order created: %s %s qty=%.4f", symbol, side, quantity)
return orderResp, nil
}
// SendTxRequest 發送交易請求
// SendTxRequest Send transaction request
type SendTxRequest struct {
TxType int `json:"tx_type"`
TxInfo string `json:"tx_info"`
PriceProtection bool `json:"price_protection,omitempty"`
}
// SendTxResponse 發送交易響應
// SendTxResponse Send transaction response
type SendTxResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data map[string]interface{} `json:"data"`
}
// submitOrder 提交已簽名的訂單到LIGHTER API
// submitOrder Submit signed order to LIGHTER API
func (t *LighterTraderV2) submitOrder(signedTx []byte) (map[string]interface{}, error) {
const TX_TYPE_CREATE_ORDER = 14
// 構建請求
// Build request
req := SendTxRequest{
TxType: TX_TYPE_CREATE_ORDER,
TxInfo: string(signedTx),
@@ -267,10 +267,10 @@ func (t *LighterTraderV2) submitOrder(signedTx []byte) (map[string]interface{},
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("序列化請求失敗: %w", err)
return nil, fmt.Errorf("failed to serialize request: %w", err)
}
// 發送 POST 請求到 /api/v1/sendTx
// Send POST request to /api/v1/sendTx
endpoint := fmt.Sprintf("%s/api/v1/sendTx", t.baseURL)
httpReq, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(reqBody))
if err != nil {
@@ -290,39 +290,39 @@ func (t *LighterTraderV2) submitOrder(signedTx []byte) (map[string]interface{},
return nil, err
}
// 解析響應
// Parse response
var sendResp SendTxResponse
if err := json.Unmarshal(body, &sendResp); err != nil {
return nil, fmt.Errorf("解析響應失敗: %w, body: %s", err, string(body))
return nil, fmt.Errorf("failed to parse response: %w, body: %s", err, string(body))
}
// 檢查響應碼
// Check response code
if sendResp.Code != 200 {
return nil, fmt.Errorf("提交訂單失敗 (code %d): %s", sendResp.Code, sendResp.Message)
return nil, fmt.Errorf("failed to submit order (code %d): %s", sendResp.Code, sendResp.Message)
}
// 提取交易哈希和訂單ID
// Extract transaction hash and order ID
result := map[string]interface{}{
"tx_hash": sendResp.Data["tx_hash"],
"status": "submitted",
}
// 如果有訂單ID,添加到結果中
// Add order ID to result if available
if orderID, ok := sendResp.Data["order_id"]; ok {
result["orderId"] = orderID
} else if txHash, ok := sendResp.Data["tx_hash"].(string); ok {
// 使用 tx_hash 作為 orderID
// Use tx_hash as orderID
result["orderId"] = txHash
}
logger.Infof("✓ 訂單已提交到 LIGHTER - tx_hash: %v", sendResp.Data["tx_hash"])
logger.Infof("✓ Order submitted to LIGHTER - tx_hash: %v", sendResp.Data["tx_hash"])
return result, nil
}
// getMarketIndex 獲取市場索引(從symbol轉換)- 動態從API獲取
// getMarketIndex Get market index (convert from symbol) - dynamically fetch from API
func (t *LighterTraderV2) getMarketIndex(symbol string) (uint8, error) {
// 1. 檢查緩存
// 1. Check cache
t.marketMutex.RLock()
if index, ok := t.marketIndexMap[symbol]; ok {
t.marketMutex.RUnlock()
@@ -330,62 +330,62 @@ func (t *LighterTraderV2) getMarketIndex(symbol string) (uint8, error) {
}
t.marketMutex.RUnlock()
// 2. 從 API 獲取市場列表
// 2. Fetch market list from API
markets, err := t.fetchMarketList()
if err != nil {
// 如果 API 失敗,回退到硬編碼映射
logger.Infof("⚠️ 從 API 獲取市場列表失敗,使用硬編碼映射: %v", err)
// If API fails, fallback to hardcoded mapping
logger.Infof("⚠️ Failed to fetch market list from API, using hardcoded mapping: %v", err)
return t.getFallbackMarketIndex(symbol)
}
// 3. 更新緩存
// 3. Update cache
t.marketMutex.Lock()
for _, market := range markets {
t.marketIndexMap[market.Symbol] = market.MarketID
}
t.marketMutex.Unlock()
// 4. 從緩存中獲取
// 4. Get from cache
t.marketMutex.RLock()
index, ok := t.marketIndexMap[symbol]
t.marketMutex.RUnlock()
if !ok {
return 0, fmt.Errorf("未知的市場符號: %s", symbol)
return 0, fmt.Errorf("unknown market symbol: %s", symbol)
}
return index, nil
}
// MarketInfo 市場信息
// MarketInfo Market information
type MarketInfo struct {
Symbol string `json:"symbol"`
MarketID uint8 `json:"market_id"`
}
// fetchMarketList 從 API 獲取市場列表
// fetchMarketList Fetch market list from API
func (t *LighterTraderV2) fetchMarketList() ([]MarketInfo, error) {
endpoint := fmt.Sprintf("%s/api/v1/orderBooks", t.baseURL)
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("創建請求失敗: %w", err)
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := t.client.Do(req)
if err != nil {
return nil, fmt.Errorf("請求失敗: %w", err)
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("讀取響應失敗: %w", err)
return nil, fmt.Errorf("failed to read response: %w", err)
}
// 解析響應
// Parse response
var apiResp struct {
Code int `json:"code"`
Message string `json:"message"`
@@ -396,14 +396,14 @@ func (t *LighterTraderV2) fetchMarketList() ([]MarketInfo, error) {
}
if err := json.Unmarshal(body, &apiResp); err != nil {
return nil, fmt.Errorf("解析響應失敗: %w", err)
return nil, fmt.Errorf("failed to parse response: %w", err)
}
if apiResp.Code != 200 {
return nil, fmt.Errorf("獲取市場列表失敗 (code %d): %s", apiResp.Code, apiResp.Message)
return nil, fmt.Errorf("failed to get market list (code %d): %s", apiResp.Code, apiResp.Message)
}
// 轉換為 MarketInfo 列表
// Convert to MarketInfo list
markets := make([]MarketInfo, len(apiResp.Data))
for i, market := range apiResp.Data {
markets[i] = MarketInfo{
@@ -412,11 +412,11 @@ func (t *LighterTraderV2) fetchMarketList() ([]MarketInfo, error) {
}
}
logger.Infof("✓ 獲取到 %d 個市場", len(markets))
logger.Infof("✓ Retrieved %d markets", len(markets))
return markets, nil
}
// getFallbackMarketIndex 硬編碼的回退映射
// getFallbackMarketIndex Hardcoded fallback mapping
func (t *LighterTraderV2) getFallbackMarketIndex(symbol string) (uint8, error) {
fallbackMap := map[string]uint8{
"BTC-PERP": 0,
@@ -428,43 +428,43 @@ func (t *LighterTraderV2) getFallbackMarketIndex(symbol string) (uint8, error) {
}
if index, ok := fallbackMap[symbol]; ok {
logger.Infof("✓ 使用硬編碼市場索引: %s -> %d", symbol, index)
logger.Infof("✓ Using hardcoded market index: %s -> %d", symbol, index)
return index, nil
}
return 0, fmt.Errorf("未知的市場符號: %s", symbol)
return 0, fmt.Errorf("unknown market symbol: %s", symbol)
}
// SetLeverage 設置杠杆(實現 Trader 接口)
// SetLeverage Set leverage (implements Trader interface)
func (t *LighterTraderV2) SetLeverage(symbol string, leverage int) error {
if t.txClient == nil {
return fmt.Errorf("TxClient 未初始化")
return fmt.Errorf("TxClient not initialized")
}
// TODO: 使用SDK簽名並提交SetLeverage交易
logger.Infof("⚙️ 設置杠杆: %s = %dx", symbol, leverage)
// TODO: Sign and submit SetLeverage transaction using SDK
logger.Infof("⚙️ Setting leverage: %s = %dx", symbol, leverage)
return nil // 暫時返回成功
return nil // Return success for now
}
// SetMarginMode 設置倉位模式(實現 Trader 接口)
// SetMarginMode Set margin mode (implements Trader interface)
func (t *LighterTraderV2) SetMarginMode(symbol string, isCrossMargin bool) error {
if t.txClient == nil {
return fmt.Errorf("TxClient 未初始化")
return fmt.Errorf("TxClient not initialized")
}
modeStr := "逐倉"
modeStr := "isolated"
if isCrossMargin {
modeStr = "全倉"
modeStr = "cross"
}
logger.Infof("⚙️ 設置倉位模式: %s = %s", symbol, modeStr)
logger.Infof("⚙️ Setting margin mode: %s = %s", symbol, modeStr)
// TODO: 使用SDK簽名並提交SetMarginMode交易
// TODO: Sign and submit SetMarginMode transaction using SDK
return nil
}
// boolToUint8 將布爾值轉換為uint8
// boolToUint8 Convert boolean to uint8
func boolToUint8(b bool) uint8 {
if b {
return 1
+46 -46
View File
@@ -5,15 +5,15 @@ import (
"nofx/logger"
)
// OpenLong 开多仓
// OpenLong Open long position
func (t *LighterTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
// TODO: 实现完整的开多仓逻辑
logger.Infof("🚧 LIGHTER OpenLong 暂未完全实现 (symbol=%s, qty=%.4f, leverage=%d)", symbol, quantity, leverage)
// TODO: Implement complete open long logic
logger.Infof("🚧 LIGHTER OpenLong not fully implemented (symbol=%s, qty=%.4f, leverage=%d)", symbol, quantity, leverage)
// 使用市价买入单
// Use market buy order
orderID, err := t.CreateOrder(symbol, "buy", quantity, 0, "market")
if err != nil {
return nil, fmt.Errorf("开多仓失败: %w", err)
return nil, fmt.Errorf("failed to open long: %w", err)
}
return map[string]interface{}{
@@ -23,15 +23,15 @@ func (t *LighterTrader) OpenLong(symbol string, quantity float64, leverage int)
}, nil
}
// OpenShort 开空仓
// OpenShort Open short position
func (t *LighterTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
// TODO: 实现完整的开空仓逻辑
logger.Infof("🚧 LIGHTER OpenShort 暂未完全实现 (symbol=%s, qty=%.4f, leverage=%d)", symbol, quantity, leverage)
// TODO: Implement complete open short logic
logger.Infof("🚧 LIGHTER OpenShort not fully implemented (symbol=%s, qty=%.4f, leverage=%d)", symbol, quantity, leverage)
// 使用市价卖出单
// Use market sell order
orderID, err := t.CreateOrder(symbol, "sell", quantity, 0, "market")
if err != nil {
return nil, fmt.Errorf("开空仓失败: %w", err)
return nil, fmt.Errorf("failed to open short: %w", err)
}
return map[string]interface{}{
@@ -41,13 +41,13 @@ func (t *LighterTrader) OpenShort(symbol string, quantity float64, leverage int)
}, nil
}
// CloseLong 平多仓(quantity=0表示全部平仓)
// CloseLong Close long position (quantity=0 means close all)
func (t *LighterTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) {
// 如果quantity=0,获取当前持仓数量
// If quantity=0, get current position size
if quantity == 0 {
pos, err := t.GetPosition(symbol)
if err != nil {
return nil, fmt.Errorf("获取持仓失败: %w", err)
return nil, fmt.Errorf("failed to get position: %w", err)
}
if pos == nil || pos.Size == 0 {
return map[string]interface{}{
@@ -58,15 +58,15 @@ func (t *LighterTrader) CloseLong(symbol string, quantity float64) (map[string]i
quantity = pos.Size
}
// 使用市价卖出单平仓
// Use market sell order to close
orderID, err := t.CreateOrder(symbol, "sell", quantity, 0, "market")
if err != nil {
return nil, fmt.Errorf("平多仓失败: %w", err)
return nil, fmt.Errorf("failed to close long: %w", err)
}
// 平仓后取消所有挂单
// Cancel all pending orders after closing
if err := t.CancelAllOrders(symbol); err != nil {
logger.Infof(" ⚠ 取消挂单失败: %v", err)
logger.Infof(" ⚠ Failed to cancel pending orders: %v", err)
}
return map[string]interface{}{
@@ -76,13 +76,13 @@ func (t *LighterTrader) CloseLong(symbol string, quantity float64) (map[string]i
}, nil
}
// CloseShort 平空仓(quantity=0表示全部平仓)
// CloseShort Close short position (quantity=0 means close all)
func (t *LighterTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) {
// 如果quantity=0,获取当前持仓数量
// If quantity=0, get current position size
if quantity == 0 {
pos, err := t.GetPosition(symbol)
if err != nil {
return nil, fmt.Errorf("获取持仓失败: %w", err)
return nil, fmt.Errorf("failed to get position: %w", err)
}
if pos == nil || pos.Size == 0 {
return map[string]interface{}{
@@ -93,15 +93,15 @@ func (t *LighterTrader) CloseShort(symbol string, quantity float64) (map[string]
quantity = pos.Size
}
// 使用市价买入单平仓
// Use market buy order to close
orderID, err := t.CreateOrder(symbol, "buy", quantity, 0, "market")
if err != nil {
return nil, fmt.Errorf("平空仓失败: %w", err)
return nil, fmt.Errorf("failed to close short: %w", err)
}
// 平仓后取消所有挂单
// Cancel all pending orders after closing
if err := t.CancelAllOrders(symbol); err != nil {
logger.Infof(" ⚠ 取消挂单失败: %v", err)
logger.Infof(" ⚠ Failed to cancel pending orders: %v", err)
}
return map[string]interface{}{
@@ -111,62 +111,62 @@ func (t *LighterTrader) CloseShort(symbol string, quantity float64) (map[string]
}, nil
}
// SetStopLoss 设置止损单
// SetStopLoss Set stop-loss order
func (t *LighterTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error {
// TODO: 实现完整的止损单逻辑
logger.Infof("🚧 LIGHTER SetStopLoss 暂未完全实现 (symbol=%s, side=%s, qty=%.4f, stop=%.2f)", symbol, positionSide, quantity, stopPrice)
// TODO: Implement complete stop-loss logic
logger.Infof("🚧 LIGHTER SetStopLoss not fully implemented (symbol=%s, side=%s, qty=%.4f, stop=%.2f)", symbol, positionSide, quantity, stopPrice)
// 确定订单方向(做空止损用买单,做多止损用卖单)
// Determine order side (short position uses buy, long position uses sell)
side := "sell"
if positionSide == "SHORT" {
side = "buy"
}
// 创建限价止损单
// Create limit stop-loss order
_, err := t.CreateOrder(symbol, side, quantity, stopPrice, "limit")
if err != nil {
return fmt.Errorf("设置止损失败: %w", err)
return fmt.Errorf("failed to set stop-loss: %w", err)
}
logger.Infof("✓ LIGHTER - 止损已设置: %.2f (side: %s)", stopPrice, side)
logger.Infof("✓ LIGHTER - stop-loss set: %.2f (side: %s)", stopPrice, side)
return nil
}
// SetTakeProfit 设置止盈单
// SetTakeProfit Set take-profit order
func (t *LighterTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error {
// TODO: 实现完整的止盈单逻辑
logger.Infof("🚧 LIGHTER SetTakeProfit 暂未完全实现 (symbol=%s, side=%s, qty=%.4f, tp=%.2f)", symbol, positionSide, quantity, takeProfitPrice)
// TODO: Implement complete take-profit logic
logger.Infof("🚧 LIGHTER SetTakeProfit not fully implemented (symbol=%s, side=%s, qty=%.4f, tp=%.2f)", symbol, positionSide, quantity, takeProfitPrice)
// 确定订单方向(做空止盈用买单,做多止盈用卖单)
// Determine order side (short position uses buy, long position uses sell)
side := "sell"
if positionSide == "SHORT" {
side = "buy"
}
// 创建限价止盈单
// Create limit take-profit order
_, err := t.CreateOrder(symbol, side, quantity, takeProfitPrice, "limit")
if err != nil {
return fmt.Errorf("设置止盈失败: %w", err)
return fmt.Errorf("failed to set take-profit: %w", err)
}
logger.Infof("✓ LIGHTER - 止盈已设置: %.2f (side: %s)", takeProfitPrice, side)
logger.Infof("✓ LIGHTER - take-profit set: %.2f (side: %s)", takeProfitPrice, side)
return nil
}
// SetMarginMode 设置仓位模式 (true=全仓, false=逐仓)
// SetMarginMode Set position mode (true=cross, false=isolated)
func (t *LighterTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
// TODO: 实现仓位模式设置
modeStr := "逐仓"
// TODO: Implement position mode setting
modeStr := "isolated"
if isCrossMargin {
modeStr = "全仓"
modeStr = "cross"
}
logger.Infof("🚧 LIGHTER SetMarginMode 暂未实现 (symbol=%s, mode=%s)", symbol, modeStr)
logger.Infof("🚧 LIGHTER SetMarginMode not implemented (symbol=%s, mode=%s)", symbol, modeStr)
return nil
}
// FormatQuantity 格式化数量到正确的精度
// FormatQuantity Format quantity to correct precision
func (t *LighterTrader) FormatQuantity(symbol string, quantity float64) (string, error) {
// TODO: 根据LIGHTER API获取币种精度
// 暂时使用默认精度
// TODO: Get symbol precision from LIGHTER API
// Using default precision for now
return fmt.Sprintf("%.4f", quantity), nil
}
+182 -182
View File
File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More