mirror of
https://github.com/laoxong/nofx.git
synced 2026-06-04 01:48:22 +08:00
feat: migrate store layer to GORM with PostgreSQL support
- Migrate all store packages from raw database/sql to GORM ORM - Add PostgreSQL support alongside SQLite - Move EncryptedString type to crypto package for cleaner architecture - Add automatic encryption/decryption for sensitive fields (API keys, secrets) - Fix PostgreSQL AutoMigrate conflicts by skipping existing tables - Fix duplicate /klines route registration - Update tests to use GORM database connections - Add database configuration support in config package
This commit is contained in:
@@ -55,3 +55,16 @@ TRANSPORT_ENCRYPTION=false
|
|||||||
# Telegram notifications (optional)
|
# Telegram notifications (optional)
|
||||||
# TELEGRAM_BOT_TOKEN=your-bot-token
|
# TELEGRAM_BOT_TOKEN=your-bot-token
|
||||||
# TELEGRAM_CHAT_ID=your-chat-id
|
# TELEGRAM_CHAT_ID=your-chat-id
|
||||||
|
|
||||||
|
DB_TYPE=postgres
|
||||||
|
DB_HOST=10.
|
||||||
|
DB_PORT=5432
|
||||||
|
DB_USER=nofx_user
|
||||||
|
DB_PASSWORD=
|
||||||
|
DB_NAME=nofx
|
||||||
|
DB_SSLMODE=disable
|
||||||
|
|
||||||
|
|
||||||
|
# 数据库配置 - SQLite(默认)
|
||||||
|
DB_TYPE=sqlite
|
||||||
|
DB_PATH=data/data.db
|
||||||
+1
-1
@@ -824,7 +824,7 @@ func (s *Server) hydrateBacktestAIConfig(cfg *backtest.BacktestConfig) error {
|
|||||||
return fmt.Errorf("AI model %s is not enabled yet", model.Name)
|
return fmt.Errorf("AI model %s is not enabled yet", model.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
apiKey := strings.TrimSpace(model.APIKey)
|
apiKey := strings.TrimSpace(string(model.APIKey))
|
||||||
if apiKey == "" {
|
if apiKey == "" {
|
||||||
return fmt.Errorf("AI model %s is missing API Key, please configure it in the system first", model.Name)
|
return fmt.Errorf("AI model %s is missing API Key, please configure it in the system first", model.Name)
|
||||||
}
|
}
|
||||||
|
|||||||
+43
-41
@@ -54,7 +54,7 @@ func NewServer(traderManager *manager.TraderManager, st *store.Store, cryptoServ
|
|||||||
cryptoHandler := NewCryptoHandler(cryptoService)
|
cryptoHandler := NewCryptoHandler(cryptoService)
|
||||||
|
|
||||||
// Create debate store and handler
|
// Create debate store and handler
|
||||||
debateStore := store.NewDebateStore(st.DB())
|
debateStore := store.NewDebateStore(st.GormDB())
|
||||||
if err := debateStore.InitSchema(); err != nil {
|
if err := debateStore.InitSchema(); err != nil {
|
||||||
logger.Errorf("Failed to initialize debate schema: %v", err)
|
logger.Errorf("Failed to initialize debate schema: %v", err)
|
||||||
}
|
}
|
||||||
@@ -125,7 +125,6 @@ func (s *Server) setupRoutes() {
|
|||||||
|
|
||||||
// Market data (no authentication required)
|
// Market data (no authentication required)
|
||||||
api.GET("/klines", s.handleKlines)
|
api.GET("/klines", s.handleKlines)
|
||||||
api.GET("/klines", s.handleKlines)
|
|
||||||
api.GET("/symbols", s.handleSymbols)
|
api.GET("/symbols", s.handleSymbols)
|
||||||
|
|
||||||
// Authentication related routes (no authentication required)
|
// Authentication related routes (no authentication required)
|
||||||
@@ -576,12 +575,13 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
|
|||||||
var createErr error
|
var createErr error
|
||||||
|
|
||||||
// Use ExchangeType (e.g., "binance") instead of ID (UUID)
|
// Use ExchangeType (e.g., "binance") instead of ID (UUID)
|
||||||
|
// Convert EncryptedString fields to string
|
||||||
switch exchangeCfg.ExchangeType {
|
switch exchangeCfg.ExchangeType {
|
||||||
case "binance":
|
case "binance":
|
||||||
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID)
|
tempTrader = trader.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID)
|
||||||
case "hyperliquid":
|
case "hyperliquid":
|
||||||
tempTrader, createErr = trader.NewHyperliquidTrader(
|
tempTrader, createErr = trader.NewHyperliquidTrader(
|
||||||
exchangeCfg.APIKey, // private key
|
string(exchangeCfg.APIKey), // private key
|
||||||
exchangeCfg.HyperliquidWalletAddr,
|
exchangeCfg.HyperliquidWalletAddr,
|
||||||
exchangeCfg.Testnet,
|
exchangeCfg.Testnet,
|
||||||
)
|
)
|
||||||
@@ -589,31 +589,31 @@ func (s *Server) handleCreateTrader(c *gin.Context) {
|
|||||||
tempTrader, createErr = trader.NewAsterTrader(
|
tempTrader, createErr = trader.NewAsterTrader(
|
||||||
exchangeCfg.AsterUser,
|
exchangeCfg.AsterUser,
|
||||||
exchangeCfg.AsterSigner,
|
exchangeCfg.AsterSigner,
|
||||||
exchangeCfg.AsterPrivateKey,
|
string(exchangeCfg.AsterPrivateKey),
|
||||||
)
|
)
|
||||||
case "bybit":
|
case "bybit":
|
||||||
tempTrader = trader.NewBybitTrader(
|
tempTrader = trader.NewBybitTrader(
|
||||||
exchangeCfg.APIKey,
|
string(exchangeCfg.APIKey),
|
||||||
exchangeCfg.SecretKey,
|
string(exchangeCfg.SecretKey),
|
||||||
)
|
)
|
||||||
case "okx":
|
case "okx":
|
||||||
tempTrader = trader.NewOKXTrader(
|
tempTrader = trader.NewOKXTrader(
|
||||||
exchangeCfg.APIKey,
|
string(exchangeCfg.APIKey),
|
||||||
exchangeCfg.SecretKey,
|
string(exchangeCfg.SecretKey),
|
||||||
exchangeCfg.Passphrase,
|
string(exchangeCfg.Passphrase),
|
||||||
)
|
)
|
||||||
case "bitget":
|
case "bitget":
|
||||||
tempTrader = trader.NewBitgetTrader(
|
tempTrader = trader.NewBitgetTrader(
|
||||||
exchangeCfg.APIKey,
|
string(exchangeCfg.APIKey),
|
||||||
exchangeCfg.SecretKey,
|
string(exchangeCfg.SecretKey),
|
||||||
exchangeCfg.Passphrase,
|
string(exchangeCfg.Passphrase),
|
||||||
)
|
)
|
||||||
case "lighter":
|
case "lighter":
|
||||||
if exchangeCfg.LighterWalletAddr != "" && exchangeCfg.LighterAPIKeyPrivateKey != "" {
|
if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" {
|
||||||
// Lighter only supports mainnet
|
// Lighter only supports mainnet
|
||||||
tempTrader, createErr = trader.NewLighterTraderV2(
|
tempTrader, createErr = trader.NewLighterTraderV2(
|
||||||
exchangeCfg.LighterWalletAddr,
|
exchangeCfg.LighterWalletAddr,
|
||||||
exchangeCfg.LighterAPIKeyPrivateKey,
|
string(exchangeCfg.LighterAPIKeyPrivateKey),
|
||||||
exchangeCfg.LighterAPIKeyIndex,
|
exchangeCfg.LighterAPIKeyIndex,
|
||||||
false, // Always use mainnet for Lighter
|
false, // Always use mainnet for Lighter
|
||||||
)
|
)
|
||||||
@@ -1095,12 +1095,13 @@ func (s *Server) handleSyncBalance(c *gin.Context) {
|
|||||||
var createErr error
|
var createErr error
|
||||||
|
|
||||||
// Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID)
|
// Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID)
|
||||||
|
// Convert EncryptedString fields to string
|
||||||
switch exchangeCfg.ExchangeType {
|
switch exchangeCfg.ExchangeType {
|
||||||
case "binance":
|
case "binance":
|
||||||
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID)
|
tempTrader = trader.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID)
|
||||||
case "hyperliquid":
|
case "hyperliquid":
|
||||||
tempTrader, createErr = trader.NewHyperliquidTrader(
|
tempTrader, createErr = trader.NewHyperliquidTrader(
|
||||||
exchangeCfg.APIKey,
|
string(exchangeCfg.APIKey),
|
||||||
exchangeCfg.HyperliquidWalletAddr,
|
exchangeCfg.HyperliquidWalletAddr,
|
||||||
exchangeCfg.Testnet,
|
exchangeCfg.Testnet,
|
||||||
)
|
)
|
||||||
@@ -1108,31 +1109,31 @@ func (s *Server) handleSyncBalance(c *gin.Context) {
|
|||||||
tempTrader, createErr = trader.NewAsterTrader(
|
tempTrader, createErr = trader.NewAsterTrader(
|
||||||
exchangeCfg.AsterUser,
|
exchangeCfg.AsterUser,
|
||||||
exchangeCfg.AsterSigner,
|
exchangeCfg.AsterSigner,
|
||||||
exchangeCfg.AsterPrivateKey,
|
string(exchangeCfg.AsterPrivateKey),
|
||||||
)
|
)
|
||||||
case "bybit":
|
case "bybit":
|
||||||
tempTrader = trader.NewBybitTrader(
|
tempTrader = trader.NewBybitTrader(
|
||||||
exchangeCfg.APIKey,
|
string(exchangeCfg.APIKey),
|
||||||
exchangeCfg.SecretKey,
|
string(exchangeCfg.SecretKey),
|
||||||
)
|
)
|
||||||
case "okx":
|
case "okx":
|
||||||
tempTrader = trader.NewOKXTrader(
|
tempTrader = trader.NewOKXTrader(
|
||||||
exchangeCfg.APIKey,
|
string(exchangeCfg.APIKey),
|
||||||
exchangeCfg.SecretKey,
|
string(exchangeCfg.SecretKey),
|
||||||
exchangeCfg.Passphrase,
|
string(exchangeCfg.Passphrase),
|
||||||
)
|
)
|
||||||
case "bitget":
|
case "bitget":
|
||||||
tempTrader = trader.NewBitgetTrader(
|
tempTrader = trader.NewBitgetTrader(
|
||||||
exchangeCfg.APIKey,
|
string(exchangeCfg.APIKey),
|
||||||
exchangeCfg.SecretKey,
|
string(exchangeCfg.SecretKey),
|
||||||
exchangeCfg.Passphrase,
|
string(exchangeCfg.Passphrase),
|
||||||
)
|
)
|
||||||
case "lighter":
|
case "lighter":
|
||||||
if exchangeCfg.LighterWalletAddr != "" && exchangeCfg.LighterAPIKeyPrivateKey != "" {
|
if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" {
|
||||||
// Lighter only supports mainnet
|
// Lighter only supports mainnet
|
||||||
tempTrader, createErr = trader.NewLighterTraderV2(
|
tempTrader, createErr = trader.NewLighterTraderV2(
|
||||||
exchangeCfg.LighterWalletAddr,
|
exchangeCfg.LighterWalletAddr,
|
||||||
exchangeCfg.LighterAPIKeyPrivateKey,
|
string(exchangeCfg.LighterAPIKeyPrivateKey),
|
||||||
exchangeCfg.LighterAPIKeyIndex,
|
exchangeCfg.LighterAPIKeyIndex,
|
||||||
false, // Always use mainnet for Lighter
|
false, // Always use mainnet for Lighter
|
||||||
)
|
)
|
||||||
@@ -1246,12 +1247,13 @@ func (s *Server) handleClosePosition(c *gin.Context) {
|
|||||||
var createErr error
|
var createErr error
|
||||||
|
|
||||||
// Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID)
|
// Use ExchangeType (e.g., "binance") instead of ExchangeID (which is now UUID)
|
||||||
|
// Convert EncryptedString fields to string
|
||||||
switch exchangeCfg.ExchangeType {
|
switch exchangeCfg.ExchangeType {
|
||||||
case "binance":
|
case "binance":
|
||||||
tempTrader = trader.NewFuturesTrader(exchangeCfg.APIKey, exchangeCfg.SecretKey, userID)
|
tempTrader = trader.NewFuturesTrader(string(exchangeCfg.APIKey), string(exchangeCfg.SecretKey), userID)
|
||||||
case "hyperliquid":
|
case "hyperliquid":
|
||||||
tempTrader, createErr = trader.NewHyperliquidTrader(
|
tempTrader, createErr = trader.NewHyperliquidTrader(
|
||||||
exchangeCfg.APIKey,
|
string(exchangeCfg.APIKey),
|
||||||
exchangeCfg.HyperliquidWalletAddr,
|
exchangeCfg.HyperliquidWalletAddr,
|
||||||
exchangeCfg.Testnet,
|
exchangeCfg.Testnet,
|
||||||
)
|
)
|
||||||
@@ -1259,31 +1261,31 @@ func (s *Server) handleClosePosition(c *gin.Context) {
|
|||||||
tempTrader, createErr = trader.NewAsterTrader(
|
tempTrader, createErr = trader.NewAsterTrader(
|
||||||
exchangeCfg.AsterUser,
|
exchangeCfg.AsterUser,
|
||||||
exchangeCfg.AsterSigner,
|
exchangeCfg.AsterSigner,
|
||||||
exchangeCfg.AsterPrivateKey,
|
string(exchangeCfg.AsterPrivateKey),
|
||||||
)
|
)
|
||||||
case "bybit":
|
case "bybit":
|
||||||
tempTrader = trader.NewBybitTrader(
|
tempTrader = trader.NewBybitTrader(
|
||||||
exchangeCfg.APIKey,
|
string(exchangeCfg.APIKey),
|
||||||
exchangeCfg.SecretKey,
|
string(exchangeCfg.SecretKey),
|
||||||
)
|
)
|
||||||
case "okx":
|
case "okx":
|
||||||
tempTrader = trader.NewOKXTrader(
|
tempTrader = trader.NewOKXTrader(
|
||||||
exchangeCfg.APIKey,
|
string(exchangeCfg.APIKey),
|
||||||
exchangeCfg.SecretKey,
|
string(exchangeCfg.SecretKey),
|
||||||
exchangeCfg.Passphrase,
|
string(exchangeCfg.Passphrase),
|
||||||
)
|
)
|
||||||
case "bitget":
|
case "bitget":
|
||||||
tempTrader = trader.NewBitgetTrader(
|
tempTrader = trader.NewBitgetTrader(
|
||||||
exchangeCfg.APIKey,
|
string(exchangeCfg.APIKey),
|
||||||
exchangeCfg.SecretKey,
|
string(exchangeCfg.SecretKey),
|
||||||
exchangeCfg.Passphrase,
|
string(exchangeCfg.Passphrase),
|
||||||
)
|
)
|
||||||
case "lighter":
|
case "lighter":
|
||||||
if exchangeCfg.LighterWalletAddr != "" && exchangeCfg.LighterAPIKeyPrivateKey != "" {
|
if exchangeCfg.LighterWalletAddr != "" && string(exchangeCfg.LighterAPIKeyPrivateKey) != "" {
|
||||||
// Lighter only supports mainnet
|
// Lighter only supports mainnet
|
||||||
tempTrader, createErr = trader.NewLighterTraderV2(
|
tempTrader, createErr = trader.NewLighterTraderV2(
|
||||||
exchangeCfg.LighterWalletAddr,
|
exchangeCfg.LighterWalletAddr,
|
||||||
exchangeCfg.LighterAPIKeyPrivateKey,
|
string(exchangeCfg.LighterAPIKeyPrivateKey),
|
||||||
exchangeCfg.LighterAPIKeyIndex,
|
exchangeCfg.LighterAPIKeyIndex,
|
||||||
false, // Always use mainnet for Lighter
|
false, // Always use mainnet for Lighter
|
||||||
)
|
)
|
||||||
|
|||||||
+10
-8
@@ -549,32 +549,34 @@ func (s *Server) runRealAITest(userID, modelID, systemPrompt, userPrompt string)
|
|||||||
var aiClient mcp.AIClient
|
var aiClient mcp.AIClient
|
||||||
provider := model.Provider
|
provider := model.Provider
|
||||||
|
|
||||||
|
// Convert EncryptedString to string for API key
|
||||||
|
apiKey := string(model.APIKey)
|
||||||
switch provider {
|
switch provider {
|
||||||
case "qwen":
|
case "qwen":
|
||||||
aiClient = mcp.NewQwenClient()
|
aiClient = mcp.NewQwenClient()
|
||||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||||
case "deepseek":
|
case "deepseek":
|
||||||
aiClient = mcp.NewDeepSeekClient()
|
aiClient = mcp.NewDeepSeekClient()
|
||||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||||
case "claude":
|
case "claude":
|
||||||
aiClient = mcp.NewClaudeClient()
|
aiClient = mcp.NewClaudeClient()
|
||||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||||
case "kimi":
|
case "kimi":
|
||||||
aiClient = mcp.NewKimiClient()
|
aiClient = mcp.NewKimiClient()
|
||||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||||
case "gemini":
|
case "gemini":
|
||||||
aiClient = mcp.NewGeminiClient()
|
aiClient = mcp.NewGeminiClient()
|
||||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||||
case "grok":
|
case "grok":
|
||||||
aiClient = mcp.NewGrokClient()
|
aiClient = mcp.NewGrokClient()
|
||||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||||
case "openai":
|
case "openai":
|
||||||
aiClient = mcp.NewOpenAIClient()
|
aiClient = mcp.NewOpenAIClient()
|
||||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||||
default:
|
default:
|
||||||
// Use generic client
|
// Use generic client
|
||||||
aiClient = mcp.NewClient()
|
aiClient = mcp.NewClient()
|
||||||
aiClient.SetAPIKey(model.APIKey, model.CustomAPIURL, model.CustomModelName)
|
aiClient.SetAPIKey(apiKey, model.CustomAPIURL, model.CustomModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call AI API
|
// Call AI API
|
||||||
|
|||||||
@@ -76,8 +76,8 @@ func enforceRetentionDB(maxRuns int) {
|
|||||||
query := `
|
query := `
|
||||||
SELECT run_id FROM backtest_runs
|
SELECT run_id FROM backtest_runs
|
||||||
WHERE state IN (?, ?, ?, ?)
|
WHERE state IN (?, ?, ?, ?)
|
||||||
ORDER BY datetime(updated_at) DESC
|
ORDER BY updated_at DESC
|
||||||
LIMIT -1 OFFSET ?
|
OFFSET ?
|
||||||
`
|
`
|
||||||
rows, err := persistenceDB.Query(query,
|
rows, err := persistenceDB.Query(query,
|
||||||
finalStates[0], finalStates[1], finalStates[2], finalStates[3], maxRuns)
|
finalStates[0], finalStates[1], finalStates[2], finalStates[3], maxRuns)
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ func loadRunMetadataDB(runID string) (*RunMetadata, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func loadRunIDsDB() ([]string, error) {
|
func loadRunIDsDB() ([]string, error) {
|
||||||
rows, err := persistenceDB.Query(`SELECT run_id FROM backtest_runs ORDER BY datetime(updated_at) DESC`)
|
rows, err := persistenceDB.Query(`SELECT run_id FROM backtest_runs ORDER BY updated_at DESC`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -278,9 +278,9 @@ func loadDecisionTraceDB(runID string, cycle int) (*store.DecisionRecord, error)
|
|||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
var err error
|
var err error
|
||||||
if cycle > 0 {
|
if cycle > 0 {
|
||||||
rows, err = persistenceDB.Query(query+` AND cycle = ? ORDER BY datetime(created_at) DESC LIMIT 1`, runID, cycle)
|
rows, err = persistenceDB.Query(query+` AND cycle = ? ORDER BY created_at DESC LIMIT 1`, runID, cycle)
|
||||||
} else {
|
} else {
|
||||||
rows, err = persistenceDB.Query(query+` ORDER BY datetime(created_at) DESC LIMIT 1`, runID)
|
rows, err = persistenceDB.Query(query+` ORDER BY created_at DESC LIMIT 1`, runID)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -461,7 +461,7 @@ func listIndexEntriesDB() ([]RunIndexEntry, error) {
|
|||||||
rows, err := persistenceDB.Query(`
|
rows, err := persistenceDB.Query(`
|
||||||
SELECT run_id, state, symbol_count, decision_tf, equity_last, max_drawdown_pct, created_at, updated_at, config_json
|
SELECT run_id, state, symbol_count, decision_tf, equity_last, max_drawdown_pct, created_at, updated_at, config_json
|
||||||
FROM backtest_runs
|
FROM backtest_runs
|
||||||
ORDER BY datetime(updated_at) DESC
|
ORDER BY updated_at DESC
|
||||||
`)
|
`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -20,6 +20,16 @@ type Config struct {
|
|||||||
RegistrationEnabled bool
|
RegistrationEnabled bool
|
||||||
MaxUsers int // Maximum number of users allowed (0 = unlimited, default = 10)
|
MaxUsers int // Maximum number of users allowed (0 = unlimited, default = 10)
|
||||||
|
|
||||||
|
// Database configuration
|
||||||
|
DBType string // sqlite or postgres
|
||||||
|
DBPath string // SQLite database file path
|
||||||
|
DBHost string // PostgreSQL host
|
||||||
|
DBPort int // PostgreSQL port
|
||||||
|
DBUser string // PostgreSQL user
|
||||||
|
DBPassword string // PostgreSQL password
|
||||||
|
DBName string // PostgreSQL database name
|
||||||
|
DBSSLMode string // PostgreSQL SSL mode
|
||||||
|
|
||||||
// Security configuration
|
// Security configuration
|
||||||
// TransportEncryption enables browser-side encryption for API keys
|
// TransportEncryption enables browser-side encryption for API keys
|
||||||
// Requires HTTPS or localhost. Set to false for HTTP access via IP.
|
// Requires HTTPS or localhost. Set to false for HTTP access via IP.
|
||||||
@@ -43,6 +53,14 @@ func Init() {
|
|||||||
RegistrationEnabled: true,
|
RegistrationEnabled: true,
|
||||||
MaxUsers: 10, // Default: 10 users allowed
|
MaxUsers: 10, // Default: 10 users allowed
|
||||||
ExperienceImprovement: true, // Default: enabled to help improve the product
|
ExperienceImprovement: true, // Default: enabled to help improve the product
|
||||||
|
// Database defaults
|
||||||
|
DBType: "sqlite",
|
||||||
|
DBPath: "data/data.db",
|
||||||
|
DBHost: "localhost",
|
||||||
|
DBPort: 5432,
|
||||||
|
DBUser: "postgres",
|
||||||
|
DBName: "nofx",
|
||||||
|
DBSSLMode: "disable",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load from environment variables
|
// Load from environment variables
|
||||||
@@ -86,6 +104,34 @@ func Init() {
|
|||||||
cfg.AlpacaSecretKey = os.Getenv("ALPACA_SECRET_KEY")
|
cfg.AlpacaSecretKey = os.Getenv("ALPACA_SECRET_KEY")
|
||||||
cfg.TwelveDataKey = os.Getenv("TWELVEDATA_API_KEY")
|
cfg.TwelveDataKey = os.Getenv("TWELVEDATA_API_KEY")
|
||||||
|
|
||||||
|
// Database configuration
|
||||||
|
if v := os.Getenv("DB_TYPE"); v != "" {
|
||||||
|
cfg.DBType = strings.ToLower(v)
|
||||||
|
}
|
||||||
|
if v := os.Getenv("DB_PATH"); v != "" {
|
||||||
|
cfg.DBPath = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("DB_HOST"); v != "" {
|
||||||
|
cfg.DBHost = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("DB_PORT"); v != "" {
|
||||||
|
if port, err := strconv.Atoi(v); err == nil && port > 0 {
|
||||||
|
cfg.DBPort = port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v := os.Getenv("DB_USER"); v != "" {
|
||||||
|
cfg.DBUser = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("DB_PASSWORD"); v != "" {
|
||||||
|
cfg.DBPassword = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("DB_NAME"); v != "" {
|
||||||
|
cfg.DBName = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("DB_SSLMODE"); v != "" {
|
||||||
|
cfg.DBSSLMode = v
|
||||||
|
}
|
||||||
|
|
||||||
global = cfg
|
global = cfg
|
||||||
|
|
||||||
// Initialize experience improvement (installation ID will be set after database init)
|
// Initialize experience improvement (installation ID will be set after database init)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"database/sql/driver"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -392,3 +393,77 @@ func GenerateDataKey() (string, error) {
|
|||||||
}
|
}
|
||||||
return base64.StdEncoding.EncodeToString(key), nil
|
return base64.StdEncoding.EncodeToString(key), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// EncryptedString - GORM custom type for automatic encryption/decryption
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
// Global crypto service for EncryptedString
|
||||||
|
var globalCryptoService *CryptoService
|
||||||
|
|
||||||
|
// SetGlobalCryptoService sets the global crypto service for EncryptedString
|
||||||
|
func SetGlobalCryptoService(cs *CryptoService) {
|
||||||
|
globalCryptoService = cs
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncryptedString is a custom type that automatically encrypts on save and decrypts on load
|
||||||
|
// Usage: Use EncryptedString instead of string for sensitive fields in GORM models
|
||||||
|
type EncryptedString string
|
||||||
|
|
||||||
|
// Scan implements sql.Scanner - called when reading from database
|
||||||
|
// Automatically decrypts the value
|
||||||
|
func (es *EncryptedString) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
*es = ""
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var str string
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
str = v
|
||||||
|
case []byte:
|
||||||
|
str = string(v)
|
||||||
|
default:
|
||||||
|
*es = ""
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt if crypto service is set
|
||||||
|
if globalCryptoService != nil && str != "" && globalCryptoService.IsEncryptedStorageValue(str) {
|
||||||
|
decrypted, err := globalCryptoService.DecryptFromStorage(str)
|
||||||
|
if err != nil {
|
||||||
|
// If decryption fails, return the original value
|
||||||
|
*es = EncryptedString(str)
|
||||||
|
} else {
|
||||||
|
*es = EncryptedString(decrypted)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
*es = EncryptedString(str)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements driver.Valuer - called when writing to database
|
||||||
|
// Automatically encrypts the value
|
||||||
|
func (es EncryptedString) Value() (driver.Value, error) {
|
||||||
|
if es == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt if crypto service is set
|
||||||
|
if globalCryptoService != nil {
|
||||||
|
encrypted, err := globalCryptoService.EncryptForStorage(string(es))
|
||||||
|
if err != nil {
|
||||||
|
// If encryption fails, return the original value
|
||||||
|
return string(es), nil
|
||||||
|
}
|
||||||
|
return encrypted, nil
|
||||||
|
}
|
||||||
|
return string(es), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the plaintext string value
|
||||||
|
func (es EncryptedString) String() string {
|
||||||
|
return string(es)
|
||||||
|
}
|
||||||
|
|||||||
+2
-2
@@ -101,8 +101,8 @@ func (e *DebateEngine) InitializeClients(participants []*store.DebateParticipant
|
|||||||
client = mcp.New()
|
client = mcp.New()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure client
|
// Configure client (convert EncryptedString to string)
|
||||||
client.SetAPIKey(aiModel.APIKey, aiModel.CustomAPIURL, aiModel.CustomModelName)
|
client.SetAPIKey(string(aiModel.APIKey), aiModel.CustomAPIURL, aiModel.CustomModelName)
|
||||||
|
|
||||||
e.clients[p.AIModelID] = client
|
e.clients[p.AIModelID] = client
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,6 +51,12 @@ require (
|
|||||||
github.com/goccy/go-json v0.10.4 // indirect
|
github.com/goccy/go-json v0.10.4 // indirect
|
||||||
github.com/goccy/go-yaml v1.18.0 // indirect
|
github.com/goccy/go-yaml v1.18.0 // indirect
|
||||||
github.com/holiman/uint256 v1.3.2 // indirect
|
github.com/holiman/uint256 v1.3.2 // indirect
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
|
github.com/jackc/pgx/v5 v5.6.0 // indirect
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||||
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/josharian/intern v1.0.0 // indirect
|
github.com/josharian/intern v1.0.0 // indirect
|
||||||
github.com/jpillora/backoff v1.0.0 // indirect
|
github.com/jpillora/backoff v1.0.0 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
@@ -94,6 +100,9 @@ require (
|
|||||||
golang.org/x/tools v0.36.0 // indirect
|
golang.org/x/tools v0.36.0 // indirect
|
||||||
google.golang.org/protobuf v1.36.9 // indirect
|
google.golang.org/protobuf v1.36.9 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
gorm.io/driver/postgres v1.6.0 // indirect
|
||||||
|
gorm.io/driver/sqlite v1.6.0 // indirect
|
||||||
|
gorm.io/gorm v1.31.1 // indirect
|
||||||
howett.net/plist v1.0.1 // indirect
|
howett.net/plist v1.0.1 // indirect
|
||||||
modernc.org/libc v1.66.10 // indirect
|
modernc.org/libc v1.66.10 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
|
|||||||
@@ -111,7 +111,19 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
|
|||||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/holiman/uint256 v1.3.2 h1:a9EgMPSC1AAaj1SZL5zIQD3WbwTuHrMGOerLjGmM/TA=
|
github.com/holiman/uint256 v1.3.2 h1:a9EgMPSC1AAaj1SZL5zIQD3WbwTuHrMGOerLjGmM/TA=
|
||||||
github.com/holiman/uint256 v1.3.2/go.mod h1:EOMSn4q6Nyt9P6efbI3bueV4e1b3dGlUCXeiRV4ng7E=
|
github.com/holiman/uint256 v1.3.2/go.mod h1:EOMSn4q6Nyt9P6efbI3bueV4e1b3dGlUCXeiRV4ng7E=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
|
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
|
||||||
|
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||||
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
|
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||||
@@ -284,6 +296,12 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
|||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
|
||||||
|
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
|
||||||
|
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||||
|
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||||
|
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||||
|
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
||||||
howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM=
|
howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM=
|
||||||
howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
||||||
modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4=
|
modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4=
|
||||||
|
|||||||
@@ -36,30 +36,44 @@ func main() {
|
|||||||
cfg := config.Get()
|
cfg := config.Get()
|
||||||
logger.Info("✅ Configuration loaded")
|
logger.Info("✅ Configuration loaded")
|
||||||
|
|
||||||
// Initialize database from environment variables
|
// Initialize encryption service BEFORE database (so EncryptedString can decrypt on read)
|
||||||
// DB_TYPE: sqlite (default) or postgres
|
logger.Info("🔐 Initializing encryption service...")
|
||||||
// For SQLite: DB_PATH (default: data/data.db)
|
cryptoService, err := crypto.NewCryptoService()
|
||||||
// For PostgreSQL: DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME, DB_SSLMODE
|
if err != nil {
|
||||||
dbPath := os.Getenv("DB_PATH")
|
logger.Fatalf("❌ Failed to initialize encryption service: %v", err)
|
||||||
if dbPath == "" {
|
|
||||||
dbPath = "data/data.db"
|
|
||||||
}
|
}
|
||||||
// For backward compatibility: command line arg overrides env var (SQLite only)
|
crypto.SetGlobalCryptoService(cryptoService)
|
||||||
|
logger.Info("✅ Encryption service initialized successfully")
|
||||||
|
|
||||||
|
// Initialize database from configuration
|
||||||
|
// For backward compatibility: command line arg overrides config (SQLite only)
|
||||||
if len(os.Args) > 1 {
|
if len(os.Args) > 1 {
|
||||||
dbPath = os.Args[1]
|
cfg.DBPath = os.Args[1]
|
||||||
os.Setenv("DB_PATH", dbPath)
|
|
||||||
}
|
}
|
||||||
// Ensure data directory exists (for SQLite)
|
// Ensure data directory exists (for SQLite)
|
||||||
if os.Getenv("DB_TYPE") == "" || os.Getenv("DB_TYPE") == "sqlite" {
|
if cfg.DBType == "sqlite" {
|
||||||
if dir := filepath.Dir(dbPath); dir != "." {
|
if dir := filepath.Dir(cfg.DBPath); dir != "." {
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
logger.Errorf("Failed to create data directory: %v", err)
|
logger.Errorf("Failed to create data directory: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("📋 Initializing database...")
|
logger.Infof("📋 Initializing database (%s)...", cfg.DBType)
|
||||||
st, err := store.NewFromEnv()
|
dbType := store.DBTypeSQLite
|
||||||
|
if cfg.DBType == "postgres" {
|
||||||
|
dbType = store.DBTypePostgres
|
||||||
|
}
|
||||||
|
st, err := store.NewWithConfig(store.DBConfig{
|
||||||
|
Type: dbType,
|
||||||
|
Path: cfg.DBPath,
|
||||||
|
Host: cfg.DBHost,
|
||||||
|
Port: cfg.DBPort,
|
||||||
|
User: cfg.DBUser,
|
||||||
|
Password: cfg.DBPassword,
|
||||||
|
DBName: cfg.DBName,
|
||||||
|
SSLMode: cfg.DBSSLMode,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("❌ Failed to initialize database: %v", err)
|
logger.Fatalf("❌ Failed to initialize database: %v", err)
|
||||||
}
|
}
|
||||||
@@ -69,40 +83,6 @@ func main() {
|
|||||||
// Initialize installation ID for experience improvement (anonymous statistics)
|
// Initialize installation ID for experience improvement (anonymous statistics)
|
||||||
initInstallationID(st)
|
initInstallationID(st)
|
||||||
|
|
||||||
// Initialize encryption service
|
|
||||||
logger.Info("🔐 Initializing encryption service...")
|
|
||||||
cryptoService, err := crypto.NewCryptoService()
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatalf("❌ Failed to initialize encryption service: %v", err)
|
|
||||||
}
|
|
||||||
encryptFunc := func(plaintext string) string {
|
|
||||||
if plaintext == "" {
|
|
||||||
return plaintext
|
|
||||||
}
|
|
||||||
encrypted, err := cryptoService.EncryptForStorage(plaintext)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warnf("⚠️ Encryption failed: %v", err)
|
|
||||||
return plaintext
|
|
||||||
}
|
|
||||||
return encrypted
|
|
||||||
}
|
|
||||||
decryptFunc := func(encrypted string) string {
|
|
||||||
if encrypted == "" {
|
|
||||||
return encrypted
|
|
||||||
}
|
|
||||||
if !cryptoService.IsEncryptedStorageValue(encrypted) {
|
|
||||||
return encrypted
|
|
||||||
}
|
|
||||||
decrypted, err := cryptoService.DecryptFromStorage(encrypted)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warnf("⚠️ Decryption failed: %v", err)
|
|
||||||
return encrypted
|
|
||||||
}
|
|
||||||
return decrypted
|
|
||||||
}
|
|
||||||
st.SetCryptoFuncs(encryptFunc, decryptFunc)
|
|
||||||
logger.Info("✅ Encryption service initialized successfully")
|
|
||||||
|
|
||||||
// Set JWT secret
|
// Set JWT secret
|
||||||
auth.SetJWTSecret(cfg.JWTSecret)
|
auth.SetJWTSecret(cfg.JWTSecret)
|
||||||
logger.Info("🔑 JWT secret configured")
|
logger.Info("🔑 JWT secret configured")
|
||||||
|
|||||||
+19
-19
@@ -664,46 +664,46 @@ func (tm *TraderManager) addTraderFromStore(traderCfg *store.Trader, aiModelCfg
|
|||||||
StrategyConfig: strategyConfig,
|
StrategyConfig: strategyConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set API keys based on exchange type
|
// Set API keys based on exchange type (convert EncryptedString to string)
|
||||||
switch exchangeCfg.ExchangeType {
|
switch exchangeCfg.ExchangeType {
|
||||||
case "binance":
|
case "binance":
|
||||||
traderConfig.BinanceAPIKey = exchangeCfg.APIKey
|
traderConfig.BinanceAPIKey = string(exchangeCfg.APIKey)
|
||||||
traderConfig.BinanceSecretKey = exchangeCfg.SecretKey
|
traderConfig.BinanceSecretKey = string(exchangeCfg.SecretKey)
|
||||||
case "bybit":
|
case "bybit":
|
||||||
traderConfig.BybitAPIKey = exchangeCfg.APIKey
|
traderConfig.BybitAPIKey = string(exchangeCfg.APIKey)
|
||||||
traderConfig.BybitSecretKey = exchangeCfg.SecretKey
|
traderConfig.BybitSecretKey = string(exchangeCfg.SecretKey)
|
||||||
case "okx":
|
case "okx":
|
||||||
traderConfig.OKXAPIKey = exchangeCfg.APIKey
|
traderConfig.OKXAPIKey = string(exchangeCfg.APIKey)
|
||||||
traderConfig.OKXSecretKey = exchangeCfg.SecretKey
|
traderConfig.OKXSecretKey = string(exchangeCfg.SecretKey)
|
||||||
traderConfig.OKXPassphrase = exchangeCfg.Passphrase
|
traderConfig.OKXPassphrase = string(exchangeCfg.Passphrase)
|
||||||
case "bitget":
|
case "bitget":
|
||||||
traderConfig.BitgetAPIKey = exchangeCfg.APIKey
|
traderConfig.BitgetAPIKey = string(exchangeCfg.APIKey)
|
||||||
traderConfig.BitgetSecretKey = exchangeCfg.SecretKey
|
traderConfig.BitgetSecretKey = string(exchangeCfg.SecretKey)
|
||||||
traderConfig.BitgetPassphrase = exchangeCfg.Passphrase
|
traderConfig.BitgetPassphrase = string(exchangeCfg.Passphrase)
|
||||||
case "hyperliquid":
|
case "hyperliquid":
|
||||||
traderConfig.HyperliquidPrivateKey = exchangeCfg.APIKey
|
traderConfig.HyperliquidPrivateKey = string(exchangeCfg.APIKey)
|
||||||
traderConfig.HyperliquidWalletAddr = exchangeCfg.HyperliquidWalletAddr
|
traderConfig.HyperliquidWalletAddr = exchangeCfg.HyperliquidWalletAddr
|
||||||
case "aster":
|
case "aster":
|
||||||
traderConfig.AsterUser = exchangeCfg.AsterUser
|
traderConfig.AsterUser = exchangeCfg.AsterUser
|
||||||
traderConfig.AsterSigner = exchangeCfg.AsterSigner
|
traderConfig.AsterSigner = exchangeCfg.AsterSigner
|
||||||
traderConfig.AsterPrivateKey = exchangeCfg.AsterPrivateKey
|
traderConfig.AsterPrivateKey = string(exchangeCfg.AsterPrivateKey)
|
||||||
case "lighter":
|
case "lighter":
|
||||||
traderConfig.LighterPrivateKey = exchangeCfg.LighterPrivateKey
|
traderConfig.LighterPrivateKey = string(exchangeCfg.LighterPrivateKey)
|
||||||
traderConfig.LighterWalletAddr = exchangeCfg.LighterWalletAddr
|
traderConfig.LighterWalletAddr = exchangeCfg.LighterWalletAddr
|
||||||
traderConfig.LighterAPIKeyPrivateKey = exchangeCfg.LighterAPIKeyPrivateKey
|
traderConfig.LighterAPIKeyPrivateKey = string(exchangeCfg.LighterAPIKeyPrivateKey)
|
||||||
traderConfig.LighterAPIKeyIndex = exchangeCfg.LighterAPIKeyIndex
|
traderConfig.LighterAPIKeyIndex = exchangeCfg.LighterAPIKeyIndex
|
||||||
traderConfig.LighterTestnet = exchangeCfg.Testnet
|
traderConfig.LighterTestnet = exchangeCfg.Testnet
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set API keys based on AI model
|
// Set API keys based on AI model (convert EncryptedString to string)
|
||||||
switch aiModelCfg.Provider {
|
switch aiModelCfg.Provider {
|
||||||
case "qwen":
|
case "qwen":
|
||||||
traderConfig.QwenKey = aiModelCfg.APIKey
|
traderConfig.QwenKey = string(aiModelCfg.APIKey)
|
||||||
case "deepseek":
|
case "deepseek":
|
||||||
traderConfig.DeepSeekKey = aiModelCfg.APIKey
|
traderConfig.DeepSeekKey = string(aiModelCfg.APIKey)
|
||||||
default:
|
default:
|
||||||
// For other providers (grok, openai, claude, gemini, kimi, etc.), use CustomAPIKey
|
// For other providers (grok, openai, claude, gemini, kimi, etc.), use CustomAPIKey
|
||||||
traderConfig.CustomAPIKey = aiModelCfg.APIKey
|
traderConfig.CustomAPIKey = string(aiModelCfg.APIKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create trader instance
|
// Create trader instance
|
||||||
|
|||||||
+90
-174
@@ -1,71 +1,52 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"nofx/crypto"
|
||||||
"nofx/logger"
|
"nofx/logger"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AIModelStore AI model storage
|
// AIModelStore AI model storage
|
||||||
type AIModelStore struct {
|
type AIModelStore struct {
|
||||||
db *sql.DB
|
db *gorm.DB
|
||||||
encryptFunc func(string) string
|
|
||||||
decryptFunc func(string) string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AIModel AI model configuration
|
// AIModel AI model configuration
|
||||||
type AIModel struct {
|
type AIModel struct {
|
||||||
ID string `json:"id"`
|
ID string `gorm:"primaryKey" json:"id"`
|
||||||
UserID string `json:"user_id"`
|
UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"`
|
||||||
Name string `json:"name"`
|
Name string `gorm:"not null" json:"name"`
|
||||||
Provider string `json:"provider"`
|
Provider string `gorm:"not null" json:"provider"`
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `gorm:"default:false" json:"enabled"`
|
||||||
APIKey string `json:"apiKey"`
|
APIKey crypto.EncryptedString `gorm:"column:api_key;default:''" json:"apiKey"`
|
||||||
CustomAPIURL string `json:"customApiUrl"`
|
CustomAPIURL string `gorm:"column:custom_api_url;default:''" json:"customApiUrl"`
|
||||||
CustomModelName string `json:"customModelName"`
|
CustomModelName string `gorm:"column:custom_model_name;default:''" json:"customModelName"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (AIModel) TableName() string { return "ai_models" }
|
||||||
|
|
||||||
|
// NewAIModelStore creates a new AIModelStore
|
||||||
|
func NewAIModelStore(db *gorm.DB) *AIModelStore {
|
||||||
|
return &AIModelStore{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AIModelStore) initTables() error {
|
func (s *AIModelStore) initTables() error {
|
||||||
_, err := s.db.Exec(`
|
// For PostgreSQL with existing table, skip AutoMigrate
|
||||||
CREATE TABLE IF NOT EXISTS ai_models (
|
if s.db.Dialector.Name() == "postgres" {
|
||||||
id TEXT PRIMARY KEY,
|
var tableExists int64
|
||||||
user_id TEXT NOT NULL DEFAULT 'default',
|
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'ai_models'`).Scan(&tableExists)
|
||||||
name TEXT NOT NULL,
|
if tableExists > 0 {
|
||||||
provider TEXT NOT NULL,
|
return nil
|
||||||
enabled BOOLEAN DEFAULT 0,
|
}
|
||||||
api_key TEXT DEFAULT '',
|
|
||||||
custom_api_url TEXT DEFAULT '',
|
|
||||||
custom_model_name TEXT DEFAULT '',
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
return s.db.AutoMigrate(&AIModel{})
|
||||||
// Trigger
|
|
||||||
_, err = s.db.Exec(`
|
|
||||||
CREATE TRIGGER IF NOT EXISTS update_ai_models_updated_at
|
|
||||||
AFTER UPDATE ON ai_models
|
|
||||||
BEGIN
|
|
||||||
UPDATE ai_models SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
|
||||||
END
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Backward compatibility: add potentially missing columns
|
|
||||||
s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_api_url TEXT DEFAULT ''`)
|
|
||||||
s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_model_name TEXT DEFAULT ''`)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AIModelStore) initDefaultData() error {
|
func (s *AIModelStore) initDefaultData() error {
|
||||||
@@ -73,51 +54,13 @@ func (s *AIModelStore) initDefaultData() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AIModelStore) encrypt(plaintext string) string {
|
|
||||||
if s.encryptFunc != nil {
|
|
||||||
return s.encryptFunc(plaintext)
|
|
||||||
}
|
|
||||||
return plaintext
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *AIModelStore) decrypt(encrypted string) string {
|
|
||||||
if s.decryptFunc != nil {
|
|
||||||
return s.decryptFunc(encrypted)
|
|
||||||
}
|
|
||||||
return encrypted
|
|
||||||
}
|
|
||||||
|
|
||||||
// List retrieves user's AI model list
|
// List retrieves user's AI model list
|
||||||
func (s *AIModelStore) List(userID string) ([]*AIModel, error) {
|
func (s *AIModelStore) List(userID string) ([]*AIModel, error) {
|
||||||
rows, err := s.db.Query(`
|
var models []*AIModel
|
||||||
SELECT id, user_id, name, provider, enabled, api_key,
|
err := s.db.Where("user_id = ?", userID).Order("id").Find(&models).Error
|
||||||
COALESCE(custom_api_url, '') as custom_api_url,
|
|
||||||
COALESCE(custom_model_name, '') as custom_model_name,
|
|
||||||
created_at, updated_at
|
|
||||||
FROM ai_models WHERE user_id = ? ORDER BY id
|
|
||||||
`, userID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
models := make([]*AIModel, 0)
|
|
||||||
for rows.Next() {
|
|
||||||
var model AIModel
|
|
||||||
var createdAt, updatedAt string
|
|
||||||
err := rows.Scan(
|
|
||||||
&model.ID, &model.UserID, &model.Name, &model.Provider,
|
|
||||||
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
|
|
||||||
&createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
model.APIKey = s.decrypt(model.APIKey)
|
|
||||||
models = append(models, &model)
|
|
||||||
}
|
|
||||||
return models, nil
|
return models, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,27 +83,15 @@ func (s *AIModelStore) Get(userID, modelID string) (*AIModel, error) {
|
|||||||
|
|
||||||
for _, uid := range candidates {
|
for _, uid := range candidates {
|
||||||
var model AIModel
|
var model AIModel
|
||||||
var createdAt, updatedAt string
|
err := s.db.Where("user_id = ? AND id = ?", uid, modelID).First(&model).Error
|
||||||
err := s.db.QueryRow(`
|
|
||||||
SELECT id, user_id, name, provider, enabled, api_key,
|
|
||||||
COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at
|
|
||||||
FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1
|
|
||||||
`, uid, modelID).Scan(
|
|
||||||
&model.ID, &model.UserID, &model.Name, &model.Provider,
|
|
||||||
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
|
|
||||||
&createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
model.APIKey = s.decrypt(model.APIKey)
|
|
||||||
return &model, nil
|
return &model, nil
|
||||||
}
|
}
|
||||||
if !errors.Is(err, sql.ErrNoRows) {
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, sql.ErrNoRows
|
return nil, gorm.ErrRecordNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByID retrieves an AI model by ID only (for debate engine)
|
// GetByID retrieves an AI model by ID only (for debate engine)
|
||||||
@@ -170,22 +101,10 @@ func (s *AIModelStore) GetByID(modelID string) (*AIModel, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var model AIModel
|
var model AIModel
|
||||||
var createdAt, updatedAt string
|
err := s.db.Where("id = ?", modelID).First(&model).Error
|
||||||
err := s.db.QueryRow(`
|
|
||||||
SELECT id, user_id, name, provider, enabled, api_key,
|
|
||||||
COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at
|
|
||||||
FROM ai_models WHERE id = ? LIMIT 1
|
|
||||||
`, modelID).Scan(
|
|
||||||
&model.ID, &model.UserID, &model.Name, &model.Provider,
|
|
||||||
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
|
|
||||||
&createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
model.APIKey = s.decrypt(model.APIKey)
|
|
||||||
return &model, nil
|
return &model, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,7 +117,7 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return model, nil
|
return model, nil
|
||||||
}
|
}
|
||||||
if !errors.Is(err, sql.ErrNoRows) {
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if userID != "default" {
|
if userID != "default" {
|
||||||
@@ -209,23 +128,12 @@ func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) {
|
|||||||
|
|
||||||
func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) {
|
func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) {
|
||||||
var model AIModel
|
var model AIModel
|
||||||
var createdAt, updatedAt string
|
err := s.db.Where("user_id = ? AND enabled = ?", userID, true).
|
||||||
err := s.db.QueryRow(`
|
Order("updated_at DESC, id ASC").
|
||||||
SELECT id, user_id, name, provider, enabled, api_key,
|
First(&model).Error
|
||||||
COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at
|
|
||||||
FROM ai_models WHERE user_id = ? AND enabled = 1
|
|
||||||
ORDER BY datetime(updated_at) DESC, id ASC LIMIT 1
|
|
||||||
`, userID).Scan(
|
|
||||||
&model.ID, &model.UserID, &model.Name, &model.Provider,
|
|
||||||
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
|
|
||||||
&createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
model.APIKey = s.decrypt(model.APIKey)
|
|
||||||
return &model, nil
|
return &model, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -233,44 +141,38 @@ func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) {
|
|||||||
// IMPORTANT: If apiKey is empty string, the existing API key will be preserved (not overwritten)
|
// IMPORTANT: If apiKey is empty string, the existing API key will be preserved (not overwritten)
|
||||||
func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error {
|
func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error {
|
||||||
// Try exact ID match first
|
// Try exact ID match first
|
||||||
var existingID string
|
var existingModel AIModel
|
||||||
err := s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1`, userID, id).Scan(&existingID)
|
err := s.db.Where("user_id = ? AND id = ?", userID, id).First(&existingModel).Error
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// If apiKey is empty, preserve the existing API key
|
// Update existing model
|
||||||
if apiKey == "" {
|
updates := map[string]interface{}{
|
||||||
_, err = s.db.Exec(`
|
"enabled": enabled,
|
||||||
UPDATE ai_models SET enabled = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
|
"custom_api_url": customAPIURL,
|
||||||
WHERE id = ? AND user_id = ?
|
"custom_model_name": customModelName,
|
||||||
`, enabled, customAPIURL, customModelName, existingID, userID)
|
"updated_at": time.Now(),
|
||||||
} else {
|
|
||||||
encryptedAPIKey := s.encrypt(apiKey)
|
|
||||||
_, err = s.db.Exec(`
|
|
||||||
UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
|
|
||||||
WHERE id = ? AND user_id = ?
|
|
||||||
`, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID)
|
|
||||||
}
|
}
|
||||||
return err
|
// If apiKey is not empty, update it (encryption handled by crypto.EncryptedString)
|
||||||
|
if apiKey != "" {
|
||||||
|
updates["api_key"] = crypto.EncryptedString(apiKey)
|
||||||
|
}
|
||||||
|
return s.db.Model(&existingModel).Updates(updates).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try legacy logic compatibility: use id as provider to search
|
// Try legacy logic compatibility: use id as provider to search
|
||||||
provider := id
|
provider := id
|
||||||
err = s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND provider = ? LIMIT 1`, userID, provider).Scan(&existingID)
|
err = s.db.Where("user_id = ? AND provider = ?", userID, provider).First(&existingModel).Error
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logger.Warnf("⚠️ Using legacy provider matching to update model: %s -> %s", provider, existingID)
|
logger.Warnf("⚠️ Using legacy provider matching to update model: %s -> %s", provider, existingModel.ID)
|
||||||
// If apiKey is empty, preserve the existing API key
|
updates := map[string]interface{}{
|
||||||
if apiKey == "" {
|
"enabled": enabled,
|
||||||
_, err = s.db.Exec(`
|
"custom_api_url": customAPIURL,
|
||||||
UPDATE ai_models SET enabled = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
|
"custom_model_name": customModelName,
|
||||||
WHERE id = ? AND user_id = ?
|
"updated_at": time.Now(),
|
||||||
`, enabled, customAPIURL, customModelName, existingID, userID)
|
|
||||||
} else {
|
|
||||||
encryptedAPIKey := s.encrypt(apiKey)
|
|
||||||
_, err = s.db.Exec(`
|
|
||||||
UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
|
|
||||||
WHERE id = ? AND user_id = ?
|
|
||||||
`, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID)
|
|
||||||
}
|
}
|
||||||
return err
|
if apiKey != "" {
|
||||||
|
updates["api_key"] = crypto.EncryptedString(apiKey)
|
||||||
|
}
|
||||||
|
return s.db.Model(&existingModel).Updates(updates).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new record
|
// Create new record
|
||||||
@@ -285,9 +187,12 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try to get name from existing model with same provider
|
||||||
|
var refModel AIModel
|
||||||
var name string
|
var name string
|
||||||
err = s.db.QueryRow(`SELECT name FROM ai_models WHERE provider = ? LIMIT 1`, provider).Scan(&name)
|
if err := s.db.Where("provider = ?", provider).First(&refModel).Error; err == nil {
|
||||||
if err != nil {
|
name = refModel.Name
|
||||||
|
} else {
|
||||||
if provider == "deepseek" {
|
if provider == "deepseek" {
|
||||||
name = "DeepSeek AI"
|
name = "DeepSeek AI"
|
||||||
} else if provider == "qwen" {
|
} else if provider == "qwen" {
|
||||||
@@ -303,19 +208,30 @@ func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPI
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Infof("✓ Creating new AI model configuration: ID=%s, Provider=%s, Name=%s", newModelID, provider, name)
|
logger.Infof("✓ Creating new AI model configuration: ID=%s, Provider=%s, Name=%s", newModelID, provider, name)
|
||||||
encryptedAPIKey := s.encrypt(apiKey)
|
newModel := &AIModel{
|
||||||
_, err = s.db.Exec(`
|
ID: newModelID,
|
||||||
INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url, custom_model_name, created_at, updated_at)
|
UserID: userID,
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))
|
Name: name,
|
||||||
`, newModelID, userID, name, provider, enabled, encryptedAPIKey, customAPIURL, customModelName)
|
Provider: provider,
|
||||||
return err
|
Enabled: enabled,
|
||||||
|
APIKey: crypto.EncryptedString(apiKey),
|
||||||
|
CustomAPIURL: customAPIURL,
|
||||||
|
CustomModelName: customModelName,
|
||||||
|
}
|
||||||
|
return s.db.Create(newModel).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create creates an AI model
|
// Create creates an AI model
|
||||||
func (s *AIModelStore) Create(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error {
|
func (s *AIModelStore) Create(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error {
|
||||||
_, err := s.db.Exec(`
|
model := &AIModel{
|
||||||
INSERT OR IGNORE INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url)
|
ID: id,
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
UserID: userID,
|
||||||
`, id, userID, name, provider, enabled, apiKey, customAPIURL)
|
Name: name,
|
||||||
return err
|
Provider: provider,
|
||||||
|
Enabled: enabled,
|
||||||
|
APIKey: crypto.EncryptedString(apiKey),
|
||||||
|
CustomAPIURL: customAPIURL,
|
||||||
|
}
|
||||||
|
// Use FirstOrCreate to ignore if already exists
|
||||||
|
return s.db.Where("id = ?", id).FirstOrCreate(model).Error
|
||||||
}
|
}
|
||||||
|
|||||||
+339
-351
@@ -1,15 +1,26 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BacktestStore backtest data storage
|
// BacktestStore backtest data storage
|
||||||
type BacktestStore struct {
|
type BacktestStore struct {
|
||||||
db *sql.DB
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBacktestStore creates a new backtest store
|
||||||
|
func NewBacktestStore(db *gorm.DB) *BacktestStore {
|
||||||
|
return &BacktestStore{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPostgres checks if the database is PostgreSQL
|
||||||
|
func (s *BacktestStore) isPostgres() bool {
|
||||||
|
return s.db.Dialector.Name() == "postgres"
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunState backtest state
|
// RunState backtest state
|
||||||
@@ -92,492 +103,469 @@ type RunIndexEntry struct {
|
|||||||
UpdatedAtISO string `json:"updated_at"`
|
UpdatedAtISO string `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BacktestRun GORM model for backtest_runs table
|
||||||
|
type BacktestRun struct {
|
||||||
|
RunID string `gorm:"column:run_id;primaryKey"`
|
||||||
|
UserID string `gorm:"column:user_id;not null;default:''"`
|
||||||
|
ConfigJSON []byte `gorm:"column:config_json"`
|
||||||
|
State string `gorm:"column:state;not null;default:created"`
|
||||||
|
Label string `gorm:"column:label;default:''"`
|
||||||
|
SymbolCount int `gorm:"column:symbol_count;default:0"`
|
||||||
|
DecisionTF string `gorm:"column:decision_tf;default:''"`
|
||||||
|
ProcessedBars int `gorm:"column:processed_bars;default:0"`
|
||||||
|
ProgressPct float64 `gorm:"column:progress_pct;default:0"`
|
||||||
|
EquityLast float64 `gorm:"column:equity_last;default:0"`
|
||||||
|
MaxDrawdownPct float64 `gorm:"column:max_drawdown_pct;default:0"`
|
||||||
|
Liquidated bool `gorm:"column:liquidated;default:false"`
|
||||||
|
LiquidationNote string `gorm:"column:liquidation_note;default:''"`
|
||||||
|
PromptTemplate string `gorm:"column:prompt_template;default:''"`
|
||||||
|
CustomPrompt string `gorm:"column:custom_prompt;default:''"`
|
||||||
|
OverridePrompt bool `gorm:"column:override_prompt;default:false"`
|
||||||
|
AIProvider string `gorm:"column:ai_provider;default:''"`
|
||||||
|
AIModel string `gorm:"column:ai_model;default:''"`
|
||||||
|
LastError string `gorm:"column:last_error;default:''"`
|
||||||
|
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"`
|
||||||
|
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (BacktestRun) TableName() string {
|
||||||
|
return "backtest_runs"
|
||||||
|
}
|
||||||
|
|
||||||
|
// BacktestCheckpoint GORM model
|
||||||
|
type BacktestCheckpoint struct {
|
||||||
|
RunID string `gorm:"column:run_id;primaryKey"`
|
||||||
|
Payload []byte `gorm:"column:payload;not null"`
|
||||||
|
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (BacktestCheckpoint) TableName() string {
|
||||||
|
return "backtest_checkpoints"
|
||||||
|
}
|
||||||
|
|
||||||
|
// BacktestEquity GORM model
|
||||||
|
type BacktestEquity struct {
|
||||||
|
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||||
|
RunID string `gorm:"column:run_id;not null;index:idx_backtest_equity_run_ts"`
|
||||||
|
TS int64 `gorm:"column:ts;not null;index:idx_backtest_equity_run_ts"`
|
||||||
|
Equity float64 `gorm:"column:equity;not null"`
|
||||||
|
Available float64 `gorm:"column:available;not null"`
|
||||||
|
PnL float64 `gorm:"column:pnl;not null"`
|
||||||
|
PnLPct float64 `gorm:"column:pnl_pct;not null"`
|
||||||
|
DDPct float64 `gorm:"column:dd_pct;not null"`
|
||||||
|
Cycle int `gorm:"column:cycle;not null"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (BacktestEquity) TableName() string {
|
||||||
|
return "backtest_equity"
|
||||||
|
}
|
||||||
|
|
||||||
|
// BacktestTrade GORM model
|
||||||
|
type BacktestTrade struct {
|
||||||
|
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||||
|
RunID string `gorm:"column:run_id;not null;index:idx_backtest_trades_run_ts"`
|
||||||
|
TS int64 `gorm:"column:ts;not null;index:idx_backtest_trades_run_ts"`
|
||||||
|
Symbol string `gorm:"column:symbol;not null"`
|
||||||
|
Action string `gorm:"column:action;not null"`
|
||||||
|
Side string `gorm:"column:side;default:''"`
|
||||||
|
Qty float64 `gorm:"column:qty;default:0"`
|
||||||
|
Price float64 `gorm:"column:price;default:0"`
|
||||||
|
Fee float64 `gorm:"column:fee;default:0"`
|
||||||
|
Slippage float64 `gorm:"column:slippage;default:0"`
|
||||||
|
OrderValue float64 `gorm:"column:order_value;default:0"`
|
||||||
|
RealizedPnL float64 `gorm:"column:realized_pnl;default:0"`
|
||||||
|
Leverage int `gorm:"column:leverage;default:0"`
|
||||||
|
Cycle int `gorm:"column:cycle;default:0"`
|
||||||
|
PositionAfter float64 `gorm:"column:position_after;default:0"`
|
||||||
|
Liquidation bool `gorm:"column:liquidation;default:false"`
|
||||||
|
Note string `gorm:"column:note;default:''"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (BacktestTrade) TableName() string {
|
||||||
|
return "backtest_trades"
|
||||||
|
}
|
||||||
|
|
||||||
|
// BacktestMetrics GORM model
|
||||||
|
type BacktestMetrics struct {
|
||||||
|
RunID string `gorm:"column:run_id;primaryKey"`
|
||||||
|
Payload []byte `gorm:"column:payload;not null"`
|
||||||
|
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (BacktestMetrics) TableName() string {
|
||||||
|
return "backtest_metrics"
|
||||||
|
}
|
||||||
|
|
||||||
|
// BacktestDecision GORM model
|
||||||
|
type BacktestDecision struct {
|
||||||
|
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||||
|
RunID string `gorm:"column:run_id;not null;index:idx_backtest_decisions_run_cycle"`
|
||||||
|
Cycle int `gorm:"column:cycle;not null;index:idx_backtest_decisions_run_cycle"`
|
||||||
|
Payload []byte `gorm:"column:payload;not null"`
|
||||||
|
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (BacktestDecision) TableName() string {
|
||||||
|
return "backtest_decisions"
|
||||||
|
}
|
||||||
|
|
||||||
// initTables initializes backtest related tables
|
// initTables initializes backtest related tables
|
||||||
func (s *BacktestStore) initTables() error {
|
func (s *BacktestStore) initTables() error {
|
||||||
queries := []string{
|
// For PostgreSQL with existing tables, skip AutoMigrate to avoid type conflicts
|
||||||
// Backtest runs main table
|
if s.db.Dialector.Name() == "postgres" {
|
||||||
`CREATE TABLE IF NOT EXISTS backtest_runs (
|
var tableExists int64
|
||||||
run_id TEXT PRIMARY KEY,
|
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'backtest_runs'`).Scan(&tableExists)
|
||||||
user_id TEXT NOT NULL DEFAULT '',
|
|
||||||
config_json TEXT NOT NULL DEFAULT '',
|
|
||||||
state TEXT NOT NULL DEFAULT 'created',
|
|
||||||
label TEXT DEFAULT '',
|
|
||||||
symbol_count INTEGER DEFAULT 0,
|
|
||||||
decision_tf TEXT DEFAULT '',
|
|
||||||
processed_bars INTEGER DEFAULT 0,
|
|
||||||
progress_pct REAL DEFAULT 0,
|
|
||||||
equity_last REAL DEFAULT 0,
|
|
||||||
max_drawdown_pct REAL DEFAULT 0,
|
|
||||||
liquidated BOOLEAN DEFAULT 0,
|
|
||||||
liquidation_note TEXT DEFAULT '',
|
|
||||||
prompt_template TEXT DEFAULT '',
|
|
||||||
custom_prompt TEXT DEFAULT '',
|
|
||||||
override_prompt BOOLEAN DEFAULT 0,
|
|
||||||
ai_provider TEXT DEFAULT '',
|
|
||||||
ai_model TEXT DEFAULT '',
|
|
||||||
last_error TEXT DEFAULT '',
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)`,
|
|
||||||
|
|
||||||
// Backtest checkpoints
|
if tableExists > 0 {
|
||||||
`CREATE TABLE IF NOT EXISTS backtest_checkpoints (
|
// Tables exist - just ensure indexes exist
|
||||||
run_id TEXT PRIMARY KEY,
|
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`)
|
||||||
payload BLOB NOT NULL,
|
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`)
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_backtest_decisions_run_cycle ON backtest_decisions(run_id, cycle)`)
|
||||||
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
|
return nil
|
||||||
)`,
|
|
||||||
|
|
||||||
// Backtest equity curve
|
|
||||||
`CREATE TABLE IF NOT EXISTS backtest_equity (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
run_id TEXT NOT NULL,
|
|
||||||
ts INTEGER NOT NULL,
|
|
||||||
equity REAL NOT NULL,
|
|
||||||
available REAL NOT NULL,
|
|
||||||
pnl REAL NOT NULL,
|
|
||||||
pnl_pct REAL NOT NULL,
|
|
||||||
dd_pct REAL NOT NULL,
|
|
||||||
cycle INTEGER NOT NULL,
|
|
||||||
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
|
|
||||||
)`,
|
|
||||||
|
|
||||||
// Backtest trade records
|
|
||||||
`CREATE TABLE IF NOT EXISTS backtest_trades (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
run_id TEXT NOT NULL,
|
|
||||||
ts INTEGER NOT NULL,
|
|
||||||
symbol TEXT NOT NULL,
|
|
||||||
action TEXT NOT NULL,
|
|
||||||
side TEXT DEFAULT '',
|
|
||||||
qty REAL DEFAULT 0,
|
|
||||||
price REAL DEFAULT 0,
|
|
||||||
fee REAL DEFAULT 0,
|
|
||||||
slippage REAL DEFAULT 0,
|
|
||||||
order_value REAL DEFAULT 0,
|
|
||||||
realized_pnl REAL DEFAULT 0,
|
|
||||||
leverage INTEGER DEFAULT 0,
|
|
||||||
cycle INTEGER DEFAULT 0,
|
|
||||||
position_after REAL DEFAULT 0,
|
|
||||||
liquidation BOOLEAN DEFAULT 0,
|
|
||||||
note TEXT DEFAULT '',
|
|
||||||
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
|
|
||||||
)`,
|
|
||||||
|
|
||||||
// Backtest metrics
|
|
||||||
`CREATE TABLE IF NOT EXISTS backtest_metrics (
|
|
||||||
run_id TEXT PRIMARY KEY,
|
|
||||||
payload BLOB NOT NULL,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
|
|
||||||
)`,
|
|
||||||
|
|
||||||
// Backtest decision logs
|
|
||||||
`CREATE TABLE IF NOT EXISTS backtest_decisions (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
run_id TEXT NOT NULL,
|
|
||||||
cycle INTEGER NOT NULL,
|
|
||||||
payload BLOB NOT NULL,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
|
|
||||||
)`,
|
|
||||||
|
|
||||||
// Indexes
|
|
||||||
`CREATE INDEX IF NOT EXISTS idx_backtest_runs_state ON backtest_runs(state, updated_at)`,
|
|
||||||
`CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`,
|
|
||||||
`CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`,
|
|
||||||
`CREATE INDEX IF NOT EXISTS idx_backtest_decisions_run_cycle ON backtest_decisions(run_id, cycle)`,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, query := range queries {
|
|
||||||
if _, err := s.db.Exec(query); err != nil {
|
|
||||||
return fmt.Errorf("failed to execute SQL: %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add potentially missing columns (backward compatibility)
|
// AutoMigrate all backtest tables
|
||||||
s.addColumnIfNotExists("backtest_runs", "label", "TEXT DEFAULT ''")
|
if err := s.db.AutoMigrate(
|
||||||
s.addColumnIfNotExists("backtest_runs", "last_error", "TEXT DEFAULT ''")
|
&BacktestRun{},
|
||||||
s.addColumnIfNotExists("backtest_trades", "leverage", "INTEGER DEFAULT 0")
|
&BacktestCheckpoint{},
|
||||||
|
&BacktestEquity{},
|
||||||
|
&BacktestTrade{},
|
||||||
|
&BacktestMetrics{},
|
||||||
|
&BacktestDecision{},
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("failed to migrate backtest tables: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BacktestStore) addColumnIfNotExists(table, column, definition string) {
|
|
||||||
rows, err := s.db.Query(fmt.Sprintf("PRAGMA table_info(%s)", table))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
var cid int
|
|
||||||
var name, ctype string
|
|
||||||
var notnull, pk int
|
|
||||||
var dflt interface{}
|
|
||||||
if err := rows.Scan(&cid, &name, &ctype, ¬null, &dflt, &pk); err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if name == column {
|
|
||||||
return // Column already exists
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, definition))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveCheckpoint saves checkpoint
|
// SaveCheckpoint saves checkpoint
|
||||||
func (s *BacktestStore) SaveCheckpoint(runID string, payload []byte) error {
|
func (s *BacktestStore) SaveCheckpoint(runID string, payload []byte) error {
|
||||||
_, err := s.db.Exec(`
|
checkpoint := BacktestCheckpoint{
|
||||||
INSERT INTO backtest_checkpoints (run_id, payload, updated_at)
|
RunID: runID,
|
||||||
VALUES (?, ?, CURRENT_TIMESTAMP)
|
Payload: payload,
|
||||||
ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP
|
}
|
||||||
`, runID, payload)
|
return s.db.Save(&checkpoint).Error
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadCheckpoint loads checkpoint
|
// LoadCheckpoint loads checkpoint
|
||||||
func (s *BacktestStore) LoadCheckpoint(runID string) ([]byte, error) {
|
func (s *BacktestStore) LoadCheckpoint(runID string) ([]byte, error) {
|
||||||
var payload []byte
|
var checkpoint BacktestCheckpoint
|
||||||
err := s.db.QueryRow(`SELECT payload FROM backtest_checkpoints WHERE run_id = ?`, runID).Scan(&payload)
|
err := s.db.Where("run_id = ?", runID).First(&checkpoint).Error
|
||||||
return payload, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return checkpoint.Payload, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveRunMetadata saves run metadata
|
// SaveRunMetadata saves run metadata
|
||||||
func (s *BacktestStore) SaveRunMetadata(meta *RunMetadata) error {
|
func (s *BacktestStore) SaveRunMetadata(meta *RunMetadata) error {
|
||||||
created := meta.CreatedAt.UTC().Format(time.RFC3339)
|
run := BacktestRun{
|
||||||
updated := meta.UpdatedAt.UTC().Format(time.RFC3339)
|
RunID: meta.RunID,
|
||||||
userID := meta.UserID
|
UserID: meta.UserID,
|
||||||
|
State: string(meta.State),
|
||||||
if _, err := s.db.Exec(`
|
Label: meta.Label,
|
||||||
INSERT INTO backtest_runs (run_id, user_id, label, last_error, created_at, updated_at)
|
LastError: meta.LastError,
|
||||||
VALUES (?, ?, ?, ?, ?, ?)
|
SymbolCount: meta.Summary.SymbolCount,
|
||||||
ON CONFLICT(run_id) DO NOTHING
|
DecisionTF: meta.Summary.DecisionTF,
|
||||||
`, meta.RunID, userID, meta.Label, meta.LastError, created, updated); err != nil {
|
ProcessedBars: meta.Summary.ProcessedBars,
|
||||||
return err
|
ProgressPct: meta.Summary.ProgressPct,
|
||||||
|
EquityLast: meta.Summary.EquityLast,
|
||||||
|
MaxDrawdownPct: meta.Summary.MaxDrawdownPct,
|
||||||
|
Liquidated: meta.Summary.Liquidated,
|
||||||
|
LiquidationNote: meta.Summary.LiquidationNote,
|
||||||
|
CreatedAt: meta.CreatedAt,
|
||||||
|
UpdatedAt: meta.UpdatedAt,
|
||||||
}
|
}
|
||||||
|
return s.db.Save(&run).Error
|
||||||
_, err := s.db.Exec(`
|
|
||||||
UPDATE backtest_runs
|
|
||||||
SET user_id = ?, state = ?, symbol_count = ?, decision_tf = ?, processed_bars = ?,
|
|
||||||
progress_pct = ?, equity_last = ?, max_drawdown_pct = ?, liquidated = ?,
|
|
||||||
liquidation_note = ?, label = ?, last_error = ?, updated_at = ?
|
|
||||||
WHERE run_id = ?
|
|
||||||
`, userID, string(meta.State), meta.Summary.SymbolCount, meta.Summary.DecisionTF,
|
|
||||||
meta.Summary.ProcessedBars, meta.Summary.ProgressPct, meta.Summary.EquityLast,
|
|
||||||
meta.Summary.MaxDrawdownPct, meta.Summary.Liquidated, meta.Summary.LiquidationNote,
|
|
||||||
meta.Label, meta.LastError, updated, meta.RunID)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadRunMetadata loads run metadata
|
// LoadRunMetadata loads run metadata
|
||||||
func (s *BacktestStore) LoadRunMetadata(runID string) (*RunMetadata, error) {
|
func (s *BacktestStore) LoadRunMetadata(runID string) (*RunMetadata, error) {
|
||||||
var (
|
var run BacktestRun
|
||||||
userID string
|
err := s.db.Where("run_id = ?", runID).First(&run).Error
|
||||||
state string
|
|
||||||
label string
|
|
||||||
lastErr string
|
|
||||||
symbolCount int
|
|
||||||
decisionTF string
|
|
||||||
processedBars int
|
|
||||||
progressPct float64
|
|
||||||
equityLast float64
|
|
||||||
maxDD float64
|
|
||||||
liquidated bool
|
|
||||||
liquidationNote string
|
|
||||||
createdISO string
|
|
||||||
updatedISO string
|
|
||||||
)
|
|
||||||
|
|
||||||
err := s.db.QueryRow(`
|
|
||||||
SELECT user_id, state, label, last_error, symbol_count, decision_tf, processed_bars,
|
|
||||||
progress_pct, equity_last, max_drawdown_pct, liquidated, liquidation_note,
|
|
||||||
created_at, updated_at
|
|
||||||
FROM backtest_runs WHERE run_id = ?
|
|
||||||
`, runID).Scan(&userID, &state, &label, &lastErr, &symbolCount, &decisionTF,
|
|
||||||
&processedBars, &progressPct, &equityLast, &maxDD, &liquidated, &liquidationNote,
|
|
||||||
&createdISO, &updatedISO)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
meta := &RunMetadata{
|
return &RunMetadata{
|
||||||
RunID: runID,
|
RunID: run.RunID,
|
||||||
UserID: userID,
|
UserID: run.UserID,
|
||||||
Version: 1,
|
Version: 1,
|
||||||
State: RunState(state),
|
State: RunState(run.State),
|
||||||
Label: label,
|
Label: run.Label,
|
||||||
LastError: lastErr,
|
LastError: run.LastError,
|
||||||
Summary: RunSummary{
|
Summary: RunSummary{
|
||||||
SymbolCount: symbolCount,
|
SymbolCount: run.SymbolCount,
|
||||||
DecisionTF: decisionTF,
|
DecisionTF: run.DecisionTF,
|
||||||
ProcessedBars: processedBars,
|
ProcessedBars: run.ProcessedBars,
|
||||||
ProgressPct: progressPct,
|
ProgressPct: run.ProgressPct,
|
||||||
EquityLast: equityLast,
|
EquityLast: run.EquityLast,
|
||||||
MaxDrawdownPct: maxDD,
|
MaxDrawdownPct: run.MaxDrawdownPct,
|
||||||
Liquidated: liquidated,
|
Liquidated: run.Liquidated,
|
||||||
LiquidationNote: liquidationNote,
|
LiquidationNote: run.LiquidationNote,
|
||||||
},
|
},
|
||||||
}
|
CreatedAt: run.CreatedAt,
|
||||||
|
UpdatedAt: run.UpdatedAt,
|
||||||
meta.CreatedAt, _ = time.Parse(time.RFC3339, createdISO)
|
}, nil
|
||||||
meta.UpdatedAt, _ = time.Parse(time.RFC3339, updatedISO)
|
|
||||||
|
|
||||||
return meta, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListRunIDs lists all run IDs
|
// ListRunIDs lists all run IDs
|
||||||
func (s *BacktestStore) ListRunIDs() ([]string, error) {
|
func (s *BacktestStore) ListRunIDs() ([]string, error) {
|
||||||
rows, err := s.db.Query(`SELECT run_id FROM backtest_runs ORDER BY datetime(updated_at) DESC`)
|
var runs []BacktestRun
|
||||||
|
err := s.db.Order("updated_at DESC").Find(&runs).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var ids []string
|
ids := make([]string, len(runs))
|
||||||
for rows.Next() {
|
for i, run := range runs {
|
||||||
var runID string
|
ids[i] = run.RunID
|
||||||
if err := rows.Scan(&runID); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ids = append(ids, runID)
|
|
||||||
}
|
}
|
||||||
return ids, rows.Err()
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AppendEquityPoint appends equity point
|
// AppendEquityPoint appends equity point
|
||||||
func (s *BacktestStore) AppendEquityPoint(runID string, point EquityPoint) error {
|
func (s *BacktestStore) AppendEquityPoint(runID string, point EquityPoint) error {
|
||||||
_, err := s.db.Exec(`
|
eq := BacktestEquity{
|
||||||
INSERT INTO backtest_equity (run_id, ts, equity, available, pnl, pnl_pct, dd_pct, cycle)
|
RunID: runID,
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
TS: point.Timestamp,
|
||||||
`, runID, point.Timestamp, point.Equity, point.Available, point.PnL,
|
Equity: point.Equity,
|
||||||
point.PnLPct, point.DrawdownPct, point.Cycle)
|
Available: point.Available,
|
||||||
return err
|
PnL: point.PnL,
|
||||||
|
PnLPct: point.PnLPct,
|
||||||
|
DDPct: point.DrawdownPct,
|
||||||
|
Cycle: point.Cycle,
|
||||||
|
}
|
||||||
|
return s.db.Create(&eq).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadEquityPoints loads equity points
|
// LoadEquityPoints loads equity points
|
||||||
func (s *BacktestStore) LoadEquityPoints(runID string) ([]EquityPoint, error) {
|
func (s *BacktestStore) LoadEquityPoints(runID string) ([]EquityPoint, error) {
|
||||||
rows, err := s.db.Query(`
|
var eqs []BacktestEquity
|
||||||
SELECT ts, equity, available, pnl, pnl_pct, dd_pct, cycle
|
err := s.db.Where("run_id = ?", runID).Order("ts ASC").Find(&eqs).Error
|
||||||
FROM backtest_equity WHERE run_id = ? ORDER BY ts ASC
|
|
||||||
`, runID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
points := make([]EquityPoint, 0)
|
points := make([]EquityPoint, len(eqs))
|
||||||
for rows.Next() {
|
for i, eq := range eqs {
|
||||||
var point EquityPoint
|
points[i] = EquityPoint{
|
||||||
if err := rows.Scan(&point.Timestamp, &point.Equity, &point.Available,
|
Timestamp: eq.TS,
|
||||||
&point.PnL, &point.PnLPct, &point.DrawdownPct, &point.Cycle); err != nil {
|
Equity: eq.Equity,
|
||||||
return nil, err
|
Available: eq.Available,
|
||||||
|
PnL: eq.PnL,
|
||||||
|
PnLPct: eq.PnLPct,
|
||||||
|
DrawdownPct: eq.DDPct,
|
||||||
|
Cycle: eq.Cycle,
|
||||||
}
|
}
|
||||||
points = append(points, point)
|
|
||||||
}
|
}
|
||||||
return points, rows.Err()
|
return points, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AppendTradeEvent appends trade event
|
// AppendTradeEvent appends trade event
|
||||||
func (s *BacktestStore) AppendTradeEvent(runID string, event TradeEvent) error {
|
func (s *BacktestStore) AppendTradeEvent(runID string, event TradeEvent) error {
|
||||||
_, err := s.db.Exec(`
|
trade := BacktestTrade{
|
||||||
INSERT INTO backtest_trades (run_id, ts, symbol, action, side, qty, price, fee,
|
RunID: runID,
|
||||||
slippage, order_value, realized_pnl, leverage, cycle,
|
TS: event.Timestamp,
|
||||||
position_after, liquidation, note)
|
Symbol: event.Symbol,
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
Action: event.Action,
|
||||||
`, runID, event.Timestamp, event.Symbol, event.Action, event.Side, event.Quantity,
|
Side: event.Side,
|
||||||
event.Price, event.Fee, event.Slippage, event.OrderValue, event.RealizedPnL,
|
Qty: event.Quantity,
|
||||||
event.Leverage, event.Cycle, event.PositionAfter, event.LiquidationFlag, event.Note)
|
Price: event.Price,
|
||||||
return err
|
Fee: event.Fee,
|
||||||
|
Slippage: event.Slippage,
|
||||||
|
OrderValue: event.OrderValue,
|
||||||
|
RealizedPnL: event.RealizedPnL,
|
||||||
|
Leverage: event.Leverage,
|
||||||
|
Cycle: event.Cycle,
|
||||||
|
PositionAfter: event.PositionAfter,
|
||||||
|
Liquidation: event.LiquidationFlag,
|
||||||
|
Note: event.Note,
|
||||||
|
}
|
||||||
|
return s.db.Create(&trade).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadTradeEvents loads trade events
|
// LoadTradeEvents loads trade events
|
||||||
func (s *BacktestStore) LoadTradeEvents(runID string) ([]TradeEvent, error) {
|
func (s *BacktestStore) LoadTradeEvents(runID string) ([]TradeEvent, error) {
|
||||||
rows, err := s.db.Query(`
|
var trades []BacktestTrade
|
||||||
SELECT ts, symbol, action, side, qty, price, fee, slippage, order_value,
|
err := s.db.Where("run_id = ?", runID).Order("ts ASC").Find(&trades).Error
|
||||||
realized_pnl, leverage, cycle, position_after, liquidation, note
|
|
||||||
FROM backtest_trades WHERE run_id = ? ORDER BY ts ASC
|
|
||||||
`, runID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
events := make([]TradeEvent, 0)
|
events := make([]TradeEvent, len(trades))
|
||||||
for rows.Next() {
|
for i, trade := range trades {
|
||||||
var event TradeEvent
|
events[i] = TradeEvent{
|
||||||
if err := rows.Scan(&event.Timestamp, &event.Symbol, &event.Action, &event.Side,
|
Timestamp: trade.TS,
|
||||||
&event.Quantity, &event.Price, &event.Fee, &event.Slippage, &event.OrderValue,
|
Symbol: trade.Symbol,
|
||||||
&event.RealizedPnL, &event.Leverage, &event.Cycle, &event.PositionAfter,
|
Action: trade.Action,
|
||||||
&event.LiquidationFlag, &event.Note); err != nil {
|
Side: trade.Side,
|
||||||
return nil, err
|
Quantity: trade.Qty,
|
||||||
|
Price: trade.Price,
|
||||||
|
Fee: trade.Fee,
|
||||||
|
Slippage: trade.Slippage,
|
||||||
|
OrderValue: trade.OrderValue,
|
||||||
|
RealizedPnL: trade.RealizedPnL,
|
||||||
|
Leverage: trade.Leverage,
|
||||||
|
Cycle: trade.Cycle,
|
||||||
|
PositionAfter: trade.PositionAfter,
|
||||||
|
LiquidationFlag: trade.Liquidation,
|
||||||
|
Note: trade.Note,
|
||||||
}
|
}
|
||||||
events = append(events, event)
|
|
||||||
}
|
}
|
||||||
return events, rows.Err()
|
return events, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveMetrics saves metrics
|
// SaveMetrics saves metrics
|
||||||
func (s *BacktestStore) SaveMetrics(runID string, payload []byte) error {
|
func (s *BacktestStore) SaveMetrics(runID string, payload []byte) error {
|
||||||
_, err := s.db.Exec(`
|
metrics := BacktestMetrics{
|
||||||
INSERT INTO backtest_metrics (run_id, payload, updated_at)
|
RunID: runID,
|
||||||
VALUES (?, ?, CURRENT_TIMESTAMP)
|
Payload: payload,
|
||||||
ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP
|
}
|
||||||
`, runID, payload)
|
return s.db.Save(&metrics).Error
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadMetrics loads metrics
|
// LoadMetrics loads metrics
|
||||||
func (s *BacktestStore) LoadMetrics(runID string) ([]byte, error) {
|
func (s *BacktestStore) LoadMetrics(runID string) ([]byte, error) {
|
||||||
var payload []byte
|
var metrics BacktestMetrics
|
||||||
err := s.db.QueryRow(`SELECT payload FROM backtest_metrics WHERE run_id = ?`, runID).Scan(&payload)
|
err := s.db.Where("run_id = ?", runID).First(&metrics).Error
|
||||||
return payload, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return metrics.Payload, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveDecisionRecord saves decision record
|
// SaveDecisionRecord saves decision record
|
||||||
func (s *BacktestStore) SaveDecisionRecord(runID string, cycle int, payload []byte) error {
|
func (s *BacktestStore) SaveDecisionRecord(runID string, cycle int, payload []byte) error {
|
||||||
_, err := s.db.Exec(`
|
decision := BacktestDecision{
|
||||||
INSERT INTO backtest_decisions (run_id, cycle, payload)
|
RunID: runID,
|
||||||
VALUES (?, ?, ?)
|
Cycle: cycle,
|
||||||
`, runID, cycle, payload)
|
Payload: payload,
|
||||||
return err
|
}
|
||||||
|
return s.db.Create(&decision).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadDecisionRecords loads decision records
|
// LoadDecisionRecords loads decision records
|
||||||
func (s *BacktestStore) LoadDecisionRecords(runID string, limit, offset int) ([]json.RawMessage, error) {
|
func (s *BacktestStore) LoadDecisionRecords(runID string, limit, offset int) ([]json.RawMessage, error) {
|
||||||
rows, err := s.db.Query(`
|
var decisions []BacktestDecision
|
||||||
SELECT payload FROM backtest_decisions
|
err := s.db.Where("run_id = ?", runID).
|
||||||
WHERE run_id = ?
|
Order("id DESC").
|
||||||
ORDER BY id DESC
|
Limit(limit).
|
||||||
LIMIT ? OFFSET ?
|
Offset(offset).
|
||||||
`, runID, limit, offset)
|
Find(&decisions).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
records := make([]json.RawMessage, 0, limit)
|
records := make([]json.RawMessage, len(decisions))
|
||||||
for rows.Next() {
|
for i, d := range decisions {
|
||||||
var payload []byte
|
records[i] = json.RawMessage(d.Payload)
|
||||||
if err := rows.Scan(&payload); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
records = append(records, json.RawMessage(payload))
|
|
||||||
}
|
}
|
||||||
return records, rows.Err()
|
return records, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadLatestDecision loads latest decision
|
// LoadLatestDecision loads latest decision
|
||||||
func (s *BacktestStore) LoadLatestDecision(runID string, cycle int) ([]byte, error) {
|
func (s *BacktestStore) LoadLatestDecision(runID string, cycle int) ([]byte, error) {
|
||||||
var query string
|
var decision BacktestDecision
|
||||||
var args []interface{}
|
query := s.db.Where("run_id = ?", runID)
|
||||||
|
|
||||||
if cycle > 0 {
|
if cycle > 0 {
|
||||||
query = `SELECT payload FROM backtest_decisions WHERE run_id = ? AND cycle = ? ORDER BY datetime(created_at) DESC LIMIT 1`
|
query = query.Where("cycle = ?", cycle)
|
||||||
args = []interface{}{runID, cycle}
|
|
||||||
} else {
|
|
||||||
query = `SELECT payload FROM backtest_decisions WHERE run_id = ? ORDER BY datetime(created_at) DESC LIMIT 1`
|
|
||||||
args = []interface{}{runID}
|
|
||||||
}
|
}
|
||||||
|
err := query.Order("created_at DESC").First(&decision).Error
|
||||||
var payload []byte
|
if err != nil {
|
||||||
err := s.db.QueryRow(query, args...).Scan(&payload)
|
return nil, err
|
||||||
return payload, err
|
}
|
||||||
|
return decision.Payload, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateProgress updates progress
|
// UpdateProgress updates progress
|
||||||
func (s *BacktestStore) UpdateProgress(runID string, progressPct, equity float64, barIndex int, liquidated bool) error {
|
func (s *BacktestStore) UpdateProgress(runID string, progressPct, equity float64, barIndex int, liquidated bool) error {
|
||||||
_, err := s.db.Exec(`
|
return s.db.Model(&BacktestRun{}).Where("run_id = ?", runID).Updates(map[string]interface{}{
|
||||||
UPDATE backtest_runs
|
"progress_pct": progressPct,
|
||||||
SET progress_pct = ?, equity_last = ?, processed_bars = ?, liquidated = ?, updated_at = CURRENT_TIMESTAMP
|
"equity_last": equity,
|
||||||
WHERE run_id = ?
|
"processed_bars": barIndex,
|
||||||
`, progressPct, equity, barIndex, liquidated, runID)
|
"liquidated": liquidated,
|
||||||
return err
|
}).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListIndexEntries lists index entries
|
// ListIndexEntries lists index entries
|
||||||
func (s *BacktestStore) ListIndexEntries() ([]RunIndexEntry, error) {
|
func (s *BacktestStore) ListIndexEntries() ([]RunIndexEntry, error) {
|
||||||
rows, err := s.db.Query(`
|
var runs []BacktestRun
|
||||||
SELECT run_id, state, symbol_count, decision_tf, equity_last, max_drawdown_pct,
|
err := s.db.Order("updated_at DESC").Find(&runs).Error
|
||||||
created_at, updated_at, config_json
|
|
||||||
FROM backtest_runs
|
|
||||||
ORDER BY datetime(updated_at) DESC
|
|
||||||
`)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var entries []RunIndexEntry
|
entries := make([]RunIndexEntry, len(runs))
|
||||||
for rows.Next() {
|
for i, run := range runs {
|
||||||
var entry RunIndexEntry
|
entry := RunIndexEntry{
|
||||||
var symbolCnt int
|
RunID: run.RunID,
|
||||||
var cfgJSON []byte
|
State: run.State,
|
||||||
var createdISO, updatedISO string
|
DecisionTF: run.DecisionTF,
|
||||||
|
EquityLast: run.EquityLast,
|
||||||
if err := rows.Scan(&entry.RunID, &entry.State, &symbolCnt, &entry.DecisionTF,
|
MaxDrawdownPct: run.MaxDrawdownPct,
|
||||||
&entry.EquityLast, &entry.MaxDrawdownPct, &createdISO, &updatedISO, &cfgJSON); err != nil {
|
CreatedAtISO: run.CreatedAt.Format(time.RFC3339),
|
||||||
return nil, err
|
UpdatedAtISO: run.UpdatedAt.Format(time.RFC3339),
|
||||||
|
Symbols: make([]string, 0, run.SymbolCount),
|
||||||
}
|
}
|
||||||
|
|
||||||
entry.CreatedAtISO = createdISO
|
if len(run.ConfigJSON) > 0 {
|
||||||
entry.UpdatedAtISO = updatedISO
|
|
||||||
entry.Symbols = make([]string, 0, symbolCnt)
|
|
||||||
|
|
||||||
// Try to extract more information from config
|
|
||||||
if len(cfgJSON) > 0 {
|
|
||||||
var cfg struct {
|
var cfg struct {
|
||||||
Symbols []string `json:"symbols"`
|
Symbols []string `json:"symbols"`
|
||||||
StartTS int64 `json:"start_ts"`
|
StartTS int64 `json:"start_ts"`
|
||||||
EndTS int64 `json:"end_ts"`
|
EndTS int64 `json:"end_ts"`
|
||||||
}
|
}
|
||||||
if json.Unmarshal(cfgJSON, &cfg) == nil {
|
if json.Unmarshal(run.ConfigJSON, &cfg) == nil {
|
||||||
entry.Symbols = cfg.Symbols
|
entry.Symbols = cfg.Symbols
|
||||||
entry.StartTS = cfg.StartTS
|
entry.StartTS = cfg.StartTS
|
||||||
entry.EndTS = cfg.EndTS
|
entry.EndTS = cfg.EndTS
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
entries = append(entries, entry)
|
entries[i] = entry
|
||||||
}
|
}
|
||||||
return entries, rows.Err()
|
return entries, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRun deletes run
|
// DeleteRun deletes run
|
||||||
func (s *BacktestStore) DeleteRun(runID string) error {
|
func (s *BacktestStore) DeleteRun(runID string) error {
|
||||||
_, err := s.db.Exec(`DELETE FROM backtest_runs WHERE run_id = ?`, runID)
|
// Delete related records first (cascade may not work in all cases)
|
||||||
return err
|
s.db.Where("run_id = ?", runID).Delete(&BacktestCheckpoint{})
|
||||||
|
s.db.Where("run_id = ?", runID).Delete(&BacktestEquity{})
|
||||||
|
s.db.Where("run_id = ?", runID).Delete(&BacktestTrade{})
|
||||||
|
s.db.Where("run_id = ?", runID).Delete(&BacktestMetrics{})
|
||||||
|
s.db.Where("run_id = ?", runID).Delete(&BacktestDecision{})
|
||||||
|
|
||||||
|
return s.db.Where("run_id = ?", runID).Delete(&BacktestRun{}).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveConfig saves config
|
// SaveConfig saves config
|
||||||
func (s *BacktestStore) SaveConfig(runID, userID, template, customPrompt, provider, model string, override bool, configJSON []byte) error {
|
func (s *BacktestStore) SaveConfig(runID, userID, template, customPrompt, provider, model string, override bool, configJSON []byte) error {
|
||||||
now := time.Now().UTC().Format(time.RFC3339)
|
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
userID = "default"
|
userID = "default"
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := s.db.Exec(`
|
run := BacktestRun{
|
||||||
INSERT INTO backtest_runs (run_id, user_id, config_json, prompt_template, custom_prompt,
|
RunID: runID,
|
||||||
override_prompt, ai_provider, ai_model, created_at, updated_at)
|
UserID: userID,
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
ConfigJSON: configJSON,
|
||||||
ON CONFLICT(run_id) DO NOTHING
|
PromptTemplate: template,
|
||||||
`, runID, userID, configJSON, template, customPrompt, override, provider, model, now, now)
|
CustomPrompt: customPrompt,
|
||||||
if err != nil {
|
OverridePrompt: override,
|
||||||
return err
|
AIProvider: provider,
|
||||||
|
AIModel: model,
|
||||||
}
|
}
|
||||||
|
return s.db.Save(&run).Error
|
||||||
_, err = s.db.Exec(`
|
|
||||||
UPDATE backtest_runs
|
|
||||||
SET user_id = ?, config_json = ?, prompt_template = ?, custom_prompt = ?,
|
|
||||||
override_prompt = ?, ai_provider = ?, ai_model = ?, updated_at = CURRENT_TIMESTAMP
|
|
||||||
WHERE run_id = ?
|
|
||||||
`, userID, configJSON, template, customPrompt, override, provider, model, runID)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadConfig loads config
|
// LoadConfig loads config
|
||||||
func (s *BacktestStore) LoadConfig(runID string) ([]byte, error) {
|
func (s *BacktestStore) LoadConfig(runID string) ([]byte, error) {
|
||||||
var payload []byte
|
var run BacktestRun
|
||||||
err := s.db.QueryRow(`SELECT config_json FROM backtest_runs WHERE run_id = ?`, runID).Scan(&payload)
|
err := s.db.Where("run_id = ?", runID).First(&run).Error
|
||||||
return payload, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return run.ConfigJSON, nil
|
||||||
}
|
}
|
||||||
|
|||||||
+225
-491
@@ -1,12 +1,11 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DebateStatus represents the status of a debate session
|
// DebateStatus represents the status of a debate session
|
||||||
@@ -49,30 +48,6 @@ var PersonalityEmojis = map[DebatePersonality]string{
|
|||||||
PersonalityRiskManager: "🛡️",
|
PersonalityRiskManager: "🛡️",
|
||||||
}
|
}
|
||||||
|
|
||||||
// DebateSession represents a debate session
|
|
||||||
type DebateSession struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
UserID string `json:"user_id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
StrategyID string `json:"strategy_id"`
|
|
||||||
Status DebateStatus `json:"status"`
|
|
||||||
Symbol string `json:"symbol"` // Primary symbol (for backward compat, may be empty for multi-coin)
|
|
||||||
MaxRounds int `json:"max_rounds"`
|
|
||||||
CurrentRound int `json:"current_round"`
|
|
||||||
IntervalMinutes int `json:"interval_minutes"` // Debate interval (5, 15, 30, 60 minutes)
|
|
||||||
PromptVariant string `json:"prompt_variant"` // balanced/aggressive/conservative/scalping
|
|
||||||
FinalDecision *DebateDecision `json:"final_decision,omitempty"` // Single decision (backward compat)
|
|
||||||
FinalDecisions []*DebateDecision `json:"final_decisions,omitempty"` // Multi-coin decisions
|
|
||||||
AutoExecute bool `json:"auto_execute"`
|
|
||||||
TraderID string `json:"trader_id,omitempty"` // Trader to use for auto-execute
|
|
||||||
// OI Ranking data options
|
|
||||||
EnableOIRanking bool `json:"enable_oi_ranking"` // Whether to include OI ranking data
|
|
||||||
OIRankingLimit int `json:"oi_ranking_limit"` // Number of OI ranking entries (default 10)
|
|
||||||
OIDuration string `json:"oi_duration"` // Duration for OI data (1h, 4h, 24h, etc.)
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// DebateDecision represents a trading decision from the debate
|
// DebateDecision represents a trading decision from the debate
|
||||||
type DebateDecision struct {
|
type DebateDecision struct {
|
||||||
Action string `json:"action"` // open_long/open_short/close_long/close_short/hold/wait
|
Action string `json:"action"` // open_long/open_short/close_long/close_short/hold/wait
|
||||||
@@ -86,178 +61,187 @@ type DebateDecision struct {
|
|||||||
Reasoning string `json:"reasoning"` // Brief reasoning
|
Reasoning string `json:"reasoning"` // Brief reasoning
|
||||||
|
|
||||||
// Execution tracking
|
// Execution tracking
|
||||||
Executed bool `json:"executed"` // Whether this decision was executed
|
Executed bool `json:"executed"` // Whether this decision was executed
|
||||||
ExecutedAt time.Time `json:"executed_at,omitempty"` // When it was executed
|
ExecutedAt time.Time `json:"executed_at,omitempty"` // When it was executed
|
||||||
OrderID string `json:"order_id,omitempty"` // Exchange order ID
|
OrderID string `json:"order_id,omitempty"` // Exchange order ID
|
||||||
Error string `json:"error,omitempty"` // Execution error if any
|
Error string `json:"error,omitempty"` // Execution error if any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DebateSession represents a debate session (API struct)
|
||||||
|
type DebateSession struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
StrategyID string `json:"strategy_id"`
|
||||||
|
Status DebateStatus `json:"status"`
|
||||||
|
Symbol string `json:"symbol"` // Primary symbol (for backward compat, may be empty for multi-coin)
|
||||||
|
MaxRounds int `json:"max_rounds"`
|
||||||
|
CurrentRound int `json:"current_round"`
|
||||||
|
IntervalMinutes int `json:"interval_minutes"` // Debate interval (5, 15, 30, 60 minutes)
|
||||||
|
PromptVariant string `json:"prompt_variant"` // balanced/aggressive/conservative/scalping
|
||||||
|
FinalDecision *DebateDecision `json:"final_decision,omitempty"` // Single decision (backward compat)
|
||||||
|
FinalDecisions []*DebateDecision `json:"final_decisions,omitempty"` // Multi-coin decisions
|
||||||
|
AutoExecute bool `json:"auto_execute"`
|
||||||
|
TraderID string `json:"trader_id,omitempty"` // Trader to use for auto-execute
|
||||||
|
// OI Ranking data options
|
||||||
|
EnableOIRanking bool `json:"enable_oi_ranking"` // Whether to include OI ranking data
|
||||||
|
OIRankingLimit int `json:"oi_ranking_limit"` // Number of OI ranking entries (default 10)
|
||||||
|
OIDuration string `json:"oi_duration"` // Duration for OI data (1h, 4h, 24h, etc.)
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DebateSessionDB is the GORM model for debate_sessions
|
||||||
|
type DebateSessionDB struct {
|
||||||
|
ID string `gorm:"column:id;primaryKey"`
|
||||||
|
UserID string `gorm:"column:user_id;not null;index"`
|
||||||
|
Name string `gorm:"column:name;not null"`
|
||||||
|
StrategyID string `gorm:"column:strategy_id;not null"`
|
||||||
|
Status DebateStatus `gorm:"column:status;not null;default:pending;index"`
|
||||||
|
Symbol string `gorm:"column:symbol;not null"`
|
||||||
|
MaxRounds int `gorm:"column:max_rounds;default:3"`
|
||||||
|
CurrentRound int `gorm:"column:current_round;default:0"`
|
||||||
|
IntervalMinutes int `gorm:"column:interval_minutes;default:5"`
|
||||||
|
PromptVariant string `gorm:"column:prompt_variant;default:balanced"`
|
||||||
|
FinalDecision string `gorm:"column:final_decision"` // JSON string
|
||||||
|
AutoExecute bool `gorm:"column:auto_execute;default:false"`
|
||||||
|
TraderID string `gorm:"column:trader_id"`
|
||||||
|
EnableOIRanking bool `gorm:"column:enable_oi_ranking;default:false"`
|
||||||
|
OIRankingLimit int `gorm:"column:oi_ranking_limit;default:10"`
|
||||||
|
OIDuration string `gorm:"column:oi_duration;default:1h"`
|
||||||
|
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"`
|
||||||
|
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (DebateSessionDB) TableName() string {
|
||||||
|
return "debate_sessions"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DebateSessionDB) toSession() *DebateSession {
|
||||||
|
s := &DebateSession{
|
||||||
|
ID: db.ID,
|
||||||
|
UserID: db.UserID,
|
||||||
|
Name: db.Name,
|
||||||
|
StrategyID: db.StrategyID,
|
||||||
|
Status: db.Status,
|
||||||
|
Symbol: db.Symbol,
|
||||||
|
MaxRounds: db.MaxRounds,
|
||||||
|
CurrentRound: db.CurrentRound,
|
||||||
|
IntervalMinutes: db.IntervalMinutes,
|
||||||
|
PromptVariant: db.PromptVariant,
|
||||||
|
AutoExecute: db.AutoExecute,
|
||||||
|
TraderID: db.TraderID,
|
||||||
|
EnableOIRanking: db.EnableOIRanking,
|
||||||
|
OIRankingLimit: db.OIRankingLimit,
|
||||||
|
OIDuration: db.OIDuration,
|
||||||
|
CreatedAt: db.CreatedAt,
|
||||||
|
UpdatedAt: db.UpdatedAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set defaults
|
||||||
|
if s.IntervalMinutes == 0 {
|
||||||
|
s.IntervalMinutes = 5
|
||||||
|
}
|
||||||
|
if s.PromptVariant == "" {
|
||||||
|
s.PromptVariant = "balanced"
|
||||||
|
}
|
||||||
|
if s.OIRankingLimit == 0 {
|
||||||
|
s.OIRankingLimit = 10
|
||||||
|
}
|
||||||
|
if s.OIDuration == "" {
|
||||||
|
s.OIDuration = "1h"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse final decision
|
||||||
|
if db.FinalDecision != "" {
|
||||||
|
var decision DebateDecision
|
||||||
|
if json.Unmarshal([]byte(db.FinalDecision), &decision) == nil {
|
||||||
|
s.FinalDecision = &decision
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
// DebateParticipant represents an AI participant in a debate
|
// DebateParticipant represents an AI participant in a debate
|
||||||
type DebateParticipant struct {
|
type DebateParticipant struct {
|
||||||
ID string `json:"id"`
|
ID string `gorm:"column:id;primaryKey" json:"id"`
|
||||||
SessionID string `json:"session_id"`
|
SessionID string `gorm:"column:session_id;not null;index" json:"session_id"`
|
||||||
AIModelID string `json:"ai_model_id"`
|
AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
|
||||||
AIModelName string `json:"ai_model_name"`
|
AIModelName string `gorm:"column:ai_model_name;not null" json:"ai_model_name"`
|
||||||
Provider string `json:"provider"`
|
Provider string `gorm:"column:provider;not null" json:"provider"`
|
||||||
Personality DebatePersonality `json:"personality"`
|
Personality DebatePersonality `gorm:"column:personality;not null" json:"personality"`
|
||||||
Color string `json:"color"`
|
Color string `gorm:"column:color;not null" json:"color"`
|
||||||
SpeakOrder int `json:"speak_order"`
|
SpeakOrder int `gorm:"column:speak_order;default:0" json:"speak_order"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (DebateParticipant) TableName() string {
|
||||||
|
return "debate_participants"
|
||||||
}
|
}
|
||||||
|
|
||||||
// DebateMessage represents a message in the debate
|
// DebateMessage represents a message in the debate
|
||||||
type DebateMessage struct {
|
type DebateMessage struct {
|
||||||
ID string `json:"id"`
|
ID string `gorm:"column:id;primaryKey" json:"id"`
|
||||||
SessionID string `json:"session_id"`
|
SessionID string `gorm:"column:session_id;not null;index" json:"session_id"`
|
||||||
Round int `json:"round"`
|
Round int `gorm:"column:round;not null" json:"round"`
|
||||||
AIModelID string `json:"ai_model_id"`
|
AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
|
||||||
AIModelName string `json:"ai_model_name"`
|
AIModelName string `gorm:"column:ai_model_name;not null" json:"ai_model_name"`
|
||||||
Provider string `json:"provider"`
|
Provider string `gorm:"column:provider;not null" json:"provider"`
|
||||||
Personality DebatePersonality `json:"personality"`
|
Personality DebatePersonality `gorm:"column:personality;not null" json:"personality"`
|
||||||
MessageType string `json:"message_type"` // analysis/rebuttal/final/vote
|
MessageType string `gorm:"column:message_type;not null" json:"message_type"` // analysis/rebuttal/final/vote
|
||||||
Content string `json:"content"`
|
Content string `gorm:"column:content;not null" json:"content"`
|
||||||
Decision *DebateDecision `json:"decision,omitempty"` // Single decision (backward compat)
|
DecisionRaw string `gorm:"column:decision" json:"-"` // JSON string in DB
|
||||||
Decisions []*DebateDecision `json:"decisions,omitempty"` // Multi-coin decisions
|
Decision *DebateDecision `gorm:"-" json:"decision,omitempty"` // Parsed for API
|
||||||
Confidence int `json:"confidence"`
|
Decisions []*DebateDecision `gorm:"-" json:"decisions,omitempty"` // Multi-coin decisions
|
||||||
CreatedAt time.Time `json:"created_at"`
|
Confidence int `gorm:"column:confidence;default:0" json:"confidence"`
|
||||||
|
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (DebateMessage) TableName() string {
|
||||||
|
return "debate_messages"
|
||||||
}
|
}
|
||||||
|
|
||||||
// DebateVote represents a final vote from an AI (can contain multiple coin decisions)
|
// DebateVote represents a final vote from an AI (can contain multiple coin decisions)
|
||||||
type DebateVote struct {
|
type DebateVote struct {
|
||||||
ID string `json:"id"`
|
ID string `gorm:"column:id;primaryKey" json:"id"`
|
||||||
SessionID string `json:"session_id"`
|
SessionID string `gorm:"column:session_id;not null;index" json:"session_id"`
|
||||||
AIModelID string `json:"ai_model_id"`
|
AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
|
||||||
AIModelName string `json:"ai_model_name"`
|
AIModelName string `gorm:"column:ai_model_name;not null" json:"ai_model_name"`
|
||||||
Action string `json:"action"` // Primary action (backward compat)
|
Action string `gorm:"column:action;not null" json:"action"` // Primary action (backward compat)
|
||||||
Symbol string `json:"symbol"` // Primary symbol (backward compat)
|
Symbol string `gorm:"column:symbol;not null" json:"symbol"` // Primary symbol (backward compat)
|
||||||
Confidence int `json:"confidence"`
|
Confidence int `gorm:"column:confidence;default:0" json:"confidence"`
|
||||||
Leverage int `json:"leverage"`
|
Leverage int `gorm:"column:leverage;default:5" json:"leverage"`
|
||||||
PositionPct float64 `json:"position_pct"`
|
PositionPct float64 `gorm:"column:position_pct;default:0.2" json:"position_pct"`
|
||||||
StopLossPct float64 `json:"stop_loss_pct"`
|
StopLossPct float64 `gorm:"column:stop_loss_pct;default:0.03" json:"stop_loss_pct"`
|
||||||
TakeProfitPct float64 `json:"take_profit_pct"`
|
TakeProfitPct float64 `gorm:"column:take_profit_pct;default:0.06" json:"take_profit_pct"`
|
||||||
Reasoning string `json:"reasoning"`
|
Reasoning string `gorm:"column:reasoning" json:"reasoning"`
|
||||||
Decisions []*DebateDecision `json:"decisions,omitempty"` // Multi-coin decisions
|
Decisions []*DebateDecision `gorm:"-" json:"decisions,omitempty"` // Multi-coin decisions
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (DebateVote) TableName() string {
|
||||||
|
return "debate_votes"
|
||||||
}
|
}
|
||||||
|
|
||||||
// DebateStore handles database operations for debates
|
// DebateStore handles database operations for debates
|
||||||
type DebateStore struct {
|
type DebateStore struct {
|
||||||
db *sql.DB
|
db *gorm.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDebateStore creates a new DebateStore
|
// NewDebateStore creates a new DebateStore
|
||||||
func NewDebateStore(db *sql.DB) *DebateStore {
|
func NewDebateStore(db *gorm.DB) *DebateStore {
|
||||||
return &DebateStore{db: db}
|
return &DebateStore{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitSchema creates the debate tables
|
// InitSchema creates the debate tables using GORM AutoMigrate
|
||||||
func (s *DebateStore) InitSchema() error {
|
func (s *DebateStore) InitSchema() error {
|
||||||
schemas := []string{
|
return s.db.AutoMigrate(
|
||||||
`CREATE TABLE IF NOT EXISTS debate_sessions (
|
&DebateSessionDB{},
|
||||||
id TEXT PRIMARY KEY,
|
&DebateParticipant{},
|
||||||
user_id TEXT NOT NULL,
|
&DebateMessage{},
|
||||||
name TEXT NOT NULL,
|
&DebateVote{},
|
||||||
strategy_id TEXT NOT NULL,
|
)
|
||||||
status TEXT NOT NULL DEFAULT 'pending',
|
|
||||||
symbol TEXT NOT NULL,
|
|
||||||
max_rounds INTEGER DEFAULT 3,
|
|
||||||
current_round INTEGER DEFAULT 0,
|
|
||||||
interval_minutes INTEGER DEFAULT 5,
|
|
||||||
prompt_variant TEXT DEFAULT 'balanced',
|
|
||||||
final_decision TEXT,
|
|
||||||
auto_execute BOOLEAN DEFAULT 0,
|
|
||||||
trader_id TEXT,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)`,
|
|
||||||
`CREATE INDEX IF NOT EXISTS idx_debate_sessions_user_id ON debate_sessions(user_id)`,
|
|
||||||
`CREATE INDEX IF NOT EXISTS idx_debate_sessions_status ON debate_sessions(status)`,
|
|
||||||
|
|
||||||
`CREATE TABLE IF NOT EXISTS debate_participants (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
session_id TEXT NOT NULL,
|
|
||||||
ai_model_id TEXT NOT NULL,
|
|
||||||
ai_model_name TEXT NOT NULL,
|
|
||||||
provider TEXT NOT NULL,
|
|
||||||
personality TEXT NOT NULL,
|
|
||||||
color TEXT NOT NULL,
|
|
||||||
speak_order INTEGER DEFAULT 0,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
FOREIGN KEY (session_id) REFERENCES debate_sessions(id) ON DELETE CASCADE
|
|
||||||
)`,
|
|
||||||
`CREATE INDEX IF NOT EXISTS idx_debate_participants_session ON debate_participants(session_id)`,
|
|
||||||
|
|
||||||
`CREATE TABLE IF NOT EXISTS debate_messages (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
session_id TEXT NOT NULL,
|
|
||||||
round INTEGER NOT NULL,
|
|
||||||
ai_model_id TEXT NOT NULL,
|
|
||||||
ai_model_name TEXT NOT NULL,
|
|
||||||
provider TEXT NOT NULL,
|
|
||||||
personality TEXT NOT NULL,
|
|
||||||
message_type TEXT NOT NULL,
|
|
||||||
content TEXT NOT NULL,
|
|
||||||
decision TEXT,
|
|
||||||
confidence INTEGER DEFAULT 0,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
FOREIGN KEY (session_id) REFERENCES debate_sessions(id) ON DELETE CASCADE
|
|
||||||
)`,
|
|
||||||
`CREATE INDEX IF NOT EXISTS idx_debate_messages_session ON debate_messages(session_id)`,
|
|
||||||
`CREATE INDEX IF NOT EXISTS idx_debate_messages_round ON debate_messages(session_id, round)`,
|
|
||||||
|
|
||||||
`CREATE TABLE IF NOT EXISTS debate_votes (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
session_id TEXT NOT NULL,
|
|
||||||
ai_model_id TEXT NOT NULL,
|
|
||||||
ai_model_name TEXT NOT NULL,
|
|
||||||
action TEXT NOT NULL,
|
|
||||||
symbol TEXT NOT NULL,
|
|
||||||
confidence INTEGER DEFAULT 0,
|
|
||||||
leverage INTEGER DEFAULT 5,
|
|
||||||
position_pct REAL DEFAULT 0.2,
|
|
||||||
stop_loss_pct REAL DEFAULT 0.03,
|
|
||||||
take_profit_pct REAL DEFAULT 0.06,
|
|
||||||
reasoning TEXT,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
FOREIGN KEY (session_id) REFERENCES debate_sessions(id) ON DELETE CASCADE
|
|
||||||
)`,
|
|
||||||
`CREATE INDEX IF NOT EXISTS idx_debate_votes_session ON debate_votes(session_id)`,
|
|
||||||
|
|
||||||
// Trigger to update updated_at
|
|
||||||
`CREATE TRIGGER IF NOT EXISTS update_debate_sessions_timestamp
|
|
||||||
AFTER UPDATE ON debate_sessions
|
|
||||||
FOR EACH ROW
|
|
||||||
BEGIN
|
|
||||||
UPDATE debate_sessions SET updated_at = CURRENT_TIMESTAMP WHERE id = OLD.id;
|
|
||||||
END`,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, schema := range schemas {
|
|
||||||
if _, err := s.db.Exec(schema); err != nil {
|
|
||||||
return fmt.Errorf("failed to create debate schema: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Migrate: Add new columns to existing tables (ignore errors if columns already exist)
|
|
||||||
migrations := []string{
|
|
||||||
`ALTER TABLE debate_sessions ADD COLUMN interval_minutes INTEGER DEFAULT 5`,
|
|
||||||
`ALTER TABLE debate_sessions ADD COLUMN prompt_variant TEXT DEFAULT 'balanced'`,
|
|
||||||
`ALTER TABLE debate_sessions ADD COLUMN trader_id TEXT`,
|
|
||||||
`ALTER TABLE debate_sessions ADD COLUMN enable_oi_ranking BOOLEAN DEFAULT 0`,
|
|
||||||
`ALTER TABLE debate_sessions ADD COLUMN oi_ranking_limit INTEGER DEFAULT 10`,
|
|
||||||
`ALTER TABLE debate_sessions ADD COLUMN oi_duration TEXT DEFAULT '1h'`,
|
|
||||||
`ALTER TABLE debate_votes ADD COLUMN leverage INTEGER DEFAULT 5`,
|
|
||||||
`ALTER TABLE debate_votes ADD COLUMN position_pct REAL DEFAULT 0.2`,
|
|
||||||
`ALTER TABLE debate_votes ADD COLUMN stop_loss_pct REAL DEFAULT 0.03`,
|
|
||||||
`ALTER TABLE debate_votes ADD COLUMN take_profit_pct REAL DEFAULT 0.06`,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, migration := range migrations {
|
|
||||||
// Ignore errors - column may already exist
|
|
||||||
s.db.Exec(migration)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateSession creates a new debate session
|
// CreateSession creates a new debate session
|
||||||
@@ -279,227 +263,73 @@ func (s *DebateStore) CreateSession(session *DebateSession) error {
|
|||||||
if session.OIDuration == "" {
|
if session.OIDuration == "" {
|
||||||
session.OIDuration = "1h"
|
session.OIDuration = "1h"
|
||||||
}
|
}
|
||||||
session.CreatedAt = time.Now()
|
|
||||||
session.UpdatedAt = time.Now()
|
|
||||||
|
|
||||||
_, err := s.db.Exec(`
|
db := &DebateSessionDB{
|
||||||
INSERT INTO debate_sessions (id, user_id, name, strategy_id, status, symbol, max_rounds, current_round, interval_minutes, prompt_variant, auto_execute, trader_id, enable_oi_ranking, oi_ranking_limit, oi_duration, created_at, updated_at)
|
ID: session.ID,
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
UserID: session.UserID,
|
||||||
session.ID, session.UserID, session.Name, session.StrategyID, session.Status,
|
Name: session.Name,
|
||||||
session.Symbol, session.MaxRounds, session.CurrentRound, session.IntervalMinutes, session.PromptVariant,
|
StrategyID: session.StrategyID,
|
||||||
session.AutoExecute, session.TraderID, session.EnableOIRanking, session.OIRankingLimit, session.OIDuration,
|
Status: session.Status,
|
||||||
session.CreatedAt, session.UpdatedAt,
|
Symbol: session.Symbol,
|
||||||
)
|
MaxRounds: session.MaxRounds,
|
||||||
return err
|
CurrentRound: session.CurrentRound,
|
||||||
|
IntervalMinutes: session.IntervalMinutes,
|
||||||
|
PromptVariant: session.PromptVariant,
|
||||||
|
AutoExecute: session.AutoExecute,
|
||||||
|
TraderID: session.TraderID,
|
||||||
|
EnableOIRanking: session.EnableOIRanking,
|
||||||
|
OIRankingLimit: session.OIRankingLimit,
|
||||||
|
OIDuration: session.OIDuration,
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.db.Create(db).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSession gets a debate session by ID
|
// GetSession gets a debate session by ID
|
||||||
func (s *DebateStore) GetSession(id string) (*DebateSession, error) {
|
func (s *DebateStore) GetSession(id string) (*DebateSession, error) {
|
||||||
var session DebateSession
|
var db DebateSessionDB
|
||||||
var finalDecisionJSON sql.NullString
|
if err := s.db.Where("id = ?", id).First(&db).Error; err != nil {
|
||||||
var traderID sql.NullString
|
return nil, err
|
||||||
var intervalMinutes sql.NullInt64
|
|
||||||
var promptVariant sql.NullString
|
|
||||||
var enableOIRanking sql.NullBool
|
|
||||||
var oiRankingLimit sql.NullInt64
|
|
||||||
var oiDuration sql.NullString
|
|
||||||
|
|
||||||
// Try new schema first
|
|
||||||
err := s.db.QueryRow(`
|
|
||||||
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round,
|
|
||||||
interval_minutes, prompt_variant, final_decision, auto_execute, trader_id,
|
|
||||||
enable_oi_ranking, oi_ranking_limit, oi_duration, created_at, updated_at
|
|
||||||
FROM debate_sessions WHERE id = ?`, id,
|
|
||||||
).Scan(
|
|
||||||
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
|
|
||||||
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
|
|
||||||
&intervalMinutes, &promptVariant,
|
|
||||||
&finalDecisionJSON, &session.AutoExecute, &traderID,
|
|
||||||
&enableOIRanking, &oiRankingLimit, &oiDuration, &session.CreatedAt, &session.UpdatedAt,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Fallback to basic schema if new columns don't exist
|
|
||||||
if err != nil {
|
|
||||||
err = s.db.QueryRow(`
|
|
||||||
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round,
|
|
||||||
final_decision, auto_execute, created_at, updated_at
|
|
||||||
FROM debate_sessions WHERE id = ?`, id,
|
|
||||||
).Scan(
|
|
||||||
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
|
|
||||||
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
|
|
||||||
&finalDecisionJSON, &session.AutoExecute, &session.CreatedAt, &session.UpdatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// Set defaults for new fields
|
|
||||||
session.IntervalMinutes = 5
|
|
||||||
session.PromptVariant = "balanced"
|
|
||||||
session.OIRankingLimit = 10
|
|
||||||
session.OIDuration = "1h"
|
|
||||||
} else {
|
|
||||||
// Set defaults for nullable fields
|
|
||||||
session.IntervalMinutes = 5
|
|
||||||
if intervalMinutes.Valid {
|
|
||||||
session.IntervalMinutes = int(intervalMinutes.Int64)
|
|
||||||
}
|
|
||||||
session.PromptVariant = "balanced"
|
|
||||||
if promptVariant.Valid {
|
|
||||||
session.PromptVariant = promptVariant.String
|
|
||||||
}
|
|
||||||
if traderID.Valid {
|
|
||||||
session.TraderID = traderID.String
|
|
||||||
}
|
|
||||||
if enableOIRanking.Valid {
|
|
||||||
session.EnableOIRanking = enableOIRanking.Bool
|
|
||||||
}
|
|
||||||
session.OIRankingLimit = 10
|
|
||||||
if oiRankingLimit.Valid {
|
|
||||||
session.OIRankingLimit = int(oiRankingLimit.Int64)
|
|
||||||
}
|
|
||||||
session.OIDuration = "1h"
|
|
||||||
if oiDuration.Valid {
|
|
||||||
session.OIDuration = oiDuration.String
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return db.toSession(), nil
|
||||||
if finalDecisionJSON.Valid && finalDecisionJSON.String != "" {
|
|
||||||
var decision DebateDecision
|
|
||||||
if err := json.Unmarshal([]byte(finalDecisionJSON.String), &decision); err == nil {
|
|
||||||
session.FinalDecision = &decision
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &session, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSessionsByUser gets all debate sessions for a user
|
// GetSessionsByUser gets all debate sessions for a user
|
||||||
func (s *DebateStore) GetSessionsByUser(userID string) ([]*DebateSession, error) {
|
func (s *DebateStore) GetSessionsByUser(userID string) ([]*DebateSession, error) {
|
||||||
// First try the new schema with all columns
|
var dbs []DebateSessionDB
|
||||||
rows, err := s.db.Query(`
|
if err := s.db.Where("user_id = ?", userID).Order("created_at DESC").Find(&dbs).Error; err != nil {
|
||||||
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round,
|
return nil, err
|
||||||
interval_minutes, prompt_variant, final_decision, auto_execute, trader_id, created_at, updated_at
|
|
||||||
FROM debate_sessions WHERE user_id = ? ORDER BY created_at DESC`, userID,
|
|
||||||
)
|
|
||||||
|
|
||||||
// If query fails (likely due to missing columns), try basic query
|
|
||||||
if err != nil {
|
|
||||||
return s.getSessionsByUserBasic(userID)
|
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var sessions []*DebateSession
|
sessions := make([]*DebateSession, len(dbs))
|
||||||
for rows.Next() {
|
for i, db := range dbs {
|
||||||
var session DebateSession
|
sessions[i] = db.toSession()
|
||||||
var finalDecisionJSON sql.NullString
|
|
||||||
var traderID sql.NullString
|
|
||||||
var intervalMinutes sql.NullInt64
|
|
||||||
var promptVariant sql.NullString
|
|
||||||
|
|
||||||
if err := rows.Scan(
|
|
||||||
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
|
|
||||||
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
|
|
||||||
&intervalMinutes, &promptVariant,
|
|
||||||
&finalDecisionJSON, &session.AutoExecute, &traderID, &session.CreatedAt, &session.UpdatedAt,
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set defaults for nullable fields
|
|
||||||
session.IntervalMinutes = 5
|
|
||||||
if intervalMinutes.Valid {
|
|
||||||
session.IntervalMinutes = int(intervalMinutes.Int64)
|
|
||||||
}
|
|
||||||
session.PromptVariant = "balanced"
|
|
||||||
if promptVariant.Valid {
|
|
||||||
session.PromptVariant = promptVariant.String
|
|
||||||
}
|
|
||||||
|
|
||||||
if finalDecisionJSON.Valid && finalDecisionJSON.String != "" {
|
|
||||||
var decision DebateDecision
|
|
||||||
if err := json.Unmarshal([]byte(finalDecisionJSON.String), &decision); err == nil {
|
|
||||||
session.FinalDecision = &decision
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if traderID.Valid {
|
|
||||||
session.TraderID = traderID.String
|
|
||||||
}
|
|
||||||
|
|
||||||
sessions = append(sessions, &session)
|
|
||||||
}
|
}
|
||||||
return sessions, nil
|
return sessions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListAllSessions returns all debate sessions (for cleanup on startup)
|
// ListAllSessions returns all debate sessions (for cleanup on startup)
|
||||||
func (s *DebateStore) ListAllSessions() ([]*DebateSession, error) {
|
func (s *DebateStore) ListAllSessions() ([]*DebateSession, error) {
|
||||||
rows, err := s.db.Query(`SELECT id, status FROM debate_sessions`)
|
var dbs []DebateSessionDB
|
||||||
if err != nil {
|
if err := s.db.Select("id, status").Find(&dbs).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var sessions []*DebateSession
|
sessions := make([]*DebateSession, len(dbs))
|
||||||
for rows.Next() {
|
for i, db := range dbs {
|
||||||
var session DebateSession
|
sessions[i] = &DebateSession{ID: db.ID, Status: db.Status}
|
||||||
if err := rows.Scan(&session.ID, &session.Status); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sessions = append(sessions, &session)
|
|
||||||
}
|
|
||||||
return sessions, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getSessionsByUserBasic is a fallback for old schema without new columns
|
|
||||||
func (s *DebateStore) getSessionsByUserBasic(userID string) ([]*DebateSession, error) {
|
|
||||||
rows, err := s.db.Query(`
|
|
||||||
SELECT id, user_id, name, strategy_id, status, symbol, max_rounds, current_round,
|
|
||||||
final_decision, auto_execute, created_at, updated_at
|
|
||||||
FROM debate_sessions WHERE user_id = ? ORDER BY created_at DESC`, userID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var sessions []*DebateSession
|
|
||||||
for rows.Next() {
|
|
||||||
var session DebateSession
|
|
||||||
var finalDecisionJSON sql.NullString
|
|
||||||
|
|
||||||
if err := rows.Scan(
|
|
||||||
&session.ID, &session.UserID, &session.Name, &session.StrategyID,
|
|
||||||
&session.Status, &session.Symbol, &session.MaxRounds, &session.CurrentRound,
|
|
||||||
&finalDecisionJSON, &session.AutoExecute, &session.CreatedAt, &session.UpdatedAt,
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set defaults for new fields
|
|
||||||
session.IntervalMinutes = 5
|
|
||||||
session.PromptVariant = "balanced"
|
|
||||||
|
|
||||||
if finalDecisionJSON.Valid && finalDecisionJSON.String != "" {
|
|
||||||
var decision DebateDecision
|
|
||||||
if err := json.Unmarshal([]byte(finalDecisionJSON.String), &decision); err == nil {
|
|
||||||
session.FinalDecision = &decision
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sessions = append(sessions, &session)
|
|
||||||
}
|
}
|
||||||
return sessions, nil
|
return sessions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSessionStatus updates the status of a debate session
|
// UpdateSessionStatus updates the status of a debate session
|
||||||
func (s *DebateStore) UpdateSessionStatus(id string, status DebateStatus) error {
|
func (s *DebateStore) UpdateSessionStatus(id string, status DebateStatus) error {
|
||||||
_, err := s.db.Exec(`UPDATE debate_sessions SET status = ? WHERE id = ?`, status, id)
|
return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Update("status", status).Error
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSessionRound updates the current round of a debate session
|
// UpdateSessionRound updates the current round of a debate session
|
||||||
func (s *DebateStore) UpdateSessionRound(id string, round int) error {
|
func (s *DebateStore) UpdateSessionRound(id string, round int) error {
|
||||||
_, err := s.db.Exec(`UPDATE debate_sessions SET current_round = ? WHERE id = ?`, round, id)
|
return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Update("current_round", round).Error
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSessionFinalDecision updates the final decision of a debate session (single decision)
|
// UpdateSessionFinalDecision updates the final decision of a debate session (single decision)
|
||||||
@@ -508,30 +338,31 @@ func (s *DebateStore) UpdateSessionFinalDecision(id string, decision *DebateDeci
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = s.db.Exec(`UPDATE debate_sessions SET final_decision = ?, status = ? WHERE id = ?`,
|
return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||||||
string(decisionJSON), DebateStatusCompleted, id)
|
"final_decision": string(decisionJSON),
|
||||||
return err
|
"status": DebateStatusCompleted,
|
||||||
|
}).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSessionFinalDecisions updates both single and multi-coin final decisions
|
// UpdateSessionFinalDecisions updates both single and multi-coin final decisions
|
||||||
func (s *DebateStore) UpdateSessionFinalDecisions(id string, primaryDecision *DebateDecision, allDecisions []*DebateDecision) error {
|
func (s *DebateStore) UpdateSessionFinalDecisions(id string, primaryDecision *DebateDecision, allDecisions []*DebateDecision) error {
|
||||||
// Always store primary decision as a single object (for backward compat)
|
|
||||||
// This ensures GetSession can deserialize it correctly
|
|
||||||
primaryJSON, err := json.Marshal(primaryDecision)
|
primaryJSON, err := json.Marshal(primaryDecision)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return s.db.Model(&DebateSessionDB{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||||||
// Update final_decision with primary decision and set status to completed
|
"final_decision": string(primaryJSON),
|
||||||
_, err = s.db.Exec(`UPDATE debate_sessions SET final_decision = ?, status = ? WHERE id = ?`,
|
"status": DebateStatusCompleted,
|
||||||
string(primaryJSON), DebateStatusCompleted, id)
|
}).Error
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteSession deletes a debate session and all related data
|
// DeleteSession deletes a debate session and all related data
|
||||||
func (s *DebateStore) DeleteSession(id string) error {
|
func (s *DebateStore) DeleteSession(id string) error {
|
||||||
_, err := s.db.Exec(`DELETE FROM debate_sessions WHERE id = ?`, id)
|
// Delete related data first
|
||||||
return err
|
s.db.Where("session_id = ?", id).Delete(&DebateParticipant{})
|
||||||
|
s.db.Where("session_id = ?", id).Delete(&DebateMessage{})
|
||||||
|
s.db.Where("session_id = ?", id).Delete(&DebateVote{})
|
||||||
|
return s.db.Where("id = ?", id).Delete(&DebateSessionDB{}).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddParticipant adds a participant to a debate session
|
// AddParticipant adds a participant to a debate session
|
||||||
@@ -539,9 +370,6 @@ func (s *DebateStore) AddParticipant(participant *DebateParticipant) error {
|
|||||||
if participant.ID == "" {
|
if participant.ID == "" {
|
||||||
participant.ID = uuid.New().String()
|
participant.ID = uuid.New().String()
|
||||||
}
|
}
|
||||||
participant.CreatedAt = time.Now()
|
|
||||||
|
|
||||||
// Set color based on personality if not provided
|
|
||||||
if participant.Color == "" {
|
if participant.Color == "" {
|
||||||
if color, ok := PersonalityColors[participant.Personality]; ok {
|
if color, ok := PersonalityColors[participant.Personality]; ok {
|
||||||
participant.Color = color
|
participant.Color = color
|
||||||
@@ -549,39 +377,14 @@ func (s *DebateStore) AddParticipant(participant *DebateParticipant) error {
|
|||||||
participant.Color = "#6B7280" // Default gray
|
participant.Color = "#6B7280" // Default gray
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return s.db.Create(participant).Error
|
||||||
_, err := s.db.Exec(`
|
|
||||||
INSERT INTO debate_participants (id, session_id, ai_model_id, ai_model_name, provider, personality, color, speak_order, created_at)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
||||||
participant.ID, participant.SessionID, participant.AIModelID, participant.AIModelName,
|
|
||||||
participant.Provider, participant.Personality, participant.Color, participant.SpeakOrder, participant.CreatedAt,
|
|
||||||
)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetParticipants gets all participants for a debate session
|
// GetParticipants gets all participants for a debate session
|
||||||
func (s *DebateStore) GetParticipants(sessionID string) ([]*DebateParticipant, error) {
|
func (s *DebateStore) GetParticipants(sessionID string) ([]*DebateParticipant, error) {
|
||||||
rows, err := s.db.Query(`
|
|
||||||
SELECT id, session_id, ai_model_id, ai_model_name, provider, personality, color, speak_order, created_at
|
|
||||||
FROM debate_participants WHERE session_id = ? ORDER BY speak_order`, sessionID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var participants []*DebateParticipant
|
var participants []*DebateParticipant
|
||||||
for rows.Next() {
|
err := s.db.Where("session_id = ?", sessionID).Order("speak_order").Find(&participants).Error
|
||||||
var p DebateParticipant
|
return participants, err
|
||||||
if err := rows.Scan(
|
|
||||||
&p.ID, &p.SessionID, &p.AIModelID, &p.AIModelName,
|
|
||||||
&p.Provider, &p.Personality, &p.Color, &p.SpeakOrder, &p.CreatedAt,
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
participants = append(participants, &p)
|
|
||||||
}
|
|
||||||
return participants, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddMessage adds a message to a debate session
|
// AddMessage adds a message to a debate session
|
||||||
@@ -589,95 +392,52 @@ func (s *DebateStore) AddMessage(msg *DebateMessage) error {
|
|||||||
if msg.ID == "" {
|
if msg.ID == "" {
|
||||||
msg.ID = uuid.New().String()
|
msg.ID = uuid.New().String()
|
||||||
}
|
}
|
||||||
msg.CreatedAt = time.Now()
|
|
||||||
|
|
||||||
var decisionJSON sql.NullString
|
|
||||||
if msg.Decision != nil {
|
if msg.Decision != nil {
|
||||||
data, err := json.Marshal(msg.Decision)
|
data, err := json.Marshal(msg.Decision)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
decisionJSON = sql.NullString{String: string(data), Valid: true}
|
msg.DecisionRaw = string(data)
|
||||||
}
|
}
|
||||||
|
return s.db.Create(msg).Error
|
||||||
_, err := s.db.Exec(`
|
|
||||||
INSERT INTO debate_messages (id, session_id, round, ai_model_id, ai_model_name, provider, personality, message_type, content, decision, confidence, created_at)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
||||||
msg.ID, msg.SessionID, msg.Round, msg.AIModelID, msg.AIModelName,
|
|
||||||
msg.Provider, msg.Personality, msg.MessageType, msg.Content,
|
|
||||||
decisionJSON, msg.Confidence, msg.CreatedAt,
|
|
||||||
)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMessages gets all messages for a debate session
|
// GetMessages gets all messages for a debate session
|
||||||
func (s *DebateStore) GetMessages(sessionID string) ([]*DebateMessage, error) {
|
func (s *DebateStore) GetMessages(sessionID string) ([]*DebateMessage, error) {
|
||||||
rows, err := s.db.Query(`
|
var messages []*DebateMessage
|
||||||
SELECT id, session_id, round, ai_model_id, ai_model_name, provider, personality, message_type, content, decision, confidence, created_at
|
err := s.db.Where("session_id = ?", sessionID).Order("round, created_at").Find(&messages).Error
|
||||||
FROM debate_messages WHERE session_id = ? ORDER BY round, created_at`, sessionID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var messages []*DebateMessage
|
// Parse decision JSON
|
||||||
for rows.Next() {
|
for _, msg := range messages {
|
||||||
var msg DebateMessage
|
if msg.DecisionRaw != "" {
|
||||||
var decisionJSON sql.NullString
|
|
||||||
|
|
||||||
if err := rows.Scan(
|
|
||||||
&msg.ID, &msg.SessionID, &msg.Round, &msg.AIModelID, &msg.AIModelName,
|
|
||||||
&msg.Provider, &msg.Personality, &msg.MessageType, &msg.Content,
|
|
||||||
&decisionJSON, &msg.Confidence, &msg.CreatedAt,
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if decisionJSON.Valid && decisionJSON.String != "" {
|
|
||||||
var decision DebateDecision
|
var decision DebateDecision
|
||||||
if err := json.Unmarshal([]byte(decisionJSON.String), &decision); err == nil {
|
if json.Unmarshal([]byte(msg.DecisionRaw), &decision) == nil {
|
||||||
msg.Decision = &decision
|
msg.Decision = &decision
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = append(messages, &msg)
|
|
||||||
}
|
}
|
||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMessagesByRound gets messages for a specific round
|
// GetMessagesByRound gets messages for a specific round
|
||||||
func (s *DebateStore) GetMessagesByRound(sessionID string, round int) ([]*DebateMessage, error) {
|
func (s *DebateStore) GetMessagesByRound(sessionID string, round int) ([]*DebateMessage, error) {
|
||||||
rows, err := s.db.Query(`
|
var messages []*DebateMessage
|
||||||
SELECT id, session_id, round, ai_model_id, ai_model_name, provider, personality, message_type, content, decision, confidence, created_at
|
err := s.db.Where("session_id = ? AND round = ?", sessionID, round).Order("created_at").Find(&messages).Error
|
||||||
FROM debate_messages WHERE session_id = ? AND round = ? ORDER BY created_at`, sessionID, round,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var messages []*DebateMessage
|
// Parse decision JSON
|
||||||
for rows.Next() {
|
for _, msg := range messages {
|
||||||
var msg DebateMessage
|
if msg.DecisionRaw != "" {
|
||||||
var decisionJSON sql.NullString
|
|
||||||
|
|
||||||
if err := rows.Scan(
|
|
||||||
&msg.ID, &msg.SessionID, &msg.Round, &msg.AIModelID, &msg.AIModelName,
|
|
||||||
&msg.Provider, &msg.Personality, &msg.MessageType, &msg.Content,
|
|
||||||
&decisionJSON, &msg.Confidence, &msg.CreatedAt,
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if decisionJSON.Valid && decisionJSON.String != "" {
|
|
||||||
var decision DebateDecision
|
var decision DebateDecision
|
||||||
if err := json.Unmarshal([]byte(decisionJSON.String), &decision); err == nil {
|
if json.Unmarshal([]byte(msg.DecisionRaw), &decision) == nil {
|
||||||
msg.Decision = &decision
|
msg.Decision = &decision
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = append(messages, &msg)
|
|
||||||
}
|
}
|
||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
@@ -687,40 +447,14 @@ func (s *DebateStore) AddVote(vote *DebateVote) error {
|
|||||||
if vote.ID == "" {
|
if vote.ID == "" {
|
||||||
vote.ID = uuid.New().String()
|
vote.ID = uuid.New().String()
|
||||||
}
|
}
|
||||||
vote.CreatedAt = time.Now()
|
return s.db.Create(vote).Error
|
||||||
|
|
||||||
_, err := s.db.Exec(`
|
|
||||||
INSERT INTO debate_votes (id, session_id, ai_model_id, ai_model_name, action, symbol, confidence, leverage, position_pct, stop_loss_pct, take_profit_pct, reasoning, created_at)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
||||||
vote.ID, vote.SessionID, vote.AIModelID, vote.AIModelName,
|
|
||||||
vote.Action, vote.Symbol, vote.Confidence, vote.Leverage, vote.PositionPct, vote.StopLossPct, vote.TakeProfitPct, vote.Reasoning, vote.CreatedAt,
|
|
||||||
)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetVotes gets all votes for a debate session
|
// GetVotes gets all votes for a debate session
|
||||||
func (s *DebateStore) GetVotes(sessionID string) ([]*DebateVote, error) {
|
func (s *DebateStore) GetVotes(sessionID string) ([]*DebateVote, error) {
|
||||||
rows, err := s.db.Query(`
|
|
||||||
SELECT id, session_id, ai_model_id, ai_model_name, action, symbol, confidence, leverage, position_pct, stop_loss_pct, take_profit_pct, reasoning, created_at
|
|
||||||
FROM debate_votes WHERE session_id = ? ORDER BY created_at`, sessionID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var votes []*DebateVote
|
var votes []*DebateVote
|
||||||
for rows.Next() {
|
err := s.db.Where("session_id = ?", sessionID).Order("created_at").Find(&votes).Error
|
||||||
var vote DebateVote
|
return votes, err
|
||||||
if err := rows.Scan(
|
|
||||||
&vote.ID, &vote.SessionID, &vote.AIModelID, &vote.AIModelName,
|
|
||||||
&vote.Action, &vote.Symbol, &vote.Confidence, &vote.Leverage, &vote.PositionPct, &vote.StopLossPct, &vote.TakeProfitPct, &vote.Reasoning, &vote.CreatedAt,
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
votes = append(votes, &vote)
|
|
||||||
}
|
|
||||||
return votes, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DebateSessionWithDetails combines session with participants and messages
|
// DebateSessionWithDetails combines session with participants and messages
|
||||||
|
|||||||
+135
-198
@@ -1,18 +1,41 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DecisionStore decision log storage
|
// DecisionStore decision log storage
|
||||||
type DecisionStore struct {
|
type DecisionStore struct {
|
||||||
db *sql.DB
|
db *gorm.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecisionRecord decision record
|
// DecisionRecordDB internal GORM model for decision_records table
|
||||||
|
type DecisionRecordDB struct {
|
||||||
|
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||||
|
TraderID string `gorm:"column:trader_id;not null;index:idx_decision_records_trader_time"`
|
||||||
|
CycleNumber int `gorm:"column:cycle_number;not null"`
|
||||||
|
Timestamp time.Time `gorm:"not null;index:idx_decision_records_trader_time,sort:desc;index:idx_decision_records_timestamp,sort:desc"`
|
||||||
|
SystemPrompt string `gorm:"column:system_prompt;default:''"`
|
||||||
|
InputPrompt string `gorm:"column:input_prompt;default:''"`
|
||||||
|
CoTTrace string `gorm:"column:cot_trace;default:''"`
|
||||||
|
DecisionJSON string `gorm:"column:decision_json;default:''"`
|
||||||
|
RawResponse string `gorm:"column:raw_response;default:''"`
|
||||||
|
CandidateCoins string `gorm:"column:candidate_coins;default:''"`
|
||||||
|
ExecutionLog string `gorm:"column:execution_log;default:''"`
|
||||||
|
Decisions string `gorm:"column:decisions;default:'[]'"`
|
||||||
|
Success bool `gorm:"default:false"`
|
||||||
|
ErrorMessage string `gorm:"column:error_message;default:''"`
|
||||||
|
AIRequestDurationMs int64 `gorm:"column:ai_request_duration_ms;default:0"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (DecisionRecordDB) TableName() string { return "decision_records" }
|
||||||
|
|
||||||
|
// DecisionRecord decision record (external API struct)
|
||||||
type DecisionRecord struct {
|
type DecisionRecord struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
TraderID string `json:"trader_id"`
|
TraderID string `json:"trader_id"`
|
||||||
@@ -81,49 +104,47 @@ type Statistics struct {
|
|||||||
TotalClosePositions int `json:"total_close_positions"`
|
TotalClosePositions int `json:"total_close_positions"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// initTables initializes AI decision log tables
|
// NewDecisionStore creates a new DecisionStore
|
||||||
// Note: Account equity curve data has been migrated to trader_equity_snapshots table (managed by EquityStore)
|
func NewDecisionStore(db *gorm.DB) *DecisionStore {
|
||||||
func (s *DecisionStore) initTables() error {
|
return &DecisionStore{db: db}
|
||||||
queries := []string{
|
|
||||||
// AI decision log table (records AI input/output, chain of thought, etc.)
|
|
||||||
`CREATE TABLE IF NOT EXISTS decision_records (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
trader_id TEXT NOT NULL,
|
|
||||||
cycle_number INTEGER NOT NULL,
|
|
||||||
timestamp DATETIME NOT NULL,
|
|
||||||
system_prompt TEXT DEFAULT '',
|
|
||||||
input_prompt TEXT DEFAULT '',
|
|
||||||
cot_trace TEXT DEFAULT '',
|
|
||||||
decision_json TEXT DEFAULT '',
|
|
||||||
raw_response TEXT DEFAULT '',
|
|
||||||
candidate_coins TEXT DEFAULT '',
|
|
||||||
execution_log TEXT DEFAULT '',
|
|
||||||
success BOOLEAN DEFAULT 0,
|
|
||||||
error_message TEXT DEFAULT '',
|
|
||||||
ai_request_duration_ms INTEGER DEFAULT 0,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)`,
|
|
||||||
// Indexes
|
|
||||||
`CREATE INDEX IF NOT EXISTS idx_decision_records_trader_time ON decision_records(trader_id, timestamp DESC)`,
|
|
||||||
`CREATE INDEX IF NOT EXISTS idx_decision_records_timestamp ON decision_records(timestamp DESC)`,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, query := range queries {
|
|
||||||
if _, err := s.db.Exec(query); err != nil {
|
|
||||||
return fmt.Errorf("failed to execute SQL: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Migration: add raw_response column if not exists
|
|
||||||
s.db.Exec(`ALTER TABLE decision_records ADD COLUMN raw_response TEXT DEFAULT ''`)
|
|
||||||
|
|
||||||
// Migration: add decisions column if not exists
|
|
||||||
s.db.Exec(`ALTER TABLE decision_records ADD COLUMN decisions TEXT DEFAULT '[]'`)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogDecision logs decision (only saves AI decision log, equity curve has been migrated to equity table)
|
// initTables initializes AI decision log tables
|
||||||
|
func (s *DecisionStore) initTables() error {
|
||||||
|
// For PostgreSQL with existing table, skip AutoMigrate
|
||||||
|
if s.db.Dialector.Name() == "postgres" {
|
||||||
|
var tableExists int64
|
||||||
|
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'decision_records'`).Scan(&tableExists)
|
||||||
|
if tableExists > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.db.AutoMigrate(&DecisionRecordDB{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// toRecord converts DB model to API struct
|
||||||
|
func (db *DecisionRecordDB) toRecord() *DecisionRecord {
|
||||||
|
record := &DecisionRecord{
|
||||||
|
ID: db.ID,
|
||||||
|
TraderID: db.TraderID,
|
||||||
|
CycleNumber: db.CycleNumber,
|
||||||
|
Timestamp: db.Timestamp,
|
||||||
|
SystemPrompt: db.SystemPrompt,
|
||||||
|
InputPrompt: db.InputPrompt,
|
||||||
|
CoTTrace: db.CoTTrace,
|
||||||
|
DecisionJSON: db.DecisionJSON,
|
||||||
|
RawResponse: db.RawResponse,
|
||||||
|
Success: db.Success,
|
||||||
|
ErrorMessage: db.ErrorMessage,
|
||||||
|
AIRequestDurationMs: db.AIRequestDurationMs,
|
||||||
|
}
|
||||||
|
json.Unmarshal([]byte(db.CandidateCoins), &record.CandidateCoins)
|
||||||
|
json.Unmarshal([]byte(db.ExecutionLog), &record.ExecutionLog)
|
||||||
|
json.Unmarshal([]byte(db.Decisions), &record.Decisions)
|
||||||
|
return record
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogDecision logs decision
|
||||||
func (s *DecisionStore) LogDecision(record *DecisionRecord) error {
|
func (s *DecisionStore) LogDecision(record *DecisionRecord) error {
|
||||||
if record.Timestamp.IsZero() {
|
if record.Timestamp.IsZero() {
|
||||||
record.Timestamp = time.Now().UTC()
|
record.Timestamp = time.Now().UTC()
|
||||||
@@ -131,65 +152,49 @@ func (s *DecisionStore) LogDecision(record *DecisionRecord) error {
|
|||||||
record.Timestamp = record.Timestamp.UTC()
|
record.Timestamp = record.Timestamp.UTC()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serialize candidate coins, execution log and decisions to JSON
|
// Serialize arrays to JSON
|
||||||
candidateCoinsJSON, _ := json.Marshal(record.CandidateCoins)
|
candidateCoinsJSON, _ := json.Marshal(record.CandidateCoins)
|
||||||
executionLogJSON, _ := json.Marshal(record.ExecutionLog)
|
executionLogJSON, _ := json.Marshal(record.ExecutionLog)
|
||||||
decisionsJSON, _ := json.Marshal(record.Decisions)
|
decisionsJSON, _ := json.Marshal(record.Decisions)
|
||||||
|
|
||||||
// Insert decision record main table (only save AI decision related content)
|
dbRecord := &DecisionRecordDB{
|
||||||
result, err := s.db.Exec(`
|
TraderID: record.TraderID,
|
||||||
INSERT INTO decision_records (
|
CycleNumber: record.CycleNumber,
|
||||||
trader_id, cycle_number, timestamp, system_prompt, input_prompt,
|
Timestamp: record.Timestamp,
|
||||||
cot_trace, decision_json, raw_response, candidate_coins, execution_log,
|
SystemPrompt: record.SystemPrompt,
|
||||||
decisions, success, error_message, ai_request_duration_ms
|
InputPrompt: record.InputPrompt,
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
CoTTrace: record.CoTTrace,
|
||||||
`,
|
DecisionJSON: record.DecisionJSON,
|
||||||
record.TraderID, record.CycleNumber, record.Timestamp.Format(time.RFC3339),
|
RawResponse: record.RawResponse,
|
||||||
record.SystemPrompt, record.InputPrompt, record.CoTTrace, record.DecisionJSON,
|
CandidateCoins: string(candidateCoinsJSON),
|
||||||
record.RawResponse, string(candidateCoinsJSON), string(executionLogJSON),
|
ExecutionLog: string(executionLogJSON),
|
||||||
string(decisionsJSON), record.Success, record.ErrorMessage, record.AIRequestDurationMs,
|
Decisions: string(decisionsJSON),
|
||||||
)
|
Success: record.Success,
|
||||||
if err != nil {
|
ErrorMessage: record.ErrorMessage,
|
||||||
|
AIRequestDurationMs: record.AIRequestDurationMs,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.Create(dbRecord).Error; err != nil {
|
||||||
return fmt.Errorf("failed to insert decision record: %w", err)
|
return fmt.Errorf("failed to insert decision record: %w", err)
|
||||||
}
|
}
|
||||||
|
record.ID = dbRecord.ID
|
||||||
decisionID, err := result.LastInsertId()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get decision ID: %w", err)
|
|
||||||
}
|
|
||||||
record.ID = decisionID
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLatestRecords gets the latest N records for specified trader (sorted by time in ascending order: old to new)
|
// GetLatestRecords gets the latest N records for specified trader (sorted by time in ascending order: old to new)
|
||||||
func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRecord, error) {
|
func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRecord, error) {
|
||||||
rows, err := s.db.Query(`
|
var dbRecords []*DecisionRecordDB
|
||||||
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt,
|
err := s.db.Where("trader_id = ?", traderID).
|
||||||
cot_trace, decision_json, candidate_coins, execution_log,
|
Order("timestamp DESC").
|
||||||
COALESCE(decisions, '[]'), success, error_message, ai_request_duration_ms
|
Limit(n).
|
||||||
FROM decision_records
|
Find(&dbRecords).Error
|
||||||
WHERE trader_id = ?
|
|
||||||
ORDER BY timestamp DESC
|
|
||||||
LIMIT ?
|
|
||||||
`, traderID, n)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to query decision records: %w", err)
|
return nil, fmt.Errorf("failed to query decision records: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var records []*DecisionRecord
|
records := make([]*DecisionRecord, len(dbRecords))
|
||||||
for rows.Next() {
|
for i, db := range dbRecords {
|
||||||
record, err := s.scanDecisionRecord(rows)
|
records[i] = db.toRecord()
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
records = append(records, record)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill associated data
|
|
||||||
for _, record := range records {
|
|
||||||
s.fillRecordDetails(record)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reverse array to sort time from old to new
|
// Reverse array to sort time from old to new
|
||||||
@@ -202,26 +207,15 @@ func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRec
|
|||||||
|
|
||||||
// GetAllLatestRecords gets the latest N records for all traders
|
// GetAllLatestRecords gets the latest N records for all traders
|
||||||
func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) {
|
func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) {
|
||||||
rows, err := s.db.Query(`
|
var dbRecords []*DecisionRecordDB
|
||||||
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt,
|
err := s.db.Order("timestamp DESC").Limit(n).Find(&dbRecords).Error
|
||||||
cot_trace, decision_json, candidate_coins, execution_log,
|
|
||||||
COALESCE(decisions, '[]'), success, error_message, ai_request_duration_ms
|
|
||||||
FROM decision_records
|
|
||||||
ORDER BY timestamp DESC
|
|
||||||
LIMIT ?
|
|
||||||
`, n)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to query decision records: %w", err)
|
return nil, fmt.Errorf("failed to query decision records: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var records []*DecisionRecord
|
records := make([]*DecisionRecord, len(dbRecords))
|
||||||
for rows.Next() {
|
for i, db := range dbRecords {
|
||||||
record, err := s.scanDecisionRecord(rows)
|
records[i] = db.toRecord()
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
records = append(records, record)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reverse array
|
// Reverse array
|
||||||
@@ -236,26 +230,17 @@ func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) {
|
|||||||
func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*DecisionRecord, error) {
|
func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*DecisionRecord, error) {
|
||||||
dateStr := date.Format("2006-01-02")
|
dateStr := date.Format("2006-01-02")
|
||||||
|
|
||||||
rows, err := s.db.Query(`
|
var dbRecords []*DecisionRecordDB
|
||||||
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt,
|
err := s.db.Where("trader_id = ? AND DATE(timestamp) = ?", traderID, dateStr).
|
||||||
cot_trace, decision_json, candidate_coins, execution_log,
|
Order("timestamp ASC").
|
||||||
COALESCE(decisions, '[]'), success, error_message, ai_request_duration_ms
|
Find(&dbRecords).Error
|
||||||
FROM decision_records
|
|
||||||
WHERE trader_id = ? AND DATE(timestamp) = ?
|
|
||||||
ORDER BY timestamp ASC
|
|
||||||
`, traderID, dateStr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to query decision records: %w", err)
|
return nil, fmt.Errorf("failed to query decision records: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var records []*DecisionRecord
|
records := make([]*DecisionRecord, len(dbRecords))
|
||||||
for rows.Next() {
|
for i, db := range dbRecords {
|
||||||
record, err := s.scanDecisionRecord(rows)
|
records[i] = db.toRecord()
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
records = append(records, record)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return records, nil
|
return records, nil
|
||||||
@@ -263,48 +248,31 @@ func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*De
|
|||||||
|
|
||||||
// CleanOldRecords cleans old records from N days ago
|
// CleanOldRecords cleans old records from N days ago
|
||||||
func (s *DecisionStore) CleanOldRecords(traderID string, days int) (int64, error) {
|
func (s *DecisionStore) CleanOldRecords(traderID string, days int) (int64, error) {
|
||||||
cutoffTime := time.Now().AddDate(0, 0, -days).Format(time.RFC3339)
|
cutoffTime := time.Now().AddDate(0, 0, -days)
|
||||||
|
|
||||||
result, err := s.db.Exec(`
|
result := s.db.Where("trader_id = ? AND timestamp < ?", traderID, cutoffTime).
|
||||||
DELETE FROM decision_records
|
Delete(&DecisionRecordDB{})
|
||||||
WHERE trader_id = ? AND timestamp < ?
|
if result.Error != nil {
|
||||||
`, traderID, cutoffTime)
|
return 0, fmt.Errorf("failed to clean old records: %w", result.Error)
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("failed to clean old records: %w", err)
|
|
||||||
}
|
}
|
||||||
|
return result.RowsAffected, nil
|
||||||
return result.RowsAffected()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStatistics gets statistics information for specified trader
|
// GetStatistics gets statistics information for specified trader
|
||||||
func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) {
|
func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) {
|
||||||
stats := &Statistics{}
|
stats := &Statistics{}
|
||||||
|
|
||||||
err := s.db.QueryRow(`
|
var totalCount, successCount int64
|
||||||
SELECT COUNT(*) FROM decision_records WHERE trader_id = ?
|
s.db.Model(&DecisionRecordDB{}).Where("trader_id = ?", traderID).Count(&totalCount)
|
||||||
`, traderID).Scan(&stats.TotalCycles)
|
s.db.Model(&DecisionRecordDB{}).Where("trader_id = ? AND success = ?", traderID, true).Count(&successCount)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to query total cycles: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.db.QueryRow(`
|
stats.TotalCycles = int(totalCount)
|
||||||
SELECT COUNT(*) FROM decision_records WHERE trader_id = ? AND success = 1
|
stats.SuccessfulCycles = int(successCount)
|
||||||
`, traderID).Scan(&stats.SuccessfulCycles)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to query successful cycles: %w", err)
|
|
||||||
}
|
|
||||||
stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles
|
stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles
|
||||||
|
|
||||||
// Count from trader_positions table
|
// Count from trader_positions table using raw query for cross-table
|
||||||
s.db.QueryRow(`
|
s.db.Raw("SELECT COUNT(*) FROM trader_positions WHERE trader_id = ?", traderID).Scan(&stats.TotalOpenPositions)
|
||||||
SELECT COUNT(*) FROM trader_positions
|
s.db.Raw("SELECT COUNT(*) FROM trader_positions WHERE trader_id = ? AND status = 'CLOSED'", traderID).Scan(&stats.TotalClosePositions)
|
||||||
WHERE trader_id = ?
|
|
||||||
`, traderID).Scan(&stats.TotalOpenPositions)
|
|
||||||
|
|
||||||
s.db.QueryRow(`
|
|
||||||
SELECT COUNT(*) FROM trader_positions
|
|
||||||
WHERE trader_id = ? AND status = 'CLOSED'
|
|
||||||
`, traderID).Scan(&stats.TotalClosePositions)
|
|
||||||
|
|
||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
@@ -313,64 +281,33 @@ func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) {
|
|||||||
func (s *DecisionStore) GetAllStatistics() (*Statistics, error) {
|
func (s *DecisionStore) GetAllStatistics() (*Statistics, error) {
|
||||||
stats := &Statistics{}
|
stats := &Statistics{}
|
||||||
|
|
||||||
s.db.QueryRow(`SELECT COUNT(*) FROM decision_records`).Scan(&stats.TotalCycles)
|
var totalCount, successCount int64
|
||||||
s.db.QueryRow(`SELECT COUNT(*) FROM decision_records WHERE success = 1`).Scan(&stats.SuccessfulCycles)
|
s.db.Model(&DecisionRecordDB{}).Count(&totalCount)
|
||||||
|
s.db.Model(&DecisionRecordDB{}).Where("success = ?", true).Count(&successCount)
|
||||||
|
|
||||||
|
stats.TotalCycles = int(totalCount)
|
||||||
|
stats.SuccessfulCycles = int(successCount)
|
||||||
stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles
|
stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles
|
||||||
|
|
||||||
// Count from trader_positions table
|
// Count from trader_positions table
|
||||||
s.db.QueryRow(`
|
s.db.Raw("SELECT COUNT(*) FROM trader_positions").Scan(&stats.TotalOpenPositions)
|
||||||
SELECT COUNT(*) FROM trader_positions
|
s.db.Raw("SELECT COUNT(*) FROM trader_positions WHERE status = 'CLOSED'").Scan(&stats.TotalClosePositions)
|
||||||
`).Scan(&stats.TotalOpenPositions)
|
|
||||||
|
|
||||||
s.db.QueryRow(`
|
|
||||||
SELECT COUNT(*) FROM trader_positions
|
|
||||||
WHERE status = 'CLOSED'
|
|
||||||
`).Scan(&stats.TotalClosePositions)
|
|
||||||
|
|
||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLastCycleNumber gets the last cycle number for specified trader
|
// GetLastCycleNumber gets the last cycle number for specified trader
|
||||||
func (s *DecisionStore) GetLastCycleNumber(traderID string) (int, error) {
|
func (s *DecisionStore) GetLastCycleNumber(traderID string) (int, error) {
|
||||||
var cycleNumber int
|
var cycleNumber *int
|
||||||
err := s.db.QueryRow(`
|
err := s.db.Model(&DecisionRecordDB{}).
|
||||||
SELECT COALESCE(MAX(cycle_number), 0) FROM decision_records WHERE trader_id = ?
|
Where("trader_id = ?", traderID).
|
||||||
`, traderID).Scan(&cycleNumber)
|
Select("MAX(cycle_number)").
|
||||||
|
Scan(&cycleNumber).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return cycleNumber, nil
|
if cycleNumber == nil {
|
||||||
}
|
return 0, nil
|
||||||
|
|
||||||
// scanDecisionRecord scans decision record from row
|
|
||||||
func (s *DecisionStore) scanDecisionRecord(rows *sql.Rows) (*DecisionRecord, error) {
|
|
||||||
var record DecisionRecord
|
|
||||||
var timestampStr string
|
|
||||||
var candidateCoinsJSON, executionLogJSON, decisionsJSON string
|
|
||||||
|
|
||||||
err := rows.Scan(
|
|
||||||
&record.ID, &record.TraderID, &record.CycleNumber, ×tampStr,
|
|
||||||
&record.SystemPrompt, &record.InputPrompt, &record.CoTTrace,
|
|
||||||
&record.DecisionJSON, &candidateCoinsJSON, &executionLogJSON,
|
|
||||||
&decisionsJSON, &record.Success, &record.ErrorMessage, &record.AIRequestDurationMs,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
return *cycleNumber, nil
|
||||||
record.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
|
|
||||||
json.Unmarshal([]byte(candidateCoinsJSON), &record.CandidateCoins)
|
|
||||||
json.Unmarshal([]byte(executionLogJSON), &record.ExecutionLog)
|
|
||||||
json.Unmarshal([]byte(decisionsJSON), &record.Decisions)
|
|
||||||
|
|
||||||
return &record, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// fillRecordDetails fills associated data for decision record (old associated tables removed, this function kept for compatibility)
|
|
||||||
// Note: Account snapshot, position snapshot, decision action data are no longer stored in decision related tables
|
|
||||||
// - For equity data use EquityStore.GetLatest()
|
|
||||||
// - For order data use OrderStore
|
|
||||||
func (s *DecisionStore) fillRecordDetails(record *DecisionRecord) {
|
|
||||||
// Old associated tables removed, no longer need to fill
|
|
||||||
// AccountState, Positions, Decisions fields will remain at zero values
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -238,3 +238,44 @@ func getEnv(key, defaultValue string) string {
|
|||||||
}
|
}
|
||||||
return defaultValue
|
return defaultValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// convertQuery converts ? placeholders to $1, $2 for PostgreSQL
|
||||||
|
// and handles other database-specific syntax differences
|
||||||
|
func convertQuery(query string, dbType DBType) string {
|
||||||
|
if dbType != DBTypePostgres {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
result := query
|
||||||
|
|
||||||
|
// Convert ? to $1, $2, etc. for PostgreSQL
|
||||||
|
index := 1
|
||||||
|
for strings.Contains(result, "?") {
|
||||||
|
result = strings.Replace(result, "?", fmt.Sprintf("$%d", index), 1)
|
||||||
|
index++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert datetime('now') to CURRENT_TIMESTAMP
|
||||||
|
result = strings.ReplaceAll(result, "datetime('now')", "CURRENT_TIMESTAMP")
|
||||||
|
|
||||||
|
// Remove datetime() wrapper for ORDER BY (PostgreSQL timestamps sort correctly)
|
||||||
|
// This handles patterns like "ORDER BY datetime(column) DESC"
|
||||||
|
result = strings.ReplaceAll(result, "datetime(updated_at)", "updated_at")
|
||||||
|
result = strings.ReplaceAll(result, "datetime(created_at)", "created_at")
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// boolDefault returns database-appropriate boolean default for COALESCE
|
||||||
|
// Use in queries like: COALESCE(column, %s)
|
||||||
|
func boolDefault(dbType DBType, value bool) string {
|
||||||
|
if dbType == DBTypePostgres {
|
||||||
|
if value {
|
||||||
|
return "TRUE"
|
||||||
|
}
|
||||||
|
return "FALSE"
|
||||||
|
}
|
||||||
|
if value {
|
||||||
|
return "1"
|
||||||
|
}
|
||||||
|
return "0"
|
||||||
|
}
|
||||||
|
|||||||
+61
-139
@@ -1,55 +1,48 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// EquityStore account equity storage (for plotting return curves)
|
// EquityStore account equity storage (for plotting return curves)
|
||||||
type EquityStore struct {
|
type EquityStore struct {
|
||||||
db *sql.DB
|
db *gorm.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// EquitySnapshot equity snapshot
|
// EquitySnapshot equity snapshot
|
||||||
type EquitySnapshot struct {
|
type EquitySnapshot struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||||
TraderID string `json:"trader_id"`
|
TraderID string `gorm:"column:trader_id;not null;index:idx_equity_trader_time" json:"trader_id"`
|
||||||
Timestamp time.Time `json:"timestamp"`
|
Timestamp time.Time `gorm:"not null;index:idx_equity_trader_time,sort:desc;index:idx_equity_timestamp,sort:desc" json:"timestamp"`
|
||||||
TotalEquity float64 `json:"total_equity"` // Account equity (balance + unrealized PnL)
|
TotalEquity float64 `gorm:"column:total_equity;not null;default:0" json:"total_equity"`
|
||||||
Balance float64 `json:"balance"` // Account balance
|
Balance float64 `gorm:"not null;default:0" json:"balance"`
|
||||||
UnrealizedPnL float64 `json:"unrealized_pnl"` // Unrealized profit and loss
|
UnrealizedPnL float64 `gorm:"column:unrealized_pnl;not null;default:0" json:"unrealized_pnl"`
|
||||||
PositionCount int `json:"position_count"` // Position count
|
PositionCount int `gorm:"column:position_count;default:0" json:"position_count"`
|
||||||
MarginUsedPct float64 `json:"margin_used_pct"` // Margin usage percentage
|
MarginUsedPct float64 `gorm:"column:margin_used_pct;default:0" json:"margin_used_pct"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (EquitySnapshot) TableName() string { return "trader_equity_snapshots" }
|
||||||
|
|
||||||
|
// NewEquityStore creates a new EquityStore
|
||||||
|
func NewEquityStore(db *gorm.DB) *EquityStore {
|
||||||
|
return &EquityStore{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
// initTables initializes equity tables
|
// initTables initializes equity tables
|
||||||
func (s *EquityStore) initTables() error {
|
func (s *EquityStore) initTables() error {
|
||||||
queries := []string{
|
// For PostgreSQL with existing table, skip AutoMigrate
|
||||||
// Equity snapshot table - specifically for return curves
|
if s.db.Dialector.Name() == "postgres" {
|
||||||
`CREATE TABLE IF NOT EXISTS trader_equity_snapshots (
|
var tableExists int64
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_equity_snapshots'`).Scan(&tableExists)
|
||||||
trader_id TEXT NOT NULL,
|
if tableExists > 0 {
|
||||||
timestamp DATETIME NOT NULL,
|
return nil
|
||||||
total_equity REAL NOT NULL DEFAULT 0,
|
|
||||||
balance REAL NOT NULL DEFAULT 0,
|
|
||||||
unrealized_pnl REAL NOT NULL DEFAULT 0,
|
|
||||||
position_count INTEGER DEFAULT 0,
|
|
||||||
margin_used_pct REAL DEFAULT 0,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)`,
|
|
||||||
// Indexes
|
|
||||||
`CREATE INDEX IF NOT EXISTS idx_equity_trader_time ON trader_equity_snapshots(trader_id, timestamp DESC)`,
|
|
||||||
`CREATE INDEX IF NOT EXISTS idx_equity_timestamp ON trader_equity_snapshots(timestamp DESC)`,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, query := range queries {
|
|
||||||
if _, err := s.db.Exec(query); err != nil {
|
|
||||||
return fmt.Errorf("failed to execute SQL: %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return s.db.AutoMigrate(&EquitySnapshot{})
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save saves equity snapshot
|
// Save saves equity snapshot
|
||||||
@@ -60,58 +53,22 @@ func (s *EquityStore) Save(snapshot *EquitySnapshot) error {
|
|||||||
snapshot.Timestamp = snapshot.Timestamp.UTC()
|
snapshot.Timestamp = snapshot.Timestamp.UTC()
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := s.db.Exec(`
|
if err := s.db.Create(snapshot).Error; err != nil {
|
||||||
INSERT INTO trader_equity_snapshots (
|
|
||||||
trader_id, timestamp, total_equity, balance,
|
|
||||||
unrealized_pnl, position_count, margin_used_pct
|
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
||||||
`,
|
|
||||||
snapshot.TraderID,
|
|
||||||
snapshot.Timestamp.Format(time.RFC3339),
|
|
||||||
snapshot.TotalEquity,
|
|
||||||
snapshot.Balance,
|
|
||||||
snapshot.UnrealizedPnL,
|
|
||||||
snapshot.PositionCount,
|
|
||||||
snapshot.MarginUsedPct,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to save equity snapshot: %w", err)
|
return fmt.Errorf("failed to save equity snapshot: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
id, _ := result.LastInsertId()
|
|
||||||
snapshot.ID = id
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLatest gets the latest N equity records for specified trader (sorted in ascending chronological order: old to new)
|
// GetLatest gets the latest N equity records for specified trader (sorted in ascending chronological order: old to new)
|
||||||
func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot, error) {
|
func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot, error) {
|
||||||
rows, err := s.db.Query(`
|
var snapshots []*EquitySnapshot
|
||||||
SELECT id, trader_id, timestamp, total_equity, balance,
|
err := s.db.Where("trader_id = ?", traderID).
|
||||||
unrealized_pnl, position_count, margin_used_pct
|
Order("timestamp DESC").
|
||||||
FROM trader_equity_snapshots
|
Limit(limit).
|
||||||
WHERE trader_id = ?
|
Find(&snapshots).Error
|
||||||
ORDER BY timestamp DESC
|
|
||||||
LIMIT ?
|
|
||||||
`, traderID, limit)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to query equity records: %w", err)
|
return nil, fmt.Errorf("failed to query equity records: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var snapshots []*EquitySnapshot
|
|
||||||
for rows.Next() {
|
|
||||||
snap := &EquitySnapshot{}
|
|
||||||
var timestampStr string
|
|
||||||
err := rows.Scan(
|
|
||||||
&snap.ID, &snap.TraderID, ×tampStr, &snap.TotalEquity,
|
|
||||||
&snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
|
|
||||||
snapshots = append(snapshots, snap)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reverse the array to sort time from old to new (suitable for plotting curves)
|
// Reverse the array to sort time from old to new (suitable for plotting curves)
|
||||||
for i, j := 0, len(snapshots)-1; i < j; i, j = i+1, j-1 {
|
for i, j := 0, len(snapshots)-1; i < j; i, j = i+1, j-1 {
|
||||||
@@ -123,116 +80,81 @@ func (s *EquityStore) GetLatest(traderID string, limit int) ([]*EquitySnapshot,
|
|||||||
|
|
||||||
// GetByTimeRange gets equity records within specified time range
|
// GetByTimeRange gets equity records within specified time range
|
||||||
func (s *EquityStore) GetByTimeRange(traderID string, start, end time.Time) ([]*EquitySnapshot, error) {
|
func (s *EquityStore) GetByTimeRange(traderID string, start, end time.Time) ([]*EquitySnapshot, error) {
|
||||||
rows, err := s.db.Query(`
|
var snapshots []*EquitySnapshot
|
||||||
SELECT id, trader_id, timestamp, total_equity, balance,
|
err := s.db.Where("trader_id = ? AND timestamp >= ? AND timestamp <= ?", traderID, start, end).
|
||||||
unrealized_pnl, position_count, margin_used_pct
|
Order("timestamp ASC").
|
||||||
FROM trader_equity_snapshots
|
Find(&snapshots).Error
|
||||||
WHERE trader_id = ? AND timestamp >= ? AND timestamp <= ?
|
|
||||||
ORDER BY timestamp ASC
|
|
||||||
`, traderID, start.Format(time.RFC3339), end.Format(time.RFC3339))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to query equity records: %w", err)
|
return nil, fmt.Errorf("failed to query equity records: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var snapshots []*EquitySnapshot
|
|
||||||
for rows.Next() {
|
|
||||||
snap := &EquitySnapshot{}
|
|
||||||
var timestampStr string
|
|
||||||
err := rows.Scan(
|
|
||||||
&snap.ID, &snap.TraderID, ×tampStr, &snap.TotalEquity,
|
|
||||||
&snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
|
|
||||||
snapshots = append(snapshots, snap)
|
|
||||||
}
|
|
||||||
|
|
||||||
return snapshots, nil
|
return snapshots, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllTradersLatest gets latest equity for all traders (for leaderboards)
|
// GetAllTradersLatest gets latest equity for all traders (for leaderboards)
|
||||||
func (s *EquityStore) GetAllTradersLatest() (map[string]*EquitySnapshot, error) {
|
func (s *EquityStore) GetAllTradersLatest() (map[string]*EquitySnapshot, error) {
|
||||||
rows, err := s.db.Query(`
|
// Use raw SQL for this complex query with subquery
|
||||||
|
var snapshots []*EquitySnapshot
|
||||||
|
err := s.db.Raw(`
|
||||||
SELECT e.id, e.trader_id, e.timestamp, e.total_equity, e.balance,
|
SELECT e.id, e.trader_id, e.timestamp, e.total_equity, e.balance,
|
||||||
e.unrealized_pnl, e.position_count, e.margin_used_pct
|
e.unrealized_pnl, e.position_count, e.margin_used_pct, e.created_at
|
||||||
FROM trader_equity_snapshots e
|
FROM trader_equity_snapshots e
|
||||||
INNER JOIN (
|
INNER JOIN (
|
||||||
SELECT trader_id, MAX(timestamp) as max_ts
|
SELECT trader_id, MAX(timestamp) as max_ts
|
||||||
FROM trader_equity_snapshots
|
FROM trader_equity_snapshots
|
||||||
GROUP BY trader_id
|
GROUP BY trader_id
|
||||||
) latest ON e.trader_id = latest.trader_id AND e.timestamp = latest.max_ts
|
) latest ON e.trader_id = latest.trader_id AND e.timestamp = latest.max_ts
|
||||||
`)
|
`).Scan(&snapshots).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to query latest equity: %w", err)
|
return nil, fmt.Errorf("failed to query latest equity: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
result := make(map[string]*EquitySnapshot)
|
result := make(map[string]*EquitySnapshot)
|
||||||
for rows.Next() {
|
for _, snap := range snapshots {
|
||||||
snap := &EquitySnapshot{}
|
|
||||||
var timestampStr string
|
|
||||||
err := rows.Scan(
|
|
||||||
&snap.ID, &snap.TraderID, ×tampStr, &snap.TotalEquity,
|
|
||||||
&snap.Balance, &snap.UnrealizedPnL, &snap.PositionCount, &snap.MarginUsedPct,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
snap.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
|
|
||||||
result[snap.TraderID] = snap
|
result[snap.TraderID] = snap
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanOldRecords cleans old records from N days ago
|
// CleanOldRecords cleans old records from N days ago
|
||||||
func (s *EquityStore) CleanOldRecords(traderID string, days int) (int64, error) {
|
func (s *EquityStore) CleanOldRecords(traderID string, days int) (int64, error) {
|
||||||
cutoffTime := time.Now().AddDate(0, 0, -days).Format(time.RFC3339)
|
cutoffTime := time.Now().AddDate(0, 0, -days)
|
||||||
|
|
||||||
result, err := s.db.Exec(`
|
result := s.db.Where("trader_id = ? AND timestamp < ?", traderID, cutoffTime).
|
||||||
DELETE FROM trader_equity_snapshots
|
Delete(&EquitySnapshot{})
|
||||||
WHERE trader_id = ? AND timestamp < ?
|
if result.Error != nil {
|
||||||
`, traderID, cutoffTime)
|
return 0, fmt.Errorf("failed to clean old records: %w", result.Error)
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("failed to clean old records: %w", err)
|
|
||||||
}
|
}
|
||||||
|
return result.RowsAffected, nil
|
||||||
return result.RowsAffected()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCount gets record count for specified trader
|
// GetCount gets record count for specified trader
|
||||||
func (s *EquityStore) GetCount(traderID string) (int, error) {
|
func (s *EquityStore) GetCount(traderID string) (int, error) {
|
||||||
var count int
|
var count int64
|
||||||
err := s.db.QueryRow(`
|
err := s.db.Model(&EquitySnapshot{}).Where("trader_id = ?", traderID).Count(&count).Error
|
||||||
SELECT COUNT(*) FROM trader_equity_snapshots WHERE trader_id = ?
|
return int(count), err
|
||||||
`, traderID).Scan(&count)
|
|
||||||
return count, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MigrateFromDecision migrates data from old decision_account_snapshots table
|
// MigrateFromDecision migrates data from old decision_account_snapshots table
|
||||||
func (s *EquityStore) MigrateFromDecision() (int64, error) {
|
func (s *EquityStore) MigrateFromDecision() (int64, error) {
|
||||||
// Check if migration is needed (whether new table is empty)
|
// Check if migration is needed (whether new table is empty)
|
||||||
var count int
|
var count int64
|
||||||
s.db.QueryRow(`SELECT COUNT(*) FROM trader_equity_snapshots`).Scan(&count)
|
s.db.Model(&EquitySnapshot{}).Count(&count)
|
||||||
if count > 0 {
|
if count > 0 {
|
||||||
return 0, nil // Already has data, skip migration
|
return 0, nil // Already has data, skip migration
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if old table exists
|
// Check if old table exists (SQLite specific check, but works for migration)
|
||||||
var tableName string
|
var tableName string
|
||||||
err := s.db.QueryRow(`
|
err := s.db.Raw(`
|
||||||
SELECT name FROM sqlite_master
|
SELECT name FROM sqlite_master
|
||||||
WHERE type='table' AND name='decision_account_snapshots'
|
WHERE type='table' AND name='decision_account_snapshots'
|
||||||
`).Scan(&tableName)
|
`).Scan(&tableName).Error
|
||||||
if err != nil {
|
if err != nil || tableName == "" {
|
||||||
return 0, nil // Old table doesn't exist, skip
|
return 0, nil // Old table doesn't exist, skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate data: join query from decision_records + decision_account_snapshots
|
// Migrate data: join query from decision_records + decision_account_snapshots
|
||||||
result, err := s.db.Exec(`
|
result := s.db.Exec(`
|
||||||
INSERT INTO trader_equity_snapshots (
|
INSERT INTO trader_equity_snapshots (
|
||||||
trader_id, timestamp, total_equity, balance,
|
trader_id, timestamp, total_equity, balance,
|
||||||
unrealized_pnl, position_count, margin_used_pct
|
unrealized_pnl, position_count, margin_used_pct
|
||||||
@@ -249,9 +171,9 @@ func (s *EquityStore) MigrateFromDecision() (int64, error) {
|
|||||||
JOIN decision_account_snapshots das ON dr.id = das.decision_id
|
JOIN decision_account_snapshots das ON dr.id = das.decision_id
|
||||||
ORDER BY dr.timestamp ASC
|
ORDER BY dr.timestamp ASC
|
||||||
`)
|
`)
|
||||||
if err != nil {
|
if result.Error != nil {
|
||||||
return 0, fmt.Errorf("failed to migrate data: %w", err)
|
return 0, fmt.Errorf("failed to migrate data: %w", result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
return result.RowsAffected()
|
return result.RowsAffected, nil
|
||||||
}
|
}
|
||||||
|
|||||||
+153
-302
@@ -1,83 +1,68 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"nofx/crypto"
|
||||||
"nofx/logger"
|
"nofx/logger"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ExchangeStore exchange storage
|
// ExchangeStore exchange storage
|
||||||
type ExchangeStore struct {
|
type ExchangeStore struct {
|
||||||
db *sql.DB
|
db *gorm.DB
|
||||||
encryptFunc func(string) string
|
|
||||||
decryptFunc func(string) string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exchange exchange configuration
|
// Exchange exchange configuration
|
||||||
type Exchange struct {
|
type Exchange struct {
|
||||||
ID string `json:"id"` // UUID
|
ID string `gorm:"primaryKey" json:"id"`
|
||||||
ExchangeType string `json:"exchange_type"` // "binance", "bybit", "okx", "hyperliquid", "aster", "lighter"
|
ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"`
|
||||||
AccountName string `json:"account_name"` // User-defined account name
|
AccountName string `gorm:"column:account_name;not null;default:''" json:"account_name"`
|
||||||
UserID string `json:"user_id"`
|
UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"`
|
||||||
Name string `json:"name"` // Display name (auto-generated or user-defined)
|
Name string `gorm:"not null" json:"name"`
|
||||||
Type string `json:"type"` // "cex" or "dex"
|
Type string `gorm:"not null" json:"type"` // "cex" or "dex"
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `gorm:"default:false" json:"enabled"`
|
||||||
APIKey string `json:"apiKey"`
|
APIKey crypto.EncryptedString `gorm:"column:api_key;default:''" json:"apiKey"`
|
||||||
SecretKey string `json:"secretKey"`
|
SecretKey crypto.EncryptedString `gorm:"column:secret_key;default:''" json:"secretKey"`
|
||||||
Passphrase string `json:"passphrase"` // OKX-specific
|
Passphrase crypto.EncryptedString `gorm:"column:passphrase;default:''" json:"passphrase"`
|
||||||
Testnet bool `json:"testnet"`
|
Testnet bool `gorm:"default:false" json:"testnet"`
|
||||||
HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"`
|
HyperliquidWalletAddr string `gorm:"column:hyperliquid_wallet_addr;default:''" json:"hyperliquidWalletAddr"`
|
||||||
AsterUser string `json:"asterUser"`
|
AsterUser string `gorm:"column:aster_user;default:''" json:"asterUser"`
|
||||||
AsterSigner string `json:"asterSigner"`
|
AsterSigner string `gorm:"column:aster_signer;default:''" json:"asterSigner"`
|
||||||
AsterPrivateKey string `json:"asterPrivateKey"`
|
AsterPrivateKey crypto.EncryptedString `gorm:"column:aster_private_key;default:''" json:"asterPrivateKey"`
|
||||||
LighterWalletAddr string `json:"lighterWalletAddr"`
|
LighterWalletAddr string `gorm:"column:lighter_wallet_addr;default:''" json:"lighterWalletAddr"`
|
||||||
LighterPrivateKey string `json:"lighterPrivateKey"`
|
LighterPrivateKey crypto.EncryptedString `gorm:"column:lighter_private_key;default:''" json:"lighterPrivateKey"`
|
||||||
LighterAPIKeyPrivateKey string `json:"lighterAPIKeyPrivateKey"`
|
LighterAPIKeyPrivateKey crypto.EncryptedString `gorm:"column:lighter_api_key_private_key;default:''" json:"lighterAPIKeyPrivateKey"`
|
||||||
LighterAPIKeyIndex int `json:"lighterAPIKeyIndex"`
|
LighterAPIKeyIndex int `gorm:"column:lighter_api_key_index;default:0" json:"lighterAPIKeyIndex"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Exchange) TableName() string { return "exchanges" }
|
||||||
|
|
||||||
|
// NewExchangeStore creates a new ExchangeStore
|
||||||
|
func NewExchangeStore(db *gorm.DB) *ExchangeStore {
|
||||||
|
return &ExchangeStore{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ExchangeStore) initTables() error {
|
func (s *ExchangeStore) initTables() error {
|
||||||
// Create new table structure with UUID as primary key
|
// For PostgreSQL with existing table, skip AutoMigrate
|
||||||
_, err := s.db.Exec(`
|
if s.db.Dialector.Name() == "postgres" {
|
||||||
CREATE TABLE IF NOT EXISTS exchanges (
|
var tableExists int64
|
||||||
id TEXT PRIMARY KEY,
|
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'exchanges'`).Scan(&tableExists)
|
||||||
exchange_type TEXT NOT NULL DEFAULT '',
|
if tableExists > 0 {
|
||||||
account_name TEXT NOT NULL DEFAULT '',
|
// Still run data migrations
|
||||||
user_id TEXT NOT NULL DEFAULT 'default',
|
s.migrateToMultiAccount()
|
||||||
name TEXT NOT NULL,
|
s.db.Model(&Exchange{}).Where("account_name = '' OR account_name IS NULL").Update("account_name", "Default")
|
||||||
type TEXT NOT NULL,
|
return nil
|
||||||
enabled BOOLEAN DEFAULT 0,
|
}
|
||||||
api_key TEXT DEFAULT '',
|
|
||||||
secret_key TEXT DEFAULT '',
|
|
||||||
passphrase TEXT DEFAULT '',
|
|
||||||
testnet BOOLEAN DEFAULT 0,
|
|
||||||
hyperliquid_wallet_addr TEXT DEFAULT '',
|
|
||||||
aster_user TEXT DEFAULT '',
|
|
||||||
aster_signer TEXT DEFAULT '',
|
|
||||||
aster_private_key TEXT DEFAULT '',
|
|
||||||
lighter_wallet_addr TEXT DEFAULT '',
|
|
||||||
lighter_private_key TEXT DEFAULT '',
|
|
||||||
lighter_api_key_private_key TEXT DEFAULT '',
|
|
||||||
lighter_api_key_index INTEGER DEFAULT 0,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migration: add new columns if not exists
|
if err := s.db.AutoMigrate(&Exchange{}); err != nil {
|
||||||
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN passphrase TEXT DEFAULT ''`)
|
return err
|
||||||
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN exchange_type TEXT NOT NULL DEFAULT ''`)
|
}
|
||||||
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN account_name TEXT NOT NULL DEFAULT ''`)
|
|
||||||
s.db.Exec(`ALTER TABLE exchanges ADD COLUMN lighter_api_key_index INTEGER DEFAULT 0`)
|
|
||||||
|
|
||||||
// Run migration to multi-account if needed
|
// Run migration to multi-account if needed
|
||||||
if err := s.migrateToMultiAccount(); err != nil {
|
if err := s.migrateToMultiAccount(); err != nil {
|
||||||
@@ -85,120 +70,65 @@ func (s *ExchangeStore) initTables() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fix empty account_name for existing records
|
// Fix empty account_name for existing records
|
||||||
s.db.Exec(`UPDATE exchanges SET account_name = 'Default' WHERE account_name = '' OR account_name IS NULL`)
|
s.db.Model(&Exchange{}).Where("account_name = '' OR account_name IS NULL").Update("account_name", "Default")
|
||||||
|
|
||||||
// Update trigger for new schema
|
return nil
|
||||||
s.db.Exec(`DROP TRIGGER IF EXISTS update_exchanges_updated_at`)
|
|
||||||
_, err = s.db.Exec(`
|
|
||||||
CREATE TRIGGER IF NOT EXISTS update_exchanges_updated_at
|
|
||||||
AFTER UPDATE ON exchanges
|
|
||||||
BEGIN
|
|
||||||
UPDATE exchanges SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
|
||||||
END
|
|
||||||
`)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// migrateToMultiAccount migrates old schema (id=exchange_type) to new schema (id=UUID)
|
// migrateToMultiAccount migrates old schema (id=exchange_type) to new schema (id=UUID)
|
||||||
func (s *ExchangeStore) migrateToMultiAccount() error {
|
func (s *ExchangeStore) migrateToMultiAccount() error {
|
||||||
// Check if migration is needed by looking for old-style IDs (non-UUID)
|
// Check if migration is needed by looking for old-style IDs (non-UUID)
|
||||||
var count int
|
var count int64
|
||||||
err := s.db.QueryRow(`
|
err := s.db.Model(&Exchange{}).
|
||||||
SELECT COUNT(*) FROM exchanges
|
Where("exchange_type = '' AND id IN ?", []string{"binance", "bybit", "okx", "bitget", "hyperliquid", "aster", "lighter"}).
|
||||||
WHERE exchange_type = '' AND id IN ('binance', 'bybit', 'okx', 'bitget', 'hyperliquid', 'aster', 'lighter')
|
Count(&count).Error
|
||||||
`).Scan(&count)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
// No migration needed
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Infof("🔄 Migrating %d exchange records to multi-account schema...", count)
|
logger.Infof("🔄 Migrating %d exchange records to multi-account schema...", count)
|
||||||
|
|
||||||
// Get all old records
|
// Get all old records
|
||||||
rows, err := s.db.Query(`
|
var records []Exchange
|
||||||
SELECT id, user_id, name, type, enabled, api_key, secret_key,
|
err = s.db.Where("exchange_type = '' AND id IN ?", []string{"binance", "bybit", "okx", "bitget", "hyperliquid", "aster", "lighter"}).
|
||||||
COALESCE(passphrase, '') as passphrase, testnet,
|
Find(&records).Error
|
||||||
COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr,
|
|
||||||
COALESCE(aster_user, '') as aster_user,
|
|
||||||
COALESCE(aster_signer, '') as aster_signer,
|
|
||||||
COALESCE(aster_private_key, '') as aster_private_key,
|
|
||||||
COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr,
|
|
||||||
COALESCE(lighter_private_key, '') as lighter_private_key,
|
|
||||||
COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key
|
|
||||||
FROM exchanges
|
|
||||||
WHERE exchange_type = '' AND id IN ('binance', 'bybit', 'okx', 'bitget', 'hyperliquid', 'aster', 'lighter')
|
|
||||||
`)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
type oldRecord struct {
|
|
||||||
id, userID, name, typ string
|
|
||||||
enabled, testnet bool
|
|
||||||
apiKey, secretKey, passphrase string
|
|
||||||
hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string
|
|
||||||
lighterWalletAddr, lighterPrivateKey, lighterApiKeyPrivateKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
var records []oldRecord
|
|
||||||
for rows.Next() {
|
|
||||||
var r oldRecord
|
|
||||||
if err := rows.Scan(&r.id, &r.userID, &r.name, &r.typ, &r.enabled,
|
|
||||||
&r.apiKey, &r.secretKey, &r.passphrase, &r.testnet,
|
|
||||||
&r.hyperliquidWalletAddr, &r.asterUser, &r.asterSigner, &r.asterPrivateKey,
|
|
||||||
&r.lighterWalletAddr, &r.lighterPrivateKey, &r.lighterApiKeyPrivateKey); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
records = append(records, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Begin transaction
|
// Begin transaction
|
||||||
tx, err := s.db.Begin()
|
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
if err != nil {
|
for _, r := range records {
|
||||||
return err
|
newID := uuid.New().String()
|
||||||
}
|
oldID := r.ID // This is the exchange type (e.g., "binance")
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
// Migrate each record
|
// Update traders table to use new UUID
|
||||||
for _, r := range records {
|
if err := tx.Exec("UPDATE traders SET exchange_id = ? WHERE exchange_id = ? AND user_id = ?",
|
||||||
newID := uuid.New().String()
|
newID, oldID, r.UserID).Error; err != nil {
|
||||||
oldID := r.id // This is the exchange type (e.g., "binance")
|
logger.Errorf("Failed to update traders for exchange %s: %v", oldID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Update traders table to use new UUID
|
// Update the exchange record
|
||||||
_, err = tx.Exec(`UPDATE traders SET exchange_id = ? WHERE exchange_id = ? AND user_id = ?`,
|
if err := tx.Model(&Exchange{}).
|
||||||
newID, oldID, r.userID)
|
Where("id = ? AND user_id = ?", oldID, r.UserID).
|
||||||
if err != nil {
|
Updates(map[string]interface{}{
|
||||||
logger.Errorf("Failed to update traders for exchange %s: %v", oldID, err)
|
"id": newID,
|
||||||
return err
|
"exchange_type": oldID,
|
||||||
|
"account_name": "Default",
|
||||||
|
}).Error; err != nil {
|
||||||
|
logger.Errorf("Failed to migrate exchange %s: %v", oldID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("✅ Migrated exchange %s -> UUID %s for user %s", oldID, newID, r.UserID)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
// Update the exchange record
|
})
|
||||||
_, err = tx.Exec(`
|
|
||||||
UPDATE exchanges SET
|
|
||||||
id = ?,
|
|
||||||
exchange_type = ?,
|
|
||||||
account_name = ?
|
|
||||||
WHERE id = ? AND user_id = ?
|
|
||||||
`, newID, oldID, "Default", oldID, r.userID)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Failed to migrate exchange %s: %v", oldID, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Infof("✅ Migrated exchange %s -> UUID %s for user %s", oldID, newID, r.userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Infof("✅ Multi-account migration completed successfully")
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ExchangeStore) initDefaultData() error {
|
func (s *ExchangeStore) initDefaultData() error {
|
||||||
@@ -206,108 +136,24 @@ func (s *ExchangeStore) initDefaultData() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ExchangeStore) encrypt(plaintext string) string {
|
|
||||||
if s.encryptFunc != nil {
|
|
||||||
return s.encryptFunc(plaintext)
|
|
||||||
}
|
|
||||||
return plaintext
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ExchangeStore) decrypt(encrypted string) string {
|
|
||||||
if s.decryptFunc != nil {
|
|
||||||
return s.decryptFunc(encrypted)
|
|
||||||
}
|
|
||||||
return encrypted
|
|
||||||
}
|
|
||||||
|
|
||||||
// List gets user's exchange list
|
// List gets user's exchange list
|
||||||
func (s *ExchangeStore) List(userID string) ([]*Exchange, error) {
|
func (s *ExchangeStore) List(userID string) ([]*Exchange, error) {
|
||||||
rows, err := s.db.Query(`
|
var exchanges []*Exchange
|
||||||
SELECT id, COALESCE(exchange_type, '') as exchange_type, COALESCE(account_name, '') as account_name,
|
err := s.db.Where("user_id = ?", userID).Order("exchange_type, account_name").Find(&exchanges).Error
|
||||||
user_id, name, type, enabled, api_key, secret_key,
|
|
||||||
COALESCE(passphrase, '') as passphrase, testnet,
|
|
||||||
COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr,
|
|
||||||
COALESCE(aster_user, '') as aster_user,
|
|
||||||
COALESCE(aster_signer, '') as aster_signer,
|
|
||||||
COALESCE(aster_private_key, '') as aster_private_key,
|
|
||||||
COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr,
|
|
||||||
COALESCE(lighter_private_key, '') as lighter_private_key,
|
|
||||||
COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key,
|
|
||||||
COALESCE(lighter_api_key_index, 0) as lighter_api_key_index,
|
|
||||||
created_at, updated_at
|
|
||||||
FROM exchanges WHERE user_id = ? ORDER BY exchange_type, account_name
|
|
||||||
`, userID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
exchanges := make([]*Exchange, 0)
|
|
||||||
for rows.Next() {
|
|
||||||
var e Exchange
|
|
||||||
var createdAt, updatedAt string
|
|
||||||
err := rows.Scan(
|
|
||||||
&e.ID, &e.ExchangeType, &e.AccountName,
|
|
||||||
&e.UserID, &e.Name, &e.Type,
|
|
||||||
&e.Enabled, &e.APIKey, &e.SecretKey, &e.Passphrase, &e.Testnet,
|
|
||||||
&e.HyperliquidWalletAddr, &e.AsterUser, &e.AsterSigner, &e.AsterPrivateKey,
|
|
||||||
&e.LighterWalletAddr, &e.LighterPrivateKey, &e.LighterAPIKeyPrivateKey, &e.LighterAPIKeyIndex,
|
|
||||||
&createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
e.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
e.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
e.APIKey = s.decrypt(e.APIKey)
|
|
||||||
e.SecretKey = s.decrypt(e.SecretKey)
|
|
||||||
e.Passphrase = s.decrypt(e.Passphrase)
|
|
||||||
e.AsterPrivateKey = s.decrypt(e.AsterPrivateKey)
|
|
||||||
e.LighterPrivateKey = s.decrypt(e.LighterPrivateKey)
|
|
||||||
e.LighterAPIKeyPrivateKey = s.decrypt(e.LighterAPIKeyPrivateKey)
|
|
||||||
exchanges = append(exchanges, &e)
|
|
||||||
}
|
|
||||||
return exchanges, nil
|
return exchanges, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByID gets a specific exchange by UUID
|
// GetByID gets a specific exchange by UUID
|
||||||
func (s *ExchangeStore) GetByID(userID, id string) (*Exchange, error) {
|
func (s *ExchangeStore) GetByID(userID, id string) (*Exchange, error) {
|
||||||
var e Exchange
|
var exchange Exchange
|
||||||
var createdAt, updatedAt string
|
err := s.db.Where("id = ? AND user_id = ?", id, userID).First(&exchange).Error
|
||||||
err := s.db.QueryRow(`
|
|
||||||
SELECT id, COALESCE(exchange_type, '') as exchange_type, COALESCE(account_name, '') as account_name,
|
|
||||||
user_id, name, type, enabled, api_key, secret_key,
|
|
||||||
COALESCE(passphrase, '') as passphrase, testnet,
|
|
||||||
COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr,
|
|
||||||
COALESCE(aster_user, '') as aster_user,
|
|
||||||
COALESCE(aster_signer, '') as aster_signer,
|
|
||||||
COALESCE(aster_private_key, '') as aster_private_key,
|
|
||||||
COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr,
|
|
||||||
COALESCE(lighter_private_key, '') as lighter_private_key,
|
|
||||||
COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key,
|
|
||||||
COALESCE(lighter_api_key_index, 0) as lighter_api_key_index,
|
|
||||||
created_at, updated_at
|
|
||||||
FROM exchanges WHERE id = ? AND user_id = ?
|
|
||||||
`, id, userID).Scan(
|
|
||||||
&e.ID, &e.ExchangeType, &e.AccountName,
|
|
||||||
&e.UserID, &e.Name, &e.Type,
|
|
||||||
&e.Enabled, &e.APIKey, &e.SecretKey, &e.Passphrase, &e.Testnet,
|
|
||||||
&e.HyperliquidWalletAddr, &e.AsterUser, &e.AsterSigner, &e.AsterPrivateKey,
|
|
||||||
&e.LighterWalletAddr, &e.LighterPrivateKey, &e.LighterAPIKeyPrivateKey, &e.LighterAPIKeyIndex,
|
|
||||||
&createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
e.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
return &exchange, nil
|
||||||
e.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
e.APIKey = s.decrypt(e.APIKey)
|
|
||||||
e.SecretKey = s.decrypt(e.SecretKey)
|
|
||||||
e.Passphrase = s.decrypt(e.Passphrase)
|
|
||||||
e.AsterPrivateKey = s.decrypt(e.AsterPrivateKey)
|
|
||||||
e.LighterPrivateKey = s.decrypt(e.LighterPrivateKey)
|
|
||||||
e.LighterAPIKeyPrivateKey = s.decrypt(e.LighterAPIKeyPrivateKey)
|
|
||||||
return &e, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getExchangeNameAndType returns the display name and type for an exchange type
|
// getExchangeNameAndType returns the display name and type for an exchange type
|
||||||
@@ -341,7 +187,6 @@ func (s *ExchangeStore) Create(userID, exchangeType, accountName string, enabled
|
|||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
name, typ := getExchangeNameAndType(exchangeType)
|
name, typ := getExchangeNameAndType(exchangeType)
|
||||||
|
|
||||||
// If account name is empty, use "Default"
|
|
||||||
if accountName == "" {
|
if accountName == "" {
|
||||||
accountName = "Default"
|
accountName = "Default"
|
||||||
}
|
}
|
||||||
@@ -349,19 +194,29 @@ func (s *ExchangeStore) Create(userID, exchangeType, accountName string, enabled
|
|||||||
logger.Debugf("🔧 ExchangeStore.Create: userID=%s, exchangeType=%s, accountName=%s, id=%s",
|
logger.Debugf("🔧 ExchangeStore.Create: userID=%s, exchangeType=%s, accountName=%s, id=%s",
|
||||||
userID, exchangeType, accountName, id)
|
userID, exchangeType, accountName, id)
|
||||||
|
|
||||||
_, err := s.db.Exec(`
|
exchange := &Exchange{
|
||||||
INSERT INTO exchanges (id, exchange_type, account_name, user_id, name, type, enabled,
|
ID: id,
|
||||||
api_key, secret_key, passphrase, testnet,
|
ExchangeType: exchangeType,
|
||||||
hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key,
|
AccountName: accountName,
|
||||||
lighter_wallet_addr, lighter_private_key, lighter_api_key_private_key, lighter_api_key_index,
|
UserID: userID,
|
||||||
created_at, updated_at)
|
Name: name,
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))
|
Type: typ,
|
||||||
`, id, exchangeType, accountName, userID, name, typ, enabled,
|
Enabled: enabled,
|
||||||
s.encrypt(apiKey), s.encrypt(secretKey), s.encrypt(passphrase), testnet,
|
APIKey: crypto.EncryptedString(apiKey),
|
||||||
hyperliquidWalletAddr, asterUser, asterSigner, s.encrypt(asterPrivateKey),
|
SecretKey: crypto.EncryptedString(secretKey),
|
||||||
lighterWalletAddr, s.encrypt(lighterPrivateKey), s.encrypt(lighterApiKeyPrivateKey), lighterApiKeyIndex)
|
Passphrase: crypto.EncryptedString(passphrase),
|
||||||
|
Testnet: testnet,
|
||||||
|
HyperliquidWalletAddr: hyperliquidWalletAddr,
|
||||||
|
AsterUser: asterUser,
|
||||||
|
AsterSigner: asterSigner,
|
||||||
|
AsterPrivateKey: crypto.EncryptedString(asterPrivateKey),
|
||||||
|
LighterWalletAddr: lighterWalletAddr,
|
||||||
|
LighterPrivateKey: crypto.EncryptedString(lighterPrivateKey),
|
||||||
|
LighterAPIKeyPrivateKey: crypto.EncryptedString(lighterApiKeyPrivateKey),
|
||||||
|
LighterAPIKeyIndex: lighterApiKeyIndex,
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err := s.db.Create(exchange).Error; err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return id, nil
|
return id, nil
|
||||||
@@ -373,53 +228,42 @@ func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKe
|
|||||||
|
|
||||||
logger.Debugf("🔧 ExchangeStore.Update: userID=%s, id=%s, enabled=%v", userID, id, enabled)
|
logger.Debugf("🔧 ExchangeStore.Update: userID=%s, id=%s, enabled=%v", userID, id, enabled)
|
||||||
|
|
||||||
setClauses := []string{
|
updates := map[string]interface{}{
|
||||||
"enabled = ?",
|
"enabled": enabled,
|
||||||
"testnet = ?",
|
"testnet": testnet,
|
||||||
"hyperliquid_wallet_addr = ?",
|
"hyperliquid_wallet_addr": hyperliquidWalletAddr,
|
||||||
"aster_user = ?",
|
"aster_user": asterUser,
|
||||||
"aster_signer = ?",
|
"aster_signer": asterSigner,
|
||||||
"lighter_wallet_addr = ?",
|
"lighter_wallet_addr": lighterWalletAddr,
|
||||||
"lighter_api_key_index = ?",
|
"lighter_api_key_index": lighterApiKeyIndex,
|
||||||
"updated_at = datetime('now')",
|
"updated_at": time.Now(),
|
||||||
}
|
}
|
||||||
args := []interface{}{enabled, testnet, hyperliquidWalletAddr, asterUser, asterSigner, lighterWalletAddr, lighterApiKeyIndex}
|
|
||||||
|
|
||||||
|
// Only update encrypted fields if not empty
|
||||||
if apiKey != "" {
|
if apiKey != "" {
|
||||||
setClauses = append(setClauses, "api_key = ?")
|
updates["api_key"] = crypto.EncryptedString(apiKey)
|
||||||
args = append(args, s.encrypt(apiKey))
|
|
||||||
}
|
}
|
||||||
if secretKey != "" {
|
if secretKey != "" {
|
||||||
setClauses = append(setClauses, "secret_key = ?")
|
updates["secret_key"] = crypto.EncryptedString(secretKey)
|
||||||
args = append(args, s.encrypt(secretKey))
|
|
||||||
}
|
}
|
||||||
if passphrase != "" {
|
if passphrase != "" {
|
||||||
setClauses = append(setClauses, "passphrase = ?")
|
updates["passphrase"] = crypto.EncryptedString(passphrase)
|
||||||
args = append(args, s.encrypt(passphrase))
|
|
||||||
}
|
}
|
||||||
if asterPrivateKey != "" {
|
if asterPrivateKey != "" {
|
||||||
setClauses = append(setClauses, "aster_private_key = ?")
|
updates["aster_private_key"] = crypto.EncryptedString(asterPrivateKey)
|
||||||
args = append(args, s.encrypt(asterPrivateKey))
|
|
||||||
}
|
}
|
||||||
if lighterPrivateKey != "" {
|
if lighterPrivateKey != "" {
|
||||||
setClauses = append(setClauses, "lighter_private_key = ?")
|
updates["lighter_private_key"] = crypto.EncryptedString(lighterPrivateKey)
|
||||||
args = append(args, s.encrypt(lighterPrivateKey))
|
|
||||||
}
|
}
|
||||||
if lighterApiKeyPrivateKey != "" {
|
if lighterApiKeyPrivateKey != "" {
|
||||||
setClauses = append(setClauses, "lighter_api_key_private_key = ?")
|
updates["lighter_api_key_private_key"] = crypto.EncryptedString(lighterApiKeyPrivateKey)
|
||||||
args = append(args, s.encrypt(lighterApiKeyPrivateKey))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
args = append(args, id, userID)
|
result := s.db.Model(&Exchange{}).Where("id = ? AND user_id = ?", id, userID).Updates(updates)
|
||||||
query := fmt.Sprintf(`UPDATE exchanges SET %s WHERE id = ? AND user_id = ?`, strings.Join(setClauses, ", "))
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
result, err := s.db.Exec(query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
rowsAffected, _ := result.RowsAffected()
|
|
||||||
if rowsAffected == 0 {
|
|
||||||
return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID)
|
return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -427,13 +271,16 @@ func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKe
|
|||||||
|
|
||||||
// UpdateAccountName updates the account name for an exchange
|
// UpdateAccountName updates the account name for an exchange
|
||||||
func (s *ExchangeStore) UpdateAccountName(userID, id, accountName string) error {
|
func (s *ExchangeStore) UpdateAccountName(userID, id, accountName string) error {
|
||||||
result, err := s.db.Exec(`UPDATE exchanges SET account_name = ?, updated_at = datetime('now') WHERE id = ? AND user_id = ?`,
|
result := s.db.Model(&Exchange{}).
|
||||||
accountName, id, userID)
|
Where("id = ? AND user_id = ?", id, userID).
|
||||||
if err != nil {
|
Updates(map[string]interface{}{
|
||||||
return err
|
"account_name": accountName,
|
||||||
|
"updated_at": time.Now(),
|
||||||
|
})
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
}
|
}
|
||||||
rowsAffected, _ := result.RowsAffected()
|
if result.RowsAffected == 0 {
|
||||||
if rowsAffected == 0 {
|
|
||||||
return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID)
|
return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -441,12 +288,11 @@ func (s *ExchangeStore) UpdateAccountName(userID, id, accountName string) error
|
|||||||
|
|
||||||
// Delete deletes an exchange account
|
// Delete deletes an exchange account
|
||||||
func (s *ExchangeStore) Delete(userID, id string) error {
|
func (s *ExchangeStore) Delete(userID, id string) error {
|
||||||
result, err := s.db.Exec(`DELETE FROM exchanges WHERE id = ? AND user_id = ?`, id, userID)
|
result := s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Exchange{})
|
||||||
if err != nil {
|
if result.Error != nil {
|
||||||
return err
|
return result.Error
|
||||||
}
|
}
|
||||||
rowsAffected, _ := result.RowsAffected()
|
if result.RowsAffected == 0 {
|
||||||
if rowsAffected == 0 {
|
|
||||||
return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID)
|
return fmt.Errorf("exchange not found: id=%s, userID=%s", id, userID)
|
||||||
}
|
}
|
||||||
logger.Infof("🗑️ Deleted exchange: id=%s, userID=%s", id, userID)
|
logger.Infof("🗑️ Deleted exchange: id=%s, userID=%s", id, userID)
|
||||||
@@ -460,20 +306,25 @@ func (s *ExchangeStore) CreateLegacy(userID, id, name, typ string, enabled bool,
|
|||||||
|
|
||||||
// Check if this is an old-style ID (exchange type as ID)
|
// Check if this is an old-style ID (exchange type as ID)
|
||||||
if id == "binance" || id == "bybit" || id == "okx" || id == "bitget" || id == "hyperliquid" || id == "aster" || id == "lighter" {
|
if id == "binance" || id == "bybit" || id == "okx" || id == "bitget" || id == "hyperliquid" || id == "aster" || id == "lighter" {
|
||||||
// Use new Create method with exchange type
|
|
||||||
_, err := s.Create(userID, id, "Default", enabled, apiKey, secretKey, "", testnet,
|
_, err := s.Create(userID, id, "Default", enabled, apiKey, secretKey, "", testnet,
|
||||||
hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, "", "", "", 0)
|
hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, "", "", "", 0)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise assume it's already a UUID
|
// Otherwise assume it's already a UUID
|
||||||
_, err := s.db.Exec(`
|
exchange := &Exchange{
|
||||||
INSERT OR IGNORE INTO exchanges (id, exchange_type, account_name, user_id, name, type, enabled,
|
ID: id,
|
||||||
api_key, secret_key, testnet,
|
UserID: userID,
|
||||||
hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key,
|
Name: name,
|
||||||
lighter_wallet_addr, lighter_private_key)
|
Type: typ,
|
||||||
VALUES (?, '', '', ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, '', '')
|
Enabled: enabled,
|
||||||
`, id, userID, name, typ, enabled, s.encrypt(apiKey), s.encrypt(secretKey), testnet,
|
APIKey: crypto.EncryptedString(apiKey),
|
||||||
hyperliquidWalletAddr, asterUser, asterSigner, s.encrypt(asterPrivateKey))
|
SecretKey: crypto.EncryptedString(secretKey),
|
||||||
return err
|
Testnet: testnet,
|
||||||
|
HyperliquidWalletAddr: hyperliquidWalletAddr,
|
||||||
|
AsterUser: asterUser,
|
||||||
|
AsterSigner: asterSigner,
|
||||||
|
AsterPrivateKey: crypto.EncryptedString(asterPrivateKey),
|
||||||
|
}
|
||||||
|
return s.db.Where("id = ?", id).FirstOrCreate(exchange).Error
|
||||||
}
|
}
|
||||||
|
|||||||
+146
@@ -0,0 +1,146 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gorm.io/driver/postgres"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GormDB is the global GORM database connection
|
||||||
|
var gormDB *gorm.DB
|
||||||
|
|
||||||
|
// DB returns the GORM database connection
|
||||||
|
func DB() *gorm.DB {
|
||||||
|
return gormDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitGorm initializes GORM with SQLite
|
||||||
|
func InitGorm(dbPath string) (*gorm.DB, error) {
|
||||||
|
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set connection pool for SQLite
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sqlDB.SetMaxOpenConns(1)
|
||||||
|
sqlDB.SetMaxIdleConns(1)
|
||||||
|
|
||||||
|
// Enable foreign keys for SQLite
|
||||||
|
db.Exec("PRAGMA foreign_keys = ON")
|
||||||
|
db.Exec("PRAGMA journal_mode = DELETE")
|
||||||
|
db.Exec("PRAGMA synchronous = FULL")
|
||||||
|
db.Exec("PRAGMA busy_timeout = 5000")
|
||||||
|
|
||||||
|
gormDB = db
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitGormPostgres initializes GORM with PostgreSQL
|
||||||
|
func InitGormPostgres(host string, port int, user, password, dbname, sslmode string) (*gorm.DB, error) {
|
||||||
|
dsn := fmt.Sprintf(
|
||||||
|
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||||
|
host, port, user, password, dbname, sslmode,
|
||||||
|
)
|
||||||
|
|
||||||
|
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open PostgreSQL database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set connection pool for PostgreSQL
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sqlDB.SetMaxOpenConns(25)
|
||||||
|
sqlDB.SetMaxIdleConns(5)
|
||||||
|
|
||||||
|
gormDB = db
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitGormWithConfig initializes GORM with provided configuration
|
||||||
|
// Uses DBConfig from driver.go
|
||||||
|
func InitGormWithConfig(cfg DBConfig) (*gorm.DB, error) {
|
||||||
|
switch cfg.Type {
|
||||||
|
case DBTypeSQLite:
|
||||||
|
return InitGorm(cfg.Path)
|
||||||
|
|
||||||
|
case DBTypePostgres:
|
||||||
|
return InitGormPostgres(
|
||||||
|
cfg.Host,
|
||||||
|
cfg.Port,
|
||||||
|
cfg.User,
|
||||||
|
cfg.Password,
|
||||||
|
cfg.DBName,
|
||||||
|
cfg.SSLMode,
|
||||||
|
)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported DB_TYPE: %s (use 'sqlite' or 'postgres')", cfg.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Query Scopes - Reusable query helpers
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
// ForUser returns a scope that filters by user_id
|
||||||
|
func ForUser(userID string) func(*gorm.DB) *gorm.DB {
|
||||||
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Where("user_id = ?", userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForTrader returns a scope that filters by trader_id
|
||||||
|
func ForTrader(traderID string) func(*gorm.DB) *gorm.DB {
|
||||||
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Where("trader_id = ?", traderID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenPositions returns a scope for open positions
|
||||||
|
func OpenPositions() func(*gorm.DB) *gorm.DB {
|
||||||
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Where("status = ?", "OPEN")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClosedPositions returns a scope for closed positions
|
||||||
|
func ClosedPositions() func(*gorm.DB) *gorm.DB {
|
||||||
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Where("status = ?", "CLOSED")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrderByCreatedDesc returns a scope that orders by created_at DESC
|
||||||
|
func OrderByCreatedDesc() func(*gorm.DB) *gorm.DB {
|
||||||
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Order("created_at DESC")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrderByUpdatedDesc returns a scope that orders by updated_at DESC
|
||||||
|
func OrderByUpdatedDesc() func(*gorm.DB) *gorm.DB {
|
||||||
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Order("updated_at DESC")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Paginate returns a scope for pagination
|
||||||
|
func Paginate(limit, offset int) func(*gorm.DB) *gorm.DB {
|
||||||
|
return func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Limit(limit).Offset(offset)
|
||||||
|
}
|
||||||
|
}
|
||||||
+199
-448
@@ -1,495 +1,255 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TraderOrder 订单记录(完整的订单生命周期追踪)
|
// TraderOrder order record
|
||||||
type TraderOrder struct {
|
type TraderOrder struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||||
TraderID string `json:"trader_id"`
|
TraderID string `gorm:"column:trader_id;not null;index:idx_orders_trader_id" json:"trader_id"`
|
||||||
ExchangeID string `json:"exchange_id"` // Exchange account UUID
|
ExchangeID string `gorm:"column:exchange_id;not null;default:''" json:"exchange_id"`
|
||||||
ExchangeType string `json:"exchange_type"` // Exchange type (hyperliquid/lighter/binance/etc)
|
ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"`
|
||||||
ExchangeOrderID string `json:"exchange_order_id"` // Exchange order ID
|
ExchangeOrderID string `gorm:"column:exchange_order_id;not null;uniqueIndex:idx_orders_exchange_unique,priority:2" json:"exchange_order_id"`
|
||||||
ClientOrderID string `json:"client_order_id"` // Client order ID
|
ClientOrderID string `gorm:"column:client_order_id;default:''" json:"client_order_id"`
|
||||||
Symbol string `json:"symbol"` // Trading pair
|
Symbol string `gorm:"column:symbol;not null;index:idx_orders_symbol" json:"symbol"`
|
||||||
Side string `json:"side"` // BUY/SELL
|
Side string `gorm:"column:side;not null" json:"side"`
|
||||||
PositionSide string `json:"position_side"` // LONG/SHORT (hedge mode)
|
PositionSide string `gorm:"column:position_side;default:''" json:"position_side"`
|
||||||
Type string `json:"type"` // MARKET/LIMIT/STOP/STOP_MARKET/TAKE_PROFIT/TAKE_PROFIT_MARKET
|
Type string `gorm:"column:type;not null" json:"type"`
|
||||||
TimeInForce string `json:"time_in_force"` // GTC/IOC/FOK
|
TimeInForce string `gorm:"column:time_in_force;default:GTC" json:"time_in_force"`
|
||||||
Quantity float64 `json:"quantity"` // 订单数量
|
Quantity float64 `gorm:"column:quantity;not null" json:"quantity"`
|
||||||
Price float64 `json:"price"` // 限价单价格
|
Price float64 `gorm:"column:price;default:0" json:"price"`
|
||||||
StopPrice float64 `json:"stop_price"` // 止损/止盈触发价格
|
StopPrice float64 `gorm:"column:stop_price;default:0" json:"stop_price"`
|
||||||
Status string `json:"status"` // NEW/PARTIALLY_FILLED/FILLED/CANCELED/REJECTED/EXPIRED
|
Status string `gorm:"column:status;not null;default:NEW;index:idx_orders_status" json:"status"`
|
||||||
FilledQuantity float64 `json:"filled_quantity"` // 已成交数量
|
FilledQuantity float64 `gorm:"column:filled_quantity;default:0" json:"filled_quantity"`
|
||||||
AvgFillPrice float64 `json:"avg_fill_price"` // 平均成交价格
|
AvgFillPrice float64 `gorm:"column:avg_fill_price;default:0" json:"avg_fill_price"`
|
||||||
Commission float64 `json:"commission"` // 手续费总额
|
Commission float64 `gorm:"column:commission;default:0" json:"commission"`
|
||||||
CommissionAsset string `json:"commission_asset"` // 手续费资产(USDT等)
|
CommissionAsset string `gorm:"column:commission_asset;default:USDT" json:"commission_asset"`
|
||||||
Leverage int `json:"leverage"` // 杠杆倍数
|
Leverage int `gorm:"column:leverage;default:1" json:"leverage"`
|
||||||
ReduceOnly bool `json:"reduce_only"` // 是否只减仓
|
ReduceOnly bool `gorm:"column:reduce_only;default:false" json:"reduce_only"`
|
||||||
ClosePosition bool `json:"close_position"` // 是否平仓单
|
ClosePosition bool `gorm:"column:close_position;default:false" json:"close_position"`
|
||||||
WorkingType string `json:"working_type"` // CONTRACT_PRICE/MARK_PRICE
|
WorkingType string `gorm:"column:working_type;default:CONTRACT_PRICE" json:"working_type"`
|
||||||
PriceProtect bool `json:"price_protect"` // 价格保护
|
PriceProtect bool `gorm:"column:price_protect;default:false" json:"price_protect"`
|
||||||
OrderAction string `json:"order_action"` // OPEN_LONG/OPEN_SHORT/CLOSE_LONG/CLOSE_SHORT/ADD_LONG/ADD_SHORT/STOP_LOSS/TAKE_PROFIT
|
OrderAction string `gorm:"column:order_action;default:''" json:"order_action"`
|
||||||
RelatedPositionID int64 `json:"related_position_id"` // 关联的仓位ID
|
RelatedPositionID int64 `gorm:"column:related_position_id;default:0" json:"related_position_id"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"`
|
||||||
FilledAt time.Time `json:"filled_at"` // 完全成交时间
|
FilledAt time.Time `gorm:"column:filled_at" json:"filled_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TraderFill trade record (one order may have multiple fills)
|
// TableName returns the table name for TraderOrder
|
||||||
|
func (TraderOrder) TableName() string {
|
||||||
|
return "trader_orders"
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraderFill trade record
|
||||||
type TraderFill struct {
|
type TraderFill struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||||
TraderID string `json:"trader_id"`
|
TraderID string `gorm:"column:trader_id;not null;index:idx_fills_trader_id" json:"trader_id"`
|
||||||
ExchangeID string `json:"exchange_id"` // Exchange account UUID
|
ExchangeID string `gorm:"column:exchange_id;not null;default:''" json:"exchange_id"`
|
||||||
ExchangeType string `json:"exchange_type"` // Exchange type (hyperliquid/lighter/binance/etc)
|
ExchangeType string `gorm:"column:exchange_type;not null;default:''" json:"exchange_type"`
|
||||||
OrderID int64 `json:"order_id"` // Related order ID
|
OrderID int64 `gorm:"column:order_id;not null;index:idx_fills_order_id" json:"order_id"`
|
||||||
ExchangeOrderID string `json:"exchange_order_id"` // Exchange order ID
|
ExchangeOrderID string `gorm:"column:exchange_order_id;not null" json:"exchange_order_id"`
|
||||||
ExchangeTradeID string `json:"exchange_trade_id"` // Exchange trade ID
|
ExchangeTradeID string `gorm:"column:exchange_trade_id;not null;uniqueIndex:idx_fills_exchange_unique,priority:2" json:"exchange_trade_id"`
|
||||||
Symbol string `json:"symbol"`
|
Symbol string `gorm:"column:symbol;not null" json:"symbol"`
|
||||||
Side string `json:"side"` // BUY/SELL
|
Side string `gorm:"column:side;not null" json:"side"`
|
||||||
Price float64 `json:"price"` // 成交价格
|
Price float64 `gorm:"column:price;not null" json:"price"`
|
||||||
Quantity float64 `json:"quantity"` // 成交数量
|
Quantity float64 `gorm:"column:quantity;not null" json:"quantity"`
|
||||||
QuoteQuantity float64 `json:"quote_quantity"` // 成交金额(USDT)
|
QuoteQuantity float64 `gorm:"column:quote_quantity;not null" json:"quote_quantity"`
|
||||||
Commission float64 `json:"commission"` // 手续费
|
Commission float64 `gorm:"column:commission;not null" json:"commission"`
|
||||||
CommissionAsset string `json:"commission_asset"`
|
CommissionAsset string `gorm:"column:commission_asset;not null" json:"commission_asset"`
|
||||||
RealizedPnL float64 `json:"realized_pnl"` // 实现盈亏(平仓时)
|
RealizedPnL float64 `gorm:"column:realized_pnl;default:0" json:"realized_pnl"`
|
||||||
IsMaker bool `json:"is_maker"` // 是否为maker
|
IsMaker bool `gorm:"column:is_maker;default:false" json:"is_maker"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// OrderStore 订单存储
|
// TableName returns the table name for TraderFill
|
||||||
|
func (TraderFill) TableName() string {
|
||||||
|
return "trader_fills"
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrderStore order storage
|
||||||
type OrderStore struct {
|
type OrderStore struct {
|
||||||
db *sql.DB
|
db *gorm.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOrderStore 创建订单存储实例
|
// NewOrderStore creates order storage instance
|
||||||
func NewOrderStore(db *sql.DB) *OrderStore {
|
func NewOrderStore(db *gorm.DB) *OrderStore {
|
||||||
return &OrderStore{db: db}
|
return &OrderStore{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitTables 初始化订单表
|
// InitTables initializes order tables
|
||||||
func (s *OrderStore) InitTables() error {
|
func (s *OrderStore) InitTables() error {
|
||||||
// 创建订单表
|
// For PostgreSQL, check if tables exist to avoid AutoMigrate index conflicts
|
||||||
_, err := s.db.Exec(`
|
if s.db.Dialector.Name() == "postgres" {
|
||||||
CREATE TABLE IF NOT EXISTS trader_orders (
|
var ordersExist, fillsExist int64
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_orders'`).Scan(&ordersExist)
|
||||||
trader_id TEXT NOT NULL,
|
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'trader_fills'`).Scan(&fillsExist)
|
||||||
exchange_id TEXT NOT NULL DEFAULT '',
|
|
||||||
exchange_type TEXT NOT NULL DEFAULT '',
|
if ordersExist > 0 && fillsExist > 0 {
|
||||||
exchange_order_id TEXT NOT NULL,
|
// Tables exist - just ensure indexes exist, skip AutoMigrate
|
||||||
client_order_id TEXT DEFAULT '',
|
s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_orders_exchange_unique ON trader_orders(exchange_id, exchange_order_id)`)
|
||||||
symbol TEXT NOT NULL,
|
s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_fills_exchange_unique ON trader_fills(exchange_id, exchange_trade_id)`)
|
||||||
side TEXT NOT NULL,
|
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_trader_id ON trader_orders(trader_id)`)
|
||||||
position_side TEXT DEFAULT '',
|
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_symbol ON trader_orders(symbol)`)
|
||||||
type TEXT NOT NULL,
|
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_status ON trader_orders(status)`)
|
||||||
time_in_force TEXT DEFAULT 'GTC',
|
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_trader_id ON trader_fills(trader_id)`)
|
||||||
quantity REAL NOT NULL,
|
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_order_id ON trader_fills(order_id)`)
|
||||||
price REAL DEFAULT 0,
|
return nil
|
||||||
stop_price REAL DEFAULT 0,
|
}
|
||||||
status TEXT NOT NULL DEFAULT 'NEW',
|
|
||||||
filled_quantity REAL DEFAULT 0,
|
|
||||||
avg_fill_price REAL DEFAULT 0,
|
|
||||||
commission REAL DEFAULT 0,
|
|
||||||
commission_asset TEXT DEFAULT 'USDT',
|
|
||||||
leverage INTEGER DEFAULT 1,
|
|
||||||
reduce_only INTEGER DEFAULT 0,
|
|
||||||
close_position INTEGER DEFAULT 0,
|
|
||||||
working_type TEXT DEFAULT 'CONTRACT_PRICE',
|
|
||||||
price_protect INTEGER DEFAULT 0,
|
|
||||||
order_action TEXT DEFAULT '',
|
|
||||||
related_position_id INTEGER DEFAULT 0,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
filled_at DATETIME,
|
|
||||||
UNIQUE(exchange_id, exchange_order_id)
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create trader_orders table: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建成交记录表
|
if err := s.db.AutoMigrate(&TraderOrder{}, &TraderFill{}); err != nil {
|
||||||
_, err = s.db.Exec(`
|
return fmt.Errorf("failed to migrate order tables: %w", err)
|
||||||
CREATE TABLE IF NOT EXISTS trader_fills (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
trader_id TEXT NOT NULL,
|
|
||||||
exchange_id TEXT NOT NULL DEFAULT '',
|
|
||||||
exchange_type TEXT NOT NULL DEFAULT '',
|
|
||||||
order_id INTEGER NOT NULL,
|
|
||||||
exchange_order_id TEXT NOT NULL,
|
|
||||||
exchange_trade_id TEXT NOT NULL,
|
|
||||||
symbol TEXT NOT NULL,
|
|
||||||
side TEXT NOT NULL,
|
|
||||||
price REAL NOT NULL,
|
|
||||||
quantity REAL NOT NULL,
|
|
||||||
quote_quantity REAL NOT NULL,
|
|
||||||
commission REAL NOT NULL,
|
|
||||||
commission_asset TEXT NOT NULL,
|
|
||||||
realized_pnl REAL DEFAULT 0,
|
|
||||||
is_maker INTEGER DEFAULT 0,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
UNIQUE(exchange_id, exchange_trade_id),
|
|
||||||
FOREIGN KEY (order_id) REFERENCES trader_orders(id)
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create trader_fills table: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建索引
|
// Create unique composite index for exchange_id + exchange_order_id
|
||||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_trader_id ON trader_orders(trader_id)`)
|
s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_orders_exchange_unique ON trader_orders(exchange_id, exchange_order_id)`)
|
||||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_symbol ON trader_orders(symbol)`)
|
// Create unique composite index for exchange_id + exchange_trade_id
|
||||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_status ON trader_orders(status)`)
|
s.db.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_fills_exchange_unique ON trader_fills(exchange_id, exchange_trade_id)`)
|
||||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_orders_exchange_order_id ON trader_orders(exchange_id, exchange_order_id)`)
|
|
||||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_order_id ON trader_fills(order_id)`)
|
|
||||||
s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_fills_trader_id ON trader_fills(trader_id)`)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateOrder 创建订单记录(去重:如果订单已存在则返回已有记录)
|
// CreateOrder creates order record
|
||||||
func (s *OrderStore) CreateOrder(order *TraderOrder) error {
|
func (s *OrderStore) CreateOrder(order *TraderOrder) error {
|
||||||
// 1. 先检查订单是否已存在(去重)
|
// Check if order already exists
|
||||||
existing, err := s.GetOrderByExchangeID(order.ExchangeID, order.ExchangeOrderID)
|
existing, err := s.GetOrderByExchangeID(order.ExchangeID, order.ExchangeOrderID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to check existing order: %w", err)
|
return fmt.Errorf("failed to check existing order: %w", err)
|
||||||
}
|
}
|
||||||
if existing != nil {
|
if existing != nil {
|
||||||
// 订单已存在,返回已有记录的ID
|
|
||||||
order.ID = existing.ID
|
order.ID = existing.ID
|
||||||
order.CreatedAt = existing.CreatedAt
|
order.CreatedAt = existing.CreatedAt
|
||||||
order.UpdatedAt = existing.UpdatedAt
|
order.UpdatedAt = existing.UpdatedAt
|
||||||
return nil // 不是错误,只是跳过插入
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 订单不存在,插入新记录
|
return s.db.Create(order).Error
|
||||||
now := time.Now()
|
|
||||||
order.CreatedAt = now
|
|
||||||
order.UpdatedAt = now
|
|
||||||
|
|
||||||
result, err := s.db.Exec(`
|
|
||||||
INSERT INTO trader_orders (
|
|
||||||
trader_id, exchange_id, exchange_type, exchange_order_id, client_order_id,
|
|
||||||
symbol, side, position_side, type, time_in_force,
|
|
||||||
quantity, price, stop_price, status,
|
|
||||||
filled_quantity, avg_fill_price, commission, commission_asset,
|
|
||||||
leverage, reduce_only, close_position, working_type, price_protect,
|
|
||||||
order_action, related_position_id,
|
|
||||||
created_at, updated_at, filled_at
|
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
`,
|
|
||||||
order.TraderID, order.ExchangeID, order.ExchangeType, order.ExchangeOrderID, order.ClientOrderID,
|
|
||||||
order.Symbol, order.Side, order.PositionSide, order.Type, order.TimeInForce,
|
|
||||||
order.Quantity, order.Price, order.StopPrice, order.Status,
|
|
||||||
order.FilledQuantity, order.AvgFillPrice, order.Commission, order.CommissionAsset,
|
|
||||||
order.Leverage, order.ReduceOnly, order.ClosePosition, order.WorkingType, order.PriceProtect,
|
|
||||||
order.OrderAction, order.RelatedPositionID,
|
|
||||||
formatTimeOrNow(order.CreatedAt, now), formatTimeOrNow(order.UpdatedAt, now),
|
|
||||||
formatTimePtr(order.FilledAt),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create order: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
id, _ := result.LastInsertId()
|
|
||||||
order.ID = id
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateOrderStatus 更新订单状态
|
// UpdateOrderStatus updates order status
|
||||||
func (s *OrderStore) UpdateOrderStatus(id int64, status string, filledQty, avgPrice, commission float64) error {
|
func (s *OrderStore) UpdateOrderStatus(id int64, status string, filledQty, avgPrice, commission float64) error {
|
||||||
now := time.Now()
|
updates := map[string]interface{}{
|
||||||
updateSQL := `
|
"status": status,
|
||||||
UPDATE trader_orders SET
|
"filled_quantity": filledQty,
|
||||||
status = ?,
|
"avg_fill_price": avgPrice,
|
||||||
filled_quantity = ?,
|
"commission": commission,
|
||||||
avg_fill_price = ?,
|
}
|
||||||
commission = ?,
|
|
||||||
updated_at = ?
|
|
||||||
`
|
|
||||||
args := []interface{}{status, filledQty, avgPrice, commission, now.Format(time.RFC3339)}
|
|
||||||
|
|
||||||
// 如果完全成交,记录成交时间
|
|
||||||
if status == "FILLED" {
|
if status == "FILLED" {
|
||||||
updateSQL += `, filled_at = ?`
|
updates["filled_at"] = time.Now()
|
||||||
args = append(args, now.Format(time.RFC3339))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
updateSQL += ` WHERE id = ?`
|
return s.db.Model(&TraderOrder{}).Where("id = ?", id).Updates(updates).Error
|
||||||
args = append(args, id)
|
|
||||||
|
|
||||||
_, err := s.db.Exec(updateSQL, args...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to update order status: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateFill 创建成交记录(去重:如果成交已存在则跳过)
|
// CreateFill creates fill record
|
||||||
func (s *OrderStore) CreateFill(fill *TraderFill) error {
|
func (s *OrderStore) CreateFill(fill *TraderFill) error {
|
||||||
// 1. 先检查成交是否已存在(去重)
|
// Check if fill already exists
|
||||||
existing, err := s.GetFillByExchangeTradeID(fill.ExchangeID, fill.ExchangeTradeID)
|
existing, err := s.GetFillByExchangeTradeID(fill.ExchangeID, fill.ExchangeTradeID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to check existing fill: %w", err)
|
return fmt.Errorf("failed to check existing fill: %w", err)
|
||||||
}
|
}
|
||||||
if existing != nil {
|
if existing != nil {
|
||||||
// 成交已存在,返回已有记录的ID
|
|
||||||
fill.ID = existing.ID
|
fill.ID = existing.ID
|
||||||
fill.CreatedAt = existing.CreatedAt
|
fill.CreatedAt = existing.CreatedAt
|
||||||
return nil // 不是错误,只是跳过插入
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 成交不存在,插入新记录
|
return s.db.Create(fill).Error
|
||||||
now := time.Now()
|
|
||||||
fill.CreatedAt = now
|
|
||||||
|
|
||||||
result, err := s.db.Exec(`
|
|
||||||
INSERT INTO trader_fills (
|
|
||||||
trader_id, exchange_id, exchange_type, order_id, exchange_order_id, exchange_trade_id,
|
|
||||||
symbol, side, price, quantity, quote_quantity,
|
|
||||||
commission, commission_asset, realized_pnl, is_maker,
|
|
||||||
created_at
|
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
`,
|
|
||||||
fill.TraderID, fill.ExchangeID, fill.ExchangeType, fill.OrderID, fill.ExchangeOrderID, fill.ExchangeTradeID,
|
|
||||||
fill.Symbol, fill.Side, fill.Price, fill.Quantity, fill.QuoteQuantity,
|
|
||||||
fill.Commission, fill.CommissionAsset, fill.RealizedPnL, fill.IsMaker,
|
|
||||||
now.Format(time.RFC3339),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create fill: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
id, _ := result.LastInsertId()
|
|
||||||
fill.ID = id
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFillByExchangeTradeID 根据交易所成交ID获取成交记录
|
// GetFillByExchangeTradeID gets fill by exchange trade ID
|
||||||
func (s *OrderStore) GetFillByExchangeTradeID(exchangeID, exchangeTradeID string) (*TraderFill, error) {
|
func (s *OrderStore) GetFillByExchangeTradeID(exchangeID, exchangeTradeID string) (*TraderFill, error) {
|
||||||
row := s.db.QueryRow(`
|
|
||||||
SELECT id, trader_id, exchange_id, exchange_type, order_id, exchange_order_id, exchange_trade_id,
|
|
||||||
symbol, side, price, quantity, quote_quantity,
|
|
||||||
commission, commission_asset, realized_pnl, is_maker,
|
|
||||||
created_at
|
|
||||||
FROM trader_fills
|
|
||||||
WHERE exchange_id = ? AND exchange_trade_id = ?
|
|
||||||
`, exchangeID, exchangeTradeID)
|
|
||||||
|
|
||||||
var fill TraderFill
|
var fill TraderFill
|
||||||
var createdAt sql.NullString
|
err := s.db.Where("exchange_id = ? AND exchange_trade_id = ?", exchangeID, exchangeTradeID).First(&fill).Error
|
||||||
err := row.Scan(
|
|
||||||
&fill.ID, &fill.TraderID, &fill.ExchangeID, &fill.ExchangeType, &fill.OrderID, &fill.ExchangeOrderID, &fill.ExchangeTradeID,
|
|
||||||
&fill.Symbol, &fill.Side, &fill.Price, &fill.Quantity, &fill.QuoteQuantity,
|
|
||||||
&fill.Commission, &fill.CommissionAsset, &fill.RealizedPnL, &fill.IsMaker,
|
|
||||||
&createdAt,
|
|
||||||
)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == gorm.ErrRecordNotFound {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("failed to get fill: %w", err)
|
return nil, fmt.Errorf("failed to get fill: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse time
|
|
||||||
if createdAt.Valid {
|
|
||||||
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
|
|
||||||
fill.CreatedAt = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &fill, nil
|
return &fill, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOrderByExchangeID 根据交易所订单ID获取订单
|
// GetOrderByExchangeID gets order by exchange order ID
|
||||||
func (s *OrderStore) GetOrderByExchangeID(exchangeID, exchangeOrderID string) (*TraderOrder, error) {
|
func (s *OrderStore) GetOrderByExchangeID(exchangeID, exchangeOrderID string) (*TraderOrder, error) {
|
||||||
row := s.db.QueryRow(`
|
|
||||||
SELECT id, trader_id, exchange_id, exchange_type, exchange_order_id, client_order_id,
|
|
||||||
symbol, side, position_side, type, time_in_force,
|
|
||||||
quantity, price, stop_price, status,
|
|
||||||
filled_quantity, avg_fill_price, commission, commission_asset,
|
|
||||||
leverage, reduce_only, close_position, working_type, price_protect,
|
|
||||||
order_action, related_position_id,
|
|
||||||
created_at, updated_at, filled_at
|
|
||||||
FROM trader_orders
|
|
||||||
WHERE exchange_id = ? AND exchange_order_id = ?
|
|
||||||
`, exchangeID, exchangeOrderID)
|
|
||||||
|
|
||||||
var order TraderOrder
|
var order TraderOrder
|
||||||
var createdAt, updatedAt, filledAt sql.NullString
|
err := s.db.Where("exchange_id = ? AND exchange_order_id = ?", exchangeID, exchangeOrderID).First(&order).Error
|
||||||
err := row.Scan(
|
|
||||||
&order.ID, &order.TraderID, &order.ExchangeID, &order.ExchangeType, &order.ExchangeOrderID, &order.ClientOrderID,
|
|
||||||
&order.Symbol, &order.Side, &order.PositionSide, &order.Type, &order.TimeInForce,
|
|
||||||
&order.Quantity, &order.Price, &order.StopPrice, &order.Status,
|
|
||||||
&order.FilledQuantity, &order.AvgFillPrice, &order.Commission, &order.CommissionAsset,
|
|
||||||
&order.Leverage, &order.ReduceOnly, &order.ClosePosition, &order.WorkingType, &order.PriceProtect,
|
|
||||||
&order.OrderAction, &order.RelatedPositionID,
|
|
||||||
&createdAt, &updatedAt, &filledAt,
|
|
||||||
)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == gorm.ErrRecordNotFound {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("failed to get order: %w", err)
|
return nil, fmt.Errorf("failed to get order: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse times
|
|
||||||
if createdAt.Valid {
|
|
||||||
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
|
|
||||||
order.CreatedAt = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if updatedAt.Valid {
|
|
||||||
if t, err := time.Parse(time.RFC3339, updatedAt.String); err == nil {
|
|
||||||
order.UpdatedAt = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if filledAt.Valid {
|
|
||||||
if t, err := time.Parse(time.RFC3339, filledAt.String); err == nil {
|
|
||||||
order.FilledAt = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &order, nil
|
return &order, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTraderOrders 获取trader的订单列表
|
// GetTraderOrders gets trader's order list
|
||||||
func (s *OrderStore) GetTraderOrders(traderID string, limit int) ([]*TraderOrder, error) {
|
func (s *OrderStore) GetTraderOrders(traderID string, limit int) ([]*TraderOrder, error) {
|
||||||
rows, err := s.db.Query(`
|
var orders []*TraderOrder
|
||||||
SELECT id, trader_id, exchange_id, exchange_type, exchange_order_id, client_order_id,
|
err := s.db.Where("trader_id = ?", traderID).
|
||||||
symbol, side, position_side, type, time_in_force,
|
Order("created_at DESC").
|
||||||
quantity, price, stop_price, status,
|
Limit(limit).
|
||||||
filled_quantity, avg_fill_price, commission, commission_asset,
|
Find(&orders).Error
|
||||||
leverage, reduce_only, close_position, working_type, price_protect,
|
|
||||||
order_action, related_position_id,
|
|
||||||
created_at, updated_at, filled_at
|
|
||||||
FROM trader_orders
|
|
||||||
WHERE trader_id = ?
|
|
||||||
ORDER BY created_at DESC
|
|
||||||
LIMIT ?
|
|
||||||
`, traderID, limit)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to query orders: %w", err)
|
return nil, fmt.Errorf("failed to query orders: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var orders []*TraderOrder
|
|
||||||
for rows.Next() {
|
|
||||||
var order TraderOrder
|
|
||||||
var createdAt, updatedAt, filledAt sql.NullString
|
|
||||||
err := rows.Scan(
|
|
||||||
&order.ID, &order.TraderID, &order.ExchangeID, &order.ExchangeType, &order.ExchangeOrderID, &order.ClientOrderID,
|
|
||||||
&order.Symbol, &order.Side, &order.PositionSide, &order.Type, &order.TimeInForce,
|
|
||||||
&order.Quantity, &order.Price, &order.StopPrice, &order.Status,
|
|
||||||
&order.FilledQuantity, &order.AvgFillPrice, &order.Commission, &order.CommissionAsset,
|
|
||||||
&order.Leverage, &order.ReduceOnly, &order.ClosePosition, &order.WorkingType, &order.PriceProtect,
|
|
||||||
&order.OrderAction, &order.RelatedPositionID,
|
|
||||||
&createdAt, &updatedAt, &filledAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse times
|
|
||||||
if createdAt.Valid {
|
|
||||||
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
|
|
||||||
order.CreatedAt = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if updatedAt.Valid {
|
|
||||||
if t, err := time.Parse(time.RFC3339, updatedAt.String); err == nil {
|
|
||||||
order.UpdatedAt = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if filledAt.Valid {
|
|
||||||
if t, err := time.Parse(time.RFC3339, filledAt.String); err == nil {
|
|
||||||
order.FilledAt = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
orders = append(orders, &order)
|
|
||||||
}
|
|
||||||
|
|
||||||
return orders, nil
|
return orders, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOrderFills 获取订单的成交记录
|
// GetOrderFills gets order's fill records
|
||||||
func (s *OrderStore) GetOrderFills(orderID int64) ([]*TraderFill, error) {
|
func (s *OrderStore) GetOrderFills(orderID int64) ([]*TraderFill, error) {
|
||||||
rows, err := s.db.Query(`
|
var fills []*TraderFill
|
||||||
SELECT id, trader_id, exchange_id, exchange_type, order_id, exchange_order_id, exchange_trade_id,
|
err := s.db.Where("order_id = ?", orderID).
|
||||||
symbol, side, price, quantity, quote_quantity,
|
Order("created_at ASC").
|
||||||
commission, commission_asset, realized_pnl, is_maker,
|
Find(&fills).Error
|
||||||
created_at
|
|
||||||
FROM trader_fills
|
|
||||||
WHERE order_id = ?
|
|
||||||
ORDER BY created_at ASC
|
|
||||||
`, orderID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to query fills: %w", err)
|
return nil, fmt.Errorf("failed to query fills: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var fills []*TraderFill
|
|
||||||
for rows.Next() {
|
|
||||||
var fill TraderFill
|
|
||||||
var createdAt sql.NullString
|
|
||||||
err := rows.Scan(
|
|
||||||
&fill.ID, &fill.TraderID, &fill.ExchangeID, &fill.ExchangeType, &fill.OrderID, &fill.ExchangeOrderID, &fill.ExchangeTradeID,
|
|
||||||
&fill.Symbol, &fill.Side, &fill.Price, &fill.Quantity, &fill.QuoteQuantity,
|
|
||||||
&fill.Commission, &fill.CommissionAsset, &fill.RealizedPnL, &fill.IsMaker,
|
|
||||||
&createdAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if createdAt.Valid {
|
|
||||||
if t, err := time.Parse(time.RFC3339, createdAt.String); err == nil {
|
|
||||||
fill.CreatedAt = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fills = append(fills, &fill)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fills, nil
|
return fills, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTraderOrderStats 获取trader的订单统计
|
// GetTraderOrderStats gets trader's order statistics
|
||||||
func (s *OrderStore) GetTraderOrderStats(traderID string) (map[string]interface{}, error) {
|
func (s *OrderStore) GetTraderOrderStats(traderID string) (map[string]interface{}, error) {
|
||||||
var totalOrders, filledOrders, canceledOrders int
|
type result struct {
|
||||||
var totalCommission, totalVolume float64
|
TotalOrders int
|
||||||
|
FilledOrders int
|
||||||
err := s.db.QueryRow(`
|
CanceledOrders int
|
||||||
SELECT
|
TotalCommission float64
|
||||||
COUNT(*) as total_orders,
|
TotalVolume float64
|
||||||
SUM(CASE WHEN status = 'FILLED' THEN 1 ELSE 0 END) as filled_orders,
|
}
|
||||||
SUM(CASE WHEN status = 'CANCELED' THEN 1 ELSE 0 END) as canceled_orders,
|
var r result
|
||||||
SUM(commission) as total_commission,
|
|
||||||
SUM(filled_quantity * avg_fill_price) as total_volume
|
|
||||||
FROM trader_orders
|
|
||||||
WHERE trader_id = ?
|
|
||||||
`, traderID).Scan(&totalOrders, &filledOrders, &canceledOrders, &totalCommission, &totalVolume)
|
|
||||||
|
|
||||||
|
err := s.db.Model(&TraderOrder{}).
|
||||||
|
Select(`COUNT(*) as total_orders,
|
||||||
|
SUM(CASE WHEN status = 'FILLED' THEN 1 ELSE 0 END) as filled_orders,
|
||||||
|
SUM(CASE WHEN status = 'CANCELED' THEN 1 ELSE 0 END) as canceled_orders,
|
||||||
|
SUM(commission) as total_commission,
|
||||||
|
SUM(filled_quantity * avg_fill_price) as total_volume`).
|
||||||
|
Where("trader_id = ?", traderID).
|
||||||
|
Scan(&r).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get order stats: %w", err)
|
return nil, fmt.Errorf("failed to get order stats: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return map[string]interface{}{
|
return map[string]interface{}{
|
||||||
"total_orders": totalOrders,
|
"total_orders": r.TotalOrders,
|
||||||
"filled_orders": filledOrders,
|
"filled_orders": r.FilledOrders,
|
||||||
"canceled_orders": canceledOrders,
|
"canceled_orders": r.CanceledOrders,
|
||||||
"total_commission": totalCommission,
|
"total_commission": r.TotalCommission,
|
||||||
"total_volume": totalVolume,
|
"total_volume": r.TotalVolume,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanupDuplicateOrders 清理重复的订单记录(保留最早创建的记录)
|
// CleanupDuplicateOrders cleans up duplicate order records
|
||||||
func (s *OrderStore) CleanupDuplicateOrders() (int, error) {
|
func (s *OrderStore) CleanupDuplicateOrders() (int, error) {
|
||||||
result, err := s.db.Exec(`
|
result := s.db.Exec(`
|
||||||
DELETE FROM trader_orders
|
DELETE FROM trader_orders
|
||||||
WHERE id NOT IN (
|
WHERE id NOT IN (
|
||||||
SELECT MIN(id)
|
SELECT MIN(id)
|
||||||
@@ -497,17 +257,15 @@ func (s *OrderStore) CleanupDuplicateOrders() (int, error) {
|
|||||||
GROUP BY exchange_id, exchange_order_id
|
GROUP BY exchange_id, exchange_order_id
|
||||||
)
|
)
|
||||||
`)
|
`)
|
||||||
if err != nil {
|
if result.Error != nil {
|
||||||
return 0, fmt.Errorf("failed to cleanup duplicate orders: %w", err)
|
return 0, fmt.Errorf("failed to cleanup duplicate orders: %w", result.Error)
|
||||||
}
|
}
|
||||||
|
return int(result.RowsAffected), nil
|
||||||
rowsAffected, _ := result.RowsAffected()
|
|
||||||
return int(rowsAffected), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanupDuplicateFills 清理重复的成交记录(保留最早创建的记录)
|
// CleanupDuplicateFills cleans up duplicate fill records
|
||||||
func (s *OrderStore) CleanupDuplicateFills() (int, error) {
|
func (s *OrderStore) CleanupDuplicateFills() (int, error) {
|
||||||
result, err := s.db.Exec(`
|
result := s.db.Exec(`
|
||||||
DELETE FROM trader_fills
|
DELETE FROM trader_fills
|
||||||
WHERE id NOT IN (
|
WHERE id NOT IN (
|
||||||
SELECT MIN(id)
|
SELECT MIN(id)
|
||||||
@@ -515,73 +273,66 @@ func (s *OrderStore) CleanupDuplicateFills() (int, error) {
|
|||||||
GROUP BY exchange_id, exchange_trade_id
|
GROUP BY exchange_id, exchange_trade_id
|
||||||
)
|
)
|
||||||
`)
|
`)
|
||||||
if err != nil {
|
if result.Error != nil {
|
||||||
return 0, fmt.Errorf("failed to cleanup duplicate fills: %w", err)
|
return 0, fmt.Errorf("failed to cleanup duplicate fills: %w", result.Error)
|
||||||
}
|
}
|
||||||
|
return int(result.RowsAffected), nil
|
||||||
rowsAffected, _ := result.RowsAffected()
|
|
||||||
return int(rowsAffected), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDuplicateOrdersCount 获取重复订单的数量(用于诊断)
|
// GetDuplicateOrdersCount gets duplicate orders count
|
||||||
func (s *OrderStore) GetDuplicateOrdersCount() (int, error) {
|
func (s *OrderStore) GetDuplicateOrdersCount() (int, error) {
|
||||||
var count int
|
var total, distinct int64
|
||||||
err := s.db.QueryRow(`
|
s.db.Model(&TraderOrder{}).Count(&total)
|
||||||
SELECT COUNT(*) - COUNT(DISTINCT exchange_id || ',' || exchange_order_id)
|
|
||||||
FROM trader_orders
|
// Count distinct combinations
|
||||||
`).Scan(&count)
|
var distinctResult struct{ Count int64 }
|
||||||
return count, err
|
s.db.Model(&TraderOrder{}).
|
||||||
|
Select("COUNT(DISTINCT exchange_id || ',' || exchange_order_id) as count").
|
||||||
|
Scan(&distinctResult)
|
||||||
|
distinct = distinctResult.Count
|
||||||
|
|
||||||
|
return int(total - distinct), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDuplicateFillsCount 获取重复成交的数量(用于诊断)
|
// GetDuplicateFillsCount gets duplicate fills count
|
||||||
func (s *OrderStore) GetDuplicateFillsCount() (int, error) {
|
func (s *OrderStore) GetDuplicateFillsCount() (int, error) {
|
||||||
var count int
|
var total, distinct int64
|
||||||
err := s.db.QueryRow(`
|
s.db.Model(&TraderFill{}).Count(&total)
|
||||||
SELECT COUNT(*) - COUNT(DISTINCT exchange_id || ',' || exchange_trade_id)
|
|
||||||
FROM trader_fills
|
var distinctResult struct{ Count int64 }
|
||||||
`).Scan(&count)
|
s.db.Model(&TraderFill{}).
|
||||||
return count, err
|
Select("COUNT(DISTINCT exchange_id || ',' || exchange_trade_id) as count").
|
||||||
|
Scan(&distinctResult)
|
||||||
|
distinct = distinctResult.Count
|
||||||
|
|
||||||
|
return int(total - distinct), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMaxTradeIDsByExchange returns max trade ID for each symbol for a given exchange
|
// GetMaxTradeIDsByExchange returns max trade ID for each symbol for a given exchange
|
||||||
// Used for incremental sync - only fetch trades with ID > maxTradeID
|
|
||||||
func (s *OrderStore) GetMaxTradeIDsByExchange(exchangeID string) (map[string]int64, error) {
|
func (s *OrderStore) GetMaxTradeIDsByExchange(exchangeID string) (map[string]int64, error) {
|
||||||
rows, err := s.db.Query(`
|
type symbolMaxID struct {
|
||||||
SELECT symbol, MAX(CAST(exchange_trade_id AS INTEGER)) as max_trade_id
|
Symbol string
|
||||||
FROM trader_fills
|
MaxTradeID int64
|
||||||
WHERE exchange_id = ? AND exchange_trade_id != ''
|
}
|
||||||
GROUP BY symbol
|
var results []symbolMaxID
|
||||||
`, exchangeID)
|
|
||||||
|
err := s.db.Model(&TraderFill{}).
|
||||||
|
Select("symbol, MAX(CAST(exchange_trade_id AS INTEGER)) as max_trade_id").
|
||||||
|
Where("exchange_id = ? AND exchange_trade_id != ''", exchangeID).
|
||||||
|
Group("symbol").
|
||||||
|
Find(&results).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// If CAST fails (non-numeric trade IDs), fallback to string comparison
|
||||||
|
if strings.Contains(err.Error(), "CAST") || strings.Contains(err.Error(), "invalid") {
|
||||||
|
return make(map[string]int64), nil
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("failed to query max trade IDs: %w", err)
|
return nil, fmt.Errorf("failed to query max trade IDs: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
result := make(map[string]int64)
|
result := make(map[string]int64)
|
||||||
for rows.Next() {
|
for _, r := range results {
|
||||||
var symbol string
|
result[r.Symbol] = r.MaxTradeID
|
||||||
var maxID int64
|
|
||||||
if err := rows.Scan(&symbol, &maxID); err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
result[symbol] = maxID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// formatTimePtr formats time.Time to RFC3339 string, returns NULL for zero time
|
|
||||||
func formatTimePtr(t time.Time) interface{} {
|
|
||||||
if t.IsZero() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return t.Format(time.RFC3339)
|
|
||||||
}
|
|
||||||
|
|
||||||
// formatTimeOrNow returns the formatted time if not zero, otherwise returns now
|
|
||||||
func formatTimeOrNow(t time.Time, now time.Time) string {
|
|
||||||
if t.IsZero() {
|
|
||||||
return now.Format(time.RFC3339)
|
|
||||||
}
|
|
||||||
return t.Format(time.RFC3339)
|
|
||||||
}
|
|
||||||
|
|||||||
+534
-826
File diff suppressed because it is too large
Load Diff
+105
-81
@@ -7,12 +7,15 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"nofx/logger"
|
"nofx/logger"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Store unified data storage interface
|
// Store unified data storage interface
|
||||||
type Store struct {
|
type Store struct {
|
||||||
db *sql.DB
|
gdb *gorm.DB // GORM database connection
|
||||||
driver *DBDriver // Database driver for abstraction
|
db *sql.DB // Legacy sql.DB for backward compatibility
|
||||||
|
driver *DBDriver // Database driver for abstraction (legacy)
|
||||||
|
|
||||||
// Sub-stores (lazy initialization)
|
// Sub-stores (lazy initialization)
|
||||||
user *UserStore
|
user *UserStore
|
||||||
@@ -26,105 +29,103 @@ type Store struct {
|
|||||||
equity *EquityStore
|
equity *EquityStore
|
||||||
order *OrderStore
|
order *OrderStore
|
||||||
|
|
||||||
// Encryption functions
|
|
||||||
encryptFunc func(string) string
|
|
||||||
decryptFunc func(string) string
|
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates new Store instance (SQLite mode for backward compatibility)
|
// New creates new Store instance (SQLite mode for backward compatibility)
|
||||||
func New(dbPath string) (*Store, error) {
|
func New(dbPath string) (*Store, error) {
|
||||||
driver, err := NewDBDriver(DBConfig{Type: DBTypeSQLite, Path: dbPath})
|
gdb, err := InitGorm(dbPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s := &Store{db: driver.DB(), driver: driver}
|
// Get underlying sql.DB for legacy compatibility
|
||||||
|
sqlDB, err := gdb.DB()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Store{gdb: gdb, db: sqlDB}
|
||||||
|
|
||||||
// Initialize all table structures
|
// Initialize all table structures
|
||||||
if err := s.initTables(); err != nil {
|
if err := s.initTables(); err != nil {
|
||||||
driver.Close()
|
sqlDB.Close()
|
||||||
return nil, fmt.Errorf("failed to initialize table structure: %w", err)
|
return nil, fmt.Errorf("failed to initialize table structure: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize default data
|
// Initialize default data
|
||||||
if err := s.initDefaultData(); err != nil {
|
if err := s.initDefaultData(); err != nil {
|
||||||
driver.Close()
|
sqlDB.Close()
|
||||||
return nil, fmt.Errorf("failed to initialize default data: %w", err)
|
return nil, fmt.Errorf("failed to initialize default data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Infof("✅ Database initialized (type: %s)", driver.Type)
|
logger.Infof("✅ Database initialized (GORM, SQLite)")
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFromEnv creates new Store instance from environment variables
|
// NewWithConfig creates new Store instance with provided database configuration
|
||||||
// DB_TYPE: sqlite (default) or postgres
|
func NewWithConfig(cfg DBConfig) (*Store, error) {
|
||||||
// For SQLite: DB_PATH (default: data/data.db)
|
gdb, err := InitGormWithConfig(cfg)
|
||||||
// For PostgreSQL: DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME, DB_SSLMODE
|
|
||||||
func NewFromEnv() (*Store, error) {
|
|
||||||
driver, err := NewDBDriverFromEnv()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s := &Store{db: driver.DB(), driver: driver}
|
// Get underlying sql.DB for legacy compatibility
|
||||||
|
sqlDB, err := gdb.DB()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Store{gdb: gdb, db: sqlDB}
|
||||||
|
|
||||||
// Initialize all table structures
|
// Initialize all table structures
|
||||||
if err := s.initTables(); err != nil {
|
if err := s.initTables(); err != nil {
|
||||||
driver.Close()
|
sqlDB.Close()
|
||||||
return nil, fmt.Errorf("failed to initialize table structure: %w", err)
|
return nil, fmt.Errorf("failed to initialize table structure: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize default data
|
// Initialize default data
|
||||||
if err := s.initDefaultData(); err != nil {
|
if err := s.initDefaultData(); err != nil {
|
||||||
driver.Close()
|
sqlDB.Close()
|
||||||
return nil, fmt.Errorf("failed to initialize default data: %w", err)
|
return nil, fmt.Errorf("failed to initialize default data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Infof("✅ Database initialized (type: %s)", driver.Type)
|
dbTypeStr := "SQLite"
|
||||||
|
if cfg.Type == DBTypePostgres {
|
||||||
|
dbTypeStr = "PostgreSQL"
|
||||||
|
}
|
||||||
|
logger.Infof("✅ Database initialized (GORM, %s)", dbTypeStr)
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFromDB creates Store from existing database connection
|
// NewFromGorm creates Store from existing GORM connection
|
||||||
|
func NewFromGorm(gdb *gorm.DB) (*Store, error) {
|
||||||
|
sqlDB, err := gdb.DB()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Store{gdb: gdb, db: sqlDB}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFromDB creates Store from existing database connection (legacy)
|
||||||
|
// Deprecated: Use NewFromGorm instead
|
||||||
func NewFromDB(db *sql.DB) *Store {
|
func NewFromDB(db *sql.DB) *Store {
|
||||||
return &Store{db: db}
|
return &Store{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetCryptoFuncs sets encryption/decryption functions
|
// initTables initializes all database tables using GORM AutoMigrate
|
||||||
func (s *Store) SetCryptoFuncs(encrypt, decrypt func(string) string) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
s.encryptFunc = encrypt
|
|
||||||
s.decryptFunc = decrypt
|
|
||||||
|
|
||||||
// Update already initialized sub-stores
|
|
||||||
if s.aiModel != nil {
|
|
||||||
s.aiModel.encryptFunc = encrypt
|
|
||||||
s.aiModel.decryptFunc = decrypt
|
|
||||||
}
|
|
||||||
if s.exchange != nil {
|
|
||||||
s.exchange.encryptFunc = encrypt
|
|
||||||
s.exchange.decryptFunc = decrypt
|
|
||||||
}
|
|
||||||
if s.trader != nil {
|
|
||||||
s.trader.decryptFunc = decrypt
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// initTables initializes all database tables
|
|
||||||
func (s *Store) initTables() error {
|
func (s *Store) initTables() error {
|
||||||
// Initialize system config table first
|
// Create system_config table (GORM handles this via raw SQL for simplicity)
|
||||||
if _, err := s.db.Exec(`
|
if err := s.gdb.Exec(`
|
||||||
CREATE TABLE IF NOT EXISTS system_config (
|
CREATE TABLE IF NOT EXISTS system_config (
|
||||||
key TEXT PRIMARY KEY,
|
key TEXT PRIMARY KEY,
|
||||||
value TEXT NOT NULL
|
value TEXT NOT NULL
|
||||||
)
|
)
|
||||||
`); err != nil {
|
`).Error; err != nil {
|
||||||
return fmt.Errorf("failed to create system_config table: %w", err)
|
return fmt.Errorf("failed to create system_config table: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize in dependency order
|
// Initialize sub-store tables
|
||||||
if err := s.User().initTables(); err != nil {
|
if err := s.User().initTables(); err != nil {
|
||||||
return fmt.Errorf("failed to initialize user tables: %w", err)
|
return fmt.Errorf("failed to initialize user tables: %w", err)
|
||||||
}
|
}
|
||||||
@@ -183,7 +184,7 @@ func (s *Store) User() *UserStore {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if s.user == nil {
|
if s.user == nil {
|
||||||
s.user = &UserStore{db: s.db}
|
s.user = NewUserStore(s.gdb)
|
||||||
}
|
}
|
||||||
return s.user
|
return s.user
|
||||||
}
|
}
|
||||||
@@ -193,11 +194,7 @@ func (s *Store) AIModel() *AIModelStore {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if s.aiModel == nil {
|
if s.aiModel == nil {
|
||||||
s.aiModel = &AIModelStore{
|
s.aiModel = NewAIModelStore(s.gdb)
|
||||||
db: s.db,
|
|
||||||
encryptFunc: s.encryptFunc,
|
|
||||||
decryptFunc: s.decryptFunc,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return s.aiModel
|
return s.aiModel
|
||||||
}
|
}
|
||||||
@@ -207,11 +204,7 @@ func (s *Store) Exchange() *ExchangeStore {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if s.exchange == nil {
|
if s.exchange == nil {
|
||||||
s.exchange = &ExchangeStore{
|
s.exchange = NewExchangeStore(s.gdb)
|
||||||
db: s.db,
|
|
||||||
encryptFunc: s.encryptFunc,
|
|
||||||
decryptFunc: s.decryptFunc,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return s.exchange
|
return s.exchange
|
||||||
}
|
}
|
||||||
@@ -221,10 +214,7 @@ func (s *Store) Trader() *TraderStore {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if s.trader == nil {
|
if s.trader == nil {
|
||||||
s.trader = &TraderStore{
|
s.trader = NewTraderStore(s.gdb)
|
||||||
db: s.db,
|
|
||||||
decryptFunc: s.decryptFunc,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return s.trader
|
return s.trader
|
||||||
}
|
}
|
||||||
@@ -234,7 +224,7 @@ func (s *Store) Decision() *DecisionStore {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if s.decision == nil {
|
if s.decision == nil {
|
||||||
s.decision = &DecisionStore{db: s.db}
|
s.decision = NewDecisionStore(s.gdb)
|
||||||
}
|
}
|
||||||
return s.decision
|
return s.decision
|
||||||
}
|
}
|
||||||
@@ -244,7 +234,7 @@ func (s *Store) Backtest() *BacktestStore {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if s.backtest == nil {
|
if s.backtest == nil {
|
||||||
s.backtest = &BacktestStore{db: s.db}
|
s.backtest = NewBacktestStore(s.gdb)
|
||||||
}
|
}
|
||||||
return s.backtest
|
return s.backtest
|
||||||
}
|
}
|
||||||
@@ -254,7 +244,7 @@ func (s *Store) Position() *PositionStore {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if s.position == nil {
|
if s.position == nil {
|
||||||
s.position = NewPositionStore(s.db)
|
s.position = NewPositionStore(s.gdb)
|
||||||
}
|
}
|
||||||
return s.position
|
return s.position
|
||||||
}
|
}
|
||||||
@@ -264,7 +254,7 @@ func (s *Store) Strategy() *StrategyStore {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if s.strategy == nil {
|
if s.strategy == nil {
|
||||||
s.strategy = &StrategyStore{db: s.db}
|
s.strategy = NewStrategyStore(s.gdb)
|
||||||
}
|
}
|
||||||
return s.strategy
|
return s.strategy
|
||||||
}
|
}
|
||||||
@@ -274,7 +264,7 @@ func (s *Store) Equity() *EquityStore {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if s.equity == nil {
|
if s.equity == nil {
|
||||||
s.equity = &EquityStore{db: s.db}
|
s.equity = NewEquityStore(s.gdb)
|
||||||
}
|
}
|
||||||
return s.equity
|
return s.equity
|
||||||
}
|
}
|
||||||
@@ -284,7 +274,7 @@ func (s *Store) Order() *OrderStore {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if s.order == nil {
|
if s.order == nil {
|
||||||
s.order = NewOrderStore(s.db)
|
s.order = NewOrderStore(s.gdb)
|
||||||
}
|
}
|
||||||
return s.order
|
return s.order
|
||||||
}
|
}
|
||||||
@@ -294,10 +284,18 @@ func (s *Store) Close() error {
|
|||||||
if s.driver != nil {
|
if s.driver != nil {
|
||||||
return s.driver.Close()
|
return s.driver.Close()
|
||||||
}
|
}
|
||||||
return s.db.Close()
|
if s.db != nil {
|
||||||
|
return s.db.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Driver returns database driver for abstraction
|
// GormDB returns the GORM database connection
|
||||||
|
func (s *Store) GormDB() *gorm.DB {
|
||||||
|
return s.gdb
|
||||||
|
}
|
||||||
|
|
||||||
|
// Driver returns database driver for abstraction (legacy)
|
||||||
func (s *Store) Driver() *DBDriver {
|
func (s *Store) Driver() *DBDriver {
|
||||||
return s.driver
|
return s.driver
|
||||||
}
|
}
|
||||||
@@ -307,11 +305,25 @@ func (s *Store) DBType() DBType {
|
|||||||
if s.driver != nil {
|
if s.driver != nil {
|
||||||
return s.driver.Type
|
return s.driver.Type
|
||||||
}
|
}
|
||||||
|
// Detect from GORM dialector
|
||||||
|
if s.gdb != nil {
|
||||||
|
switch s.gdb.Dialector.Name() {
|
||||||
|
case "postgres":
|
||||||
|
return DBTypePostgres
|
||||||
|
default:
|
||||||
|
return DBTypeSQLite
|
||||||
|
}
|
||||||
|
}
|
||||||
return DBTypeSQLite
|
return DBTypeSQLite
|
||||||
}
|
}
|
||||||
|
|
||||||
// DB gets underlying database connection (for legacy code compatibility, gradually deprecated)
|
// q converts query placeholders for current database type (legacy helper)
|
||||||
// Deprecated: use Store methods instead
|
func (s *Store) q(query string) string {
|
||||||
|
return convertQuery(query, s.DBType())
|
||||||
|
}
|
||||||
|
|
||||||
|
// DB gets underlying database connection (for legacy code compatibility)
|
||||||
|
// Deprecated: use GormDB() instead
|
||||||
func (s *Store) DB() *sql.DB {
|
func (s *Store) DB() *sql.DB {
|
||||||
return s.db
|
return s.db
|
||||||
}
|
}
|
||||||
@@ -319,24 +331,36 @@ func (s *Store) DB() *sql.DB {
|
|||||||
// GetSystemConfig gets a system configuration value by key
|
// GetSystemConfig gets a system configuration value by key
|
||||||
func (s *Store) GetSystemConfig(key string) (string, error) {
|
func (s *Store) GetSystemConfig(key string) (string, error) {
|
||||||
var value string
|
var value string
|
||||||
err := s.db.QueryRow(`SELECT value FROM system_config WHERE key = ?`, key).Scan(&value)
|
result := s.gdb.Raw("SELECT value FROM system_config WHERE key = ?", key).Scan(&value)
|
||||||
if err == sql.ErrNoRows {
|
if result.Error != nil {
|
||||||
|
if result.Error == gorm.ErrRecordNotFound {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return "", result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
return value, err
|
return value, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSystemConfig sets a system configuration value
|
// SetSystemConfig sets a system configuration value
|
||||||
func (s *Store) SetSystemConfig(key, value string) error {
|
func (s *Store) SetSystemConfig(key, value string) error {
|
||||||
_, err := s.db.Exec(`
|
// Use GORM-compatible upsert
|
||||||
|
return s.gdb.Exec(`
|
||||||
INSERT INTO system_config (key, value) VALUES (?, ?)
|
INSERT INTO system_config (key, value) VALUES (?, ?)
|
||||||
ON CONFLICT(key) DO UPDATE SET value = excluded.value
|
ON CONFLICT(key) DO UPDATE SET value = excluded.value
|
||||||
`, key, value)
|
`, key, value).Error
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transaction executes transaction
|
// Transaction executes transaction with GORM
|
||||||
func (s *Store) Transaction(fn func(tx *sql.Tx) error) error {
|
func (s *Store) Transaction(fn func(tx *gorm.DB) error) error {
|
||||||
|
return s.gdb.Transaction(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TransactionSQL executes transaction with sql.Tx (legacy)
|
||||||
|
// Deprecated: Use Transaction() instead
|
||||||
|
func (s *Store) TransactionSQL(fn func(tx *sql.Tx) error) error {
|
||||||
tx, err := s.db.Begin()
|
tx, err := s.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||||
|
|||||||
+57
-155
@@ -1,30 +1,33 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StrategyStore strategy storage
|
// StrategyStore strategy storage
|
||||||
type StrategyStore struct {
|
type StrategyStore struct {
|
||||||
db *sql.DB
|
db *gorm.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// Strategy strategy configuration
|
// Strategy strategy configuration
|
||||||
type Strategy struct {
|
type Strategy struct {
|
||||||
ID string `json:"id"`
|
ID string `gorm:"primaryKey" json:"id"`
|
||||||
UserID string `json:"user_id"`
|
UserID string `gorm:"column:user_id;not null;default:'';index" json:"user_id"`
|
||||||
Name string `json:"name"`
|
Name string `gorm:"not null" json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `gorm:"default:''" json:"description"`
|
||||||
IsActive bool `json:"is_active"` // whether it is active (a user can only have one active strategy)
|
IsActive bool `gorm:"column:is_active;default:false;index" json:"is_active"`
|
||||||
IsDefault bool `json:"is_default"` // whether it is a system default strategy
|
IsDefault bool `gorm:"column:is_default;default:false" json:"is_default"`
|
||||||
Config string `json:"config"` // strategy configuration in JSON format
|
Config string `gorm:"not null;default:'{}'" json:"config"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (Strategy) TableName() string { return "strategies" }
|
||||||
|
|
||||||
// StrategyConfig strategy configuration details (JSON structure)
|
// StrategyConfig strategy configuration details (JSON structure)
|
||||||
type StrategyConfig struct {
|
type StrategyConfig struct {
|
||||||
// coin source configuration
|
// coin source configuration
|
||||||
@@ -136,24 +139,6 @@ type ExternalDataSource struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RiskControlConfig risk control configuration
|
// RiskControlConfig risk control configuration
|
||||||
// All parameters are clearly defined without ambiguity:
|
|
||||||
//
|
|
||||||
// Position Limits:
|
|
||||||
// - MaxPositions: max number of coins held simultaneously (CODE ENFORCED)
|
|
||||||
//
|
|
||||||
// Trading Leverage (exchange leverage for opening positions):
|
|
||||||
// - BTCETHMaxLeverage: BTC/ETH max exchange leverage (AI guided)
|
|
||||||
// - AltcoinMaxLeverage: Altcoin max exchange leverage (AI guided)
|
|
||||||
//
|
|
||||||
// Position Value Limits (single position notional value / account equity):
|
|
||||||
// - BTCETHMaxPositionValueRatio: BTC/ETH max = equity × ratio (CODE ENFORCED)
|
|
||||||
// - AltcoinMaxPositionValueRatio: Altcoin max = equity × ratio (CODE ENFORCED)
|
|
||||||
//
|
|
||||||
// Risk Controls:
|
|
||||||
// - MaxMarginUsage: max margin utilization percentage (CODE ENFORCED)
|
|
||||||
// - MinPositionSize: minimum position size in USDT (CODE ENFORCED)
|
|
||||||
// - MinRiskRewardRatio: min take_profit / stop_loss ratio (AI guided)
|
|
||||||
// - MinConfidence: min AI confidence to open position (AI guided)
|
|
||||||
type RiskControlConfig struct {
|
type RiskControlConfig struct {
|
||||||
// Max number of coins held simultaneously (CODE ENFORCED)
|
// Max number of coins held simultaneously (CODE ENFORCED)
|
||||||
MaxPositions int `json:"max_positions"`
|
MaxPositions int `json:"max_positions"`
|
||||||
@@ -179,38 +164,21 @@ type RiskControlConfig struct {
|
|||||||
MinConfidence int `json:"min_confidence"`
|
MinConfidence int `json:"min_confidence"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewStrategyStore creates a new StrategyStore
|
||||||
|
func NewStrategyStore(db *gorm.DB) *StrategyStore {
|
||||||
|
return &StrategyStore{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *StrategyStore) initTables() error {
|
func (s *StrategyStore) initTables() error {
|
||||||
_, err := s.db.Exec(`
|
// For PostgreSQL with existing table, skip AutoMigrate
|
||||||
CREATE TABLE IF NOT EXISTS strategies (
|
if s.db.Dialector.Name() == "postgres" {
|
||||||
id TEXT PRIMARY KEY,
|
var tableExists int64
|
||||||
user_id TEXT NOT NULL DEFAULT '',
|
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'strategies'`).Scan(&tableExists)
|
||||||
name TEXT NOT NULL,
|
if tableExists > 0 {
|
||||||
description TEXT DEFAULT '',
|
return nil
|
||||||
is_active BOOLEAN DEFAULT 0,
|
}
|
||||||
is_default BOOLEAN DEFAULT 0,
|
|
||||||
config TEXT NOT NULL DEFAULT '{}',
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
return s.db.AutoMigrate(&Strategy{})
|
||||||
// create indexes
|
|
||||||
_, _ = s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_strategies_user_id ON strategies(user_id)`)
|
|
||||||
_, _ = s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_strategies_is_active ON strategies(is_active)`)
|
|
||||||
|
|
||||||
// trigger: automatically update updated_at on update
|
|
||||||
_, err = s.db.Exec(`
|
|
||||||
CREATE TRIGGER IF NOT EXISTS update_strategies_updated_at
|
|
||||||
AFTER UPDATE ON strategies
|
|
||||||
BEGIN
|
|
||||||
UPDATE strategies SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
|
||||||
END
|
|
||||||
`)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StrategyStore) initDefaultData() error {
|
func (s *StrategyStore) initDefaultData() error {
|
||||||
@@ -322,159 +290,93 @@ Only enter positions when multiple signals resonate. Freely use any effective an
|
|||||||
|
|
||||||
// Create create a strategy
|
// Create create a strategy
|
||||||
func (s *StrategyStore) Create(strategy *Strategy) error {
|
func (s *StrategyStore) Create(strategy *Strategy) error {
|
||||||
_, err := s.db.Exec(`
|
return s.db.Create(strategy).Error
|
||||||
INSERT INTO strategies (id, user_id, name, description, is_active, is_default, config)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
||||||
`, strategy.ID, strategy.UserID, strategy.Name, strategy.Description, strategy.IsActive, strategy.IsDefault, strategy.Config)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update update a strategy
|
// Update update a strategy
|
||||||
func (s *StrategyStore) Update(strategy *Strategy) error {
|
func (s *StrategyStore) Update(strategy *Strategy) error {
|
||||||
_, err := s.db.Exec(`
|
return s.db.Model(&Strategy{}).
|
||||||
UPDATE strategies SET
|
Where("id = ? AND user_id = ?", strategy.ID, strategy.UserID).
|
||||||
name = ?, description = ?, config = ?, updated_at = CURRENT_TIMESTAMP
|
Updates(map[string]interface{}{
|
||||||
WHERE id = ? AND user_id = ?
|
"name": strategy.Name,
|
||||||
`, strategy.Name, strategy.Description, strategy.Config, strategy.ID, strategy.UserID)
|
"description": strategy.Description,
|
||||||
return err
|
"config": strategy.Config,
|
||||||
|
"updated_at": time.Now(),
|
||||||
|
}).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete delete a strategy
|
// Delete delete a strategy
|
||||||
func (s *StrategyStore) Delete(userID, id string) error {
|
func (s *StrategyStore) Delete(userID, id string) error {
|
||||||
// do not allow deleting system default strategy
|
// do not allow deleting system default strategy
|
||||||
var isDefault bool
|
var st Strategy
|
||||||
s.db.QueryRow(`SELECT is_default FROM strategies WHERE id = ?`, id).Scan(&isDefault)
|
if err := s.db.Where("id = ?", id).First(&st).Error; err == nil && st.IsDefault {
|
||||||
if isDefault {
|
|
||||||
return fmt.Errorf("cannot delete system default strategy")
|
return fmt.Errorf("cannot delete system default strategy")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := s.db.Exec(`DELETE FROM strategies WHERE id = ? AND user_id = ?`, id, userID)
|
return s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Strategy{}).Error
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// List get user's strategy list
|
// List get user's strategy list
|
||||||
func (s *StrategyStore) List(userID string) ([]*Strategy, error) {
|
func (s *StrategyStore) List(userID string) ([]*Strategy, error) {
|
||||||
// get user's own strategies + system default strategy
|
var strategies []*Strategy
|
||||||
rows, err := s.db.Query(`
|
err := s.db.Where("user_id = ? OR is_default = ?", userID, true).
|
||||||
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
|
Order("is_default DESC, created_at DESC").
|
||||||
FROM strategies
|
Find(&strategies).Error
|
||||||
WHERE user_id = ? OR is_default = 1
|
|
||||||
ORDER BY is_default DESC, created_at DESC
|
|
||||||
`, userID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var strategies []*Strategy
|
|
||||||
for rows.Next() {
|
|
||||||
var st Strategy
|
|
||||||
var createdAt, updatedAt string
|
|
||||||
err := rows.Scan(
|
|
||||||
&st.ID, &st.UserID, &st.Name, &st.Description,
|
|
||||||
&st.IsActive, &st.IsDefault, &st.Config,
|
|
||||||
&createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
strategies = append(strategies, &st)
|
|
||||||
}
|
|
||||||
return strategies, nil
|
return strategies, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get get a single strategy
|
// Get get a single strategy
|
||||||
func (s *StrategyStore) Get(userID, id string) (*Strategy, error) {
|
func (s *StrategyStore) Get(userID, id string) (*Strategy, error) {
|
||||||
var st Strategy
|
var st Strategy
|
||||||
var createdAt, updatedAt string
|
err := s.db.Where("id = ? AND (user_id = ? OR is_default = ?)", id, userID, true).
|
||||||
err := s.db.QueryRow(`
|
First(&st).Error
|
||||||
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
|
|
||||||
FROM strategies
|
|
||||||
WHERE id = ? AND (user_id = ? OR is_default = 1)
|
|
||||||
`, id, userID).Scan(
|
|
||||||
&st.ID, &st.UserID, &st.Name, &st.Description,
|
|
||||||
&st.IsActive, &st.IsDefault, &st.Config,
|
|
||||||
&createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
return &st, nil
|
return &st, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActive get user's currently active strategy
|
// GetActive get user's currently active strategy
|
||||||
func (s *StrategyStore) GetActive(userID string) (*Strategy, error) {
|
func (s *StrategyStore) GetActive(userID string) (*Strategy, error) {
|
||||||
var st Strategy
|
var st Strategy
|
||||||
var createdAt, updatedAt string
|
err := s.db.Where("user_id = ? AND is_active = ?", userID, true).First(&st).Error
|
||||||
err := s.db.QueryRow(`
|
if err == gorm.ErrRecordNotFound {
|
||||||
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
|
|
||||||
FROM strategies
|
|
||||||
WHERE user_id = ? AND is_active = 1
|
|
||||||
`, userID).Scan(
|
|
||||||
&st.ID, &st.UserID, &st.Name, &st.Description,
|
|
||||||
&st.IsActive, &st.IsDefault, &st.Config,
|
|
||||||
&createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
// no active strategy, return system default strategy
|
// no active strategy, return system default strategy
|
||||||
return s.GetDefault()
|
return s.GetDefault()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
return &st, nil
|
return &st, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDefault get system default strategy
|
// GetDefault get system default strategy
|
||||||
func (s *StrategyStore) GetDefault() (*Strategy, error) {
|
func (s *StrategyStore) GetDefault() (*Strategy, error) {
|
||||||
var st Strategy
|
var st Strategy
|
||||||
var createdAt, updatedAt string
|
err := s.db.Where("is_default = ?", true).First(&st).Error
|
||||||
err := s.db.QueryRow(`
|
|
||||||
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
|
|
||||||
FROM strategies
|
|
||||||
WHERE is_default = 1
|
|
||||||
LIMIT 1
|
|
||||||
`).Scan(
|
|
||||||
&st.ID, &st.UserID, &st.Name, &st.Description,
|
|
||||||
&st.IsActive, &st.IsDefault, &st.Config,
|
|
||||||
&createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
st.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
st.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
return &st, nil
|
return &st, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetActive set active strategy (will first deactivate other strategies)
|
// SetActive set active strategy (will first deactivate other strategies)
|
||||||
func (s *StrategyStore) SetActive(userID, strategyID string) error {
|
func (s *StrategyStore) SetActive(userID, strategyID string) error {
|
||||||
// begin transaction
|
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
tx, err := s.db.Begin()
|
// first deactivate all strategies for the user
|
||||||
if err != nil {
|
if err := tx.Model(&Strategy{}).Where("user_id = ?", userID).
|
||||||
return err
|
Update("is_active", false).Error; err != nil {
|
||||||
}
|
return err
|
||||||
defer tx.Rollback()
|
}
|
||||||
|
|
||||||
// first deactivate all strategies for the user
|
// activate specified strategy
|
||||||
_, err = tx.Exec(`UPDATE strategies SET is_active = 0 WHERE user_id = ?`, userID)
|
return tx.Model(&Strategy{}).
|
||||||
if err != nil {
|
Where("id = ? AND (user_id = ? OR is_default = ?)", strategyID, userID, true).
|
||||||
return err
|
Update("is_active", true).Error
|
||||||
}
|
})
|
||||||
|
|
||||||
// activate specified strategy
|
|
||||||
_, err = tx.Exec(`UPDATE strategies SET is_active = 1 WHERE id = ? AND (user_id = ? OR is_default = 1)`, strategyID, userID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Duplicate duplicate a strategy (used to create custom strategy based on default strategy)
|
// Duplicate duplicate a strategy (used to create custom strategy based on default strategy)
|
||||||
|
|||||||
+111
-374
@@ -1,43 +1,52 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TraderStore trader storage
|
// TraderStore trader storage
|
||||||
type TraderStore struct {
|
type TraderStore struct {
|
||||||
db *sql.DB
|
db *gorm.DB
|
||||||
decryptFunc func(string) string
|
}
|
||||||
|
|
||||||
|
// NewTraderStore creates a new trader store
|
||||||
|
func NewTraderStore(db *gorm.DB) *TraderStore {
|
||||||
|
return &TraderStore{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trader trader configuration
|
// Trader trader configuration
|
||||||
type Trader struct {
|
type Trader struct {
|
||||||
ID string `json:"id"`
|
ID string `gorm:"primaryKey" json:"id"`
|
||||||
UserID string `json:"user_id"`
|
UserID string `gorm:"column:user_id;not null;default:default;index" json:"user_id"`
|
||||||
Name string `json:"name"`
|
Name string `gorm:"column:name;not null" json:"name"`
|
||||||
AIModelID string `json:"ai_model_id"`
|
AIModelID string `gorm:"column:ai_model_id;not null" json:"ai_model_id"`
|
||||||
ExchangeID string `json:"exchange_id"`
|
ExchangeID string `gorm:"column:exchange_id;not null" json:"exchange_id"`
|
||||||
StrategyID string `json:"strategy_id"` // Associated strategy ID
|
StrategyID string `gorm:"column:strategy_id;default:''" json:"strategy_id"`
|
||||||
InitialBalance float64 `json:"initial_balance"`
|
InitialBalance float64 `gorm:"column:initial_balance;not null" json:"initial_balance"`
|
||||||
ScanIntervalMinutes int `json:"scan_interval_minutes"`
|
ScanIntervalMinutes int `gorm:"column:scan_interval_minutes;default:3" json:"scan_interval_minutes"`
|
||||||
IsRunning bool `json:"is_running"`
|
IsRunning bool `gorm:"column:is_running;default:false" json:"is_running"`
|
||||||
IsCrossMargin bool `json:"is_cross_margin"`
|
IsCrossMargin bool `gorm:"column:is_cross_margin;default:true" json:"is_cross_margin"`
|
||||||
ShowInCompetition bool `json:"show_in_competition"` // Whether to show in competition page
|
ShowInCompetition bool `gorm:"column:show_in_competition;default:true" json:"show_in_competition"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"`
|
||||||
|
|
||||||
// Following fields are deprecated, kept for backward compatibility, new traders should use StrategyID
|
// Following fields are deprecated, kept for backward compatibility, new traders should use StrategyID
|
||||||
BTCETHLeverage int `json:"btc_eth_leverage,omitempty"`
|
BTCETHLeverage int `gorm:"column:btc_eth_leverage;default:5" json:"btc_eth_leverage,omitempty"`
|
||||||
AltcoinLeverage int `json:"altcoin_leverage,omitempty"`
|
AltcoinLeverage int `gorm:"column:altcoin_leverage;default:5" json:"altcoin_leverage,omitempty"`
|
||||||
TradingSymbols string `json:"trading_symbols,omitempty"`
|
TradingSymbols string `gorm:"column:trading_symbols;default:''" json:"trading_symbols,omitempty"`
|
||||||
UseCoinPool bool `json:"use_coin_pool,omitempty"`
|
UseCoinPool bool `gorm:"column:use_coin_pool;default:false" json:"use_coin_pool,omitempty"`
|
||||||
UseOITop bool `json:"use_oi_top,omitempty"`
|
UseOITop bool `gorm:"column:use_oi_top;default:false" json:"use_oi_top,omitempty"`
|
||||||
CustomPrompt string `json:"custom_prompt,omitempty"`
|
CustomPrompt string `gorm:"column:custom_prompt;default:''" json:"custom_prompt,omitempty"`
|
||||||
OverrideBasePrompt bool `json:"override_base_prompt,omitempty"`
|
OverrideBasePrompt bool `gorm:"column:override_base_prompt;default:false" json:"override_base_prompt,omitempty"`
|
||||||
SystemPromptTemplate string `json:"system_prompt_template,omitempty"`
|
SystemPromptTemplate string `gorm:"column:system_prompt_template;default:default" json:"system_prompt_template,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName returns the table name for Trader
|
||||||
|
func (Trader) TableName() string {
|
||||||
|
return "traders"
|
||||||
}
|
}
|
||||||
|
|
||||||
// TraderFullConfig trader full configuration (includes AI model, exchange and strategy)
|
// TraderFullConfig trader full configuration (includes AI model, exchange and strategy)
|
||||||
@@ -45,331 +54,130 @@ type TraderFullConfig struct {
|
|||||||
Trader *Trader
|
Trader *Trader
|
||||||
AIModel *AIModel
|
AIModel *AIModel
|
||||||
Exchange *Exchange
|
Exchange *Exchange
|
||||||
Strategy *Strategy // Associated strategy configuration
|
Strategy *Strategy
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TraderStore) initTables() error {
|
func (s *TraderStore) initTables() error {
|
||||||
_, err := s.db.Exec(`
|
// For PostgreSQL with existing table, skip AutoMigrate
|
||||||
CREATE TABLE IF NOT EXISTS traders (
|
if s.db.Dialector.Name() == "postgres" {
|
||||||
id TEXT PRIMARY KEY,
|
var tableExists int64
|
||||||
user_id TEXT NOT NULL DEFAULT 'default',
|
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'traders'`).Scan(&tableExists)
|
||||||
name TEXT NOT NULL,
|
if tableExists > 0 {
|
||||||
ai_model_id TEXT NOT NULL,
|
return nil
|
||||||
exchange_id TEXT NOT NULL,
|
}
|
||||||
initial_balance REAL NOT NULL,
|
|
||||||
scan_interval_minutes INTEGER DEFAULT 3,
|
|
||||||
is_running BOOLEAN DEFAULT 0,
|
|
||||||
btc_eth_leverage INTEGER DEFAULT 5,
|
|
||||||
altcoin_leverage INTEGER DEFAULT 5,
|
|
||||||
trading_symbols TEXT DEFAULT '',
|
|
||||||
use_coin_pool BOOLEAN DEFAULT 0,
|
|
||||||
use_oi_top BOOLEAN DEFAULT 0,
|
|
||||||
custom_prompt TEXT DEFAULT '',
|
|
||||||
override_base_prompt BOOLEAN DEFAULT 0,
|
|
||||||
system_prompt_template TEXT DEFAULT 'default',
|
|
||||||
is_cross_margin BOOLEAN DEFAULT 1,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
// Use GORM AutoMigrate
|
||||||
// Trigger
|
if err := s.db.AutoMigrate(&Trader{}); err != nil {
|
||||||
_, err = s.db.Exec(`
|
return fmt.Errorf("failed to migrate traders table: %w", err)
|
||||||
CREATE TRIGGER IF NOT EXISTS update_traders_updated_at
|
|
||||||
AFTER UPDATE ON traders
|
|
||||||
BEGIN
|
|
||||||
UPDATE traders SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
|
||||||
END
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Backward compatibility
|
|
||||||
alterQueries := []string{
|
|
||||||
`ALTER TABLE traders ADD COLUMN custom_prompt TEXT DEFAULT ''`,
|
|
||||||
`ALTER TABLE traders ADD COLUMN override_base_prompt BOOLEAN DEFAULT 0`,
|
|
||||||
`ALTER TABLE traders ADD COLUMN is_cross_margin BOOLEAN DEFAULT 1`,
|
|
||||||
`ALTER TABLE traders ADD COLUMN btc_eth_leverage INTEGER DEFAULT 5`,
|
|
||||||
`ALTER TABLE traders ADD COLUMN altcoin_leverage INTEGER DEFAULT 5`,
|
|
||||||
`ALTER TABLE traders ADD COLUMN trading_symbols TEXT DEFAULT ''`,
|
|
||||||
`ALTER TABLE traders ADD COLUMN use_coin_pool BOOLEAN DEFAULT 0`,
|
|
||||||
`ALTER TABLE traders ADD COLUMN use_oi_top BOOLEAN DEFAULT 0`,
|
|
||||||
`ALTER TABLE traders ADD COLUMN system_prompt_template TEXT DEFAULT 'default'`,
|
|
||||||
`ALTER TABLE traders ADD COLUMN strategy_id TEXT DEFAULT ''`,
|
|
||||||
`ALTER TABLE traders ADD COLUMN show_in_competition BOOLEAN DEFAULT 1`,
|
|
||||||
}
|
|
||||||
for _, q := range alterQueries {
|
|
||||||
s.db.Exec(q)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Migration: Remove FOREIGN KEY constraint from existing traders table
|
|
||||||
// SQLite doesn't support ALTER TABLE DROP CONSTRAINT, so we need to recreate the table
|
|
||||||
if err := s.migrateTradersRemoveFK(); err != nil {
|
|
||||||
// Log but don't fail - this is a best-effort migration
|
|
||||||
// The constraint may not exist in older databases
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// migrateTradersRemoveFK removes FOREIGN KEY constraint from traders table if it exists
|
|
||||||
func (s *TraderStore) migrateTradersRemoveFK() error {
|
|
||||||
// Check if the table has a foreign key constraint by examining the schema
|
|
||||||
var sql string
|
|
||||||
err := s.db.QueryRow(`SELECT sql FROM sqlite_master WHERE type='table' AND name='traders'`).Scan(&sql)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no FOREIGN KEY in schema, no migration needed
|
|
||||||
if !strings.Contains(sql, "FOREIGN KEY") {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recreate table without FOREIGN KEY constraint
|
|
||||||
_, err = s.db.Exec(`
|
|
||||||
-- Create new table without FOREIGN KEY
|
|
||||||
CREATE TABLE IF NOT EXISTS traders_new (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
user_id TEXT NOT NULL DEFAULT 'default',
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
ai_model_id TEXT NOT NULL,
|
|
||||||
exchange_id TEXT NOT NULL,
|
|
||||||
initial_balance REAL NOT NULL,
|
|
||||||
scan_interval_minutes INTEGER DEFAULT 3,
|
|
||||||
is_running BOOLEAN DEFAULT 0,
|
|
||||||
btc_eth_leverage INTEGER DEFAULT 5,
|
|
||||||
altcoin_leverage INTEGER DEFAULT 5,
|
|
||||||
trading_symbols TEXT DEFAULT '',
|
|
||||||
use_coin_pool BOOLEAN DEFAULT 0,
|
|
||||||
use_oi_top BOOLEAN DEFAULT 0,
|
|
||||||
custom_prompt TEXT DEFAULT '',
|
|
||||||
override_base_prompt BOOLEAN DEFAULT 0,
|
|
||||||
system_prompt_template TEXT DEFAULT 'default',
|
|
||||||
is_cross_margin BOOLEAN DEFAULT 1,
|
|
||||||
strategy_id TEXT DEFAULT '',
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
);
|
|
||||||
|
|
||||||
-- Copy data from old table
|
|
||||||
INSERT OR IGNORE INTO traders_new
|
|
||||||
SELECT id, user_id, name, ai_model_id, exchange_id, initial_balance,
|
|
||||||
scan_interval_minutes, is_running, btc_eth_leverage, altcoin_leverage,
|
|
||||||
trading_symbols, use_coin_pool, use_oi_top, custom_prompt,
|
|
||||||
override_base_prompt, system_prompt_template, is_cross_margin,
|
|
||||||
COALESCE(strategy_id, ''), created_at, updated_at
|
|
||||||
FROM traders;
|
|
||||||
|
|
||||||
-- Drop old table
|
|
||||||
DROP TABLE traders;
|
|
||||||
|
|
||||||
-- Rename new table
|
|
||||||
ALTER TABLE traders_new RENAME TO traders;
|
|
||||||
`)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recreate trigger
|
|
||||||
_, err = s.db.Exec(`
|
|
||||||
CREATE TRIGGER IF NOT EXISTS update_traders_updated_at
|
|
||||||
AFTER UPDATE ON traders
|
|
||||||
BEGIN
|
|
||||||
UPDATE traders SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
|
||||||
END
|
|
||||||
`)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *TraderStore) decrypt(encrypted string) string {
|
|
||||||
if s.decryptFunc != nil {
|
|
||||||
return s.decryptFunc(encrypted)
|
|
||||||
}
|
|
||||||
return encrypted
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create creates trader
|
// Create creates trader
|
||||||
func (s *TraderStore) Create(trader *Trader) error {
|
func (s *TraderStore) Create(trader *Trader) error {
|
||||||
_, err := s.db.Exec(`
|
return s.db.Create(trader).Error
|
||||||
INSERT INTO traders (id, user_id, name, ai_model_id, exchange_id, strategy_id, initial_balance,
|
|
||||||
scan_interval_minutes, is_running, is_cross_margin, show_in_competition,
|
|
||||||
btc_eth_leverage, altcoin_leverage, trading_symbols, use_coin_pool,
|
|
||||||
use_oi_top, custom_prompt, override_base_prompt, system_prompt_template)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
`, trader.ID, trader.UserID, trader.Name, trader.AIModelID, trader.ExchangeID, trader.StrategyID,
|
|
||||||
trader.InitialBalance, trader.ScanIntervalMinutes, trader.IsRunning, trader.IsCrossMargin, trader.ShowInCompetition,
|
|
||||||
trader.BTCETHLeverage, trader.AltcoinLeverage, trader.TradingSymbols, trader.UseCoinPool,
|
|
||||||
trader.UseOITop, trader.CustomPrompt, trader.OverrideBasePrompt, trader.SystemPromptTemplate)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// List gets user's trader list
|
// List gets user's trader list
|
||||||
func (s *TraderStore) List(userID string) ([]*Trader, error) {
|
func (s *TraderStore) List(userID string) ([]*Trader, error) {
|
||||||
rows, err := s.db.Query(`
|
var traders []*Trader
|
||||||
SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''),
|
err := s.db.Where("user_id = ?", userID).
|
||||||
initial_balance, scan_interval_minutes, is_running, COALESCE(is_cross_margin, 1),
|
Order("created_at DESC").
|
||||||
COALESCE(show_in_competition, 1),
|
Find(&traders).Error
|
||||||
COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''),
|
|
||||||
COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''),
|
|
||||||
COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'),
|
|
||||||
created_at, updated_at
|
|
||||||
FROM traders WHERE user_id = ? ORDER BY created_at DESC
|
|
||||||
`, userID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var traders []*Trader
|
|
||||||
for rows.Next() {
|
|
||||||
var t Trader
|
|
||||||
var createdAt, updatedAt string
|
|
||||||
err := rows.Scan(
|
|
||||||
&t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, &t.StrategyID,
|
|
||||||
&t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, &t.IsCrossMargin,
|
|
||||||
&t.ShowInCompetition,
|
|
||||||
&t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols,
|
|
||||||
&t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt,
|
|
||||||
&t.SystemPromptTemplate, &createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
traders = append(traders, &t)
|
|
||||||
}
|
|
||||||
return traders, nil
|
return traders, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateStatus updates trader running status
|
// UpdateStatus updates trader running status
|
||||||
func (s *TraderStore) UpdateStatus(userID, id string, isRunning bool) error {
|
func (s *TraderStore) UpdateStatus(userID, id string, isRunning bool) error {
|
||||||
_, err := s.db.Exec(`UPDATE traders SET is_running = ? WHERE id = ? AND user_id = ?`, isRunning, id, userID)
|
return s.db.Model(&Trader{}).
|
||||||
return err
|
Where("id = ? AND user_id = ?", id, userID).
|
||||||
|
Update("is_running", isRunning).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateShowInCompetition updates trader competition visibility
|
// UpdateShowInCompetition updates trader competition visibility
|
||||||
func (s *TraderStore) UpdateShowInCompetition(userID, id string, showInCompetition bool) error {
|
func (s *TraderStore) UpdateShowInCompetition(userID, id string, showInCompetition bool) error {
|
||||||
_, err := s.db.Exec(`UPDATE traders SET show_in_competition = ? WHERE id = ? AND user_id = ?`, showInCompetition, id, userID)
|
return s.db.Model(&Trader{}).
|
||||||
return err
|
Where("id = ? AND user_id = ?", id, userID).
|
||||||
|
Update("show_in_competition", showInCompetition).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update updates trader configuration
|
// Update updates trader configuration
|
||||||
func (s *TraderStore) Update(trader *Trader) error {
|
func (s *TraderStore) Update(trader *Trader) error {
|
||||||
fmt.Printf("📝 TraderStore.Update: ID=%s, Name=%s, AIModelID=%s, StrategyID=%s\n",
|
fmt.Printf("📝 TraderStore.Update: ID=%s, Name=%s, AIModelID=%s, StrategyID=%s\n",
|
||||||
trader.ID, trader.Name, trader.AIModelID, trader.StrategyID)
|
trader.ID, trader.Name, trader.AIModelID, trader.StrategyID)
|
||||||
_, err := s.db.Exec(`
|
|
||||||
UPDATE traders SET
|
updates := map[string]interface{}{
|
||||||
name = ?,
|
"name": trader.Name,
|
||||||
ai_model_id = ?,
|
"ai_model_id": trader.AIModelID,
|
||||||
exchange_id = ?,
|
"exchange_id": trader.ExchangeID,
|
||||||
strategy_id = ?,
|
"strategy_id": trader.StrategyID,
|
||||||
initial_balance = CASE WHEN ? > 0 THEN ? ELSE initial_balance END,
|
"is_cross_margin": trader.IsCrossMargin,
|
||||||
scan_interval_minutes = CASE WHEN ? > 0 THEN ? ELSE scan_interval_minutes END,
|
"show_in_competition": trader.ShowInCompetition,
|
||||||
is_cross_margin = ?,
|
}
|
||||||
show_in_competition = ?,
|
|
||||||
updated_at = CURRENT_TIMESTAMP
|
// Only update these if > 0
|
||||||
WHERE id = ? AND user_id = ?
|
if trader.InitialBalance > 0 {
|
||||||
`, trader.Name, trader.AIModelID, trader.ExchangeID, trader.StrategyID,
|
updates["initial_balance"] = trader.InitialBalance
|
||||||
trader.InitialBalance, trader.InitialBalance,
|
}
|
||||||
trader.ScanIntervalMinutes, trader.ScanIntervalMinutes,
|
if trader.ScanIntervalMinutes > 0 {
|
||||||
trader.IsCrossMargin, trader.ShowInCompetition,
|
updates["scan_interval_minutes"] = trader.ScanIntervalMinutes
|
||||||
trader.ID, trader.UserID)
|
}
|
||||||
return err
|
|
||||||
|
return s.db.Model(&Trader{}).
|
||||||
|
Where("id = ? AND user_id = ?", trader.ID, trader.UserID).
|
||||||
|
Updates(updates).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateInitialBalance updates initial balance
|
// UpdateInitialBalance updates initial balance
|
||||||
func (s *TraderStore) UpdateInitialBalance(userID, id string, newBalance float64) error {
|
func (s *TraderStore) UpdateInitialBalance(userID, id string, newBalance float64) error {
|
||||||
_, err := s.db.Exec(`UPDATE traders SET initial_balance = ? WHERE id = ? AND user_id = ?`, newBalance, id, userID)
|
return s.db.Model(&Trader{}).
|
||||||
return err
|
Where("id = ? AND user_id = ?", id, userID).
|
||||||
|
Update("initial_balance", newBalance).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateCustomPrompt updates custom prompt
|
// UpdateCustomPrompt updates custom prompt
|
||||||
func (s *TraderStore) UpdateCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error {
|
func (s *TraderStore) UpdateCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error {
|
||||||
_, err := s.db.Exec(`UPDATE traders SET custom_prompt = ?, override_base_prompt = ? WHERE id = ? AND user_id = ?`,
|
return s.db.Model(&Trader{}).
|
||||||
customPrompt, overrideBase, id, userID)
|
Where("id = ? AND user_id = ?", id, userID).
|
||||||
return err
|
Updates(map[string]interface{}{
|
||||||
|
"custom_prompt": customPrompt,
|
||||||
|
"override_base_prompt": overrideBase,
|
||||||
|
}).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete deletes trader and associated data
|
// Delete deletes trader and associated data
|
||||||
func (s *TraderStore) Delete(userID, id string) error {
|
func (s *TraderStore) Delete(userID, id string) error {
|
||||||
// Delete associated equity snapshots first
|
// Delete associated equity snapshots first
|
||||||
_, _ = s.db.Exec(`DELETE FROM trader_equity_snapshots WHERE trader_id = ?`, id)
|
s.db.Where("trader_id = ?", id).Delete(&EquitySnapshot{})
|
||||||
|
|
||||||
// Delete the trader
|
// Delete the trader
|
||||||
_, err := s.db.Exec(`DELETE FROM traders WHERE id = ? AND user_id = ?`, id, userID)
|
return s.db.Where("id = ? AND user_id = ?", id, userID).Delete(&Trader{}).Error
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFullConfig gets trader full configuration
|
// GetFullConfig gets trader full configuration
|
||||||
func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig, error) {
|
func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig, error) {
|
||||||
var trader Trader
|
var trader Trader
|
||||||
var aiModel AIModel
|
err := s.db.Where("id = ? AND user_id = ?", traderID, userID).First(&trader).Error
|
||||||
var exchange Exchange
|
|
||||||
var traderCreatedAt, traderUpdatedAt string
|
|
||||||
var aiModelCreatedAt, aiModelUpdatedAt string
|
|
||||||
var exchangeCreatedAt, exchangeUpdatedAt string
|
|
||||||
|
|
||||||
err := s.db.QueryRow(`
|
|
||||||
SELECT
|
|
||||||
t.id, t.user_id, t.name, t.ai_model_id, t.exchange_id, COALESCE(t.strategy_id, ''),
|
|
||||||
t.initial_balance, t.scan_interval_minutes, t.is_running, COALESCE(t.is_cross_margin, 1),
|
|
||||||
COALESCE(t.btc_eth_leverage, 5), COALESCE(t.altcoin_leverage, 5), COALESCE(t.trading_symbols, ''),
|
|
||||||
COALESCE(t.use_coin_pool, 0), COALESCE(t.use_oi_top, 0), COALESCE(t.custom_prompt, ''),
|
|
||||||
COALESCE(t.override_base_prompt, 0), COALESCE(t.system_prompt_template, 'default'),
|
|
||||||
t.created_at, t.updated_at,
|
|
||||||
a.id, a.user_id, a.name, a.provider, a.enabled, a.api_key,
|
|
||||||
COALESCE(a.custom_api_url, ''), COALESCE(a.custom_model_name, ''), a.created_at, a.updated_at,
|
|
||||||
e.id, COALESCE(e.exchange_type, '') as exchange_type, COALESCE(e.account_name, '') as account_name,
|
|
||||||
e.user_id, e.name, e.type, e.enabled, e.api_key, e.secret_key, COALESCE(e.passphrase, ''), e.testnet,
|
|
||||||
COALESCE(e.hyperliquid_wallet_addr, ''), COALESCE(e.aster_user, ''), COALESCE(e.aster_signer, ''),
|
|
||||||
COALESCE(e.aster_private_key, ''), COALESCE(e.lighter_wallet_addr, ''), COALESCE(e.lighter_private_key, ''),
|
|
||||||
COALESCE(e.lighter_api_key_private_key, ''), COALESCE(e.lighter_api_key_index, 0), e.created_at, e.updated_at
|
|
||||||
FROM traders t
|
|
||||||
JOIN ai_models a ON t.ai_model_id = a.id AND t.user_id = a.user_id
|
|
||||||
JOIN exchanges e ON t.exchange_id = e.id AND t.user_id = e.user_id
|
|
||||||
WHERE t.id = ? AND t.user_id = ?
|
|
||||||
`, traderID, userID).Scan(
|
|
||||||
&trader.ID, &trader.UserID, &trader.Name, &trader.AIModelID, &trader.ExchangeID, &trader.StrategyID,
|
|
||||||
&trader.InitialBalance, &trader.ScanIntervalMinutes, &trader.IsRunning, &trader.IsCrossMargin,
|
|
||||||
&trader.BTCETHLeverage, &trader.AltcoinLeverage, &trader.TradingSymbols,
|
|
||||||
&trader.UseCoinPool, &trader.UseOITop, &trader.CustomPrompt, &trader.OverrideBasePrompt,
|
|
||||||
&trader.SystemPromptTemplate, &traderCreatedAt, &traderUpdatedAt,
|
|
||||||
&aiModel.ID, &aiModel.UserID, &aiModel.Name, &aiModel.Provider, &aiModel.Enabled, &aiModel.APIKey,
|
|
||||||
&aiModel.CustomAPIURL, &aiModel.CustomModelName, &aiModelCreatedAt, &aiModelUpdatedAt,
|
|
||||||
&exchange.ID, &exchange.ExchangeType, &exchange.AccountName,
|
|
||||||
&exchange.UserID, &exchange.Name, &exchange.Type, &exchange.Enabled,
|
|
||||||
&exchange.APIKey, &exchange.SecretKey, &exchange.Passphrase, &exchange.Testnet, &exchange.HyperliquidWalletAddr,
|
|
||||||
&exchange.AsterUser, &exchange.AsterSigner, &exchange.AsterPrivateKey,
|
|
||||||
&exchange.LighterWalletAddr, &exchange.LighterPrivateKey, &exchange.LighterAPIKeyPrivateKey, &exchange.LighterAPIKeyIndex,
|
|
||||||
&exchangeCreatedAt, &exchangeUpdatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
trader.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", traderCreatedAt)
|
// Get AI model
|
||||||
trader.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", traderUpdatedAt)
|
var aiModel AIModel
|
||||||
aiModel.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelCreatedAt)
|
err = s.db.Where("id = ? AND user_id = ?", trader.AIModelID, userID).First(&aiModel).Error
|
||||||
aiModel.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelUpdatedAt)
|
if err != nil {
|
||||||
exchange.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeCreatedAt)
|
return nil, fmt.Errorf("failed to get AI model: %w", err)
|
||||||
exchange.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeUpdatedAt)
|
}
|
||||||
|
|
||||||
// Decrypt
|
// Get exchange
|
||||||
aiModel.APIKey = s.decrypt(aiModel.APIKey)
|
var exchange Exchange
|
||||||
exchange.APIKey = s.decrypt(exchange.APIKey)
|
err = s.db.Where("id = ? AND user_id = ?", trader.ExchangeID, userID).First(&exchange).Error
|
||||||
exchange.SecretKey = s.decrypt(exchange.SecretKey)
|
if err != nil {
|
||||||
exchange.Passphrase = s.decrypt(exchange.Passphrase)
|
return nil, fmt.Errorf("failed to get exchange: %w", err)
|
||||||
exchange.AsterPrivateKey = s.decrypt(exchange.AsterPrivateKey)
|
}
|
||||||
exchange.LighterPrivateKey = s.decrypt(exchange.LighterPrivateKey)
|
|
||||||
exchange.LighterAPIKeyPrivateKey = s.decrypt(exchange.LighterAPIKeyPrivateKey)
|
|
||||||
|
|
||||||
// Load associated strategy
|
// Load associated strategy
|
||||||
var strategy *Strategy
|
var strategy *Strategy
|
||||||
@@ -392,119 +200,48 @@ func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig,
|
|||||||
// getStrategyByID internal method: gets strategy by ID
|
// getStrategyByID internal method: gets strategy by ID
|
||||||
func (s *TraderStore) getStrategyByID(userID, strategyID string) (*Strategy, error) {
|
func (s *TraderStore) getStrategyByID(userID, strategyID string) (*Strategy, error) {
|
||||||
var strategy Strategy
|
var strategy Strategy
|
||||||
var createdAt, updatedAt string
|
err := s.db.Where("id = ? AND (user_id = ? OR is_default = ?)", strategyID, userID, true).
|
||||||
err := s.db.QueryRow(`
|
First(&strategy).Error
|
||||||
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
|
|
||||||
FROM strategies WHERE id = ? AND (user_id = ? OR is_default = 1)
|
|
||||||
`, strategyID, userID).Scan(
|
|
||||||
&strategy.ID, &strategy.UserID, &strategy.Name, &strategy.Description,
|
|
||||||
&strategy.IsActive, &strategy.IsDefault, &strategy.Config, &createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
strategy.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
strategy.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
return &strategy, nil
|
return &strategy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getActiveOrDefaultStrategy internal method: gets user's active strategy or system default strategy
|
// getActiveOrDefaultStrategy internal method: gets user's active strategy or system default strategy
|
||||||
func (s *TraderStore) getActiveOrDefaultStrategy(userID string) (*Strategy, error) {
|
func (s *TraderStore) getActiveOrDefaultStrategy(userID string) (*Strategy, error) {
|
||||||
var strategy Strategy
|
var strategy Strategy
|
||||||
var createdAt, updatedAt string
|
|
||||||
|
|
||||||
// First try to get user's active strategy
|
// First try to get user's active strategy
|
||||||
err := s.db.QueryRow(`
|
err := s.db.Where("user_id = ? AND is_active = ?", userID, true).First(&strategy).Error
|
||||||
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
|
|
||||||
FROM strategies WHERE user_id = ? AND is_active = 1
|
|
||||||
`, userID).Scan(
|
|
||||||
&strategy.ID, &strategy.UserID, &strategy.Name, &strategy.Description,
|
|
||||||
&strategy.IsActive, &strategy.IsDefault, &strategy.Config, &createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
strategy.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
strategy.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
return &strategy, nil
|
return &strategy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to system default strategy
|
// Fallback to system default strategy
|
||||||
err = s.db.QueryRow(`
|
err = s.db.Where("is_default = ?", true).First(&strategy).Error
|
||||||
SELECT id, user_id, name, description, is_active, is_default, config, created_at, updated_at
|
|
||||||
FROM strategies WHERE is_default = 1 LIMIT 1
|
|
||||||
`).Scan(
|
|
||||||
&strategy.ID, &strategy.UserID, &strategy.Name, &strategy.Description,
|
|
||||||
&strategy.IsActive, &strategy.IsDefault, &strategy.Config, &createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
strategy.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
strategy.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
return &strategy, nil
|
return &strategy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListAll gets all users' trader list
|
|
||||||
// GetByID gets a trader by ID without requiring userID (for public APIs)
|
// GetByID gets a trader by ID without requiring userID (for public APIs)
|
||||||
func (s *TraderStore) GetByID(traderID string) (*Trader, error) {
|
func (s *TraderStore) GetByID(traderID string) (*Trader, error) {
|
||||||
var t Trader
|
var trader Trader
|
||||||
var createdAt, updatedAt string
|
err := s.db.Where("id = ?", traderID).First(&trader).Error
|
||||||
err := s.db.QueryRow(`
|
|
||||||
SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''),
|
|
||||||
initial_balance, scan_interval_minutes, is_running, COALESCE(is_cross_margin, 1),
|
|
||||||
COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''),
|
|
||||||
COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''),
|
|
||||||
COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'),
|
|
||||||
created_at, updated_at
|
|
||||||
FROM traders WHERE id = ?
|
|
||||||
`, traderID).Scan(
|
|
||||||
&t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, &t.StrategyID,
|
|
||||||
&t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, &t.IsCrossMargin,
|
|
||||||
&t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols,
|
|
||||||
&t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt,
|
|
||||||
&t.SystemPromptTemplate, &createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
return &trader, nil
|
||||||
t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
return &t, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListAll gets all traders
|
||||||
func (s *TraderStore) ListAll() ([]*Trader, error) {
|
func (s *TraderStore) ListAll() ([]*Trader, error) {
|
||||||
rows, err := s.db.Query(`
|
var traders []*Trader
|
||||||
SELECT id, user_id, name, ai_model_id, exchange_id, COALESCE(strategy_id, ''),
|
err := s.db.Order("created_at DESC").Find(&traders).Error
|
||||||
initial_balance, scan_interval_minutes, is_running, COALESCE(is_cross_margin, 1),
|
|
||||||
COALESCE(show_in_competition, 1),
|
|
||||||
COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''),
|
|
||||||
COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''),
|
|
||||||
COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'),
|
|
||||||
created_at, updated_at
|
|
||||||
FROM traders ORDER BY created_at DESC
|
|
||||||
`)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var traders []*Trader
|
|
||||||
for rows.Next() {
|
|
||||||
var t Trader
|
|
||||||
var createdAt, updatedAt string
|
|
||||||
err := rows.Scan(
|
|
||||||
&t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID, &t.StrategyID,
|
|
||||||
&t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning, &t.IsCrossMargin,
|
|
||||||
&t.ShowInCompetition,
|
|
||||||
&t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols,
|
|
||||||
&t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt,
|
|
||||||
&t.SystemPromptTemplate, &createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
traders = append(traders, &t)
|
|
||||||
}
|
|
||||||
return traders, nil
|
return traders, nil
|
||||||
}
|
}
|
||||||
|
|||||||
+60
-87
@@ -2,27 +2,30 @@ package store
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"database/sql"
|
|
||||||
"encoding/base32"
|
"encoding/base32"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UserStore user storage
|
// UserStore user storage
|
||||||
type UserStore struct {
|
type UserStore struct {
|
||||||
db *sql.DB
|
db *gorm.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// User user
|
// User user model
|
||||||
type User struct {
|
type User struct {
|
||||||
ID string `json:"id"`
|
ID string `gorm:"primaryKey" json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `gorm:"uniqueIndex:idx_users_email;not null" json:"email"`
|
||||||
PasswordHash string `json:"-"`
|
PasswordHash string `gorm:"column:password_hash;not null" json:"-"`
|
||||||
OTPSecret string `json:"-"`
|
OTPSecret string `gorm:"column:otp_secret" json:"-"`
|
||||||
OTPVerified bool `json:"otp_verified"`
|
OTPVerified bool `gorm:"column:otp_verified;default:false" json:"otp_verified"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (User) TableName() string { return "users" }
|
||||||
|
|
||||||
// GenerateOTPSecret generates OTP secret
|
// GenerateOTPSecret generates OTP secret
|
||||||
func GenerateOTPSecret() (string, error) {
|
func GenerateOTPSecret() (string, error) {
|
||||||
secret := make([]byte, 20)
|
secret := make([]byte, 20)
|
||||||
@@ -33,131 +36,101 @@ func GenerateOTPSecret() (string, error) {
|
|||||||
return base32.StdEncoding.EncodeToString(secret), nil
|
return base32.StdEncoding.EncodeToString(secret), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewUserStore creates a new UserStore
|
||||||
|
func NewUserStore(db *gorm.DB) *UserStore {
|
||||||
|
return &UserStore{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *UserStore) initTables() error {
|
func (s *UserStore) initTables() error {
|
||||||
_, err := s.db.Exec(`
|
// For PostgreSQL with existing table, skip AutoMigrate to avoid index conflicts
|
||||||
CREATE TABLE IF NOT EXISTS users (
|
if s.db.Dialector.Name() == "postgres" {
|
||||||
id TEXT PRIMARY KEY,
|
var tableExists int64
|
||||||
email TEXT UNIQUE NOT NULL,
|
s.db.Raw(`SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'users'`).Scan(&tableExists)
|
||||||
password_hash TEXT NOT NULL,
|
|
||||||
otp_secret TEXT,
|
|
||||||
otp_verified BOOLEAN DEFAULT 0,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trigger
|
if tableExists > 0 {
|
||||||
_, err = s.db.Exec(`
|
// Table exists - manually ensure all columns exist
|
||||||
CREATE TRIGGER IF NOT EXISTS update_users_updated_at
|
// Core columns (should already exist)
|
||||||
AFTER UPDATE ON users
|
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS email TEXT NOT NULL DEFAULT ''`)
|
||||||
BEGIN
|
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS password_hash TEXT NOT NULL DEFAULT ''`)
|
||||||
UPDATE users SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP`)
|
||||||
END
|
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP`)
|
||||||
`)
|
// OTP columns (added later)
|
||||||
if err != nil {
|
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS otp_secret TEXT DEFAULT ''`)
|
||||||
return err
|
s.db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS otp_verified BOOLEAN DEFAULT FALSE`)
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
// Ensure unique index exists on email (don't care about the name)
|
||||||
|
var indexExists int64
|
||||||
|
s.db.Raw(`
|
||||||
|
SELECT COUNT(*) FROM pg_indexes
|
||||||
|
WHERE tablename = 'users' AND indexdef LIKE '%email%' AND indexdef LIKE '%UNIQUE%'
|
||||||
|
`).Scan(&indexExists)
|
||||||
|
|
||||||
|
if indexExists == 0 {
|
||||||
|
s.db.Exec("CREATE UNIQUE INDEX idx_users_email ON users(email)")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.db.AutoMigrate(&User{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create creates user
|
// Create creates user
|
||||||
func (s *UserStore) Create(user *User) error {
|
func (s *UserStore) Create(user *User) error {
|
||||||
_, err := s.db.Exec(`
|
return s.db.Create(user).Error
|
||||||
INSERT INTO users (id, email, password_hash, otp_secret, otp_verified)
|
|
||||||
VALUES (?, ?, ?, ?, ?)
|
|
||||||
`, user.ID, user.Email, user.PasswordHash, user.OTPSecret, user.OTPVerified)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByEmail gets user by email
|
// GetByEmail gets user by email
|
||||||
func (s *UserStore) GetByEmail(email string) (*User, error) {
|
func (s *UserStore) GetByEmail(email string) (*User, error) {
|
||||||
var user User
|
var user User
|
||||||
var createdAt, updatedAt string
|
err := s.db.Where("email = ?", email).First(&user).Error
|
||||||
err := s.db.QueryRow(`
|
|
||||||
SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at
|
|
||||||
FROM users WHERE email = ?
|
|
||||||
`, email).Scan(
|
|
||||||
&user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret,
|
|
||||||
&user.OTPVerified, &createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByID gets user by ID
|
// GetByID gets user by ID
|
||||||
func (s *UserStore) GetByID(userID string) (*User, error) {
|
func (s *UserStore) GetByID(userID string) (*User, error) {
|
||||||
var user User
|
var user User
|
||||||
var createdAt, updatedAt string
|
err := s.db.Where("id = ?", userID).First(&user).Error
|
||||||
err := s.db.QueryRow(`
|
|
||||||
SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at
|
|
||||||
FROM users WHERE id = ?
|
|
||||||
`, userID).Scan(
|
|
||||||
&user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret,
|
|
||||||
&user.OTPVerified, &createdAt, &updatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
|
||||||
user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count returns the total number of users
|
// Count returns the total number of users
|
||||||
func (s *UserStore) Count() (int, error) {
|
func (s *UserStore) Count() (int, error) {
|
||||||
var count int
|
var count int64
|
||||||
err := s.db.QueryRow(`SELECT COUNT(*) FROM users`).Scan(&count)
|
err := s.db.Model(&User{}).Count(&count).Error
|
||||||
return count, err
|
return int(count), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllIDs gets all user IDs
|
// GetAllIDs gets all user IDs
|
||||||
func (s *UserStore) GetAllIDs() ([]string, error) {
|
func (s *UserStore) GetAllIDs() ([]string, error) {
|
||||||
rows, err := s.db.Query(`SELECT id FROM users ORDER BY id`)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var userIDs []string
|
var userIDs []string
|
||||||
for rows.Next() {
|
err := s.db.Model(&User{}).Order("id").Pluck("id", &userIDs).Error
|
||||||
var userID string
|
return userIDs, err
|
||||||
if err := rows.Scan(&userID); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
userIDs = append(userIDs, userID)
|
|
||||||
}
|
|
||||||
return userIDs, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateOTPVerified updates OTP verification status
|
// UpdateOTPVerified updates OTP verification status
|
||||||
func (s *UserStore) UpdateOTPVerified(userID string, verified bool) error {
|
func (s *UserStore) UpdateOTPVerified(userID string, verified bool) error {
|
||||||
_, err := s.db.Exec(`UPDATE users SET otp_verified = ? WHERE id = ?`, verified, userID)
|
return s.db.Model(&User{}).Where("id = ?", userID).Update("otp_verified", verified).Error
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePassword updates password
|
// UpdatePassword updates password
|
||||||
func (s *UserStore) UpdatePassword(userID, passwordHash string) error {
|
func (s *UserStore) UpdatePassword(userID, passwordHash string) error {
|
||||||
_, err := s.db.Exec(`
|
return s.db.Model(&User{}).Where("id = ?", userID).Updates(map[string]interface{}{
|
||||||
UPDATE users SET password_hash = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?
|
"password_hash": passwordHash,
|
||||||
`, passwordHash, userID)
|
"updated_at": time.Now(),
|
||||||
return err
|
}).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnsureAdmin ensures admin user exists
|
// EnsureAdmin ensures admin user exists
|
||||||
func (s *UserStore) EnsureAdmin() error {
|
func (s *UserStore) EnsureAdmin() error {
|
||||||
var count int
|
var count int64
|
||||||
err := s.db.QueryRow(`SELECT COUNT(*) FROM users WHERE id = 'admin'`).Scan(&count)
|
s.db.Model(&User{}).Where("id = ?", "admin").Count(&count)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if count > 0 {
|
if count > 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
package trader
|
package trader
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"nofx/store"
|
"nofx/store"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestScenario represents a trading scenario to test
|
// TestScenario represents a trading scenario to test
|
||||||
@@ -116,11 +117,12 @@ func runStandardTests(t *testing.T, exchangeName string) {
|
|||||||
for _, scenario := range scenarios {
|
for _, scenario := range scenarios {
|
||||||
t.Run(scenario.Name, func(t *testing.T) {
|
t.Run(scenario.Name, func(t *testing.T) {
|
||||||
// Setup database
|
// Setup database
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create test database: %v", err)
|
t.Fatalf("Failed to create test database: %v", err)
|
||||||
}
|
}
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
positionStore := store.NewPositionStore(db)
|
positionStore := store.NewPositionStore(db)
|
||||||
if err := positionStore.InitTables(); err != nil {
|
if err := positionStore.InitTables(); err != nil {
|
||||||
@@ -199,11 +201,12 @@ func TestAllExchangesStandardScenarios(t *testing.T) {
|
|||||||
|
|
||||||
// TestPositionAccumulationBug tests that positions don't accumulate incorrectly
|
// TestPositionAccumulationBug tests that positions don't accumulate incorrectly
|
||||||
func TestPositionAccumulationBug(t *testing.T) {
|
func TestPositionAccumulationBug(t *testing.T) {
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create test database: %v", err)
|
t.Fatalf("Failed to create test database: %v", err)
|
||||||
}
|
}
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
positionStore := store.NewPositionStore(db)
|
positionStore := store.NewPositionStore(db)
|
||||||
if err := positionStore.InitTables(); err != nil {
|
if err := positionStore.InitTables(); err != nil {
|
||||||
@@ -283,11 +286,12 @@ func TestPositionAccumulationBug(t *testing.T) {
|
|||||||
|
|
||||||
// TestQuantityPrecision tests handling of quantity precision issues
|
// TestQuantityPrecision tests handling of quantity precision issues
|
||||||
func TestQuantityPrecision(t *testing.T) {
|
func TestQuantityPrecision(t *testing.T) {
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create test database: %v", err)
|
t.Fatalf("Failed to create test database: %v", err)
|
||||||
}
|
}
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
positionStore := store.NewPositionStore(db)
|
positionStore := store.NewPositionStore(db)
|
||||||
if err := positionStore.InitTables(); err != nil {
|
if err := positionStore.InitTables(); err != nil {
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package trader
|
package trader
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"math"
|
"math"
|
||||||
"nofx/store"
|
"nofx/store"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestHyperliquidOrderDirectionParsing tests Dir field parsing
|
// TestHyperliquidOrderDirectionParsing tests Dir field parsing
|
||||||
@@ -75,11 +76,12 @@ func TestHyperliquidOrderDirectionParsing(t *testing.T) {
|
|||||||
// TestHyperliquidPositionBuilding tests the complete flow of position building
|
// TestHyperliquidPositionBuilding tests the complete flow of position building
|
||||||
func TestHyperliquidPositionBuilding(t *testing.T) {
|
func TestHyperliquidPositionBuilding(t *testing.T) {
|
||||||
// Setup in-memory database
|
// Setup in-memory database
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create test database: %v", err)
|
t.Fatalf("Failed to create test database: %v", err)
|
||||||
}
|
}
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
// Initialize stores
|
// Initialize stores
|
||||||
positionStore := store.NewPositionStore(db)
|
positionStore := store.NewPositionStore(db)
|
||||||
@@ -304,11 +306,12 @@ func TestHyperliquidPositionBuilding(t *testing.T) {
|
|||||||
// TestHyperliquidBugScenario tests the exact bug we fixed
|
// TestHyperliquidBugScenario tests the exact bug we fixed
|
||||||
func TestHyperliquidBugScenario(t *testing.T) {
|
func TestHyperliquidBugScenario(t *testing.T) {
|
||||||
// Setup database
|
// Setup database
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create test database: %v", err)
|
t.Fatalf("Failed to create test database: %v", err)
|
||||||
}
|
}
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
positionStore := store.NewPositionStore(db)
|
positionStore := store.NewPositionStore(db)
|
||||||
if err := positionStore.InitTables(); err != nil {
|
if err := positionStore.InitTables(); err != nil {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { motion } from 'framer-motion'
|
import { motion } from 'framer-motion'
|
||||||
import { Bot, TrendingUp, Layers, Zap, Hexagon, Crosshair } from 'lucide-react'
|
import { TrendingUp, Layers, Zap, Hexagon, Crosshair } from 'lucide-react'
|
||||||
|
|
||||||
const agents = [
|
const agents = [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,8 +1,15 @@
|
|||||||
import { motion } from 'framer-motion'
|
import { motion } from 'framer-motion'
|
||||||
import { Activity, BarChart3, Globe, Wifi, Server, Database, Lock } from 'lucide-react'
|
|
||||||
import { useState, useEffect } from 'react'
|
import { useState, useEffect } from 'react'
|
||||||
|
|
||||||
const generateLog = (id) => {
|
interface LogEntry {
|
||||||
|
id: number
|
||||||
|
time: string
|
||||||
|
type: string
|
||||||
|
msg: string
|
||||||
|
color: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const generateLog = (id: number): LogEntry => {
|
||||||
const types = ['EXE', 'ARB', 'LIQ', 'NET', 'SYS']
|
const types = ['EXE', 'ARB', 'LIQ', 'NET', 'SYS']
|
||||||
const pairs = ['BTC-USDT', 'ETH-PERP', 'SOL-USDT', 'BNB-BUSD']
|
const pairs = ['BTC-USDT', 'ETH-PERP', 'SOL-USDT', 'BNB-BUSD']
|
||||||
const actions = ['BUY', 'SELL', 'SHORT', 'LONG']
|
const actions = ['BUY', 'SELL', 'SHORT', 'LONG']
|
||||||
@@ -37,7 +44,7 @@ const generateLog = (id) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function LiveFeed() {
|
export default function LiveFeed() {
|
||||||
const [logs, setLogs] = useState([])
|
const [logs, setLogs] = useState<LogEntry[]>([])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// Initial population
|
// Initial population
|
||||||
|
|||||||
@@ -159,11 +159,24 @@ export default function TerminalHero() {
|
|||||||
<span className="text-stroke-1 text-transparent bg-clip-text bg-gradient-to-r from-nofx-gold via-white to-nofx-gold animate-shimmer bg-[length:200%_auto]">TRADING</span>
|
<span className="text-stroke-1 text-transparent bg-clip-text bg-gradient-to-r from-nofx-gold via-white to-nofx-gold animate-shimmer bg-[length:200%_auto]">TRADING</span>
|
||||||
</h1>
|
</h1>
|
||||||
|
|
||||||
<p className="max-w-xl text-zinc-400 text-lg mb-10 font-light leading-relaxed">
|
<p className="max-w-xl text-zinc-400 text-lg mb-6 font-light leading-relaxed">
|
||||||
The World's First Open-Source Agentic Trading OS.
|
The World's First Open-Source Agentic Trading OS.
|
||||||
Deploy autonomous high-frequency trading agents powered by advanced LLMs.
|
Deploy autonomous high-frequency trading agents powered by advanced LLMs.
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
{/* Market Access Strip - Prominent Display */}
|
||||||
|
<div className="flex flex-wrap gap-4 mb-12 font-mono">
|
||||||
|
{['CRYPTO', 'US STOCKS', 'FOREX', 'METALS'].map((market) => (
|
||||||
|
<div key={market} className="flex items-center gap-3 px-4 py-2 rounded bg-zinc-900 border border-zinc-700 text-white font-bold tracking-wider hover:border-nofx-gold hover:shadow-[0_0_15px_rgba(255,215,0,0.3)] transition-all duration-300">
|
||||||
|
<span className="relative flex h-2 w-2">
|
||||||
|
<span className="animate-ping absolute inline-flex h-full w-full rounded-full bg-nofx-success opacity-75"></span>
|
||||||
|
<span className="relative inline-flex rounded-full h-2 w-2 bg-nofx-success"></span>
|
||||||
|
</span>
|
||||||
|
{market}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
|
||||||
{/* Command Line Input Simulation */}
|
{/* Command Line Input Simulation */}
|
||||||
<div className="w-full max-w-lg h-12 bg-black/50 border border-zinc-800 rounded flex items-center px-4 mb-10 font-mono text-sm shadow-2xl backdrop-blur-sm group hover:border-nofx-gold/50 transition-colors cursor-text" onClick={() => document.getElementById('market-scanner')?.scrollIntoView({ behavior: 'smooth' })}>
|
<div className="w-full max-w-lg h-12 bg-black/50 border border-zinc-800 rounded flex items-center px-4 mb-10 font-mono text-sm shadow-2xl backdrop-blur-sm group hover:border-nofx-gold/50 transition-colors cursor-text" onClick={() => document.getElementById('market-scanner')?.scrollIntoView({ behavior: 'smooth' })}>
|
||||||
<span className="text-nofx-success mr-2">➜</span>
|
<span className="text-nofx-success mr-2">➜</span>
|
||||||
@@ -269,7 +282,7 @@ export default function TerminalHero() {
|
|||||||
import { OFFICIAL_LINKS } from '../../../constants/branding'
|
import { OFFICIAL_LINKS } from '../../../constants/branding'
|
||||||
|
|
||||||
function CommunityStats() {
|
function CommunityStats() {
|
||||||
const { stars, forks, contributors, isLoading, error } = useGitHubStats('tinkle-community', 'nofx')
|
const { stars, forks, contributors, isLoading, error } = useGitHubStats('NoFxAiOS', 'nofx')
|
||||||
|
|
||||||
const stats = [
|
const stats = [
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user