merge dev

This commit is contained in:
Icy
2025-11-12 23:40:58 +08:00
140 changed files with 33470 additions and 4481 deletions
+23 -186
View File
@@ -3,47 +3,10 @@ package config
import (
"encoding/json"
"fmt"
"log"
"os"
"time"
)
// TraderConfig 单个trader的配置
type TraderConfig struct {
ID string `json:"id"`
Name string `json:"name"`
Enabled bool `json:"enabled"` // 是否启用该trader
AIModel string `json:"ai_model"` // "qwen" or "deepseek"
// 交易平台选择(二选一)
Exchange string `json:"exchange"` // "binance" or "hyperliquid"
// 币安配置
BinanceAPIKey string `json:"binance_api_key,omitempty"`
BinanceSecretKey string `json:"binance_secret_key,omitempty"`
// Hyperliquid配置
HyperliquidPrivateKey string `json:"hyperliquid_private_key,omitempty"`
HyperliquidWalletAddr string `json:"hyperliquid_wallet_addr,omitempty"`
HyperliquidTestnet bool `json:"hyperliquid_testnet,omitempty"`
// Aster配置
AsterUser string `json:"aster_user,omitempty"` // Aster主钱包地址
AsterSigner string `json:"aster_signer,omitempty"` // Aster API钱包地址
AsterPrivateKey string `json:"aster_private_key,omitempty"` // Aster API钱包私钥
// AI配置
QwenKey string `json:"qwen_key,omitempty"`
DeepSeekKey string `json:"deepseek_key,omitempty"`
// 自定义AI API配置(支持任何OpenAI格式的API)
CustomAPIURL string `json:"custom_api_url,omitempty"`
CustomAPIKey string `json:"custom_api_key,omitempty"`
CustomModelName string `json:"custom_model_name,omitempty"`
InitialBalance float64 `json:"initial_balance"`
ScanIntervalMinutes int `json:"scan_interval_minutes"`
}
// LeverageConfig 杠杆配置
type LeverageConfig struct {
BTCETHLeverage int `json:"btc_eth_leverage"` // BTC和ETH的杠杆倍数(主账户建议5-50,子账户≤5)
@@ -66,166 +29,40 @@ type TelegramConfig struct {
// Config 总配置
type Config struct {
Traders []TraderConfig `json:"traders"`
UseDefaultCoins bool `json:"use_default_coins"` // 是否使用默认主流币种列表
DefaultCoins []string `json:"default_coins"` // 默认主流币种池
BetaMode bool `json:"beta_mode"`
APIServerPort int `json:"api_server_port"`
UseDefaultCoins bool `json:"use_default_coins"`
DefaultCoins []string `json:"default_coins"`
CoinPoolAPIURL string `json:"coin_pool_api_url"`
OITopAPIURL string `json:"oi_top_api_url"`
MaxDailyLoss float64 `json:"max_daily_loss"`
MaxDrawdown float64 `json:"max_drawdown"`
StopTradingMinutes int `json:"stop_trading_minutes"`
Leverage LeverageConfig `json:"leverage"` // 杠杆配置
Log *LogConfig `json:"log"` // 日志配置(可选)
Proxy *ProxyConfig `json:"proxy"` // HTTP 代理配置(可选)
Leverage LeverageConfig `json:"leverage"`
JWTSecret string `json:"jwt_secret"`
DataKLineTime string `json:"data_k_line_time"`
Log *LogConfig `json:"log"` // 日志配置
}
// ProxyConfig HTTP 代理配置
type ProxyConfig struct {
Enabled bool `json:"enabled"` // 是否启用代理
Mode string `json:"mode"` // 模式: "single", "pool", "brightdata"
Timeout int `json:"timeout"` // 超时时间(秒)
ProxyURL string `json:"proxy_url"` // 单个代理地址
ProxyList []string `json:"proxy_list"` // 代理列表
BrightDataEndpoint string `json:"brightdata_endpoint"` // Bright Data接口地址
BrightDataToken string `json:"brightdata_token"` // Bright Data访问令牌
BrightDataZone string `json:"brightdata_zone"` // Bright Data区域
ProxyHost string `json:"proxy_host"` // 代理主机
ProxyUser string `json:"proxy_user"` // 代理用户名模板
ProxyPassword string `json:"proxy_password"` // 代理密码
RefreshInterval int `json:"refresh_interval"` // 刷新间隔(秒)
BlacklistTTL int `json:"blacklist_ttl"` // 黑名单TTL
}
// LoadConfig 从文件加载配置
func LoadConfig(filename string) (*Config, error) {
// 检查filename是否存在
if _, err := os.Stat(filename); os.IsNotExist(err) {
log.Printf("📄 %s不存在,使用默认配置", filename)
return &Config{}, nil
}
// 读取 filename
data, err := os.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("读取配置文件失败: %w", err)
return nil, fmt.Errorf("读取%s失败: %w", filename, err)
}
var config Config
if err := json.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("解析配置文件失败: %w", err)
// 解析JSON
var configFile Config
if err := json.Unmarshal(data, &configFile); err != nil {
return nil, fmt.Errorf("解析%s失败: %w", filename, err)
}
// 设置默认值:确保使用默认币种列表
if !config.UseDefaultCoins {
config.UseDefaultCoins = true
}
// 设置默认币种池
if len(config.DefaultCoins) == 0 {
config.DefaultCoins = []string{
"BTCUSDT",
"ETHUSDT",
"SOLUSDT",
"BNBUSDT",
"XRPUSDT",
"DOGEUSDT",
"ADAUSDT",
"HYPEUSDT",
}
}
// 验证配置
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("配置验证失败: %w", err)
}
return &config, nil
}
// Validate 验证配置有效性
func (c *Config) Validate() error {
if len(c.Traders) == 0 {
return fmt.Errorf("至少需要配置一个trader")
}
traderIDs := make(map[string]bool)
for i, trader := range c.Traders {
if trader.ID == "" {
return fmt.Errorf("trader[%d]: ID不能为空", i)
}
if traderIDs[trader.ID] {
return fmt.Errorf("trader[%d]: ID '%s' 重复", i, trader.ID)
}
traderIDs[trader.ID] = true
if trader.Name == "" {
return fmt.Errorf("trader[%d]: Name不能为空", i)
}
if trader.AIModel != "qwen" && trader.AIModel != "deepseek" && trader.AIModel != "custom" {
return fmt.Errorf("trader[%d]: ai_model必须是 'qwen', 'deepseek' 或 'custom'", i)
}
// 验证交易平台配置
if trader.Exchange == "" {
trader.Exchange = "binance" // 默认使用币安
}
if trader.Exchange != "binance" && trader.Exchange != "hyperliquid" && trader.Exchange != "aster" {
return fmt.Errorf("trader[%d]: exchange必须是 'binance', 'hyperliquid' 或 'aster'", i)
}
// 根据平台验证对应的密钥
if trader.Exchange == "binance" {
if trader.BinanceAPIKey == "" || trader.BinanceSecretKey == "" {
return fmt.Errorf("trader[%d]: 使用币安时必须配置binance_api_key和binance_secret_key", i)
}
} else if trader.Exchange == "hyperliquid" {
if trader.HyperliquidPrivateKey == "" {
return fmt.Errorf("trader[%d]: 使用Hyperliquid时必须配置hyperliquid_private_key", i)
}
} else if trader.Exchange == "aster" {
if trader.AsterUser == "" || trader.AsterSigner == "" || trader.AsterPrivateKey == "" {
return fmt.Errorf("trader[%d]: 使用Aster时必须配置aster_user, aster_signer和aster_private_key", i)
}
}
if trader.AIModel == "qwen" && trader.QwenKey == "" {
return fmt.Errorf("trader[%d]: 使用Qwen时必须配置qwen_key", i)
}
if trader.AIModel == "deepseek" && trader.DeepSeekKey == "" {
return fmt.Errorf("trader[%d]: 使用DeepSeek时必须配置deepseek_key", i)
}
if trader.AIModel == "custom" {
if trader.CustomAPIURL == "" {
return fmt.Errorf("trader[%d]: 使用自定义API时必须配置custom_api_url", i)
}
if trader.CustomAPIKey == "" {
return fmt.Errorf("trader[%d]: 使用自定义API时必须配置custom_api_key", i)
}
if trader.CustomModelName == "" {
return fmt.Errorf("trader[%d]: 使用自定义API时必须配置custom_model_name", i)
}
}
if trader.InitialBalance <= 0 {
return fmt.Errorf("trader[%d]: initial_balance必须大于0", i)
}
if trader.ScanIntervalMinutes <= 0 {
trader.ScanIntervalMinutes = 3 // 默认3分钟
}
}
if c.APIServerPort <= 0 {
c.APIServerPort = 8080 // 默认8080端口
}
// 设置杠杆默认值(适配币安子账户限制,最大5倍)
if c.Leverage.BTCETHLeverage <= 0 {
c.Leverage.BTCETHLeverage = 5 // 默认5倍(安全值,适配子账户)
}
if c.Leverage.BTCETHLeverage > 5 {
fmt.Printf("⚠️ 警告: BTC/ETH杠杆设置为%dx,如果使用子账户可能会失败(子账户限制≤5x)\n", c.Leverage.BTCETHLeverage)
}
if c.Leverage.AltcoinLeverage <= 0 {
c.Leverage.AltcoinLeverage = 5 // 默认5倍(安全值,适配子账户)
}
if c.Leverage.AltcoinLeverage > 5 {
fmt.Printf("⚠️ 警告: 山寨币杠杆设置为%dx,如果使用子账户可能会失败(子账户限制≤5x)\n", c.Leverage.AltcoinLeverage)
}
return nil
}
// GetScanInterval 获取扫描间隔
func (tc *TraderConfig) GetScanInterval() time.Duration {
return time.Duration(tc.ScanIntervalMinutes) * time.Minute
return &configFile, nil
}
+1243 -100
View File
File diff suppressed because it is too large Load Diff
+799
View File
@@ -0,0 +1,799 @@
package config
import (
"nofx/crypto"
"os"
"testing"
"time"
)
// TestUpdateExchange_EmptyValuesShouldNotOverwrite 测试空值不应覆盖现有数据
// 这是 Bug 的核心:当前实现会用空字符串覆盖现有的私钥
func TestUpdateExchange_EmptyValuesShouldNotOverwrite(t *testing.T) {
// 准备测试数据库
db, cleanup := setupTestDB(t)
defer cleanup()
userID := "test-user-001"
// 步骤 1: 创建初始配置(包含私钥)
initialAPIKey := "initial-api-key-12345"
initialSecretKey := "initial-secret-key-67890"
err := db.UpdateExchange(
userID,
"hyperliquid",
true, // enabled
initialAPIKey,
initialSecretKey,
false, // testnet
"0xWalletAddress",
"",
"",
"",
)
if err != nil {
t.Fatalf("初始化失败: %v", err)
}
// 步骤 2: 验证初始数据已保存
exchanges, err := db.GetExchanges(userID)
if err != nil {
t.Fatalf("获取配置失败: %v", err)
}
if len(exchanges) == 0 {
t.Fatal("未找到配置")
}
// 解密后应该能看到原始值
if exchanges[0].APIKey != initialAPIKey {
t.Errorf("初始 APIKey 不正确,期望 %s,实际 %s", initialAPIKey, exchanges[0].APIKey)
}
// 步骤 3: 用空值更新(模拟前端发送空值的场景)
// 🐛 Bug 重现:这应该 NOT 覆盖现有的私钥,但当前实现会覆盖
err = db.UpdateExchange(
userID,
"hyperliquid",
false, // 只改变 enabled 状态
"", // 空 apiKey - 不应该覆盖
"", // 空 secretKey - 不应该覆盖
true, // 改变 testnet 状态
"0xWalletAddress",
"",
"",
"", // 空 aster_private_key - 不应该覆盖
)
if err != nil {
t.Fatalf("更新失败: %v", err)
}
// 步骤 4: 验证私钥没有被空值覆盖
exchanges, err = db.GetExchanges(userID)
if err != nil {
t.Fatalf("获取更新后配置失败: %v", err)
}
// 🎯 关键断言:私钥应该保持不变
if exchanges[0].APIKey != initialAPIKey {
t.Errorf("❌ Bug 确认:APIKey 被空值覆盖了!期望 %s,实际 %s", initialAPIKey, exchanges[0].APIKey)
}
if exchanges[0].SecretKey != initialSecretKey {
t.Errorf("❌ Bug 确认:SecretKey 被空值覆盖了!期望 %s,实际 %s", initialSecretKey, exchanges[0].SecretKey)
}
// 验证非敏感字段正常更新
if exchanges[0].Enabled {
t.Error("enabled 应该被更新为 false")
}
if !exchanges[0].Testnet {
t.Error("testnet 应该被更新为 true")
}
}
// TestUpdateExchange_AsterEmptyValuesShouldNotOverwrite 测试 Aster 私钥不被空值覆盖
func TestUpdateExchange_AsterEmptyValuesShouldNotOverwrite(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
userID := "test-user-002"
// 步骤 1: 创建 Aster 配置
initialAsterKey := "aster-private-key-xyz123"
err := db.UpdateExchange(
userID,
"aster",
true,
"",
"",
false,
"",
"0xAsterUser",
"0xAsterSigner",
initialAsterKey,
)
if err != nil {
t.Fatalf("初始化 Aster 失败: %v", err)
}
// 步骤 2: 用空值更新
err = db.UpdateExchange(
userID,
"aster",
false, // 只改 enabled
"",
"",
false,
"",
"0xAsterUser",
"0xAsterSigner",
"", // 空 aster_private_key
)
if err != nil {
t.Fatalf("更新失败: %v", err)
}
// 步骤 3: 验证 aster_private_key 没有被覆盖
exchanges, err := db.GetExchanges(userID)
if err != nil {
t.Fatalf("获取配置失败: %v", err)
}
if exchanges[0].AsterPrivateKey != initialAsterKey {
t.Errorf("❌ Bug 确认:AsterPrivateKey 被空值覆盖了!期望 %s,实际 %s", initialAsterKey, exchanges[0].AsterPrivateKey)
}
}
// TestUpdateExchange_NonEmptyValuesShouldUpdate 测试非空值应该正常更新
func TestUpdateExchange_NonEmptyValuesShouldUpdate(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
userID := "test-user-003"
// 步骤 1: 创建初始配置
err := db.UpdateExchange(
userID,
"hyperliquid",
true,
"old-api-key",
"old-secret-key",
false,
"0xOldWallet",
"",
"",
"",
)
if err != nil {
t.Fatalf("初始化失败: %v", err)
}
// 步骤 2: 用非空值更新
newAPIKey := "new-api-key-456"
newSecretKey := "new-secret-key-789"
err = db.UpdateExchange(
userID,
"hyperliquid",
true,
newAPIKey,
newSecretKey,
false,
"0xNewWallet",
"",
"",
"",
)
if err != nil {
t.Fatalf("更新失败: %v", err)
}
// 步骤 3: 验证新值已更新
exchanges, err := db.GetExchanges(userID)
if err != nil {
t.Fatalf("获取配置失败: %v", err)
}
if exchanges[0].APIKey != newAPIKey {
t.Errorf("APIKey 未更新,期望 %s,实际 %s", newAPIKey, exchanges[0].APIKey)
}
if exchanges[0].SecretKey != newSecretKey {
t.Errorf("SecretKey 未更新,期望 %s,实际 %s", newSecretKey, exchanges[0].SecretKey)
}
if exchanges[0].HyperliquidWalletAddr != "0xNewWallet" {
t.Errorf("WalletAddr 未更新")
}
}
// TestUpdateExchange_PartialUpdateShouldWork 测试部分字段更新
func TestUpdateExchange_PartialUpdateShouldWork(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
userID := "test-user-005"
// 创建初始配置
err := db.UpdateExchange(
userID,
"hyperliquid",
true,
"api-key-123",
"secret-key-456",
false,
"0xWallet1",
"",
"",
"",
)
if err != nil {
t.Fatalf("初始化失败: %v", err)
}
// 只更新 enabled 和 testnet,私钥留空
err = db.UpdateExchange(
userID,
"hyperliquid",
false,
"", // 留空
"", // 留空
true,
"0xWallet2",
"",
"",
"",
)
if err != nil {
t.Fatalf("部分更新失败: %v", err)
}
// 验证
exchanges, err := db.GetExchanges(userID)
if err != nil {
t.Fatalf("获取配置失败: %v", err)
}
// 私钥应该保持不变
if exchanges[0].APIKey != "api-key-123" {
t.Errorf("APIKey 不应改变,期望 api-key-123,实际 %s", exchanges[0].APIKey)
}
if exchanges[0].SecretKey != "secret-key-456" {
t.Errorf("SecretKey 不应改变,期望 secret-key-456,实际 %s", exchanges[0].SecretKey)
}
// 其他字段应该更新
if exchanges[0].Enabled {
t.Error("enabled 应该更新为 false")
}
if !exchanges[0].Testnet {
t.Error("testnet 应该更新为 true")
}
if exchanges[0].HyperliquidWalletAddr != "0xWallet2" {
t.Error("wallet 地址应该更新")
}
}
// TestUpdateExchange_MultipleExchangeTypes 测试不同交易所类型
func TestUpdateExchange_MultipleExchangeTypes(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
userID := "test-user-006"
testCases := []struct {
exchangeID string
name string
typ string
}{
{"binance", "Binance Futures", "cex"},
{"hyperliquid", "Hyperliquid", "dex"},
{"aster", "Aster DEX", "dex"},
{"unknown-exchange", "unknown-exchange Exchange", "cex"},
}
for _, tc := range testCases {
t.Run(tc.exchangeID, func(t *testing.T) {
err := db.UpdateExchange(
userID,
tc.exchangeID,
true,
"api-key-"+tc.exchangeID,
"secret-key-"+tc.exchangeID,
false,
"",
"",
"",
"",
)
if err != nil {
t.Fatalf("创建 %s 失败: %v", tc.exchangeID, err)
}
// 验证创建成功
exchanges, err := db.GetExchanges(userID)
if err != nil {
t.Fatalf("获取配置失败: %v", err)
}
found := false
for _, ex := range exchanges {
if ex.ID == tc.exchangeID {
found = true
if ex.Name != tc.name {
t.Errorf("交易所名称不正确,期望 %s,实际 %s", tc.name, ex.Name)
}
if ex.Type != tc.typ {
t.Errorf("交易所类型不正确,期望 %s,实际 %s", tc.typ, ex.Type)
}
if ex.APIKey != "api-key-"+tc.exchangeID {
t.Errorf("APIKey 不正确")
}
break
}
}
if !found {
t.Errorf("未找到交易所 %s", tc.exchangeID)
}
})
}
}
// TestUpdateExchange_MixedSensitiveFields 测试混合更新敏感和非敏感字段
func TestUpdateExchange_MixedSensitiveFields(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
userID := "test-user-007"
// 创建初始配置
err := db.UpdateExchange(
userID,
"hyperliquid",
true,
"old-api-key",
"old-secret-key",
false,
"0xOldWallet",
"",
"",
"",
)
if err != nil {
t.Fatalf("初始化失败: %v", err)
}
// 场景1: 只更新 apiKeysecretKey 留空
err = db.UpdateExchange(
userID,
"hyperliquid",
false,
"new-api-key",
"", // 留空
true,
"0xNewWallet",
"",
"",
"",
)
if err != nil {
t.Fatalf("更新1失败: %v", err)
}
exchanges, _ := db.GetExchanges(userID)
if exchanges[0].APIKey != "new-api-key" {
t.Error("APIKey 应该更新")
}
if exchanges[0].SecretKey != "old-secret-key" {
t.Error("SecretKey 应该保持不变")
}
// 场景2: 只更新 secretKeyapiKey 留空
err = db.UpdateExchange(
userID,
"hyperliquid",
true,
"", // 留空
"new-secret-key",
false,
"0xFinalWallet",
"",
"",
"",
)
if err != nil {
t.Fatalf("更新2失败: %v", err)
}
exchanges, _ = db.GetExchanges(userID)
if exchanges[0].APIKey != "new-api-key" {
t.Error("APIKey 应该保持不变")
}
if exchanges[0].SecretKey != "new-secret-key" {
t.Error("SecretKey 应该更新")
}
if exchanges[0].Enabled != true {
t.Error("Enabled 应该更新为 true")
}
if exchanges[0].HyperliquidWalletAddr != "0xFinalWallet" {
t.Error("WalletAddr 应该更新")
}
}
// TestUpdateExchange_OnlyNonSensitiveFields 测试只更新非敏感字段
func TestUpdateExchange_OnlyNonSensitiveFields(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
userID := "test-user-008"
// 创建初始配置(包含所有私钥)
err := db.UpdateExchange(
userID,
"aster",
true,
"binance-api",
"binance-secret",
false,
"",
"0xUser1",
"0xSigner1",
"aster-private-key-1",
)
if err != nil {
t.Fatalf("初始化失败: %v", err)
}
// 只更新非敏感字段(所有私钥字段留空)
err = db.UpdateExchange(
userID,
"aster",
false,
"",
"",
true,
"",
"0xUser2",
"0xSigner2",
"",
)
if err != nil {
t.Fatalf("更新失败: %v", err)
}
// 验证所有私钥保持不变
exchanges, _ := db.GetExchanges(userID)
if exchanges[0].APIKey != "binance-api" {
t.Errorf("APIKey 应该保持不变,实际 %s", exchanges[0].APIKey)
}
if exchanges[0].SecretKey != "binance-secret" {
t.Errorf("SecretKey 应该保持不变,实际 %s", exchanges[0].SecretKey)
}
if exchanges[0].AsterPrivateKey != "aster-private-key-1" {
t.Errorf("AsterPrivateKey 应该保持不变,实际 %s", exchanges[0].AsterPrivateKey)
}
// 验证非敏感字段已更新
if exchanges[0].Enabled != false {
t.Error("Enabled 应该更新为 false")
}
if exchanges[0].Testnet != true {
t.Error("Testnet 应该更新为 true")
}
if exchanges[0].AsterUser != "0xUser2" {
t.Error("AsterUser 应该更新")
}
if exchanges[0].AsterSigner != "0xSigner2" {
t.Error("AsterSigner 应该更新")
}
}
// TestUpdateExchange_AllSensitiveFieldsUpdate 测试同时更新所有敏感字段
func TestUpdateExchange_AllSensitiveFieldsUpdate(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
userID := "test-user-009"
// 创建初始配置
err := db.UpdateExchange(
userID,
"binance",
true,
"old-api",
"old-secret",
false,
"",
"",
"",
"old-aster-key",
)
if err != nil {
t.Fatalf("初始化失败: %v", err)
}
// 同时更新所有敏感字段
err = db.UpdateExchange(
userID,
"binance",
false,
"new-api",
"new-secret",
true,
"0xWallet",
"0xUser",
"0xSigner",
"new-aster-key",
)
if err != nil {
t.Fatalf("更新失败: %v", err)
}
// 验证所有字段都更新了
exchanges, _ := db.GetExchanges(userID)
if exchanges[0].APIKey != "new-api" {
t.Error("APIKey 应该更新")
}
if exchanges[0].SecretKey != "new-secret" {
t.Error("SecretKey 应该更新")
}
if exchanges[0].AsterPrivateKey != "new-aster-key" {
t.Error("AsterPrivateKey 应该更新")
}
if !exchanges[0].Testnet {
t.Error("Testnet 应该更新为 true")
}
}
// setupTestDB 创建测试数据库
func setupTestDB(t *testing.T) (*Database, func()) {
// 创建临时数据库文件
tmpFile := t.TempDir() + "/test.db"
db, err := NewDatabase(tmpFile)
if err != nil {
t.Fatalf("创建测试数据库失败: %v", err)
}
// 创建测试用户
testUsers := []string{"test-user-001", "test-user-002", "test-user-003", "test-user-004", "test-user-005", "test-user-006", "test-user-007", "test-user-008", "test-user-009"}
for _, userID := range testUsers {
user := &User{
ID: userID,
Email: userID + "@test.com",
PasswordHash: "hash",
OTPSecret: "",
OTPVerified: false,
}
_ = db.CreateUser(user)
}
// 设置加密服务(用于测试加密功能)
// 创建临时 RSA 密钥
rsaKeyPath := t.TempDir() + "/test_rsa_key"
cryptoService, err := crypto.NewCryptoService(rsaKeyPath)
if err != nil {
// 如果创建失败,继续测试但不使用加密
t.Logf("警告:无法创建加密服务,将在无加密模式下测试: %v", err)
} else {
db.SetCryptoService(cryptoService)
}
cleanup := func() {
db.Close()
os.RemoveAll(tmpFile)
os.RemoveAll(rsaKeyPath)
}
return db, cleanup
}
// TestWALModeEnabled 测试 WAL 模式是否启用
// TDD: 这个测试应该失败,因为当前代码没有启用 WAL 模式
func TestWALModeEnabled(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
// 查询当前的 journal_mode
var journalMode string
err := db.db.QueryRow("PRAGMA journal_mode").Scan(&journalMode)
if err != nil {
t.Fatalf("查询 journal_mode 失败: %v", err)
}
// 期望是 WAL 模式
if journalMode != "wal" {
t.Errorf("期望 journal_mode=wal,实际是 %s", journalMode)
}
}
// TestSynchronousMode 测试 synchronous 模式设置
// TDD: 验证数据持久性设置
func TestSynchronousMode(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
// 查询 synchronous 设置
var synchronous int
err := db.db.QueryRow("PRAGMA synchronous").Scan(&synchronous)
if err != nil {
t.Fatalf("查询 synchronous 失败: %v", err)
}
// 期望是 FULL (2) 以确保数据持久性
if synchronous != 2 {
t.Errorf("期望 synchronous=2 (FULL),实际是 %d", synchronous)
}
}
// TestDataPersistenceAcrossReopen 测试数据在数据库关闭并重新打开后是否持久化
// TDD: 模拟 Docker restart 场景
func TestDataPersistenceAcrossReopen(t *testing.T) {
// 创建临时数据库文件
tmpFile, err := os.CreateTemp("", "test_persistence_*.db")
if err != nil {
t.Fatalf("创建临时文件失败: %v", err)
}
tmpFile.Close()
dbPath := tmpFile.Name()
defer os.Remove(dbPath)
// 设置加密服务
rsaKeyPath := "test_rsa_key.pem"
cryptoService, err := crypto.NewCryptoService(rsaKeyPath)
if err != nil {
t.Fatalf("初始化加密服务失败: %v", err)
}
defer os.RemoveAll(rsaKeyPath)
userID := "test-user-persistence"
testAPIKey := "test-api-key-should-persist"
testSecretKey := "test-secret-key-should-persist"
// 第一次打开数据库并写入数据
{
db, err := NewDatabase(dbPath)
if err != nil {
t.Fatalf("第一次创建数据库失败: %v", err)
}
db.SetCryptoService(cryptoService)
// 写入交易所配置
err = db.UpdateExchange(
userID,
"binance",
true,
testAPIKey,
testSecretKey,
false,
"",
"",
"",
"",
)
if err != nil {
t.Fatalf("写入数据失败: %v", err)
}
// 模拟正常关闭
if err := db.Close(); err != nil {
t.Fatalf("关闭数据库失败: %v", err)
}
}
// 第二次打开数据库并验证数据是否还在
{
db, err := NewDatabase(dbPath)
if err != nil {
t.Fatalf("第二次打开数据库失败: %v", err)
}
db.SetCryptoService(cryptoService)
defer db.Close()
// 读取数据
exchanges, err := db.GetExchanges(userID)
if err != nil {
t.Fatalf("读取数据失败: %v", err)
}
if len(exchanges) == 0 {
t.Fatal("数据丢失:没有找到任何交易所配置")
}
// 验证数据完整性
found := false
for _, ex := range exchanges {
if ex.ID == "binance" {
found = true
if ex.APIKey != testAPIKey {
t.Errorf("API Key 丢失或损坏,期望 %s,实际 %s", testAPIKey, ex.APIKey)
}
if ex.SecretKey != testSecretKey {
t.Errorf("Secret Key 丢失或损坏,期望 %s,实际 %s", testSecretKey, ex.SecretKey)
}
}
}
if !found {
t.Error("数据丢失:找不到 binance 配置")
}
}
}
// TestConcurrentWritesWithWAL 测试 WAL 模式下的并发写入
// TDD: WAL 模式应该支持更好的并发性能
func TestConcurrentWritesWithWAL(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
// 这个测试验证多个并发写入可以成功
// WAL 模式下并发性能更好,但 SQLite 仍然可能出现短暂的锁
done := make(chan bool, 2)
errors := make(chan error, 10)
// 并发写入1
go func() {
for i := 0; i < 3; i++ {
err := db.UpdateExchange(
"user1",
"binance",
true,
"key1",
"secret1",
false,
"",
"",
"",
"",
)
if err != nil {
errors <- err
}
// 小延迟减少锁冲突
time.Sleep(10 * time.Millisecond)
}
done <- true
}()
// 并发写入2
go func() {
for i := 0; i < 3; i++ {
err := db.UpdateExchange(
"user2",
"hyperliquid",
true,
"key2",
"secret2",
false,
"0xWallet",
"",
"",
"",
)
if err != nil {
errors <- err
}
// 小延迟减少锁冲突
time.Sleep(10 * time.Millisecond)
}
done <- true
}()
// 等待两个 goroutine 完成
<-done
<-done
close(errors)
// 检查是否有错误
errorCount := 0
for err := range errors {
t.Logf("并发写入错误: %v", err)
errorCount++
}
// WAL 模式下应该能处理并发,但可能有少量锁错误
// 我们允许最多 2 个错误
if errorCount > 2 {
t.Errorf("并发写入失败次数过多: %d", errorCount)
}
}
+9
View File
@@ -0,0 +1,9 @@
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4Y666RzY5LLi6PiYL+vC
7+fcr122Fd8BC7IdqUSYKQ33Nsi9J7J5fDgcMf7ZAnIBpxMV7+e1KEoiwtGmxwHj
mYo0ZV0E6JXdiK26S052+Shquri0IXkwGFraDuNKqmGrj6vZuXtq2L2gdSyZCxrI
veN9g6LxBvLBP1Rx7UEmZeyokRYvChcxAQXuS/0br44BOHGtwAElk6AGLISz55AG
oM40b3ktiza+8THKMz3GiylQQYpBltbM3yAXPlnXJ2MtUZiaHNhEQI4++PMvEErN
Izm8cIgcvUAXJ5vBfa4kD0kSgBJFuEQ2im3qcWTuEPRKztEeJDY7XAVHc1Xy6d4N
vQIDAQAB
-----END PUBLIC KEY-----